TensorFlow-tf.nn.embedding_lookup()函数解析

眼看千遍,不如手动一遍,看了原文再手动整理一遍,代码实际操作一遍,加深理解。相当于高中时做的笔记了。

tf.nn.embedding_lookup的用法主要是选取一个张量里面索引对应的元素

原型:

tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None, validate_indices=True, max_norm=None)

params 代表输入的张量,ids代表要选取params里对应的那个维度的数据

简单来个例子(粘贴可直接运行)

代码:

import tensorflow as tf

import numpy as np

a = [[0.1, 0.2, 0.3], [1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3]]

a = np.asarray(a)

idx1 = tf.Variable([0, 2, 3, 1], tf.int32)

idx2 = tf.Variable([[0, 2, 3, 1], [4, 0, 2, 2]], tf.int32)

b = [[0.1, 0.2, 1], [2.1, 1.2, 1]]

b = np.asarray(b)

idx3 = tf.placeholder(tf.int32, [None, 3], name="input_x")

out1 = tf.nn.embedding_lookup(a, idx1)

out2 = tf.nn.embedding_lookup(a, idx2)

out3 = tf.nn.embedding_lookup(a, idx3)

init = tf.global_variables_initializer()

with tf.Session() as sess:

    sess.run(init)

    print (sess.run(out1))

    print (out1)

    print ('==================')

    print (sess.run(out2))

    print (out2)

    print (sess.run(out3, feed_dict ={idx3: b}))

    print (out3)

结果:

TensorFlow-tf.nn.embedding_lookup()函数解析_第1张图片

分析:

1.第一个out1代表从a中依次取第 0, 2, 3, 1维数据进行拼装,拼出来的shape还是(4,3)

2.第二个out2代表从a中依次取 第0, 2, 3, 1维数据拼装一个(4,3)的数据 接着再从a中依次取4, 0, 2, 2 维来进行拼装,之后再把两个(4, 3) 拼装在一起形成(2,4,3)的张量(tensor)

3.第三个使用了placeholder来输入ids,placeholder的shape为(?,3),代表从数据里先取3个数据出来,每个数据有3个元素,最后再 ?个(3, 3)拼接在一起组成(?,3,3)的tensor

    有好多小伙伴在公众号给我留言,我没有及时回复,公众号消息我不常看,过了48小时就不能回复了,如果有啥可以先在CSDN评论里提醒我一下哈(20190830)。我是一个菜鸟,需要学习提高!

你可能感兴趣的:(#,TensorFlow其他)