tensorflow中的concat、gather、gather_nd函数的使用

tf.concat

tf.concat(values,axis,name='concat')

把一组向量从某一维上拼接起来,很向numpy中的Concatenate,官网例子:

t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 1) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]

# tensor t3 with shape [2, 3]
# tensor t4 with shape [2, 3]
tf.shape(tf.concat([t3, t4], 0)) ==> [4, 3]
tf.shape(tf.concat([t3, t4], 1)) ==> [2, 6]

 

tf.gather

tf.gather(params,indices,validate_indices=None,name=None,axis=None,batch_dims=0)

其中, params must be at least rank axis + 1,axis默认为0。类似于数组的索引,可以把向量中某些索引值提取出来,得到新的向量,适用于要提取的索引为不连续的情况。

import tensorflow as tf

a = tf.Variable([[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]])
index_a = tf.Variable([0,2])

b = tf.Variable([1,2,3,4,5,6,7,8,9,10])
index_b = tf.Variable([2,4,6,8])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("axis=0:\t",sess.run(tf.gather(a, index_a)))  #获取行数为index_a的子数组
    print("axis=1:\t",sess.run(tf.gather(a, index_a,axis=1)))  #列
    print(sess.run(tf.gather(b, index_b)))  #当数组为一维时

'''
axis=0:	 [[ 1  2  3  4  5]
 [11 12 13 14 15]]

axis=1:	 [[ 1  3]
 [ 6  8]
 [11 13]]

[3 5 7 9]
'''

tf.gather_nd

tf.gather_nd(params,indices,name=None,batch_dims=0)

返回值:根据indices的具体索引,取出params对应位置的值。

a = tf.Variable([[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]])
index_a = tf.Variable([[0,2], [0,4], [2,2]])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.gather_nd(a, index_a)))

#  [ 3  5 13]


'''
#另一个例子:
    indices = [[1], [0]]
    params = [['a', 'b'], ['c', 'd']]
    output = [['c', 'd'], ['a', 'b']]
'''

 

应用:

1、可以利用tf.concat函数修改tensor指定item的值,例如:

#把tensor_1的第i项修改为0
tensor_1 = tf.constant([x for x in range(1, 10)])
#tensor_1[4] = 0  #TypeError: 'Tensor' object does not support item assignment
# 将原来的张量拆分为3部分,修改位置前的部分,要修改的部分和修改位置之后的部分
i = 4
part1 = tensor_1[:i]
part2 = tensor_1[i + 1:]
val = tf.constant([0])
new_tensor = tf.concat([part1, val, part2], axis=0)

print('new_tensor',tf.Session().run(new_tensor))  #new_tensor [1 2 3 4 0 6 7 8 9]

2、修改二维数组(x,y)处的值

def set_value_first(matrix, x, y, val):
    # 提取出要更新的行
    row = tf.gather(matrix, x)
    # 构造这行的新数据
    new_row = tf.concat([row[:y], [val], row[y+1:]], axis=0)
    # 使用 tf.scatter_update 方法进行替换
    matrix1=tf.scatter_update(matrix, x, new_row)
    return matrix1

matrix=tf.Variable(tf.ones([3,4]))
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    matrix_first=set_value_first(matrix,1,2,5.)
    print("matrix_first:\t",sess.run(matrix_first))


'''
matrix_first:	 [[1. 1. 1. 1.]
 [1. 1. 5. 1.]
 [1. 1. 1. 1.]]
'''

3、获取“数组”指定索引的值

arr=tf.constant([[1,2,3,4,5,5,7,8,9,10],
 [11,12,13,14,15,16,17,18,19,20],
 [21,22,23,24,25,26,27,28,29,30],
 [31,32,33,34,35,36,37,38,39,40]])

row=tf.range(4)
colum=tf.constant([2,3,4,5])

#tf.stack和tf.unstack的使用,详见:https://www.jianshu.com/p/25706575f8d4
ss=tf.stack([row,colum],axis=0)    #构成[[0,1,2,3],[2,3,4,5]]
indexs=tf.unstack(ss,axis=1)     #构成[[0,2],[1,3],[2,4],[3,5]]

newarr=tf.gather_nd(arr,indexs)

with tf.Session() as sess:
    print(sess.run(newarr))


'''
[ 3 14 25 36]
'''

 

参考的以下文章:如有侵权,请联删

Tensorflow常用函数笔记

Tensorflow 实现修改张量特定元素的值方法

修改TensorFlow张量特定位置的值

你可能感兴趣的:(TensorFlow)