pytorch中的张量纬度变化

pytorch中的张量纬度变化

1. 张量元素的顺序

pytorch中的张量纬度变化_第1张图片
我们一般常用的是4维张量,即 ( N ∗ H ∗ W ∗ C ) (N*H*W*C) (NHWC)。张量的排列是从最后一维开始排,然后依次排到第一维。例如对于一个shape为 ( 2 ∗ 3 ∗ 4 ) (2*3*4) (234)的张量,元素为0~23,首先对于将24个元素每4个打包,有6个小包;然后将6个包每3个打包,有2个中包得到最终的张量。

x = torch.arange(24)
print(x,'\n')
# tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
#         18, 19, 20, 21, 22, 23]) 

y = x.reshape(2,3,4)
print(y,'\n')
# tensor([[[ 0,  1,  2,  3],
#          [ 4,  5,  6,  7],
#          [ 8,  9, 10, 11]],

#         [[12, 13, 14, 15],
#          [16, 17, 18, 19],
#          [20, 21, 22, 23]]])

view与其他纬度变更函数

  • view与reshape:都是对shape进行改变。但是view的操作对象地址必须是连续的,而shape没有限制。即如果对tensor使用了permute,transpose之类的函数会使得变量的地址不连续,这种情况下如果要使用view,则需要先对变换后的tensor连续化,即contiguous()。而reshape则不需要。
import torch
x = torch.rand(2,3,4)
y = x.permute(1,0,2).contiguous().view(-1,4)
z = x.permute(1,0,2).reshape(-1,4)
y.equal(z)
True
  • view和reshape不会改变数据的原始排列
x = torch.arange(24)
print(x,'\n')
# tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
#         18, 19, 20, 21, 22, 23]) 

y = x.reshape(2,3,4)
print(y,'\n')
# tensor([[[ 0,  1,  2,  3],
#          [ 4,  5,  6,  7],
#          [ 8,  9, 10, 11]],

#         [[12, 13, 14, 15],
#          [16, 17, 18, 19],
#          [20, 21, 22, 23]]])

z = y.reshape(24)
print(z,'\n')
# tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
#         18, 19, 20, 21, 22, 23]) 
  • 交换张量的纬度函数,例如transpose、permute、flatten等函数会保持纬度上的元素,而改变元素的原始排列。
#y = x.reshape(2,3,4)
print(y,'\n')
# tensor([[[ 0,  1,  2,  3],
#          [ 4,  5,  6,  7],
#          [ 8,  9, 10, 11]],

#         [[12, 13, 14, 15],
#          [16, 17, 18, 19],
#          [20, 21, 22, 23]]]) 

a = y.transpose(1,2)
print(a,'\n')
# tensor([[[ 0,  4,  8],
#          [ 1,  5,  9],
#          [ 2,  6, 10],
#          [ 3,  7, 11]],

#         [[12, 16, 20],
#          [13, 17, 21],
#          [14, 18, 22],
#          [15, 19, 23]]]) 

b = y.transpose(0,2)
print(b,'\n')
# tensor([[[ 0, 12],
#          [ 4, 16],
#          [ 8, 20]],

#         [[ 1, 13],
#          [ 5, 17],
#          [ 9, 21]],

#         [[ 2, 14],
#          [ 6, 18],
#          [10, 22]],

#         [[ 3, 15],
#          [ 7, 19],
#          [11, 23]]]) 


c  = t.transpose(0,2)
print(c,'\n')
# tensor([[[ 0,  1,  2,  3],
#          [ 4,  5,  6,  7],
#          [ 8,  9, 10, 11]],

#         [[12, 13, 14, 15],
#          [16, 17, 18, 19],
#          [20, 21, 22, 23]]]) 
  • 在这个例子中,y使用的是reshape,它与x的顺序相同;a是对第二维和第三维进行了交换,这里的交换类似于矩阵中的转置操作,观察y可以发现,对于纬度长度为4的数据为[0,1,2,3]…;纬度长度为3的数据为[0,4,8]…;纬度长度为2的数据为[0,12]…。对于交换纬度顺序这种操作,纬度上的数据是不变的,但是数据的内部顺序是变化了的。如a,它在最后一维上的数字已经不再连续了。

  • view和reshape是保持数据的内部顺序不变,但是纬度上对应的数据可能会改变;而transpose、permute等则是保持纬度上对应的数据不变,但是内部的顺序可能会改变。因此仅仅使用view或reshape可以变换到初始的样子;仅仅使用transpose和permute这类函数也可以变换到初始的样子;但是如果两种函数有混合,可能无法恢复原数据的顺序。

你可能感兴趣的:(pytorch,pytorch,深度学习,python)