pytorch模型保存与加载

1.模型保存与加载

有两种方式:

# 保存整个网络
torch.save(net, PATH) 
# 加载整个网络
model_dict=torch.load(PATH)
#--------------------------------------------------
# 保存网络中的参数, 速度快,占空间少
torch.save(net.state_dict(),PATH)
# 加载保存的部分参数,前提是model需要先定义
model_dict=model.load_state_dict(torch.load(PATH))

然而,在实验中往往需要保存更多的信息,比如优化器的参数,那么可以采取下面的方法保存:

torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,
'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')

以上包含的信息有,epochID, state_dict, min loss, optimizer, 自定义损失函数的两个参数;格式以字典的格式存储。

加载的方式:

def load_checkpoint(model, checkpoint_PATH, optimizer):
    if checkpoint != None:
        model_CKPT = torch.load(checkpoint_PATH)
        model.load_state_dict(model_CKPT['state_dict'])
        print('loading checkpoint!')
        optimizer.load_state_dict(model_CKPT['optimizer'])
    return model, optimizer

其他的参数可以通过以字典的方式获得

但是,但是,我们可能修改了一部分网络,比如加了一些,删除一些,等等,那么需要过滤这些参数,加载方式:

def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
    if checkpoint != 'No':
        print("loading checkpoint...")
        model_dict = model.state_dict()
        modelCheckpoint = torch.load(checkpoint)
        pretrained_dict = modelCheckpoint['state_dict']
        # 过滤操作
        new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
        model_dict.update(new_dict)
        # 打印出来,更新了多少的参数
        print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
        model.load_state_dict(model_dict)
        print("loaded finished!")
        # 如果不需要更新优化器那么设置为false
        if loadOptimizer == True:
            optimizer.load_state_dict(modelCheckpoint['optimizer'])
            print('loaded! optimizer')
        else:
            print('not loaded optimizer')
    else:
        print('No checkpoint is included')
    return model, optimizer

转自:https://zhuanlan.zhihu.com/p/38056115

你可能感兴趣的:(PyTorch)