调整Tensor的shape(通过返回一个新的Tensor),在老版本中这个函数是view()
,功能上都是一样的。
import torch
a = torch.rand(4, 1, 28, 28)
print(a.shape)
print(a.reshape(4 * 1, 28, 28).shape)
print(a.reshape(4, 1 * 28 * 28).shape)
运行结果:
torch.Size([4, 1, 28, 28])
torch.Size([4, 28, 28])
torch.Size([4, 784])
正的索引是在那个维度原本的位置前面插入这个新增加的维度,负的索引是在那个位置之后插入。
print(a.shape)
print(a.unsqueeze(0).shape) # 在0号维度位置插入一个维度
print(a.unsqueeze(-1).shape) # 在最后插入一个维度
print(a.unsqueeze(3).shape) # 在3号维度位置插入一个维度
运行结果:
torch.Size([4, 1, 28, 28])
torch.Size([1, 4, 1, 28, 28])
torch.Size([4, 1, 28, 28, 1])
torch.Size([4, 1, 28, 1, 28])
删减维度实际上是一个压榨的过程,直观地看是把那些多余的[]
给去掉,也就是只是去删除那些size=1的维度。
import torch
a = torch.Tensor(1, 4, 1, 9)
print(a.shape)
print(a.squeeze().shape) # 能删除的都删除掉
print(a.squeeze(0).shape) # 尝试删除0号维度,ok
print(a.squeeze(2).shape) # 尝试删除2号维度,ok
print(a.squeeze(3).shape) # 尝试删除3号维度,3号维度是9不是1,删除失败
运行结果:
torch.Size([1, 4, 1, 9])
torch.Size([4, 9])
torch.Size([4, 1, 9])
torch.Size([1, 4, 9])
torch.Size([1, 4, 1, 9])
expand就是在某个size=1的维度上改变size,改成更大的一个大小,实际就是在每个size=1的维度上的标量的广播操作。
import torch
b = torch.rand(32)
f = torch.rand(4, 32, 14, 14)
# 想要把b加到f上面去
# 先进行维度增加
b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(b.shape)
# 再进行维度扩展
b = b.expand(4, -1, 14, 14) # -1表示这个维度保持不变,这里写32也可以
print(b.shape)
运行结果:
torch.Size([1, 32, 1, 1])
torch.Size([4, 32, 14, 14])
repeat就是将每个位置的维度都重复至指定的次数,以形成新的Tensor。repeat会重新申请内存空间。
# 维度增加...
print(b.shape)
# 维度重复,32这里不想进行重复,所以就相当于"重复至1次"
b = b.repeat(4, 1, 14, 14)
print(b.shape)
运行结果:
torch.Size([1, 32, 1, 1])
torch.Size([4, 32, 14, 14])
只适用于dim=2的Tensor。
c = torch.Tensor(2, 4)
print(c.t().shape)
运行结果:
torch.Size([4, 2])
注意这种交换使得存储不再连续,再执行一些reshape的操作肯定是执行不了的,所以要调用一下contiguous()
使其变成连续的维度。
d = torch.Tensor(6, 3, 1, 2)
print(d.transpose(1, 3).contiguous().shape) # 1号维度和3号维度交换
运行结果:
torch.Size([6, 2, 1, 3])
下面这个例子比较一下每个位置上的元素都是一致的,来验证一下这个交换->压缩shape->展开shape->交换回去是没有问题的。
e = torch.rand(4, 3, 6, 7)
e2 = e.transpose(1, 3).contiguous().reshape(4, 7 * 6 * 3).reshape(4, 7, 6, 3).transpose(1, 3)
print(e2.shape)
# 比较下两个Tensor所有位置上的元素是否都相等
print(torch.all(torch.eq(e, e2)))
运行结果:
torch.Size([4, 3, 6, 7])
tensor(1, dtype=torch.uint8)
有趣的是,这个例子里的rand改成Tensor结果就是0(表示FALSE)了,因为Tensor根本就是是未初始化的数据,有些可能根本没法比较数值。
如果四个维度表示上节的 [ b a t c h , c h a n n e l , h , w ] [batch,channel,h,w] [batch,channel,h,w],如果想把 c h a n n e l channel channel放到最后去,形成 [ b a t c h , h , w , c h a n n e l ] [batch,h,w,channel] [batch,h,w,channel],那么如果使用前面的维度交换,至少要交换两次(先13交换再12交换)。而使用permute可以直接指定维度新的所处位置,方便很多。
h = torch.rand(4, 3, 6, 7)
print(h.permute(0, 2, 3, 1).shape)
运行结果:
torch.Size([4, 6, 7, 3])