torch.Tensor.view(*shape)方法的使用举例

参考链接: view(*shape)

torch.Tensor.view(*shape)方法的使用举例_第1张图片

说明: 一句话概括,对一个连续的(contiguous)张量维度重新布局,但内存上不进行移动,仅仅返回一个视图.

(base) PS C:\Users\chenxuqi> python
Python 3.7.4 (default, Aug  9 2019, 18:34:13) [MSC v.1915 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.randn(4, 4)
>>> x.shape
torch.Size([4, 4])
>>> x.size()
torch.Size([4, 4])
>>> x.stride()
(4, 1)
>>>
>>> y = x.view(16)
>>> y.size()
torch.Size([16])
>>> y.stride()
(1,)
>>>
>>>
>>> z = x.view(-1, 8)  # the size -1 is inferred from other dimensions
>>> z.size()
torch.Size([2, 8])
>>> z.stride()
(8, 1)
>>>
>>>
>>> b = a.transpose(1, 2)  # Swaps 2nd and 3rd dimension
Traceback (most recent call last):
  File "", line 1, in <module>
NameError: name 'a' is not defined
>>>
>>>
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x0000019148EC2A90>
>>> a = torch.randn(1, 2, 3, 4)
>>> a.size()
torch.Size([1, 2, 3, 4])
>>> a.stride()
(24, 12, 4, 1)
>>> b = a.transpose(1, 2)  # Swaps 2nd and 3rd dimension
>>> b.size()
torch.Size([1, 3, 2, 4])
>>> b.stride()
(24, 4, 12, 1)
>>>
>>>
>>> c = a.view(1, 3, 2, 4)  # Does not change tensor layout in memory
>>> c.size()
torch.Size([1, 3, 2, 4])
>>>
>>> torch.equal(b, c)
False
>>>
>>>
>>> # 对于非连续的张量不可以调用view()方法
...
>>> b.is_contiguous()
False
>>> b.view(1, 3, 2, 4
... )
tensor([[[[ 0.5816,  2.0060,  1.6013, -0.6379],
          [ 1.1088,  1.2914, -1.4494, -1.7273]],

         [[-1.1943,  0.1426,  1.3612, -1.4171],
          [-0.2370, -1.7016, -0.2565,  1.4568]],

         [[ 0.2797, -0.5316,  0.6480,  2.6538],
          [ 0.7152,  1.1784,  0.0806,  0.7787]]]])
>>> b.view(1, 3, 4, 2)
Traceback (most recent call last):
  File "", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
>>>
>>> b.contiguous().view(1, 3, 4, 2)
tensor([[[[ 0.5816,  2.0060],
          [ 1.6013, -0.6379],
          [ 1.1088,  1.2914],
          [-1.4494, -1.7273]],

         [[-1.1943,  0.1426],
          [ 1.3612, -1.4171],
          [-0.2370, -1.7016],
          [-0.2565,  1.4568]],

         [[ 0.2797, -0.5316],
          [ 0.6480,  2.6538],
          [ 0.7152,  1.1784],
          [ 0.0806,  0.7787]]]])
>>>
>>>
>>>    

关于视图共享内存的实验:

(base) PS C:\Users\chenxuqi> python
Python 3.7.4 (default, Aug  9 2019, 18:34:13) [MSC v.1915 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001DE09103A90>
>>>
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.2824, -0.3715,  0.9088],
        [-1.7601, -0.1806,  2.0937]])
>>> x.shape
torch.Size([2, 3])
>>> x.stride()
(3, 1)
>>> y = x.view(6)
>>> y.size()
torch.Size([6])
>>> y.stride()
(1,)
>>> y
tensor([ 0.2824, -0.3715,  0.9088, -1.7601, -0.1806,  2.0937])
>>> # 底层共享内存
...
>>> y[0] = 20200910
>>> x[0][1] = 8888
>>> x
tensor([[ 2.0201e+07,  8.8880e+03,  9.0878e-01],
        [-1.7601e+00, -1.8060e-01,  2.0937e+00]])
>>> y
tensor([ 2.0201e+07,  8.8880e+03,  9.0878e-01, -1.7601e+00, -1.8060e-01,
         2.0937e+00])
>>>
>>>
>>>
>>>
>>>
>>>
>>>
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001DE09103A90>
>>> x = torch.randn(4, 3)
>>> y = x.permute(1, 0)
>>> x.shape
torch.Size([4, 3])
>>> y.shape
torch.Size([3, 4])
>>> x.stride()
(3, 1)
>>> y.stride()
(1, 3)
>>>
>>>
>>> # y不连续
... y.is_contiguous()
False
>>> x.is_contiguous()
True
>>>
>>> x.view(12)
tensor([ 0.2824, -0.3715,  0.9088, -1.7601, -0.1806,  2.0937,  1.0406, -1.7651,
         1.1216,  0.8440,  0.1783,  0.6859])
>>> y.view(12)
Traceback (most recent call last):
  File "", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
>>> y.contiguous().view(12)
tensor([ 0.2824, -1.7601,  1.0406,  0.8440, -0.3715, -0.1806, -1.7651,  0.1783,
         0.9088,  2.0937,  1.1216,  0.6859])
>>>
>>> z = y.contiguous().view(12)
>>> x
tensor([[ 0.2824, -0.3715,  0.9088],
        [-1.7601, -0.1806,  2.0937],
        [ 1.0406, -1.7651,  1.1216],
        [ 0.8440,  0.1783,  0.6859]])
>>> y
tensor([[ 0.2824, -1.7601,  1.0406,  0.8440],
        [-0.3715, -0.1806, -1.7651,  0.1783],
        [ 0.9088,  2.0937,  1.1216,  0.6859]])
>>> z
tensor([ 0.2824, -1.7601,  1.0406,  0.8440, -0.3715, -0.1806, -1.7651,  0.1783,
         0.9088,  2.0937,  1.1216,  0.6859])
>>> x[0][0] = 888
>>> y[1][1] = 999
>>> z[11] = 777
>>> x
tensor([[ 8.8800e+02, -3.7148e-01,  9.0878e-01],
        [-1.7601e+00,  9.9900e+02,  2.0937e+00],
        [ 1.0406e+00, -1.7651e+00,  1.1216e+00],
        [ 8.4397e-01,  1.7833e-01,  6.8588e-01]])
>>> y
tensor([[ 8.8800e+02, -1.7601e+00,  1.0406e+00,  8.4397e-01],
        [-3.7148e-01,  9.9900e+02, -1.7651e+00,  1.7833e-01],
        [ 9.0878e-01,  2.0937e+00,  1.1216e+00,  6.8588e-01]])
>>> z
tensor([ 2.8239e-01, -1.7601e+00,  1.0406e+00,  8.4397e-01, -3.7148e-01,
        -1.8060e-01, -1.7651e+00,  1.7833e-01,  9.0878e-01,  2.0937e+00,
         1.1216e+00,  7.7700e+02])
>>>
>>> # 结论: x和y共享内存,但是z和x、y这两者d不共享内存
...
>>>
>>>
>>>   

你可能感兴趣的:(torch.Tensor.view(*shape)方法的使用举例)