SuperPoint 特征点 NMS 代码分析

1. 简介

  SuperPoint 网络中所用到的 NMS 代码部分,输入为 scores,即输入图像每个像素的得分,输出依然是 scores,准特征点位置的得分保留,其余位置的得分置零。
  大体上是根据超参数 nms_radius 使图像每个局部区域都得到一个准特征点,后续再用得分阈值筛选等,从而去除距离很近的冗余特征点。

SuperPoint 中 forward 部分涉及 NMS 的代码

# Compute the dense keypoint scores
cPa = self.relu(self.convPa(x))
scores = self.convPb(cPa)
scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
b, _, h, w = scores.shape
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)

scores = simple_nms(scores, self.config['nms_radius'])
# Extract keypoints
keypoints = [
    torch.nonzero(s > self.config['keypoint_threshold'])
    for s in scores]
scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]

2. 代码

def simple_nms(scores, nms_radius: int):
    """ Fast Non-maximum suppression to remove nearby points """
    assert(nms_radius >= 0)

    def max_pool(x):
        return torch.nn.functional.max_pool2d(
            x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)

    zeros = torch.zeros_like(scores)
    max_mask = scores == max_pool(scores)
    for _ in range(2):
        supp_mask = max_pool(max_mask.float()) > 0
        supp_scores = torch.where(supp_mask, zeros, scores)
        new_max_mask = supp_scores == max_pool(supp_scores)
        max_mask = max_mask | (new_max_mask & (~supp_mask))
    return torch.where(max_mask, scores, zeros)

3. 代码解析与测试

超参数:nms_radius = 4
第一步:
max_mask = scores == max_pool(scores),当该点得分在以自身为中心的9*9范围内最大则为True,得到 初代特征点

第二步:
supp_mask = max_pool(max_mask.float()) > 0,把max_mask中为True的点的周围9*9范围的值都变成True
supp_scores = torch.where(supp_mask, zeros, scores),把supp_mask中为True的点置零,False的点取原scores对应的值
new_max_mask = supp_scores == max_pool(supp_scores),和第一步作用类似,这里整体的作用是将 初代特征点 范围内的得分都置零以后,用剩余范围里的得分得到 二代特征点
max_mask = max_mask | (new_max_mask & (~supp_mask)),合并 初代特征点二代特征点

第三步:
迭代

  由此,NMS所得到的特征点会比较均匀的分布在整个图像(应该每个特征点以自身为中心的9*9范围内不会有别的特征点),后续代码通过得分阈值筛选和得分排序去除可能性较低的特征点。

原始图像:
SuperPoint 特征点 NMS 代码分析_第1张图片
图像 + 特征点:
SuperPoint 特征点 NMS 代码分析_第2张图片
原始 scores(将 scores 按最大最小值等比缩放至0~255可视化):
SuperPoint 特征点 NMS 代码分析_第3张图片
NMS 后 scores(这里为了对比明显,把 nms_radius 调大为16):
SuperPoint 特征点 NMS 代码分析_第4张图片

你可能感兴趣的:(特征点检测与匹配,深度学习,计算机视觉,python,pytorch)