自学内容网 自学内容网

深度学习02-pytorch-07-张量的拼接操作

在 PyTorch 中,张量的拼接操作主要通过 torch.cat()torch.stack() 两个函数来完成。拼接操作允许你将多个张量沿着指定的维度连接在一起,构建更大的张量。以下是详细解释和举例说明:

1. torch.cat()

功能: 沿着指定的维度连接(拼接)多个张量。torch.cat() 是最常用的拼接函数,它不会增加新的维度,只是在指定维度上将张量的值连接在一起。

语法:

torch.cat(tensors, dim=0)
  • tensors: 需要拼接的张量列表。

  • dim: 沿着哪一个维度进行拼接。

示例:

import torch
# 创建两个形状相同的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = torch.tensor([[7, 8, 9], [10, 11, 12]])
​
# 沿着第0个维度拼接
z0 = torch.cat((x, y), dim=0)
print(z0)

输出:

tensor([[ 1, 2, 3],
      [ 4, 5, 6],
      [ 7, 8, 9],
      [10, 11, 12]])

在这个例子中,xy 沿着第0维拼接(相当于竖直方向连接)。

# 沿着第1个维度拼接
z1 = torch.cat((x, y), dim=1)
print(z1)

输出:

tensor([[ 1, 2, 3, 7, 8, 9],
      [ 4, 5, 6, 10, 11, 12]])

在这个例子中,xy 沿着第1维拼接(相当于水平方向连接)。

2. torch.stack()

功能: 沿着新维度拼接多个张量。与 torch.cat() 不同,torch.stack() 会在指定维度插入一个新的维度,并将张量叠加在该维度上。

语法:

torch.stack(tensors, dim=0)
  • tensors: 需要叠加的张量列表。

  • dim: 在哪一个维度上插入新的维度。

示例:

import torch
# 创建两个相同形状的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = torch.tensor([[7, 8, 9], [10, 11, 12]])
​
# 沿着新维度叠加张量
z = torch.stack((x, y), dim=0)
print(z)

输出:

tensor([[[ 1, 2, 3],
        [ 4, 5, 6]],
      [[ 7, 8, 9],
        [10, 11, 12]]])

在这个例子中,torch.stack() 插入了一个新的维度,最终的形状是 (2, 2, 3)

# 沿着第1个维度叠加
z1 = torch.stack((x, y), dim=1)
print(z1)

输出:

tensor([[[ 1, 2, 3],
        [ 7, 8, 9]],
      [[ 4, 5, 6],
        [10, 11, 12]]])

在这个例子中,新的维度插入到了第1维,最终的形状是 (2, 2, 3)

3. torch.chunk()

功能: 将一个张量沿着指定的维度分割成若干个小张量。 语法:

torch.chunk(tensor, chunks, dim=0)
  • tensor: 需要被分割的张量。

  • chunks: 分割成多少个张量。

  • dim: 沿着哪一个维度分割。

示例:

x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 沿着第0维分割为3块
chunks = torch.chunk(x, 3, dim=0)
for chunk in chunks:
   print(chunk)

输出:

tensor([[1, 2, 3]])
tensor([[4, 5, 6]])
tensor([[7, 8, 9]])

4. torch.split()

功能: 与 torch.chunk() 类似,但它允许指定每个子张量的大小。

语法:

torch.split(tensor, split_size_or_sections, dim=0)
  • tensor: 要分割的张量。

  • split_size_or_sections: 每个子张量的大小,或按指定的切片。

  • dim: 沿着哪一个维度分割。

示例:

x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 将张量分割为每块大小为2和1
splits = torch.split(x, [2, 1], dim=0)
for split in splits:
   print(split)

输出:

tensor([[1, 2, 3],
      [4, 5, 6]])
tensor([[7, 8, 9]])

5. torch.unbind()

功能: 沿着指定维度将张量解开为多个子张量。 语法:

torch.unbind(tensor, dim=0)
  • tensor: 要解开的张量。

  • dim: 沿着哪个维度解开。

示例:

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 沿着第0维度解开
unbinded = torch.unbind(x, dim=0)
for t in unbinded:
   print(t)

输出:

tensor([1, 2, 3])
tensor([4, 5, 6])

在这个例子中,张量 x 沿着第0维被解开为两个子张量。

总结

  • torch.cat() 是最常用的拼接方法,用于沿着指定维度拼接多个张量。

  • torch.stack() 可以插入新的维度,叠加多个张量。

  • torch.chunk()torch.split() 用于将张量分割成多个子张量。

  • torch.unbind() 用于沿指定维度解开张量为多个张量。

这些操作允许你灵活地操作张量的维度,方便进行数据预处理和模型设计。


原文地址:https://blog.csdn.net/weixin_41645791/article/details/142406455

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!