PyTorch 去除值为 1 的维度

1. 方法

  • torch.squeeze(input, dim=None, out=None) → Tensor
    • 官方文档:https://pytorch.org/docs/master/generated/torch.squeeze.html
    • 参数说明:
      • input (Tensor):输入的张量
      • dim (int, optional) :可选参数,如果不指定,该方法会把所有值为 1 的维度移除,如果指定,该方法则指移除指定的那个维度
      • out (Tensor, optional) :可选,指定输出的张量.

2. 实例

>>> import torch
>>> import numpy as np
>>> labels = np.random.randint(1,7,(10,5,1,2,1))
>>> labels = torch.LongTensor(labels)
>>> labels.size()
torch.Size([10, 5, 1, 2, 1])

# 2.1 移除指定的值为 1 的维度
>>> squeezed_labels1 = torch.squeeze(labels, dim=2)
>>> squeezed_labels1.size()
torch.Size([10, 5, 2, 1])

# 2.2 移除指定的值为 1 的维度
>>> squeezed_labels2 = labels.squeeze(dim=4)
>>> squeezed_labels2.size()
>>> torch.Size([10, 5, 1, 2])

# 2.3 不指定维度的话,默认移除所有的值为 1 的维度
>>> squeezed_labels3 = labels.squeeze()
>>> squeezed_labels3.size()
torch.Size([10, 5, 2])
  • 当然以上都可以使用相应的 tensor.squeeze_() 方法

你可能感兴趣的:(PyTorch,基础)