>>> a = torch.randn(1,1,24,24,24)
>>> b = a.unfold(2,8,8).unfold(3,8,8).unfold(4,8,8)
>>> b = b.view(1,27,512)
Traceback (most recent call last):
File "" , line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
在对数据进行维度转换时使用了view()
方法,但却报了如下错误
RuntimeError: view size is not compatible with input tensor’s size
and stride (at least one dimension spans across two contiguous
subspaces). Use .reshape(…) instead.
按照错误提示,将view()
换为reshape()
之后就正常了。
查阅Pytorch的官方文档[1.7.1]后总结如下:
view()
方法返回一个具有新的shape
的tensor
,返回的tensor
与原始的tensor
在Python中是不同的tensor
对象,但它们共享同一块内存上的值。验证方法如下所示:
>>> a = torch.randn(2,3)
>>> b = a.view(3,2)
>>>> id(a)
1583361270552
>>> id(b)
1583361345576
>>> a.data_ptr()
1584936860096
>>> b.data_ptr()
1584936860096
id()
返回对象的内存地址(Python中万物皆对象),data_ptr()
如字面意义上的,是tensor
的数据指针。从结果可以看到a
和b
指向同一块内存,因此view()
方法虽然返回了一个新的对象,但这个对象的数据指针指向的内存并没有发生改变。如果改变a
中的值,则b
也会发生改变。
>>> a
tensor([[-2.0974, 0.7608, -0.8227],
[-1.2951, 2.3984, 0.3418]])
>>> b
tensor([[-2.0974, 0.7608],
[-0.8227, -1.2951],
[ 2.3984, 0.3418]])
>>> a[1,2] = 5
>>> a
tensor([[-2.0974, 0.7608, -0.8227],
[-1.2951, 2.3984, 5.0000]])
>>> b
tensor([[-2.0974, 0.7608],
[-0.8227, -1.2951],
[ 2.3984, 5.0000]])
这时我们再回到一开始的问题,去验证tensor
经过unfold()
操作之后,数据指针是否发生了变化
>>> a = torch.randn(1,1,24,24,24)
>>> b = a.unfold(2,8,8).unfold(3,8,8).unfold(4,8,8)
>>> a.data_ptr()
1584939567424
>>> b.data_ptr()
1584939567424
指向的还是同一块内存,但是注意,data_ptr()
它仅仅返回第一个数据的内存地址,之后的内存是什么样的我们并不知道。于是再检查一下内存的连续性
>>> a.is_contiguous()
True
>>> b.is_contiguous()
False
问题找到了,经过unfold()
操作后,尽管a
和b
的头指针指向的还是同一块内存,但是b
中存储的数据的内存已经不连续了。我们可以通过stride()
方法看一下tensor
对象在内存中读取数据的步长
>>> b.stride()
(13824, 13824, 4608, 192, 8, 576, 24, 1)
>>> c = torch.randn(b.shape)
>>> c.stride()
(13824, 13824, 4608, 1536, 512, 64, 8, 1)
可以看到内存不连续的b
与连续内存存储的c
读取数据的步长是有区别的,这也就是报错中提到的view size is not compatible with input tensor's size and stride
尺寸与步长不匹配,故无法使用view()
方法。
现在再来看reshape()
的文档就很容易理解了。
Returns a tensor with the same data and number of elements as self
but with the specified shape. This method returns a view if shape is
compatible with the current shape. See torch.Tensor.view() on when it
is possible to return a view.
reshape()
会首先确认tensor
的目标尺寸与内存读取的步长是否匹配,如果匹配,则与view()
相同,返回一个指向同一块内存的tensor
对象;如果不匹配,则对tensor
的数据做一个拷贝,返回指向新的内存的tensor
对象。
>>> b.data_ptr()
1584939567424
>>> c = b.reshape(1,27,512)
>>> c.data_ptr() # b的内存不连续,故c会指向新的内存
1584942848064
>>> d = torch.randn(b.shape)
>>> e = d.reshape(1,27,512)
>>> d.data_ptr()
1584937934656
>>> e.data_ptr() # d是连续内存存储,故e与d指向同一块内存
1584937934656
view()
:不改变内存,只改变tensor
的维度信息,要求输入是contiguous
的。reshape()
:对输入进行判断,如果内存连续,则只改变维度信息,否则进行数据拷贝,返回指向新的内存的tensor