pytorch中unsqueeze()函数理解

unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度。
在第一个维度(中括号)的每个元素加中括号
0表示在张量最外层加一个中括号变成第一维。
直接看例子:

import torch
input=torch.arange(0,6)
print(input)
print(input.shape)
结果:
tensor([0, 1, 2, 3, 4, 5])
torch.Size([6])
print(input.unsqueeze(0))
print(input.unsqueeze(0).shape)
结果:
tensor([[0, 1, 2, 3, 4, 5]])
torch.Size([1, 6])
print(input.unsqueeze(1))
print(input.unsqueeze(1).shape)
结果:
tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5]])
torch.Size([6, 1])

你可能感兴趣的:(pytorch常用函数简单解析)