PV-RCNN代码解读——train & test

PV-RCNN:paper,code

1. train

tools/train.py中找到以下表示开始训练的代码

    # -----------------------start training---------------------------
    logger.info('**********************Start training %s/%s(%s)**********************'
                % (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
    train_model(
        model, optimizer, train_loader, model_func=model_fn_decorator(),
        lr_scheduler=lr_scheduler, optim_cfg=cfg.OPTIMIZATION,
        start_epoch=start_epoch, total_epochs=args.epochs,
        start_iter=it, rank=cfg.LOCAL_RANK, tb_log=tb_log, 	
        ckpt_save_dir=ckpt_dir, train_sampler=train_sampler,
        lr_warmup_scheduler=lr_warmup_scheduler,
        ckpt_save_interval=args.ckpt_save_interval,
        max_ckpt_save_num=args.max_ckpt_save_num,
        merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch
    )

train_model定位在头文件中的tools/train_utils/train_utils.py
其关键信息的框架为

for key in epochs:
	train_one_epoch #训练一个epoch
save_trained_model #储存训练好的模型

找到train_model中的train_one_epoch()

accumulated_iter = train_one_epoch(
                model, optimizer, train_loader, model_func,
                lr_scheduler=cur_scheduler,
                accumulated_iter=accumulated_iter, optim_cfg=optim_cfg,
                rank=rank, tbar=tbar, tb_log=tb_log,
                leave_pbar=(cur_epoch + 1 == total_epochs),
                total_it_each_epoch=total_it_each_epoch,
                dataloader_iter=dataloader_iter
            )

train_one_epoch函数的定义在同一py文件中,其关键信息为

def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, accumulated_iter, optim_cfg,
                    rank, tbar, total_it_each_epoch, dataloader_iter, tb_log=None, leave_pbar=False): 
        #找到表示训练和梯度优化等的关键函数
        model.train() #一个固定语句
        optimizer.zero_grad() #梯度清零
        loss, tb_dict, disp_dict = model_func(model, batch) #求loss
        loss.backward() #反向传播
        clip_grad_norm_(model.parameters(), optim_cfg.GRAD_NORM_CLIP) #梯度裁剪
        optimizer.step() #更新

用到了几个Pytorch自带的函数:
model.train()
在训练模型时都会在前面加上model.train()
在测试模型时都会在前面加上model.eval()
如果不写这两个程序也可以运行,这两个方法是针对在训练和测试时采用不同方式的情况,比如Batch NormalizationDropout。详细介绍。

clip_grad_norm_()
功能是梯度裁剪。即为了防止梯度爆炸,当梯度超过阈值optim_cfg时将其设置为阈值。详细介绍。

optimizer.step()
功能是根据网络反向传播的梯度信息,更新网络的参数,以降低loss。详细介绍。

训练结束之后info

    logger.info('**********************End training %s/%s(%s)**********************\n\n\n'
                % (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))

2. test

接下来程序会自动执行测试过程

    logger.info('**********************Start evaluation %s/%s(%s)**********************' %
                (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
	test_set, test_loader, sampler = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        class_names=cfg.CLASS_NAMES,
        batch_size=args.batch_size,
        dist=dist_train, workers=args.workers, logger=logger, training=False
    ) # test_loader为一个batch_size大小的Tensor
    
    eval_output_dir = output_dir / 'eval' / 'eval_with_train' #输出路径
    eval_output_dir.mkdir(parents=True, exist_ok=True)
    args.start_epoch = max(args.epochs - 10, 0)  # Only evaluate the last 10 epochs
    # args.start_epoch = max(args.epochs - 1, 0)  # Only evaluate the last epoch

    repeat_eval_ckpt( #测试训练好的模型
        model.module if dist_train else model,
        test_loader, args, eval_output_dir, logger, ckpt_dir,
        dist_test=dist_train
    )

built_dataloader()
dataloaderPyTorch中数据读取的一个重要接口,功能是将自定义的Dataset封装成一个batch_size大小的Tensor,用于后面的训练。

repeat_eval_ckpt()
tools/test中定义的函数,其关键部分为

def repeat_eval_ckpt(model, test_loader, args, eval_output_dir, logger, ckpt_dir, dist_test=False):
		# 加载数据
        model.load_params_from_file(filename=cur_ckpt, logger=logger, to_cpu=dist_test)
        model.cuda()
        # 测试结果的储存路径
        cur_result_dir = eval_output_dir / ('epoch_%s' % cur_epoch_id) / cfg.DATA_CONFIG.DATA_SPLIT['test'] 
        # 测试一个epoch
        tb_dict = eval_utils.eval_one_epoch( 
            cfg, model, test_loader, cur_epoch_id, logger, dist_test=dist_test,
            result_dir=cur_result_dir, save_to_file=args.save_to_file
        )
        # 储存这个测试过的epoch
        with open(ckpt_record_file, 'a') as f:
            print('%s' % cur_epoch_id, file=f)
        logger.info('Epoch %s has been evaluated' % cur_epoch_id)

eval_one_epoch()的定义在tools/eval_utils/eval_utils.py中,其关键部分为

def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, save_to_file=False, result_dir=None):

	# 储存路径
    result_dir.mkdir(parents=True, exist_ok=True)
    final_output_dir = result_dir / 'final_result' / 'data'
    if save_to_file:
        final_output_dir.mkdir(parents=True, exist_ok=True)
        
	# 开始测试
    logger.info('*************** EPOCH %s EVALUATION *****************' % epoch_id)
 
    '''
    这里省略了把资源分配给多个GPU的代码,详见源码
    '''
    
	# 对于dataloader中的每一个batch_dict产生预测值和反馈值
    start_time = time.time()
    for i, batch_dict in enumerate(dataloader):
        load_data_to_gpu(batch_dict)
        with torch.no_grad():
            pred_dicts, ret_dict = model(batch_dict)
        disp_dict = {}
        
		#产生预测值,对于kitti数据引用的函数定义在pcdet/datasets/kitti/kitti_dataset.py
        statistics_info(cfg, ret_dict, metric, disp_dict)
        annos = dataset.generate_prediction_dicts(
            batch_dict, pred_dicts, class_names,
            output_path=final_output_dir if save_to_file else None
        )
        det_annos += annos
        
        #显示进度条
        if cfg.LOCAL_RANK == 0:
            progress_bar.set_postfix(disp_dict)
            progress_bar.update()

    if cfg.LOCAL_RANK == 0:
        progress_bar.close()
        
    '''
    这里省略了显示进度条&打印测试结果的语句,详见源码
    '''
	
	# 得到测试结果,evaluation()函数我会详细写一个博客写
    result_str, result_dict = dataset.evaluation( 
        det_annos, class_names,
        eval_metric=cfg.MODEL.POST_PROCESSING.EVAL_METRIC,
        output_path=final_output_dir
    ) 

	# 更新结果,打印日志,测试结束
    logger.info(result_str)
    ret_dict.update(result_dict)
    logger.info('Result is save to %s' % result_dir)
    logger.info('****************Evaluation done.*****************')
    return ret_dict

evaluation函数详解

我的其他PV-RCNN代码解读系列文章,如果对你有帮助的话,请给我点赞哦~
PV-RCNN代码解读——输入参数介绍
PV-RCNN代码解读——TP,FP,TN,FN的计算
PV-RCNN代码解读——eval.py
PV-RCNN代码解读——计算iou
PV-RCNN代码解读——数据初始化
PV-RCNN代码解读——从点云到输入神经网络的数据处理

你可能感兴趣的:(PV-RCNN,深度学习,python,神经网络,pytorch,机器学习)