tf.nn.embedding_lookup

tf.nn.embedding_lookup( params, ids, partition_strategy='mod', name=None, validate_indices=True, max_norm=None)

查找张量中序号为ids的
params:可以是张量,也可以是数组(embedding矩阵)
ids:

params = [[0, 0, 0, 0], [1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6],[4, 5, 6, 7]]
params = np.asarray(params)
index = [[1, 2, 3, 4, 0], [3, 4, 2, 1, 0]]
t0 = tf.nn.embedding_lookup(params, [2, 1, 3, 4, 0])
t1 = tf.nn.embedding_lookup(params, index)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(t0))
    print('='*20)
    print(sess.run(t1))

[[2 3 4 5]
 [1 2 3 4]
 [3 4 5 6]
 [4 5 6 7]
 [0 0 0 0]]
====================
[[[1 2 3 4]
  [2 3 4 5]
  [3 4 5 6]
  [4 5 6 7]
  [0 0 0 0]]

 [[3 4 5 6]
  [4 5 6 7]
  [2 3 4 5]
  [1 2 3 4]
  [0 0 0 0]]]


你可能感兴趣的:(tf.nn.embedding_lookup)