自学内容网 自学内容网

PyTorch基本使用-张量的形状操作

文章目录


学习目标:掌握reshape()、squeze()、unsqueeze()、transpose()、permute()、view()、contiguous()等函数使用

  • reshape() 函数

    reshape 函数可以在保证张量数据不变的前提下改版数据维度,将其转换成指定的形状。

    data = torch.tensor([[1,2,3],[4,5,6]])
    print('data ---> ',data)
    # 使用 shape属性或者size方法都可以获取张量的形状
    print(data.shape,data.shape[0],data.shape[1])
    print(data.size(),data.size(0),data.size(1))
    # 使用 reshape 修改张量的形状
    new_data = data.reshape(1,6)
    print(new_data.shape)
    

    输出结果:

    data --->  tensor([[1, 2, 3],
            [4, 5, 6]])
    torch.Size([2, 3]) 2 3
    torch.Size([2, 3]) 2 3
    torch.Size([1, 6])
    
  • squeeze()和unsqueeze()函数

    squeeze 函数删除形状为 1 的维度(降维),unsqueeze 函数添加形状为1的维度(升维)。

    data1 = torch.tensor([1,2,3,4,5,6])
    print('普通的1维数组 ---> ',data1.shape,data1)
    data2 = data1.unsqueeze(dim=0)
    print('在0维上拓展维度',data2.shape,data2) # 1*6
    data3 = data1.unsqueeze(dim=1)
    print('在1维上拓展维度',data3.shape,data3) # 6*1
    data4 = data1.unsqueeze(dim=-1)
    print('在-1维上拓展维度',data4.shape,data4) # 6*1
    data5 = data4.squeeze()
    print('压缩维度',data5.shape,data5) # 1*6
    

    输出结果:

    普通的1维数组 --->  torch.Size([6]) tensor([1, 2, 3, 4, 5, 6])
    在0维上拓展维度 torch.Size([1, 6]) tensor([[1, 2, 3, 4, 5, 6]])
    在1维上拓展维度 torch.Size([6, 1]) tensor([[1],
            [2],
            [3],
            [4],
            [5],
            [6]])
    在-1维上拓展维度 torch.Size([6, 1]) tensor([[1],
            [2],
            [3],
            [4],
            [5],
            [6]])
    压缩维度 torch.Size([6]) tensor([1, 2, 3, 4, 5, 6])
    
  • transpose()和permute()函数

    transpose 函数可以实现交换张量形状的指定维度, 例如: 一个张量的形状为 (2, 3, 4) 可以通过 transpose 函数把 3 和 4
    进行交换, 将张量的形状变为 (2, 4, 3) 。 permute 函数可以一次交换更多的维度。

    data = torch.tensor(np.random.randint(0,10,[3,4,5]))
    print('data.shape ---> ',data.shape)
    # 1. 交换1和2维度
    data2 = torch.transpose(data,1,2)
    print('data2.shape ---> ',data2.shape)
    # 2. 将形状换成 4,5,3 需要多次
    data3 = torch.transpose(data,0,1)
    data4 = torch.transpose(data3,1,2)
    print('data4.shape ---> ',data4.shape)
    # 使用permute将形状换成 4,5,3
    # 方法-1
    data5 = torch.permute(data,[1,2,0])
    print('data5.shape ---> ',data5.shape)
    # 方法-2
    data6 = data.permute([1,2,0])
    print('data6.shape ---> ',data6.shape)
    

    输出结果:

    data.shape --->  torch.Size([3, 4, 5])
    data2.shape --->  torch.Size([3, 5, 4])
    data4.shape --->  torch.Size([4, 5, 3])
    data5.shape --->  torch.Size([4, 5, 3])
    data6.shape --->  torch.Size([4, 5, 3])
    
  • view()和contiguous()函数

    view 函数也可以用于修改张量的形状,只能用于存储在整块内存中的张量。在 PyTorch 中,有些张量是由不同的数据
    块组成的,它们并没有存储在整块的内存中,view 函数无法对这样的张量进行变形处理,例如: 一个张量经过了
    transpose 或者 permute 函数的处理之后,就无法使用 view 函数进行形状操作。

    # 1. 一个张量经过了 transpose 或者 permute 函数的处理之后,就无法使用view 函数进行形状操作
    # 若要使用view函数, 需要使用contiguous() 变成连续以后再使用view函数
    # 2. 判断张量是否使用整块内存
    data1 = torch.tensor([[1,2,3],[4,5,6]])
    print('data1 ---> ',data1.shape,data1)
    # 3. 判断是否使用整块内存
    print(data1.is_contiguous())
    data2 = data1.view(3,2)
    print('data2 ---> ',data2.shape,data2)
    # 4. 判断是否使用整块内存
    print(data2.is_contiguous())
    # 5. 使用 transpose 函数修改形状
    data3 = torch.transpose(data1,0,1)
    print('data3 ---> ',data3.shape,data3)
    # 6. 判断是否使用整块内存
    print(data3.is_contiguous())
    # 7. 需要先使用 contiguous 函数转换为整块内存的张量,再使用 view 函数
    data4 = data3.contiguous().view(2,3)
    print('data4 ---> ',data4.shape,data4)
    # 8. 判断是否使用整块内存
    print(data4.is_contiguous())
    

    输出结果:

    data1 --->  torch.Size([2, 3]) tensor([[1, 2, 3],
            [4, 5, 6]])
    True
    data2 --->  torch.Size([3, 2]) tensor([[1, 2],
            [3, 4],
            [5, 6]])
    True
    data3 --->  torch.Size([3, 2]) tensor([[1, 4],
            [2, 5],
            [3, 6]])
    False
    data4 --->  torch.Size([2, 3]) tensor([[1, 4, 2],
            [5, 3, 6]])
    True
    

原文地址:https://blog.csdn.net/dwjf321/article/details/144354508

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!