mmdet之centernet损失函数记录

1.损失函数在

bbox_head中定义,类型为CenterNetHead

具体相关代码

bbox_head=dict(
    type='CenterNetHead',
    num_classes=5,
    in_channel=64,
    feat_channel=64,
    loss_center_heatmap=dict(type='GaussianFocalLoss', loss_weight=1.0),
    loss_wh=dict(type='L1Loss', loss_weight=0.1),
    loss_offset=dict(type='L1Loss', loss_weight=1.0)),

具体有关loss_wh的计算,类型为L1Loss,函数定义在model中,losses下的smooth_l1_loss.py文件,具体调用损失函数为L1Loss

本质为

pred和target做差
在计算loss_wh时,只计算center存在点的loss,这一实现依靠weight权重
weight的计算在CenterNetHead类中实现,实现函数为
def get_targets(self, gt_bboxes, gt_labels, feat_shape, img_shape):

核心代码如下:

wh_offset_target_weight = gt_bboxes[-1].new_zeros(
    [bs, 2, feat_h, feat_w])
wh_offset_target_weight[batch_id, :, cty_int, ctx_int] = 1

初始化全0的weight的tensor,

只把中心点对应处的weight改为1。

你可能感兴趣的:(mmdetection,pytorch,centernet)