自学内容网 自学内容网

samout 新设计

import torch


class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads):
        super(MaxState, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        # self.h_linear=torch.nn.Parameter(torch.empty(1, 1))
        # torch.nn.init.xavier_uniform_(self.h_linear,0.5)
        # self.layer_nor = torch.nn.LayerNorm(hidden_dim)
        # self.norm = torch.nn.LayerNorm(hidden_dim)
        # self.alpha = torch.nn.Parameter(torch.tensor(0.5))

        self.head_num = heads

        self.hidden = hidden_dim
        self.layer_nor = torch.nn.LayerNorm(hidden_dim)

    def forward(self, input_data, state=None):
        # self.head.to(device)
        b, s, k, h = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size
        out = self.head0(input_data)
        out1 = self.head1(input_data)
        out2 = self.head2(input_data)

        out = out.reshape([b, s, k, h]).permute([0, 2, 1, 3])
        out1 = out1.reshape([b, s, k, h]).permute([0, 2, 1, 3])
        # out2 = out2.reshape([b, s, k, h]).permute([0, 2, 1, 3])
        # out1 = self.head1(input_data).reshape([b, s, k, h]).permute([0, 2, 1, 3])

        out = torch.cummax((out + out1) / h ** 0.5, 2)[0]
        # out = torch.cummin((out + out1)/k**0.5 , 2)[0]
        # out_sum = torch.cumsum((out + out1)/k**0.5 , 2)
        # out=(out-out_min)*out

        out = out.permute([0, 2, 1, 3])
        out1 = out1.permute([0, 2, 1, 3])
        # out2 = out2.permute([0, 2, 1, 3])
        out = out.reshape([b, s, -1])
        out1 = out1.reshape([b, s, -1])
        # out2 = out2.reshape([b, s, -1])
        # out = self.layer_nor(out)
        # out = (out + out2) * out+out1
        # out3=torch.cummax(out,1)[0]
        # out = (out + out2) * out + out1
        out = self.layer_nor(out + out2 + out1)
        # out = self.alpha * out * (out + out2) + (1 - self.alpha) * out1
        return out



class FeedForward(torch.nn.Module):
    def __init__(self, hidden_size):
        super(FeedForward, self).__init__()

        self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
        self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
        # self.h_linear=torch.nn.Parameter(torch.empty(1, 1))
        # self.gate  = torch.nn.Parameter(torch.empty(hidden_size,  hidden_size * 2))
        # torch.nn.init.xavier_uniform_(self.gate,0.5)
        self.relu = torch.nn.ReLU()
        self.dr = torch.nn.Dropout(0.1)

    def forward(self, x):
        x1 = self.ffn1(x)
        x2 = self.relu(self.gate(x))
        xx = self.dr(x1 * x2)
        x = self.ffn2(xx)
        return x



class SamOut(torch.nn.Module):
    def __init__(self, voc_size, hidden_size, num_heads, num_layers):
        super(SamOut, self).__init__()
        self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
        self.pos = torch.nn.Embedding(1024, hidden_size)
        self.state = MaxState(hidden_size, num_heads)
        self.state1 = MaxState(hidden_size, num_heads)
        self.state2 = MaxState(hidden_size, num_heads)
        self.state3 = MaxState(hidden_size, num_heads)
        self.state4 = MaxState(hidden_size, num_heads)
        self.state5 = MaxState(hidden_size, num_heads)
        self.decoder = FeedForward(hidden_size)
        self.decoder1 = FeedForward(hidden_size)
        self.decoder2 = FeedForward(hidden_size)
        self.decoder3 = FeedForward(hidden_size)
        self.decoder4 = FeedForward(hidden_size)
        self.decoder5 = FeedForward(hidden_size)


        self.head = torch.nn.Linear(hidden_size, voc_size, False)
        self.layer_nor=torch.nn.LayerNorm(hidden_size)

    def pos_forward(self, x):
        if x.shape[1] >= 1024:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos

        else:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)
        return pos

    def forward(self, x):

        x = self.em(x)
        pos = self.pos_forward(x)
        x = self.state(x + pos) + x
        x1 = self.decoder(x)
        x2 = self.state1(x1) + x1
        x2 = self.decoder1(x2)
        x3 = self.state2(x1) + x1
        x3 = self.decoder2(x3)
        x = self.layer_nor(x1 + x2 + x3)

        x = self.state3(x) + x
        x1 = self.decoder3(x)
        x2 = self.state4(x1) + x1
        x2 = self.decoder4(x2)
        x3 = self.state5(x1) + x1
        x3 = self.decoder5(x3)
        x = self.layer_nor(x1 + x2 + x3)

        return self.head(x), ""





device = "cuda"
if __name__ == '__main__':
    net = SamOut(235, 256, 16, 4)
    net.to(device)
    net(torch.randint(0, 200, [2, 8 * 13]).to(device))
    #
# epoch___0____loss___8.586270____steps___65760:   0%|          | 0/1 [01:21<?, ?it/s]  cummax
# epoch___0____loss___6.930531____steps___67040:   0%|          | 0/1 [01:21<?, ?it/s]  cummax no layer_nor
# epoch___0____loss___7.680687____steps___77840:   0%|          | 0/1 [01:35<?, ?it/s]  cummax layer_nor
# epoch___0____loss___6.994579____steps___68240:   0%|          | 0/1 [01:25<?, ?it/s]  cummax cos
# epoch___0____loss___6.707716____steps___70640:   0%|          | 0/1 [01:24<?, ?it/s]  cummax no sin no cos
# epoch___0____loss___6.895388____steps___65200:   0%|          | 0/1 [01:21<?, ?it/s]  cummin
# epoch___0____loss___7.079460____steps___66720:   0%|          | 0/1 [01:22<?, ?it/s]  cummax no x
# epoch___0____loss___6.174834____steps___45360:   0%|          | 0/10 [01:00<?, ?it/s] cummax 2   2 no pos
# epoch___0____loss___6.239753____steps___45120:   0%|          | 0/10 [01:00<?, ?it/s] cummax 2   2  pos
# epoch___0____loss___6.547979____steps___36240:   0%|          | 0/10 [01:00<?, ?it/s] cummax 3   3  no pos
# epoch___0____loss___6.947957____steps___17600:   0%|          | 0/10 [01:01<?, ?it/s] src samout
# epoch___0____loss___6.108305____steps___52640:   0%|          | 0/10 [02:54<?, ?it/s] src samout
# epoch___0____loss___6.069768____steps___55280:   0%|          | 0/10 [03:03<?, ?it/s] src samout
# epoch___0____loss___6.058203____steps___54560:   0%|          | 0/10 [01:11<?, ?it/s] current samout
# epoch___0____loss___5.996508____steps___52560:   0%|          | 0/10 [01:27<?, ?it/s]
# epoch___0____loss___6.067177____steps___54400:   0%|          | 0/10 [01:30<?, ?it/s]
# epoch___0____loss___5.974577____steps___52720:   0%|          | 0/10 [01:44<?, ?it/s]
# epoch___0____loss___5.869751____steps___55520:   0%|          | 0/10 [01:57<?, ?it/s]
# epoch___0____loss___5.749324____steps___55440:   0%|          | 0/10 [02:03<?, ?it/s]  maxstate  no cat
# epoch___0____loss___5.715099____steps___55440:   0%|          | 0/10 [02:26<?, ?it/s]  cat
# epoch___0____loss___5.704436____steps___55520:   0%|          | 0/10 [02:04<?, ?it/s] x1 +x2+x3
# epoch___0____loss___5.710885____steps___55360:   0%|          | 0/10 [02:04<?, ?it/s] x1 +x2+x3  比 cat 牛且减少了参数量
# epoch___0____loss___5.673217____steps___55360:   0%|          | 0/10 [02:00<?, ?it/s]  out+out1+out2
# epoch___0____loss___5.669157____steps___55360:   0%|          | 0/10 [02:13<?, ?it/s]
# epoch___0____loss___5.677723____steps___55360:   0%|          | 0/10 [02:42<?, ?it/s]
# epoch___0____loss___5.494996____steps___55360:   0%|          | 0/10 [03:43<?, ?it/s]
# epoch___0____loss___5.319009____steps___55280:   0%|          | 0/10 [03:42<?, ?it/s]  0.0003
# epoch___0____loss___4.823767____steps___54160:   0%|          | 0/10 [03:38<?, ?it/s]  0.0003 结尾 + layer norm
# epoch___0____loss___4.830925____steps___54240:   0%|          | 0/10 [03:39<?, ?it/s]  0.0003 都加 + layer norm
# epoch___0____loss___4.843996____steps___56160:   0%|          | 0/10 [03:46<?, ?it/s]  0.0003 中间 + relu
# epoch___0____loss___4.821821____steps___55520:   0%|          | 0/10 [03:44<?, ?it/s]  0.0003 中间 + gelu
# epoch___0____loss___5.115138____steps___60400:   0%|          | 0/10 [04:03<?, ?it/s]  0.0003 中间 + layer norm

这个LLM设计是一个基于PyTorch的序列到序列的模型,它由多个自定义的神经网络模块组成。以下是其主要组件的概述:

  1. MaxState模块
    • 这个模块似乎是一个自定义的注意力机制,它将输入数据通过多个线性层处理,然后应用一个累积最大(cummax)操作来捕获序列中的全局信息。
    • 它使用了多个“头”(heads),每个头有自己的线性层,这类似于Transformer模型中的多头注意力机制。
    • 该模块还包括一个LayerNorm层,用于规范化输出。
  2. FeedForward模块
    • 这是一个前馈神经网络,包含两个线性层和一个ReLU激活函数。
    • 还有一个门控机制(gate),通过一个线性层生成门控信号,与第一个线性层的输出相乘。
  3. SamOut模型
    • 这是一个更大的模型,它包含了多个MaxState和FeedForward模块。
    • 它首先通过一个嵌入层(embedding layer)和一个位置编码层(positional encoding layer)处理输入。
    • 然后它多次应用MaxState和FeedForward模块,交替使用,以处理序列数据。
    • 最后,它通过一个线性层输出最终的预测结果。
  4. 训练和推断
    • 在模型的最后,有一个循环,用于进行训练和推断。它计算损失并更新模型的权重。
      一些其他的要点:
  • 模型使用了dropout来减少过拟合。
  • 模型似乎支持在GPU上运行,因为它有一个.to(device)调用,其中device被设置为"cuda"。
  • 在模型的主体部分,有多个MaxState和FeedForward模块堆叠在一起,这可能是为了增加模型的表达能力。
  • 在代码的注释部分,有关于不同实验设置的注释,比如是否使用LayerNorm、不同的损失值等。
    总体来说,这个模型的设计与Transformer模型有相似之处,但也包含了一些独特的元素,如累积最大操作和自定义的注意力机制。

原文地址:https://blog.csdn.net/weixin_32759777/article/details/144016528

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!