Pytorch保存模型

1、假设在某个epoch,我们要保存模型参数,优化器参数以及epoch

① 先建立一个字典,保存三个参数:

state = {‘net':model.state_dict(), 
        'optimizer':optimizer.state_dict(), 
        'epoch':epoch}

2.调用torch.save():

torch.save(state, path)

其中path表示保存文件的绝对路径+文件名。

当你想恢复某一阶段的训练(或者进行测试)时,就可以读取之前保存的网络模型参数等。

checkpoint = torch.load(path)

model.load_state_dict(checkpoint['net'])

optimizer.load_state_dict(checkpoint['optimizer'])

start_epoch = checkpoint['epoch'] + 1

3、当我们修改了一部分网络,比如加了一些,删除一些,需要过滤某些参数,加载方式:

def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
    if checkpoint != 'No':
        print("loading checkpoint...")
        model_dict = model.state_dict()
        modelCheckpoint = torch.load(checkpoint)
        pretrained_dict = modelCheckpoint['state_dict']

        # 过滤操作
        new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
        model_dict.update(new_dict)
        # 打印出来,更新了多少的参数
        print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))

        model.load_state_dict(model_dict)

        print("loaded finished!")

        # 如果不需要更新优化器那么设置为false
        if loadOptimizer == True:
            optimizer.load_state_dict(modelCheckpoint['optimizer'])
            print('loaded! optimizer')
        else:
            print('not loaded optimizer')
    else:
        print('No checkpoint is included')
    return model, optimizer

你可能感兴趣的:(神经网络,pytorch)