Person_reID_baseline_pytorch 源码解析之 train.py

源码解析之模型训练

  • 1. 载入数据集
    • 1.1 数据集张量化
    • 1.2 数据集迭代器
  • 2. 开始训练
    • 2.1 训练代码
    • 2.2 模型加载
  • 3. 结果保存
    • 3.1 模型保存
    • 3.2 loss 曲线绘制
  • 参考文献

脚本 train.py 是用来训练模型的脚本,训练模型首先需要载入数据集,然后开始训练过程,训练完成后可以根据训练结果绘制 loss 曲线图,并保存训练好的模型参数。本文将按照训练模型的流程,分别解析对应步骤的代码。

1. 载入数据集

通过执行数据处理脚本 prepare.py ,我们已经将数据集组织成了 datasets.ImageFolder 可以直接使用的数据集结构。要想将数据集载入模型还需要将数据集张量化并生成数据集迭代器。

1.1 数据集张量化

使用 datasets.ImageFolder 可以将图片格式的数据集变为 pytorch 支持的张量 tensor ,如果对 transform 参数进行设置,则会对数据集的图片进行数据增强等变换。

调用 datasets.ImageFolder 后生成了 pytorch 支持的数据集 image_datasets[‘train’] 和 image_datasets[‘val’] 。

image_datasets = {}
image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train'),
                                          data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),
                                          data_transforms['val'])

可以通过 pytorch 的 transforms 库引入 transform,针对训练集和测试集进行不同的 transform 变化

from torchvision import datasets, transforms
transform_train_list = [
        #transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
        transforms.Resize((h, w), interpolation=3),
        transforms.Pad(10),
        transforms.RandomCrop((h, w)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]

transform_val_list = [
        transforms.Resize(size=(h, w),interpolation=3), #Image.BICUBIC
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
        
data_transforms = {
    'train': transforms.Compose(transform_train_list),
    'val': transforms.Compose(transform_val_list),
}

1.2 数据集迭代器

训练模型时,一般不会一次性把所有数据都加载到模型中。通常采用 mini_batch 的方法,按照 batchsize 的大小将一个 batch 的数据载入到模型中。pytorch 框架支持用 torch.utils.data.DataLoader 作为 dataloader 载入数据。

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                             shuffle=True, num_workers=0, pin_memory=True) # 8 workers may work faster
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

将 image_datasets[‘train’] 和 image_datasets[‘val’] 输入 torch.utils.data.DataLoader 后,获得了两个迭代器 dataloaders[‘train’] and dataloaders[‘val’] 。

下面来介绍一下 torch.utils.data.DataLoader 的主要参数

class torch.utils.data.DataLoader(dataset, 
								batch_size=1, 
								shuffle=False, 
								sampler=None, 
								num_workers=0, 
								collate_fn=<function default_collate>, 
								pin_memory=False, 
								drop_last=False)

torch.utils.data.DataLoader 将返回一个数据迭代器。

参数说明:

  • dataset (Dataset) – 加载数据的数据集
  • batch_size (int) – 每个batch加载多少个样本(默认: 1)
  • shuffle (bool) – 设置为True时会在每个epoch重新打乱数据(默认: False)
  • sampler (Sampler) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数
  • num_workers (int) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
  • drop_last (bool, optional) – 如果数据集大小不能被 batch size 整除,则设置为 True 后可删除最后一个不完整的batch。如果设为 False 并且数据集的大小不能被 batch size 整除,则最后一个batch将更小。(默认: False)

2. 开始训练

在函数 train_model 中,实现了模型训练过程。网络模型一般会迭代多轮以达到一个很好的训练效果,通常通过循环执行一段训练代码来实现迭代训练。

2.1 训练代码

下面对主要的训练代码进行解析:

			# Iterate over data.
            for data in dataloaders[phase]:
                # 载入一个 batch 的输入
                # 数据迭代器返回一个 batch 的图像及其标签
                inputs, labels = data
                now_batch_size,c,h,w = inputs.shape
                if now_batch_size<opt.batchsize: # skip the last batch
                    continue
                # print(inputs.shape)
                # 变量化输入
                if use_gpu:
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)
				# 开始训练
                # 将梯度参数置零
                optimizer.zero_grad()
				
				# 前向传播,计算损失
                #-------- forward --------
                outputs = model(inputs)
                # preds 是 softmax 概率最大的类别的索引, 即模型预测的类别
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)
				
				# 只在 train 模式下执行,反向传播,梯度下降优化, 
                #-------- backward + optimize -------- 
                # only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

训练过程中,还可以使用 warm_up 等学习率策略。

2.2 模型加载

模型训练过程中,还会涉及到模型加载。在训练模式下,模型的网络参数会发生改变;而在验证模式下,一般不进行梯度下降反向传播等操作,我们希望网络参数保持不变。此时会考虑使用 model.load_state_dict 加载最佳模型参数进行验证。

注意
model.load_state_dict 是深拷贝,可以保证加载的是最佳模型参数
model.state_dict 是浅拷贝,保存的是最后一轮训练的模型参数

另外使用预训练迁移模型的部分层参数时,记得令 strict=False,即
model.load_state_dict(state_dict, strict=False)。strict 默认为 True,表示严格按照名称加载参数,如果出现未定义的名称,就会报错。如果将 strict=False,则会忽略未定义的名称,不会报错。

            # deep copy the model
            if phase == 'val':
                last_model_wts = model.state_dict()
                if epoch%10 == 9:
                    save_network(model, epoch)
                draw_curve(epoch)
            if phase == 'train':
               scheduler.step()
        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    #print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(last_model_wts)
    save_network(model, 'last')

3. 结果保存

训练过程中,一般会保存训练好的模型参数,方便下次训练时加载模型。为了监控训练过程,一般还会绘制 loss 曲线。

3.1 模型保存

baseline 通过 torch.save 实现模型参数的保存,具体代码如下:

# Save model
#---------------------------
def save_network(network, epoch_label):
    save_filename = 'net_%s.pth'% epoch_label
    # save_path = os.path.join('./model',name,save_filename)
    save_path = os.path.join('model', name, save_filename)
    torch.save(network.cpu().state_dict(), save_path)
    if torch.cuda.is_available():
        network.cuda(gpu_ids[0])

pytorch 一般使用如下代码实现模型的保存和加载

# save
torch.save(model.state_dict(), PATH)

# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

3.2 loss 曲线绘制

使用 pyplot 库可以实现绘图,loss 曲线绘制代码如下:

# Draw Curve
#---------------------------
import matplotlib.pyplot as plt
x_epoch = []
fig = plt.figure()
ax0 = fig.add_subplot(121, title="loss")
ax1 = fig.add_subplot(122, title="top1err")
def draw_curve(current_epoch):
    x_epoch.append(current_epoch)
    ax0.plot(x_epoch, y_loss['train'], 'bo-', label='train')
    ax0.plot(x_epoch, y_loss['val'], 'ro-', label='val')
    ax1.plot(x_epoch, y_err['train'], 'bo-', label='train')
    ax1.plot(x_epoch, y_err['val'], 'ro-', label='val')
    if current_epoch == 0:
        ax0.legend()
        ax1.legend()
    # fig.savefig( os.path.join('./model',name,'train.jpg'))
    fig.savefig(os.path.join('model', name, 'train.jpg'))

参考文献

  1. 从零开始行人重识别
  2. Person_reID_baseline_pytorch
  3. torch.max()使用讲解
  4. 源码详解Pytorch的state_dict和load_state_dict
  5. Pytorch踩坑记:赋值、浅拷贝、深拷贝三者的区别以及model.state_dict()和model.load_state_dict()的坑点
  6. torch.load_state_dict()函数的用法总结

你可能感兴趣的:(行人重识别,pytorch,深度学习,计算机视觉)