目录
1. torch.reshape(shape) 和 torch.view(shape)函数用法
2. 当处理的tensor是连续性的(contiguous)
3. 当处理的tensor是非连续性的(contiguous)
4. PyTorch中的contiguous
在本文开始之前,需要了解最基础的Tensor存储方式,具体见 Tensor数据类型与存储结构
注:如果不想继续往下看,就无脑使用reshape()函数来进行tensor处理!!
torch.reshape() 和 torch.view()不会修改tensor内部的值,只是对tensor的形状进行变化,里面只包含了shape的参数,shape为当前tensor改变后的形状
示例:将x改变成shape为[2,3] 和 [3,2]的方式,reshape和view均可
x = torch.tensor([1, 2, 3, 4, 5, 6])
y1 = x.reshape(2, 3)
y2 = x.view(3, 2)
print(y1.shape, y2.shape) # torch.Size([2, 3]) torch.Size([3, 2])
当tensor是连续的,torch.reshape() 和 torch.view()这两个函数的处理过程也是相同的,即两者均不会开辟新的内存空间,也不会产生数据的副本,只是新建了一份tensor的头信息区,并在头信息区中指定重新指定某些信息,如stride , 并没有修改这个tensor的存储区 Storage。
示例:尽管y1和y2的stride不同,但他们的storage()是相同的
x = torch.tensor([1, 2, 3, 4, 5, 6])
y1 = x.reshape(2, 3)
y2 = x.view(3, 2)
print(y1.stride(), y2.stride()) # (3, 1) (2, 1)
print(y1.storage().data_ptr() == y2.storage().data_ptr()) # True
view():在调用view()函数之前需要先调用 contiguous()方法,即x.contiguous().view()。但这种方法变换后的tensor就不是与原始tensor共享内存了,而是重新开辟了一个空间。
reshape():仅需要直接调用reshape()函数即可,返回结果等同于 contiguous().view()。
示例:通过transpose将x变成了 uncontiguous,然后使用contiguous().view()和reshape()均会重新开辟一个新空间,不与原始tensor共享内存了。
x = torch.tensor([1, 2, 3, 4, 5, 6]).reshape(2, 3).transpose(0, 1)
print(x.is_contiguous()) # False
print(x.storage().data_ptr()) # 1690151067136
y1 = x.contiguous().view(2, 3)
print(y1.storage().data_ptr()) # 1690151068544
y2 = x.reshape(2, 3)
print(y2.storage().data_ptr()) # 1690151076736
上面说了这么多,什么情况下是连续的,什么时候不连续呢?很简单,在PyTorch中定义了:(不)连续:Tensor底层一维数组元素的存储顺序与Tensor按行优先一维展开的元素顺序是否一致
比如:可以通过storage()查看底层一维数组的存储顺序,通过is_contiguous()判断是否连续
x = torch.tensor([1, 2, 3, 4, 5, 6]).reshape(2, 3)
print(x.storage()) # 1 2 3 4 5 6
print(x.is_contiguous()) # True
xt = x.transpose(0, 1)
print(xt.storage()) # 转置后,底层存储仍为 1 2 3 4 5 6
print(xt.is_contiguous()) # False
x按行展开后与x.stoage()相同,而xt按行展开后与stoage()不同,所以x是连续的,xt是不连续的。那我们继续实验,将xt再次进行转置,变成 xtt,xtt与原始x相同,所以变为连续
x = torch.tensor([1, 2, 3, 4, 5, 6]).reshape(2, 3)
print(x.storage()) # 1 2 3 4 5 6
print(x.is_contiguous()) # True
xtt = x.transpose(0, 1).transpose(0, 1)
print(xtt.storage()) # 转置的转置后,底层存储为 1 2 3 4 5 6,按行展开仍为 1 2 3 4 5 6
print(xtt.is_contiguous()) # True
导致导致不连续的函数常见有:transpose(), permute(), narrow(), expand()等
注:推荐一篇讲contiguous非常好的文章:PyTorch中的contiguous - 知乎