SuperPoint 网络中所用到的 NMS 代码部分,输入为 scores,即输入图像每个像素的得分,输出依然是 scores,准特征点位置的得分保留,其余位置的得分置零。
大体上是根据超参数 nms_radius 使图像每个局部区域都得到一个准特征点,后续再用得分阈值筛选等,从而去除距离很近的冗余特征点。
SuperPoint 中 forward 部分涉及 NMS 的代码
# Compute the dense keypoint scores
cPa = self.relu(self.convPa(x))
scores = self.convPb(cPa)
scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
b, _, h, w = scores.shape
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)
scores = simple_nms(scores, self.config['nms_radius'])
# Extract keypoints
keypoints = [
torch.nonzero(s > self.config['keypoint_threshold'])
for s in scores]
scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
def simple_nms(scores, nms_radius: int):
""" Fast Non-maximum suppression to remove nearby points """
assert(nms_radius >= 0)
def max_pool(x):
return torch.nn.functional.max_pool2d(
x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)
zeros = torch.zeros_like(scores)
max_mask = scores == max_pool(scores)
for _ in range(2):
supp_mask = max_pool(max_mask.float()) > 0
supp_scores = torch.where(supp_mask, zeros, scores)
new_max_mask = supp_scores == max_pool(supp_scores)
max_mask = max_mask | (new_max_mask & (~supp_mask))
return torch.where(max_mask, scores, zeros)
超参数:nms_radius = 4
第一步:
max_mask = scores == max_pool(scores)
,当该点得分在以自身为中心的9*9范围内最大则为True
,得到 初代特征点
第二步:
supp_mask = max_pool(max_mask.float()) > 0
,把max_mask
中为True
的点的周围9*9范围的值都变成True
supp_scores = torch.where(supp_mask, zeros, scores)
,把supp_mask
中为True
的点置零,False
的点取原scores
对应的值
new_max_mask = supp_scores == max_pool(supp_scores)
,和第一步作用类似,这里整体的作用是将 初代特征点 范围内的得分都置零以后,用剩余范围里的得分得到 二代特征点
max_mask = max_mask | (new_max_mask & (~supp_mask))
,合并 初代特征点 与 二代特征点
第三步:
迭代
由此,NMS所得到的特征点会比较均匀的分布在整个图像(应该每个特征点以自身为中心的9*9范围内不会有别的特征点),后续代码通过得分阈值筛选和得分排序去除可能性较低的特征点。
原始图像:
图像 + 特征点:
原始 scores(将 scores 按最大最小值等比缩放至0~255可视化):
NMS 后 scores(这里为了对比明显,把 nms_radius 调大为16):