自学内容网 自学内容网

手动构建线性回归Day34

import torch
import math
from sklearn.datasets import make_regression
import random

def data():
    bias = 10
    x,y,coef = make_regression(1000,7,bias = bias,coef=True,noise=2)
    x = torch.tensor(x,dtype=torch.float32,requires_grad=True)
    y = torch.tensor(y,dtype=torch.float32,requires_grad=True)
    return x , y , coef , bias

def data_load(x,y,batch_size):
    data_row = x.shape[0]
    index = list(range(data_row))
    random.shuffle(index)
    part = math.ceil(data_row/batch_size)
    for i in range(part):
        x_test = x[i*batch_size:min((i+1)*batch_size,data_row)]
        y_test = y[i*batch_size:min((i+1)*batch_size,data_row)]
        yield x_test,y_test

def initialize(n_features):
    w = torch.randn(n_features,requires_grad=True,dtype=torch.float32)
    b = torch.tensor(6.7,requires_grad=True,dtype=torch.float32)
    return w , b


def linear_regression(x,w,b):
    return torch.matmul(x,w)+b

def loss_2(y_pred,y_true):
    return torch.mean((y_pred-y_true)**2)

def sgd(w,b,dw,db,lr,batch_size):
    w.data -= lr*dw.data/batch_size
    b.data -= lr*db.data/batch_size
    return w , b

def train():
    x,y,coef,bias = data()
    w , b = initialize(x.shape[1])

    lr = 0.01
    epoch = 120
    batch_size = 15

    for i in range(epoch):
        epoch_loss = 0
        count = 0
        for x_train , y_train in data_load(x,y,batch_size):
            y_pred = linear_regression(x_train,w,b)
            loss = loss_2(y_pred,y_train)
            epoch_loss += loss
            count += 1
            if w.grad is not None:
                w.grad.zero_()
            if b.grad is not None:
                b.grad.zero_()
            loss.backward()
            sgd(w,b,w.grad,b.grad,lr,batch_size)
        print(f'epoch:{i},epoch_loss:{epoch_loss/count}')
    return w , b , coef , bias

if __name__ == '__main__':
    train()
    print(f"真实系数: {coef}")
    print(f"预测系数: {w}")
    print(f"真实偏置: {bias}")
    print(f"预测偏置: {b}")

原文地址:https://blog.csdn.net/KeKe_L/article/details/144038993

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!