大体形式:
x[:, :, :, :]
这个操作是最基本,也是初学时最难理解的一个操作。不管是在np.array数组中,还是在torch.tensor中,都可以用这种通用方式去切片出我们需要的矩阵。
简单切片操作:
x[idx_start:idx_end:stride] #x[起点:终点:步长]
带逗号的切片操作:
x[idx_start:idx_end, idx_start:idx_end:stride]
逗号的作用是区分维度(记住这个,基本就理解这类语法了),如果步长取-1,则代表从后往前取,但是要注意一点,逗号前面的不能限定步长。另外,补充一个常识:遇到这种[m:n]语法时,牢记左闭右开,即左侧m能取到,右侧n取不到,只能取到n-1(联想range(0, n)这种语法同理,取值范围是0~n-1)。
示例:
取第三个维度的5-7维,取第四个维度的0, 1维
x = torch.zeros(8, 1, 16, 3)
y = x[:, :, 5:8, :-1]
print(y.shape)
torch.Size([8, 1, 3, 2])
小心连续切片的问题!
还是上面举的例子,取第三个维度的5-7维,取第四个维度的0, 1维,如果按矩阵取值的方式(用多个中括号去分别定位每个维度)去写会是下面这样:
x = torch.zeros(8, 1, 16, 3)
y = x[:][:][5:8][:-1]
print(y.shape)
torch.Size([2, 1, 16, 3])
这个结果显然是错误的!因为把第一个维度中的8给切成了2,别的都没变化。
所以这里理解连续切片的概念,我们debug分析上面的代码:
由此可见,这种连续切片的方式(理解为连续多个中括号) 并没有分别去改变每个维度,一定是 在上一步切片的结果上,进行一次新的切片,和x[3][1][2][4]矩阵取值的思路完全不一样!
如果要实现对每个维度的分别切片,还得用上面那个例子中的写法:
x[:, :, 5:8, :-1]
上面我只提及了常见的用法和坑点,详细的教程和例子可以参考:
大体形式:
x.squeeze(dim=n)
x.unsqueeze(dim=n)
很多时候,我们都需要将矩阵展开维度或压缩维度后进行矩阵的运算。squeeze()函数为压缩操作,将某一个维度值为1的维度进行删减,或将多个维度值为1的维度进行删减;unsqueeze()函数为展开操作,将某个维度补上维度值为1的维度。
示例:
在256和32对应维度之间补上一个维度1
x = torch.zeros(8, 256, 32, 64)
y = x.unsqueeze(dim=2)
print(y.shape)
torch.Size([8, 256, 1, 32, 64])
删掉前面那个维度1
x = torch.zeros(8, 1, 256, 1)
y = x.squeeze(dim=1)
print(y.shape)
torch.Size([8, 256, 1])
同时删掉多个(所有)维度1
x = torch.zeros(8, 1, 256, 1)
y = x.squeeze() # 留空就是删掉所有维度1
print(y.shape)
torch.Size([8, 256])
大体形式:
x.transpose(m, n)
transpose()代表转置,即线性代数中将两个位置进行交换,在高维矩阵中同理。
示例:
将256与32对应维度进行交换
x = torch.zeros(8, 256, 32, 64)
y = x.transpose(1, 2)
print(y.shape)
torch.Size([8, 32, 256, 64])
大体形式:
x.permute(c, b, a, d)
permute()函数是transpose()的更一般形式,因为它可以同时处理多个位置的顺序变换。这个操作非常好理解,假如原矩阵是(8, 256, 32, 64)的维度,那么0位置对应8,1位置对应256,2位置对应32,3位置对应32,如果我们想让矩阵变为(256, 8, 32, 64),那就需要交换位置0和1,于是语法如下:
x.permute(1, 0, 2, 3)
示例:
8对应维度不动,256对应维度移动到末尾,32对应维度移动到第二个位置,64对应维度移动到倒数第二个位置
x = torch.zeros(8, 256, 32, 64)
y = x.permute(0, 2, 3, 1)
print(y.shape)
torch.Size([8, 32, 64, 256])
不改变各个维度的位置
x = torch.zeros(8, 256, 32, 64)
y = x.permute(0, 1, 2, 3) # 注意:这个就是表示原位置
print(y.shape)
torch.Size([8, 256, 32, 64])
大体形式:
x.flattten(m, n)
在神经网络搭建中,时常会在全连接之前将矩阵的某两个维度合并在一起(如H x W),这个操作叫做展平(flatten)。 注意,flatten()函数中的两个位置索引要求m
示例:
展平最后两维
x = torch.zeros(8, 1, 16, 3)
y = x.flatten(-2, -1)
print(y.shape)
torch.Size([8, 1, 48])