torch.unsqueeze()和torch.squeeze()

torch.unsqueeze()和torch.squeeze()

1.torch.unsqueeze()

原型:

torch.unsqueeze(input, dim, out=None)	

作用:扩展维度,返回一个新的张量,对输入徳既定位置插入维度1。

参数:

  • tensor (Tensor) – 输入张量
  • dim (int) – 插入维度的索引
  • out (Tensor, optional) – 结果张量

2.torch.squeeze()

原型:

torch.squeeze(input, dim=None, out=None)

作用:降低维度,将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)

当给定dim时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B), squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。

参数:

  • input (Tensor) – 输入张量
  • dim (int, optional) – 如果给定,则input只会在给定维度挤压
  • out (Tensor, optional) – 输出张量

你可能感兴趣的:(pytorch,pytorch)