tf.nn.embedding_lookup函数以及对嵌入表示的理解

假设有一组分类,总共有5个类别,我们对所有类别进行哑编码(one-hot),则编码后[1,0,0,0,0]为类别1,[0,1,0,0,0]为类别2,[0,0,1,0,0]为类别3,[0,0,0,1,0]为类别4,[0,0,0,0,1]为类别5.若类别过多,该如何处理?我们用one-hot表示,就有可能使得训练参数过于庞大,而且不能表示不同类别之间的相似度,于是我们想到了嵌入(Embeding)方式,即用一组更短的向量表示原类别

看图说话:
tf.nn.embedding_lookup函数以及对嵌入表示的理解_第1张图片
类别表示向量从5维的one-hot变成了一个三维的向量,而且这个向量是可以计算类别之间的距离的,通过计算距离,我们就可以判断哪两个类别更为相似,比如苹果、香蕉这两个类别的相似度就应该比苹果、汽车两个类别的相似度高,此时我们应该可以得到苹果、香蕉这两个类别向量的距离更近,从而解决了类别过多时训练参数过多的问题,也解决了类别相似度表示的问题
:向量矩阵是我们提前预训练出来的,与其他权重矩阵一样

tf.nn.embedding_lookup函数如下

tf.nn.embedding_lookup(
    params,
    ids,
    partition_strategy='mod',
    name=None,
    validate_indices=True,
    max_norm=None
)
Args(函数参数):
    params:代表完整的嵌入张量的简单张量,即上面所讲向量矩阵
    ids: 要查找的id编号列表,eg:[0,2],即指要查找类别1、类别3在向量矩阵中的表示向量(下标从0开始)
    partition_strategy:切分策略,在 len(params) > 1 的情况下使用.目前支持两种切分方式:"div"和"mod",默认是"mod".
    name:操作的名称(可选)
    validate_indices: 已弃用
    max_norm: 如果所嵌入向量的l2范数大于max_norm,则所嵌入向量将被裁剪
Returns:
    该函数与 params 中的张量具有相同类型的 Tensor

代码实现:
功能:查询类别1,类别3在向量矩阵中的表示向量
:tf.nn.embedding_lookup省去了one-hot向量与向量矩阵(或者叫嵌入矩阵)相乘的过程,one-hot有很多的0,与0相乘毫无意义,故只需取出one-hot向量中1位置的下标,然后到向量矩阵中查找与该下标位置相乘的嵌入向量就可以了,节约了计算成本

# 查询类别1,类别3在向量矩阵中的表示向量
import numpy as np
import tensorflow as tf
params =[[0.11,0.28,0.48],
         [0.33,0.23,0.45],
         [0.11,0.22,0.49],
         [0.36,0.45,0.88],
         [0.28,0.29,0.36]]  #提前训练好的向量矩阵
params = tf.convert_to_tensor(params)
ids = [0,2] #即指要查找类别1、类别3在向量矩阵中的表示向量(下标从0开始)
lookup_embeding_params =tf.nn.embedding_lookup(
	                             params=params,
                                 ids=ids,
                                 partition_strategy='mod',
                                 name="test_embeding_lookup",
                                 validate_indices=True,
                                 max_norm=None)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("在向量矩阵种查询的类别1,类别3的表示向量为:",
          sess.run(lookup_embeding_params))

打印结果如下:

在向量矩阵种查询的类别1,类别3的表示向量为: [[0.11 0.28 0.48]
 [0.11 0.22 0.49]]

你可能感兴趣的:(深度学习,tensorflow,tensorflow,嵌入查表)