unsqueeze()函数

Pytorch unsqueeze()函数的用法

在这里插入代码片
import torch
a=torch.ones(10)
print("a.shape:",a.shape)
b=a.unsqueeze(0)   #对0维度扩展一维
c=a.unsqueeze(1)   #对一维度扩展一维
print("b.shape",b.shape)
print("c.shape",c.shape)

运行结果在这里插入图片描述

你可能感兴趣的:(python,pytorch,深度学习,机器学习)