使用 PyTorch 搭建网络 - train_py篇

train.py

目录如下:

  • 导包
  • train.py
  • argparse配置参数
  • main函数
  • torch.nn.CrossEntropyLoss类
  • torch.optim.Adam类
  • python中enumerate()方法
  • torch.optim.Adam.zero_grad()方法
  • FP,BP
  • 待解决问题
  • 源码

导包

我们需要导入timedatetime用于计算训练时间;导入torch用于使用Pytorch框架;导入网络from model import UNet;导入需要的工具方法from utils import [你需要的方法];导入我们的DIYDatesetfrom dataset import DriveDataest;导入transforms文件import transforms as T

import os
import time
import datetime

import torch

from model import UNet
from utils import train_one_epoch, evaluate, create_lr_scheduler
from dataset import DriveDataset
import transforms as T

train.py

train.py中我们对网络进行训练,我们首先使用argparse配置参数,将参数传入main进行训练。

案例如下:

if __name__ == '__main__':
    args = parse_args()
    main(args)

argparse配置参数

使用argparse封装需要的参数。

参看https://blog.csdn.net/qq_43369406/article/details/127787799

argparse函数案例如下:

def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description="pytorch unet training")

    parser.add_argument("--data-path", default="./", help="DRIVE root")
    # exclude background
    parser.add_argument("--num-classes", default=1, type=int)
    parser.add_argument("--device", default="cuda", help="training device")
    parser.add_argument("-b", "--batch-size", default=4, type=int)
    parser.add_argument("--epochs", default=200, type=int, metavar="N",
                        help="number of total epochs to train")

    parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('--print-freq', default=1, type=int, help='print frequency')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--save-best', default=True, type=bool, help='only save best dice weights')
    # Mixed precision training parameters
    parser.add_argument("--amp", default=False, type=bool,
                        help="Use torch.cuda.amp for mixed precision training")

    args = parser.parse_args()

    return args


if __name__ == '__main__':
    args = parse_args()
	args.data_path

main函数

main函数包括训练的全过程,我们一般这样组织main中结构:DataLoader - 训练参数 - epoch

DataLoader

在DataLoader环节我们需要选择合适的Transforms传入Dataset,向DataLoader中传入Dataset和batch,DataLoader就会每次从Dataset中取出batch个数据。其中最为重要的就是选定适合的Transforms传入Dataset中,设定合适的DataLoader。

Transforms选定如下:

DataLoader案例如下:

# dataloader
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
batch_size = args.batch_size
num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
# segmentation nun_classes + background
num_classes = args.num_classes + 1
# using compute_mean_std.py
mean = (0.709, 0.381, 0.224)
std = (0.127, 0.079, 0.043)
train_dataset = DriveDataset(args.data_path,
                             train=True,
                             transforms=get_transform(train=True, mean=mean, std=std))
val_dataset = DriveDataset(args.data_path,
                           train=False,
                           transforms=get_transform(train=False, mean=mean, std=std))
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           num_workers=num_workers,
                                           shuffle=True,
                                           pin_memory=True,
                                           collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=1,
                                         num_workers=num_workers,
                                         pin_memory=True,
                                         collate_fn=val_dataset.collate_fn)

训练参数

在该步骤中需要指定模型,优化器,加载预训练权重(迁移学习)。

# 模型
model = create_model(num_classes=num_classes)
model.to(device)

params_to_optimize = [p for p in model.parameters() if p.requires_grad]
# 优化器
optimizer = torch.optim.SGD(
        params_to_optimize,
        lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
)

我们使用torch.load()torch.save()用来加载和保存训练超参,我们在load和save中指定model.load_state_load()model.state_dict()用来将训练的权重超参保存为字典格式进行存储,如下:

# 保存.pth文件
# 设定文件存储的格式
save_file = {"model": model.state_dict(),
             "optimizer": optimizer.state_dict(),	# 优化器中参数
             "lr_scheduler": lr_scheduler.state_dict(),
             "epoch": epoch,
             "args": args}
torch.save(save_file, "save_weights/best_model.pth")

# 加载.pth文件
# 从.pth文件中取数据
checkpoint = torch.load(args.resume, map_location='cpu')	# args.resume="save_weights/best_model.pth"; map_location指的是映射到CPU上加载模型
        model.load_state_dict(checkpoint['model'])			# 从dictionary中根据key取value,若是用.state_dict()进行存储,则需要用.load_state_dict()将值取出
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1			# 若不是用.state_dict()取出,则直接取出来用便可

epoch

每一次训练都是在epoch中进行,每一个epoch需要进行训练和测试并将训练结果进行存储,并记录每一轮训练时长。

训练的完整代码如下:

# 用来保存训练以及验证过程中信息
    results_file = "/home/yingmuzhi/unet/results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))   
    best_dice = 0.
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        mean_loss, lr = train_one_epoch(model, optimizer, train_loader, device, epoch, num_classes,
                                        lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)

        confmat, dice = evaluate(model, val_loader, device=device, num_classes=num_classes)
        val_info = str(confmat)
        print(val_info)
        print(f"dice coefficient: {dice:.3f}")
        # write into txt
        with open(results_file, "a") as f:
            # 记录每个epoch对应的train_loss、lr以及验证集各指标
            train_info = f"[epoch: {epoch}]\n" \
                         f"train_loss: {mean_loss:.4f}\n" \
                         f"lr: {lr:.6f}\n" \
                         f"dice coefficient: {dice:.3f}\n"
            f.write(train_info + val_info + "\n\n")

        if args.save_best is True:
            if best_dice < dice:
                best_dice = dice
            else:
                continue

        save_file = {"model": model.state_dict(),
                     "optimizer": optimizer.state_dict(),
                     "lr_scheduler": lr_scheduler.state_dict(),
                     "epoch": epoch,
                     "args": args}
        if args.amp:
            save_file["scaler"] = scaler.state_dict()

        if args.save_best is True:
            torch.save(save_file, "save_weights/best_model.pth")
        else:
            torch.save(save_file, "save_weights/model_{}.pth".format(epoch))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("training time {}".format(total_time_str))

你可能感兴趣的:(人工智能,pytorch,深度学习,python)