文章目录
- train.py
- dataload_five_flower.py
- train_engin.py
- lr_methods.py
- __init__.py
- train_sample.py 和 test.py 见文章:
- 用 parser 方便服务器中的终端操作
- 第三个代码将 mac的 mps 和 cuda 混合了,有点问题,看下代码整体思想就行,不用去跑
- 因为我的电脑是 mac 的 mps,还没找到代码的替代方法
- 可以直接用上面那篇文章中的 train_sample.py
- 只要不是训练,cpu 一般都比 cuda快
train.py
############################################################################################################
# 相较于简单版本的训练脚本 train_sample 增添了以下功能:
# 1. 使用argparse类实现可以在训练的启动命令中指定超参数
# 2. 可以通过在启动命令中指定 --seed 来固定网络的初始化方式,以达到结果可复现的效果
# 3. 使用了更高级的学习策略 cosine warm up:在训练的第一轮使用一个较小的lr(warm_up),从第二个epoch开始,随训练轮数逐渐减小lr。
# 4. 可以通过在启动命令中指定 --model 来选择使用的模型
# 5. 使用amp包实现半精度训练,在保证准确率的同时尽可能的减小训练成本
# 6. 实现了数据加载类的自定义实现
# 7. 可以通过在启动命令中指定 --tensorboard 来进行tensorboard可视化, 默认不启用。
# 注意,使用tensorboad之前需要使用命令 "tensorboard --logdir= log_path"来启动,结果通过网页 http:
############################################################################################################
# --model 可选的超参如下:
# alexnet zfnet vgg vgg_tiny vgg_small vgg_big googlenet xception resnet_small resnet resnet_big resnext resnext_big
# densenet_tiny densenet_small densenet densenet_big mobilenet_v3 mobilenet_v3_large shufflenet_small shufflenet
# efficient_v2_small efficient_v2 efficient_v2_large convnext_tiny convnext_small convnext convnext_big convnext_huge
# vision_transformer_small vision_transformer vision_transformer_big swin_transformer_tiny swin_transformer_small swin_transformer
# 训练命令示例: # python train.py --model alexnet --num_classes 5
############################################################################################################
import os
import argparse
import math
import shutil
import random
import numpy as np
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler
import classic_models
from utils.lr_methods import warmup
from dataload.dataload_five_flower import Five_Flowers_Load
from utils.train_engin import train_one_epoch, evaluate
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5, help='the number of classes')
parser.add_argument('--epochs', type=int, default=50, help='the number of training epoch')
parser.add_argument('--batch_size', type=int, default=64, help='batch_size for training')
parser.add_argument('--lr', type=float, default=0.0002, help='star learning rate')
parser.add_argument('--lrf', type=float, default=0.0001, help='end learning rate')
parser.add_argument('--seed', default=False, action='store_true', help='fix the initialization of parameters')
parser.add_argument('--tensorboard', default=False, action='store_true', help=' use tensorboard for visualization')
parser.add_argument('--use_amp', default=False, action='store_true', help=' training with mixed precision')
# 数据路径需要改成自己的
parser.add_argument('--data_path', type=str, default="/Users/jiangxiyu/根目录/深度学习/flower")
parser.add_argument('--model', type=str, default="vgg", help=' select a model for training')
parser.add_argument('--device', default='mps', help='device id (i.e. 0 or 0,1 or cpu)')
# 把超参数实例化
opt = parser.parse_args()
if opt.seed:
def seed_torch(seed=7):
random.seed(seed) # Python random module.
os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现
np.random.seed(seed) # Numpy module.
torch.manual_seed(seed) # 为CPU设置随机种子
# mac m1 mps gpu可以不用
# torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
# torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
# 设置cuDNN:cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
# 实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。
print('random seed has been fixed')
seed_torch()
def main(args):
# mac m1 gpu
device = torch.device(args.device if torch.backends.mps.is_available() else "cpu")
print(args)
if opt.tensorboard:
# 这是存放你要使用tensorboard显示的数据的绝对路径
log_path = os.path.join('./results/tensorboard' , args.model)
print('Start Tensorboard with "tensorboard --logdir={}"'.format(log_path))
if os.path.exists(log_path) is False:
os.makedirs(log_path)
print("tensorboard log save in {}".format(log_path))
else:
shutil.rmtree(log_path) #当log文件存在时删除文件夹。记得在代码最开始import shutil
# 实例化一个tensorboard
tb_writer = SummaryWriter(log_path)
# 数据集比较大的归一化ImageNet [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
# 对标pytorch封装好的ImageFlolder,我们自己实现了一个数据加载类 Five_Flowers_Load,并使用指定的预处理操作来处理图像,结果会同时返回图像和对应的标签。
train_dataset = Five_Flowers_Load(os.path.join(args.data_path , 'train'), transform=data_transform["train"])
val_dataset = Five_Flowers_Load(os.path.join(args.data_path , 'val'), transform=data_transform["val"])
if args.num_classes != train_dataset.num_class:
raise ValueError("dataset have {} classes, but input {}".format(train_dataset.num_class, args.num_classes))
nw = min([os.cpu_count(), args.batch_size if args.batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
# 使用 DataLoader 将加载的数据集处理成批量(batch)加载模式
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=nw, collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=nw, collate_fn=val_dataset.collate_fn)
# create model
model = classic_models.find_model_using_name(opt.model, num_classes=opt.num_classes).to(device)
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(pg, lr=args.lr)
# Scheduler https:
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
best_acc = 0.
# save parameters path
save_path = os.path.join(os.getcwd(), 'results/weights', args.model)
if os.path.exists(save_path) is False:
os.makedirs(save_path)
for epoch in range(args.epochs):
# train
mean_loss, train_acc = train_one_epoch(model=model, optimizer=optimizer, data_loader=train_loader, device=device, epoch=epoch, use_amp=args.use_amp, lr_method= warmup)
scheduler.step()
# validate
val_acc = evaluate(model=model, data_loader=val_loader, device=device)
print('[epoch %d] train_loss: %.3f train_acc: %.3f val_accuracy: %.3f' % (epoch + 1, mean_loss, train_acc, val_acc))
with open(os.path.join(save_path, "AlexNet_log.txt"), 'a') as f:
f.writelines('[epoch %d] train_loss: %.3f train_acc: %.3f val_accuracy: %.3f' % (epoch + 1, mean_loss, train_acc, val_acc) + '\n')
if opt.tensorboard:
tags = ["train_loss", "train_acc", "val_accuracy", "learning_rate"]
tb_writer.add_scalar(tags[0], mean_loss, epoch)
tb_writer.add_scalar(tags[1], train_acc, epoch)
tb_writer.add_scalar(tags[2], val_acc, epoch)
tb_writer.add_scalar(tags[3], optimizer.param_groups[0]["lr"], epoch)
# 判断当前验证集的准确率是否是最大的,如果是,则更新之前保存的权重
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), os.path.join(save_path, "AlexNet.pth"))
if __name__ == '__main__':
main(opt)
dataload_five_flower.py
- 不同的数据集,torch封装的dataload不一定适配,所以要学会自己封装dataload
from PIL import Image
from matplotlib.cbook import ls_mapper
import torch
from torch.utils.data import Dataset
import random
import os
class Five_Flowers_Load(Dataset):
def __init__(self, data_path: str, transform=None):
self.data_path = data_path
self.transform = transform
random.seed(0) # 保证随机结果可复现
assert os.path.exists(data_path), "dataset root: {} does not exist.".format(data_path)
# 遍历文件夹,一个文件夹对应一个类别,['daisy', 'dandelion', 'roses', 'sunflower', 'tulips']
flower_class = [cla for cla in os.listdir(os.path.join(data_path))] # 得到一个列表
self.num_class = len(flower_class)
# 排序,保证顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引 {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
class_indices = dict((cla, idx) for idx, cla in enumerate(flower_class))
self.images_path = [] # 存储训练集的所有图片路径
self.images_label = [] # 存储训练集图片对应索引信息
self.images_num = [] # 存储每个类别的样本总数
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
# 遍历每个文件夹下的文件
for cla in flower_class:
cla_path = os.path.join(data_path, cla)
# 遍历获取supported支持的所有文件路径
images = [os.path.join(data_path, cla, i) for i in os.listdir(cla_path) if os.path.splitext(i)[-1] in supported]
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
self.images_num.append(len(images))
# 写入列表
for img_path in images:
self.images_path.append(img_path)
self.images_label.append(image_class)
print("{} images were found in the dataset.".format(sum(self.images_num)))
def __len__(self):
return sum(self.images_num)
def __getitem__(self, idx):
img = Image.open(self.images_path[idx])
label = self.images_label[idx]
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[idx]))
if self.transform is not None:
img = self.transform(img)
else:
raise ValueError('Image is not preprocessed')
return img, label
# 非必须实现,torch里有默认实现;该函数的作用是: 决定一个batch的数据以什么形式来返回数据和标签
# 官方实现的default_collate可以参考
# https:
@staticmethod
def collate_fn(batch):
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
train_engin.py
import sys
import torch
from tqdm import tqdm
from utils.distrubute_utils import is_main_process, reduce_value
from utils.lr_methods import warmup
def train_one_epoch(model, optimizer, data_loader, device, epoch, use_amp=False, lr_method=None):
model.train()
loss_function = torch.nn.CrossEntropyLoss()
train_loss = torch.zeros(1).to(device)
acc_num = torch.zeros(1).to(device)
optimizer.zero_grad()
lr_scheduler = None
if epoch == 0 and lr_method == warmup :
warmup_factor = 1.0/1000
warmup_iters = min(1000, len(data_loader) -1)
lr_scheduler = warmup(optimizer, warmup_iters, warmup_factor)
if is_main_process():
data_loader = tqdm(data_loader, file=sys.stdout)
# 创建一个梯度缩放标量,以最大程度避免使用fp16进行运算时的梯度下溢
enable_amp = use_amp and "mps" in device.type
scaler = torch.cuda.amp.GradScaler(enabled=enable_amp)
sample_num = 0
for step, data in enumerate(data_loader):
images, labels = data
sample_num += images.shape[0]
with torch.cuda.amp.autocast(enabled=enable_amp):
pred = model(images.to(device))
loss = loss_function(pred, labels.to(device))
pred_class = torch.max(pred, dim=1)[1]
acc_num += torch.eq(pred_class, labels.to(device)).sum()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
train_loss += reduce_value(loss, average=True).detach()
# 在进程中打印平均loss
if is_main_process():
info = '[epoch{}]: learning_rate:{:.5f}'.format(
epoch + 1,
optimizer.param_groups[0]["lr"]
)
data_loader.desc = info # tqdm 成员 desc
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)
if lr_scheduler is not None: # 如果使用warmup训练,逐渐调整学习率
lr_scheduler.step()
# 等待所有进程计算完毕
if device != torch.device('cpu'):
torch.cuda.synchronize(device)
return train_loss.item() / (step + 1), acc_num.item() / sample_num
@torch.no_grad()
def evaluate(model, data_loader, device):
model.eval()
# 验证集样本个数
num_samples = len(data_loader.dataset)
# 用于存储预测正确的样本个数
sum_num = torch.zeros(1).to(device)
for step, data in enumerate(data_loader):
images, labels = data
pred = model(images.to(device))
pred_class = torch.max(pred, dim=1)[1]
sum_num += torch.eq(pred_class, labels.to(device)).sum()
# 等待所有进程计算完毕
if device != torch.device('cpu'):
torch.cuda.synchronize(device)
sum_num = reduce_value(sum_num, average=False)
val_acc = sum_num.item() / num_samples
return val_acc
lr_methods.py
import torch
def warmup(optimizer, warm_up_iters, warm_up_factor):
def f(x):
"""根据step数返回一个学习率倍率因子, x代表step"""
if x >= warm_up_iters:
return 1
alpha = float(x) / warm_up_iters
# 迭代过程中倍率因子从warmup_factor -> 1
return warm_up_factor * (1 - alpha) + alpha
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
init.py
from .alexnet import alexnet
from .vggnet import vgg11, vgg13, vgg16, vgg19
from .zfnet import zfnet
from .googlenet_v1 import googlenet
from .xception import xception
from .resnet import resnet34, resnet50, resnet101, resnext50_32x4d, resnext101_32x8d
from .densenet import densenet121, densenet161, densenet169, densenet201
from .dla import dla34
from .mobilenet_v3 import mobilenet_v3_small, mobilenet_v3_large
from .shufflenet_v2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from .efficientnet_v2 import efficientnetv2_l, efficientnetv2_m, efficientnetv2_s
from .convnext import convnext_tiny, convnext_small, convnext_base, convnext_large, convnext_xlarge
from .vision_transformer import vit_base_patch16_224, vit_base_patch32_224, vit_large_patch16_224
from .swin_transformer import swin_tiny_patch4_window7_224, swin_small_patch4_window7_224, swin_base_patch4_window7_224
cfgs = {
'alexnet': alexnet,
'zfnet': zfnet,
'vgg': vgg16,
'vgg_tiny': vgg11,
'vgg_small': vgg13,
'vgg_big': vgg19,
'googlenet': googlenet,
'xception': xception,
'resnet_small': resnet34,
'resnet': resnet50,
'resnet_big': resnet101,
'resnext': resnext50_32x4d,
'resnext_big': resnext101_32x8d,
'densenet_tiny': densenet121,
'densenet_small': densenet161,
'densenet': densenet169,
'densenet_big': densenet121,
'dla': dla34,
'mobilenet_v3': mobilenet_v3_small,
'mobilenet_v3_large': mobilenet_v3_large,
'shufflenet_small':shufflenet_v2_x0_5,
'shufflenet': shufflenet_v2_x1_0,
'efficient_v2_small': efficientnetv2_s,
'efficient_v2': efficientnetv2_m,
'efficient_v2_large': efficientnetv2_l,
'convnext_tiny': convnext_tiny,
'convnext_small': convnext_small,
'convnext': convnext_base,
'convnext_big': convnext_large,
'convnext_huge': convnext_xlarge,
'vision_transformer_small': vit_base_patch32_224,
'vision_transformer': vit_base_patch16_224,
'vision_transformer_big': vit_large_patch16_224,
'swin_transformer_tiny': swin_tiny_patch4_window7_224,
'swin_transformer_small': swin_small_patch4_window7_224,
'swin_transformer': swin_base_patch4_window7_224
}
def find_model_using_name(model_name, num_classes):
return cfgs[model_name](num_classes)