自学内容网 自学内容网

PyTorch基本使用——张量的索引操作

在操作张量时,经常要去获取某些元素进行处理或者修改操作,在这里需要了解torch中的索引操作。

准备数据:

data = torch.randint(0,10,[4,5])
print('data--->',data)

输出结果:

data---> tensor([[3, 9, 4, 0, 5],
        [7, 5, 9, 9, 7],
        [5, 9, 8, 9, 7],
        [9, 2, 6, 7, 7]])
  • 简单行、列索引

    print('第一行:',data[0])
    print('第一列:',data[:,0])
    

    输出结果:

    第一行: tensor([3, 9, 4, 0, 5])
    第一列: tensor([3, 7, 5, 9])
    
  • 列表索引

    print('-----------------返回(0,1)、(1,2) 2个位置的元素------------------')
    print(data[[0,1],[1,2]])
    print('-----------------返回0、1 行的1、2 列共4个元素------------------')
    print(data[[[0],[1]],[1,2]])
    

    输出结果:

    -----------------返回(0,1)、(1,2) 2个位置的元素------------------
    tensor([9, 9])
    -----------------返回0、1 行的1、2 列共4个元素------------------
    tensor([[9, 4],
            [5, 9]])
    
  • 范围索引

    print('-----------------前3行、前2列的数据------------------')
    print(data[:3,:2])
    print('-----------------第2行到最后的前2列数据------------------')
    print(data[2:,:2])
    

    输出结果:

    -----------------前3行、前2列的数据------------------
    tensor([[3, 9],
            [7, 5],
            [5, 9]])
    -----------------第2行到最后的前2列数据------------------
    tensor([[5, 9],
            [9, 2]])
    
  • 布尔索引

    print('-----------------第三列大于5的行数据------------------')
    print(data[data[:,2] > 5])
    print('-----------------第二行大于5的行数据------------------')
    print(data[:,data[1] > 5])
    

    输出结果:

    -----------------第三列大于5的行数据------------------
    tensor([[7, 5, 9, 9, 7],
            [5, 9, 8, 9, 7],
            [9, 2, 6, 7, 7]])
    -----------------第二行大于5的行数据------------------
    tensor([[3, 4, 0, 5],
            [7, 9, 9, 7],
            [5, 8, 9, 7],
            [9, 6, 7, 7]])
    
  • 多维索引

    data = torch.randint(0,10,[3,4,5])
    print(data)
    # 获取0轴上的第一个数据
    print(data[0,:,:])
    # 获取1轴上的第一个数据
    print(data[:,0,:])
    # 获取2轴上的第一个数据
    print(data[:,:,0])
    

    输出结果:

    tensor([[[8, 3, 6, 1, 5],
             [5, 0, 4, 3, 8],
             [8, 3, 3, 5, 0],
             [6, 4, 0, 8, 4]],
    
            [[7, 2, 3, 8, 5],
             [6, 2, 9, 5, 0],
             [4, 2, 7, 1, 1],
             [5, 4, 4, 1, 1]],
    
            [[2, 4, 7, 2, 5],
             [6, 1, 4, 5, 6],
             [9, 2, 3, 1, 0],
             [2, 1, 2, 7, 9]]])
    tensor([[8, 3, 6, 1, 5],
            [5, 0, 4, 3, 8],
            [8, 3, 3, 5, 0],
            [6, 4, 0, 8, 4]])
    tensor([[8, 3, 6, 1, 5],
            [7, 2, 3, 8, 5],
            [2, 4, 7, 2, 5]])
    tensor([[8, 5, 8, 6],
            [7, 6, 4, 5],
            [2, 6, 9, 2]])
    

原文地址:https://blog.csdn.net/dwjf321/article/details/144331314

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!