一、tf.scatter_nd_update
函数定义:
tf.scatter_nd_update(
ref,
indices,
updates,
use_locking=True,
name=None
)
Applies sparse updates
to individual values or slices in a Variable.
将稀疏updates
应用于变量中的单个值或切片。
ref
is a Tensor
with rank P and indices
is a Tensor
of rank Q.
ref
是一个rank为P的 Tensor
,indices
是一个rank为Q的 Tensor
.
indices
must be integer tensor, containing indices into ref
. It must be shape [d_0, ..., d_{Q-2}, K]
where 0 < K <= P
.
indices
必须是一个由整数组成的tensor,包含了参考于 ref
的索引信息。shape必须满足下列条件:[d_0, ..., d_{Q-2}, K]
,其中 0 < K <= P
.
The innermost dimension of )的索引。 For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this: The resulting update to See 二、参数 返回值: 三、代码示例 有关返回值: 结果如下: 二维矩阵的更新: 结果如下: 三维矩阵的更新1: 结果如下: 三维矩阵的更新2: 结果如下:indices
(with length K
) corresponds to indices into elements (if K = P
) or slices (if K < P
) along the K
th dimension of ref
.
indices
的最内层维度(长度为 K
)对应于沿着 ref
的 K
维度的元素(如果是 K = P
)或切片(如果是 K
updates
is Tensor
of rank Q-1+P-K
with shape:
updates
是一个rank为 Q-1+P-K
,shape如下:[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
例如,假设我们想要将4个分散元素更新为1级张量到8个元素。 在Python中,该更新将如下所示:import tensorflow as tf
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
update = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
tf.global_variables_initializer().run()
print sess.run(update)
ref
would look like this:
ref
更新结果如下:[1, 11, 3, 10, 9, 6, 7, 12]
tf.scatter_nd
for more details about how to make updates to slices.
有关如何更新切片的更多详细信息,请参阅 tf.scatter_nd
。
参数
ref
A Variable.
一个变量
indices
A
Tensor
. Must be one of the following types: int32
, int64
. A tensor of indices into ref.一个
Tensor
,必须为以下的数据类型: int32
, int64
, 参考于ref的索引张量。
updates
A
Tensor
. Must have the same type as ref
. A tensor of updated values to add to ref.一个
Tensor
,必须和ref
保持相同的数据类型,含义为需要在ref中更新的值。
use_locking
An optional
bool
. Defaults to True
. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.一个可选的布尔数据, 默认为True。 如果为True,则分配将受锁保护; 否则行为未定义,但可能表现出较少的争用。
name
A name for the operation (optional).
名称,可选。
The value of the variable after the update.
变量更新后的值。(注:返回的是一个Tensor
,而不是一个Variable
!)
该函数的含义还是很好理解的,给定一些特定位置的数据进行更新,函数的意义以及参数在前面都有都有说明。这种稀疏更新的方法在numpy中普遍存在,但是在TensorFlow中,使用起来却十分别扭,往往还需要借助py_func来改写。不过有了这个函数就会方便很多。import tensorflow as tf
import tensorflow.contrib.eager as tfe
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
update = tf.scatter_nd_update(ref, indices, updates)
print(update)
print(type(update))
Tensor("ScatterNdUpdate:0", shape=(8,), dtype=int32_ref)
import tensorflow as tf
import tensorflow.contrib.eager as tfe
import numpy as np
tf.enable_eager_execution()
ref = tfe.Variable(np.zeros(shape=[6, 6], dtype=np.float32))
indices = tf.constant([[4, 4], [3, 0], [1, 5], [5, 0]], dtype=tf.int32)
updates = tf.constant([9, 10, 11, 12], dtype=tf.float32)
update = tf.scatter_nd_update(ref, indices, updates)
print(update)
tf.Tensor(
[[ 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 11.]
[ 0. 0. 0. 0. 0. 0.]
[10. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 9. 0.]
[12. 0. 0. 0. 0. 0.]], shape=(6, 6), dtype=float32)
import tensorflow as tf
import tensorflow.contrib.eager as tfe
import numpy as np
tf.enable_eager_execution()
ref = tfe.Variable(np.zeros(shape=[6, 6, 3], dtype=np.float32))
indices = tf.constant([[4, 4, 1], [3, 0, 2], [1, 5, 0], [5, 0, 1]], dtype=tf.int32)
updates = tf.constant([9, 10, 11, 12], dtype=tf.float32)
update = tf.scatter_nd_update(ref, indices, updates)
print(update)
tf.Tensor(
[[[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]]
[[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[11. 0. 0.]]
[[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]]
[[ 0. 0. 10.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]]
[[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 9. 0.]
[ 0. 0. 0.]]
[[ 0. 12. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]]], shape=(6, 6, 3), dtype=float32)
import tensorflow as tf
import tensorflow.contrib.eager as tfe
import numpy as np
tf.enable_eager_execution()
ref = tfe.Variable(np.zeros(shape=[6, 6, 3], dtype=np.float32))
indices = tf.constant([[4, 4], [3, 0], [1, 5], [5, 0]], dtype=tf.int32)
updates = tf.constant([[9, 9, 92], [10, 10, 120], [11, 11, 151], [12, 12, 55]], dtype=tf.float32)
update = tf.scatter_nd_update(ref, indices, updates)
print(update)
tf.Tensor(
[[[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]]
[[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 11. 11. 151.]]
[[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]]
[[ 10. 10. 120.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]]
[[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 9. 9. 92.]
[ 0. 0. 0.]]
[[ 12. 12. 55.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]]], shape=(6, 6, 3), dtype=float32)