pytorch基础:torch.unsqueeze()

torch.unsqueeze()

torch.unsqueeze()这个函数主要是对数据维度进行扩充。给指定位置加上维数为一的维度

import torch
a = torch.tensor([[1,2,3],[4,5,6]])  # size([2,3])
print(a.unsqueeze(0)) #扩增0维
>>> tensor([[[1,2,3],[4,5,6]]]) # size([1,2,3])

print(a.unsqueeze(1)) #增加1维 
>>> tensor([[[1,2,3]],[[4,5,6]]]) # size(2,1,3)

print(a.unsqueeze(2)) # 增加2维
>>>tensor([[[1],[2],[3]],[[4],[5],[6]]]) # size(2,3,1)

你可能感兴趣的:(pytorch,深度学习,人工智能)