tensorflow 常用函数

文章目录

      • 常用函数
        • tf.gather
        • tf.nn.l2_normalizez()

常用函数

tf.gather

tf.gather(params,indices,axis=0) 从params的axis维根据indices的参数值获取切片

t1 = tf.ones(shape=(10,10))
t2 = tf.Variable([2,3])
t3 = tf.gather(t1,t2)
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(t3))

[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]

t3 = tf.gather(t1,t2,axis=1)
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(t3))

[[1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]]

tf.nn.l2_normalizez()

讲解的很详细tf.nn.l2_normalize的使用

你可能感兴趣的:(Tensorflow)