PyTorch torch.cat
PyTorch torch.cat
torch
https://pytorch.org/docs/stable/torch.html
torch.cat
(Python function, intorch.cat
)
1. torch.cat
https://pytorch.org/docs/stable/generated/torch.cat.html
torch.cat(tensors, dim=0, *, out=None) -> Tensor
Concatenates the given sequence of seq
tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be a 1-D empty tensor with size (0,)
.
在给定维度上连接给定的 seq
张量序列。所有张量必须具有相同的形状 (连接维度除外),或者是一个大小为 (0,)
的一维空张量。
torch.cat()
can be seen as an inverse operation for torch.split()
and torch.chunk()
.
torch.cat()
可以看作是 torch.split()
和 torch.chunk()
的逆运算。
torch.cat()
can be best understood via examples.
torch.stack()
concatenates the given sequence along a new dimension.
torch.stack()
沿着新维度连接给定的序列。
- Parameters
tensors (sequence of Tensors)
- any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
任何相同类型的张量 Python 序列。提供的非空张量必须具有相同的形状,连接维度除外。
dim (int, optional)
- the dimension over which the tensors are concatenated
连接张量的维度
- Keyword Arguments
out (Tensor, optional)
- the output tensor.
2. Example
(base) yongqiang@yongqiang:~$ python
Python 3.11.4 (main, Jul 5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.0811, 0.4571, -1.5260],
[ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 0)
tensor([[ 0.0811, 0.4571, -1.5260],
[ 1.4803, -0.0314, -1.5818],
[ 0.0811, 0.4571, -1.5260],
[ 1.4803, -0.0314, -1.5818],
[ 0.0811, 0.4571, -1.5260],
[ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 1)
tensor([[ 0.0811, 0.4571, -1.5260, 0.0811, 0.4571, -1.5260, 0.0811, 0.4571, -1.5260],
[ 1.4803, -0.0314, -1.5818, 1.4803, -0.0314, -1.5818, 1.4803, -0.0314, -1.5818]])
>>>
>>> exit()
(base) yongqiang@yongqiang:~$
3. Example
https://github.com/karpathy/llama2.c/blob/master/model.py
import torch
idxs = torch.randn(1, 5)
print("idxs.shape:", idxs.shape)
print("idxs:\n", idxs)
next_idx = torch.randn(1, 1)
print("\nnext_idx.shape:", next_idx.shape)
print("next_idx:\n", next_idx)
print("\nidxs.size(1):", idxs.size(1))
idxs_set = torch.cat((idxs, next_idx), dim=1)
print("\nidxs_set.shape:", idxs_set.shape)
print("idxs_set:\n", idxs_set)
/home/yongqiang/miniconda3/bin/python /home/yongqiang/llm_work/llama2.c/yongqiang.py
idxs.shape: torch.Size([1, 5])
idxs:
tensor([[-1.3383, 0.1427, 0.0857, 2.2887, 0.1691]])
next_idx.shape: torch.Size([1, 1])
next_idx:
tensor([[0.4807]])
idxs.size(1): 5
idxs_set.shape: torch.Size([1, 6])
idxs_set:
tensor([[-1.3383, 0.1427, 0.0857, 2.2887, 0.1691, 0.4807]])
Process finished with exit code 0
References
[1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/
原文地址:https://blog.csdn.net/chengyq116/article/details/144329703
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!