TensorFlow-HasTable: 特征ID映射集成TF中

相信很多同志,在做深度学习模型的过程中,往往会需要将特征的原始值映射为数值类型的ID,然后再通过tf.nn.embedding_lookup转化为dense向量。最终,在上线的时候,映射关系一般保存为哈希表(dict),但如果特征很多,那么管理起来就很麻烦。

今天这篇博客会讲述《如何将这个过程在tensorflow实现》!

MutableHashTable

首先,先附上官方的API文档

tf.contrib.lookup.MutableHashTable(
    key_dtype, value_dtype, default_value, name='MutableHashTable', checkpoint=True
)
Args
key_dtype the type of the key tensors.
value_dtype the type of the value tensors.
default_value The value to use if a key is missing in the table.
name A name for the operation (optional).
checkpoint if True, the contents of the table are saved to and restored from checkpoints. If shared_name is empty for a checkpointed table, it is shared using the table node name.
Raises
ValueError If checkpoint is True and no name was specified.
Attributes
key_dtype The table key dtype.
name The name of the table.
resource_handle Returns the resource handle associated with this Resource.
value_dtype The table value dtype.

并且,它提供了哈希表的基本操作:

insert:插入键值对
export:导出hashtable
lookup:key查询
remove:删除key
size:hashtable的容量

demo代码

import tensorflow as tfÅ
import time


def demo():
    """
    insert:插入键值对
    export:导出hashtable
    lookup:key查询
    remove:删除key
    size:hashtable的容量
    :return:
    """
    keys = tf.placeholder(dtype=tf.string, shape=[None])
    values = tf.placeholder(dtype=tf.int64, shape=[None])
    # 如果有多个表,则需要name命名,否则保存加载时,会因为都是默认命名而导致被覆盖
    table1 = tf.contrib.lookup.MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=-1,
                                                name="HashTable_1")
    table2 = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64, -1)
    insert_table1 = table1.insert(keys, values)
    insert_table2 = table2.insert(keys, values)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(insert_table1, feed_dict={keys: ["a"], values: [1]})
        sess.run(insert_table2, feed_dict={keys: ["b"], values: [2]})
        print("table1:", sess.run(table1.export()))
        print("table2:", sess.run(table2.export()))
        saver.saverve(sess, "checkpoint/test")


def run():
    """
    测试50W容量的hashtable,保存的大小和查询速度
    :return:
    """
    size = 500000
    keys = tf.placeholder(dtype=tf.string, shape=[None])
    values = tf.placeholder(dtype=tf.int64, shape=[None])
    # 如果有多个表,则需要name命名,否则保存加载时,会因为都是默认命名而导致被覆盖
    table1 = tf.contrib.lookup.MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=-1,
                                                name="HashTable_1")
    insert_table1 = table1.insert(keys, values)
    lookup = table1.lookup(keys)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(insert_table1, feed_dict={keys: ["id_" + str(i) for i in range(size)], values: list(range(size))})
        # print("table1:", sess.run(table1.export()))

        # 查询时间:0.007218122482299805
        # 模型大小:8.9M
        s1 = time.time()
        print(sess.run(lookup, feed_dict={keys: ["id_1", "id_100"]}))
        print(time.time() - s1)
        saver.save(sess, "checkpoint/test")


if __name__ == '__main__':
    run()

你可能感兴趣的:(python,tensorflow,python,tensorflow,深度学习,机器学习)