【CenterFusion】测试执行过程CenterFusion/src/test.py

文件作用:CenterFusion 项目验证的执行过程

  • 首先执行 main() 函数中的内容.
if __name__ == '__main__':
  opt = opts().parse() 
  '''
  调用 opts.py 中的 parse() 函数.   
  在 CenterFusion/src/lib/opts.py 第 305 行
  '''           
  if opt.not_prefetch_test: # 如果opt不能预取test则进行 test(opt)
    test(opt) 
  else:                            
    prefetch_test(opt) # 如果可以取,则进行prefetch_test(opt)
  '''
  由于在 test.sh 中并没有添加参数 not_prefetch_test,所以 opt.not_prefetch_test = False
  最后只执行了 else 中的语句,调用 prefetch_test() 函数.   
  '''  
  • 然后执行 test.py 中的 prefetch_test() 函数
def prefetch_test(opt):  
  if not opt.not_set_cuda_env:  
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
  '''
  由于在 test.sh 中没有添加参数 not_set_cuda_env,所以 opt.not_set_cuda_env = False,再 not 取反
  执行该 if 语句.  
  gpus_str 是 opt 中的一个 GPU 索引号字符串,如:'0,1'
  这里是为了给系统添加 cuda 索引号.      
  '''   
  Dataset = dataset_factory[opt.test_dataset]      
  '''        
  设置数据集对象 nuScenes,Dataset 是一个 nuScenes 类.  
  dataset_factory 在 CenterFusion/src/lib/dataset/dataset_factory.py 第 20 行
  nuScenes对象定义在 CenterFusion/src/lib/dataset/datasets/nuscenes.py 中. 
  '''      
  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  '''                    
  设置一些配置信息.      
  update_dataset_info_and_set_heads() 函数在 opys.py 第 458 行
  '''
  print(opt)
  Logger(opt)     
  '''
  新建一个 Logger 对象,记录配置信息.    
  Logger 对象在 CenterFusion/src/lib/logger.py
  '''
  
  split = 'val' if not opt.trainval else 'test'
  '''
  在 test.sh 中没有添加参数 trainval,所以 opt.trainval = False,再 not 取反,为 True
  所以 split = 'val'
  '''
  if split == 'val':
    split = opt.val_split
    '''
    val_split 在 opts.py 中的默认值为 'val',但在 test.sh 中添加了该参数,并赋值为 'mini_val'
    所以 split 的值为 'mini_val'
    '''
  dataset = Dataset(opt, split) 
  ''' 
  传递参数 opt(配置信息)、split(数据集名称)
  '''
  detector = Detector(opt)  
  '''
  Detector 类定义在 CenterFusion/src/lib/detector.py 中
  '''
  if opt.load_results != '':
    load_results = json.load(open(opt.load_results, 'r'))
    for img_id in load_results:
      for k in range(len(load_results[img_id])):
        if load_results[img_id][k]['class'] - 1 in opt.ignore_loaded_cats:
          load_results[img_id][k]['score'] = -1  
  else:
    load_results = {}
  '''
  load_results 默认值为 ''  
  所以执行了 else 语句
  '''
  data_loader = torch.utils.data.DataLoader(  
    PrefetchDataset(opt, dataset, detector.pre_process), 
    batch_size=1, shuffle=False, num_workers=1, pin_memory=True)  
  '''
  torch.utils.data.DataLoader 是一个数据读取的一个接口,参数:
    PrefetchDataset(opt, dataset, detector.pre_process):加载数据的数据集
    batch_size (int, optional):每个 batch 加载多少个样本(默认: 1)
    shuffle (bool, optional):设置为 True 时会在每个 epoch 重新打乱数据(默认: False)
    num_workers (int, optional):用多少个子进程加载数据。0 表示数据将在主进程中加载(默认: 0)
    pin_memory (bool, optional):设置 pin_memory=True,则意味着生成的 Tensor 数据最开始是属于内存中的锁页内存,
                                 这样将内存的 Tensor 转义到 GPU 的显存就会更快一些
  PrefetchDataset 类,在上面,这个类继承了 torch.utils.data.Dataset 类,表示自定义了数据读取方式
  最后返回一个列表给 data_loader,其中有图片 id 以及 tensor 格式的图片数据
  '''

  results = {}
  num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
  '''
  num_iters 默认值为 -1,所以 num_iters = data_loader 列表的长度
  '''
  bar = Bar('{}'.format(opt.exp_id), max=num_iters)
  '''
  定义了一个进度条,如:centerfusion |###       | 3/10
  进度条名称为:centerfusion
  进度条最大值为:num_iters
  '''
  time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge', 'track']
  avg_time_stats = {t: AverageMeter() for t in time_stats}
  '''       
  AverageMeter 类定义在 CenterFusion/src/lib/utils/utils.py 中
  给 time_stats 中的每一个属性赋值为 AverageMeter 类
  '''
  if opt.use_loaded_results:
    for img_id in data_loader.dataset.images:
      results[img_id] = load_results['{}'.format(img_id)]
    num_iters = 0
  ''' 
  在 test.sh 中没有添加参数 use_loaded_results ,所以值为 False
  没有执行该 if 语句  
  '''
  for ind, (img_id, pre_processed_images) in enumerate(data_loader):
    '''
    enumerate() 函数:用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列
                    同时列出数据和数据下标
    其中 ind 表示对 data_loader 列表位置的计数
         img_id 为图片数据的 id
         pre_processed_images 为 tensor 格式的图片数据
    '''
    if ind >= num_iters:
      break
    '''
    如果遍历完 data_loader 就结束.    
    '''
    if opt.tracking and ('is_first_frame' in pre_processed_images):
      if '{}'.format(int(img_id.numpy().astype(np.int32)[0])) in load_results:
        pre_processed_images['meta']['pre_dets'] = \  
          load_results['{}'.format(int(img_id.numpy().astype(np.int32)[0]))]
      else:
        print()
        print('No pre_dets for', int(img_id.numpy().astype(np.int32)[0]), 
          '. Use empty initialization.') 
        pre_processed_images['meta']['pre_dets'] = []
      detector.reset_tracking()
      print('Start tracking video', int(pre_processed_images['video_id']))
    '''
    由于在 test.sh 中没有添加参数 tracking,所以 opt.tracking = False
    没有执行该 if 语句
    '''
    if opt.public_det:
      if '{}'.format(int(img_id.numpy().astype(np.int32)[0])) in load_results:
        pre_processed_images['meta']['cur_dets'] = \
          load_results['{}'.format(int(img_id.numpy().astype(np.int32)[0]))]
      else:
        print('No cur_dets for', int(img_id.numpy().astype(np.int32)[0]))
        pre_processed_images['meta']['cur_dets'] = []
    '''
    由于在 test.sh 中没有添加参数 public_det,所以 opt.public_det = False
    没有执行该 if 语句
    '''
    ret = detector.run(pre_processed_images)
    '''
    run() 函数在 CenterFusion/src/lib/detector.py 第 56 行
    对 tensor 格式的图片数据进行检测,并返回检测结果
    '''
    results[int(img_id.numpy().astype(np.int32)[0])] = ret['results']
    '''
    其中 img_id.numpy().astype(np.int32) 是将 img_id 强制转换成 int32 型的数据
    这里是为了记录对应图片数据的检测结果,results[图片的索引号] = ret['result']
    '''
    
    Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                   ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
    '''
    给 bar 添加一些显示字符串
    显示 ind、num_iters、bar.elapsed_td、bar.eta_td 的值
    '''
    for t in avg_time_stats:  
      avg_time_stats[t].update(ret[t])
      Bar.suffix = Bar.suffix + '|{} {tm.val:.3f}s ({tm.avg:.3f}s) '.format(
        t, tm = avg_time_stats[t])
    '''
    update() 函数在 CenterFusion/src/lib/utils/utils.py 第 18 行
    计算 ret 中每个属性的平均值和当前值,并将其添加后 bar 的后面显示在屏幕上
    '''
    if opt.print_iter > 0:
      if ind % opt.print_iter == 0:
        print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
    else:
      bar.next()
    '''
    print_iter 默认值为 0,所以执行 else 语句
    打印进度条到屏幕上
    '''
  bar.finish()
  '''
  进度条完成
  '''
  if opt.save_results:
    print('saving results to', opt.save_dir + '/save_results_{}{}.json'.format(
      opt.test_dataset, opt.dataset_version))   
    json.dump(_to_list(copy.deepcopy(results)), 
              open(opt.save_dir + '/save_results_{}{}.json'.format(
                opt.test_dataset, opt.dataset_version), 'w'))
  '''
  在 test.sh 中没有添加 save_results 参数,所以 opt.save_results = False
  没有执行该 if 语句   
  '''
  dataset.run_eval(results, opt.save_dir, n_plots=opt.eval_n_plots, 
                   render_curves=opt.eval_render_curves)  
  '''
  对结果进行检测评估,评估结果保存在 ~/CenterFusion/src/lib/../../exp/ddd/centerfusion/nuscenes_eval_det_output_mini_val 下
  run_eval() 函数在 CenterFusion/src/lib/dataset/datasets/nuscenes.py 第 272 行
    results :图片数据的检测结果
    save_dir :保存路径为 ~/CenterFusion/src/lib/../../exp/ddd/centerfusion
    eval_n_plots :默认值为 0
    eval_render_curves :渲染和保存评价曲线,在 test.sh 中没有添加该参数,则为 False
  '''
  • run_eval() 函数在 CenterFusion/src/lib/dataset/datasets/nuscenes.py
  • run_eval() 函数中的内容如下:
def run_eval(self, results, save_dir, n_plots=10, render_curves=False):
    task = 'tracking' if self.opt.tracking else 'det'
    '''
    由于 test.sh 中没有添加参数 tracking,所以 opt.tracking 的值为 False
    所以 task = 'det'
    '''

    split = self.opt.val_split  
    '''
    split = 'mini_val'
    '''

    version = 'v1.0-mini' if 'mini' in split else 'v1.0-trainval'
    #================================  源代码 版本判断有Bug 修改位置如下:======================================
    if version == 'v1.0-mini' :
         flag=1
         print("当前数据集选择的版本是:",version)
    elif version == 'v1.0-test': 
         print("当前数据集选择的版本是:",version)
    elif version == 'v1.0-trainval':
         print("当前数据集选择的版本是:",version)
    else:
         print("您设置的数据集版本不存在,请重新输入!")
    #===========================================================================================================
    ''' 目前存在的数据集名称列表有
    split_names = {
        'mini_train':'mini_train', 
        'mini_val':'mini_val',
        'train': 'train', 
        'train_detect': 'train_detect',
        'train_track':'train_track', 
        'val': 'val',
        'test': 'test',
        'mini_train_2': 'mini_train_2',
        'trainval': 'trainval',
    }
    '''
    '''
    version = 'v1.0-mini'
    '''  
    self.save_results(results, save_dir, task, split)
    '''
    保存结果为 json 文件
    '''

    render_curves = 1 if render_curves else 0
    '''
    render_curves = 0
    '''
    
    if task == 'det':
      output_dir = '{}/nuscenes_eval_det_output_{}/'.format(save_dir, split)
      '''
      设置输出路径
      '''

      os.system('python ' + \
        'tools/nuscenes-devkit/python-sdk/nuscenes/eval/detection/evaluate.py ' + \
        '{}/results_nuscenes_{}_{}.json '.format(save_dir, task, split) + \
        '--output_dir {} '.format(output_dir) + \
        '--eval_set {} '.format(split) + \
        '--dataroot ../data/nuscenes/ ' + \ 
        '--version {} '.format(version) + \
        '--plot_examples {} '.format(n_plots) + \
        '--render_curves {} '.format(render_curves)) 
      '''   
      执行官网 evaluate.py 文件  
      对结果进行检测评估,并输出到 output_dir 路径下.
      ''' 
     
    else:  
      output_dir = '{}/nuscenes_evaltracl__output/'.format(save_dir)
      os.system('python ' + \
        'tools/nuscenes-devkit/python-sdk/nuscenes/eval/tracking/evaluate.py ' + \
        '{}/results_nuscenes_{}_{}.json '.format(save_dir, task, split) + \
        '--output_dir {} '.format(output_dir) + \
        '--dataroot ../data/nuscenes/')
      os.system('python ' + \
        'tools/nuscenes-devkit/python-sdk-alpha02/nuscenes/eval/tracking/evaluate.py ' + \
        '{}/results_nuscenes_{}_{}.json '.format(save_dir, task, split) + \
        '--output_dir {} '.format(output_dir) + \
        '--dataroot ../data/nuscenes/')
    
    return output_dir

你可能感兴趣的:(CenterFusion,人工智能,python,深度学习,目标检测,自动驾驶)