只能reshape成-1,然后gather的时候累加batch_size去取
import tensorflow as tf
def gather_indexes_2d(sequence_tensor, positions):
sequence_shape = sequence_tensor.shape.as_list()
batch_size = sequence_shape[0]
seq_length = sequence_shape[1]
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
value = [[0,1],[2,3],[4,5]]
init = tf.constant_initializer(value)
v = tf.get_variable('value', shape=[3,2], initializer=init,dtype=tf.int32)
p = tf.placeholder(shape=[3], dtype=tf.int32)
v_ = gather_indexes_2d(v,p)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
print(sess.run(v_,feed_dict={p:[1,1,0]}))
打印结果[1 3 4]
rank3的情况:
import tensorflow as tf
def gather_indexes_3d(sequence_tensor, positions):
sequence_shape = sequence_tensor.shape.as_list()
batch_size = sequence_shape[0]
seq_length = sequence_shape[1]
width = sequence_shape[2]
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
v = tf.constant([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]]]) # [2,3,2]
p = tf.placeholder(shape=[2], dtype=tf.int32)
v_ = gather_indexes_3d(v,p)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
print(sess.run(v_,feed_dict={p:[1,0]}))
打印结果
[[2 2]
[4 4]]