pytorch维度变换

2019独角兽企业重金招聘Python工程师标准>>> hot3.png

import torch

import numpy as np

 

#维度变换1:view容易造成数据存储方式丢失.

a = torch.rand(4,1,28,28)

print(a.shape,a.view(4,28,28))#4,28,28 4张图片,把每张图片都合并在一起,即784,常用于全连接层;

print(a.view(4,28*28).shape)#torch.Size([4, 784])

print(a.view(4*28,28))#把所有通道所有行都放在第一个维度,即channel和行通道合并在一起

print(a.view(4*1,28,28))#

 

#维度展开:unsqueeze,注意能插入范围是[-5,4)这里4代表整个维度,5代表维度加1,比如0代表第一个位置前插入,1代表第二个位置前插入,3代表第三个位置前插入

b = a.unsqueeze(0)#torch.Size([1, 4, 1, 28, 28])

print(b.shape)

c = a.unsqueeze(-1)#torch.Size([4, 1, 28, 28, 1])

print(c.shape)

"""

-5 -4 -3 -2 -1

0 1 2 3 4

4 1 28 28

 

torch.Size([1, 4, 1, 28, 28])

torch.Size([4, 1, 28, 28, 1])

torch.Size([1, 4, 1, 28, 28])

torch.Size([4, 1, 1, 28, 28])

torch.Size([4, 1, 1, 28, 28])

torch.Size([4, 1, 28, 1, 28])

torch.Size([4, 1, 28, 28, 1])

torch.Size([1, 4, 1, 28, 28])

torch.Size([4, 1, 1, 28, 28])

torch.Size([4, 1, 1, 28, 28])

torch.Size([4, 1, 28, 1, 28])

torch.Size([4, 1, 28, 28, 1])

尽量不使用负数

"""

for i in range(-5,5):

d = a.unsqueeze(i)

print(d.shape)

 

b = torch.rand(32)

f = torch.rand(4,32,14,14)

b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)#torch.Size([1, 32, 1, 1])

print(b.shape)


 

#维度删减squeeze

"""

torch.Size([32, 1, 1])

torch.Size([1, 32, 1, 1])

torch.Size([1, 32, 1])

torch.Size([1, 32, 1])

torch.Size([32, 1, 1])

torch.Size([1, 32, 1, 1])

torch.Size([1, 32, 1])

torch.Size([1, 32, 1])

"""

c = b.squeeze()#torch.Size([32])

print(c.shape)

for i in range(-4,4):

print(b.squeeze(i).shape)

 

#维度扩展,即把shape改变 expand改变理解方式,不增加数据,repeat增加数据;注意repeat需要拷贝数据,所以速度慢.

b = torch.rand(1,32,1,1)

a = torch.rand(4,32,14,14)

c = b.expand(4,32,14,14)

print(b.shape,a.shape,c.shape)#torch.Size([1, 32, 1, 1]) torch.Size([4, 32, 14, 14]) torch.Size([4, 32, 14, 14])

 

d = b.repeat(4,32,1,1)#这里4,32,1,1代表数据被拷贝次数;

print(d.shape)#torch.Size([4, 1024, 1, 1])这不是我们想要结果,正确如下:

d = b.repeat(4,1,1,1)

print(d.shape)#torch.Size([4, 32, 1, 1])

 

#矩阵转置

a = torch.randn(4,3)

print(a,a.t())#t只用于二维度

 

a = torch.rand(4,3,32,32)

#b = a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)#数据不连续,错误

#print(a.shape,c.shape)

b=a.transpose(1,3).contiguous().view(4,3*32*32 ).view(4,3,32,32)

c=a.transpose(1,3).contiguous().view(4,3*32*32 ).view(4,32,32,3).transpose(1,3)

print(b.shape,c.shape)#torch.Size([4, 3, 32, 32]) torch.Size([4, 32, 32, 3])

print(torch.all(torch.eq(a,b)),torch.all(torch.eq(a,c)))#tensor(0, dtype=torch.uint8) tensor(1, dtype=torch.uint8) 判断数据内容是否一致

 

d = a.permute(0,2,3,1)

print(d.shape)#torch.Size([4, 32, 32, 3]) 0,2,3,1代表存放维度数

转载于:https://my.oschina.net/u/4131400/blog/3048616

你可能感兴趣的:(pytorch维度变换)