Pytorch 中的 torch.clamp() 函数

作用

将输入input张量的每个元素的夹紧到区间 [min,max],并返回结果得到一个新张量

参数及定义

import torch
torch.clamp(input, min, max, out = None)

参数:

  • input (Tensor) – 输入张量

  • min (Number) – 限制范围下限

  • max (Number) – 限制范围上限

  • out (Tensor, optional) – 输出张量

运算规则

      |-min, if x_i < min
      |
y_i = | x_i, if min <= x_i <= max
      |
      |-max, if x_i > max

示例

import torch

a = torch.randint(low = 0, high = 100, size = (1, 10))
print(a)

a = torch.clamp(a, 10, 20)
print(a)

>>tensor([[ 8, 92, 17, 60, 83, 81, 33, 53, 12, 79]])
>>tensor([[10, 20, 17, 20, 20, 20, 20, 20, 12, 20]])

你可能感兴趣的:(Pytorch,中的各种函数,Pytorch)