脚本 train.py 是用来训练模型的脚本,训练模型首先需要载入数据集,然后开始训练过程,训练完成后可以根据训练结果绘制 loss 曲线图,并保存训练好的模型参数。本文将按照训练模型的流程,分别解析对应步骤的代码。
通过执行数据处理脚本 prepare.py ,我们已经将数据集组织成了 datasets.ImageFolder 可以直接使用的数据集结构。要想将数据集载入模型还需要将数据集张量化并生成数据集迭代器。
使用 datasets.ImageFolder 可以将图片格式的数据集变为 pytorch 支持的张量 tensor ,如果对 transform 参数进行设置,则会对数据集的图片进行数据增强等变换。
调用 datasets.ImageFolder 后生成了 pytorch 支持的数据集 image_datasets[‘train’] 和 image_datasets[‘val’] 。
image_datasets = {}
image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train'),
data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),
data_transforms['val'])
可以通过 pytorch 的 transforms 库引入 transform,针对训练集和测试集进行不同的 transform 变化
from torchvision import datasets, transforms
transform_train_list = [
#transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
transforms.Resize((h, w), interpolation=3),
transforms.Pad(10),
transforms.RandomCrop((h, w)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
transform_val_list = [
transforms.Resize(size=(h, w),interpolation=3), #Image.BICUBIC
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
data_transforms = {
'train': transforms.Compose(transform_train_list),
'val': transforms.Compose(transform_val_list),
}
训练模型时,一般不会一次性把所有数据都加载到模型中。通常采用 mini_batch 的方法,按照 batchsize 的大小将一个 batch 的数据载入到模型中。pytorch 框架支持用 torch.utils.data.DataLoader 作为 dataloader 载入数据。
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
shuffle=True, num_workers=0, pin_memory=True) # 8 workers may work faster
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
将 image_datasets[‘train’] 和 image_datasets[‘val’] 输入 torch.utils.data.DataLoader 后,获得了两个迭代器 dataloaders[‘train’] and dataloaders[‘val’] 。
下面来介绍一下 torch.utils.data.DataLoader 的主要参数
class torch.utils.data.DataLoader(dataset,
batch_size=1,
shuffle=False,
sampler=None,
num_workers=0,
collate_fn=<function default_collate>,
pin_memory=False,
drop_last=False)
torch.utils.data.DataLoader 将返回一个数据迭代器。
参数说明:
在函数 train_model 中,实现了模型训练过程。网络模型一般会迭代多轮以达到一个很好的训练效果,通常通过循环执行一段训练代码来实现迭代训练。
下面对主要的训练代码进行解析:
# Iterate over data.
for data in dataloaders[phase]:
# 载入一个 batch 的输入
# 数据迭代器返回一个 batch 的图像及其标签
inputs, labels = data
now_batch_size,c,h,w = inputs.shape
if now_batch_size<opt.batchsize: # skip the last batch
continue
# print(inputs.shape)
# 变量化输入
if use_gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels)
# 开始训练
# 将梯度参数置零
optimizer.zero_grad()
# 前向传播,计算损失
#-------- forward --------
outputs = model(inputs)
# preds 是 softmax 概率最大的类别的索引, 即模型预测的类别
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
# 只在 train 模式下执行,反向传播,梯度下降优化,
#-------- backward + optimize --------
# only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
训练过程中,还可以使用 warm_up 等学习率策略。
模型训练过程中,还会涉及到模型加载。在训练模式下,模型的网络参数会发生改变;而在验证模式下,一般不进行梯度下降反向传播等操作,我们希望网络参数保持不变。此时会考虑使用 model.load_state_dict 加载最佳模型参数进行验证。
注意
model.load_state_dict 是深拷贝,可以保证加载的是最佳模型参数
model.state_dict 是浅拷贝,保存的是最后一轮训练的模型参数
另外使用预训练迁移模型的部分层参数时,记得令 strict=False,即
model.load_state_dict(state_dict, strict=False)。strict 默认为 True,表示严格按照名称加载参数,如果出现未定义的名称,就会报错。如果将 strict=False,则会忽略未定义的名称,不会报错。
# deep copy the model
if phase == 'val':
last_model_wts = model.state_dict()
if epoch%10 == 9:
save_network(model, epoch)
draw_curve(epoch)
if phase == 'train':
scheduler.step()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
#print('Best val Acc: {:4f}'.format(best_acc))
# load best model weights
model.load_state_dict(last_model_wts)
save_network(model, 'last')
训练过程中,一般会保存训练好的模型参数,方便下次训练时加载模型。为了监控训练过程,一般还会绘制 loss 曲线。
baseline 通过 torch.save 实现模型参数的保存,具体代码如下:
# Save model
#---------------------------
def save_network(network, epoch_label):
save_filename = 'net_%s.pth'% epoch_label
# save_path = os.path.join('./model',name,save_filename)
save_path = os.path.join('model', name, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if torch.cuda.is_available():
network.cuda(gpu_ids[0])
pytorch 一般使用如下代码实现模型的保存和加载
# save
torch.save(model.state_dict(), PATH)
# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
使用 pyplot 库可以实现绘图,loss 曲线绘制代码如下:
# Draw Curve
#---------------------------
import matplotlib.pyplot as plt
x_epoch = []
fig = plt.figure()
ax0 = fig.add_subplot(121, title="loss")
ax1 = fig.add_subplot(122, title="top1err")
def draw_curve(current_epoch):
x_epoch.append(current_epoch)
ax0.plot(x_epoch, y_loss['train'], 'bo-', label='train')
ax0.plot(x_epoch, y_loss['val'], 'ro-', label='val')
ax1.plot(x_epoch, y_err['train'], 'bo-', label='train')
ax1.plot(x_epoch, y_err['val'], 'ro-', label='val')
if current_epoch == 0:
ax0.legend()
ax1.legend()
# fig.savefig( os.path.join('./model',name,'train.jpg'))
fig.savefig(os.path.join('model', name, 'train.jpg'))