CenterNet:Objects as Points代码解析(八):hw_loss的计算

if opt.wh_weight > 0:
        #dense_wh稠密的wh形式,即训练数据中wh的标注batch['dense_wh']的维度是batch*2*output_w*output_h而不是batch['wh']的维度self.max_objs*2
        if opt.dense_wh:
          mask_weight = batch['dense_wh_mask'].sum() + 1e-4
          #batch['dense_wh_mask']存储的是在输出头“wh”(output_w*output_h*2)中有目标中心点的位置的掩码(与heatmap层对应);
          #以保证output['wh'] * batch['dense_wh_mask']相乘后,预测的output['wh']只在有对象中心点的位置有值,而其它无对象中心点的位置置零
          wh_loss += (
            self.crit_wh(output['wh'] * batch['dense_wh_mask'],
            batch['dense_wh'] * batch['dense_wh_mask']) / 
            mask_weight) / opt.num_stacks
        elif opt.cat_spec_wh:
          wh_loss += self.crit_wh(
            output['wh'], batch['cat_spec_mask'],
            batch['ind'], batch['cat_spec_wh']) / opt.num_stacks
        else:
          wh_loss += self.crit_reg(
            output['wh'], batch['reg_mask'],
            batch['ind'], batch['wh']) / opt.num_stacks
def forward(self, output, mask, ind, target):
    #通过_tranpose_and_gather_feat以及def _gather_feat()函数得出我们预测的宽高,32*50(self.max_objs)*2
    pred = _tranpose_and_gather_feat(output, ind)
    #mask维度 :32*50(self.max_objs)---->32*50*2
    mask = mask.unsqueeze(2).expand_as(pred).float()
    # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
    loss = F.l1_loss(pred * mask, target * mask, size_average=False)
    loss = loss / (mask.sum() + 1e-4)
    return loss
def _tranpose_and_gather_feat(feat, ind):
    #feat(即,预测的output['wh'])维度32*2*96*96----->32*96*96*2
    feat = feat.permute(0, 2, 3, 1).contiguous()
    #feat维度32*96*96*2----->32*9216*2
    feat = feat.view(feat.size(0), -1, feat.size(3))
    #根据ind取出feat中对应的元素;  因为不是dense_wh形式,训练数据中wh的标注batch['wh']的维度是self.max_objs*2,和预测的输出feat(output['wh'])的维度32*2*96*96不相符,
    #没有办法进行计算求损失,所以需要根据ind(对象在heatmap图上的索引)取出feat中对应的元素,使其维度和batch['wh']一样,最后维度为32*50*2
    feat = _gather_feat(feat, ind)
    return feat

def _gather_feat(feat, ind, mask=None):
    #dim = 2
    dim  = feat.size(2)
    #ind维度 :32*50---->32*50*2
    ind  = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
    #从feat的第1个维度,按ind给出的索引提取元素
    feat = feat.gather(1, ind)
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat

你可能感兴趣的:(论文解读)