pytorch之squeeze和unsqueeze的用法及注意事项

pytorch之squeeze和unsqueeze的用法及注意事项

      • 用法
      • 注意事项

用法

  1. squeeze:对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1, 10)的数去掉第一个维数为一的维度之后就变成(10)行。用法两种:

b= a.squeeze(dim=index) 输入需要降维的维度index
b = torch.squeeze(a, dim=index) 同上

a = torch.randn([1, 10], )
b= a.squeeze(0)
print(b.size())
# out:torch.Size([10])
  1. 给指定位置加上维数为一的维度,比如是一行或者一列这种,一个一行三列(2, 10)的数去掉第一个维数为一的维度之后就变成(1, 2, 10)行。用法两种:

b= a.unsqueeze(dim=index) 输入需要增维的维度index
b = torch.unsqueeze(a, dim=index) 同上

a = torch.randn([2, 10])
b = a.unsqueeze(dim=0)
print(b.size())
# out:torch.Size([1, 2, 10])

注意事项

  1. 使用squeeze时所删除的维数index不能超过数据维数。
  2. 使用unsqueeze时,只能增加一维,如果还需要耕多可以用repeat、expend方法。

你可能感兴趣的:(Pytorch那些事儿,BUG解决类,学习工具类)