在pytorch中expand_dim

在numpy和tensorflow中都有扩展维度操作:expand_dims操作

pytorch中也有这个操作,但是命名不一样,pytorch中的命名是:unsqueeze,直接放在tensor后面即可。

示例如下:

import torch

x1 = torch.zeros(10, 10)
x2 = x1.unsqueeze(0) # 括号里的参数是扩展的维度的位置

print(x2.size())

"""
返回:torch.Size([1, 10, 10])
"""

unsqueeze_与unsqueeze有同样的效果

import torch

x1 = torch.zeros(10, 10)
x2 = x1.unsqueeze_(0) # 括号里的参数是扩展的维度的位置

print(x2.size())

"""
返回:torch.Size([1, 10, 10])
"""

参考:https://jbencook.com/adding-a-dimension-to-a-tensor-in-pytorch/

你可能感兴趣的:(在pytorch中expand_dim)