目录
1. torch.linalg.norm()
2. torch.linalg.vector_norm()
3. torch.linalg.matrix_norm()
最近在写程序的时候,用到了L2范数,也因此了解到了这几个函数torch.norm()【已弃用】,torch.linalg.norm(),torch.linalg.matrix_norm(),torch.linalg.vector_norm()。
该函数能够计算向量/矩阵/范数,首先献上官方文档:torch.linalg.norm — PyTorch 1.13
torch.linalg.norm(A, ord=None, dim=None, keepdim=False, *, out=None, dtype=None)
(1) ord是指定的范数,默认None,为F-范数(2-范数)。当p='nuc'时为核范数,剩下的可看文档。
(2) dim为指定求范数的维度
(3) keepdim指是否保留dim指定的维度
示例1:dim和ord均为None,即默认A展平,计算向量的2-范数,下面是向量范数公式(当p=2时)
import torch
from torch import linalg as LA
if __name__ == '__main__':
a = torch.arange(20).reshape(4, -1).float()
f2 = LA.norm(a)
print(f2) # tensor(49.6991)
# 验证
f2_my = torch.sqrt(a.pow(2).sum())
print(f2_my) # tensor(49.6991)
示例2:当指定维度为int时,计算的是该维度的向量范数。以下为示例和验证
import torch
from torch import linalg as LA
if __name__ == '__main__':
a = torch.arange(20).reshape(4, -1).float()
f2 = LA.norm(a, dim=0)
print(f2) # tensor([18.7083, 20.3470, 22.0454, 23.7908, 25.5734])
# 验证
f2_my = torch.sqrt(a.pow(2).sum(dim=0))
print(f2_my) # tensor([18.7083, 20.3470, 22.0454, 23.7908, 25.5734])
示例3:当指定维度为二元组时,计算的是该维度的矩阵范数。
import torch
from torch import linalg as LA
if __name__ == '__main__':
a = torch.arange(8).reshape(2, 2, 2).float()
f2 = LA.norm(a, dim=(1, 2))
print(f2, f2.shape) # tensor([ 3.7417, 11.2250]), torch.Size([2])
f2_keepdim = LA.norm(a, dim=(1, 2), keepdim=True)
print(f2_keepdim, f2_keepdim.shape) # tensor([[[ 3.7417]], [[11.2250]]]), torch.Size([2, 1, 1])
例子中给出了dim=(1, 2)是指计算dim=1和dim=2的矩阵范数,所以结果f2_keepdim是维度1和2的shape是1,维度为0的shape不变。同时例子中也验证了keepdim这个参数的用法,即默认是False,会把结果展平到一维,若为True的话,会保留原始的shape。
torch.linalg.norm()包含矩阵和向量范数,更具体的可以使用 torch.linalg.vector_norm()来计算向量范数,使用torch.linalg.matrix_norm()来计算矩阵范数。
torch.linalg.vector_norm(x, ord=2, dim=None, keepdim=False, *, dtype=None, out=None)
基本用法和norm一样,具体可参考torch.linalg.vector_norm — PyTorch 1.13 documentation
(1) ord是指定的范数:默认是2
(2) dim为指定求范数的维度:
示例4: 直接将示例1中的LA.norm更改成LA.vector_norm即可
from torch import linalg as LA
if __name__ == '__main__':
a = torch.arange(20, dtype=torch.float).reshape((4, -1)
f2 = LA.vector_norm(a, ord=2)
print(f2)
torch.linalg.matrix_norm(A, ord='fro', dim(- 2, - 1), keepdim=False, *, dtype=None, out=None)
基本用法和norm也是一致的,具体参考torch.linalg.matrix_norm — PyTorch 1.13 documentation
(1) ord指指定范数,下图为官网截图
示例5:直接将示例3中的LA.norm更改为LA.matrix_norm即可
import torch
from torch import linalg as LA
if __name__ == '__main__':
a = torch.arange(8).reshape(2, 2, 2).float()
f2 = LA.matrix_norm(a, dim=(1, 2))