Test.py主要作用测试或者评估一个模型。很大篇幅都是关于参数解析的内容,我们这里详细了解一下:
config:配置文件路径
checkpoint:Checkpoint路径。Checkpoint是用于描述在每次训练后保存模型参数(权重)的惯例或术语
work_dir:评估指标文件的存储路径
out:将预测转储到pickle文件以进行脱机评估
show:显示预测结果
show_dir:预测结果图片保存路径
wait_time:显示间隔时间
cfg_options:使用键值对覆盖配置文件的一些设置
launcher:launcher 是指分布式训练的任务启动器(job launcher),默认值为none表示不进行分布式训练;
tta:Test-Time Augmentation,测试时数据增强
local_rank: 代表当前程序进程使用的GPU标号
下面的代码含义:如果环境变量中没有指定当前进程使用的GPU标号,则使用参数里指定的
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
然后是根据配置文件和解析的参数共同去初始化cfg
然后按照cfg去实例化runner。
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import warnings
from copy import deepcopy
from mmengine import ConfigDict
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
from mmdet.engine.hooks.utils import trigger_visualization_hook
from mmdet.evaluation import DumpDetResults
from mmdet.registry import RUNNERS
from mmdet.utils import setup_cache_size_limit_of_dynamo
# TODO: support fuse_conv_bn and format_only
def parse_args():
parser = argparse.ArgumentParser(
description='MMDet test (and eval) a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--work-dir',
help='the directory to save the file containing evaluation metrics')
parser.add_argument(
'--out',
type=str,
help='dump predictions to a pickle file for offline evaluation')
parser.add_argument(
'--show', action='store_true', help='show prediction results')
parser.add_argument(
'--show-dir',
help='directory where painted images will be saved. '
'If specified, it will be automatically saved '
'to the work_dir/timestamp/show_dir')
parser.add_argument(
'--wait-time', type=float, default=2, help='the interval of show (s)')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--tta', action='store_true')
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
# will pass the `--local-rank` parameter to `tools/train.py` instead
# of `--local_rank`.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
# Reduce the number of repeated compilations and improve
# testing speed.
setup_cache_size_limit_of_dynamo()
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
cfg.load_from = args.checkpoint
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)
if args.tta:
if 'tta_model' not in cfg:
warnings.warn('Cannot find ``tta_model`` in config, '
'we will set it as default.')
cfg.tta_model = dict(
type='DetTTAModel',
tta_cfg=dict(
nms=dict(type='nms', iou_threshold=0.5), max_per_img=100))
if 'tta_pipeline' not in cfg:
warnings.warn('Cannot find ``tta_pipeline`` in config, '
'we will set it as default.')
test_data_cfg = cfg.test_dataloader.dataset
while 'dataset' in test_data_cfg:
test_data_cfg = test_data_cfg['dataset']
cfg.tta_pipeline = deepcopy(test_data_cfg.pipeline)
flip_tta = dict(
type='TestTimeAug',
transforms=[
[
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
],
[
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor', 'flip',
'flip_direction'))
],
])
cfg.tta_pipeline[-1] = flip_tta
cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
# build the runner from config
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
# add `DumpResults` dummy metric
if args.out is not None:
assert args.out.endswith(('.pkl', '.pickle')), \
'The dump file must be a pkl file.'
runner.test_evaluator.metrics.append(
DumpDetResults(out_file_path=args.out))
# start testing
runner.test()
if __name__ == '__main__':
main()
程序的运行过程主要通过函数runner.test()来执行,test()启动测试函数。下面我们看下 runner.test()
class Runner:
def test(self) -> dict:
"""Launch test.
Returns:
dict: A dict of metrics on testing set.
"""
if self._test_loop is None:
raise RuntimeError(
'`self._test_loop` should not be None when calling test '
'method. Please provide `test_dataloader`, `test_cfg` and '
'`test_evaluator` arguments when initializing runner.')
self._test_loop = self.build_test_loop(self._test_loop) # type: ignore
self.call_hook('before_run')
# make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume()
metrics = self.test_loop.run() # type: ignore
self.call_hook('after_run')
return metrics
上面函数中的主要对象self._test_loop是mmengine.runner.loops.TestLoop这个类实例化得到的。下面是初始化函数:
runner (Runner): 一个 runner对象的引用
dataloader (Dataloader or dict): 一个 dataloader 对象 或者是一个构建dataloader的字典,用于生成一个批次数据的迭代器。
evaluator (Evaluator or dict or list): 用于计算度量。
fp16 (bool):是否采用float16编码
class TestLoop(BaseLoop):
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False):
def run(self) -> dict:
"""Launch test."""
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
self.runner.model.eval()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test')
return metrics
@torch.no_grad()
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
"""Iterate one mini-batch.
Args:
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self.runner.call_hook(
'before_test_iter', batch_idx=idx, data_batch=data_batch)
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.test_step(data_batch)
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
下面是TestLoop的基类
class BaseLoop(metaclass=ABCMeta):
def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None:
self._runner = runner
if isinstance(dataloader, dict):
# Determine whether or not different ranks use different seed.
diff_rank_seed = runner._randomness_cfg.get(
'diff_rank_seed', False)
self.dataloader = runner.build_dataloader(
dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed)
else:
self.dataloader = dataloader
@property
def runner(self):
return self._runner
@abstractmethod
def run(self) -> Any:
"""Execute loop."""