自学内容网 自学内容网

samout游跨越一次

在这里插入图片描述

import torch
import numpy as np


class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads, win):
        super(MaxState, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        

        self.head_num = heads

        self.hidden = hidden_dim

    def forward(self, input_data, state=None):
        # self.head.to(device)
        b, s, k, h = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size

        out = self.head0(input_data)
        # 0版
        # out1 = torch.max(torch.concat([1-torch.exp(self.head1(input_data).unsqueeze(-1)),1-torch.exp(out.unsqueeze(-1))], -1), -1)[0]
        # 1版
        # out1 = torch.min(torch.concat(
        #     [1-torch.exp(h ** 0.5-self.head1(input_data).unsqueeze(-1)), 1-torch.exp(h ** 0.5-out.unsqueeze(-1))],
        #     -1), -1)[0]
        # 2版  超过12层
        out1 = torch.min(torch.concat(
            [h ** 0.5 - torch.exp(self.head2(input_data).unsqueeze(-1)), h ** 0.5-torch.exp(h ** 0.5 - out.unsqueeze(-1))],
            -1), -1)[0]

        #
        out = out.reshape([b, s, k, h]).permute([0, 2, 1, 3])
        out1 = out1.reshape([b, s, k, h]).permute([0, 2, 1, 3])
        # out1 = self.head1(input_data).reshape([b, s, k, h]).permute([0, 2, 1, 3])

        out = torch.cummax(out * (torch.exp(out1)+h**0.5), 2)[0]

        out = out.permute([0, 2, 1, 3])
        out = out.reshape([b, s, -1])

        out = torch.min(torch.concat(
            [h **0.5-torch.exp(self.head2(input_data).unsqueeze(-1)), torch.exp(h **0.5-out.unsqueeze(-1))],
            -1), -1)[0]

        # out = torch.min(torch.concat(
        #     [(out-torch.exp(self.head2(input_data))).unsqueeze(-1), torch.exp(h ** 0.5-out.unsqueeze(-1))],
        #     -1), -1)[0]

        return out, state


class KAttention(torch.nn.Module):
    def __init__(self, hidden_dim, heads):
        super(KAttention, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.q = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.k = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        # self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head_num = heads

    def forward(self, x, state=None):
        b, s, h, d = x.shape[0], x.shape[1], self.head_num, self.head_size
        q = self.q(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
        k = self.k(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
        v = self.v(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
        qk = (q @ k.permute([0, 1, 3, 2])) / d ** 0.5
        mask = torch.triu(torch.ones(s, s).to(device))
        qk = torch.where(mask.T == 1, qk, torch.Tensor([-float('inf')]).to(device))
        qkv = torch.nn.functional.softmax(qk, -1) @ v
        #             v + torch.arange(1, 3 * s, 3).reshape([1, 1, -1, 1]).to(device) / s / 3)
        qkv = qkv.permute([0, 2, 1, 3]).reshape([b, s, -1])
        #
        return qkv, state


class FeedForward(torch.nn.Module):
    def __init__(self, hidden_size):
        super(FeedForward, self).__init__()

        self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
        self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x1 = self.ffn1(x)

        x2 = self.relu(self.gate(x))

        x = x1 * x2

        x = self.ffn2(x)
        return x





class DecoderLayer(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(DecoderLayer, self).__init__()
        # self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)
        self.self_attention = MaxState(hidden_size, num_heads, 8)
        # self.self_attention = KAttention(hidden_size, num_heads)
        self.ffn = FeedForward(hidden_size)
        self.layer_norm = torch.nn.LayerNorm(hidden_size)

    def forward(self, x, state=None, seq_len=None):
        x1, state = self.self_attention(x, state)
        x = self.layer_norm(self.ffn(x1) + x)

        return x, state


class SamOut(torch.nn.Module):
    def __init__(self, voc_size, hidden_size, num_heads, num_layers):
        super(SamOut, self).__init__()
        self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
        self.pos = torch.nn.Embedding(1024, hidden_size)

        self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
        self.head = torch.nn.Linear(hidden_size, voc_size, False)
        # self.head_state = torch.nn.Linear(hidden_size, num_layers, False)

        self.down = torch.nn.ModuleList(
            [torch.nn.Linear(2 * hidden_size, hidden_size, False) for _ in range(num_layers)])

    def state_forward(self, state, pos, x):
        if state is None:
            state = [None] * len(self.decoder_layers)
        i = 0
        for ii, decoder_layer in enumerate(self.decoder_layers):
            x = self.down[i](torch.concat([torch.zeros([x.shape[0], 1, 1]).to(device) + pos, x], -1))

            x1, state[i] = decoder_layer(x, state[i])
            x = x1 + x
            i += 1
        return x, state

    def pos_forward(self, x):
        if x.shape[1] >= 1024:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos

        else:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)
        return pos

    def forward(self, x0):
        x0, _ = self.one_forward(x0, state=None)

        return x0, _

    def one_forward(self, x, state=None, seq_len=None):
        x = self.em(x)

        pos = self.pos_forward(x)

        x, state = self.state_forward(state, pos, x)

        return self.head(x), state


device = "cuda"
if __name__ == '__main__':
    net = SamOut(235, 256, 16, 4)
    net.to(device)
    net(torch.randint(0, 200, [2, 8 * 13]).to(device))
    #

这段代码定义了一个基于PyTorch的神经网络模型,用于序列到序列的转换任务。以下是代码的主要组成部分和功能概述:

  1. MaxState类:这是一个自定义的注意力机制层,用于处理序列数据。它包含了多个线性层,用于计算注意力权重,并通过累积最大值的方式来更新状态。
  2. KAttention类:这是另一个自定义的注意力机制层,实现了基于键值对的注意力机制。
  3. FeedForward类:这是一个前馈神经网络层,包含两个线性层和一个ReLU激活函数,用于在注意力机制之后处理数据。
  4. DecoderLayer类:这是一个解码器层,包含一个注意力层和一个前馈神经网络层,并使用层归一化。
  5. SamOut类:这是整个模型的主体,包含嵌入层、位置编码、多个解码器层和一个输出层。它还负责处理状态前向传播和位置编码前向传播。
  6. 设备配置:代码最后部分将模型移动到CUDA设备上,以便使用GPU进行加速计算。
  7. 主函数:在主函数中,创建了一个SamOut实例,并将其应用于一个随机整数矩阵,模拟输入数据。
    整体而言,这个模型适用于处理序列数据,如自然语言处理任务中的机器翻译、文本摘要等。通过使用注意力机制和前馈神经网络,模型能够学习输入序列和输出序列之间的复杂关系。

原文地址:https://blog.csdn.net/weixin_32759777/article/details/142702818

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!