【Python】torch.where()解析

【Python】torch.where()解析

文章目录

  • 【Python】torch.where()解析
    • 1. 介绍
    • 2. API
    • 3. 示例

1. 介绍

torch.where(condition, x, y)

  • 函数功能:
    • 将指定 tensor 的满足条件位置设置为想要的数值。
  • 参数:
    • condition:判断条件
    • x:若满足条件,则为 x 中元素
    • y:若不满足条件,则为 y 中元素

2. API

(function)
'''
四种调用方法如下:
'''
def where(
    condition: Tensor,
    input: Tensor,
    other: Tensor,
    *,
    out: Tensor | None = None
) -> Tensor: ...

def where(
    condition: Tensor,
    self: Number,
    other: Tensor
) -> Tensor: ...

def where(
    condition: Tensor,
    input: Tensor,
    other: Number
) -> Tensor: ...

def where(
    condition: Tensor,
    self: Number,
    other: Number
) -> Tensor: ...

3. 示例

import torch
 
# 条件
condition = torch.rand(3, 2)
print(condition)
# 满足条件则取x中对应元素
x = torch.ones(3, 2)
print(x)
# 不满足条件则取y中对应元素
y = torch.zeros(3, 2)
print(y)
# 条件判断后的结果
result = torch.where(condition > 0.5, x, y)
print(result)

输出如下:

tensor([[0.3224, 0.5789],
        [0.8341, 0.1673],
        [0.1668, 0.4933]])
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])
tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])
tensor([[0., 1.],
        [1., 0.],
        [0., 0.]])

你可能感兴趣的:(Python,使用说明,python,深度学习,numpy)