自学内容网 自学内容网

pytorch view、expand、transpose、permute、reshape、repeat、repeat_interleave

非contiguous操作

There are a few operations on Tensors in PyTorch that do not change the contents of a tensor, but change the way the data is organized. These operations include:

narrow(), view(), expand() and transpose() permute()

This is where the concept of contiguous comes in. In the example above, x is contiguous but y is not because its memory layout is different to that of a tensor of same shape made from scratch. Note that the word “contiguous” is a bit misleading because it’s not that the content of the tensor is spread out around disconnected blocks of memory. Here bytes are still allocated in one block of memory but the order of the elements is different!

When you call contiguous(), it actually makes a copy of the tensor such that the order of its elements in memory is the same as if it had been created from scratch with the same data.

transpose()

permute() and tranpose() are similar. transpose() can only swap two dimension. But permute() can swap all the dimensions. For example:

x = torch.rand(16, 32, 3)
y = x.tranpose(0, 2)

z = x.permute(2, 1, 0)

permute

Returns a view of the original tensor input with its dimensions permuted.

>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> torch.permute(x, (2, 0, 1)).size()
torch.Size([5, 2, 3])

expand

More than one element of an expanded tensor may refer to a single memory location. As a result, in-place operations (especially ones that are vectorized) may result in incorrect behavior. If you need to write to the tensors, please clone them first.

>>> x = torch.tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.expand(3, 4)
tensor([[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3]])
>>> x.expand(-1, 4)   # -1 means not changing the size of that dimension
tensor([[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3]])

Difference Between view() and reshape()

1/ view(): Does NOT make a copy of the original tensor. It changes the dimensional interpretation (striding) on the original data. In other words, it uses the same chunk of data with the original tensor, so it ONLY works with contiguous data.

2/ reshape(): Returns a view while possible (i.e., when the data is contiguous). If not (i.e., the data is not contiguous), then it copies the data into a contiguous data chunk, and as a copy, it would take up memory space, and also the change in the new tensor would not affect the value in the original tensor.

With contiguous data, reshape() returns a view.

When data is contiguous

x = torch.arange(1,13)
x
>> tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])

Reshape returns a view with the new dimension

y = x.reshape(4,3)
y
>>
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

How do we know it’s a view? Because the element change in new tensor y would affect the value in x, and vice versa

y[0,0] = 100
y
>>
tensor([[100,   2,   3],
        [  4,   5,   6],
        [  7,   8,   9],
        [ 10,  11,  12]])
print(x)
>>
tensor([100,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12])

Next, let’s see how reshape() works on non-contiguous data.

# After transpose(), the data is non-contiguous
x = torch.arange(1,13).view(6,2).transpose(0,1)
x
>>
tensor([[ 1,  3,  5,  7,  9, 11],
        [ 2,  4,  6,  8, 10, 12]])
# Reshape() works fine on a non-contiguous data
y = x.reshape(4,3)
y
>>
tensor([[ 1,  3,  5],
        [ 7,  9, 11],
        [ 2,  4,  6],
        [ 8, 10, 12]])
# Change an element in y
y[0,0] = 100
y
>>
tensor([[100,   3,   5],
        [  7,   9,  11],
        [  2,   4,   6],
        [  8,  10,  12]])
# Check the original tensor, and nothing was changed
x
>>
tensor([[ 1,  3,  5,  7,  9, 11],
        [ 2,  4,  6,  8, 10, 12]])

Finally, let’s see if view() can work on non-contiguous data.
No, it can’t!

# After transpose(), the data is non-contiguous
x = torch.arange(1,13).view(6,2).transpose(0,1)
x
>>
tensor([[ 1,  3,  5,  7,  9, 11],
        [ 2,  4,  6,  8, 10, 12]])
# Try to use view on the non-contiguous data
y = x.view(4,3)
y
>>
-------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
----> 1 y = x.view(4,3)
      2 y

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

contiguous操作

reshape是能返回view就view,不能view就拷贝一份

>>> a = torch.arange(4.)
>>> torch.reshape(a, (2, 2))
tensor([[ 0.,  1.],
        [ 2.,  3.]])
>>> b = torch.tensor([[0, 1], [2, 3]])
>>> torch.reshape(b, (-1,))
tensor([ 0,  1,  2,  3])

repeat是新克隆内存,但是expand是原地更新stride

import torch
a = torch.arange(10).reshape(2,5)
# b = a.expand(4,5) #这就崩了,多维上没法expand,用repeat
b = a.repeat(2,2)
print('b={}'.format(b))
'''
b=tensor([[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9, 5, 6, 7, 8, 9]])
'''
c = torch.arange(3).reshape(1,3)
print('c={} c.stride()={}'.format(c, c.stride()))
d = c.expand(2,3)
print('d={} d.stride()={}'.format(d, d.stride()))
'''
c=tensor([[0, 1, 2]]) c.stride()=(3, 1), 在dim=0上迈3步,在dim=1上迈1步
d=tensor([[0, 1, 2],
        [0, 1, 2]]) d.stride()=(0, 1), 在dim=0上迈0步,在dim=1上迈1步
'''
d[0][0] = 5
print('c={} d={}'.format(c, d))
'''
c=tensor([[5, 1, 2]]) d=tensor([[5, 1, 2],
        [5, 1, 2]])
'''

repeat_interleave是把相邻着重复放,但是repeat是整体重复。所以repeat_interleave要指定下dim,但是repeat一次多维重复

This is different from torch.Tensor.repeat() but similar to numpy.repeat.

>>> x = torch.tensor([1, 2, 3])
>>> x.repeat_interleave(2)
tensor([1, 1, 2, 2, 3, 3])
>>> y = torch.tensor([[1, 2], [3, 4]])
>>> torch.repeat_interleave(y, 2)
tensor([1, 1, 2, 2, 3, 3, 4, 4])
>>> torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])
# 第一行重复1遍,第二行重复2遍
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
tensor([[1, 2],
        [3, 4],
        [3, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3)
tensor([[1, 2],
        [3, 4],
        [3, 4]])

  1. https://stackoverflow.com/questions/48915810/what-does-contiguous-do-in-pytorch
  2. https://medium.com/analytics-vidhya/pytorch-contiguous-vs-non-contiguous-tensor-view-understanding-view-reshape-73e10cdfa0dd

原文地址:https://blog.csdn.net/taoqick/article/details/137646079

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!