自学内容网 自学内容网

深度学习:transpose_qkv()与transpose_output()

transpose_qkv 函数的主要作用是将输入的张量重新排列,使其适合多头注意力的计算。具体来说,它将输入张量的形状从 (batch_size, seq_len, num_hiddens) 转换为 (batch_size * num_heads, seq_len, num_hiddens // num_heads)

详细步骤

  • 输入形状
    假设输入的张量形状为 (batch_size, seq_len, num_hiddens),其中:
    batch_size 是批次大小。
    seq_len 是序列长度。
    num_hiddens 是隐藏层的维度。

  • 拆分多头
    多头注意力机制将 num_hiddens 维度拆分成 num_heads 个头,每个头的维度为 num_hiddens // num_heads。

  • 重新排列
    通过重新排列张量的维度,将 (batch_size, seq_len, num_hiddens) 转换为 (batch_size * num_heads, seq_len, num_hiddens // num_heads)。

具体实现

假设 transpose_qkv 函数的实现如下:

def transpose_qkv(X, num_heads):
    # X: (batch_size, seq_len, num_hiddens)
    batch_size, seq_len, num_hiddens = X.shape
    num_hiddens_per_head = num_hiddens // num_heads
    
    # 将 num_hiddens 维度拆分成 num_heads 个头
    X = X.reshape(batch_size, seq_len, num_heads, num_hiddens_per_head)
    
    # 交换维度,使得每个头的数据连续排列
    X = X.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, num_hiddens_per_head)
    
    # 将 batch_size 和 num_heads 合并
    X = X.reshape(batch_size * num_heads, seq_len, num_hiddens_per_head)
    
    return X
  • 解释
    1. 拆分维度:
      X.reshape(batch_size, seq_len, num_heads, num_hiddens_per_head):
      将 num_hiddens 维度拆分成 num_heads 个头,每个头的维度为 num_hiddens_per_head。
      此时,X 的形状为 (batch_size, seq_len, num_heads, num_hiddens_per_head)。
    2. 交换维度:
      X.permute(0, 2, 1, 3):
      将 num_heads 维度移到第二个位置,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, num_heads, seq_len, num_hiddens_per_head)。
    3. 合并维度:
      X.reshape(batch_size * num_heads, seq_len, num_hiddens_per_head):
      将 batch_size 和 num_heads 合并,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size * num_heads, seq_len, num_hiddens_per_head)。

总结

transpose_qkv 函数通过以下步骤将输入张量重新排列,使其适合多头注意力的计算:

  • 将 num_hiddens 维度拆分成 num_heads 个头。

  • 交换维度,使得每个头的数据连续排列。

  • 合并 batch_size 和 num_heads 维度,使得每个头的数据连续排列。

最终,transpose_qkv 函数返回形状为 (batch_size * num_heads, seq_len, num_hiddens // num_heads) 的张量,以便进行多头注意力计算。

transpose_output 函数的主要作用是将多头注意力的输出重新排列,使其适合后续的处理。具体来说,它将输入张量的形状从 (batch_size * num_heads, seq_len, num_hiddens // num_heads) 转换为 (batch_size, seq_len, num_hiddens)

具体实现

假设 transpose_output 函数的实现如下:

def transpose_output(X, num_heads):
    # X: (batch_size * num_heads, seq_len, num_hiddens_per_head)
    batch_size_times_num_heads, seq_len, num_hiddens_per_head = X.shape
    batch_size = batch_size_times_num_heads // num_heads
    
    # 将 batch_size 和 num_heads 拆分
    X = X.reshape(batch_size, num_heads, seq_len, num_hiddens_per_head)
    
    # 交换维度,使得每个头的数据连续排列
    X = X.permute(0, 2, 1, 3)  # (batch_size, seq_len, num_heads, num_hiddens_per_head)
    
    # 将 num_heads 和 num_hiddens_per_head 合并
    X = X.reshape(batch_size, seq_len, num_heads * num_hiddens_per_head)
    
    return X
  • 解释
    1. 拆分维度:
      X.reshape(batch_size, num_heads, seq_len, num_hiddens_per_head):
      将 batch_size * num_heads 维度拆分成 batch_size 和 num_heads。
      此时,X 的形状为 (batch_size, num_heads, seq_len, num_hiddens_per_head)。
    2. 交换维度:
      X.permute(0, 2, 1, 3):
      将 seq_len 维度移到第二个位置,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, seq_len, num_heads, num_hiddens_per_head)。
    3. 合并维度:
      X.reshape(batch_size, seq_len, num_heads * num_hiddens_per_head):
      将 num_heads 和 num_hiddens_per_head 合并,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, seq_len, num_hiddens)。

总结

transpose_output 函数通过以下步骤将多头注意力的输出重新排列,使其适合后续的处理:

  • 将 batch_size * num_heads 维度拆分成 batch_size 和 num_heads。

  • 交换维度,使得每个头的数据连续排列。

  • 合并 num_heads 和 num_hiddens_per_head 维度,使得每个头的数据连续排列。

最终,transpose_output 函数返回形状为 (batch_size, seq_len, num_hiddens) 的张量,以便进行后续的处理。


原文地址:https://blog.csdn.net/m0_49786943/article/details/143791237

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!