torch.expand(-1, -1)的理解

torch.expand(-1, -1)的理解

在expand中的-1表示取当前所在维度的尺寸,也就是表示当前维度不变。
在代码中 一般用这方法解决不想手动计算维度的时候

例:

import torch

x = torch.Tensor([[1], [2], [3]])
x0 = x.size(0)  # 取x第一维的尺寸,x0 = 3
x1 = x.expand(-1, 2)
x2 = x.expand(3, 2)

输出:

x0 =  3
x1 =  tensor([[1., 1.],
        [2., 2.],
        [3., 3.]])
x2 =  tensor([[1., 1.],
        [2., 2.],
        [3., 3.]])

从例子可以看出x1 = x.expand(-1, 2)等价于x2 = x.expand(3, 2)

你可能感兴趣的:(pytorch)