PyTorch 维度变换

view,reshape。

a = torch.rand(4,1,28,28)
print(a.shape)

b = a.view(4,28*28)
print(b)
print(b.shape)

b = a.view(4*28,28)
print(b.shape)

b = a.view(4*1,28,28)
print(b.shape)

b = a.view(4,784)#不建议使用,会破坏维度信息,恢复时会存在问题
print(b.shape)

PyTorch 维度变换_第1张图片

squeeze,unsqueeze

a = torch.rand(4,1,28,28)
print(a.shape)

print(a.unsqueeze(0).shape)#在最前面插入一个维度
print(a.unsqueeze(-1).shape)#在最后面插入一个维度

b = torch.rand(1,32,1,1)
print(b.squeeze().shape)#不给定参数会将所有可以删减的维度全部删掉
print(b.squeeze(0).shape)#去掉指定维度
print(b.squeeze(-1).shape)
print(b.squeeze(-4).shape)

PyTorch 维度变换_第2张图片

expand,repeat,都是扩展,expand不增加数据,repeat会增加数据,相当于复制拷贝数据,推荐第一种,速度快节约内存,真实需要时才会复制数据

a = torch.rand(4,32,14,14)
print(a.shape)

b = torch.rand(1,32,1,1)
print(b.shape)
print(b.expand(4,32,14,14).shape)
print(b.expand(-1,32,-1,-1).shape)#-1表示不变

 PyTorch 维度变换_第3张图片

b = torch.rand(1,32,1,1)
print(b.shape)

print(b.repeat(4,32,1,1).shape)#表示的是第一个维度拷贝4次,第二个维度拷贝32次
print(b.repeat(4,1,32,32).shape)

PyTorch 维度变换_第4张图片

 Transpose交换指定的维度

a = torch.rand(4,3,32,32)

a1 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32)
a2 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
#使用view操作需要人为跟踪每个维度代表的信息,否则数据会被打乱,无法复原
print(a1.shape)
print(a2.shape)

print(torch.all(torch.eq(a,a1)))#比较数据是否相同
print(torch.all(torch.eq(a,a2)))

PyTorch 维度变换_第5张图片

permute

a = torch.rand(4,3,32,32)
print(a.shape)

print(a.permute(0,2,3,1).shape)#里面的数值表示原来的维度号,原来的第一位在新的第四维

 

你可能感兴趣的:(pytorch,机器学习)