Pytorch:torch.diag()创建对角线张量方式例子解析

Pytorch:torch.diag()创建对角线张量方式例子解析_第1张图片
在PyTorch中,torch.diag函数可以用于创建对角线张量或提取给定矩阵的对角线元素。以下是一些详细的使用例子:

  1. 创建对角矩阵:如果输入是一个向量(1D张量),torch.diag将返回一个2D方阵,其中输入向量的元素作为对角线元素。例如:

    a = torch.randn(3)
    print(a)
    # 输出:tensor([ 0.5950,-0.0872, 2.3298])
    print(torch.diag(a))
    # 输出:tensor([[ 0.5950, 0.0000, 0.0000],
    #              [ 0.0000,-0.0872, 0.0000],
    #              [ 0.0000, 0.0000, 2.3298]])
    
  2. 提取对角线元素:如果输入是一个矩阵(2D张量),torch.diag将返回一个1D张量,包含输入矩阵的对角线元素。例如:

    a = torch.randn(3, 3)
    print(a)
    # 输出:tensor([[-0.4264, 0.0255,-0.1064],
    #              [ 0.8795,-0.2429, 0.1374],
    #              [ 0.1029,-0.6482,-1.6300]])
    print(torch.diag(a, 0))
    # 输出:tensor([-0.4264, -0.2429, -1.6300])
    
  3. 指定对角线torch.diag函数还允许你通过diagonal参数指定要提取或使用的对角线。diagonal=0表示主对角线,diagonal>0表示主对角线上方的对角线,diagonal<0表示主对角线下方的对角线。例如,提取矩阵的第二条对角线:

    print(torch.diag(a, 1))
    # 输出:tensor([ 0.0255, 0.1374])
    

这些例子展示了如何使用torch.diag函数来创建对角矩阵或提取对角线元素,以及如何通过diagonal参数来指定对角线。这些操作在矩阵分解和转换等数学和深度学习任务中非常有用。

喜欢本文,请点赞、收藏和关注!

你可能感兴趣的:(Python,pytorch,人工智能,python)