pytorch 给tensor增加一维(unsqueeze)或删除一维(squeeze)

给tensor增加一维

b = a.unsqueeze(0)
import torch

a = torch.randn(3, 200, 200)
b = a.unsqueeze(0)
print(a.shape)
print(b.shape)

在这里插入图片描述

删除tensor一维

squeeze只能删除维度为1的某一维。若某个维度不为1,可以用切片取出该维度的一个数据,再用squeeze删除。

b = a.squeeze(0)
import torch

a = torch.randn(1, 3, 200, 200)
b = a.squeeze(0)
print(a.shape)
print(b.shape)

在这里插入图片描述

你可能感兴趣的:(深度学习,#,Pytorch,pytorch,深度学习,python)