文件作用:CenterFusion 项目验证的执行过程
if __name__ == '__main__':
opt = opts().parse()
'''
调用 opts.py 中的 parse() 函数.
在 CenterFusion/src/lib/opts.py 第 305 行
'''
if opt.not_prefetch_test: # 如果opt不能预取test则进行 test(opt)
test(opt)
else:
prefetch_test(opt) # 如果可以取,则进行prefetch_test(opt)
'''
由于在 test.sh 中并没有添加参数 not_prefetch_test,所以 opt.not_prefetch_test = False
最后只执行了 else 中的语句,调用 prefetch_test() 函数.
'''
def prefetch_test(opt):
if not opt.not_set_cuda_env:
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
'''
由于在 test.sh 中没有添加参数 not_set_cuda_env,所以 opt.not_set_cuda_env = False,再 not 取反
执行该 if 语句.
gpus_str 是 opt 中的一个 GPU 索引号字符串,如:'0,1'
这里是为了给系统添加 cuda 索引号.
'''
Dataset = dataset_factory[opt.test_dataset]
'''
设置数据集对象 nuScenes,Dataset 是一个 nuScenes 类.
dataset_factory 在 CenterFusion/src/lib/dataset/dataset_factory.py 第 20 行
nuScenes对象定义在 CenterFusion/src/lib/dataset/datasets/nuscenes.py 中.
'''
opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
'''
设置一些配置信息.
update_dataset_info_and_set_heads() 函数在 opys.py 第 458 行
'''
print(opt)
Logger(opt)
'''
新建一个 Logger 对象,记录配置信息.
Logger 对象在 CenterFusion/src/lib/logger.py
'''
split = 'val' if not opt.trainval else 'test'
'''
在 test.sh 中没有添加参数 trainval,所以 opt.trainval = False,再 not 取反,为 True
所以 split = 'val'
'''
if split == 'val':
split = opt.val_split
'''
val_split 在 opts.py 中的默认值为 'val',但在 test.sh 中添加了该参数,并赋值为 'mini_val'
所以 split 的值为 'mini_val'
'''
dataset = Dataset(opt, split)
'''
传递参数 opt(配置信息)、split(数据集名称)
'''
detector = Detector(opt)
'''
Detector 类定义在 CenterFusion/src/lib/detector.py 中
'''
if opt.load_results != '':
load_results = json.load(open(opt.load_results, 'r'))
for img_id in load_results:
for k in range(len(load_results[img_id])):
if load_results[img_id][k]['class'] - 1 in opt.ignore_loaded_cats:
load_results[img_id][k]['score'] = -1
else:
load_results = {}
'''
load_results 默认值为 ''
所以执行了 else 语句
'''
data_loader = torch.utils.data.DataLoader(
PrefetchDataset(opt, dataset, detector.pre_process),
batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
'''
torch.utils.data.DataLoader 是一个数据读取的一个接口,参数:
PrefetchDataset(opt, dataset, detector.pre_process):加载数据的数据集
batch_size (int, optional):每个 batch 加载多少个样本(默认: 1)
shuffle (bool, optional):设置为 True 时会在每个 epoch 重新打乱数据(默认: False)
num_workers (int, optional):用多少个子进程加载数据。0 表示数据将在主进程中加载(默认: 0)
pin_memory (bool, optional):设置 pin_memory=True,则意味着生成的 Tensor 数据最开始是属于内存中的锁页内存,
这样将内存的 Tensor 转义到 GPU 的显存就会更快一些
PrefetchDataset 类,在上面,这个类继承了 torch.utils.data.Dataset 类,表示自定义了数据读取方式
最后返回一个列表给 data_loader,其中有图片 id 以及 tensor 格式的图片数据
'''
results = {}
num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
'''
num_iters 默认值为 -1,所以 num_iters = data_loader 列表的长度
'''
bar = Bar('{}'.format(opt.exp_id), max=num_iters)
'''
定义了一个进度条,如:centerfusion |### | 3/10
进度条名称为:centerfusion
进度条最大值为:num_iters
'''
time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge', 'track']
avg_time_stats = {t: AverageMeter() for t in time_stats}
'''
AverageMeter 类定义在 CenterFusion/src/lib/utils/utils.py 中
给 time_stats 中的每一个属性赋值为 AverageMeter 类
'''
if opt.use_loaded_results:
for img_id in data_loader.dataset.images:
results[img_id] = load_results['{}'.format(img_id)]
num_iters = 0
'''
在 test.sh 中没有添加参数 use_loaded_results ,所以值为 False
没有执行该 if 语句
'''
for ind, (img_id, pre_processed_images) in enumerate(data_loader):
'''
enumerate() 函数:用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列
同时列出数据和数据下标
其中 ind 表示对 data_loader 列表位置的计数
img_id 为图片数据的 id
pre_processed_images 为 tensor 格式的图片数据
'''
if ind >= num_iters:
break
'''
如果遍历完 data_loader 就结束.
'''
if opt.tracking and ('is_first_frame' in pre_processed_images):
if '{}'.format(int(img_id.numpy().astype(np.int32)[0])) in load_results:
pre_processed_images['meta']['pre_dets'] = \
load_results['{}'.format(int(img_id.numpy().astype(np.int32)[0]))]
else:
print()
print('No pre_dets for', int(img_id.numpy().astype(np.int32)[0]),
'. Use empty initialization.')
pre_processed_images['meta']['pre_dets'] = []
detector.reset_tracking()
print('Start tracking video', int(pre_processed_images['video_id']))
'''
由于在 test.sh 中没有添加参数 tracking,所以 opt.tracking = False
没有执行该 if 语句
'''
if opt.public_det:
if '{}'.format(int(img_id.numpy().astype(np.int32)[0])) in load_results:
pre_processed_images['meta']['cur_dets'] = \
load_results['{}'.format(int(img_id.numpy().astype(np.int32)[0]))]
else:
print('No cur_dets for', int(img_id.numpy().astype(np.int32)[0]))
pre_processed_images['meta']['cur_dets'] = []
'''
由于在 test.sh 中没有添加参数 public_det,所以 opt.public_det = False
没有执行该 if 语句
'''
ret = detector.run(pre_processed_images)
'''
run() 函数在 CenterFusion/src/lib/detector.py 第 56 行
对 tensor 格式的图片数据进行检测,并返回检测结果
'''
results[int(img_id.numpy().astype(np.int32)[0])] = ret['results']
'''
其中 img_id.numpy().astype(np.int32) 是将 img_id 强制转换成 int32 型的数据
这里是为了记录对应图片数据的检测结果,results[图片的索引号] = ret['result']
'''
Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
'''
给 bar 添加一些显示字符串
显示 ind、num_iters、bar.elapsed_td、bar.eta_td 的值
'''
for t in avg_time_stats:
avg_time_stats[t].update(ret[t])
Bar.suffix = Bar.suffix + '|{} {tm.val:.3f}s ({tm.avg:.3f}s) '.format(
t, tm = avg_time_stats[t])
'''
update() 函数在 CenterFusion/src/lib/utils/utils.py 第 18 行
计算 ret 中每个属性的平均值和当前值,并将其添加后 bar 的后面显示在屏幕上
'''
if opt.print_iter > 0:
if ind % opt.print_iter == 0:
print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
else:
bar.next()
'''
print_iter 默认值为 0,所以执行 else 语句
打印进度条到屏幕上
'''
bar.finish()
'''
进度条完成
'''
if opt.save_results:
print('saving results to', opt.save_dir + '/save_results_{}{}.json'.format(
opt.test_dataset, opt.dataset_version))
json.dump(_to_list(copy.deepcopy(results)),
open(opt.save_dir + '/save_results_{}{}.json'.format(
opt.test_dataset, opt.dataset_version), 'w'))
'''
在 test.sh 中没有添加 save_results 参数,所以 opt.save_results = False
没有执行该 if 语句
'''
dataset.run_eval(results, opt.save_dir, n_plots=opt.eval_n_plots,
render_curves=opt.eval_render_curves)
'''
对结果进行检测评估,评估结果保存在 ~/CenterFusion/src/lib/../../exp/ddd/centerfusion/nuscenes_eval_det_output_mini_val 下
run_eval() 函数在 CenterFusion/src/lib/dataset/datasets/nuscenes.py 第 272 行
results :图片数据的检测结果
save_dir :保存路径为 ~/CenterFusion/src/lib/../../exp/ddd/centerfusion
eval_n_plots :默认值为 0
eval_render_curves :渲染和保存评价曲线,在 test.sh 中没有添加该参数,则为 False
'''
def run_eval(self, results, save_dir, n_plots=10, render_curves=False):
task = 'tracking' if self.opt.tracking else 'det'
'''
由于 test.sh 中没有添加参数 tracking,所以 opt.tracking 的值为 False
所以 task = 'det'
'''
split = self.opt.val_split
'''
split = 'mini_val'
'''
version = 'v1.0-mini' if 'mini' in split else 'v1.0-trainval'
#================================ 源代码 版本判断有Bug 修改位置如下:======================================
if version == 'v1.0-mini' :
flag=1
print("当前数据集选择的版本是:",version)
elif version == 'v1.0-test':
print("当前数据集选择的版本是:",version)
elif version == 'v1.0-trainval':
print("当前数据集选择的版本是:",version)
else:
print("您设置的数据集版本不存在,请重新输入!")
#===========================================================================================================
''' 目前存在的数据集名称列表有
split_names = {
'mini_train':'mini_train',
'mini_val':'mini_val',
'train': 'train',
'train_detect': 'train_detect',
'train_track':'train_track',
'val': 'val',
'test': 'test',
'mini_train_2': 'mini_train_2',
'trainval': 'trainval',
}
'''
'''
version = 'v1.0-mini'
'''
self.save_results(results, save_dir, task, split)
'''
保存结果为 json 文件
'''
render_curves = 1 if render_curves else 0
'''
render_curves = 0
'''
if task == 'det':
output_dir = '{}/nuscenes_eval_det_output_{}/'.format(save_dir, split)
'''
设置输出路径
'''
os.system('python ' + \
'tools/nuscenes-devkit/python-sdk/nuscenes/eval/detection/evaluate.py ' + \
'{}/results_nuscenes_{}_{}.json '.format(save_dir, task, split) + \
'--output_dir {} '.format(output_dir) + \
'--eval_set {} '.format(split) + \
'--dataroot ../data/nuscenes/ ' + \
'--version {} '.format(version) + \
'--plot_examples {} '.format(n_plots) + \
'--render_curves {} '.format(render_curves))
'''
执行官网 evaluate.py 文件
对结果进行检测评估,并输出到 output_dir 路径下.
'''
else:
output_dir = '{}/nuscenes_evaltracl__output/'.format(save_dir)
os.system('python ' + \
'tools/nuscenes-devkit/python-sdk/nuscenes/eval/tracking/evaluate.py ' + \
'{}/results_nuscenes_{}_{}.json '.format(save_dir, task, split) + \
'--output_dir {} '.format(output_dir) + \
'--dataroot ../data/nuscenes/')
os.system('python ' + \
'tools/nuscenes-devkit/python-sdk-alpha02/nuscenes/eval/tracking/evaluate.py ' + \
'{}/results_nuscenes_{}_{}.json '.format(save_dir, task, split) + \
'--output_dir {} '.format(output_dir) + \
'--dataroot ../data/nuscenes/')
return output_dir