关于pytorch的unsqueeze()的学习笔记

官方解释

Tensor.unsqueeze(dim)

插入一个新的维度到指定位置,并返回一个新的tensor,即对tensor增加维度。

初始化一个tensor

import torch
input = torch.arange(1,7)

print(input)
print(input.shape)
###############  输出  ####################
tensor([1, 2, 3, 4, 5, 6])
torch.Size([6])

input.unsqueeze( 0)

print(input.unsqueeze(0))
###############  输出  ####################
tensor([[1, 2, 3, 4, 5, 6]])

input.unsqueeze( 1)

print(input.unsqueeze(1))
###############  输出  ####################
tensor([[1],
        [2],
        [3],
        [4],
        [5],
        [6]])

input.unsqueeze(-1)

print(input.unsqueeze(-1))
###############  输出  ####################
tensor([[1],
        [2],
        [3],
        [4],
        [5],
        [6]])

你可能感兴趣的:(pytorch,学习)