【TensorFlow】tf.scatter_update()

在看tensorflow官网的API的时候,看到一个更新数据的函数。该函数的目的是为了能更新tensor的值,这个函数也解决了之前我想要更新tensor值的想法。在网上找了很多关于 tf.scatter_update() 的资料,但是找到的基本都是tensorflow官网上的API介绍和Stack Overflow上的提问,可见关于这个API的中文资料是相当少的,所以我打算写下这篇博客来介绍 tf.scatter_update()。

在这里我简短的介绍一下这个函数的使用:

tf.scatter_update

scatter_update(

ref, 

indices,

updates,

use_locking=None,

name=None 

)

在源码,函数的定义的位置在 tensorflow/python/ops/gen_state_ops.py.

参数介绍:

ref: 原来的tensor;

indices: 原来tensor中要更新的索引值,同样也 tensor;

updates: 用于替代原来tensor的tensor值,注意,这个tensor和原来的tensor的shape要相同。


use_locking=None, name=None,一般情况下,我们使用默认的就好。

返回:依旧是一个tensor,shape和原来的tensor相同,是按照indices更新过tensor值的tensor;


介绍完了这个函数,那么我来举一个示例来让大家明白怎么去用这个函数。

代码如下:

import tensorflow as tf

g = tf.Graph()
with g.as_default():
    a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])
    b = tf.scatter_update(a, [0, 1], [[1, 1, 0, 0], [1, 0, 4, 0]])

with tf.Session(graph=g) as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(a))
    print(sess.run(b))


输出:

[[0 0 0 0]
 [0 0 0 0]]
[[1 1 0 0]
 [1 0 4 0]]

我们能看到原来的tensor是

[[0 0 0 0]

 [0 0 0 0]]

更新tensor值后的tensor是

[[1 1 0 0]

 [1 0 4 0]]


总结:1、对于tf.scatter_update()来说,ref和updates的shape一定要相同,要不然会报错;

   2、indices也是一个tensor,我们需要更新哪一维就写哪一维;

   3、这样的方式适合更新整个tensor的值,特别适合批量化更新tensor;















你可能感兴趣的:(TensorFlow)