torch.ones_like函数和torch.zero_like函数

torch.ones_like函数和torch.zeros_like函数的基本功能是根据给定张量,生成与其形状相同的全1张量或全0张量,示例如下:

input = torch.rand(2, 3)
print(input)
# 生成与input形状相同、元素全为1的张量
a = torch.ones_like(input)
print(a)
# 生成与input形状相同、元素全为0的张量
b = torch.zeros_like(input)
print(b)

效果如下:

tensor([[0.0881, 0.9002, 0.7084],
        [0.3313, 0.2736, 0.0894]])
tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])

我们进一步看一下这两个函数在源码中是怎样定义的。

torch.ones_like函数:

@overload
def ones_like(self: Tensor, *, dtype: _dtype=None, layout: layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...

torch.zeros_like函数:

@overload
def zeros_like(self: Tensor, *, dtype: _dtype=None, layout: layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...

可以看到,在这两个函数中,我们还可以指定数据类型、设备、是否计算梯度等信息,可以结合具体场景灵活使用。

你可能感兴趣的:(python,pytorch,torch.ones_like,zeros_like,python,pytorch)