自学内容网 自学内容网

Position Embedding总结和Pytorch实现

出现背景

自注意力机制处理数据,并不是采用类似RNN或者LSTM那种递归的结构,这使得模型虽然能够同时查看输入序列中的所有元素(即并行运算),但是也导致了没办法获取当前word在序列种的位置信息,使模型对顺序信息捕捉很差。

PE

位置编码公式

在这里插入图片描述

思路

采用sin和cos函数对word的每一维上进行唯一编码,这样每个word都得到了自己的位置编码信息,并且由于sin和cos都是连续函数,所以针对pos相近的word,他们的位置编码信息也是比较相近的,这样序列的顺序信息就能够获取到了。

TODO遗留问题:sin和cos都是周期函数,会存在位置编码信息重叠吗?

code

import torch
import torch.nn as nn
import math


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        """
        初始化位置编码模块。
        :param d_model: 嵌入的维度
        :param max_len: 最大序列长度
        """
        super(PositionalEncoding, self).__init__()
        # 创建一个足够长的位置编码矩阵 [max_len, d_model]
        pe = torch.zeros(max_len, d_model)
        # 0 到 maxLen - 1 的 张量
        # unsqueeze(1): [maxLen,] => [maxLen, 1],即[0 到 maxLen] => [[0 到 maxLen]]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # 除数张量, [,maxLen]
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # 奇数和偶数下标的分别处理
        pe[:, 0::2] = torch.sin(position * div_term) # 因为position是[maxLen, 1],所以会有广播机制
        pe[:, 1::2] = torch.cos(position * div_term)

        # 增加一个维度,将位置编码设置为不可训练
        pe = pe.unsqueeze(0).detach()

        # 注册缓冲区,这样pe不会在训练过程中被视为模型的可训练参数
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        将位置编码添加到输入嵌入中。
        :param x: 输入嵌入,形状为 (Batch size, Sequence length, d_model)
        """
        # x的形状是 [Batch size, Sequence length, d_model]
        # 从缓冲区中取出相应长度的pe,并添加到x上
        x = x + self.pe[:, :x.size(1)]
        return x

原文地址:https://blog.csdn.net/qq_51976556/article/details/142956847

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!