HRNet源码阅读笔记(2),train的命令行main函数

def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=True
    )

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0])
    )
    writer_dict['writer'].add_graph(model, (dump_input, ))

    logger.info(get_model_summary(model, dump_input))

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT
    ).cuda()

    # Data loading code
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    train_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )
    valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY
    )

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(
        final_output_dir, 'checkpoint.pth'
    )

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR,
        last_epoch=last_epoch
    )

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)


        # evaluate on validation set
        perf_indicator = validate(
            cfg, valid_loader, valid_dataset, model, criterion,
            final_output_dir, tb_log_dir, writer_dict
        )

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint({
            'epoch': epoch + 1,
            'model': cfg.MODEL.NAME,
            'state_dict': model.state_dict(),
            'best_state_dict': model.module.state_dict(),
            'perf': perf_indicator,
            'optimizer': optimizer.state_dict(),
        }, best_model, final_output_dir)

    final_model_state_file = os.path.join(
        final_output_dir, 'final_state.pth'
    )
    logger.info('=> saving final model state to {}'.format(
        final_model_state_file)
    )
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()

上面第2行,就是关于命令行的,没啥可说的。

1、关于update_config

上面第3行,update_config函数,溯源有点麻烦。在train.py中有一句话:

from config import update_config

那我们就看config文件夹,(因为config相当于一个类?),这个文件夹下,有个__init__.py文件,显然这个文件相当于c++中的构造函数,默认首先执行一遍。

这个__init__.py文件如下:

from .default import _C as cfg
from .default import update_config
from .models import MODEL_EXTRAS

可见,第2行告诉我们,update_config函数,需要到default.py文件中去找。

下面看看default.py文件,关键代码如下:

。。。。。。。
def update_config(cfg, args):
    cfg.defrost()
    cfg.merge_from_file(args.cfg)
    cfg.merge_from_list(args.opts)

    if args.modelDir:
        cfg.OUTPUT_DIR = args.modelDir

    if args.logDir:
        cfg.LOG_DIR = args.logDir

    if args.dataDir:
        cfg.DATA_DIR = args.dataDir

    cfg.DATASET.ROOT = os.path.join(
        cfg.DATA_DIR, cfg.DATASET.ROOT
    )

    cfg.MODEL.PRETRAINED = os.path.join(
        cfg.DATA_DIR, cfg.MODEL.PRETRAINED
    )

    if cfg.TEST.MODEL_FILE:
        cfg.TEST.MODEL_FILE = os.path.join(
            cfg.DATA_DIR, cfg.TEST.MODEL_FILE
        )

    cfg.freeze()
。。。。。。。

这里面提到的cfg,就是experiments/coco/hrnet/w32_256x192_adam_lr1e-3.yaml

上面代码的第8行,怎么解释?

在w32_256x192_adam_lr1e-3.yaml文件里,你搜索一下OUTPUT_DIR,就是下面的第9行,

。。。。。。。。。。。。
AUTO_RESUME: true
CUDNN:
  BENCHMARK: true
  DETERMINISTIC: false
  ENABLED: true
DATA_DIR: ''
GPUS: (0,1,2,3)
OUTPUT_DIR: 'output'
LOG_DIR: 'log'
WORKERS: 24
PRINT_FREQ: 100

DATASET:
  COLOR_RGB: true
  DATASET: 'coco'
  DATA_FORMAT: jpg
  FLIP: true
  NUM_JOINTS_HALF_BODY: 8
  PROB_HALF_BODY: 0.3
  ROOT: 'data/coco/'
  ROT_FACTOR: 45
  SCALE_FACTOR: 0.35
  TEST_SET: 'val2017'
  TRAIN_SET: 'train2017'
MODEL:
。。。。。。。。。。。。。

可见,default.py中的

 if args.modelDir:
        cfg.OUTPUT_DIR = args.modelDir

说的是,如果用户在命令行中,指定了modelDir,那么cfg.OUTPUT_DIR就按指定的来,否则,按照default.py文件中的OUTPUT_DIR作为默认值

2、关于create_logger

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

注意,在train.py中,有一句话:

from utils.utils import create_logger

所以,我就去找这个utils文件夹下的utils文件,其中有相关代码如下:

def create_logger(cfg, cfg_name, phase='train'):
    root_output_dir = Path(cfg.OUTPUT_DIR)
    # set up logger
    if not root_output_dir.exists():
        print('=> creating {}'.format(root_output_dir))
        root_output_dir.mkdir()

    dataset = cfg.DATASET.DATASET + '_' + cfg.DATASET.HYBRID_JOINTS_TYPE \
        if cfg.DATASET.HYBRID_JOINTS_TYPE else cfg.DATASET.DATASET
    dataset = dataset.replace(':', '_')
    model = cfg.MODEL.NAME
    cfg_name = os.path.basename(cfg_name).split('.')[0]

    final_output_dir = root_output_dir / dataset / model / cfg_name

    print('=> creating {}'.format(final_output_dir))
    final_output_dir.mkdir(parents=True, exist_ok=True)

    time_str = time.strftime('%Y-%m-%d-%H-%M')
    log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase)
    final_log_file = final_output_dir / log_file
    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(filename=str(final_log_file),
                        format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logging.getLogger('').addHandler(console)

    tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / \
        (cfg_name + '_' + time_str)

    print('=> creating {}'.format(tensorboard_log_dir))
    tensorboard_log_dir.mkdir(parents=True, exist_ok=True)

    return logger, str(final_output_dir), str(tensorboard_log_dir)

这里面用到的那么多变量,还是要到前面说的命令行:

python tools/train.py \
    --cfg experiments/coco/hrnet/w32_256x192_adam_lr1e-3.yaml \

w32_256x192_adam_lr1e-3.yaml文件中去查找。

其实,这个create_logger,就是关于程序执行的制作log的过程。

3、关于cuda相关设定

在上面main函数的12、13、14行

cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

你还是需要在w32_256x192_adam_lr1e-3.yaml文件中去查找。

(1)设置 torch.backends.cudnn.benchmark=True 将会让程序在开始时花费一点额外时间,为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。适用场景是网络结构固定(不是动态变化的),网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的,其实也就是一般情况下都比较适用。反之,如果卷积层的设置一直变化,将会导致程序不停地做优化,反而会耗费更多的时间。

(2)设置cudnn.deterministic,说的是是否需要确定性,因为前馈会有些许的随机性。

(3)cudnn.enabled不解释了。

4、eval函数

上面main函数的16、17行,用到了eval函数:

 model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=True

eval()是python中功能非常强大的一个函数

将字符串当成有效的表达式来求值,并返回计算结果

所谓表达式就是:eval这个函数会把里面的字符串参数的引号去掉,把中间的内容当成Python的代码,eval函数会执行这段代码并且返回执行结果

例如:

1 基本的数学运算

# 1. 基本的数学运算result=eval("1 + 1")print(result)# 2

2 字符串重复

# 2. 字符串重复
result = eval("'+' * 5")
print(result)  # +++++

3 将字符串转换成列表

# 3. 将字符串转换成列表
result = type(eval("[1, 2, 3, 4]"))
print(result)  # 

4 将字符串转换成字典

result = type(eval("{'name': '小夏', 'age': 30}"))
print(result)  # 

那么,这里的eval('models.'+cfg.MODEL.NAME+'.get_pose_net'),会是什么东西呢?

答案:models.pose_hrnet.get_pose_net

也就是,去models文件夹下,找pose_hrnet.py文件,里面有个get_pose_net函数如下:

def get_pose_net(cfg, is_train, **kwargs):
    model = PoseHighResolutionNet(cfg, **kwargs)

    if is_train and cfg['MODEL']['INIT_WEIGHTS']:
        model.init_weights(cfg['MODEL']['PRETRAINED'])

    return model

这个get_pose_net函数,干了两件事

首先, model = PoseHighResolutionNet(cfg, **kwargs),执行了庞大的PoseHighResolutionNet函数,并且把cfg传递进去了。

其次,看w32_256x192_adam_lr1e-3.yaml文件中,['MODEL']['INIT_WEIGHTS']确实是true,那么,就执行model.init_weights(cfg['MODEL']['PRETRAINED'])

也就是:PRETRAINED: 'models/pytorch/imagenet/hrnet_w32-36af842e.pth'预训练模型。

那么,下一篇,就从庞大的PoseHighResolutionNet函数,开始讲起吧!

你可能感兴趣的:(人工智能之HRNet,python)