自学内容网 自学内容网

差分注意力,负注意力的引入

Differential Transformer差分注意力,负注意力的引入

相关链接

ai-algorithms/README.md at main · Jaykef/ai-algorithms (github.com)

unilm/Diff-Transformer at master · microsoft/unilm (github.com)

介绍

在这里插入图片描述

注意力是非负的,导致在长序列时,有效信息淹没在无关信息的海洋中,因此引入负注意力,着重关注序列中的有效部分。因此一半的注意力头用作负注意力头,注意力权重由这两部分的注意力权重的加权差决定,加权系数可学习。加权系数的初始化值和层数有关。加权系数是通过四个可学习参数重参数化而来

lambda_q1, lambda_k1, lambda_q2, lambda_k2

参数维度

d i m _ h e a d ∗ n u m _ h e a d ∗ 2 = e m b e d _ d i m dim\_head * num\_head *2 = embed\_dim dim_headnum_head2=embed_dim

名称定义举例
dim_headembed // num_heads //232//4//2
proj_q(embed_dim, embed_dim)(32, 32)
proj_k(embed_dim,embed_dim)(32, 32)
proj_v(embed_dim, embed_dim)(32, 32)
proj_out(embed_dim, embed_dim)(32, 32)
Q[N, L, C].view(N, L, 2 *num_heads,dim_head)(1024, 256, 2 *4 , 4)
K[N, L, C].view(N, L, 2 *num_heads, dim_head )(1024, 256, 2 *4 , 4)
V[N, L, C].view(N, L, num_heads, 2*dim_head )(1024, 256, 4 , 2 * 4)
attn_weights[N, 2*num_heads, L, L].view(N, 2,num_heads, L, L ) -> [N, num_heads, L, L](1024, 2 , 4 , 256 , 256 ) ->(1024, 4 , 256 , 256 )
attn[N, num_heads, L, 2*dim_heads]->[N, L, C](1024, 4, 256, 8) -> (1024, 256, 32)

初始化函数

def lambda_init_fn(depth):
    return 0.8 - 0.6 * math.exp(-0.3 * depth)

多头差分注意力

class MultiheadDiffAttn(nn.Module):
    def __init__(
        self,
        embed_dim = 32,
        depth = 0,
        num_heads = 8,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        # num_heads set to half of Transformer's #heads
        self.num_heads = num_heads 
        self.head_dim = embed_dim // num_heads // 2
        self.scaling = self.head_dim ** -0.5
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

        self.lambda_init = lambda_init_fn(depth)
        self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))

    
    def forward(
        self,
        x,
    ):
        bsz, tgt_len, embed_dim = x.size()
        src_len = tgt_len

        q = self.q_proj(x) #[bsz, tgt_len, embed_dim ]
        k = self.k_proj(x) #[bsz, tgt_len, embed_dim]
        v = self.v_proj(x) #[bsz, tgt_len, embed_dim]

        q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim) #[bsz, tgt_len, 2 * num_heads, head_dim]  embed_dim = head_dim * num_heads
        k = k.view(bsz, src_len, 2 * self.num_heads, self.head_dim) #[bsz, src_len, 2 * num_heads, head_dim]
        v = v.view(bsz, src_len, self.num_heads, 2 * self.head_dim) #[131072, 2, 8, 8] [bsz, tgt_len, num_heads, 2 * head_dim]

        q = q.transpose(1, 2) #[bsz, 2 * num_heads, tgt_len, head_dim] [131072, 16, 2, 4] 
        q *= self.scaling 

        k = k.transpose(1, 2) #[131072, 16, 2, 4]
        v = v.transpose(1, 2) #[131072, 8, 2, 8]
        attn_weights = torch.matmul(q, k.transpose(-1, -2)) #[131072, 16, 2, 2] [bsz, 2 * num_heads, tgt_len, src_len]

        attn_weights = torch.nan_to_num(attn_weights)
        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
            attn_weights
        )

        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
        lambda_full = lambda_1 - lambda_2 + self.lambda_init
        #[bsz, 2 * num_heads, tgt_len, src_len] 每一个注意力还是 [bsz, num_heads, tgt_len, src_len]
        attn_weights = attn_weights.view(bsz, self.num_heads, 2, tgt_len, src_len) #[131072, 8, 2, 2, 2] 第一个2是两个差分 
        attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1] # 第一个注意力减去第二个注意力 [131072, 8, 2, 2]
        
        #[bsz, num_heads, tgt_len, src_len]
        attn = torch.matmul(attn_weights, v) # [131072, 8, 2, 8]
        attn = attn * (1 - self.lambda_init)
        attn = attn.transpose(1, 2).reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim) #[131072, 2, 32]

        attn = self.out_proj(attn)
        return attn

原文地址:https://blog.csdn.net/weixin_45668967/article/details/142865222

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!