自学内容网 自学内容网

【深度学习】矩阵操作万能函数 einsum-爱因斯坦求和

ref:https://blog.csdn.net/zhaohongfei_358/article/details/125273126
在学习transformer的时候,看到代码里面有

        values = self.values(values)  # (N, value_len, embed_size)
        keys = self.keys(keys)  # (N, key_len, embed_size)
        queries = self.queries(query)  # (N, query_len, embed_size)

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim),
        # keys shape: (N, key_len, heads, heads_dim)
        # energy: (N, heads, query_len, key_len)

把我看蒙了,所以这次正经学习一下,看看咋回事。这个颇有一些只可意会不可言传的感觉,还是人菜瘾大,理解不深啊!

einsum 在numpy和torch中都有,借助了index–>(求和)

import torch
import torch.nn as nn
import torch.optim as optim
x = torch.rand((2, 3))
v = torch.rand((1, 3))
print(torch.einsum('ij,kj->ik', x, v).shape) # 矩阵乘法
print(torch.einsum('ij,kj->ki', x, v).shape) # 矩阵乘法 + T
print(torch.einsum('ij,km->ijkm', x, v).shape) # 这个算是一个拼接吧
x = torch.rand((2, 3))
v = torch.rand((1, 3))
print(torch.einsum('ij,kj->ik', x, v).shape)
print(torch.einsum('ij,kj->ki', x, v).shape)
print(torch.einsum('ij,km->ijkm', x, v).shape)
import torch
x = torch.tensor([
    [1, 2, 3],
    [4,5,6]
    ])
y = torch.tensor([
    [7,8,9]
    ])
x,y
(tensor([[1, 2, 3],
         [4, 5, 6]]),
 tensor([[7, 8, 9]]))
result = torch.einsum('ij,km->ijkm', x, y)
result
tensor([[[[ 7,  8,  9]],

         [[14, 16, 18]],

         [[21, 24, 27]]],


        [[[28, 32, 36]],

         [[35, 40, 45]],

         [[42, 48, 54]]]])
a = [
    [[1, 2],   # i=0
     [3, 4]],  # i=0
    [[5, 6],   # i=1
     [7, 8]]   #  i=1
]

b = [
     [[9, 10, 11], #  i=0
     [12, 13, 14]], #  i=0
     
    [[15, 16, 17], # i=1
     [18, 19, 20]]  # i=1
]

torch.tensor(a[0]).shape,torch.tensor(b[0]).shape

torch.tensor(a[0]).shape,torch.tensor(b[0]).shape

torch.tensor(a[0]) @ torch.tensor(b[0])
torch.tensor(a[0]) @ torch.tensor(b[0])
torch.tensor(a[1]) @ torch.tensor(b[1])
tensor([[183, 194, 205],
        [249, 264, 279]])
res = []
for i in range(len(a)):
    a1 = torch.tensor(a[i])
    b1 = torch.tensor(b[i])
    res.append(a1@b1)
res1 = torch.stack(res)
print(res,"\n",res1)
res = []
for i in range(len(a)):
    a1 = torch.tensor(a[i])
    b1 = torch.tensor(b[i])
    res.append(a1@b1)
res1 = torch.stack(res)
print(res,"\n",res1)
x = torch.rand(3, 3)
torch.einsum('ii->i', x),x
(tensor([0.7127, 0.3843, 0.2046]),
 tensor([[0.7127, 0.0171, 0.9940],
         [0.6781, 0.3843, 0.9031],
         [0.4963, 0.1581, 0.2046]]))

原文地址:https://blog.csdn.net/weixin_40293999/article/details/142712073

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!