unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度。
在第一个维度(中括号)的每个元素加中括号
0表示在张量最外层加一个中括号变成第一维。
直接看例子:
import torch
input=torch.arange(0,6)
print(input)
print(input.shape)
结果:
tensor([0, 1, 2, 3, 4, 5])
torch.Size([6])
print(input.unsqueeze(0))
print(input.unsqueeze(0).shape)
结果:
tensor([[0, 1, 2, 3, 4, 5]])
torch.Size([1, 6])
print(input.unsqueeze(1))
print(input.unsqueeze(1).shape)
结果:
tensor([[0],
[1],
[2],
[3],
[4],
[5]])
torch.Size([6, 1])