【Pytorch】max函数

在Pytorch中,max函数包括torch中顶级函数torch.max和Tensor对象的max函数,且均实现了overload(函数重载),以泛化其功能。其常见的使用方法包括:

可用于Tensor对象内部元素的极值获取,或者两个Tensor对象的逐元素对比。

1. 张量内部元素的极值

import torch 
a = torch.arange(0, 6).reshape(2,3)    # tensor([[0, 1, 2][3, 4, 5]])

# 所有元素的最大值
torch.max(a)  # 方法一,tensor(5)
a.max()    # 方法二,tensor(5)

# 沿某个dim的最大值 
torch.max(a, dim=1)    # 方法一,(tensor([2, 5]), tensor([2, 2])), 第二个元素即为torch.argmax(a, dim=1)
a.max(dim=1)     # 方法二

2. 逐元素对比两个张量
该操作支持广播

import torch 
a = torch.arange(0, 6).reshape(2,3)    # tensor([[0, 1, 2][3, 4, 5]])

torch.max(a, torch.tensor(2))  # 方法一:tensor([[2, 2, 2], [3, 4, 5]])

a.max(torch.tensor(2))   # 方法二

一个典型应用就是relu激活函数:

def relu(t):
    return torch.max(t, torch.zeros_like(t))

你可能感兴趣的:(Pytorch)