MultiBoxLoss源码
SSD论文链接
本文代码涉及很多复杂矩阵索引操作,推荐阅读。
在SSD中,默认框default boxes和真实目标ground truth先进行匹配。
匹配策略细节
然后根据匹配到的一对boxes分别计算分类损失和定位损失。
从上面的描述可以看出,可能有多个default boxes匹配到一个ground truth的情况。其中α为权重系数,论文和代码中取1。
代码中定义了MultiBoxLoss类,其父类为torch.nn.model。"__ init __"函数如下:
class MultiBoxLoss(nn.Module):
def __init__(self, num_classes, overlap_thresh, prior_for_matching,
bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
use_gpu=True):
#(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5,False, args.cuda)
super(MultiBoxLoss, self).__init__()
self.use_gpu = use_gpu
self.num_classes = num_classes
self.threshold = overlap_thresh #匹配时需要的iou阈值
self.negpos_ratio = neg_pos #需要训练的负正样本比例
self.variance = cfg['variance']
forward函数中,为计算损失函数,需要先对数据进行包括匹配、正样本寻找等的操作:
def forward(self, predictions, targets):
"""forward函数第一部分的内容
输入:
predictions (tuple): 一个三元素的元组,包含了预测信息.
loc_data [batch,num_priors,4] 所有默认框预测的offsets.
conf_data [batch,num_priors,num_classes] 所有预测框预测的分类置信度.
priors [num_priors,4] 所有默认框的位置
targets [batch,num_objs,5] (last idx is the label).所有真实目标的信息
返回:
loss_l, loss_c:定位损失和分类损失
"""
loc_data, conf_data, priors = predictions
num = loc_data.size(0) #batchsize
num_priors = (priors.size(0))
num_classes = self.num_classes
"""每个default box匹配一个gt
具体分析见:ssd.pytorch源码分析(四)"""
#[batch, num_priors, 4] 匹配到的真实目标和默认框之间的offset,是learning target
loc_t = torch.Tensor(num, num_priors, 4)
#[batch, num_priors] 匹配后默认框的类别,是learning target
conf_t = torch.LongTensor(num, num_priors)
#对于batch中的每一个图片进行匹配
for idx in range(num):
truths = targets[idx][:, :-1].data #[num_objs,4]
labels = targets[idx][:, -1].data #[num_objs,1]
defaults = priors.data
match(self.threshold, truths, defaults, self.variance, labels,
loc_t, conf_t, idx)
if self.use_gpu:
loc_t = loc_t.cuda()
conf_t = conf_t.cuda()
loc_t = Variable(loc_t, requires_grad=False)
conf_t = Variable(conf_t, requires_grad=False)
#正样本查找,等于0为背景 [batch, num_priors]
pos = conf_t > 0
num_pos = pos.sum(dim=1, keepdim=True)
"""接下来的操作为损失函数的计算等"""
对于默认框定位,论文还是采取了anchor-based检测算法中最常用的bounding box回归,损失函数也采用了和RCNN系列一样的smooth_l1_loss。
对于下图公式中的g和l,对应在代码中分别代表已经encode完成的offset。(g代表一对匹配的真实框和默认框的offset,对应代码中的loc_t,l代表预测框和默认框之间的offset,对应代码中的loc_p)。
"""forward函数第二部分内容"""
# Localization Loss (Smooth L1)
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) #[batch, num_priors, 4] 4层都相同 正例索引
#[batch*num_positive, 4] loc_data保存了所有默认框的predict offset,loc_p保存其中的正例
loc_p = loc_data[pos_idx].view(-1, 4)
#[batch*num_positive, 4] loc_t保存了所有默认框的target offset,loc_t保存其中的正例
loc_t = loc_t[pos_idx].view(-1, 4)
#计算损失
loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
对默认框与真实对象之间匹配后会发现,大部分默认框仍然是背景,正负样本(所有默认框=正样本+负样本)数量差异悬殊。如果将所有默认框拿来训练,将导致对负样本的过拟合。因此只需要“挖掘”那些分类损失最大的负样本来训练,其数量为正样本的三倍。
注意下面代码中的loss_c是在难负样本挖掘中用来给默认框排序的,还不是最终的分类损失loss_class。
"""forward函数第三部分内容"""
# 难负样本挖掘的依据:loss_c
#[batch*num_priors , num_classes]
batch_conf = conf_data.view(-1, self.num_classes)
#[batch*num_priors] 计算所有默认框的分类损失
loss_c = log_sum_exp(batch_conf) -
batch_conf[torch.arange(0,num*num_priors),conf_t.view(-1, 1)]
# [batch*num_priors] 因为是给负样本排序的,所以手动给正样本损失置0
loss_c[pos.view(-1, 1)] = 0
loss_c = loss_c.view(num, -1) #[N,num_priors]
# 难负样本挖掘 Hard Negative Mining
_, loss_idx = loss_c.sort(1, descending=True)
_, idx_rank = loss_idx.sort(1) #各个框loss的排名,从大到小 [batch,num_priors]
num_pos = pos.long().sum(1, keepdim=True) #[batch,1]
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) #[batch,1]
neg = idx_rank < num_neg.expand_as(idx_rank) #得到负样本 [batch,num_priors]
上面的代码中涉及了两次排序的方法。总结了一下:
使用一次sort和两次sort的区别:
理解了两次sort,这段代码就不是问题了。
总结:正样本为默认框与真实框根据iou匹配得到,负样本为分类loss值排序得到。
有了正样本和负样本,接下来就可以愉快滴计算分类损失了。
"""forward函数第四部分内容"""
# 首先明确:分类损失包括:n个正样本损失,3n个负样本损失
pos_idx = pos.unsqueeze(2).expand_as(conf_data) #[batch,num_priors,num_classes]
neg_idx = neg.unsqueeze(2).expand_as(conf_data) #[batch,num_priors,num_classes]
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
targets_weighted = conf_t[(pos+neg).gt(0)]
loss_class = F.cross_entropy(conf_p, targets_weighted, size_average=False)
N = float(num_pos.data.sum())
loss_l /= N
loss_class /= N
return loss_l, loss_class