tf.gather_nd()的用法

定义

def gather_nd(params, indices, name=None)

功能

根据indeces描述的索引,在params中提取元素,重新组成一个tansor

举例

image.png
data shape is (3, 2, 3)
data rank is 3

indices = np.array([[0, 1], [1, 0]])
indices shape is (2, 2)

最后的切片的结果是indices中表示索引的部分被提取到的值替换后得到的结构。

以上面的例子说明这个思路:

image.png

[0, 1]索引得到[2, 2, 2]
[1, 0]索引得到[3, 3, 3]

把索引的结果替换到indices中得到:[[2, 2, 2], [3, 3, 3]]

当索引indices为 [[[[1,1]]]]时,
先找出[1, 1]的索引结果为[4,4,4]
替换到上面结构中得到 [[[[4, 4, 4]]]]

举例

nn_pts = tf.gather_nd(pts, indices, name=tag + 'nn_pts')  # (N, P, K, 3)

其中:

nn_pts.shape is (32, 1024, 3)
indices.shape is (32, 512, 32, 2)
output.shape is (32, 512, 32, 3)

已知 nn_pts的最小component是某个点的坐标(x, y, z),即 3 代表的含义。
indices的最小component是(a, b), 就是说要取nn_pts第0维的第 a 个,第1维的第b个值,取出来的这个值是一个point。
一共取了(32, 512, 32)这么多的point,所以最后的output形状为(32, 512, 32, 3)

你可能感兴趣的:(tf.gather_nd()的用法)