Tensorflow深度学习之三十二: tf.scatter_nd_update

一、tf.scatter_nd_update
   函数定义:

tf.scatter_nd_update(
    ref,
    indices,
    updates,
    use_locking=True,
    name=None
)

   Applies sparse updates to individual values or slices in a Variable.
   将稀疏updates应用于变量中的单个值或切片。

   ref is a Tensor with rank P and indices is a Tensor of rank Q.
   ref 是一个rank为P的 Tensorindices 是一个rank为Q的 Tensor.

   indices must be integer tensor, containing indices into ref. It must be shape [d_0, ..., d_{Q-2}, K] where 0 < K <= P.
   indices 必须是一个由整数组成的tensor,包含了参考于 ref 的索引信息。shape必须满足下列条件:[d_0, ..., d_{Q-2}, K] ,其中 0 < K <= P.

   The innermost dimension of indices (with length K) corresponds to indices into elements (if K = P) or slices (if K < P) along the Kth dimension of ref.
   indices 的最内层维度(长度为 K)对应于沿着 refK 维度的元素(如果是 K = P)或切片(如果是 K )的索引。
   updates is Tensor of rank Q-1+P-K with shape:
   updates 是一个rank为 Q-1+P-K,shape如下:

[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].

   For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this:
   例如,假设我们想要将4个分散元素更新为1级张量到8个元素。 在Python中,该更新将如下所示:

import tensorflow as tf

ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
update = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print sess.run(update)

   The resulting update to ref would look like this:
   ref 更新结果如下:

[1, 11, 3, 10, 9, 6, 7, 12]

   See tf.scatter_nd for more details about how to make updates to slices.
   有关如何更新切片的更多详细信息,请参阅 tf.scatter_nd

二、参数

参数
ref A Variable. 一个变量
indices A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. 一个 Tensor,必须为以下的数据类型: int32, int64, 参考于ref的索引张量。
updates A Tensor. Must have the same type as ref. A tensor of updated values to add to ref. 一个Tensor,必须和ref保持相同的数据类型,含义为需要在ref中更新的值。
use_locking An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. 一个可选的布尔数据, 默认为True。 如果为True,则分配将受锁保护; 否则行为未定义,但可能表现出较少的争用。
name A name for the operation (optional). 名称,可选。

返回值:
   The value of the variable after the update.
   变量更新后的值。(注:返回的是一个Tensor,而不是一个Variable!)

三、代码示例
   该函数的含义还是很好理解的,给定一些特定位置的数据进行更新,函数的意义以及参数在前面都有都有说明。这种稀疏更新的方法在numpy中普遍存在,但是在TensorFlow中,使用起来却十分别扭,往往还需要借助py_func来改写。不过有了这个函数就会方便很多。

  有关返回值:

import tensorflow as tf
import tensorflow.contrib.eager as tfe

ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
update = tf.scatter_nd_update(ref, indices, updates)
print(update)
print(type(update))

   结果如下:

Tensor("ScatterNdUpdate:0", shape=(8,), dtype=int32_ref)

   二维矩阵的更新:

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import numpy as np
tf.enable_eager_execution()

ref = tfe.Variable(np.zeros(shape=[6, 6], dtype=np.float32))
indices = tf.constant([[4, 4], [3, 0], [1, 5], [5, 0]], dtype=tf.int32)
updates = tf.constant([9, 10, 11, 12], dtype=tf.float32)
update = tf.scatter_nd_update(ref, indices, updates)
print(update)

   结果如下:

tf.Tensor(
[[ 0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0. 11.]
 [ 0.  0.  0.  0.  0.  0.]
 [10.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  9.  0.]
 [12.  0.  0.  0.  0.  0.]], shape=(6, 6), dtype=float32)

   三维矩阵的更新1:

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import numpy as np
tf.enable_eager_execution()

ref = tfe.Variable(np.zeros(shape=[6, 6, 3], dtype=np.float32))
indices = tf.constant([[4, 4, 1], [3, 0, 2], [1, 5, 0], [5, 0, 1]], dtype=tf.int32)
updates = tf.constant([9, 10, 11, 12], dtype=tf.float32)
update = tf.scatter_nd_update(ref, indices, updates)
print(update)

   结果如下:

tf.Tensor(
[[[ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]]

 [[ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [11.  0.  0.]]

 [[ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]]

 [[ 0.  0. 10.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]]

 [[ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  9.  0.]
  [ 0.  0.  0.]]

 [[ 0. 12.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]]], shape=(6, 6, 3), dtype=float32)

   三维矩阵的更新2:

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import numpy as np
tf.enable_eager_execution()

ref = tfe.Variable(np.zeros(shape=[6, 6, 3], dtype=np.float32))
indices = tf.constant([[4, 4], [3, 0], [1, 5], [5, 0]], dtype=tf.int32)
updates = tf.constant([[9, 9, 92], [10, 10, 120], [11, 11, 151], [12, 12, 55]], dtype=tf.float32)
update = tf.scatter_nd_update(ref, indices, updates)
print(update)

   结果如下:

tf.Tensor(
[[[  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]]

 [[  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [ 11.  11. 151.]]

 [[  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]]

 [[ 10.  10. 120.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]]

 [[  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  9.   9.  92.]
  [  0.   0.   0.]]

 [[ 12.  12.  55.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]
  [  0.   0.   0.]]], shape=(6, 6, 3), dtype=float32)

你可能感兴趣的:(深度学习,Tensorflow)