2019独角兽企业重金招聘Python工程师标准>>>
import torch
import numpy as np
#维度变换1:view容易造成数据存储方式丢失.
a = torch.rand(4,1,28,28)
print(a.shape,a.view(4,28,28))#4,28,28 4张图片,把每张图片都合并在一起,即784,常用于全连接层;
print(a.view(4,28*28).shape)#torch.Size([4, 784])
print(a.view(4*28,28))#把所有通道所有行都放在第一个维度,即channel和行通道合并在一起
print(a.view(4*1,28,28))#
#维度展开:unsqueeze,注意能插入范围是[-5,4)这里4代表整个维度,5代表维度加1,比如0代表第一个位置前插入,1代表第二个位置前插入,3代表第三个位置前插入
b = a.unsqueeze(0)#torch.Size([1, 4, 1, 28, 28])
print(b.shape)
c = a.unsqueeze(-1)#torch.Size([4, 1, 28, 28, 1])
print(c.shape)
"""
-5 -4 -3 -2 -1
0 1 2 3 4
4 1 28 28
torch.Size([1, 4, 1, 28, 28])
torch.Size([4, 1, 28, 28, 1])
torch.Size([1, 4, 1, 28, 28])
torch.Size([4, 1, 1, 28, 28])
torch.Size([4, 1, 1, 28, 28])
torch.Size([4, 1, 28, 1, 28])
torch.Size([4, 1, 28, 28, 1])
torch.Size([1, 4, 1, 28, 28])
torch.Size([4, 1, 1, 28, 28])
torch.Size([4, 1, 1, 28, 28])
torch.Size([4, 1, 28, 1, 28])
torch.Size([4, 1, 28, 28, 1])
尽量不使用负数
"""
for i in range(-5,5):
d = a.unsqueeze(i)
print(d.shape)
b = torch.rand(32)
f = torch.rand(4,32,14,14)
b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)#torch.Size([1, 32, 1, 1])
print(b.shape)
#维度删减squeeze
"""
torch.Size([32, 1, 1])
torch.Size([1, 32, 1, 1])
torch.Size([1, 32, 1])
torch.Size([1, 32, 1])
torch.Size([32, 1, 1])
torch.Size([1, 32, 1, 1])
torch.Size([1, 32, 1])
torch.Size([1, 32, 1])
"""
c = b.squeeze()#torch.Size([32])
print(c.shape)
for i in range(-4,4):
print(b.squeeze(i).shape)
#维度扩展,即把shape改变 expand改变理解方式,不增加数据,repeat增加数据;注意repeat需要拷贝数据,所以速度慢.
b = torch.rand(1,32,1,1)
a = torch.rand(4,32,14,14)
c = b.expand(4,32,14,14)
print(b.shape,a.shape,c.shape)#torch.Size([1, 32, 1, 1]) torch.Size([4, 32, 14, 14]) torch.Size([4, 32, 14, 14])
d = b.repeat(4,32,1,1)#这里4,32,1,1代表数据被拷贝次数;
print(d.shape)#torch.Size([4, 1024, 1, 1])这不是我们想要结果,正确如下:
d = b.repeat(4,1,1,1)
print(d.shape)#torch.Size([4, 32, 1, 1])
#矩阵转置
a = torch.randn(4,3)
print(a,a.t())#t只用于二维度
a = torch.rand(4,3,32,32)
#b = a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)#数据不连续,错误
#print(a.shape,c.shape)
b=a.transpose(1,3).contiguous().view(4,3*32*32 ).view(4,3,32,32)
c=a.transpose(1,3).contiguous().view(4,3*32*32 ).view(4,32,32,3).transpose(1,3)
print(b.shape,c.shape)#torch.Size([4, 3, 32, 32]) torch.Size([4, 32, 32, 3])
print(torch.all(torch.eq(a,b)),torch.all(torch.eq(a,c)))#tensor(0, dtype=torch.uint8) tensor(1, dtype=torch.uint8) 判断数据内容是否一致
d = a.permute(0,2,3,1)
print(d.shape)#torch.Size([4, 32, 32, 3]) 0,2,3,1代表存放维度数