Transformer中的自注意力是怎么实现的?
在Transformer模型中,自注意力(Self-Attention)是核心组件,用于捕捉输入序列中不同位置之间的关系。自注意力机制通过计算每个标记与其他所有标记之间的注意力权重,然后根据这些权重对输入序列进行加权求和,从而生成新的表示。下面是实现自注意力机制的代码及其详细说明。
自注意力机制的实现
1. 计算注意力得分(Scaled Dot-Product Attention)
自注意力机制的基本步骤包括以下几个部分:
- 线性变换:将输入序列通过三个不同的线性变换层,得到查询(Query)、键(Key)和值(Value)矩阵。
- 计算注意力得分:通过点积计算查询与键的相似度,再除以一个缩放因子(通常是键的维度的平方根),以稳定梯度。
- 应用掩码:在计算注意力得分后,应用掩码(如果有),避免未来信息泄露(用于解码器中的自注意力)。
- 计算注意力权重:通过softmax函数将注意力得分转换为概率分布。
- 加权求和:使用注意力权重对值进行加权求和,得到新的表示。
2. 多头注意力机制(Multi-Head Attention)
为了捕捉不同子空间的特征,Transformer使用多头注意力机制。通过将查询、键和值分割成多个头,每个头独立地计算注意力,然后将所有头的输出连接起来,并通过一个线性层进行组合。
自注意力机制代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
# Scaled Dot-Product Attention
def scaled_dot_product_attention(query, key, value, mask=None):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
print(f"Scores shape: {scores.shape}") # (batch_size, num_heads, seq_length, seq_length)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
print(f"Attention weights shape: {attention_weights.shape}") # (batch_size, num_heads, seq_length, seq_length)
output = torch.matmul(attention_weights, value)
print(f"Output shape after attention: {output.shape}") # (batch_size, num_heads, seq_length, d_k)
return output, attention_weights
# Multi-Head Attention
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.linear_query = nn.Linear(d_model, d_model)
self.linear_key = nn.Linear(d_model, d_model)
self.linear_value = nn.Linear(d_model, d_model)
self.linear_out = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear projections
query = self.linear_query(query)
key = self.linear_key(key)
value = self.linear_value(value)
print(f"Query shape after linear: {query.shape}") # (batch_size, seq_length, d_model)
print(f"Key shape after linear: {key.shape}") # (batch_size, seq_length, d_model)
print(f"Value shape after linear: {value.shape}") # (batch_size, seq_length, d_model)
# Split into num_heads
query = query.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
print(f"Query shape after split: {query.shape}") # (batch_size, num_heads, seq_length, d_k)
print(f"Key shape after split: {key.shape}") # (batch_size, num_heads, seq_length, d_k)
print(f"Value shape after split: {value.shape}") # (batch_size, num_heads, seq_length, d_k)
# Apply scaled dot-product attention
x, attention_weights = scaled_dot_product_attention(query, key, value, mask)
# Concatenate heads
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
print(f"Output shape after concatenation: {x.shape}") # (batch_size, seq_length, d_model)
# Final linear layer
x = self.linear_out(x)
print(f"Output shape after final linear: {x.shape}") # (batch_size, seq_length, d_model)
return x, attention_weights
# 示例用法
d_model = 512
num_heads = 8
batch_size = 64
seq_length = 10
# 假设输入是随机生成的张量
query = torch.rand(batch_size, seq_length, d_model)
key = torch.rand(batch_size, seq_length, d_model)
value = torch.rand(batch_size, seq_length, d_model)
# 创建多头注意力层
mha = MultiHeadAttention(d_model, num_heads)
output, attention_weights = mha(query, key, value)
print("最终输出形状:", output.shape) # 最终输出形状: (batch_size, seq_length, d_model)
print("注意力权重形状:", attention_weights.shape) # 注意力权重形状: (batch_size, num_heads, seq_length, seq_length)
每一步的形状解释
-
Linear Projections:
- Query, Key, Value分别经过线性变换。
- 形状:[batch_size, seq_length, d_model]
-
Split into Heads:
- 将Query, Key, Value分割成多个头。
- 形状:[batch_size, num_heads, seq_length, d_k],其中d_k = d_model // num_heads
-
Scaled Dot-Product Attention:
- 计算注意力得分(Scores)。
- 形状:[batch_size, num_heads, seq_length, seq_length]
- 计算注意力权重(Attention Weights)。
- 形状:[batch_size, num_heads, seq_length, seq_length]
- 使用注意力权重对Value进行加权求和。
- 形状:[batch_size, num_heads, seq_length, d_k]
-
Concatenate Heads:
- 将所有头的输出连接起来。
- 形状:[batch_size, seq_length, d_model]
-
Final Linear Layer:
- 通过一个线性层将连接的输出转换为最终的输出。
- 形状:[batch_size, seq_length, d_model]
通过这种方式,我们可以清楚地看到每一步变换后的张量形状,理解自注意力和多头注意力机制的具体实现细节。
代码说明
- scaled_dot_product_attention:实现了缩放点积注意力机制,计算查询和键的点积,应用掩码,计算softmax,然后使用权重对值进行加权求和。
- MultiHeadAttention:实现了多头注意力机制,包括线性变换、分割、缩放点积注意力和最后的线性变换。
多头注意力机制的细节
- 线性变换:将输入序列通过线性层转换为查询、键和值的矩阵。
- 分割头:将查询、键和值的矩阵分割为多个头,每个头的维度是[batch_size, num_heads, seq_length, d_k]。
- 缩放点积注意力:对每个头分别计算缩放点积注意力。
- 连接头:将所有头的输出连接起来,得到[batch_size, seq_length, d_model]的张量。
- 线性变换:通过一个线性层将连接的输出转换为最终的输出。
原文地址:https://blog.csdn.net/SisterRu/article/details/140534709
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!