CenterFusion/src/lib/opts.py 文件代码详解

文件内容:CenterFusion/src/lib/opts.py
文件作用:train.sh 脚本中参数的处理

  • 这里需要对添加参数部分说明一点点
  • 比如:
self.parser.add_argument('--not_set_cuda_env', action='store_true',
                             help='used when training in slurm clusters.')
  • action:脚本中添加了该参数时,它的值则为 True,没有添加则为 False
    help:对参数的说明

下面的 opts.py 文件中的具体内容:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys

class opts(object):
  def __init__(self):
    self.parser = argparse.ArgumentParser()
    # basic experiment setting
    self.parser.add_argument('task', default='',
                             help='ctdet | ddd | multi_pose '
                             '| tracking or combined with ,')
    self.parser.add_argument('--dataset', default='nuscenes',
                             help='see lib/dataset/dataset_facotry for ' + 
                            'available datasets')
    self.parser.add_argument('--test_dataset', default='',
                             help='coco | kitti | coco_hp | pascal')
    self.parser.add_argument('--exp_id', default='default')
    self.parser.add_argument('--eval', action='store_true',
                             help='only evaluate the val split and quit 只评估val split和quit')
    self.parser.add_argument('--debug', type=int, default=0,
                             help='level of visualization.'
                                  '1: only show the final detection results'
                                  '2: show the network output features'
                                  '3: use matplot to display' # useful when lunching training with ipython notebook
                                  '4: save all visualizations to disk')
    self.parser.add_argument('--no_pause', action='store_true',
                             help='do not pause after debugging visualizations')
    self.parser.add_argument('--demo', default='', 
                             help='path to image/ image folders/ video. '
                                  'or "webcam"')
    self.parser.add_argument('--load_model', default='',
                             help='path to pretrained model')
    self.parser.add_argument('--resume', action='store_true',
                             help='resume an experiment. '
                                  'Reloaded the optimizer parameter and '
                                  'set load_model to model_last.pth '
                                  'in the exp dir if load_model is empty.') 

    # system
    self.parser.add_argument('--gpus', default='0', 
                             help='-1 for CPU, use comma for multiple gpus,-1表示CPU,逗号表示多个gpu')
    self.parser.add_argument('--num_workers', type=int, default=4,
                             help='dataloader threads. 0 for single-thread. dataloader线程。0为单线程')
    self.parser.add_argument('--not_cuda_benchmark', action='store_true',
                             help='disable when the input size is not fixed.')
    self.parser.add_argument('--seed', type=int, default=317, 
                             help='random seed') # from CornerNet
    self.parser.add_argument('--not_set_cuda_env', action='store_true',
                             help='used when training in slurm clusters.在slurm训练时使用')

    # log
    self.parser.add_argument('--print_iter', type=int, default=0, 
                             help='disable progress bar and print to screen.')
    self.parser.add_argument('--save_all', action='store_true',
                             help='save model to disk every 5 epochs. 每5个epoch将模型保存到磁盘')
    self.parser.add_argument('--vis_thresh', type=float, default=0.3,
                             help='visualization threshold.')
    self.parser.add_argument('--debugger_theme', default='white', 
                             choices=['white', 'black'])
    self.parser.add_argument('--run_dataset_eval', action='store_true',
                             help='use dataset specific evaluation function in eval')
    self.parser.add_argument('--save_imgs', default='',
                             help='list of images to save in debug. empty to save all 要在调试中保存的图像列表。为空保存所有')
    self.parser.add_argument('--save_img_suffix', default='', help='')
    self.parser.add_argument('--skip_first', type=int, default=-1,
                             help='skip first n images in demo mode')
    self.parser.add_argument('--save_video', action='store_true')
    self.parser.add_argument('--save_framerate', type=int, default=30)
    self.parser.add_argument('--resize_video', action='store_true')
    self.parser.add_argument('--video_h', type=int, default=512, help='')
    self.parser.add_argument('--video_w', type=int, default=512, help='')
    self.parser.add_argument('--transpose_video', action='store_true')
    self.parser.add_argument('--show_track_color', action='store_true')
    self.parser.add_argument('--not_show_bbox', action='store_true')
    self.parser.add_argument('--not_show_number', action='store_true')
    self.parser.add_argument('--qualitative', action='store_true')
    self.parser.add_argument('--tango_color', action='store_true')

    # model
    self.parser.add_argument('--arch', default='dla_34', 
                             help='model architecture. Currently tested'
                                  'res_18 | res_101 | resdcn_18 | resdcn_101 |'
                                  'dlav0_34 | dla_34 | hourglass')
    self.parser.add_argument('--dla_node', default='dcn') 
    self.parser.add_argument('--head_conv', type=int, default=-1,
                             help='conv layer channels for output head'
                                  '0 for no conv layer'
                                  '-1 for default setting: '
                                  '64 for resnets and 256 for dla.')
    self.parser.add_argument('--num_head_conv', type=int, default=1,
                             help='number of conv layers before each output head')
    self.parser.add_argument('--head_kernel', type=int, default=3, help='')
    self.parser.add_argument('--down_ratio', type=int, default=4,
                             help='output stride. Currently only supports 4.')
    # self.parser.add_argument('--not_idaup', action='store_true')
    self.parser.add_argument('--num_classes', type=int, default=-1)
    self.parser.add_argument('--num_resnet_layers', type=int, default=101)
    self.parser.add_argument('--backbone', default='dla34',
                             help='backbone for the generic detection network')
    self.parser.add_argument('--neck', default='dlaup',
                             help='neck for the generic detection network')
    self.parser.add_argument('--msra_outchannel', type=int, default=256)
    # self.parser.add_argument('--efficient_level', type=int, default=0)
    self.parser.add_argument('--prior_bias', type=float, default=-4.6) # -2.19

    # input
    self.parser.add_argument('--input_res', type=int, default=-1, 
                             help='input height and width. -1 for default from '
                             'dataset. Will be overriden by input_h | input_w')
    self.parser.add_argument('--input_h', type=int, default=-1, 
                             help='input height. -1 for default from dataset.')
    self.parser.add_argument('--input_w', type=int, default=-1, 
                             help='input width. -1 for default from dataset.')
    self.parser.add_argument('--dataset_version', default='')

    # train
    self.parser.add_argument('--optim', default='adam')
    self.parser.add_argument('--lr', type=float, default=1.25e-4, 
                             help='learning rate for batch size 32.')
    self.parser.add_argument('--lr_step', type=str, default='60',
                             help='drop learning rate by 10. 学习速度除以10')
    self.parser.add_argument('--save_point', type=str, default='90',
                             help='when to save the model to disk. 何时将模型保存到磁盘')
    self.parser.add_argument('--num_epochs', type=int, default=70,
                              help='total training epochs.')
    self.parser.add_argument('--batch_size', type=int, default=32,
                             help='batch size')
    self.parser.add_argument('--master_batch_size', type=int, default=-1,
                             help='batch size on the master gpu.')
    self.parser.add_argument('--num_iters', type=int, default=-1,
                             help='default: #samples / batch_size.')
    self.parser.add_argument('--val_intervals', type=int, default=10,
                             help='number of epochs to run validation. 运行验证的纪元数')
    self.parser.add_argument('--trainval', action='store_true',
                             help='include validation in training and '
                                  'test on test set')
    self.parser.add_argument('--ltrb', action='store_true',
                             help='')          
    self.parser.add_argument('--ltrb_weight', type=float, default=0.1,
                             help='')
    self.parser.add_argument('--reset_hm', action='store_true')
    self.parser.add_argument('--reuse_hm', action='store_true')
    # self.parser.add_argument('--use_kpt_center', action='store_true')
    # self.parser.add_argument('--add_05', action='store_true')
    self.parser.add_argument('--dense_reg', type=int, default=1, help='')
    self.parser.add_argument('--shuffle_train', action='store_true',
                             help='shuffle training dataloader')

    # test
    self.parser.add_argument('--flip_test', action='store_true',
                             help='flip data augmentation.')
    self.parser.add_argument('--test_scales', type=str, default='1',
                             help='multi scale test augmentation.')
    self.parser.add_argument('--nms', action='store_true',
                             help='run nms in testing.')
    self.parser.add_argument('--K', type=int, default=100,
                             help='max number of output objects.') 
    self.parser.add_argument('--not_prefetch_test', action='store_true',
                             help='not use parallal data pre-processing.')
    self.parser.add_argument('--fix_short', type=int, default=-1)
    self.parser.add_argument('--keep_res', action='store_true',
                             help='keep the original resolution'
                                  ' during validation. 在验证期间保持原始分辨率')
    # self.parser.add_argument('--map_argoverse_id', action='store_true',
    #                          help='if trained on nuscenes and eval on kitti')
    self.parser.add_argument('--out_thresh', type=float, default=-1,
                             help='')
    self.parser.add_argument('--depth_scale', type=float, default=1,
                             help='')
    self.parser.add_argument('--save_results', action='store_true')
    self.parser.add_argument('--load_results', default='')
    self.parser.add_argument('--use_loaded_results', action='store_true')
    self.parser.add_argument('--ignore_loaded_cats', default='')
    self.parser.add_argument('--model_output_list', action='store_true',
                             help='Used when convert to onnx')
    self.parser.add_argument('--non_block_test', action='store_true')
    self.parser.add_argument('--vis_gt_bev', default='',
                             help='path to gt bev images')
    self.parser.add_argument('--kitti_split', default='3dop',
                             help='different validation split for kitti: '
                                  '3dop | subcnn')
    self.parser.add_argument('--test_focal_length', type=int, default=-1)

    # dataset
    self.parser.add_argument('--not_rand_crop', action='store_true',
                             help='not use the random crop data augmentation'
                                  'from CornerNet.')
    self.parser.add_argument('--not_max_crop', action='store_true',
                             help='used when the training dataset has'
                                  'inbalanced aspect ratios.')
    self.parser.add_argument('--shift', type=float, default=0,
                             help='when not using random crop, 0.1'
                                  'apply shift augmentation.')
    self.parser.add_argument('--scale', type=float, default=0,
                             help='when not using random crop, 0.4'
                                  'apply scale augmentation.')
    self.parser.add_argument('--aug_rot', type=float, default=0, 
                             help='probability of applying '
                                  'rotation augmentation.')
    self.parser.add_argument('--rotate', type=float, default=0,
                             help='when not using random crop'
                                  'apply rotation augmentation.')
    self.parser.add_argument('--flip', type=float, default=0.5,
                             help='probability of applying flip augmentation.')
    self.parser.add_argument('--no_color_aug', action='store_true',
                             help='not use the color augmenation '
                                  'from CornerNet')

    # Tracking
    self.parser.add_argument('--tracking', action='store_true')
    self.parser.add_argument('--pre_hm', action='store_true')
    self.parser.add_argument('--same_aug_pre', action='store_true')
    self.parser.add_argument('--zero_pre_hm', action='store_true')
    self.parser.add_argument('--hm_disturb', type=float, default=0)
    self.parser.add_argument('--lost_disturb', type=float, default=0)
    self.parser.add_argument('--fp_disturb', type=float, default=0)
    self.parser.add_argument('--pre_thresh', type=float, default=-1)
    self.parser.add_argument('--track_thresh', type=float, default=0.3)
    self.parser.add_argument('--new_thresh', type=float, default=0.3)
    self.parser.add_argument('--max_frame_dist', type=int, default=3)
    self.parser.add_argument('--ltrb_amodal', action='store_true')
    self.parser.add_argument('--ltrb_amodal_weight', type=float, default=0.1)
    self.parser.add_argument('--public_det', action='store_true')
    self.parser.add_argument('--no_pre_img', action='store_true')
    self.parser.add_argument('--zero_tracking', action='store_true')
    self.parser.add_argument('--hungarian', action='store_true')
    self.parser.add_argument('--max_age', type=int, default=-1)


    # loss
    self.parser.add_argument('--tracking_weight', type=float, default=1)
    self.parser.add_argument('--reg_loss', default='l1',
                             help='regression loss: sl1 | l1 | l2')
    self.parser.add_argument('--hm_weight', type=float, default=1,
                             help='loss weight for keypoint heatmaps.')
    self.parser.add_argument('--off_weight', type=float, default=1,
                             help='loss weight for keypoint local offsets.')
    self.parser.add_argument('--wh_weight', type=float, default=0.1,
                             help='loss weight for bounding box size.边框尺寸的损失重量')
    self.parser.add_argument('--hp_weight', type=float, default=1,
                             help='loss weight for human pose offset.')
    self.parser.add_argument('--hm_hp_weight', type=float, default=1,
                             help='loss weight for human keypoint heatmap.')
    self.parser.add_argument('--amodel_offset_weight', type=float, default=1,
                             help='Please forgive the typo.')
    self.parser.add_argument('--dep_weight', type=float, default=1,
                             help='loss weight for depth.')
    self.parser.add_argument('--dep_res_weight', type=float, default=1,
                             help='loss weight for depth residual.')
    self.parser.add_argument('--dim_weight', type=float, default=1,
                             help='loss weight for 3d bounding box size.')
    self.parser.add_argument('--rot_weight', type=float, default=1,
                             help='loss weight for orientation.')
    self.parser.add_argument('--nuscenes_att', action='store_true')
    self.parser.add_argument('--nuscenes_att_weight', type=float, default=1)
    self.parser.add_argument('--velocity', action='store_true')
    self.parser.add_argument('--velocity_weight', type=float, default=1)

    # custom dataset
    self.parser.add_argument('--custom_dataset_img_path', default='')
    self.parser.add_argument('--custom_dataset_ann_path', default='')

    # point clouds and nuScenes dataset
    self.parser.add_argument('--pointcloud', action='store_true')
    self.parser.add_argument('--train_split', default='train',
                             choices=['train','mini_train', 'train_detect', 'train_track', 'mini_train_2', 'trainval'])
    self.parser.add_argument('--val_split', default='val',
                             choices=['val','mini_val','test'])
    self.parser.add_argument('--max_pc', type=int, default=1000,
                             help='maximum number of points in the point cloud')
    self.parser.add_argument('--r_a', type=float, default=250,
                             help='alpha parameter for hm size calculation')
    self.parser.add_argument('--r_b', type=float, default=5,
                             help='beta parameter for hm size calculation')
    self.parser.add_argument('--img_format', default='jpg',
                             help='debug image format')
    self.parser.add_argument('--max_pc_dist', type=float, default=100.0,
                             help='remove points beyond max_pc_dist meters')
    self.parser.add_argument('--freeze_backbone', action='store_true',
                             help='freeze the backbone network and only train heads 冻结骨干网络,仅限train头部')
    self.parser.add_argument('--radar_sweeps', type=int, default=1,
                             help='number of radar sweeps in point cloud')
    self.parser.add_argument('--warm_start_weights', action='store_true',
                             help='try to reuse weights even if dimensions dont match')
    self.parser.add_argument('--pc_z_offset', type=float, default=0,
                             help='raise all Radar points in z direction')
    self.parser.add_argument('--eval_n_plots', type=int, default=0,
                             help='number of sample plots drawn in eval')
    self.parser.add_argument('--eval_render_curves', action='store_true',
                             help='render and save evaluation curves')
    self.parser.add_argument('--hm_transparency', type=float, default=0.7,
                             help='heatmap visualization transparency')
    self.parser.add_argument('--iou_thresh', type=float, default=0,
                             help='IOU threshold for filtering overlapping detections')
    self.parser.add_argument('--pillar_dims', type=str, default='2,0.5,0.5',
                             help='Radar pillar dimensions (h,w,l)')
    self.parser.add_argument('--show_velocity', action='store_true')
    
    

  def parse(self, args=''):

    if args == '':
      opt = self.parser.parse_args()
    else:
      opt = self.parser.parse_args(args)
    '''
    把 parser 中设置的所有 "add_argument" 给返回到 args 子类实例当中
    '''
  
    if opt.test_dataset == '':
      opt.test_dataset = opt.dataset
    '''
    设置数据集为 nuscenes
    test_dataset 默认值为 '' ,dataset 默认值为 nuscenes
    '''
    
    opt.gpus_str = opt.gpus
    '''
    为 opt 添加一个新的参数 gpus_str,用来临时保存 gpus 的值
    gpus 默认值为 0,是一个字符串,在 train.sh 中的值为 0,1
    '''

    opt.gpus = [int(gpu) for gpu in opt.gpus.split(',')]
    '''
    split() 函数:拆分字符串,通过指定分隔符对字符串进行切片,并返回分割后的字符串列表(list)
    这里是将字符串 '0,1' 整数化成整数型数组 [0, 1]
    '''

    opt.gpus = [i for i in range(len(opt.gpus))] if opt.gpus[0] >=0 else [-1]
    '''
    重新设置 GPU 索引号,其结果依然没变
    '''

    opt.lr_step = [int(i) for i in opt.lr_step.split(',')]
    '''
    lr_step 默认值为 60,在 train.sh 中的值为 50
    这里也是将字符串处理成整数型数组
    最后 lr_step 值为 [50]
    '''

    opt.save_point = [int(i) for i in opt.save_point.split(',')]
    '''
    save_point 默认值为 50,在 train.sh 中的值为 20,40,50
    这里也是将字符串处理成整数型数组
    save_point 参数意义:何时将模型保存到磁盘
    '''

    opt.test_scales = [float(i) for i in opt.test_scales.split(',')]
    '''
    test_scales 默认值为 1
    这里是将 test_scales 处理成浮点型
    test_scales 参数意义:多尺度测试增强
    '''

    opt.save_imgs = [i for i in opt.save_imgs.split(',')] \
      if opt.save_imgs != '' else []
    '''
    save_imgs 默认值为 ''
    save_imgs 参数意义:要在调试中保存的图像列表。为空保存所有
    '''
    
    opt.ignore_loaded_cats = \
      [int(i) for i in opt.ignore_loaded_cats.split(',')] \
      if opt.ignore_loaded_cats != '' else []
    '''
    ignore_loaded_cats 默认值为 ''
    '''

    opt.num_workers = max(opt.num_workers, 2 * len(opt.gpus))
    '''
    num_workers 默认值为 4
    最后 num_workers 的值为 4
    num_workers 参数意义:dataloader 线程,0 为单线程
    '''

    opt.pre_img = False
    '''
    为 opt 添加了一个新的参数 pre_img,它的值为 False
    '''

    if 'tracking' in opt.task:
      print('Running tracking')
      opt.tracking = True
      opt.out_thresh = max(opt.track_thresh, opt.out_thresh)
      opt.pre_thresh = max(opt.track_thresh, opt.pre_thresh)
      opt.new_thresh = max(opt.track_thresh, opt.new_thresh)
      opt.pre_img = not opt.no_pre_img
      print('Using tracking threshold for out threshold!', opt.track_thresh)
      if 'ddd' in opt.task:
        opt.show_track_color = True
    '''
    tast 默认值为 '',但在 train.sh 中赋值为 ddd,则 opt.task = 'ddd'
    所以该 if 语句没有执行
    '''

    opt.fix_res = not opt.keep_res
    print('Fix size testing.' if opt.fix_res else 'Keep resolution testing.')
    '''
    keep_res 由于 train.sh 没有添加该参数,所以值为 False
    最后新添加的参数 fix_res 的值为 true,则打印 'Fix size testing.'
    keep_res 参数意义:在验证期间保持原始分辨率
    '''

    if opt.head_conv == -1:
      opt.head_conv = 256 if 'dla' in opt.arch else 64
    '''
    head_conv 默认值为 -1
    arch 默认值为 dla_34
    最后 head_conv 的值为 256
    head_conv 参数意义:输出头的转换层通道
                           0 表示没有 conv 层
                          -1 默认设置:
                                resnets 是 64
                                dla 是 256
    '''

    opt.pad = 127 if 'hourglass' in opt.arch else 31
    '''
    新添加参数 pad 的值为 31
    '''

    opt.num_stacks = 2 if opt.arch == 'hourglass' else 1
    '''
    新添加参数 num_stacks 的值为 1
    '''

    if opt.master_batch_size == -1:
      opt.master_batch_size = opt.batch_size // len(opt.gpus)
    '''
    batch_size 默认值为 32,在 train.sh 中的值也为 32
    master_batch_size 默认值为 -1
    最后 master_batch_size 的值为 16
    master_batch_size 参数意义:主图形处理器上的批处理大小
    '''

    rest_batch_size = (opt.batch_size - opt.master_batch_size)
    '''
    rest_batch_size 的值为 32 - 16 = 16
    '''

    opt.chunk_sizes = [opt.master_batch_size]
    '''
    添加新参数 chunk_sizes 的值为 [16]
    '''

    for i in range(len(opt.gpus) - 1):
      slave_chunk_size = rest_batch_size // (len(opt.gpus) - 1)
      if i < rest_batch_size % (len(opt.gpus) - 1):
        slave_chunk_size += 1
      opt.chunk_sizes.append(slave_chunk_size)
    '''
    根据 GPU 的数量设置训练块的大小
    一块 GPU 的训练块大小为 32
    两块 GPU 其中每个训练块大小为 16
    '''

    if opt.debug > 0:
      opt.num_workers = 0
      opt.batch_size = 1
      opt.gpus = [opt.gpus[0]]
      opt.master_batch_size = -1
    '''
    debug 默认值为 0
    该 if 语句没有执行
    debug 参数含义:可视化的水平
                1:只显示最终检测结果
                2:显示网络输出特征
                3:使用 matplot 显示
                4:将所有可视化内容保存到磁盘
    '''

    opt.root_dir = os.path.join(os.path.dirname(__file__), '..', '..')
    opt.data_dir = os.path.join(opt.root_dir, 'data')
    opt.exp_dir = os.path.join(opt.root_dir, 'exp', opt.task)
    opt.save_dir = os.path.join(opt.exp_dir, opt.exp_id)
    opt.debug_dir = os.path.join(opt.save_dir, 'debug')
    '''
    添加路径参数,并设置 log 路径
    root_dir  = ~/CenterFusion/src/lib/../..
    data_dir  = ~/CenterFusion/src/lib/../../data
    exp_dir   = ~/CenterFusion/src/lib/../../exp/ddd
    save_dir  = ~/CenterFusion/src/lib/../../exp/ddd/centerfusion
    debug_dir = ~/CenterFusion/src/lib/../../exp/ddd/centerfusion/debug
    '''
    
    if opt.resume and opt.load_model == '':
      opt.load_model = os.path.join(opt.save_dir, 'model_last.pth')
    '''
    resume 在 train.sh 没有添加该参数,所以为 False
    load_model 默认值为 '' ,在 train.sh 中设置了该参数的值
    所以该 if 语句没有执行
    参数含义:
      resume:重新加载优化器参数,并在 load_model 为空时将 load_model 设置为 model_last.pth
      load_model:预训练模型的路径
    '''

    opt.pc_atts = ['x', 'y', 'z', 'dyn_prop', 'id', 'rcs', 'vx', 'vy', 
                    'vx_comp', 'vy_comp', 'is_quality_valid', 
                    'ambig_state', 'x_rms', 'y_rms', 'invalid_state', 
                    'pdh0', 'vx_rms', 'vy_rms']
    '''
    添加新参数 pc_atts,并设置雷达点云相关属性
    '''

    pc_attr_ind = {x:i for i,x in enumerate(opt.pc_atts)}
    '''
    enumerate():函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中
    这里是对 opt.pc_atts 中的属性递增赋值
    结果 = {'x': 0, 'y': 1, 'z': 2, 'dyn_prop': 3, 'id': 4, 'rcs': 5,......}
    '''

    opt.pillar_dims = [float(i) for i in opt.pillar_dims.split(',')]
    '''
    pillar_dims 默认值为 2.0,0.5,0.5 在 train.sh 中的值为 1.0,0.2,0.2
    这里是将字符串处理成浮点型数组
    pillar_dims 参数含义:雷达柱尺寸(h、w、l)
    '''

    opt.num_img_channels = 3
    opt.hm_dist_thresh = None
    opt.sigmoid_dep_sec = False
    opt.hm_to_box_ratio = 0.3
    opt.secondary_heads = []
    opt.custom_head_convs = {}
    opt.normalize_depth = False
    opt.disable_frustum = False
    opt.layers_to_freeze = [
      'base', 
      'dla_up',
      'ida_up',
      # 'hm'
      # 'reg'
      # 'wh'
      # 'dep'
      # 'rot'
      # 'dim'
      # 'amodel_offset'
      # 'dep_sec'
      # 'nuscenes_att'
      # 'velocity'
    ]
    '''
    添加一些新参数,并设值
    '''
  
    if opt.pointcloud:
      '''
      在 train.sh 中添加了该参数,所以 pointcloud = True
      该 if 语句中添加一些新参数并设值
      '''
      extra_pc_feats = []
      opt.pc_roi_method = "pillars"
      opt.pillar_dims = [1.5,0.2,0.2]
      opt.pc_feat_lvl = [
        'pc_dep',
        'pc_vx',
        'pc_vz',
      ]
      opt.frustumExpansionRatio = 0.0
      opt.disable_frustum = False
      opt.sort_det_by_dist = False
      opt.sigmoid_dep_sec = True
      opt.normalize_depth = True
      opt.secondary_heads = ['velocity', 'nuscenes_att', 'dep_sec', 'rot_sec']
      opt.hm_dist_thresh = {
        'car': 0, 
        'truck': 0,
        'bus': 0,
        'trailer': 0, 
        'construction_vehicle': 0, 
        'pedestrian': 1,
        'motorcycle': 1,
        'bicycle': 1, 
        'traffic_cone': 0, 
        'barrier': 0
      }
      opt.custom_head_convs = {
        'dep_sec': 3,
        'rot_sec': 3,
        'velocity': 3,
        'nuscenes_att': 3,
      }
      opt.pc_feat_channels = {feat: i for i,feat in enumerate(opt.pc_feat_lvl)}

    CATS = ['car', 'truck', 'bus', 'trailer', 'construction_vehicle', 
        'pedestrian', 'motorcycle', 'bicycle', 'traffic_cone', 'barrier']
    CAT_IDS = {v: i for i, v in enumerate(CATS)}
    '''
    设置目标物体类别及 id
    '''
    
    if opt.hm_dist_thresh is not None:
      temp = {}
      for (k,v) in opt.hm_dist_thresh.items():
        temp[CAT_IDS[k]] = v
      opt.hm_dist_thresh = temp
    '''
    items():将一个字典以列表的形式返回
    '''
    
    return opt


  def update_dataset_info_and_set_heads(self, opt, dataset):
    '''
    opt 为 opts 类,被定义在本文件中
    dataset 为 nuScenes 类,被定义在 CenterFusion-master/src/lib/dataset/datasets/nuscenes.py 中
    '''

    opt.num_classes = dataset.num_categories \
                      if opt.num_classes < 0 else opt.num_classes
    '''
    dataset.num_categories 默认值为 10 ,num_classes 值为 10
    这里是为了获取目标物体种类数
    参数含义:
        num_categories :在 nuScenes 中目标物体种类数
        num_classes :目标物体类别数
    '''

    input_h, input_w = dataset.default_resolution
    input_h = opt.input_res if opt.input_res > 0 else input_h
    input_w = opt.input_res if opt.input_res > 0 else input_w
    opt.input_h = opt.input_h if opt.input_h > 0 else input_h
    opt.input_w = opt.input_w if opt.input_w > 0 else input_w
    '''
    默认值 default_rerolution = [448,800],opt.input_res = opt.input_h = opt.input_w = -1
    获取 nuScenes 中的图片像素值
    结果:
        opt.input_h = 448
        opt.input_w = 800
    '''

    opt.output_h = opt.input_h // opt.down_ratio
    opt.output_w = opt.input_w // opt.down_ratio
    '''
    down_ratio 默认值为 4
    结果
        opt.output_h = 112
        opt.output_w = 200
    '''

    opt.input_res = max(opt.input_h, opt.input_w)
    opt.output_res = max(opt.output_h, opt.output_w)
    '''
    结果
        opt.input_res = 800
        opt.output_res = 200
    '''
  
    opt.heads = {'hm': opt.num_classes, 'reg': 2, 'wh': 2}
    if 'tracking' in opt.task:
      opt.heads.update({'tracking': 2})
    if 'ddd' in opt.task:
      opt.heads.update({'dep': 1, 'rot': 8, 'dim': 3, 'amodel_offset': 2})
    if opt.pointcloud:
      opt.heads.update({'dep_sec': 1})
      opt.heads.update({'rot_sec': 8})
    if 'multi_pose' in opt.task:
      opt.heads.update({
        'hps': dataset.num_joints * 2, 'hm_hp': dataset.num_joints,
        'hp_offset': 2})
    if opt.ltrb:
      opt.heads.update({'ltrb': 4})
    if opt.ltrb_amodal:
      opt.heads.update({'ltrb_amodal': 4})
    if opt.nuscenes_att:
      opt.heads.update({'nuscenes_att': 8})
    if opt.velocity:
      opt.heads.update({'velocity': 3})
    '''
    设置 opt.heads 中的内容
    '''

    weight_dict = {'hm': opt.hm_weight, 'wh': opt.wh_weight,
                   'reg': opt.off_weight, 'hps': opt.hp_weight,
                   'hm_hp': opt.hm_hp_weight, 'hp_offset': opt.off_weight,
                   'dep': opt.dep_weight, 'dep_res': opt.dep_res_weight,
                   'rot': opt.rot_weight, 'dep_sec': opt.dep_weight,
                   'dim': opt.dim_weight, 'rot_sec': opt.rot_weight,
                   'amodel_offset': opt.amodel_offset_weight,
                   'ltrb': opt.ltrb_weight,
                   'tracking': opt.tracking_weight,
                   'ltrb_amodal': opt.ltrb_amodal_weight,
                   'nuscenes_att': opt.nuscenes_att_weight,
                   'velocity': opt.velocity_weight}
    '''
    设置 weight_dict 中的内容
    '''

    opt.weights = {head: weight_dict[head] for head in opt.heads}
    '''
    根据 opt.heads 和 weight_dict 设置 opt.weights 中的内容
    '''
      
    for head in opt.weights:
      if opt.weights[head] == 0:
        del opt.heads[head]
    '''
    遍历 weights 并删除值为 0 的属性
    '''
    
    temp_head_conv = opt.head_conv
    opt.head_conv = {head: [opt.head_conv for i in range(opt.num_head_conv if head != 'reg' else 1)] for head in opt.heads}
    '''
    设置 opt.head_conv 中的内容
    '''
    
    if opt.pointcloud:
      temp = {k: [temp_head_conv for i in range(v)] for k,v in opt.custom_head_convs.items()}
      opt.head_conv.update(temp)
    '''
    更新自定义头部变换
    '''
    
    #print('input h w:', opt.input_h, opt.input_w)
    #print('heads', opt.heads)
    #print('weights', opt.weights)
    #print('head conv', opt.head_conv)

    return opt

  def init(self, args=''):
    # only used in demo
    default_dataset_info = {
      'ctdet': 'coco', 'multi_pose': 'coco_hp', 'ddd': 'nuscenes',
      'tracking,ctdet': 'coco', 'tracking,multi_pose': 'coco_hp', 
      'tracking,ddd': 'nuscenes'
    }
    opt = self.parse()
    from dataset.dataset_factory import dataset_factory
    train_dataset = default_dataset_info[opt.task] \
      if opt.task in default_dataset_info else 'coco'
    dataset = dataset_factory[train_dataset]
    opt = self.update_dataset_info_and_set_heads(opt, dataset)
    return opt

你可能感兴趣的:(深度学习,计算机视觉,python)