详解Pytorch中view()和reshape()的区别

问题描述:


>>> a = torch.randn(1,1,24,24,24)
>>> b = a.unfold(2,8,8).unfold(3,8,8).unfold(4,8,8)
>>> b = b.view(1,27,512)
Traceback (most recent call last):
  File "", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

在对数据进行维度转换时使用了view()方法,但却报了如下错误

RuntimeError: view size is not compatible with input tensor’s size
and stride (at least one dimension spans across two contiguous
subspaces). Use .reshape(…) instead.

按照错误提示,将view()换为reshape()之后就正常了。


原因分析:

查阅Pytorch的官方文档[1.7.1]后总结如下:

view()

view()方法返回一个具有新的shapetensor,返回的tensor与原始的tensor在Python中是不同的tensor对象,但它们共享同一块内存上的值。验证方法如下所示:

>>> a = torch.randn(2,3)
>>> b = a.view(3,2)
>>>> id(a)
1583361270552
>>> id(b)
1583361345576
>>> a.data_ptr()
1584936860096
>>> b.data_ptr()
1584936860096
  • id()返回对象的内存地址(Python中万物皆对象),data_ptr()如字面意义上的,是tensor的数据指针。

从结果可以看到ab指向同一块内存,因此view()方法虽然返回了一个新的对象,但这个对象的数据指针指向的内存并没有发生改变。如果改变a中的值,则b也会发生改变。

>>> a
tensor([[-2.0974,  0.7608, -0.8227],
        [-1.2951,  2.3984,  0.3418]])
>>> b
tensor([[-2.0974,  0.7608],
        [-0.8227, -1.2951],
        [ 2.3984,  0.3418]])
>>> a[1,2] = 5
>>> a
tensor([[-2.0974,  0.7608, -0.8227],
        [-1.2951,  2.3984,  5.0000]])
>>> b
tensor([[-2.0974,  0.7608],
        [-0.8227, -1.2951],
        [ 2.3984,  5.0000]])

这时我们再回到一开始的问题,去验证tensor经过unfold()操作之后,数据指针是否发生了变化


>>> a = torch.randn(1,1,24,24,24)
>>> b = a.unfold(2,8,8).unfold(3,8,8).unfold(4,8,8)
>>> a.data_ptr()
1584939567424
>>> b.data_ptr()
1584939567424

指向的还是同一块内存,但是注意,data_ptr()它仅仅返回第一个数据的内存地址,之后的内存是什么样的我们并不知道。于是再检查一下内存的连续性

>>> a.is_contiguous()
True
>>> b.is_contiguous()
False

问题找到了,经过unfold()操作后,尽管ab头指针指向的还是同一块内存,但是b中存储的数据的内存已经不连续了。我们可以通过stride()方法看一下tensor对象在内存中读取数据的步长

>>> b.stride()
(13824, 13824, 4608, 192, 8, 576, 24, 1)
>>> c = torch.randn(b.shape)
>>> c.stride()
(13824, 13824, 4608, 1536, 512, 64, 8, 1)

可以看到内存不连续的b与连续内存存储的c读取数据的步长是有区别的,这也就是报错中提到的view size is not compatible with input tensor's size and stride尺寸与步长不匹配,故无法使用view()方法。

reshape()

现在再来看reshape()的文档就很容易理解了。

Returns a tensor with the same data and number of elements as self
but with the specified shape. This method returns a view if shape is
compatible with the current shape. See torch.Tensor.view() on when it
is possible to return a view.

reshape()会首先确认tensor的目标尺寸与内存读取的步长是否匹配,如果匹配,则与view()相同,返回一个指向同一块内存的tensor对象;如果不匹配,则对tensor的数据做一个拷贝,返回指向新的内存的tensor对象。

>>> b.data_ptr()
1584939567424
>>> c = b.reshape(1,27,512)
>>> c.data_ptr() # b的内存不连续,故c会指向新的内存
1584942848064
>>> d = torch.randn(b.shape)
>>> e = d.reshape(1,27,512)
>>> d.data_ptr()
1584937934656
>>> e.data_ptr() # d是连续内存存储,故e与d指向同一块内存
1584937934656

总结:

  • view():不改变内存,只改变tensor的维度信息,要求输入是contiguous的。
  • reshape():对输入进行判断,如果内存连续,则只改变维度信息,否则进行数据拷贝,返回指向新的内存的tensor

你可能感兴趣的:(深度学习,python,pytorch,pytorch,python)