计算机视觉知识点整理:PointNet++之球查询(query ball)代码理解

# query_ball_point函数用于寻找球形邻域中的点。
# 输入中radius为球形邻域的半径,nsample为每个邻域中要采样的点,
# new_xyz为centroids点的数据,xyz为所有的点云数据
# 输出为每个样本的每个球形邻域的nsample个采样点集的索引[B,S,nsample]
def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    # sqrdists: [B, S, N] 记录S个中心点(new_xyz)与所有点(xyz)之间的欧几里德距离
    sqrdists = square_distance(new_xyz, xyz)
    # 找到所有距离大于radius^2的点,其group_idx直接置为N;其余的保留原来的值
    group_idx[sqrdists > radius ** 2] = N
    # 做升序排列,前面大于radius^2的都是N,会是最大值,所以直接在剩下的点中取出前nsample个点
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    # 考虑到有可能前nsample个点中也有被赋值为N的点(即球形区域内不足nsample个点),
    # 这种点需要舍弃,直接用第一个点来代替即可
    # group_first: 实际就是把group_idx中的第一个点的值复制;为[B, S, K]的维度,便于后面的替换
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    # 找到group_idx中值等于N的点
    mask = group_idx == N
    # 将这些点的值替换为第一个点的值
    group_idx[mask] = group_first[mask]
    return group_idx  # S个group

你可能感兴趣的:(计算机视觉,PointNet++)