参考链接: view(*shape)
说明: 一句话概括,对一个连续的(contiguous)张量维度重新布局,但内存上不进行移动,仅仅返回一个视图.
(base) PS C:\Users\chenxuqi> python
Python 3.7.4 (default, Aug 9 2019, 18:34:13) [MSC v.1915 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.randn(4, 4)
>>> x.shape
torch.Size([4, 4])
>>> x.size()
torch.Size([4, 4])
>>> x.stride()
(4, 1)
>>>
>>> y = x.view(16)
>>> y.size()
torch.Size([16])
>>> y.stride()
(1,)
>>>
>>>
>>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions
>>> z.size()
torch.Size([2, 8])
>>> z.stride()
(8, 1)
>>>
>>>
>>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension
Traceback (most recent call last):
File "" , line 1, in <module>
NameError: name 'a' is not defined
>>>
>>>
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x0000019148EC2A90>
>>> a = torch.randn(1, 2, 3, 4)
>>> a.size()
torch.Size([1, 2, 3, 4])
>>> a.stride()
(24, 12, 4, 1)
>>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension
>>> b.size()
torch.Size([1, 3, 2, 4])
>>> b.stride()
(24, 4, 12, 1)
>>>
>>>
>>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory
>>> c.size()
torch.Size([1, 3, 2, 4])
>>>
>>> torch.equal(b, c)
False
>>>
>>>
>>> # 对于非连续的张量不可以调用view()方法
...
>>> b.is_contiguous()
False
>>> b.view(1, 3, 2, 4
... )
tensor([[[[ 0.5816, 2.0060, 1.6013, -0.6379],
[ 1.1088, 1.2914, -1.4494, -1.7273]],
[[-1.1943, 0.1426, 1.3612, -1.4171],
[-0.2370, -1.7016, -0.2565, 1.4568]],
[[ 0.2797, -0.5316, 0.6480, 2.6538],
[ 0.7152, 1.1784, 0.0806, 0.7787]]]])
>>> b.view(1, 3, 4, 2)
Traceback (most recent call last):
File "" , line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
>>>
>>> b.contiguous().view(1, 3, 4, 2)
tensor([[[[ 0.5816, 2.0060],
[ 1.6013, -0.6379],
[ 1.1088, 1.2914],
[-1.4494, -1.7273]],
[[-1.1943, 0.1426],
[ 1.3612, -1.4171],
[-0.2370, -1.7016],
[-0.2565, 1.4568]],
[[ 0.2797, -0.5316],
[ 0.6480, 2.6538],
[ 0.7152, 1.1784],
[ 0.0806, 0.7787]]]])
>>>
>>>
>>>
关于视图共享内存的实验:
(base) PS C:\Users\chenxuqi> python
Python 3.7.4 (default, Aug 9 2019, 18:34:13) [MSC v.1915 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001DE09103A90>
>>>
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.2824, -0.3715, 0.9088],
[-1.7601, -0.1806, 2.0937]])
>>> x.shape
torch.Size([2, 3])
>>> x.stride()
(3, 1)
>>> y = x.view(6)
>>> y.size()
torch.Size([6])
>>> y.stride()
(1,)
>>> y
tensor([ 0.2824, -0.3715, 0.9088, -1.7601, -0.1806, 2.0937])
>>> # 底层共享内存
...
>>> y[0] = 20200910
>>> x[0][1] = 8888
>>> x
tensor([[ 2.0201e+07, 8.8880e+03, 9.0878e-01],
[-1.7601e+00, -1.8060e-01, 2.0937e+00]])
>>> y
tensor([ 2.0201e+07, 8.8880e+03, 9.0878e-01, -1.7601e+00, -1.8060e-01,
2.0937e+00])
>>>
>>>
>>>
>>>
>>>
>>>
>>>
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001DE09103A90>
>>> x = torch.randn(4, 3)
>>> y = x.permute(1, 0)
>>> x.shape
torch.Size([4, 3])
>>> y.shape
torch.Size([3, 4])
>>> x.stride()
(3, 1)
>>> y.stride()
(1, 3)
>>>
>>>
>>> # y不连续
... y.is_contiguous()
False
>>> x.is_contiguous()
True
>>>
>>> x.view(12)
tensor([ 0.2824, -0.3715, 0.9088, -1.7601, -0.1806, 2.0937, 1.0406, -1.7651,
1.1216, 0.8440, 0.1783, 0.6859])
>>> y.view(12)
Traceback (most recent call last):
File "" , line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
>>> y.contiguous().view(12)
tensor([ 0.2824, -1.7601, 1.0406, 0.8440, -0.3715, -0.1806, -1.7651, 0.1783,
0.9088, 2.0937, 1.1216, 0.6859])
>>>
>>> z = y.contiguous().view(12)
>>> x
tensor([[ 0.2824, -0.3715, 0.9088],
[-1.7601, -0.1806, 2.0937],
[ 1.0406, -1.7651, 1.1216],
[ 0.8440, 0.1783, 0.6859]])
>>> y
tensor([[ 0.2824, -1.7601, 1.0406, 0.8440],
[-0.3715, -0.1806, -1.7651, 0.1783],
[ 0.9088, 2.0937, 1.1216, 0.6859]])
>>> z
tensor([ 0.2824, -1.7601, 1.0406, 0.8440, -0.3715, -0.1806, -1.7651, 0.1783,
0.9088, 2.0937, 1.1216, 0.6859])
>>> x[0][0] = 888
>>> y[1][1] = 999
>>> z[11] = 777
>>> x
tensor([[ 8.8800e+02, -3.7148e-01, 9.0878e-01],
[-1.7601e+00, 9.9900e+02, 2.0937e+00],
[ 1.0406e+00, -1.7651e+00, 1.1216e+00],
[ 8.4397e-01, 1.7833e-01, 6.8588e-01]])
>>> y
tensor([[ 8.8800e+02, -1.7601e+00, 1.0406e+00, 8.4397e-01],
[-3.7148e-01, 9.9900e+02, -1.7651e+00, 1.7833e-01],
[ 9.0878e-01, 2.0937e+00, 1.1216e+00, 6.8588e-01]])
>>> z
tensor([ 2.8239e-01, -1.7601e+00, 1.0406e+00, 8.4397e-01, -3.7148e-01,
-1.8060e-01, -1.7651e+00, 1.7833e-01, 9.0878e-01, 2.0937e+00,
1.1216e+00, 7.7700e+02])
>>>
>>> # 结论: x和y共享内存,但是z和x、y这两者d不共享内存
...
>>>
>>>
>>>