pytorch图像分类框架搭建——训练流程的定义xxx_train.py

 训练脚本主要包含两个函数:parse_args()用来收集参数,train()定义了整个训练流程。

import argparse
import time
import tqdm
import logging
from pathlib import Path
from datetime import datetime
from einops import rearrange
import sys
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.models import resnext50_32x4d
from criterion import LSR
from tensorboardX import SummaryWriter
from utils.logger import setup_root_logger
from utils.collect_env import collect_env_info
from utils.dataset import ClassImageDataset
from utils.model_utils import split_weights
from utils.lr_scheduler import WarmUpLR


def parse_args():
    parser = argparse.ArgumentParser(description='Train phase params')
    parser.add_argument('--work-dir', default='./exp', help='the dir to save logs and models')
    parser.add_argument('--weights', help='pretrained model path')
    parser.add_argument('--data-path', default='data',help='dataset path')
    parser.add_argument('--form-scratch', action='store_true', default=False, help='trianing form epoch 1')
    parser.add_argument('--batch-size', type=int, default=16, help='batch size for dataloader')
    parser.add_argument('-lr', type=float, default=1e-2, help='learning rate')
    parser.add_argument('-w', type=int, default=4, help='number of workers for dataloader')
    parser.add_argument('--epoches', type=int, default=200, help='training epoches')
    parser.add_argument('-warm', type=int, default=5, help='warm up phase')
    parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
    args = parser.parse_args()

    return args

def train():
    args = parse_args()

    # directories
    work_dir = Path(args.work_dir)
    work_dir.mkdir(parents=True, exist_ok=True)  # make dir
    last, best = 'last_epoch{:d}_acc{:0.2f}.pth', 'best_epoch{:d}_acc{:0.2f}.pth'

    # log
    tensorboardLog_path = work_dir / 'tensorboardLoggs'
    tensorboardLog_path.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(log_dir=tensorboardLog_path)
    setup_root_logger(work_dir, 0)

    logger = logging.getLogger('class_train')
    logger.info(args)
    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    # build dataset and dataloader
    train_data_path = Path(args.data_path) / 'train'
    test_data_path = Path(args.data_path) / 'test'
    train_dataset = ClassImageDataset(train_data_path)
    test_dataset   = ClassImageDataset(test_data_path, augment=False)
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.w, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.w, )

    # build and initial model
    model = resnext50_32x4d(num_classes=8).to(DEVICE)


    # criterion
    lsr_loss = LSR()

    #apply no weight decay on bias
    params = split_weights(model)

    # butild lr_scheduler
    optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9, weight_decay=1e-4, nesterov=True)

    # set up warmup phase learning rate scheduler
    iter_per_epoch = len(train_dataloader)
    warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm)

    #set up training phase learning rate scheduler
    train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 90])
    #train_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.e - args.warm)

    # resume
    start_epoch, best_acc = 1, 0.0
    if args.weights is not None:
        # pretrained
        try:
            ckpt = torch.load(args.weights, map_location=DEVICE)
        except:
            raise "Try to reseum from %s, but got errors!"%args.resume_from

        # loading model weights
        logger.info(f"Initializing model weights with {args.weights}")
        if ckpt['model'] is not None:
            model.load_state_dict(ckpt['model'])
        else:
            model.load_state_dict(ckpt)

        if not args.from_scratch:
            # resume optimizer
            if ckpt['optimizer'] is not None:
                logger.info(f"Loading optimizer from {args.weights}")
                optimizer.load_state_dict(ckpt['optimizer'])
            else:
                logger.info(f"No optimizer loaded")

            if ckpt['epoch'] is not None:
                start_epoch = ckpt['epoch'] + 1

            if ckpt['acc'] is not None:
                best_acc = ckpt['acc']
    else :
        logger.info("No checkpoint found. Initializing model from scratch")


    # training procedure
    logger.info("Start training")
    for epoch in range(start_epoch, args.epoches + 1):
        # Scheduler update
        if epoch > args.warm:
            train_scheduler.step(epoch)

        model.train()
        tn = len(train_dataloader.dataset)
        tb = len(train_dataloader)

        # start batch ------------------------------------------------------------------------------------------------
        start = time.time()
        pre = start
        for batch_idx, (imgs, _, _, y, _) in enumerate(train_dataloader):
            # warmup
            if epoch <= args.warm:
                warmup_scheduler.step()

            # to DEVICE
            imgs = rearrange(imgs, 'b n c h w -> (b n) c h w')
            y = rearrange(y, 'b n  -> (b n)')
            imgs, y = imgs.to(DEVICE), y.to(DEVICE)

            # forward
            pred = model(imgs)
            loss = lsr_loss(pred, y)
            correct = (pred.argmax(1) == y).type(torch.float).sum().item()
            optimizer.zero_grad()

            # backward
            loss.backward()

            # optimize
            optimizer.step()

            # train visulaization
            n_iter = (epoch - 1) * tb + batch_idx + 1
            writer.add_scalar('Train/loss', loss.item(), n_iter)
            writer.add_scalar('Train/acc', correct / len(imgs) * 100, n_iter)

            # logging
            cur = time.time()
            logger.info('Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tAcc: {:0.2f}%\tLR: {:0.8f}\t[{used_time}<{eta}]'.format(
                loss.item(),
                correct / len(imgs) * 100,
                optimizer.param_groups[0]['lr'],
                epoch=epoch,
                trained_samples=batch_idx * args.batch_size *32 + len(imgs),
                total_samples=tn*32,
                used_time = '%02d:%02d' % ((cur - start) // 60, (cur - start) % 60),
                eta = '%02d:%02d' % ((cur - start + (tb - batch_idx - 1) * (cur - pre)) // 60, (cur - start + (tb - batch_idx - 1) * (cur - pre)) % 60),
            ))
            pre = cur
        # end batch ------------------------------------------------------------------------------------------------

        # eval procedure
        model.eval()
        test_loss, correct = 0, 0
        pbar = tqdm(test_dataloader, desc='evaling', total=len(test_dataloader))
        with torch.no_grad():
            for imgs, _, _, y, _ in pbar:
                # to DEVICE
                imgs = rearrange(imgs, 'b n c h w -> (b n) c h w')
                y = rearrange(y, 'b n  -> (b n)')
                imgs, y = imgs.to(DEVICE), y.to(DEVICE)

                # forward
                pred = model(imgs)
                test_loss += lsr_loss(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
                optimizer.zero_grad()

        test_loss /= len(test_dataloader)
        acc = correct / len(test_dataloader) * 100
        logger.info('Test metrics: Loss: {:0.4f}\tAcc: {:0.2f}%\t'.format(test_loss, acc))
        # test visulaization
        writer.add_scalar('Test/loss', test_loss, epoch)
        writer.add_scalar('Test/acc', acc, epoch)

        # save weights file
        if not args.nosave:
            ckpt = {
              'epoch': epoch,
              'model': model.state_dict(),
              'optimizer': optimizer.state_dict(),
              'acc': acc,
            }
            last_pth = weights_path / last.format(epoch, acc)
            torch.save(ckpt, last_path)
            if epoch > 10 and best_acc < acc:
                best_acc = acc
                best_pth = weights_path / best.format(epoch, best_acc)
                torch.save(ckpt, best_pth)
    # end epoch ------------------------------------------------------------------------------------------------



if __name__ == '__main__':
    train()

1,几个参数的介绍:

--work-dir:用来指定保存训练日志、权重参数的文件夹

--weights:初始化模型的权重文件,跟--from-scratch搭配实现预训练或者恢复现场训练

--warm:使用WarmUp初始化学习率所需要的迭代轮数

2,训练流程的定义:

a) 初始化参数

b) 创建工作目录,顺便提一下好用的路径处理包pathlib

c) logger的初始化:这里要注意坑呀,logging类是可以通过名称进行派生的比如:

>>> a = logging.getLogger('a')
>>> b = logging.getLogger('a.b')
>>> a, b
(, )
>>> b.parent

而所有的logger都是RootLogger(可以通过传入空名字得到,logging.getLogger(' '))的子类,因此可以通过设置RootLogger的行为控制所有自定义的logger的行为:

def setup_root_logger(save_dir, distributed_rank, filename="log.txt"):
    root_logger = logging.getLogger()
    root_logger.setLevel(logging.DEBUG)
    # don't log results for the non-master process
    if distributed_rank > 0:
        return
    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setLevel(logging.DEBUG)
    formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
    ch.setFormatter(formatter)
    root_logger.addHandler(ch)

    if save_dir:
        save_dir = Path(save_dir) if not isinstance(save_dir, Path) else save_dir
        fh = logging.FileHandler(save_dir / filename, mode='w')
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(formatter)
        root_logger.addHandler(fh)

d) 初始化dataset以及dataloader,这里注意设置dataloader合适的的num_worker.num_worker设置得大,好处是寻batch速度快,因为下一轮迭代的batch很可能在上一轮/上上一轮迭代时已经加载好了。坏处是内存开销大,也加重了CPU负担(worker加载数据到RAM的进程是CPU复制的嘛)。num_workers的经验设置值是自己电脑/服务器的CPU核心数,如果CPU很强、RAM也很充足,就可以设置得更大些。一般开始是将num_workers设置为等于计算机上的CPU数量,最好的办法是缓慢增加num_workers,直到训练速度不再提高,就停止增加num_workers的值。

e) 构建模型以及损失函数。不要忘了对模型参数进行初始化,这在没有预训练权重时很重要。

f) 定义优化器和学习率调整策略。带WarmUp的自己定义比较好,官方的只支持按步调整。

g) 加载预训练权重。最好用一个同一类来实现权重的加载与保存。

f) 执行模型训练,完成一个epoch后在test数据集上测试。

你可能感兴趣的:(深度学习,计算机视觉,Linux,pytorch,深度学习,caffe)