a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
inds = tf.constant([[0, 2], [2, 1], [1, 1]])
#目的是实现 从[1,2,3]获取index为[0,2]的值也就是[1,3]作为第一行,
从[4,5,6]获取index为[2,1]的值也就是[6,5]作为第二行,
从[7,8,9]获取index[1,1]的值作为第三行,也就是输出是
[[1 3]
[6 5]
[8 8]]
这种需求应该很常见,但是想通过look_up_table好像不行,以及想通过tf.gather_fn似乎可以但是也不好写
本文提供一种写法:
import tensorflow as tf
def gather_batch(v, inds):
return tf.gather(v, inds)
def test2():
a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
inds = tf.constant([[0, 2], [2, 1], [1, 1]])
vs = tf.map_fn(fn=lambda x: gather_batch(x[:3], x[3:]), elems=tf.concat([a, inds], 1))
with tf.Session() as sess:
print(sess.run(vs))
if __name__ == '__main__':
# test1()
test2()
但是上面写法还是用了循环 会很慢 所以更好写法
def test3():
a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
inds = tf.constant([[0, 2], [2, 1], [1, 1]])
batch_size = inds.shape[0]
cnt = inds.shape[1]
left_inds = tf.tile(
tf.expand_dims(tf.range(batch_size), 1),
[1, cnt]
)
ind = tf.squeeze(
tf.stack(
[
tf.expand_dims(left_inds, 2),
tf.expand_dims(inds, 2),
],
2
)
,-1
)
vs = tf.gather_nd(a, ind)
with tf.Session() as sess:
# print(sess.run(ind))
print(sess.run(vs))