tf实现用二维的索引从二维数组获取对应值 tf.gather_nd

a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
inds = tf.constant([[0, 2], [2, 1], [1, 1]])

#目的是实现 从[1,2,3]获取index为[0,2]的值也就是[1,3]作为第一行,
从[4,5,6]获取index为[2,1]的值也就是[6,5]作为第二行, 
从[7,8,9]获取index[1,1]的值作为第三行,也就是输出是
[[1 3]
 [6 5]
 [8 8]]






这种需求应该很常见,但是想通过look_up_table好像不行,以及想通过tf.gather_fn似乎可以但是也不好写

本文提供一种写法:

import tensorflow as tf

def gather_batch(v, inds):
    return tf.gather(v, inds)

def test2():
    a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    inds = tf.constant([[0, 2], [2, 1], [1, 1]])
    vs = tf.map_fn(fn=lambda x: gather_batch(x[:3], x[3:]), elems=tf.concat([a, inds], 1))

    with tf.Session() as sess:
        print(sess.run(vs))
 

if __name__ == '__main__':
    # test1()
    test2()

 

但是上面写法还是用了循环 会很慢 所以更好写法

def test3():
    a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    inds = tf.constant([[0, 2], [2, 1], [1, 1]])
    batch_size = inds.shape[0]
    cnt = inds.shape[1]
    left_inds = tf.tile(
        tf.expand_dims(tf.range(batch_size), 1),
        [1, cnt]
    )
    ind = tf.squeeze(
        tf.stack(
            [
                tf.expand_dims(left_inds, 2),
                tf.expand_dims(inds, 2),
            ],
            2
        )
        ,-1
    )

    vs = tf.gather_nd(a, ind)
    with tf.Session() as sess:
        # print(sess.run(ind))
        print(sess.run(vs))

 

 

你可能感兴趣的:(Tensorflow)