tf.scatter_nd详解

关于tf.sactter找了好多blog都没有找到比较详细的说明,一般都是翻译一下官方文档的列子,只对简单情况做了说明,但是稍微复杂一点的没做解释。自己英文没怎么看明白,所以就自己实验一下这个函数到底怎么玩的。下面具体看例子。

简单情况

indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
shape = tf.constant([8])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter)

结果如下:

[0, 11, 0, 10, 9, 0, 0, 12]

如上所示,indice是shape中的索引,shape是初始化一个0矩阵,将updates中的值按照indices插到具体的位置上。

再看一个复杂情况:

shape = [4, 4]
indices = np.array([[[0,1],[1,3]], 
                    [[2,2],[2,2]]])
indices = tf.constant(indices)
update = tf.constant(np.arange(4).reshape(2,2))
tf.scatter_nd(indices, update, [4,4])

结果如下:

array([[0, 0, 0, 0],
       [0, 0, 0, 1],
       [0, 0, 5, 0],
       [0, 0, 0, 0]])>

结果就很显然了,对于updates来说,update的每一个值的对应的位置对应一个indice的索引,索引是指向shape,例如,shape[0, 1] = update[0,0], shape[2,2] = update[1,0] + update[1,1], indices比update多一个维度。

你可能感兴趣的:(tf.scatter_nd详解)