tf1.x 类比切片
目录
tf.one_hot 提取行
gather 提取行
gather 提取列
gather_nd
与lookup的作用差不多embedding层_tensorflow中的Embedding操作详解_weixin_39835321的博客-CSDN博客
import tensorflow as tf
import numpy as np
embedding1 = tf.constant(
[
[0.21,0.41,0.51,0.11],
[0.22,0.42,0.52,0.12],
[0.23,0.43,0.53,0.13],
[0.24,0.44,0.54,0.14]
],dtype=tf.float32)
feature_batch = tf.constant([2,3,1,0])
# feature_batch
#
feature_batch_one_hot = tf.one_hot(feature_batch, depth=4)
# feature_batch_one_hot
#
get_embedding2 = tf.matmul(feature_batch_one_hot, embedding1)
# get_embedding2
#
运行时,上面的遍历才会产生
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
embedding1, embedding2 = sess.run([embedding1, get_embedding2])
print("embedding1 \n", embedding1)
print("embedding2 \n", embedding2)
embedding1
[[0.21 0.41 0.51 0.11]
[0.22 0.42 0.52 0.12]
[0.23 0.43 0.53 0.13]
[0.24 0.44 0.54 0.14]]
embedding2
[[0.23 0.43 0.53 0.13]
[0.24 0.44 0.54 0.14]
[0.22 0.42 0.52 0.12]
[0.21 0.41 0.51 0.11]]
# 可以看出emb2是通过emb1 根据下标[2,3,1,0]进行了行的调整
lookup
- embedding_lookup函数的作用更像是一个搜索操作,即根据我们提供的索引,从对应的tensor中寻找对应位置的切片。
- 是gather函数的一种特殊形式
# gather, axis=0 (行)
# 当 params是二维的tensor,轴axis=0时,跟我们讲的embedding_lookup函数等价
embedding = tf.constant(
[
[0.21,0.41,0.51,0.11],
[0.22,0.42,0.52,0.12],
[0.23,0.43,0.53,0.13],
[0.24,0.44,0.54,0.14]
],dtype=tf.float32)
index_a = tf.Variable([2,3,1,0])
gather_a = tf.gather(embedding, index_a)
# index_a
#
# gather_a
#
可以看到gather也实现了如上tf.one_hot的效果
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(gather_a))
# 在emb中,按照下标[2,3,1,0]顺序,构成了如下矩阵
[[0.23 0.43 0.53 0.13]
[0.24 0.44 0.54 0.14]
[0.22 0.42 0.52 0.12]
[0.21 0.41 0.51 0.11]]
# gather, axis=1 (列)
embedding = tf.constant(
[
[0.21,0.41,0.51,0.11],
[0.22,0.42,0.52,0.12],
[0.23,0.43,0.53,0.13],
[0.24,0.44,0.54,0.14]
],dtype=tf.float32)
gather_a_axis1 = tf.gather(embedding, index_a, axis=1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(gather_a_axis1))
[[0.51 0.11 0.41 0.21]
[0.52 0.12 0.42 0.22]
[0.53 0.13 0.43 0.23]
[0.54 0.14 0.44 0.24]]
当emb是一维时
# 当params是一维的tensor
b = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
index_b = tf.Variable([2, 4, 6, 8])
gather_b = tf.gather(b, index_b)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(gather_b))
[3 5 7 9]
# b下标2的值是3
# b下标4的值是5
tf.gather函数呢,我们只能通过一个维度的来获取切片,如果我们想要通过多个维度的联合索引来获取切片,可以通过gather_nd函数。
tf.reset_default_graph()
a = tf.Variable([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]])
a
index_a = tf.Variable([0, 2])
index_a
b = tf.get_variable(name='b',shape=[3,3,2],initializer=tf.random_normal_initializer)
b
index_b = tf.Variable([[0,1,1],[2,2,0]])
index_b
结果
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print('-'*10)
# 找a中第0个list的第2个值,是3
print(sess.run(tf.gather_nd(a, index_a))) # 3
print('-'*10)
print(sess.run(b)) # 随机初始化b矩阵,维度3,3,2
print('-'*10)
# 找b矩阵中,
# [0,1,1], 第0层,第1层,的第一个值
# [2,2,0],第2层,第2层,的第0个值
print(sess.run(tf.gather_nd(b, index_b)))
----------
3
----------
[[[ 0.5142302 -0.05901795]
[-0.04706477 0.08232412]
[-0.00842589 -1.1469455 ]]
[[-0.4118051 -0.87490994]
[-1.5529685 0.5411136 ]
[ 0.49881363 2.527228 ]]
[[ 0.19706753 -1.9549321 ]
[ 0.551086 1.064308 ]
[ 0.22157238 -2.3003275 ]]]
----------
[0.08232412 0.22157238]