torch.diagonal() 对于二维张量就是取对角线元素
对于三维张量,比如(6,m, n)
torch.diagonal(tensor, dim1=-2, dim2=-1) 代表分别取6个m*n张量的对角线元素。
不知道什么原理,反正能取。
参数:
input (Tensor) – the input tensor. Must be at least 2-dimensional.
offset (int, optional) – which diagonal to consider. Default: 0 (main diagonal).
dim1 (int, optional) – first dimension with respect to which to take diagonal. Default: 0.
dim2 (int, optional) – second dimension with respect to which to take diagonal. Default: 1
参考:https://blog.csdn.net/weixin_44248411/article/details/115407151