自学内容网 自学内容网

Transformer - 注意⼒机制

Transformer - 注意⼒机制

flyfish

计算过程

flyfish

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import math

def attention(query, key, value, mask=None, dropout=None):

     # query的最后⼀维的⼤⼩, ⼀般情况下就等同于词嵌⼊维度, 命名为d_k
     d_k = query.size(-1)

     scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
     print("scores.shape:",scores.shape)#scores.shape: torch.Size([1, 12, 12])

     if mask is not None:
         scores = scores.masked_fill(mask == 0, -1e9)


     p_attn = F.softmax(scores, dim = -1)

     if dropout is not None:
         p_attn = dropout(p_attn)

     return torch.matmul(p_attn, value), p_attn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

       
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x +  self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)
#在测试attention的时候需要位置编码PositionalEncoding


# 词嵌⼊维度是8维
d_model = 8
# 置0⽐率为0.1
dropout = 0.1
# 句⼦最⼤⻓度
max_len=12

x = torch.zeros(1, max_len, d_model)
pe = PositionalEncoding(d_model, dropout, max_len)
                           
pe_result = pe(x)

print("pe_result:", pe_result)
query = key = value = pe_result
print("pe_result.shape:",pe_result.shape)

#没有mask的输出情况
#pe_result.shape: torch.Size([1, 12, 8])
attn, p_attn = attention(query, key, value)
print("no mask\n")
print("attn:", attn)
print("p_attn:", p_attn)

#scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 除以math.sqrt(d_k) 表示这个注意力就是 缩放点积注意力,如果没有,那么就是 点积注意力
#当Q=K=V时,又叫⾃注意⼒机制

#有mask的输出情况
print("mask\n")
mask = torch.zeros(1, max_len, max_len)
attn, p_attn = attention(query, key, value, mask=mask)
print("attn:", attn)
print("p_attn:", p_attn)
pe_result: tensor([[[ 0.0000e+00,  1.1111e+00,  0.0000e+00,  1.1111e+00,  0.0000e+00,
           1.1111e+00,  0.0000e+00,  1.1111e+00],
         [ 9.3497e-01,  6.0034e-01,  1.1093e-01,  1.1056e+00,  1.1111e-02,
           1.1111e+00,  1.1111e-03,  1.1111e+00],
         [ 1.0103e+00, -4.6239e-01,  2.2074e-01,  1.0890e+00,  2.2221e-02,
           0.0000e+00,  2.2222e-03,  1.1111e+00],
         [ 1.5680e-01, -1.1000e+00,  0.0000e+00,  1.0615e+00,  3.3328e-02,
           0.0000e+00,  3.3333e-03,  1.1111e+00],
         [-8.4089e-01, -7.2627e-01,  4.3269e-01,  1.0234e+00,  4.4433e-02,
           1.1102e+00,  4.4444e-03,  1.1111e+00],
         [-1.0655e+00,  3.1518e-01,  5.3270e-01,  0.0000e+00,  5.5532e-02,
           1.1097e+00,  5.5555e-03,  1.1111e+00],
         [-3.1046e-01,  1.0669e+00,  6.2738e-01,  9.1704e-01,  0.0000e+00,
           1.1091e+00,  6.6666e-03,  0.0000e+00],
         [ 7.2999e-01,  8.3767e-01,  7.1580e-01,  8.4982e-01,  7.7714e-02,
           1.1084e+00,  7.7777e-03,  1.1111e+00],
         [ 1.0993e+00, -1.6167e-01,  7.9706e-01,  7.7412e-01,  8.8794e-02,
           1.1076e+00,  8.8888e-03,  1.1111e+00],
         [ 4.5791e-01, -0.0000e+00,  8.7036e-01,  6.9068e-01,  9.9865e-02,
           1.1066e+00,  9.9999e-03,  1.1111e+00],
         [-6.0447e-01, -9.3230e-01,  9.3497e-01,  6.0034e-01,  1.1093e-01,
           1.1056e+00,  1.1111e-02,  1.1111e+00],
         [-1.1111e+00,  4.9174e-03,  9.9023e-01,  5.0400e-01,  1.2198e-01,
           1.1044e+00,  1.2222e-02,  1.1110e+00]]])
pe_result.shape: torch.Size([1, 12, 8])
scores.shape: torch.Size([1, 12, 12])
no mask

attn: tensor([[[ 1.0590e-01,  2.7361e-01,  4.9333e-01,  8.3999e-01,  5.0599e-02,
           1.0079e+00,  5.6491e-03,  1.0138e+00],
         [ 2.7554e-01,  2.0916e-01,  4.9203e-01,  8.6593e-01,  5.2177e-02,
           9.7066e-01,  5.6513e-03,  1.0398e+00],
         [ 2.8765e-01, -3.8825e-02,  4.7812e-01,  8.7535e-01,  5.4246e-02,
           8.4157e-01,  5.7015e-03,  1.0659e+00],
         [ 9.3666e-02, -1.8286e-01,  4.8727e-01,  8.5124e-01,  5.7070e-02,
           8.2547e-01,  5.9523e-03,  1.0712e+00],
         [-1.6747e-01, -1.0274e-01,  5.6960e-01,  7.7584e-01,  6.3699e-02,
           9.6958e-01,  6.7169e-03,  1.0546e+00],
         [-2.2646e-01,  6.8462e-02,  5.8668e-01,  7.2227e-01,  6.3119e-02,
           1.0233e+00,  6.8004e-03,  1.0310e+00],
         [ 8.8945e-04,  2.7654e-01,  5.3750e-01,  8.0958e-01,  5.2289e-02,
           1.0259e+00,  6.1360e-03,  9.6094e-01],
         [ 2.2231e-01,  2.2832e-01,  5.2263e-01,  8.4111e-01,  5.4828e-02,
           9.9655e-01,  5.9765e-03,  1.0298e+00],
         [ 2.6388e-01,  7.2239e-02,  5.3800e-01,  8.4070e-01,  5.8958e-02,
           9.5033e-01,  6.2306e-03,  1.0564e+00],
         [ 1.2822e-01,  7.4518e-02,  5.5305e-01,  8.1381e-01,  6.0125e-02,
           9.7442e-01,  6.4089e-03,  1.0462e+00],
         [-1.5757e-01, -1.3194e-01,  5.9562e-01,  7.6069e-01,  6.7079e-02,
           9.7264e-01,  7.0187e-03,  1.0607e+00],
         [-2.3505e-01,  5.6245e-03,  6.0160e-01,  7.3040e-01,  6.5491e-02,
           1.0176e+00,  7.0038e-03,  1.0367e+00]]])
p_attn: tensor([[[0.1488, 0.1215, 0.0514, 0.0396, 0.0698, 0.0703, 0.0875, 0.1205,
          0.0790, 0.0814, 0.0544, 0.0757],
         [0.1170, 0.1434, 0.0757, 0.0489, 0.0590, 0.0460, 0.0642, 0.1304,
          0.1161, 0.0943, 0.0527, 0.0524],
         [0.0716, 0.1094, 0.1341, 0.1067, 0.0716, 0.0379, 0.0407, 0.0930,
          0.1221, 0.0921, 0.0713, 0.0494],
         [0.0597, 0.0765, 0.1155, 0.1397, 0.1127, 0.0506, 0.0359, 0.0627,
          0.0918, 0.0806, 0.1056, 0.0688],
         [0.0692, 0.0607, 0.0509, 0.0740, 0.1475, 0.0846, 0.0509, 0.0607,
          0.0692, 0.0788, 0.1342, 0.1194],
         [0.0887, 0.0601, 0.0343, 0.0423, 0.1076, 0.1341, 0.0721, 0.0748,
          0.0591, 0.0777, 0.1057, 0.1435],
         [0.1232, 0.0938, 0.0411, 0.0335, 0.0722, 0.0804, 0.1351, 0.1103,
          0.0722, 0.0814, 0.0633, 0.0935],
         [0.1124, 0.1263, 0.0623, 0.0388, 0.0571, 0.0553, 0.0731, 0.1388,
          0.1134, 0.1001, 0.0571, 0.0652],
         [0.0758, 0.1157, 0.0841, 0.0584, 0.0670, 0.0450, 0.0492, 0.1166,
          0.1429, 0.1101, 0.0763, 0.0588],
         [0.0822, 0.0989, 0.0668, 0.0540, 0.0803, 0.0622, 0.0584, 0.1084,
          0.1158, 0.1046, 0.0879, 0.0804],
         [0.0548, 0.0551, 0.0515, 0.0705, 0.1364, 0.0845, 0.0454, 0.0617,
          0.0801, 0.0877, 0.1499, 0.1224],
         [0.0763, 0.0548, 0.0357, 0.0459, 0.1213, 0.1146, 0.0669, 0.0703,
          0.0616, 0.0802, 0.1224, 0.1499]]])
mask

scores.shape: torch.Size([1, 12, 12])
attn: tensor([[[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185]]])
p_attn: tensor([[[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833]]])

原文地址:https://blog.csdn.net/flyfish1986/article/details/137272859

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!