主要借助tf.diag_part
和tf.matrix_diag
两个方法来将方阵对角线置0.
in : inputs = [[1,2,3,4], [2,3,4,5], [3,4,5,6], [4,5,6,7]]
in : sess.run(tf.diag_part(inputs))
out: array([1, 3, 5, 7], dtype=int32)
# 对角线元素
in : x = tf.diag_part(inputs)
in: matrix = tf.matrix_diag(x)
# 原矩阵减去对角矩阵,即可实现对角线元素置0
in: sess.run(inputs- matrix)
out: array([[0, 2, 3, 4],
[2, 0, 4, 5],
[3, 4, 0, 6],
[4, 5, 6, 0]], dtype=int32)