训练脚本主要包含两个函数:parse_args()用来收集参数,train()定义了整个训练流程。
import argparse
import time
import tqdm
import logging
from pathlib import Path
from datetime import datetime
from einops import rearrange
import sys
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.models import resnext50_32x4d
from criterion import LSR
from tensorboardX import SummaryWriter
from utils.logger import setup_root_logger
from utils.collect_env import collect_env_info
from utils.dataset import ClassImageDataset
from utils.model_utils import split_weights
from utils.lr_scheduler import WarmUpLR
def parse_args():
parser = argparse.ArgumentParser(description='Train phase params')
parser.add_argument('--work-dir', default='./exp', help='the dir to save logs and models')
parser.add_argument('--weights', help='pretrained model path')
parser.add_argument('--data-path', default='data',help='dataset path')
parser.add_argument('--form-scratch', action='store_true', default=False, help='trianing form epoch 1')
parser.add_argument('--batch-size', type=int, default=16, help='batch size for dataloader')
parser.add_argument('-lr', type=float, default=1e-2, help='learning rate')
parser.add_argument('-w', type=int, default=4, help='number of workers for dataloader')
parser.add_argument('--epoches', type=int, default=200, help='training epoches')
parser.add_argument('-warm', type=int, default=5, help='warm up phase')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
args = parser.parse_args()
return args
def train():
args = parse_args()
# directories
work_dir = Path(args.work_dir)
work_dir.mkdir(parents=True, exist_ok=True) # make dir
last, best = 'last_epoch{:d}_acc{:0.2f}.pth', 'best_epoch{:d}_acc{:0.2f}.pth'
# log
tensorboardLog_path = work_dir / 'tensorboardLoggs'
tensorboardLog_path.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(log_dir=tensorboardLog_path)
setup_root_logger(work_dir, 0)
logger = logging.getLogger('class_train')
logger.info(args)
logger.info("Collecting env info (might take some time)")
logger.info("\n" + collect_env_info())
# build dataset and dataloader
train_data_path = Path(args.data_path) / 'train'
test_data_path = Path(args.data_path) / 'test'
train_dataset = ClassImageDataset(train_data_path)
test_dataset = ClassImageDataset(test_data_path, augment=False)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.w, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.w, )
# build and initial model
model = resnext50_32x4d(num_classes=8).to(DEVICE)
# criterion
lsr_loss = LSR()
#apply no weight decay on bias
params = split_weights(model)
# butild lr_scheduler
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9, weight_decay=1e-4, nesterov=True)
# set up warmup phase learning rate scheduler
iter_per_epoch = len(train_dataloader)
warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm)
#set up training phase learning rate scheduler
train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 90])
#train_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.e - args.warm)
# resume
start_epoch, best_acc = 1, 0.0
if args.weights is not None:
# pretrained
try:
ckpt = torch.load(args.weights, map_location=DEVICE)
except:
raise "Try to reseum from %s, but got errors!"%args.resume_from
# loading model weights
logger.info(f"Initializing model weights with {args.weights}")
if ckpt['model'] is not None:
model.load_state_dict(ckpt['model'])
else:
model.load_state_dict(ckpt)
if not args.from_scratch:
# resume optimizer
if ckpt['optimizer'] is not None:
logger.info(f"Loading optimizer from {args.weights}")
optimizer.load_state_dict(ckpt['optimizer'])
else:
logger.info(f"No optimizer loaded")
if ckpt['epoch'] is not None:
start_epoch = ckpt['epoch'] + 1
if ckpt['acc'] is not None:
best_acc = ckpt['acc']
else :
logger.info("No checkpoint found. Initializing model from scratch")
# training procedure
logger.info("Start training")
for epoch in range(start_epoch, args.epoches + 1):
# Scheduler update
if epoch > args.warm:
train_scheduler.step(epoch)
model.train()
tn = len(train_dataloader.dataset)
tb = len(train_dataloader)
# start batch ------------------------------------------------------------------------------------------------
start = time.time()
pre = start
for batch_idx, (imgs, _, _, y, _) in enumerate(train_dataloader):
# warmup
if epoch <= args.warm:
warmup_scheduler.step()
# to DEVICE
imgs = rearrange(imgs, 'b n c h w -> (b n) c h w')
y = rearrange(y, 'b n -> (b n)')
imgs, y = imgs.to(DEVICE), y.to(DEVICE)
# forward
pred = model(imgs)
loss = lsr_loss(pred, y)
correct = (pred.argmax(1) == y).type(torch.float).sum().item()
optimizer.zero_grad()
# backward
loss.backward()
# optimize
optimizer.step()
# train visulaization
n_iter = (epoch - 1) * tb + batch_idx + 1
writer.add_scalar('Train/loss', loss.item(), n_iter)
writer.add_scalar('Train/acc', correct / len(imgs) * 100, n_iter)
# logging
cur = time.time()
logger.info('Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tAcc: {:0.2f}%\tLR: {:0.8f}\t[{used_time}<{eta}]'.format(
loss.item(),
correct / len(imgs) * 100,
optimizer.param_groups[0]['lr'],
epoch=epoch,
trained_samples=batch_idx * args.batch_size *32 + len(imgs),
total_samples=tn*32,
used_time = '%02d:%02d' % ((cur - start) // 60, (cur - start) % 60),
eta = '%02d:%02d' % ((cur - start + (tb - batch_idx - 1) * (cur - pre)) // 60, (cur - start + (tb - batch_idx - 1) * (cur - pre)) % 60),
))
pre = cur
# end batch ------------------------------------------------------------------------------------------------
# eval procedure
model.eval()
test_loss, correct = 0, 0
pbar = tqdm(test_dataloader, desc='evaling', total=len(test_dataloader))
with torch.no_grad():
for imgs, _, _, y, _ in pbar:
# to DEVICE
imgs = rearrange(imgs, 'b n c h w -> (b n) c h w')
y = rearrange(y, 'b n -> (b n)')
imgs, y = imgs.to(DEVICE), y.to(DEVICE)
# forward
pred = model(imgs)
test_loss += lsr_loss(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
optimizer.zero_grad()
test_loss /= len(test_dataloader)
acc = correct / len(test_dataloader) * 100
logger.info('Test metrics: Loss: {:0.4f}\tAcc: {:0.2f}%\t'.format(test_loss, acc))
# test visulaization
writer.add_scalar('Test/loss', test_loss, epoch)
writer.add_scalar('Test/acc', acc, epoch)
# save weights file
if not args.nosave:
ckpt = {
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'acc': acc,
}
last_pth = weights_path / last.format(epoch, acc)
torch.save(ckpt, last_path)
if epoch > 10 and best_acc < acc:
best_acc = acc
best_pth = weights_path / best.format(epoch, best_acc)
torch.save(ckpt, best_pth)
# end epoch ------------------------------------------------------------------------------------------------
if __name__ == '__main__':
train()
1,几个参数的介绍:
--work-dir:用来指定保存训练日志、权重参数的文件夹
--weights:初始化模型的权重文件,跟--from-scratch搭配实现预训练或者恢复现场训练
--warm:使用WarmUp初始化学习率所需要的迭代轮数
2,训练流程的定义:
a) 初始化参数
b) 创建工作目录,顺便提一下好用的路径处理包pathlib
c) logger的初始化:这里要注意坑呀,logging类是可以通过名称进行派生的比如:
>>> a = logging.getLogger('a')
>>> b = logging.getLogger('a.b')
>>> a, b
(, )
>>> b.parent
而所有的logger都是RootLogger(可以通过传入空名字得到,logging.getLogger(' '))的子类,因此可以通过设置RootLogger的行为控制所有自定义的logger的行为:
def setup_root_logger(save_dir, distributed_rank, filename="log.txt"):
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
# don't log results for the non-master process
if distributed_rank > 0:
return
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
ch.setFormatter(formatter)
root_logger.addHandler(ch)
if save_dir:
save_dir = Path(save_dir) if not isinstance(save_dir, Path) else save_dir
fh = logging.FileHandler(save_dir / filename, mode='w')
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
root_logger.addHandler(fh)
d) 初始化dataset以及dataloader,这里注意设置dataloader合适的的num_worker.num_worker设置得大,好处是寻batch速度快,因为下一轮迭代的batch很可能在上一轮/上上一轮迭代时已经加载好了。坏处是内存开销大,也加重了CPU负担(worker加载数据到RAM的进程是CPU复制的嘛)。num_workers的经验设置值是自己电脑/服务器的CPU核心数,如果CPU很强、RAM也很充足,就可以设置得更大些。一般开始是将num_workers设置为等于计算机上的CPU数量,最好的办法是缓慢增加num_workers,直到训练速度不再提高,就停止增加num_workers的值。
e) 构建模型以及损失函数。不要忘了对模型参数进行初始化,这在没有预训练权重时很重要。
f) 定义优化器和学习率调整策略。带WarmUp的自己定义比较好,官方的只支持按步调整。
g) 加载预训练权重。最好用一个同一类来实现权重的加载与保存。
f) 执行模型训练,完成一个epoch后在test数据集上测试。