【Tensorflow】gather与gather_nd

tf1.x 类比切片

目录

tf.one_hot 提取行

gather 提取行 

gather 提取列

gather_nd


tf.one_hot 提取行

与lookup的作用差不多embedding层_tensorflow中的Embedding操作详解_weixin_39835321的博客-CSDN博客

import tensorflow as tf
import numpy as np

embedding1 = tf.constant(
    [
        [0.21,0.41,0.51,0.11],
        [0.22,0.42,0.52,0.12],
        [0.23,0.43,0.53,0.13],
        [0.24,0.44,0.54,0.14]
    ],dtype=tf.float32)
 
feature_batch = tf.constant([2,3,1,0])
# feature_batch
# 


feature_batch_one_hot = tf.one_hot(feature_batch, depth=4)
# feature_batch_one_hot
# 

get_embedding2 = tf.matmul(feature_batch_one_hot, embedding1)
# get_embedding2
# 

运行时,上面的遍历才会产生

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    embedding1, embedding2 = sess.run([embedding1, get_embedding2])
    print("embedding1 \n", embedding1)
    print("embedding2 \n", embedding2)


embedding1 
 [[0.21 0.41 0.51 0.11]
 [0.22 0.42 0.52 0.12]
 [0.23 0.43 0.53 0.13]
 [0.24 0.44 0.54 0.14]]


embedding2 
 [[0.23 0.43 0.53 0.13]
 [0.24 0.44 0.54 0.14]
 [0.22 0.42 0.52 0.12]
 [0.21 0.41 0.51 0.11]]
# 可以看出emb2是通过emb1 根据下标[2,3,1,0]进行了行的调整

lookup 

- embedding_lookup函数的作用更像是一个搜索操作,即根据我们提供的索引,从对应的tensor中寻找对应位置的切片。

- 是gather函数的一种特殊形式

gather 提取行 

# gather, axis=0 (行)
# 当 params是二维的tensor,轴axis=0时,跟我们讲的embedding_lookup函数等价

embedding = tf.constant(
    [
        [0.21,0.41,0.51,0.11],
        [0.22,0.42,0.52,0.12],
        [0.23,0.43,0.53,0.13],
        [0.24,0.44,0.54,0.14]
    ],dtype=tf.float32)
 
index_a = tf.Variable([2,3,1,0])
gather_a = tf.gather(embedding, index_a)

# index_a
# 

# gather_a
# 

可以看到gather也实现了如上tf.one_hot的效果

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(gather_a))

# 在emb中,按照下标[2,3,1,0]顺序,构成了如下矩阵

[[0.23 0.43 0.53 0.13]
 [0.24 0.44 0.54 0.14]
 [0.22 0.42 0.52 0.12]
 [0.21 0.41 0.51 0.11]]

gather 提取列

# gather, axis=1 (列)
embedding = tf.constant(
    [
        [0.21,0.41,0.51,0.11],
        [0.22,0.42,0.52,0.12],
        [0.23,0.43,0.53,0.13],
        [0.24,0.44,0.54,0.14]
    ],dtype=tf.float32)
 
gather_a_axis1 = tf.gather(embedding, index_a, axis=1)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(gather_a_axis1))


[[0.51 0.11 0.41 0.21]
 [0.52 0.12 0.42 0.22]
 [0.53 0.13 0.43 0.23]
 [0.54 0.14 0.44 0.24]]

当emb是一维时

# 当params是一维的tensor
b = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
index_b = tf.Variable([2, 4, 6, 8])
gather_b = tf.gather(b, index_b)
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(gather_b))


[3 5 7 9]

# b下标2的值是3
# b下标4的值是5

gather_nd

tf.gather函数呢,我们只能通过一个维度的来获取切片,如果我们想要通过多个维度的联合索引来获取切片,可以通过gather_nd函数。

tf.reset_default_graph()
a = tf.Variable([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]])
a


index_a = tf.Variable([0, 2])
index_a



b = tf.get_variable(name='b',shape=[3,3,2],initializer=tf.random_normal_initializer)
b



index_b = tf.Variable([[0,1,1],[2,2,0]])
index_b

结果

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('-'*10)
    # 找a中第0个list的第2个值,是3
    print(sess.run(tf.gather_nd(a, index_a))) # 3 
    print('-'*10)



    print(sess.run(b)) # 随机初始化b矩阵,维度3,3,2
    print('-'*10)
    # 找b矩阵中,
    # [0,1,1], 第0层,第1层,的第一个值
    # [2,2,0],第2层,第2层,的第0个值
    print(sess.run(tf.gather_nd(b, index_b)))




----------
3 
----------
[[[ 0.5142302  -0.05901795]
  [-0.04706477  0.08232412]
  [-0.00842589 -1.1469455 ]]

 [[-0.4118051  -0.87490994]
  [-1.5529685   0.5411136 ]
  [ 0.49881363  2.527228  ]]

 [[ 0.19706753 -1.9549321 ]
  [ 0.551086    1.064308  ]
  [ 0.22157238 -2.3003275 ]]]
----------
[0.08232412 0.22157238]

你可能感兴趣的:(Tensorflow,&,Pytorch,tensorflow,python,人工智能)