torch.diagonal取三维张量的各自二维张量的对角线

torch.diagonal() 对于二维张量就是取对角线元素

对于三维张量,比如(6,m,  n)

torch.diagonal(tensor, dim1=-2, dim2=-1) 代表分别取6个m*n张量的对角线元素。

torch.diagonal取三维张量的各自二维张量的对角线_第1张图片

 不知道什么原理,反正能取。

参数:

input (Tensor) – the input tensor. Must be at least 2-dimensional.
offset (int, optional) – which diagonal to consider. Default: 0 (main diagonal).
dim1 (int, optional) – first dimension with respect to which to take diagonal. Default: 0.
dim2 (int, optional) – second dimension with respect to which to take diagonal. Default: 1

参考:https://blog.csdn.net/weixin_44248411/article/details/115407151

你可能感兴趣的:(python,开发语言,pytorch)