【pytorch-03】:自动微分模块
文章目录
1 自动微分模块
自动微分(Autograd)模块对张量做了进一步的封装,具有自动求导功能。
自动微分模块是构成神经网络训练的必要模块,在神经网络的反向传播过程中,Autograd 模块基于正向计算的结果对当前的参数进行微分计算,从而实现网络权重参数的更新。
1.1 梯度基本运算
- 采用backward()方法可以进行自动微分
- 采用backward()方法需要f是一个标量,如果不是标量就需要传入一个gradient参数,它是形状匹配的张量(后续在另一篇博客中详述)
1.1.1 标量的梯度计算
import torch
# 1. 标量的梯度计算
# y = x**2 + 20
def test01():
# 1.对于需要求导的张量需要设置 requires_grad = True
# 类型一般设置为torch.float64
x = torch.tensor(10, requires_grad=True, dtype=torch.float64)
# 2.对x的中间计算
# 定义关于x的函数
f = x ** 2 + 20 # 2x
# 3.自动微分
# 调用backward()之后,会根据f进行求导
f.backward()
# 访问梯度
# 求得方程在x处的梯度
print(x.grad)
1.1.2 向量的梯度计算
# 2. 向量的梯度计算
# y = x**2 + 20
def test02():
x = torch.tensor([10, 20, 30, 40], requires_grad=True, dtype=torch.float64)
# 定义变量的计算过程
# y1得到的是一个向量,需要将其处理为标量才能使用backward()进行自动微分
# 采用y1.mean() - 取均值 - 转换为标量
# 对向量进行求导,就相当于对向量中的每个分量都求导
y1 = x ** 2 + 20
# 注意: 自动微分的时候,必须是一个标量
y2 = y1.mean() # 1/4 * y1 ==> 1/4 * 2x
# 自动微分
# 反向传播
y2.backward()
# 打印梯度值
# 梯度计算的结果会保存到x.grad中
print(x.grad) # tensor([ 5., 10., 15., 20.], dtype=torch.float64)
1.1.3 多标量梯度计算
# 3. 多标量梯度计算
# y = x1**2 + x2**2 + x1*2
def test03():
x1 = torch.tensor(10, requires_grad=True, dtype=torch.float64)
x2 = torch.tensor(20, requires_grad=True, dtype=torch.float64)
# 中间计算过程
y = x1**2 + x2**2 + x1*x2
# 自动微分
y.backward()
# 打印梯度值
print(x1.grad)
print(x2.grad)
1.1.4 多向量的梯度计算
# 4. 多向量的梯度计算
def test04():
x1 = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)
x2 = torch.tensor([30, 40], requires_grad=True, dtype=torch.float64)
# 定义中间计算过程
y = x1**2 + x2**2 + x1*x2
# 将输出结果变为标量
y = y.sum()
# 自动微分
y.backward()
# 打印张量的梯度值
print(x1.grad)
print(x2.grad)
1.2 梯度的控制计算
- 模型训练的时候需要进行梯度计算
- 模型训练完成,进入到另一个阶段之后,不需要进行梯度计算,由此需要控制梯度的计算
1.2.1 控制梯度计算
- 可以通过一定的方式控制是否需要对每个函数进行梯度计算
import torch
# 1. 控制梯度计算
def test01():
# 创建张量x
# requires_grad - 表示x需要进行梯度计算
x = torch.tensor(10, requires_grad=True, dtype=torch.float64)
print(x.requires_grad)
# 1. 第一种方法
# 只想要计算y = x ** 2的数值,而不想计算这个函数的梯度
with torch.no_grad():
y = x**2
print(y.requires_grad) # False,表示不对这个函数进行梯度计算
# 2. 第二种方式,主要针对函数
@torch.no_grad()
def my_func(x):
return x**2
y = my_func(x)
print(y.requires_grad) # False,表示不对这个函数进行梯度计算
# 3. 第三种方式: 全局的方式
torch.set_grad_enabled(False)
y = x ** 2
print(y.requires_grad)
1.2.2 累计梯度和梯度清零
# 2. 累计梯度和梯度清零
def test02():
x = torch.tensor([10, 20, 30, 40], requires_grad=True, dtype=torch.float64)
# 当我们重复对x进行梯度计算的时候,是会将历史的梯度值累加到 x.grad 属性中
# 希望不要去累加历史梯度
for _ in range(10): # 函数在x处计算10次梯度
# 对输入x的计算过程
f1 = x**2 + 20
# 将向量转换为标量
f2 = f1.mean()
# 梯度清零,防止梯度进行累加
if x.grad is not None:
x.grad.data.zero_()
# 自动微分
f2.backward()
print(x.grad)
- 没有设置x.grad.data.zero(),重复对x进行梯度计算的时候,是会将历史的梯度值累加到 x.grad 属性中
1.2.3 案例 - 梯度下降优化函数
y = x ** 2
当x为什么值得时候,y最小
- 初始化x值
- 沿着梯度方向进行迭代,设置迭代次数
# 3. 案例-梯度下降优化函数
def test03():
# 1.初始化
x = torch.tensor(10, requires_grad=True, dtype=torch.float64)
# 进行5000次迭代
for _ in range(5000):
# 正向计算,函数
y = x**2
# 梯度清零
if x.grad is not None:
x.grad.data.zero_()
# 自动微分,得到x处得梯度值
y.backward()
# 更新参数,对x得值进行更新
x.data = x.data - 0.001 * x.grad
# 打印 x 的值
print('%.10f' % x.data)
对于深度学习:
- y通常为损失函数
- 初始化权重参数,求损失函数在指定权重位置处的梯度
- 更新权重参数
1.3 梯度的计算注意事项
当对设置 requires_grad=True 的张量使用 numpy 函数进行转换时, 会出现如下报错:
Can’t call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
此时, 需要先使用 detach 函数将张量进行分离, 再使用 numpy 函数.
注意: detach 之后会产生一个新的张量, 新的张量作为叶子结点,并且该张量和原来的张量共享数据, 但是分离后的张量不需要计算梯度。
b = a.detach()
a 和 b 共享数据
对a的所有操作都会影响到梯度计算
对b的所有操作都不会影响梯度计算
相当于将一个张量分离出一个数据相同但是用途不同的张量
import torch
# 1. 演示下错误
def test01():
x = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)
# RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
# print(x.numpy())
# 下面的是正确的操作
print(x.detach().numpy())
# 2. 共享数据
def test02():
# x是叶子结点
x1 = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)
# 使用detach 函数分离出一个新的张量
x2 = x1.detach()
print(id(x1.data), id(x2.data))
# 修改分离后产生的新的张量
x2[0] = 100
print(x1)
print(x2)
# 通过结果我们发现,x2 张量不存在 requires_grad=True
# 表示:对 x1 的任何计算都会影响到对 x1 的梯度计算
# 但是,对 x2 的任何计算不会影响到 x1 的梯度计算
print(x1.requires_grad)
print(x2.requires_grad)
原文地址:https://blog.csdn.net/weixin_51385258/article/details/143908012
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!