Pytorch基础 - 6. torch.reshape() 和 torch.view()

目录

1. torch.reshape(shape) 和 torch.view(shape)函数用法

2. 当处理的tensor是连续性的(contiguous)

3. 当处理的tensor是非连续性的(contiguous)

4. PyTorch中的contiguous


在本文开始之前,需要了解最基础的Tensor存储方式,具体见 Tensor数据类型与存储结构

注:如果不想继续往下看,就无脑使用reshape()函数来进行tensor处理!!

1. torch.reshape(shape) 和 torch.view(shape)函数用法

torch.reshape() 和 torch.view()不会修改tensor内部的值,只是对tensor的形状进行变化,里面只包含了shape的参数,shape为当前tensor改变后的形状

示例:将x改变成shape为[2,3] 和 [3,2]的方式,reshape和view均可 

x = torch.tensor([1, 2, 3, 4, 5, 6])
y1 = x.reshape(2, 3)
y2 = x.view(3, 2)
print(y1.shape, y2.shape)   # torch.Size([2, 3]) torch.Size([3, 2])

2. 当处理的tensor是连续性的(contiguous)

当tensor是连续的,torch.reshape() 和 torch.view()这两个函数的处理过程也是相同的,即两者均不会开辟新的内存空间,也不会产生数据的副本,只是新建了一份tensor的头信息区,并在头信息区中指定重新指定某些信息,如stride , 并没有修改这个tensor的存储区 Storage

示例:尽管y1和y2的stride不同,但他们的storage()是相同的

x = torch.tensor([1, 2, 3, 4, 5, 6])
y1 = x.reshape(2, 3)
y2 = x.view(3, 2)
print(y1.stride(), y2.stride())   # (3, 1) (2, 1)
print(y1.storage().data_ptr() == y2.storage().data_ptr())   # True

3. 当处理的tensor是非连续性的(contiguous)

view():在调用view()函数之前需要先调用 contiguous()方法,即x.contiguous().view()。但这种方法变换后的tensor就不是与原始tensor共享内存了,而是重新开辟了一个空间

reshape():仅需要直接调用reshape()函数即可,返回结果等同于 contiguous().view()

示例:通过transpose将x变成了 uncontiguous,然后使用contiguous().view()和reshape()均会重新开辟一个新空间,不与原始tensor共享内存了。

x = torch.tensor([1, 2, 3, 4, 5, 6]).reshape(2, 3).transpose(0, 1)
print(x.is_contiguous())    # False
print(x.storage().data_ptr())   # 1690151067136
y1 = x.contiguous().view(2, 3)
print(y1.storage().data_ptr())  # 1690151068544
y2 = x.reshape(2, 3)
print(y2.storage().data_ptr())  # 1690151076736

4. PyTorch中的contiguous

上面说了这么多,什么情况下是连续的,什么时候不连续呢?很简单,在PyTorch中定义了:(不)连续:Tensor底层一维数组元素的存储顺序与Tensor按行优先一维展开的元素顺序是否一致

比如:可以通过storage()查看底层一维数组的存储顺序,通过is_contiguous()判断是否连续

x = torch.tensor([1, 2, 3, 4, 5, 6]).reshape(2, 3)
print(x.storage())  #  1 2 3 4 5 6
print(x.is_contiguous())    # True
xt = x.transpose(0, 1)
print(xt.storage()) # 转置后,底层存储仍为 1 2 3 4 5 6
print(xt.is_contiguous())    # False

x按行展开后与x.stoage()相同,而xt按行展开后与stoage()不同,所以x是连续的,xt是不连续的。那我们继续实验,将xt再次进行转置,变成 xtt,xtt与原始x相同,所以变为连续

x = torch.tensor([1, 2, 3, 4, 5, 6]).reshape(2, 3)
print(x.storage())  #  1 2 3 4 5 6
print(x.is_contiguous())   # True
xtt = x.transpose(0, 1).transpose(0, 1)
print(xtt.storage()) # 转置的转置后,底层存储为 1 2 3 4 5 6,按行展开仍为 1 2 3 4 5 6
print(xtt.is_contiguous())  # True

导致导致不连续的函数常见有:transpose(), permute(), narrow(), expand()等

注:推荐一篇讲contiguous非常好的文章:PyTorch中的contiguous - 知乎

你可能感兴趣的:(#,Pytorch操作,pytorch,深度学习,python,reshape,view)