torch.diag_embed代码测试

文章目录

  • 1. 函数说明
  • 2. 代码
  • 3. 结果

1. 函数说明

torch.diag_embed(input, offset=0, dim1=- 2, dim2=- 1) → Tensor

创建一个张量,其某些2D平面(由dim1和dim2指定)的对角线由输入填充。为了方便批量创建对角矩阵,默认选择由返回张量的最后两个维度组成的2D平面。
创建一个对角张量,对角值是给定的张量的值,分布在对角线上,input 后输出的张量在最后一维进行扩充。

  • input = [2,3] -> output=[2,3,3]
  • input = [1,2,4] -> output =[1,2,4,4]
  • input =[3,2,5,2] -> output = [3,2,5,2,2]
    以上可以看出来,扩充是按照给定输入的张量的最后一维度大小进行扩充的。

2. 代码

import torch

x_2_3 = torch.ones(2, 3)
y_2_3 = torch.diag_embed(x_2_3)
x_1_2_4 = torch.ones((1, 2, 4))
y_1_2_4 = torch.diag_embed(x_1_2_4)
x_3_2_5_2 = torch.ones(3, 2, 5, 2)
y_3_2_5_2 = torch.diag_embed(x_3_2_5_2)
print(f'x_2_3.shape={x_2_3.shape}')
print(f'y_2_3.shape={y_2_3.shape}')
print(f'x_1_2_4.shape={x_1_2_4.shape}')
print(f'y_1_2_4.shape={y_1_2_4.shape}')
print(f'x_3_2_5_5.shape={x_3_2_5_2.shape}')
print(f'y_3_2_5_5.shape={y_3_2_5_2.shape}')

3. 结果

x_2_3.shape=torch.Size([2, 3])
y_2_3.shape=torch.Size([2, 3, 3])
x_1_2_4.shape=torch.Size([1, 2, 4])
y_1_2_4.shape=torch.Size([1, 2, 4, 4])
x_3_2_5_5.shape=torch.Size([3, 2, 5, 2])
y_3_2_5_5.shape=torch.Size([3, 2, 5, 2, 2])

你可能感兴趣的:(pytorch,pytorch,深度学习,python)