tensor升维和降维是神经网络的基本操作,比如不同维feature融合等都需要改操作。常用的函数有torch.unsqueeze() 和 torch.unsqueeze()操作。
目录
1. tensor降维操作: torch.squeeze() 和 指定index
2. tensor升维操作: torch.unsqueeze() 和 使用None
3. torch.squeeze和torch.unsqueeze的另一种写法
(1) 使用torch.squeeze(input,dim),默认删除tensor中所有维度为1的维度,也可指定dim。torch.squeeze — PyTorch 1.13 documentation
import torch
if __name__ == '__main__':
a = torch.randn((2, 1, 3, 1, 4))
a1 = torch.squeeze(a)
print(a1.shape) # torch.Size([2, 3, 4])
a2 = torch.squeeze(a, dim=1)
print(a2.shape) # torch.Size([2, 3, 1, 4])
a3 = torch.squeeze(a, dim=3)
print(a3.shape) # torch.Size([2, 1, 3, 4])
(2) 也可使用index=0直接指定,使用torch.equal比较两者相等。
if __name__ == '__main__':
a = torch.randn((2, 1, 3, 1, 4))
a1 = torch.squeeze(a)
print(a1.shape) # torch.Size([2, 3, 4])
a2 = a[:, 0, :, 0]
print(a2.shape) # torch.Size([2, 3, 4])
print(torch.equal(a1, a2)) # True
(1) torch.unsqueeze(input, dim) ,对指定的dim,执行升维操作,具体可参考官方文档以及如下示例。torch.unsqueeze — PyTorch 1.13 documentation
import torch
if __name__ == '__main__':
a = torch.randn((2, 3, 4))
a1 = torch.unsqueeze(a, dim=1)
print(a1.shape) # torch.Size([2, 1, 3, 4])
a2 = torch.unsqueeze(a, dim=2)
print(a2.shape) # torch.Size([2, 3, 1, 4])
(2) 简单用法:使用None,使用None来增加新维度
import torch
if __name__ == '__main__':
a = torch.randn((2, 3, 4))
a1 = a[:, None, ...]
print(a1.shape) # torch.Size([2, 1, 3, 4])
a2 = a[..., None, :]
print(a2.shape) # torch.Size([2, 3, 1, 4])
注意:a1中None后面的三个点可以省略,如下
import torch
if __name__ == '__main__':
a = torch.randn((2, 3, 4))
a1_old = a[:, None, ...]
print(a1_old .shape) # torch.Size([2, 1, 3, 4])
a1_new = a[:, None]
print(a1_new .shape) # torch.Size([2, 1, 3, 4])
print(torch.equal(a1_old, a1_new)) # True
一般情况下使用torch.squeeze(x, dim=?)来进行降维,当然还可以直接使用 x.squeeze(dim=?)。
import torch
if __name__ == '__main__':
a = torch.randn((2, 3, 4))
a1 = torch.unsqueeze(a, dim=0)
print(a1.shape) # torch.Size([1, 2, 3, 4])
# 另一种写法
a2 = a.unsqueeze(dim=0)
print(a2.shape) # torch.Size([1, 2, 3, 4])