torch.diag()

函数定义:

def diag(input: Tensor, diagonal: _int=0, *, out: Optional[Tensor]=None)

参数:

* input:tensor

* diagonal:选择输出的对角线,默认为0,即输出主对角线

实际上这个函数就是输出一个矩阵的对角线。若diagonal为正的话,输出主对角线右上角的副对角线;若diagonal为负的话,输出主对角线左上角的副对角线。

import torch

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(a)
# output:
# tensor([[1, 2, 3],
#        [4, 5, 6],
#        [7, 8, 9]])

print(torch.diag(a))
# output:
# tensor([1, 5, 9])

print(torch.diag(a, 1))
# output:
# tensor([2, 6])

print(torch.diag(a, -1))
# output:
# tensor([4, 8])

print(torch.diag(a, 2))
# output:
# tensor([3])

print(torch.diag(a, -2))
# output:
# tensor([7])

你可能感兴趣的:(函数,pytorch,深度学习,python)