tensorflow中去上下三角矩阵:matrix_band_part 和tf.linalg.band_part

tf.linalg.band_part

  • 新版本:tf.matrix_band_part变成tf.linalg.band_par

函数原型:

tf.linalg.band_part(
    input,
    num_lower,
    num_upper,
    name=None
)

参数:

  • 作用:主要功能是以对角线为中心,取它的副对角线部分,其他部分用0填充。
  • input:输入的张量.
  • num_lower:下三角矩阵保留的副对角线数量,从主对角线开始计算,相当于下三角的带宽。取值为负数时,则全部保留。
  • num_upper:上三角矩阵保留的副对角线数量,从主对角线开始计算,相当于上三角的带宽。取值为负数时,则全部保留。

例子:

import tensorflow as tf
tf.enable_eager_execution()
a=tf.constant( [[ 1,  1,  2, 3],[-1,  2,  1, 2],[-2, -1,  3, 1],
                 [-3, -2, -1, 5]],dtype=tf.float32)
b=tf.linalg.band_part(a,2,0)
c=tf.linalg.band_part(a,1,1)
d=tf.linalg.band_part(a,-1,1)
print(a)
print(b)
print(c)
print(d)
输出:
tf.Tensor(
[[ 1.  1.  2.  3.]
 [-1.  2.  1.  2.]
 [-2. -1.  3.  1.]
 [-3. -2. -1.  5.]], shape=(4, 4), dtype=float32)
=============================================================
tf.Tensor(
[[ 1.  0.  0.  0.]
 [-1.  2.  0.  0.]
 [-2. -1.  3.  0.]
 [ 0. -2. -1.  5.]], shape=(4, 4), dtype=float32)
 =============================================================
tf.Tensor(
[[ 1.  1.  0.  0.]
 [-1.  2.  1.  0.]
 [ 0. -1.  3.  1.]
 [ 0.  0. -1.  5.]], shape=(4, 4), dtype=float32)
  =============================================================
tf.Tensor(
[[ 1.  1.  0.  0.]
 [-1.  2.  1.  0.]
 [-2. -1.  3.  1.]
 [-3. -2. -1.  5.]], shape=(4, 4), dtype=float32)

你可能感兴趣的:(Tensorflow,API,学习)