ssd.pytorch源码分析(四)—default boxes与真实目标的匹配

匹配函数源码
SSD论文链接

default boxes与真实目标匹配介绍

在SSD算法的训练阶段涉及到一种匹配策略。具体来讲,为了计算损失函数,必须要选取一个预测框和一个真实框,两者匹配后其差异体现在损失函数中,这样才可以进行训练。但是default boxes和真实目标都不是唯一的,如何才可以众里寻他千百度,找到一对有缘的框呢?

论文描述:
ssd.pytorch源码分析(四)—default boxes与真实目标的匹配_第1张图片
作者分为2步:

  • a. 首先对每一个ground truth框匹配一个与之iou最大的default box
  • b. 然后对每一个default box与超过0.5iou的任何ground truth匹配

刚开始我是懵逼的,第a步不是匹配好了么,为什么还要匹配。后来结合论文和代码,作者的意思为:匹配的最终目的是“ determine which default boxes correspond to a ground truth detection”,也就是要决定哪些默认框足够“优秀”,并匹配给真实目标。这里有2个细节:1、并不是所有默认框都匹配给了真实目标;2、并不是一个真实目标只能匹配一个默认框。作者也提到了这样的好处:This simplifies the learning problem, allowing the network to predict high scores for multiple overlapping default boxes rather than requiring it to pick only the one with maximum overlap.简化了学习问题,一个真实目标可以匹配多个默认框而不是一个。为了能对应得上代码,我对上述操作进行总结:

  • 一、对每个default box匹配一个最高iou的ground truth;(b)
  • 二、对每一个ground truth框匹配一个与之iou最大的default box;(a)
  • 三、第二步的匹配结果覆盖第一步;
  • 四、匹配结果中除第二步的默认框外,将iou小于阈值的默认框设定为背景,即类别标签为0。(b)

代码解析

对于大型矩阵计算,很重要的一点是注意检查tensor的维度是否正确。因此我时不时地在下面的代码中进行维度信息的注释以帮助读者更好地理解代码。
(注意prior boxes就是default boxes也就是anchor)

"""比较重要的维度信息:
	num_objects:一张图中真实框ground truth的数量;
	num_priors:一张图中默认框default boxes的数量,包括所有层的特征图;
	batch:batch size;
"""

1、

对应于上述a、b步。

# 首先并行计算所有默认框和真实框之间两两的IOU 
#[num_objects,num_priors] truth,defaults
#point_form:[cx,cy,w,h]->[x1,y1,x2,y2]
overlaps = jaccard( truths , point_form(priors) )

# 互相匹配
# [num_objects,1] 每个真实框匹配一个最佳默认框
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
# [1,num_priors]  每个默认框匹配一个最佳真实框
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
best_truth_idx.squeeze_(0)
best_truth_overlap.squeeze_(0) #[num_priors]
best_prior_idx.squeeze_(1)
best_prior_overlap.squeeze_(1) #[num_objects]

对于代码中best_prior、best_truth等,为了方便理解,可参照下图:
图中横竖两个绿色的tensor分别代表为best_truth和best_prior,分别存储了每个默认框的最佳真实目标和每个真实目标的最佳默认框
ssd.pytorch源码分析(四)—default boxes与真实目标的匹配_第2张图片

2、

将各个best_prior的best_truth_overlap都修改为2,确保匹配的框不会因为iou太低而在第4步中被过滤掉。

best_truth_overlap.index_fill_(0, best_prior_idx, 2) 

3、

每一个真实框覆盖其匹配到的默认框匹配的真实框,保证每个真实目标和最大iou的框匹配。这一步对应上述c步。

for j in range(best_prior_idx.size(0)):
    best_truth_idx[best_prior_idx[j]] = j

4、

对应上述d步,给一个阈值筛掉不够“优秀”的默认框。

matches = truths[best_truth_idx]# [num_priors,4] 对每一个默认框都匹配一个真实框
conf = labels[best_truth_idx] + 1#[num_priors] 每个默认框匹配到的真实框的类
conf[best_truth_overlap < threshold] = 0
#如果匹配的iou小于阈值则定为背景,因此可以看出即使所有的默认框都匹配到了真实目标,但是由于有些匹配的iou太低仍被认为是负样本

完整代码

def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):
    """
    输入:
        threshold:匹配boxes的阈值.
        truths: Ground truth boxes  [num_objects,4]
        priors: Prior boxes from priorbox layers, [num_priors,4].
        variances: bbox回归时需要用到的参数,[num_priors, 4].
        labels: Ground truth boxes的类别标签, [num_objects,1].
        loc_t: 存储匹配后各default boxes的offset信息 [batch, num_priors, 4]
        conf_t: 存储匹配后各default boxes的真实类别标记 [batch, num_priors]
        idx: (int) current batch index
    返回:
        函数本身不返回值,但它会把匹配框的位置和置信信息存储在loc_t, conf_t两个tensor中。
    """
    overlaps = jaccard(  #[num_objects,num_priors] truth,defaults
        truths,
        point_form(priors)
    )
    # 互相匹配
    # [num_objects,1]每个真实框对应的默认框
    best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
    # [1,num_priors]每个默认框对应的真实框
    best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
    best_truth_idx.squeeze_(0)
    best_truth_overlap.squeeze_(0) #[num_priors]
    best_prior_idx.squeeze_(1)
    best_prior_overlap.squeeze_(1) #[num_objects]
   
    #各个best_prior的best_truth_overlap都修改为2,确保匹配的框不会因为阈值太低被过滤掉
    best_truth_overlap.index_fill_(0, best_prior_idx, 2) 

    #每一个真实框覆盖其匹配到的默认框匹配的真实框
    for j in range(best_prior_idx.size(0)):
        best_truth_idx[best_prior_idx[j]] = j
        
    matches = truths[best_truth_idx] #[num_priors,4] 对每一个默认框都匹配一个真实框
    conf = labels[best_truth_idx] + 1 #[num_priors] 每个默认框匹配到的真实框的类
    conf[best_truth_overlap < threshold] = 0  #如果匹配的iou小于阈值则定为背景
    loc = encode(matches, priors, variances) #编码,用于训练阶段,生成matched和默认框之间的offset
    loc_t[idx] = loc    # [num_priors,4] encoded offsets to learn
    conf_t[idx] = conf  # [num_priors] top class label for each prior

你可能感兴趣的:(ssd.pytorch源码分析)