Transformer之位置编码的通俗理解
为什么需要位置编码
在之前介绍的:
Transformer之Token的通俗理解
Transformer之Attention的通俗理解
两篇文章中,我们介绍了Token被作为一个整体送入Attention中进行计算,这样才能得到各个Token之间的关联。
在NLP中,词语的顺序至关重要,比如说"爱做"和"做爱",相同的词语所表达的意思却天差地别,所以编码器会把带有顺序信息的向量一同送入Attention中;在CV中,图像被nn.Conv2d切成一个个小块,然后把小块变成
[
B
,
1
,
1
,
C
]
[B, 1, 1, C]
[B,1,1,C]的点,这些点共同构成送入Attention的patch_embedding,虽然对顺序的要求没有那么高,但是也是有一定要求的。
所以就需要体现顺序的位置编码,融合进要送入Attention的Token之中。
位置编码的本质
位置编码本身是一种偏置,这样就能看出来了
Q
u
e
r
y
=
W
q
×
(
X
T
o
k
e
n
+
P
E
)
=
W
q
×
X
T
o
k
e
n
+
W
q
×
P
E
\begin{array}{ccl} Query &= &W_q \times \Big(X_{Token} + PE \Big) \\ &&\\ & = & W_q \times X_{Token} + W_q \times PE \end{array}
Query==Wq×(XToken+PE)Wq×XToken+Wq×PE
也就是说
X
T
o
k
e
n
X_{Token}
XToken是在做一维平面运动,那么我们就可以这么说,位置编码除了提供序列的时空信息,还具有提供偏置信息的功能。
位置编码是如何起作用的?
这里是重点,建议仔细看看
这里有两句话
“我爱你”和“你爱我”
我(0.5) | 爱(0.6) | 你(0.7) | |
---|---|---|---|
我(0.5) | 0.25 | 0.3 | 0.35 |
爱(0.6) | 0.3 | 0.36 | 0.42 |
你(0.7) | 0.35 | 0.42 | 0.49 |
通过注意力计算,只能计算出“我”、“爱”、“你”之间的关联,例如,如果"我"的嵌入向量是
V
e
c
我
Vec_{我}
Vec我,不管在“我爱你”还是“你爱我”中,其嵌入向量和其他向量所计算的值都是相同的,正如上表所示“爱我” 和 “我爱”的注意力值都是0.3,这就无法区分出语义了;
但是,但是啊,如果加上位置编码,那么:
在“我爱你”中,“我”的嵌入向量就变成了
V
我
+
P
E
1
V_{我}+PE_1
V我+PE1,
在“你爱我”中,“我”的嵌入向量就变成了
V
我
+
P
E
3
V_{我}+PE_3
V我+PE3,
这时,相同的Token所计算到的注意力值就是不同的,因为位置变了,假设
P
E
0
=
0
,
P
E
1
=
1
,
P
E
2
=
2
PE_{0}=0, PE_1=1, PE_2=2
PE0=0,PE1=1,PE2=2,那么上表就会变成下面这张表:
我(0.5+0=0.5) | 爱(0.6+1=1.6) | 你(0.7+2=2.7) | |
---|---|---|---|
我(0.5) | 0.25 | 0.8 | 1.35 |
爱(0.6) | 0.3 | 0.96 | 1.62 |
你(0.7) | 0.35 | 1.12 | 1.89 |
从上表上可以看出,当加上位置编码之后,“爱我”的注意力值是0.8,而“我爱”的注意力值就变成了0.3,这样就很能区分出来是哪种含义了。
Token如何与Position-Embeding融合
通常来说是有两种方法,一种是把Position-Embeding(以后都称之为PE)和Token直接相加,另一种是PE和Token做阿达玛积(对应位置一一相乘),如图所示,其中PE需要具有与Token相同的维度
位置编码有哪些?
1.绝对位置编码
三角式绝对位置编码
P
E
=
{
s
i
n
(
n
1000
0
2
×
i
D
T
o
k
e
n
)
,
d
=
2
i
c
o
s
(
n
1000
0
2
×
i
D
T
o
k
e
n
)
,
d
=
2
i
+
1
PE= \begin{cases} sin\Big(\frac{n}{10000^{2 \times \frac{i}{D_{Token}}}} \Big), & d=2i \\ & \\ cos \Big(\frac{n}{10000^{2 \times \frac{i}{D_{Token}}}} \Big), & d=2i+1 \end{cases}
PE=⎩
⎨
⎧sin(100002×DTokenin),cos(100002×DTokenin),d=2id=2i+1
具体形式如图所示,Token的维度是
[
B
,
N
,
D
i
m
]
[B, N, Dim]
[B,N,Dim],对应的PE也是
[
B
,
N
,
D
i
m
]
[B, N, Dim]
[B,N,Dim]
学习式位置编码
这时最简单的一种位置编码,例如Token的维度是
[
B
,
N
,
D
i
m
]
[B, N, Dim]
[B,N,Dim],那么就在__init__()函数中用nn.Parameter()生成一个维度为
[
B
,
N
,
D
i
m
]
[B, N, Dim]
[B,N,Dim]的初始位置编码,然后在训练中参与更新,最后学习到一组位置编码。
简单如是:
import torch
import timm
# 为什么没有用[B, N, Dim]呢,因为加上Token的时候,PE会因广播机制而复制
# 所以,本质上还是[B, N, Dim]
self.absolute_position_embedding = nn.Parameter(torch.zeros(1, N, Dim))
timm.models.layers.trunc_normal_(self.absolute_position_embedding, std=0.02)
两种绝对编码的对比
- 三角式相对于学习式具有良好的外推性:
三角式的位置编码具有三角函数的周期性,所以当文本或者patch等Token在Inference的长度要比在Train中要长上数倍时,位置编码可以周期性增长,也会有相对良好的效果;
- 三角式位置编码每词向量对应的位置编码( [ 1 , D i m ] [1 , Dim] [1,Dim])之间是正交的,这就意味着他们之间是相互独立的,不会相互干扰;
- 三角式位置编码因为是周期函数,而且频率很高(周期很长),可以容纳相当多的Tokens;
- 三角式位置编码除了关注了相对关系的距离,还有相对关系的角度信息,包含得更丰富;
- 两者都是关注单个位置信息,在输入层之上,简单地和输入向量(Token)相加,区别于相对位置模型,往往是信息对(增加了位置信息的维度);
相对位置编码
A
t
t
=
s
o
f
t
m
a
x
(
Q
×
K
T
D
i
m
+
r
e
l
a
t
i
v
e
_
p
o
s
i
t
i
o
n
_
b
i
a
s
)
×
V
Att = softmax\Big( \frac{Q \times K^T}{\sqrt{Dim}} + relative\_position\_bias\Big) \times V
Att=softmax(DimQ×KT+relative_position_bias)×V
相对位置编码主要是对序列中元素的相对位置关系处理得会更好,但是处理方式也就和绝对位置编码不同了,上面的relative_position_bias就是所谓的相对位置编码
- 绝对位置编码:是一个矩阵,加在Token上的
- 相对位置编码:是一个矩阵,加在注意力得分上的
Q
E
+
P
E
×
K
E
+
P
E
T
=
X
E
+
P
E
×
W
q
×
[
X
E
+
P
E
×
W
k
]
T
=
X
E
+
P
E
×
W
q
×
W
k
T
×
X
E
+
P
E
T
=
(
X
q
+
P
E
q
)
×
W
q
×
W
k
T
×
(
X
k
+
P
E
k
)
T
=
X
q
×
W
q
⏞
Q
u
e
r
y
×
W
k
T
×
X
k
T
⏞
K
e
y
⏟
第一项
+
P
E
q
×
W
q
⏞
a
×
W
k
T
×
X
k
T
⏞
K
e
y
⏟
第二项
+
X
q
×
W
q
⏞
Q
u
e
r
y
×
W
k
T
×
P
E
k
T
⏞
b
⏟
第三项
+
P
E
q
×
W
q
⏞
a
×
W
k
T
×
P
E
k
T
⏞
b
⏟
第四项
\begin{array}{ccl} Q_{E+PE} \times K_{E+PE}^T &= & X_{E + PE} \times W_q \times \Big[X_{E + PE} \times W_k \Big]^T \\ && \\ &= & X_{E + PE} \times W_q \times W_k^T \times X^T_{E + PE} \\ && \\ & = &(X_q+PE_q) \times W_q \times W_k^T \times (X_k+PE_k)^T \\ &&\\ &= &\underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第一项}+ \underbrace{ \overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第二项} + \underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第三项} + \underbrace{\overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第四项} \end{array}
QE+PE×KE+PET====XE+PE×Wq×[XE+PE×Wk]TXE+PE×Wq×WkT×XE+PET(Xq+PEq)×Wq×WkT×(Xk+PEk)T第一项
Xq×Wq
Query×WkT×XkT
Key+第二项
PEq×Wq
a×WkT×XkT
Key+第三项
Xq×Wq
Query×WkT×PEkT
b+第四项
PEq×Wq
a×WkT×PEkT
b
从这里可以看出,除了第一项是一个二次型,其余三项都是跟位置偏置相关的一次变换,也就是一个数据和数据本身有关,其他的都是跟位置编码有关,我们想办法把两种PE:
P
E
q
PE_q
PEq和
P
E
k
PE_k
PEk通过换元法,换成一个:
P
E
q
×
W
q
×
W
k
T
×
P
E
k
T
=
P
E
q
×
W
×
P
E
k
T
=
P
E
q
×
W
×
[
P
E
q
−
(
P
E
q
−
P
E
k
)
]
T
\begin{array}{ccl} PE_q \times W_q \times W_k^T \times PE^T_k &= &PE_q \times W \times PE^T_k \\ && \\ & = &PE_q \times W \times [PE_q - (PE_{q}-PE_{k} )]^T \\ \end{array}
PEq×Wq×WkT×PEkT==PEq×W×PEkTPEq×W×[PEq−(PEq−PEk)]T
我们可以发现,只需要
Q
u
e
r
y
Query
Query的PE和
Q
u
e
r
y
Query
Query与
K
e
y
Key
Key的PE的相对位置,就可以计算了
那么是不是可以这样:
- 首先计算维度为 [ B , N q , C ] [B, N_q, C] [B,Nq,C]的 Q Q Q和 [ B , N k , C ] [B, N_k, C] [B,Nk,C] 的 K K K
- 然后计算维度为 [ B , N q , N k ] [B, N_q, N_k] [B,Nq,Nk]的 Q × K T Q \times K^T Q×KT
- 最后加上相对位置编码
具体过程如下:
- 生成index矩阵对,分别对应下面两张图
coords_h = torch.arange(h)
coords_w = torch.arange(w)
coords = torch.meshgrid([coord_h, coord_w])
2. 拼接在一起,形成一张图,如下图所示
coords = torch.stack(coords)
3. 拉平,维度变成
[
2
,
W
×
H
]
[2, W\times H]
[2,W×H]
coords_flatten = torch.flatten(coords, 1)
4. 转置相减,利用BroadCast机制去复制维度
# 这里的None就是添加了一个维度
relative_coords = coords_flatten[:, :,None] - coords_flatten[:, None, :]
并得到结果如图:
5. 增加偏置
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += W -1
relative_coords[:, :, 1] += H -1
relative_coords[:, :, 0] *= 2 * H -1
- 求和,把上图括号里面的东西求和
relative_position_index = relative_coords.sun(-1)
以上都是索引,那么位置编码如何生成呢?
self.relative_position_bias_table = nn.Parameter(torch.zeros(2*W-1, 2*H-1, num_heads))
具体就是利用上面生成的Index来索引这里的relative_position_bias_table
relative_position_bias = self.relative_position_bias_table(relative_coords)
原文地址:https://blog.csdn.net/Soonki/article/details/140549929
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!