PyTorch中torch.norm函数详解

torch.norm() 是 PyTorch 中的一个函数,用于计算输入张量沿指定维度的范数。具体而言,当给定一个输入张量 x 和一个整数 p 时,torch.norm(x, p) 将返回输入张量 x 沿着最后一个维度(默认为所有维度)上所有元素的 p 范数。

除了使用标量 p 之外,torch.norm() 还接受以下参数:

  • dim:指定沿哪个轴计算范数,默认对所有维度计算。
  • keepdim:如果设置为 True,则输出张量维度与输入张量相同,其中指定轴尺寸为 1;否则,将从输出张量中删除指定轴。
  • out:可选输出张量结果。

以下是一个示例:

import torch

# 创建一个形状为 (3, 4) 的二维张量
x = torch.tensor([[2., 3., 5., -1.], [-1., -2., 1., 4.], [0.5, -2., 7., 2.]])

# 计算所有元素的 L2 范数
l2_norm_all = torch.norm(x)
print("L2 norm of all elements:", l2_norm_all.item())

# 计算第一个维度上每个子数组的 L2 范数(即按行计算)
l2_norm_rows = torch.norm(x, dim=1)
print("L2 norm of rows:", l2_norm_rows.numpy())

# 计算最后一个维度上每个子数组的 L1 范数(即按列计算)
l1_norm_cols = torch.norm(x, p=1, dim=-1)
print("L1 norm of columns:", l1_norm_cols.numpy())

在这个示例中,我们首先创建了一个形状为 (3, 4) 的二维张量 x,然后使用 torch.norm() 函数计算了不同维度上的范数。注意,我们将 dim 参数设置为 1 和 -1 以分别按行和按列计算范数,并将 p 参数设置为 1 来计算 L1 范数。在输出结果中,我们使用 .item() 将标量张量转换回 Python 中的浮点数,用 .numpy() 将张量转换回 NumPy 数组。

你可能感兴趣的:(numpy,python,深度学习)