pytorch学习笔记-模型保存与加载,断点续训练

一、模型的保存与加载

PyTorch中的保存(序列化,从内存到硬盘)与反序列化(加载,从硬盘到内存)

  1. torch.save 主要参数: • obj:对象 • f:输出路径
  2. torch.load 主要参数 • f:文件路径 • map_location:指定存放位置, cpu or gpu
    两种方法:
    法1: 保存整个Module
torch.save(net, path)

法2: 保存模型参数

state_dict = net.state_dict()
torch.save(state_dict , path)

二、断点续训练

checkpoint = { "model_state_dict": net.state_dict(),
 "optimizer_state_dict": optimizer.state_dict(), 
 "epoch": epoch }

这里注意实在epoch中保存断点:

for epoch in range(start_epoch+1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] 
            Loss: {:.4f} Acc:{:.2%}".
            format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, 
            correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    if (epoch+1) % checkpoint_interval == 0:

        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)

意外中断的时候,要在训练之前加载保存的断点信息:

# ============================ step 5+/5 断点恢复 ============================

path_checkpoint = "./checkpoint_4_epoch.pkl"#断点路径
checkpoint = torch.load(path_checkpoint)#加载断点

net.load_state_dict(checkpoint['model_state_dict'])#加载模型可学习参数

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])#加载优化器参数

start_epoch = checkpoint['epoch']#设置开始的epoch

scheduler.last_epoch = start_epoch#设置学习率的last_epoch

# ============================ step 5/5 训练 ============================

你可能感兴趣的:(pytorch学习笔记-模型保存与加载,断点续训练)