机器学习周报(RNN的梯度消失和LSTM缓解梯度消失公式推导)
摘要
在深度学习领域,循环神经网络(Recurrent Neural Network, RNN)被广泛应用于处理序列数据,特别是在自然语言处理、时间序列预测等任务中。然而,传统的RNN在长序列数据学习过程中容易出现梯度消失和梯度爆炸问题,使得模型难以捕捉长时间依赖性。梯度消失问题源于RNN的反向传播算法中,多次矩阵相乘导致梯度指数级衰减,从而影响模型性能。为解决这一问题,长短期记忆网络(Long Short-Term Memory, LSTM)应运而生。LSTM通过设计特殊的门结构(输入门、遗忘门和输出门)以及引入细胞状态的传播,有效缓解了梯度消失现象。本文推导了RNN梯度消失的数学公式,并详细说明了LSTM如何利用门结构保持梯度稳定性,从而捕捉长时间依赖。
Abstract
Recurrent Neural Networks (RNNs) are widely used in deep learning for handling sequential data, particularly in tasks such as natural language processing and time series forecasting. However, traditional RNNs often encounter the vanishing and exploding gradient problem when learning from long sequences, which hinders their ability to capture long-term dependencies. The vanishing gradient problem arises in RNNs due to multiple matrix multiplications during backpropagation, causing exponential decay of gradients and impacting model performance. To address this issue, Long Short-Term Memory (LSTM) networks were developed. LSTM alleviates gradient vanishing by introducing specially designed gate structures—input gate, forget gate, and output gate—along with a cell state that propagates through time. This paper derives the mathematical basis for the vanishing gradient in RNNs and explains how LSTM leverages gate structures to maintain gradient stability, enabling the model to capture long-term dependencies effectively.
1 RNN的梯度消失问题
- RNN的缺点
当序列太长时,容易产生梯度消失,参数更新只能捕捉到局部以来关系,没法再捕捉序列之间长期的关联或依赖关系。
如图为RNN连接,输入x,输出o(简单线性输出),权重w,s为生成状态。
根据前向传播可得:
假设使用平方误差作为损失函数,对单个时间点进行求梯度,假设再t=3时刻,损失函数为
L
3
=
1
2
(
Y
3
−
O
3
)
2
L_3=\frac{1}{2}(Y_3-O_3)^2
L3=21(Y3−O3)2,然后根据网络参数Wx,Ws,Wo,b1,b2等求梯度。
- W o W_o Wo求梯度得:
- W x W_x Wx求梯度得:
具体求解过程:首先,所求目标为 L 3 L_3 L3对 W x W_x Wx的偏导,通过链式法则进行展开。对比前向传播公式图可知, O 3 O_3 O3中并不能直接对 W x W_x Wx求偏导,而是包含在 S 3 S_3 S3中,所以要展开成如下形式。
但在 S 3 S_3 S3中又包含 S 2 S_2 S2, S 2 S_2 S2中包含 W x W_x Wx和 S 1 S_1 S1, S 1 S_1 S1中又包含 W x W_x Wx,嵌套了很多层,为了方便表示,我们用 θ 3 \theta_3 θ3来表示 S 3 S_3 S3括号中的内容。进一步简化可得:
由
S
3
S_3
S3演变为
S
2
S_2
S2,同理可递推求出
∂
S
2
∂
W
x
\frac{\partial{S_2}}{\partial{W_x}}
∂Wx∂S2和
∂
S
1
∂
W
x
\frac{\partial{S_1}}{\partial{W_x}}
∂Wx∂S1
梯度的更新同时依赖于x3,x2,x1包括其梯度值。
此为t=3时刻的梯度公式,推广至任意时刻的梯度公式为:
此式括号中的项为求导的连乘,此处求出的导数是介于0-1之间的,有一定的机率导致梯度消失(但非主要原因)。
造成梯度消失和梯度爆炸的主要原因是最后一项:当
W
s
W_s
Ws很小的时候,它的k-1的次方会无限接近于0,而当
W
s
W_s
Ws大于1时,它的k-1次方会很大。
如下为t=20时梯度更新计算的结果:
从式中可以看出,t=3的节点由于连乘过多导致梯度消失(t=3时的信息,
x
3
x_3
x3所乘的
W
s
17
W_s^{17}
Ws17由于
W
s
W_s
Ws介于0,1之间,已经非常接近于0),无法将信息传给t=20,因此t=20的更新无法引入t=3时的信息,认为t=20节点跟t=3的节点无关联。
对于梯度爆炸和梯度消失,可以通过梯度修剪来解决。相对于梯度爆炸,梯度消失更难解决。而LSTM很好的解决了这些问题。
2 LSTM缓解梯度消失
此过程为公式推导(以求 W x f W_{xf} Wxf为例)
故得 ∂ L 1 ∂ W x f \frac{\partial{L_1}}{\partial{W_{xf}}} ∂Wxf∂L1
其中
∂
C
t
∂
C
t
−
1
\frac{\partial{C_t}}{\partial{C_{t-1}}}
∂Ct−1∂Ct
通过调节
W
h
f
W_{hf}
Whf,
W
h
i
W_{hi}
Whi,
W
h
g
W_{hg}
Whg的值,可以灵活控制
C
t
C_t
Ct对
C
t
−
1
C_{t-1}
Ct−1的偏导值,当要从n时刻长期记忆某个东西到m时刻时,该路径上的
∏
t
=
n
m
∂
C
t
∂
C
t
−
1
\quad \prod_{t=n}^m\frac{\partial{C_t}}{\partial{C_{t-1}}}
∏t=nm∂Ct−1∂Ct
≈
\approx
≈ 1×1×1…×1=1从而大大缓解了梯度消失。
总结
传统RNN在处理长序列数据时,由于重复矩阵相乘使梯度呈指数级衰减,导致梯度消失问题。为此,RNN模型难以学习序列中远距离位置的依赖信息。通过对RNN的梯度推导可以看出,当模型深度较大时,梯度逐渐趋向于零,最终导致模型无法学习有效特征。LSTM网络通过引入细胞状态和多个门控机制来缓解这一问题。细胞状态在序列传递中起到信息通路的作用,门控机制则控制信息的增删过程,使得梯度的传递得以有效保留。通过这样的设计,LSTM能够在长序列任务中稳定地传递梯度,从而有效捕捉长时间依赖关系。
原文地址:https://blog.csdn.net/weixin_51923997/article/details/143452721
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!