CenterFusion/src/test.py 项目验证执行文件详解

目录

  • 一、test.sh 脚本
  • 二、test.py 文件

文件位置:CenterFusion-master/experiments/test.shCenterFusion-master/src/test.py
文件作用:CenterFusion 项目验证的执行过程
注意:本文中的代码都是 CenterFusion 原始代码,一些参数没有修改

一、test.sh 脚本

  • 在 README.md 中训练模型的命令是:bash experiments/test.sh
  • 首先执行的就是 test.sh 脚本
  • 在脚本中 --参数 值 表示可选参数
export CUDA_VISIBLE_DEVICES=1
cd src

## 执行检测和评估
python test.py ddd \
    --exp_id centerfusion \
    '''
    项目名称
    '''
    --dataset nuscenes \
    '''
    设置 nuscenes 数据集
    '''
    --val_split mini_val \
    '''
    验证集
    '''
    --run_dataset_eval \
    '''
    在 eval 中使用数据集特定的计算函数
    '''
    --num_workers 4 \
    '''
    4 线程
    '''
    --nuscenes_att \
    --velocity \
    --gpus 0 \
    '''
    gpu 索引号
    '''
    --pointcloud \
    '''
    雷达点云
    '''
    --radar_sweeps 3 \
    '''
    点云图中雷达扫瞄 3 次
    '''
    --max_pc_dist 60.0 \
    '''
    移除 max_pc_dist 以外的雷达点
    '''
    --pc_z_offset -0.0 \
    '''
    向 z 方向升起所有雷达,高度为 -0.0
    '''
    --load_model ../models/centerfusion_e60.pth \
    '''
    导入模型
    '''
    --flip_test \
    '''
    翻转数据增加
    '''
    # --resume \

二、test.py 文件

  • 首先执行 main() 函数中的内容
if __name__ == '__main__':

  opt = opts().parse()
  '''
  调用 opts.py 中的 parse() 函数
  在 CenterFusion/src/lib/opts.py 第 305 行
  '''

  if opt.not_prefetch_test:
    test(opt)
  else:
    prefetch_test(opt)
  '''
  由于在 test.sh 中并没有添加参数 not_prefetch_test,所以 opt.not_prefetch_test = False
  最后只执行了 else 中的语句,调用 prefetch_test() 函数
  '''
  • opts.py 文件中的代码可以参考博客:CenterFusion/src/lib/opts.py 文件代码详解
  • 然后执行 test.py 中的 prefetch_test() 函数
def prefetch_test(opt):

  if not opt.not_set_cuda_env:
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
  '''
  由于在 test.sh 中没有添加参数 not_set_cuda_env,所以 opt.not_set_cuda_env = False,再 not 取反
  执行该 if 语句
  gpus_str 是 opt 中的一个 GPU 索引号字符串,如:'0,1'
  这里是为了给系统添加 cuda 索引号
  '''

  Dataset = dataset_factory[opt.test_dataset]
  '''
  设置数据集对象 nuScenes,Dataset 是一个 nuScenes 类
  dataset_factory 在 CenterFusion/src/lib/dataset/dataset_factory.py 第 20 行
  nuScenes 对象定义在 CenterFusion/src/lib/dataset/datasets/nuscenes.py 中
  '''

  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  '''
  设置一些配置信息
  update_dataset_info_and_set_heads() 函数在 opys.py 第 458 行
  '''

  Logger(opt)
  '''
  新建一个 Logger 对象,记录配置信息
  Logger 对象在 CenterFusion/src/lib/logger.py
  '''
  
  split = 'val' if not opt.trainval else 'test'
  '''
  在 test.sh 中没有添加参数 trainval,所以 opt.trainval = False,再 not 取反,为 True
  所以 split = 'val'
  '''

  if split == 'val':
    split = opt.val_split
  '''
  val_split 在 opts.py 中的默认值为 'val',但在 test.sh 中添加了该参数,并赋值为 'mini_val'
  所以 split 的值为 'mini_val'
  '''

  dataset = Dataset(opt, split)
  '''
  传递参数 opt(配置信息)、split(数据集名称)
  '''

  detector = Detector(opt)
  '''
  Detector 类定义在 CenterFusion/src/lib/detector.py 中
  '''
  
  if opt.load_results != '':
    load_results = json.load(open(opt.load_results, 'r'))
    for img_id in load_results:
      for k in range(len(load_results[img_id])):
        if load_results[img_id][k]['class'] - 1 in opt.ignore_loaded_cats:
          load_results[img_id][k]['score'] = -1
  else:
    load_results = {}
  '''
  load_results 默认值为 ''
  所以执行了 else 语句
  '''

  data_loader = torch.utils.data.DataLoader(
    PrefetchDataset(opt, dataset, detector.pre_process), 
    batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
  '''
  torch.utils.data.DataLoader 是一个数据读取的一个接口,参数:
    PrefetchDataset(opt, dataset, detector.pre_process):加载数据的数据集
    batch_size (int, optional):每个 batch 加载多少个样本(默认: 1)
    shuffle (bool, optional):设置为 True 时会在每个 epoch 重新打乱数据(默认: False)
    num_workers (int, optional):用多少个子进程加载数据。0 表示数据将在主进程中加载(默认: 0)
    pin_memory (bool, optional):设置 pin_memory=True,则意味着生成的 Tensor 数据最开始是属于内存中的锁页内存,
                                 这样将内存的 Tensor 转义到 GPU 的显存就会更快一些
  PrefetchDataset 类,在上面,这个类继承了 torch.utils.data.Dataset 类,表示自定义了数据读取方式
  最后返回一个列表给 data_loader,其中有图片 id 以及 tensor 格式的图片数据
  '''

  results = {}
  num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
  '''
  num_iters 默认值为 -1,所以 num_iters = data_loader 列表的长度
  '''
  
  bar = Bar('{}'.format(opt.exp_id), max=num_iters)
  '''
  定义了一个进度条,如:centerfusion |###       | 3/10
  进度条名称为:centerfusion
  进度条最大值为:num_iters
  '''

  time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge', 'track']
  avg_time_stats = {t: AverageMeter() for t in time_stats}
  '''
  AverageMeter 类定义在 CenterFusion/src/lib/utils/utils.py 中
  给 time_stats 中的每一个属性赋值为 AverageMeter 类
  '''

  if opt.use_loaded_results:
    for img_id in data_loader.dataset.images:
      results[img_id] = load_results['{}'.format(img_id)]
    num_iters = 0
  '''
  在 test.sh 中没有添加参数 use_loaded_results ,所以值为 False
  没有执行该 if 语句
  '''

  for ind, (img_id, pre_processed_images) in enumerate(data_loader):
    '''
    enumerate() 函数:用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列
                    同时列出数据和数据下标
    其中 ind 表示对 data_loader 列表位置的计数
         img_id 为图片数据的 id
         pre_processed_images 为 tensor 格式的图片数据
    '''
    
    if ind >= num_iters:
      break
    '''
    如果遍历完 data_loader 就结束 
    '''

    if opt.tracking and ('is_first_frame' in pre_processed_images):
      if '{}'.format(int(img_id.numpy().astype(np.int32)[0])) in load_results:
        pre_processed_images['meta']['pre_dets'] = \
          load_results['{}'.format(int(img_id.numpy().astype(np.int32)[0]))]
      else:
        print()
        print('No pre_dets for', int(img_id.numpy().astype(np.int32)[0]), 
          '. Use empty initialization.')
        pre_processed_images['meta']['pre_dets'] = []
      detector.reset_tracking()
      print('Start tracking video', int(pre_processed_images['video_id']))
    '''
    由于在 test.sh 中没有添加参数 tracking,所以 opt.tracking = False
    没有执行该 if 语句
    '''

    if opt.public_det:
      if '{}'.format(int(img_id.numpy().astype(np.int32)[0])) in load_results:
        pre_processed_images['meta']['cur_dets'] = \
          load_results['{}'.format(int(img_id.numpy().astype(np.int32)[0]))]
      else:
        print('No cur_dets for', int(img_id.numpy().astype(np.int32)[0]))
        pre_processed_images['meta']['cur_dets'] = []
    '''
    由于在 test.sh 中没有添加参数 public_det,所以 opt.public_det = False
    没有执行该 if 语句
    '''

    ret = detector.run(pre_processed_images)
    '''
    run() 函数在 CenterFusion/src/lib/detector.py 第 56 行
    对 tensor 格式的图片数据进行检测,并返回检测结果
    '''

    results[int(img_id.numpy().astype(np.int32)[0])] = ret['results']
    '''
    其中 img_id.numpy().astype(np.int32) 是将 img_id 强制转换成 int32 型的数据
    这里是为了记录对应图片数据的检测结果,results[图片的索引号] = ret['result']
    '''
    
    Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                   ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
    '''
    给 bar 添加一些显示字符串
    显示 ind、num_iters、bar.elapsed_td、bar.eta_td 的值
    '''

    for t in avg_time_stats:
      avg_time_stats[t].update(ret[t])
      Bar.suffix = Bar.suffix + '|{} {tm.val:.3f}s ({tm.avg:.3f}s) '.format(
        t, tm = avg_time_stats[t])
    '''
    update() 函数在 CenterFusion/src/lib/utils/utils.py 第 18 行
    计算 ret 中每个属性的平均值和当前值,并将其添加后 bar 的后面显示在屏幕上
    '''

    if opt.print_iter > 0:
      if ind % opt.print_iter == 0:
        print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
    else:
      bar.next()
    '''
    print_iter 默认值为 0,所以执行 else 语句
    打印进度条到屏幕上
    '''

  bar.finish()
  '''
  进度条完成
  '''

  if opt.save_results:
    print('saving results to', opt.save_dir + '/save_results_{}{}.json'.format(
      opt.test_dataset, opt.dataset_version))
    json.dump(_to_list(copy.deepcopy(results)), 
              open(opt.save_dir + '/save_results_{}{}.json'.format(
                opt.test_dataset, opt.dataset_version), 'w'))
  '''
  在 test.sh 中没有添加 save_results 参数,所以 opt.save_results = False
  没有执行该 if 语句
  '''
  
  dataset.run_eval(results, opt.save_dir, n_plots=opt.eval_n_plots, 
                   render_curves=opt.eval_render_curves)
  '''
  对结果进行检测评估,评估结果保存在 ~/CenterFusion/src/lib/../../exp/ddd/centerfusion/nuscenes_eval_det_output_mini_val 下
  run_eval() 函数在 CenterFusion/src/lib/dataset/datasets/nuscenes.py 第 272 行
    results :图片数据的检测结果
    save_dir :保存路径为 ~/CenterFusion/src/lib/../../exp/ddd/centerfusion
    eval_n_plots :默认值为 0
    eval_render_curves :渲染和保存评价曲线,在 test.sh 中没有添加该参数,则为 False
  '''
  • run_eval() 函数中的内容如下:
  def run_eval(self, results, save_dir, n_plots=10, render_curves=False):
    task = 'tracking' if self.opt.tracking else 'det'
    '''
    由于 test.sh 中没有添加参数 tracking,所以 opt.tracking 的值为 False
    所以 task = 'det'
    '''

    split = self.opt.val_split
    '''
    split = 'mini_val'
    '''

    version = 'v1.0-mini' if 'mini' in split else 'v1.0-trainval'
    '''
    version = 'v1.0-mini'
    '''

    self.save_results(results, save_dir, task, split)
    '''
    保存结果为 json 文件
    '''

    render_curves = 1 if render_curves else 0
    '''
    render_curves = 0
    '''
    
    if task == 'det':
      output_dir = '{}/nuscenes_eval_det_output_{}/'.format(save_dir, split)
      '''
      设置输出路径
      '''

      os.system('python ' + \
        'tools/nuscenes-devkit/python-sdk/nuscenes/eval/detection/evaluate.py ' + \
        '{}/results_nuscenes_{}_{}.json '.format(save_dir, task, split) + \
        '--output_dir {} '.format(output_dir) + \
        '--eval_set {} '.format(split) + \
        '--dataroot ../data/nuscenes/ ' + \
        '--version {} '.format(version) + \
        '--plot_examples {} '.format(n_plots) + \
        '--render_curves {} '.format(render_curves))
      '''
      执行官网 evaluate.py 文件
      对结果进行检测评估,并输出到 output_dir 路径下
      '''
      
    else:
      output_dir = '{}/nuscenes_evaltracl__output/'.format(save_dir)
      os.system('python ' + \
        'tools/nuscenes-devkit/python-sdk/nuscenes/eval/tracking/evaluate.py ' + \
        '{}/results_nuscenes_{}_{}.json '.format(save_dir, task, split) + \
        '--output_dir {} '.format(output_dir) + \
        '--dataroot ../data/nuscenes/')
      os.system('python ' + \
        'tools/nuscenes-devkit/python-sdk-alpha02/nuscenes/eval/tracking/evaluate.py ' + \
        '{}/results_nuscenes_{}_{}.json '.format(save_dir, task, split) + \
        '--output_dir {} '.format(output_dir) + \
        '--dataroot ../data/nuscenes/')
    
    return output_dir

你可能感兴趣的:(python,centerfusion)