【深度学习】矩阵操作万能函数 einsum-爱因斯坦求和
ref:https://blog.csdn.net/zhaohongfei_358/article/details/125273126
在学习transformer的时候,看到代码里面有
values = self.values(values) # (N, value_len, embed_size)
keys = self.keys(keys) # (N, key_len, embed_size)
queries = self.queries(query) # (N, query_len, embed_size)
# Split the embedding into self.heads different pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = queries.reshape(N, query_len, self.heads, self.head_dim)
# Einsum does matrix mult. for query*keys for each training example
# with every other training example, don't be confused by einsum
# it's just how I like doing matrix multiplication & bmm
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# queries shape: (N, query_len, heads, heads_dim),
# keys shape: (N, key_len, heads, heads_dim)
# energy: (N, heads, query_len, key_len)
把我看蒙了,所以这次正经学习一下,看看咋回事。这个颇有一些只可意会不可言传的感觉,还是人菜瘾大,理解不深啊!
einsum 在numpy和torch中都有,借助了index–>(求和)
import torch
import torch.nn as nn
import torch.optim as optim
x = torch.rand((2, 3))
v = torch.rand((1, 3))
print(torch.einsum('ij,kj->ik', x, v).shape) # 矩阵乘法
print(torch.einsum('ij,kj->ki', x, v).shape) # 矩阵乘法 + T
print(torch.einsum('ij,km->ijkm', x, v).shape) # 这个算是一个拼接吧
x = torch.rand((2, 3))
v = torch.rand((1, 3))
print(torch.einsum('ij,kj->ik', x, v).shape)
print(torch.einsum('ij,kj->ki', x, v).shape)
print(torch.einsum('ij,km->ijkm', x, v).shape)
import torch
x = torch.tensor([
[1, 2, 3],
[4,5,6]
])
y = torch.tensor([
[7,8,9]
])
x,y
(tensor([[1, 2, 3],
[4, 5, 6]]),
tensor([[7, 8, 9]]))
result = torch.einsum('ij,km->ijkm', x, y)
result
tensor([[[[ 7, 8, 9]],
[[14, 16, 18]],
[[21, 24, 27]]],
[[[28, 32, 36]],
[[35, 40, 45]],
[[42, 48, 54]]]])
a = [
[[1, 2], # i=0
[3, 4]], # i=0
[[5, 6], # i=1
[7, 8]] # i=1
]
b = [
[[9, 10, 11], # i=0
[12, 13, 14]], # i=0
[[15, 16, 17], # i=1
[18, 19, 20]] # i=1
]
torch.tensor(a[0]).shape,torch.tensor(b[0]).shape
torch.tensor(a[0]).shape,torch.tensor(b[0]).shape
torch.tensor(a[0]) @ torch.tensor(b[0])
torch.tensor(a[0]) @ torch.tensor(b[0])
torch.tensor(a[1]) @ torch.tensor(b[1])
tensor([[183, 194, 205],
[249, 264, 279]])
res = []
for i in range(len(a)):
a1 = torch.tensor(a[i])
b1 = torch.tensor(b[i])
res.append(a1@b1)
res1 = torch.stack(res)
print(res,"\n",res1)
res = []
for i in range(len(a)):
a1 = torch.tensor(a[i])
b1 = torch.tensor(b[i])
res.append(a1@b1)
res1 = torch.stack(res)
print(res,"\n",res1)
x = torch.rand(3, 3)
torch.einsum('ii->i', x),x
(tensor([0.7127, 0.3843, 0.2046]),
tensor([[0.7127, 0.0171, 0.9940],
[0.6781, 0.3843, 0.9031],
[0.4963, 0.1581, 0.2046]]))
原文地址:https://blog.csdn.net/weixin_40293999/article/details/142712073
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!