在使用pytorch或者读别人的代码时,reshape和view是最常用的,都是矩阵变形,二者到底有什么区别呢?
原文地址:https://discuss.pytorch.org/t/difference-between-view-reshape-and-permute/54157
原作者:ptrblck
如果可以的话,reshape 会尝试返回view,否则会将数据复制到连续(contiguous)张量并返回其上的view。
这句话就概括了官方文档,下面再看reshape的官方文档确认一下。
Returns a tensor with the same data and number of elements as input, but with the specified shape. When possible, the returned tensor will be a view of input. Otherwise, it will be a copy. Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.
返回一个张量,其数据和元素数量与输入相同,但具有指定的形状。
如果可能,返回的张量将是输入的视图。否则,它将是副本。
连续输入和具有兼容步幅的输入可以在不复制的情况下进行reshape
,但不应依赖于copying vs. viewing
行为。
可以查看 torch.Tensor.view() 确定何时会返回view。
>>> x = torch.arange(4*10*2).view(4, 10, 2)
>>> x.shape
torch.Size([4, 10, 2])
>>> x.is_contiguous()
True
# view在连续张量上可以正常工作
>>> print(x.view(-1))
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
72, 73, 74, 75, 76, 77, 78, 79])
# Reshape works on non-contugous tensors (contiguous() + view)
>>> y = x.permute(2, 0, 1)
>>> y.shape
torch.Size([2, 4, 10])
>>> y.is_contiguous()
False
>>> try:
... print(y.view(-1))
... except RuntimeError as e:
... print(e)
view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
报错信息:view大小与输入张量的大小和步幅不兼容(至少有一个维度跨越两个连续的子空间)。请改用 .reshape()。
# reshape可以在不连续的张量上正常工作
>>> print(y.reshape(-1))
tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34,
36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70,
72, 74, 76, 78, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27,
29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63,
65, 67, 69, 71, 73, 75, 77, 79])
# reshape就等价于先调用contiguous()方法再调用view方法
>>> print(y.contiguous().view(-1))
tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34,
36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70,
72, 74, 76, 78, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27,
29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63,
65, 67, 69, 71, 73, 75, 77, 79])