Pytorch:.max(0)和.max(1)的区别?

Pytorch 中.max(0)和.max(1)的区别是什么?

当我们有一个形状为 ( m , n ) (m, n) (m,n) 的 Tensor x x x 时,其中 m m m 表示行数, n n n 表示列数。在 PyTorch 中,max(dim) 函数的参数 dim 表示计算最大值的维度,可以被设置为 0 或 1。那么,.max(0) 和 .max(1) 的区别在于计算的方向不同。具体来说:

对于 .max(0) 函数,计算方向是沿着第 0 维(即行数)的方向,它会返回 x x x 的每列最大值和它们的行索引。例如, x x x 的 shape 是 ( 3 , 2 ) (3, 2) (3,2),则它沿着行的方向计算最大值,返回的结果是两个张量:
一个形状为 ( 2 , ) (2,) (2,) 的张量,其中第 i i i 个元素是第 i i i 列的最大值;
一个形状为 ( 2 , ) (2,) (2,) 的张量,其中第 i i i 个元素是第 i i i 列的最大值所对应的行索引。
对于 .max(1) 函数,计算方向是沿着第 1 维(即列数)的方向,它会返回 x x x 的每行最大值和它们的列索引。例如, x x x 的 shape 是 ( 3 , 2 ) (3, 2) (3,2),则它沿着列的方向计算最大值,返回的结果是两个张量:
一个形状为 ( 3 , ) (3,) (3,) 的张量,其中第 i i i 个元素是第 i i i 行的最大值;
一个形状为 ( 3 , ) (3,) (3,) 的张量,其中第 i i i 个元素是第 i i i 行的最大值所对应的列索引。
下面是一个简单的例子,可以更好地解释 .max(0) 和 .max(1) 的区别:

import torch

# 构造一个 3x2 的 Tensor
x = torch.tensor([[0.5, 0.1], [0.8, 0.4], [0.2, 0.9]])

# 沿着行的方向计算最大值,返回每列的最大值和它们的行索引
max_values, max_indices = x.max(0)
print("max_values =", max_values)    # prints: "max_values = tensor([0.8000, 0.9000])"
print("max_indices =", max_indices)  # prints: "max_indices = tensor([1, 2])"

# 沿着列的方向计算最大值,返回每行的最大值和它们的列索引
max_values, max_indices = x.max(1)
print("max_values =", max_values)    # prints: "max_values = tensor([0.5000, 0.8000, 0.9000])"
print("max_indices =", max_indices)  # prints: "max_indices = tensor([0, 0, 1])"

在上面的例子中,我们首先构造了一个 3x2 的 Tensor x,然后分别使用 .max(0) 和 .max(1) 计算了每个维度上的最大值和最大值所在的维度索引。可以看到,.max(0) 返回了每列最大值和它们的行索引,而 .max(1) 返回了每行最大值和它们的列索引

你可能感兴趣的:(Pytorch学习手册,pytorch,深度学习,python)