pytorch 模型的断点训练

深度学习训练非常耗费时间,如果中间中止训练,则会非常麻烦。所以为了节省时间,通常会让模型从断点中继续训练。为此,需要模型在训练的过程中保存一些关键信息,比如模型的参数、优化器的配置,epoch等。

模型保存

def save_checkpoint(model, epoch,loss,optimizer):
    model_out_path = "model/" + "model_epoch_{}.pth".format(epoch)
    state = {"epoch": epoch,
             "model": model,
             'loss':loss,
             'optimizer': optimizer.state_dict()}
    # check path status
    if not os.path.exists("model/"):
        os.makedirs("model/")
    # save model
    torch.save(state, model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))

断点恢复

if opt.resume:
        if os.path.isfile(opt.resume):
            print("===> loading checkpoint: {}".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1
            model.load_state_dict(checkpoint["model"].state_dict())
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print("===> no checkpoint found at {}".format(opt.resume))

    # optionally copy weights from a checkpoint
    if opt.pretrained:
        if os.path.isfile(opt.pretrained):
            print("===> load model {}".format(opt.pretrained))
            weights = torch.load(opt.pretrained)
            model.load_state_dict(weights['model'].state_dict())
        else:
            print("===> no model found at {}".format(opt.pretrained))


你可能感兴趣的:(pytorch,pytorch,深度学习,python)