tf.matrix_band_part 与 torch.tril 使用

tf.matrix_band_part 与 torch.tril 使用

这两个函数都是用来取矩阵的某一块值的,简单使用如下

tf.matrix_band_part

tf.matrix_band_part(m, -1, 0) 表示取矩阵的下左三角,不包括斜对角。

同理,tf.matrix_band_part(m, 0, -1)表示取矩阵的右上三角,不包括斜对角。

其他操作详细见tf.matrix_band_part ,搭配transpose可灵活取值;

torch.tril

torch.tril(m, diagonal=-1) 同样表示下三角,不包括斜对角;

torch.tril(m, diagonal=0) 表示下三角,包括斜对角。

具体操作见torch.tril

你可能感兴趣的:(Errors,Pytorch,python)