SOLO训练代码解析

之前写过对SOLO demo的代码解析,今天来梳理一下training过程。

首先是tools/train.py,这个文件是训练的开始,命令行运行的就是该文件:

from mmdet.apis import set_random_seed, train_detector

def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    # update configs according to CLI args
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    cfg.gpus = args.gpus

    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # log some basic info
    logger.info('Distributed training: {}'.format(distributed))
    logger.info('MMDetection Version: {}'.format(__version__))
    logger.info('Config:\n{}'.format(cfg.text))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}, deterministic: {}'.format(
            args.seed, args.deterministic))
        set_random_seed(args.seed, deterministic=args.deterministic)

    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        datasets.append(build_dataset(cfg.data.val))
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmdet_version=__version__,
            config=cfg.text,
            CLASSES=datasets[0].CLASSES)
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
    train_detector(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=args.validate,
        timestamp=timestamp)

文件的核心部分就是这个main函数,主要是加载配置文件,从配置中创建SOLO模型(build_detector),加载数据集(build_dataset),以及进入训练SOLO的模块(train_detector)。该部分代码没有太多细节,创建模型的部分与demo的一样,加载数据集部分可以暂不深究,所以核心是train_detector模块。

下面进入mmdet/apis/train.py文件,来对train_detector函数一探究竟:

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None):
    logger = get_root_logger(cfg.log_level)

    # start training
    if distributed:
        _dist_train(
            model,
            dataset,
            cfg,
            validate=validate,
            logger=logger,
            timestamp=timestamp)
    else:
        _non_dist_train(
            model,
            dataset,
            cfg,
            validate=validate,
            logger=logger,
            timestamp=timestamp)

def _non_dist_train(model,
                    dataset,
                    cfg,
                    validate=False,
                    logger=None,
                    timestamp=None):
    if validate:
        raise NotImplementedError('Built-in validation is not implemented '
                                  'yet in not-distributed training. Use '
                                  'distributed training or test.py and '
                                  '*eval.py scripts instead.')
    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.imgs_per_gpu,
            cfg.data.workers_per_gpu,
            cfg.gpus,
            dist=False) for ds in dataset
    ]
    # put model on gpus
    model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()

    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)
    runner = Runner(
        model, batch_processor, optimizer, cfg.work_dir, logger=logger)
    # an ugly walkaround to make the .log and .log.json filenames the same
    runner.timestamp = timestamp
    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(
            **cfg.optimizer_config, **fp16_cfg, distributed=False)
    else:
        optimizer_config = cfg.optimizer_config
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

首先根据是否要分布式训练做出判断。假设这里不采用distributed train,所以调用 _non_dist_train函数。该函数的作用是创建了dataloader,将model放到GPU上,创建了一个Runner类的实例(Runner有兴趣可以搜一下,他是MMdetection中比较重要的一个容器。不过他比较底层,可以不深究,只要了解它是将“数据、模型、训练策略、评估、推理”融合在一起的工具,在training的时候要用它将几部分模块进行统一)。我们发现在Runner实例化的时候,model、批处理工具、优化器等一些跟模型相关的模块就传递进去了;而在最后一行run函数执行时,又把data_loader,工作流以及训练时的epoch数传递进去。

所以接下来代码进入到runner.run中。要了解run函数,就需要简单看一下Runner这个类,Runner存在于mmcv中(不在MMdetection项目下,但是必须有mmcv才可运行MMDetection)。

Runner类存在于mmcv/runner/epoch_based_runner.py文件中:

@RUNNERS.register_module()
class EpochBasedRunner(BaseRunner):
    """Epoch-based Runner.

    This runner train models epoch by epoch.
    """

    def run_iter(self, data_batch, train_mode, **kwargs):
        if self.batch_processor is not None:	#走该分支
            outputs = self.batch_processor(
                self.model, data_batch, train_mode=train_mode, **kwargs)
        elif train_mode:
            outputs = self.model.train_step(data_batch, self.optimizer,
                                            **kwargs)
        else:
            outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
        if not isinstance(outputs, dict):
            raise TypeError('"batch_processor()" or "model.train_step()"'
                            'and "model.val_step()" must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
        self.outputs = outputs

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        #训练一个epoch
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_train_iter')
            ###核心###
            self.run_iter(data_batch, train_mode=True)
            self.call_hook('after_train_iter')
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1

    @torch.no_grad()
    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_val_iter')
            self.run_iter(data_batch, train_mode=False)
            self.call_hook('after_val_iter')

        self.call_hook('after_val_epoch')

    def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)
        if max_epochs is not None:
            warnings.warn(
                'setting max_epochs in run is deprecated, '
                'please set max_epochs in runner_config', DeprecationWarning)
            self._max_epochs = max_epochs

        assert self._max_epochs is not None, (
            'max_epochs must be specified during instantiation')

        for i, flow in enumerate(workflow):
            mode, epochs = flow		# 如果有train,epochs为train的epoch数(为了应对train-val交错的情况,如train2个epoch,在val1个epoch)
            if mode == 'train':
                self._max_iters = self._max_epochs * len(data_loaders[i])
                break

        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow,
                         self._max_epochs)
        self.call_hook('before_run')

        #############下面开始为重要代码#################
        while self.epoch < self._max_epochs:	# 训练至最大epoch
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                # data_loaders[i]对应的是train部分数据集或者val数据集
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            f'runner has no method named "{mode}" to run an '
                            'epoch')
                    epoch_runner = getattr(self, mode)	# 返回train或者val
                else:
                    raise TypeError(
                        'mode in workflow must be a str, but got {}'.format(
                            type(mode)))

                for _ in range(epochs):		
                    # 训练完毕后break,同时满足下述两个条件
                    if mode == 'train' and self.epoch >= self._max_epochs:
                        break
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')


@RUNNERS.register_module()
class Runner(EpochBasedRunner):
    """Deprecated name of EpochBasedRunner."""

    def __init__(self, *args, **kwargs):
        warnings.warn(
            'Runner was deprecated, please use EpochBasedRunner instead')
        super().__init__(*args, **kwargs)

可以发现Runner类是完全继承自EpochBasedRunner类的,所以关注的重点来到了EpochBasedRunner类。该类继承自BaseRunner(基本只做初始化时有用)。重点来看run函数,首先根据workflow来确定train和val的情况,然后开始训练。假设这里的workflow没有val,此时epoch_runner的返回值一定是train,所以直接调用train函数。上面的train函数中,通过一个循环进行一个epoch的训练,而循环中的最重要的一行调用了run_iter函数。我们在配置文件中的设定是要做批处理的,因此调用batch_processor,而该函数是在一开始实例化Runner的时候就传入的,所以一会回头看下这个函数。先剧透一波,其实这里得到的output就是loss,有了loss的值以后,就可以根据hook的一系列操作进行梯度下降方向传播,来做网络参数的更新。

所以回头看下mmdet/apis/train.py文件的 batch_processor 函数:

def batch_processor(model, data, train_mode):
    """Process a data batch.

    This method is required as an argument of Runner, which defines how to
    process a data batch and obtain proper outputs. The first 3 arguments of
    batch_processor are fixed.

    Args:
        model (nn.Module): A PyTorch model.
        data (dict): The data batch in a dict.
        train_mode (bool): Training mode or not. It may be useless for some
            models.

    Returns:
        dict: A dict containing losses and log vars.
    """
    losses = model(**data)		# 计算loss
    loss, log_vars = parse_losses(losses)

    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))

    return outputs


def parse_losses(losses):
    log_vars = OrderedDict()
    for loss_name, loss_value in losses.items():
        if isinstance(loss_value, torch.Tensor):
            log_vars[loss_name] = loss_value.mean()
        elif isinstance(loss_value, list):
            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
        else:
            raise TypeError(
                '{} is not a tensor or list of tensors'.format(loss_name))

    loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)

    log_vars['loss'] = loss
    for loss_name, loss_value in log_vars.items():
        # reduce loss when distributed training
        if dist.is_available() and dist.is_initialized():
            loss_value = loss_value.data.clone()
            dist.all_reduce(loss_value.div_(dist.get_world_size()))
        log_vars[loss_name] = loss_value.item()

    return loss, log_vars

model( ∗ ∗ ** data)是最重要的一个环节,将当前batch作为输入传递到SOLO这个SingleStageInsDetector中,执行forward参数,由于在做training,因此调用forward_train方法。(parse函数作用不大,相当于输出更全面的loss信息,这里可以先忽略)。

forward_train函数在mmdet/models/detectors/single_stage_ins.py中:

    def forward_train(self,
                      img,
                      img_metas,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None):
        x = self.extract_feat(img)
        outs = self.bbox_head(x)

        if self.with_mask_feat_head:
            mask_feat_pred = self.mask_feat_head(
                x[self.mask_feat_head.
                  start_level:self.mask_feat_head.end_level + 1])
            loss_inputs = outs + (mask_feat_pred, gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg)
        else:
            loss_inputs = outs + (gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg)
        losses = self.bbox_head.loss(
            *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        return losses

可以发现首先将这个batch的图像作为输入正向传播,得到预测结果outs和mask_head_pred。然后再batch中的GT信息一起,作为输入,传递到self.bbox_head.loss中计算损失函数。现在我们有了一个batch的图像经过SOLO后的类别和mask预测结果,以及对应的ground truth,接下来到了最后一个环节,就是计算loss的值。有了loss值以后就可以做反向传播了。

所以最后来到了损失函数部分(重要),见文件mmdet/models/anchor_heads/solo_head.py(炒鸡详细注释版):

    def loss(self,
             ins_preds,
             cate_preds,
             gt_bbox_list,
             gt_label_list,
             gt_mask_list,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in
                         ins_preds]
#        print(featmap_sizes)  [torch.Size([200, 304]), torch.Size([200, 304]), torch.Size([100, 152]), torch.Size([50, 76]), torch.Size([50, 76])]
#        print(gt_label_list)	#(n, 1) n表示GT有n个instance,每一个对应类别的序号
#        print(gt_bbox_list)   #(n, 4)
#        print(gt_mask_list[0].shape)	# (n, 1216, 800)
#        for i in range(len(featmap_sizes)):
#	        print(ins_preds[i].shape)
#	        print(cate_preds[i].shape)
#"""最后两个维度有很多种,取决于img的大小,这一部分的预处理还需要再看一下,这里以(200,304)这组为例"""
#	     torch.Size([1, 1600, 200, 304])
#		torch.Size([1, 80, 40, 40])
#		torch.Size([1, 1296, 200, 304])
#		torch.Size([1, 80, 36, 36])
#		torch.Size([1, 576, 100, 152])
#		torch.Size([1, 80, 24, 24])
#		torch.Size([1, 256, 50, 76])
#		torch.Size([1, 80, 16, 16])
#		torch.Size([1, 144, 50, 76])
#		torch.Size([1, 80, 12, 12])

        ins_label_list, cate_label_list, ins_ind_label_list = multi_apply(
            self.solo_target_single,
            gt_bbox_list,
            gt_label_list,
            gt_mask_list,
            featmap_sizes=featmap_sizes)

#        for i in range(5):
#            print(ins_label_list[0][i].shape)
#            print(ins_ind_label_list[0][i].shape)
#		torch.Size([1600, 200, 304])
#		torch.Size([1600])
#		torch.Size([1296, 200, 304])
#		torch.Size([1296])
#		torch.Size([576, 100, 152])
#		torch.Size([576])
#		torch.Size([256, 50, 76])
#		torch.Size([256])
#		torch.Size([144, 50, 76])
#		torch.Size([144])

        # ins  根据ins_ind_label_list,对每一张特征图,挑选出对应位置为True的mask,并保存在ins_labels中
        ins_labels = [torch.cat([ins_labels_level_img[ins_ind_labels_level_img, ...]
                                 for ins_labels_level_img, ins_ind_labels_level_img in
                                 zip(ins_labels_level, ins_ind_labels_level)], 0)
                      for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list), zip(*ins_ind_label_list))]

#        for i in range(5):
#            print(ins_labels[i].shape)   #五次循环分别为m1 m2 m3 m4 m5,对应ins_ind_label_list[0][i]为True的个数
#            print(ins_ind_label_list[0][i].sum())   #依次是 m1 m2 m3 m4 m5

        ###########For Example##########
#         torch.Size([5, 272, 200])
#		tensor(5, device='cuda:0')
#		torch.Size([10, 272, 200])
#		tensor(10, device='cuda:0')
#		torch.Size([2, 136, 100])
#		tensor(2, device='cuda:0')
#		torch.Size([0, 68, 50])
#		tensor(0, device='cuda:0')
#		torch.Size([0, 68, 50])
#		tensor(0, device='cuda:0')

	   # 和上面一样,不同之处就是GT的ins_label_list换成了预测值ins_preds
        ins_preds = [torch.cat([ins_preds_level_img[ins_ind_labels_level_img, ...]
                                for ins_preds_level_img, ins_ind_labels_level_img in
                                zip(ins_preds_level, ins_ind_labels_level)], 0)
                     for ins_preds_level, ins_ind_labels_level in zip(ins_preds, zip(*ins_ind_label_list))]

        ### 至此,GTmask的变量ins_labels 和 预测mask的变量 ins_preds 在形式上实现了统一


        ins_ind_labels = [
            torch.cat([ins_ind_labels_level_img.flatten()
                       for ins_ind_labels_level_img in ins_ind_labels_level])
            for ins_ind_labels_level in zip(*ins_ind_label_list)
        ]
        flatten_ins_ind_labels = torch.cat(ins_ind_labels)

        num_ins = flatten_ins_ind_labels.sum()	# 值为 m1+m2+m3+m4+m5

        # dice loss
        loss_ins = []
        for input, target in zip(ins_preds, ins_labels):
            if input.size()[0] == 0:
                continue
            input = torch.sigmoid(input)
            loss_ins.append(dice_loss(input, target))
        loss_ins = torch.cat(loss_ins).mean()
        loss_ins = loss_ins * self.ins_loss_weight

        # cate
        cate_labels = [
            torch.cat([cate_labels_level_img.flatten()
                       for cate_labels_level_img in cate_labels_level])
            for cate_labels_level in zip(*cate_label_list)
        ]
        
#        for i in range(5):
#            print(cate_labels[i].shape)
#          torch.Size([1600])
#		torch.Size([1296])
#		torch.Size([576])
#		torch.Size([256])
#		torch.Size([144])
        
        flatten_cate_labels = torch.cat(cate_labels)	
#        print(flatten_cate_labels.shape)	# 3872 = 1600 + 1296 + 576 + 256 + 144

#        for i in range(5):
#            print(cate_preds[i].shape)	# (1, 80, 40, 40)
#
#		torch.Size([1, 80, 40, 40])
#		torch.Size([1, 80, 36, 36])
#		torch.Size([1, 80, 24, 24])
#		torch.Size([1, 80, 16, 16])
#		torch.Size([1, 80, 12, 12])

        cate_preds = [
            cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels)     # (x, 80)  x为40*40或36*36或 ...
            for cate_pred in cate_preds
        ]
        flatten_cate_preds = torch.cat(cate_preds)
#        print(flatten_cate_preds.shape)		# (3872, 80)

        loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)
        return dict(
            loss_ins=loss_ins,
            loss_cate=loss_cate)

    ###对GT进行处理,依据每张特征图的尺寸,把gt_bbox,gt_label和gt_mask分配到对应尺寸的featmap上,len(list)=5###
    def solo_target_single(self,
                               gt_bboxes_raw,
                               gt_labels_raw,
                               gt_masks_raw,
                               featmap_sizes=None):

        device = gt_labels_raw[0].device

        # ins
#        print(gt_bboxes_raw)	#(n, 4)
#        print(gt_labels_raw)  #(n, 1)
#        print(gt_masks_raw.shape)	#(n, 1216, 800)

        gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
                gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
#        print(gt_areas)	#(n)  表示instance所对应的bbox的面积

        ins_label_list = []
        cate_label_list = []
        ins_ind_label_list = []
        #对五个level的特征图循环求解
        for (lower_bound, upper_bound), stride, featmap_size, num_grid \
                in zip(self.scale_ranges, self.strides, featmap_sizes, self.seg_num_grids):
#            print(lower_bound)   [1, 48, 96, 192, 384]
#            print(upper_bound)   [96, 192, 384, 768, 2048]
#            print(stride)	   [8, 8, 16,32, 32]
#            print(featmap_size)  
#            print(num_grid) 。   [40, 36, 24, 16, 12]
#            print('=========================================')
            ins_label = torch.zeros([num_grid ** 2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device)
#            print(ins_label.shape)		[(1600,200,304), (1296,200,304), (576,200,304), (256,200,304), (144,200,304)]
            cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device)
#            print(cate_label.shape)     [(40,40), (36,36), (24,24), (16,16), (12,12)]
            ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device)
#            print(ins_ind_label.shape)   [1600, 1296, 576, 256, 144]

            hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten()# 存下bbox面积在range内的box下标,m个(m<=n)
            if len(hit_indices) == 0:	#如果没有合适面积的bbox,则append全零矩阵,接着考量下一张特征图的尺度
                ins_label_list.append(ins_label)
                cate_label_list.append(cate_label)
                ins_ind_label_list.append(ins_ind_label)
                continue
            gt_bboxes = gt_bboxes_raw[hit_indices]
            gt_labels = gt_labels_raw[hit_indices]
            gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...]

		  # Center Sampling:一个宽度和高度,容许gt的center在这个正方形区域内,落到更多的grid上面,avg=3
            half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
            half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
#            print(half_ws)	# [m]
#            print(half_hs)	# [m]

            # mass center
            gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
            center_ws, center_hs = center_of_mass(gt_masks_pt)
#            print(center_ws, center_hs)		# gt instance的质心	[m]
            valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
#            print(valid_mask_flags)	# m个True构成的向量

            output_stride = stride / 2	#[4, 4, 8, 16, 16]
            for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
                if not valid_mask_flag:
                   continue
                upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
#                print(upsampled_size)	# (800, 1216) 最上层level特征图的4倍
                # coord表示计算gt的中心点落到第几个grid中
                coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
                coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))

                # left, top, right, down  让一个ins落到附近多个grid上面
                top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
                down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
                left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
                right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))

#			 限定到coord附近最多一格
                top = max(top_box, coord_h-1)
                down = min(down_box, coord_h+1)
                left = max(coord_w-1, left_box)
                right = min(right_box, coord_w+1)

                cate_label[top:(down+1), left:(right+1)] = gt_label
                # ins
                seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)	#mask尺度缩小到与featmap一致
                seg_mask = torch.from_numpy(seg_mask).to(device=device)
                for i in range(top, down+1):
                    for j in range(left, right+1):
                        label = int(i * num_grid + j)
                        ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask	#对应的通道值设置为seg_mask的值
                        ins_ind_label[label] = True	#把选中的位置设置为True
            ins_label_list.append(ins_label)
            cate_label_list.append(cate_label)
            ins_ind_label_list.append(ins_ind_label)
#        print(len(ins_label_list))	# 5
        return ins_label_list, cate_label_list, ins_ind_label_list

上述代码的注释写得比较详细,就是有些乱望见谅。这里说一下大致的思路:我们现在有网络正向传播得到的cate_preds和ins_preds,以及GT数据的gt_bbox_list, gt_label_list和gt_mask_list。我们要明确的是:cate_preds和ins_preds是五个feature map经过anchor head得到的结果,相当于五个尺度concat起来得到的一个list;而GT中的数据是对于整个img整体的各种标注。所以(1)首先要做solo_target_single函数,其目的是根据scale_range,将所有GT instance分配到五个不同的不同level的特征图上(通过计算instance bbox的面积和scale_range做比较,以确定该instance落在那个level上);(2)solo_target_single函数内部逻辑:计算出每个instance的bbox大小,将其分配到对应scale_range的level上,将GT放缩到该level特征图的大小,从而得到instance mask质心所对应的grid的索引,进而可以确定哪些grid负责预测该instance(正例),最后GT的形式就可以拆分成五个level的组合;(3)回到loss函数,根据正例索引,筛选出所有正例grid所对应的mask(GT和pred都做筛选),将pred mask做sigmoid归一化以后,两者就可以计算Dice Loss了,作为分割损失;(4)将GT cate_label转化为[3872]维向量,将pred cate_label转化为[3872, 80]维矩阵,两者做Focal Loss;(5)返回Dice Loss和Focal Loss的值,传到上述的parse_losses函数中计算最终的loss值。

Training代码大致就是如此了,有一些Runner,Hook的细节没发讲太清楚,感兴趣的话可以深度挖掘一下。如果只从算法的层面上,solo_head这一部分无疑是最核心的。篇幅所限,如果有地方没说明白或存在漏洞,欢迎评论区留言交流~

你可能感兴趣的:(实例分割,Pytorch)