【mmdetection实践】(五)理解train的过程

文章目录

之前的几篇文章已经分别理解了:

  • 如何定义自己的数据集
  • 如何训练自己的网络
  • dataset和model是怎么构造的

本文就再详细的看一下,在构造好了dataset和model是如何训练的。从tools/train.py中

# mmdetection/tools/train.py
train_detector(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=args.validate,
        timestamp=timestamp,
        meta=meta)

可以进入在mmdet/apis/train.py中

# mmdetection/mmdet/apis/train.py
runner = Runner(
        model,
        batch_processor,
        optimizer,
        cfg.work_dir,
        logger=logger,
        meta=meta)
...
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

再查看runner,可以看到,

# mmcv/runner/runner.py
class Runner:
    def train(self, data_loader, **kwargs):
    	...
        for i, data_batch in enumerate(data_loader):
            ...
            outputs = self.batch_processor(
                self.model, data_batch, train_mode=True, **kwargs)
            ...

再看一下runner在初始化时传入的batch_processor:

# mmdetection/mmdet/apis/train.py
def batch_processor(model, data, train_mode):
    losses = model(**data)
    loss, log_vars = parse_losses(losses)
    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
    return outputs

可以看到,这里就直接调用了model(**data),也就是说,训练外围设置的事情已经完成,可以进行了一次前向计算了,而且由model出来的,就是losses。也就是说,loss的计算过程是包含在了前向计算中的。这就可以看model的设置了,这里看所有Detector的父类:

# mmdetection/mmdet/models/detectors/base.py
class BaseDetector(nn.Module, metaclass=ABCMeta):
	def forward(self, img, img_meta, return_loss=True, **kwargs):
        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs)
    
    def forward_test(self, imgs, img_metas, **kwargs):
        ···
        if num_augs == 1:
            return self.simple_test(imgs[0], img_metas[0], **kwargs)
        else:
            return self.aug_test(imgs, img_metas, **kwargs)

再从TwoStageDetector看一下forward_train:

class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                       MaskTestMixin):
    def forward_train(···):
        ···
        return losses
    
    def simple_test(self, img, img_meta, proposals=None, rescale=False):
    	···
    	if not self.with_mask:
            return bbox_results
        else:
            segm_results = self.simple_test_mask(
                x, img_meta, det_bboxes, det_labels, rescale=rescale)
            return bbox_results, segm_results

以上,就可以看到,如果在调用model(**data)时传入,return_loss=True,则是训练模式,只会返回losses;反之则是预测模式,返回bbox_results。

这个也可以在test过程中验证:

# tools/test.py
def single_gpu_test(model, data_loader, show=False):
    model.eval()
    ···
    for i, data in enumerate(data_loader):
        with torch.no_grad():
            result = model(return_loss=False, rescale=not show, **data)

你可能感兴趣的:(mmdetection)