在炼丹日常,很多模型为了保持tensor的计算便利和维度统一(例如resnet)会涉及到很多的维度转换,这时候很容易就绕晕,在参加飞浆的transformer课程中,学习了很多灵活变化维度的方法,这些方法能够让我们减少相当一部分的维度转换的复杂度。下面逐个介绍常用的维度转换方法:
简单的reshape方法我们就不进行介绍,主要介绍reshape中-1的灵活运用,我们可以轻松的避免一些计算。
import paddle
import paddle.nn as nn
a = paddle.randn([4, 12, 12, 3])
b = a.reshape([4, -1, 3])
print(a.shape, b.shape)
输出
[4, 12, 12, 3]
[4, 144, 3]
可以看到,在vision transformer中,我们一开始将图像进行embedding,首先需要将图片(n, h, w, c)变换输入到一维的(n, h * w , c)格式,如果我们使用reshape([4, 144, 3]),也可以将图片进行变换,但是当涉及到复杂的维度转换,我们很有可能出错,但是通过使用这个,当我们只需要第一维和最后一维确定时,那么我们只需要简单地将中间的设置为-1。
a = paddle.randn([4, 12, 12, 3])
b = a.flatten(1)
c = a.flatten(2)
d = a.flatten(3)
e = a.flatten()
print(a.shape, b.shape, c.shape, d.shape, e.shape)
输出
[4, 12, 12, 3]
[4, 432]
[4, 12, 36]
[4, 12, 12, 3]
[1728]
利用flatten方法,我们可以轻松的达到将后几个维度进行合并的效果。
a = paddle.randn([4, 12, 12, 3])
b = a.transpose([0, 1, 2, 3])
c = a.transpose([0, 1, 3, 2])
#d 转换的 c
d = c.transpose([0, 1, 3, 2])
print(a.shape, b.shape, c.shape, d.shape)
输出
[4, 12, 12, 3]
[4, 12, 12, 3]
[4, 12, 3, 12]
[4, 12, 12, 3]#回归原维度顺序
可以看到,transpose可以方便的进行多个维度之间替换,对替换后的tensor进行再次相同的替换(例如d,c之间),可以回归到原维度顺序。
x = torch.linspace(-math.pi, math.pi, 2000)
x1 = x
x2 = x.unsqueeze(-1)
x3 = x.squeeze(-1)
print(x1.shape, x2.shape, x3.shape)
输出
torch.Size([2000])
torch.Size([2000, 1])
torch.Size([2000])
这里的squeeze与unsqueeze从字面意义上是一种相反的操作,究其本质,即对某个维度进行升维,如unsqueeze(-1),就是从最后一维增加一维,变成二维,不过这里只对维度为1的进行维度删减。