Faster RCNN的检测蒸馏(分类、回归、Feature-level、Feature-level+Mask)

 The code is heavily borrowed from :

1.Distillation for faster rcnn in classification,regression,feature level

http://papers.nips.cc/paper/6676-learning-efficient-object-detection-models-with-knowledge-distillation.pdf

Faster RCNN的检测蒸馏(分类、回归、Feature-level、Feature-level+Mask)_第1张图片

2.Distillation for faster rcnn in feature level +mask

http://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Distilling_Object_Detectors_With_Fine-Grained_Feature_Imitation_CVPR_2019_paper.pdf
Faster RCNN的检测蒸馏(分类、回归、Feature-level、Feature-level+Mask)_第2张图片

code:

https://github.com/twangnh/Distilling-Object-Detectors

直接上核心代码:(前面是student和teacherd的网络输出,输出为字典的形式,蒸馏相比普通需要在rcnn中增加输出分类和回归的预测,backbone的特征图输出,student和teacherd的特征图大小不一样,需要变换成相同通道数,相同大小,只需要一个卷积层即可实现其具体代码可以参考我的另一篇博客: Distilling Object Detectors with Fine-grained Feature Imitation的复现)

        output = model(to_device(input))
        output_teacher = model_teacher(to_device(input))


        # rois_label_t=output['cls_target']

        '''
        Classification and regression distillation:
        '''
        if cfg_distillation.get('cls_distillation',None):
            cfg_cls_distillation = cfg_distillation.get('cls_distillation')
            rcn_cls_score_t = output_teacher['cls_pred']
            rcn_cls_score_s = output['cls_pred']
            RCNN_loss_cls_s = output['BboxNet.cls_loss']
            start_mu=cfg_cls_distillation.get('start_mu')
            end_mu=cfg_cls_distillation.get('end_mu')
            mu=start_mu+(end_mu-start_mu)*(float(epoch)/max_epoch)
            loss_rcn_cls, loss_rcn_cls_soft = compute_loss_classification(rcn_cls_score_t, rcn_cls_score_s, mu,
                                                                          RCNN_loss_cls_s, T=1, weighted=True)
            # loss_rcn_cls, loss_rcn_cls_soft = compute_loss_classification(rcn_cls_score_t, rcn_cls_score_s, mu,
            #
            output['BboxNet.cls_loss']=loss_rcn_cls

        if cfg_distillation.get('loc_distillation',None):
            cfg_loc_distillation=cfg_distillation.get('loc_distillation')
            RCNN_loss_bbox_s=output['BboxNet.loc_loss']
            bbox_pred_s=output['loc_pred']
            bbox_pred_t=output_teacher['loc_pred']
            rois_target_s=output['loc_target']
            rois_target_t=output_teacher['loc_target']

            start_ni=cfg_loc_distillation.get('start_ni')
            end_ni=cfg_loc_distillation.get('end_ni')
            ni=start_ni+(end_ni-start_ni)*(float(epoch)/max_epoch)
            loss_rcn_reg, loss_rcn_reg_soft,_,_ = \
                compute_loss_regression(RCNN_loss_bbox_s, bbox_pred_s, bbox_pred_t,rois_target_s, rois_target_t, m=0.01, ni=ni)
            # loss_rcn_cls, loss_rcn_cls_soft = compute_loss_classification(rcn_cls_score_t, rcn_cls_score_s, mu,
            #                                                               RCNN_loss_cls_s, T=1, weighted=True)
            output['BboxNet.loc_loss'] = loss_rcn_reg

        '''
        Feature level distillation:
        '''
        # sup_loss = (torch.pow(sup_feature - stu_feature_adap, 2) * mask_batch).sum() / norms
        # sup_loss = sup_loss * args.imitation_loss_weigth
        if cfg_distillation.get('feature_distillation', None):
            cfg_feature_distillation=cfg_distillation.get('feature_distillation')
            sup_feature=output_teacher['features'][0]
            stu_feature=output['features'][0]
            stu_feature_adap=model_adap(stu_feature)


            start_weigth=cfg_feature_distillation.get('start_weigth')
            end_weigth=cfg_feature_distillation.get('end_weigth')
            imitation_loss_weigth = start_weigth + (end_weigth - start_weigth) * (float(epoch) / max_epoch)
            if cfg_feature_distillation.get('start_weigth', None):
                mask_batch = output_teacher['RoINet.mask_batch']
                mask_list = []
                for mask in mask_batch:
                    mask = (mask > 0).float().unsqueeze(0)
                    mask_list.append(mask)
                mask_batch = torch.stack(mask_list, dim=0)
                norms = mask_batch.sum() * 2
                sup_loss = (torch.pow(sup_feature - stu_feature_adap, 2) * mask_batch).sum() / norms
            else:
                sup_loss = (torch.pow(sup_feature - stu_feature_adap, 2)).sum()

            # imitation_loss_weigth=0.0001

            sup_loss = sup_loss * imitation_loss_weigth
            output['sup.loss']=sup_loss

        #################################################################################

代码问题和细节可以在我的github讨论:

https://github.com/HqWei/Distillation-of-Faster-rcnn

你可能感兴趣的:(目标检测)