目录如下:
train.py
torch.nn.CrossEntropyLoss类
torch.optim.Adam类
python中enumerate()方法
torch.optim.Adam.zero_grad()方法
我们需要导入time
和datetime
用于计算训练时间;导入torch
用于使用Pytorch框架;导入网络from model import UNet
;导入需要的工具方法from utils import [你需要的方法]
;导入我们的DIYDatesetfrom dataset import DriveDataest
;导入transforms文件import transforms as T
。
import os
import time
import datetime
import torch
from model import UNet
from utils import train_one_epoch, evaluate, create_lr_scheduler
from dataset import DriveDataset
import transforms as T
train.py
在train.py
中我们对网络进行训练,我们首先使用argparse
配置参数,将参数传入main
进行训练。
案例如下:
if __name__ == '__main__':
args = parse_args()
main(args)
使用argparse封装需要的参数。
参看https://blog.csdn.net/qq_43369406/article/details/127787799
argparse函数案例如下:
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="pytorch unet training")
parser.add_argument("--data-path", default="./", help="DRIVE root")
# exclude background
parser.add_argument("--num-classes", default=1, type=int)
parser.add_argument("--device", default="cuda", help="training device")
parser.add_argument("-b", "--batch-size", default=4, type=int)
parser.add_argument("--epochs", default=200, type=int, metavar="N",
help="number of total epochs to train")
parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--print-freq', default=1, type=int, help='print frequency')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--save-best', default=True, type=bool, help='only save best dice weights')
# Mixed precision training parameters
parser.add_argument("--amp", default=False, type=bool,
help="Use torch.cuda.amp for mixed precision training")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
args.data_path
main函数包括训练的全过程,我们一般这样组织main中结构:DataLoader - 训练参数 - epoch
在DataLoader环节我们需要选择合适的Transforms传入Dataset,向DataLoader中传入Dataset和batch,DataLoader就会每次从Dataset中取出batch个数据。其中最为重要的就是选定适合的Transforms传入Dataset中,设定合适的DataLoader。
Transforms选定如下:
DataLoader案例如下:
# dataloader
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
batch_size = args.batch_size
num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
# segmentation nun_classes + background
num_classes = args.num_classes + 1
# using compute_mean_std.py
mean = (0.709, 0.381, 0.224)
std = (0.127, 0.079, 0.043)
train_dataset = DriveDataset(args.data_path,
train=True,
transforms=get_transform(train=True, mean=mean, std=std))
val_dataset = DriveDataset(args.data_path,
train=False,
transforms=get_transform(train=False, mean=mean, std=std))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
pin_memory=True,
collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=1,
num_workers=num_workers,
pin_memory=True,
collate_fn=val_dataset.collate_fn)
在该步骤中需要指定模型,优化器,加载预训练权重(迁移学习)。
# 模型
model = create_model(num_classes=num_classes)
model.to(device)
params_to_optimize = [p for p in model.parameters() if p.requires_grad]
# 优化器
optimizer = torch.optim.SGD(
params_to_optimize,
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
)
我们使用torch.load()
和torch.save()
用来加载和保存训练超参,我们在load和save中指定model.load_state_load()
和model.state_dict()
用来将训练的权重超参保存为字典格式进行存储,如下:
# 保存.pth文件
# 设定文件存储的格式
save_file = {"model": model.state_dict(),
"optimizer": optimizer.state_dict(), # 优化器中参数
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args}
torch.save(save_file, "save_weights/best_model.pth")
# 加载.pth文件
# 从.pth文件中取数据
checkpoint = torch.load(args.resume, map_location='cpu') # args.resume="save_weights/best_model.pth"; map_location指的是映射到CPU上加载模型
model.load_state_dict(checkpoint['model']) # 从dictionary中根据key取value,若是用.state_dict()进行存储,则需要用.load_state_dict()将值取出
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1 # 若不是用.state_dict()取出,则直接取出来用便可
每一次训练都是在epoch中进行,每一个epoch需要进行训练和测试并将训练结果进行存储,并记录每一轮训练时长。
训练的完整代码如下:
# 用来保存训练以及验证过程中信息
results_file = "/home/yingmuzhi/unet/results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
best_dice = 0.
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
mean_loss, lr = train_one_epoch(model, optimizer, train_loader, device, epoch, num_classes,
lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)
confmat, dice = evaluate(model, val_loader, device=device, num_classes=num_classes)
val_info = str(confmat)
print(val_info)
print(f"dice coefficient: {dice:.3f}")
# write into txt
with open(results_file, "a") as f:
# 记录每个epoch对应的train_loss、lr以及验证集各指标
train_info = f"[epoch: {epoch}]\n" \
f"train_loss: {mean_loss:.4f}\n" \
f"lr: {lr:.6f}\n" \
f"dice coefficient: {dice:.3f}\n"
f.write(train_info + val_info + "\n\n")
if args.save_best is True:
if best_dice < dice:
best_dice = dice
else:
continue
save_file = {"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args}
if args.amp:
save_file["scaler"] = scaler.state_dict()
if args.save_best is True:
torch.save(save_file, "save_weights/best_model.pth")
else:
torch.save(save_file, "save_weights/model_{}.pth".format(epoch))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("training time {}".format(total_time_str))