PyTorch中张量(tensor)的维度变换

1. view/reshape:形状改变,数据不变

Example:

In[1]: x = torch.rand(4,1,28,28)

In[2]: x.size()

Out[2]: torch.Size([4, 1, 28, 28])

 

In[3]: y = x.view(4,28*28)

In[4]: y.size()

Out[4]: torch.Size([4, 784])

 

In[5]: y = x.reshape(4,28*28)

In[6]: y.size()

Out[6]: torch.Size([4, 784])

 

2. squeeze/unsqueeze:压缩/扩展维度

Example:

In[1]: x = torch.rand(4,1,28,28)

In[2]: x.size()

Out[2]: torch.Size([4, 1, 28, 28])

 

In[3]: y = x.squeeze()

In[4]: y.size()

Out[4]: torch.Size([4, 28, 28])   #默认去掉所有为元素个数为1的维度

 

In[5]: y = x.squeeze(1)

In[6]: y.size()

Out[6]: torch.Size([4, 28, 28])

 

In[7]: y = x.squeeze(2)

In[8]: y.size()

Out[8]: torch.Size([4, 1, 28, 28])   #元素个数不为1的维度不能squeeze().

 

In[9]: y = x.unsqueeze(2)

In[10]: y.size()

Out[10]: torch.Size([4, 1, 1, 28, 28])   #参数就是插入位置

 

3. t:转置

Example:

In[1]: x = torch.randn(2, 3)

In[2]: x.size()

Out[2]: torch.Size([2, 3])

 

In[3]: y = x.t()   #转置只针对二维张量

In[4]: y.size()

Out[4]: torch.Size([3, 2])

 

4. tranpose/ permute:两两置换/多次置换

Example:

In[1]: x = torch.rand(4,1,28,28)

In[2]: x.size()

Out[2]: torch.Size([4, 1, 28, 28])

 

In[3]: y = x.transpose(0, 1)

In[4]: y.size()

Out[4]: torch.Size([1, 4, 28, 28])

 

In[5]: y = x.permute(3, 2, 1, 0))

In[6]: y.size()

Out[6]: torch.Size([28, 28, 1, 4])

 

5. expand/repeat:扩大/重复

Example:

In[1]: x = torch.rand(4,1,1,1)

In[2]: x.size()

Out[2]: torch.Size([4, 1, 1, 1])

 

In[3]: y = x.expand(4,1,28,28)

In[4]: y.size()

Out[4]: torch.Size([4, 1, 28, 28])

 

In[5]: y = x.repeat(4,1,28,28)

In[6]: y.size()

Out[6]: torch.Size([16, 1, 28, 28])   #参数是repeat倍数

 

6. broadcasting:unsqueeze+expand

 

此外,在进行维度变换时,要注意数据顺序的实际意义!!!

你可能感兴趣的:(机器学习)