自学内容网 自学内容网

【pytorch-02】:张量的索引、形状操作和常见运算函数

1 张量索引

1.1 简单行列索引和列表索引

import torch


# 1. 简单行列索引
def test01():

    # 固定随机数种子
    torch.manual_seed(0)

    data = torch.randint(0, 10, [4, 5])
    print(data)
    print('-' * 30)

    # 1.1 获得指定的某行元素
    # print(data[2])

    # 1.2 获得指定的某个列的元素
    # 逗号前面表示行, 逗号后面表示列

    # 冒号表示所有行或者所有列
    # print(data[:, :])

    # 表示获得第3列的元素
    print(data[:, 2])

    # 获得指定位置的某个元素
    print(data[1, 2], data[1][2])

    # 表示先获得前三行,然后再获得第三列的数据
    print(data[:3, 2])

    # 表示获得前三行的前两列
    print(data[:3, :2])


# 2. 列表索引
def test02():


    # 固定随机数种子
    torch.manual_seed(0)

    data = torch.randint(0, 10, [4, 5])
    print(data)
    print('-' * 30)

    # 如果索引的行列都是一个1维的列表,那么两个列表的长度必须相等
    # print(data[[0, 1, 2], [2, 4]]) # 报错,索引位置都是一维,必须匹配
    # 解决方法:如果不想前后维数一样,就采用二维数组
    # 使用二维数组进行索引得到仍然为二维数组
    print(data[[[0],[1],[2]],[2,4]])


    # 1.表示获得 (0, 0)、(2, 1)、(3, 2) 三个位置的元素
    # 使用一维数组进行索引,得到的是一维
    print(data[[0, 2, 3], [0, 1, 2]])

    # 2。表示获得 0、2、3 行的 0、1、2 列
    # print(data[[[0], [2], [3]], [0, 1, 2]])

1.2 布尔索引和多维索引

import torch


# 1. 布尔索引
def test01():

    torch.manual_seed(0)
    data = torch.randint(0, 10, [4, 5])
    print(data)

    # 1. 希望能够获得该张量中所有大于3的元素
    # 所有元素与3进行比较,大于返回True,小于返回False
    # 返回一个布尔类型的张量
    print(data > 3)

    # 对于张量中的所有元素进行筛选,变为一维的张量
    print(data[data > 3])


    # 2. 希望返回第2列元素大于6的行
    # 先获取到第二列数据,然后进行比较,得到布尔张量
    # 然后再进行行索引

    # 想要获取到行,在行索引的位置传入布尔张量
    print(data[:,1] > 6) # tensor([ True,  True, False, False])
    print(data[data[:, 1] > 6]) # 选择前两行

    # 3. 希望返回第2行元素大于3的所有列
    # 想要获取到列,在列的位置传入布尔索引
    print(data[:, data[1] > 3])


# 2. 多维索引
def test02():

    torch.manual_seed(0)
    data = torch.randint(0, 10, [3, 4, 5])
    print(data)
    print('-' * 30)

    # 按照第0个维度选择第0元素,4行5列元素
    print(data[0, :, :])
    print('-' * 30)

    # 按照第1个维度选择第0元素
    print(data[:, 0, :])
    print('-' * 30)

    # 按照第2个维度选择第0元素
    print(data[:, :, 0])
    print('-' * 30)

2 张量的形状操作

2.1 reshape函数

  • 保证张量元素个数不变的情况下改变张量的形状
  • 在神经网络中,不同层中的数据形状不同
import torch


def test():

    torch.manual_seed(0)
    data = torch.randint(0, 10, [4, 5])

    # 查看张量的形状
    print(data.shape, data.shape[0], data.shape[1])
    # print(data.size(), data.size(0), data.size(1)) # 与上述方法结果一致

    # 修改张量的形状
    new_data = data.reshape(2, 10)
    print(new_data)

    # 注意: 转换之后的形状元素个数得等于原来张量的元素个数
    # 原来有多少个元素,转换之后就有多少个元素
    # new_data = data.reshape(1, 10)
    # print(new_data)

    # 使用-1代替省略的形状
    # 转换为指定的行数,列数指定为-1,可以进行自动匹配列数
    new_data = data.reshape(5, -1)
    print(new_data)

    # 转换为两列,自动进行计算行数
    new_data = data.reshape(-1, 2)
    print(new_data)

2.2 transpose和permute函数的使用

  • reshape函数更改形状,reshape会重新计算张量的维度,有时候不需要重新计算张量的维度,只要调整张量维度的顺序即可,可以使用transpose函数和permute函数
  • transpose函数每次只能交换两个维度
  • permute函数可以一次交换多个维度
import torch


# 1. transpose 函数
def test01():

    torch.manual_seed(0)
    data = torch.randint(0, 10, [3, 4, 5])

    new_data = data.reshape(4, 3, 5)
    print(new_data.shape) # torch.Size([4, 3, 5])

    # 直接交换两个维度的值
    new_data = torch.transpose(data, 0, 1)
    print(new_data.shape) # torch.Size([4, 3, 5])

    # 缺点: 一次只能交换两个维度
    # 把数据的形状变成 (4, 5, 3)
    # 进行第一次交换: (4, 3, 5)
    # 进行第二次交换: (4, 5, 3)
    new_data = torch.transpose(data, 0, 1)
    new_data = torch.transpose(new_data, 1, 2)
    print(new_data.shape)


# 2. permute 函数
def test02():

    torch.manual_seed(0)
    data = torch.randint(0, 10, [3, 4, 5])

    # permute 函数可以一次性交换多个维度
    new_data = torch.permute(data, [1, 2, 0])
    print(new_data.shape)

2.3 view和contiguous函数

  • view函数改变张量的形状,只能用于存储在整块内存中的张量,具有一定的局限性。
  • pytorch中有些张量是由不同的数据块组成,并没有存储在整块的内存中,view函数无法对于这种张量进行变形处理
  • 一个张量经过了transpose或者permute函数的处理之后,就无法使用view函数进行形状操作
  • 先用contiguous将非连续内存空间转换为连续内存空间,然后再使用view函数进行更改张量形状
import torch


# 1. view 函数的使用
def test01():

    data = torch.tensor([[10, 20, 30], [40, 50, 60]])
    data = data.view(3, 2)
    print(data.shape)

    # is_contiguous 函数来判断张量是否是连续内存空间(整块的内存)
    print(data.is_contiguous())


# 2. view 函数使用注意
def test02():

    # 当张量经过 transpose 或者 permute 函数之后,内存空间基本不连续
    # 此时,必须先把空间连续,才能够使用 view 函数进行张量形状操作

    data = torch.tensor([[10, 20, 30], [40, 50, 60]])
    print('是否连续:', data.is_contiguous())
    data = torch.transpose(data, 0, 1)
    print('是否连续:', data.is_contiguous())

    # 此时,在不连续内存的情况使用 view 会怎么样呢?
    data = data.contiguous().view(2, 3)
    print(data) 

2.4 squeeze和unsqueeze函数用法

  • squeeze函数可以将维度为1的维度进行删除
  • unsqueeze函数给张量增加维度为1的维度
import torch


# 1. squeeze 函数使用
def test01():

    # 四维张量
    data = torch.randint(0, 10, [1, 3, 1, 5])
    print(data.shape)

    # 维度压缩, 默认去掉所有的1的维度
    # squeeze() - 默认去掉所有维度为1的函数
    # squeeze(0) - 删除第一个位置的为1的维度
    # 传入维度的索引值
    new_data = data.squeeze(0)
    print(new_data.shape)  # torch.Size([3, 5])

    # 指定去掉某个1的维度
    new_data = data.squeeze(2)
    print(new_data.shape)


# 2. unsqueeze 函数使用
def test02():

    data = torch.randint(0, 10, [3, 5])
    print(data.shape) # torch.Size([1, 3, 1, 5])


    # 可以在指定位置增加维度
    # -1 代表最后一个维度
    new_data = data.unsqueeze(-1)
    print(new_data.shape)

2.5 张量更改形状小结

  1. reshape 函数可以在保证张量数据不变的前提下改变数据的维度.
  2. transpose 函数可以实现交换张量形状的指定维度, permute 可以一次交换更多的维度.
  3. view 函数也可以用于修改张量的形状, 但是它要求被转换的张量内存必须连续,所以一般配合 contiguous 函数使用.
  4. squeeze 和 unsqueeze 函数可以用来增加或者减少维度.

3 常见运算函数

  • mean()
  • sum()
  • pow(n)
  • sqrt()
  • exp()
  • log() - 以e为底的对数
  • log2()
  • log10()
import torch


# 1. 均值
def test01():

    torch.manual_seed(0)
    # data = torch.randint(0, 10, [2, 3], dtype=torch.float64)
    data = torch.randint(0, 10, [2, 3]).double()
    # print(data.dtype)

    print(data)
    # 默认对所有的数据计算均值
    print(data.mean())
    # 按指定的维度计算均值
    print(data.mean(dim=0)) # 竖向计算 按列计算
    print(data.mean(dim=1)) # 横向计算 按行计算


# 2. 求和
def test02():

    torch.manual_seed(0)
    data = torch.randint(0, 10, [2, 3]).double()

    print(data.sum())
    print(data.sum(dim=0))
    print(data.sum(dim=1))


# 3. 平方
def test03():

    torch.manual_seed(0)
    data = torch.randint(0, 10, [2, 3]).double()
    print(data)
    data = data.pow(2)
    print(data)


# 4. 平方根
def test04():

    torch.manual_seed(0)
    data = torch.randint(0, 10, [2, 3]).double()
    print(data)
    data = data.sqrt()
    print(data)


# 5. e多少次方
def test05():

    torch.manual_seed(0)
    data = torch.randint(0, 10, [2, 3]).double()
    print(data)
    data = data.exp()
    print(data)


# 6. 对数
def test06():

    torch.manual_seed(0)
    data = torch.randint(0, 10, [2, 3]).double()
    print(data)
    data = data.log()     # 以e为底
    data = data.log2()    # 以2为底
    data = data.log10()   # 以10为底
    print(data)

原文地址:https://blog.csdn.net/weixin_51385258/article/details/143905539

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!