姿态估计2-05:PVNet(6D姿态估计)--源码无死角解析(1)-训练代码总览

以下链接是个人关于PVNet(6D姿态估计) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。
姿态估计2-00:PVNet(6D姿态估计)-目录-史上最新无死角讲解

train_net.py注释

下面是对train_net.py文件的注释,该代码十分的简单,所以注释也十分简洁:

from lib.config import cfg, args
from lib.networks import make_network
from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler
from lib.datasets import make_data_loader
from lib.utils.net_utils import load_model, save_model, load_network
from lib.evaluators import make_evaluator
import torch.multiprocessing


def train(cfg, network):
    # 如果训练数据为City,这进行文件系统共享
    if cfg.train.dataset[:4] != 'City':
        torch.multiprocessing.set_sharing_strategy('file_system')
    # 制作训练器
    trainer = make_trainer(cfg, network)
    # 制作优化器
    optimizer = make_optimizer(cfg, network)
    # 制作学习率调整器
    scheduler = make_lr_scheduler(cfg, optimizer)
    #  用于记录信息
    recorder = make_recorder(cfg)
    # 用于评估
    evaluator = make_evaluator(cfg)

    # 进行模型加载
    begin_epoch = load_model(network, optimizer, scheduler, recorder, cfg.model_dir, resume=cfg.resume)
    # set_lr_scheduler(cfg, scheduler)

    # 创建训练以及评估数据集
    train_loader = make_data_loader(cfg, is_train=True, max_iter=cfg.ep_iter)
    val_loader = make_data_loader(cfg, is_train=False)
    # train_loader = make_data_loader(cfg, is_train=True, max_iter=100)

    # 循环进行迭代训练
    for epoch in range(begin_epoch, cfg.train.epoch):
        recorder.epoch = epoch
        # 进行一个epoch的迭代训练
        trainer.train(epoch, train_loader, optimizer, recorder)
        # 记录学习了一个epoch,并且根据预设定的参数,看是否需要对学习率进行更改
        scheduler.step()
        # 迭代到指定次数,保存好训练的
        if (epoch + 1) % cfg.save_ep == 0:
            save_model(network, optimizer, scheduler, recorder, epoch, cfg.model_dir)
        # 迭代到指定次数,进行评估训练
        if (epoch + 1) % cfg.eval_ep == 0:
            trainer.val(epoch, val_loader, evaluator, recorder)

    return network


def test(cfg, network):
    # 根据配置创建训练器
    trainer = make_trainer(cfg, network)
    # 创建数据迭代器
    val_loader = make_data_loader(cfg, is_train=False)
    # 创建评估器
    evaluator = make_evaluator(cfg)
    # 加载权重
    epoch = load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch)
    # 进行评估
    trainer.val(epoch, val_loader, evaluator)


def main():
    # 根据配置参数,构建网路
    network = make_network(cfg)
    # 根据传入的参数选择测试或者训练
    if args.test:
        test(cfg, network)
    else:
        train(cfg, network)


if __name__ == "__main__":
    main()

总结

训练代码的套路基本都是差不多的,基本就是
1.解析参数
2.构建网络模型
3.加载训练测试数据集迭代器
4.迭代训练
5.模型评估保存

你可能感兴趣的:(姿态估计,CVPR2019,pytorch,PVNet)