自学内容网 自学内容网

练习2-线性回归迭代(李沐函数简要解析)

环境:再练习1中
视频链接:https://www.bilibili.com/video/BV1PX4y1g7KC/?spm_id_from=333.999.0.0

代码与详解

数据库
numpy 数据处理处理
torch.utils 数据加载与数据
d2l 专门的库
nn 包含各种层与激活函数

import numpy as np
import torch 
from torch.utils import data
from d2l import torch as d2l
from torch import nn

生成数据集
w=torch.tensor([2,-3.4]) 生成一维两个向量的张量
features,labels=d2l.synthetic_data(w,b,nume) 生成nume个w为权重,b为偏置的数据

w=torch.tensor([2,-3.4])
b=4.2
features,labels=d2l.synthetic_data(w,b,100)

定义对数据集的读取
data.TensorDataset(*data_arrays) 将多个张量合并为一个 通常用于合并特征值与标签 data_arrays=(features,labels)
data.DataLoader(dataset,batchsize,shuffle=true)每次根据上一个函数返回的对象读取batchsize个值 并打乱数据

def load_arrays(data_arrays,batch_size,is_train=True):
    dataset=data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset,batch_size,is_train)

定义数据加载器 并 调用
next(iter(已初始化的数据加载器)) 重新调用数据加载器

batch_size=10
data_iter=load_arrays((features,labels),batch_size)
next(iter(data_iter))

定义模型

定义为线性模型且只有一层
nn.Sequential() 用于包装层
nn.linear(2,1) 用于定义两输入一输出的线性层

net=nn.Sequential(nn.Linear(2,1))

初始化参数 w,b,lr,epoch,batch_size
net[0].weight.data.normal 正态分布
net[0].bias.data.fill_(0) b赋值

net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)

定义损失函数 平方误差
nn.MSELoss()

Loss=nn.MSELoss()

优化算法 小批量梯度下降 torch.optim.SGD(net.parameters(), lr=0.03)

trainer=torch.optim.SGD(net.parameters(),lr=0.03)

训练

epochs=3
for epoch in range(epochs):
    for X,y in data_iter:
        l=Loss(net(X),y)
         # 将梯度清零   
        trainer.zero_grad()
        # 反向传播
        l.backward()
        #更新参数
        trainer.step()
    l=Loss(net(features),labels)
    print(f'epoch {epoch + 1}, loss {l:f}')

相关函数与组成部分

定义模型

定义线性回归模型
from torch import nn
net=nn.Sequential(nn.Linear(2,1))

为模型赋值

w,b正态分布
net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)

定义损失函数

Loss=nn.MSELoss()

定义优化算法

trainer = torch.optim.SGD(net.parameters(),lr=0.03)

(训练与反向传播不太了解)

相关的Python语法

def 函数名(变量=True):
return 

for epoch in range(epochs):

原文地址:https://blog.csdn.net/XXxia1XX/article/details/136411930

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!