tensorflow 利用索引获取tensor特定元素

上代码,利用tf.gather_nd函数:

eg1: 常数索引

import tensorflow as tf
sess=tf.Session()
a=tf.constant([[0,1,2],[3,4,5]])#shape:(2,3)
result=tf.gather_nd(a,[0,1])#1
sess.run(result)

eg2:变量索引

import tensorflow as tf
sess=tf.Session()
a=tf.constant([[0,1,2],[3,4,5]])#shape:(2,3)
b=tf.Variable([0,1],dtype=tf.int32)
sess.run(id.initializer)
result=tf.gather_nd(a,b)#1
sess.run(result)#1

你可能感兴趣的:(TensorFlow,笔记,tensorflow)