PyTorch 模型保存,断点训练

在epoch前插入:

    initepoch = 0
    resume = True  # 设置是否需要从上次的状态继续训练
    if resume:
        if os.path.isfile("./testweights/last_model.pth"):
            print("Resume from checkpoint...")
            checkpoint = torch.load("./testweights/last_model.pth")
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            initepoch = checkpoint['epoch'] + 1
            print("====>loaded checkpoint (epoch{})".format(checkpoint['epoch']))
        else:
            print("====>no checkpoint found.")
            initepoch = 0  # 如果没进行训练过,初始训练epoch值为0

epoch循环改为:

for epoch in range(initepoch, args.epochs):

在epoch中插入:

        # save best epoch
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "testweights/best_model.pth")
            print("!!--Best Model has Update--!!")

        # save epoch model
        torch.save(model.state_dict(), "./testweights/model-{}.pth".format(epoch))

        # save last model
        checkpoint = {"model_state_dict": model.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./testweights/last_model.pth"
        torch.save(checkpoint, path_checkpoint)
        print("!!--Last Model has Update(-{})--!!".format(epoch))

你可能感兴趣的:(记录小代码,pytorch,深度学习)