Tensorflow 中的embedding_lookup详解

本文将通俗的进行解释embedding_lookup( )的用法 

首先看一段简单代码:

#!/usr/bin/env/python
# coding=utf-8
import tensorflow as tf
import numpy as np

input_ids = tf.placeholder(dtype=tf.int32, shape=[None])

embedding = tf.Variable(np.identity(5, dtype=np.int32))
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print(embedding.eval())
print(sess.run(input_embedding, feed_dict={input_ids:[1, 2, 3, 0, 3, 2, 1]}))

代码中先使用palceholder定义了一个未知变量input_ids用于存储索引,和一个已知变量embedding,是一个5*5的对角矩阵。 
运行结果为:

embedding = [[1 0 0 0 0]
             [0 1 0 0 0]
             [0 0 1 0 0]
             [0 0 0 1 0]
             [0 0 0 0 1]]
input_embedding = [[0 1 0 0 0]
                   [0 0 1 0 0]
                   [0 0 0 1 0]
                   [1 0 0 0 0]
                   [0 0 0 1 0]
                   [0 0 1 0 0]
                   [0 1 0 0 0]]

简单的讲就是根据input_ids中的id,寻找embedding中的对应元素。比如,input_ids=[1,3,5],则找出embedding中下标为1,3,5的向量组成一个矩阵返回。

如果将input_ids改写成下面的格式:

input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
print(sess.run(input_embedding, feed_dict={input_ids:[[1, 2], [2, 1], [3, 3]]}))

输出结果就会变成如下的格式:

[[[0 1 0 0 0]
  [0 0 1 0 0]]
 [[0 0 1 0 0]
  [0 1 0 0 0]]
 [[0 0 0 1 0]
  [0 0 0 1 0]]]

对比上下两个结果不难发现,相当于在np.array中直接采用下标数组获取数据。需要注意的细节是返回的tensor的dtype和传入的被查询的tensor的dtype保持一致;和ids的dtype无关。

你可能感兴趣的:(类讲解)