这里介绍只保存加载模型的参数的方法,因为速度快,占内存小。
保存语法:
torch.save(model.state_dict(),'model.pth')
加载模型的语法:
因为只保存了模型的参数,所以需要先定义一个网络对象,然后再加载模型的参数。
例如:
model = ClassNet()
将模型参数加载到新模型中,torch.load返回的时一个OrderedDict,model.state_dict()把模型的所有参数都以OrderdeDict的形式保存了下来。
{'epoch': 1, 'model_name': 'resnet', 'state_dict': OrderedDict([('backbone.conv1.weight', tensor([[[[-9.6635e-03, -5.8054e-03, -1.7499e-03, ..., 5.6849e-02,
1.7084e-02, -1.2774e-02],
[ 1.1954e-02, 1.0023e-02, -1.0967e-01, ..., -2.7083e-01,
-1.2892e-01, 3.8908e-03],
[-6.0705e-03, 5.9568e-02, 2.9577e-01, ..., 5.2005e-01,
2.5649e-01, 6.3826e-02],
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
这里附上一份自己的torch分类代码中的模型的保存和加载
#自定义模型保存
state = {
"epoch": epoch + 1,
"model_name": config.model_name,
"state_dict": model.state_dict(),
"F_Score": F_Score,
"optimizer": optimizer.state_dict(),
}
filename = config.weights + config.model_name + ".pth" # resnet.pth
torch.save(state, filename)
#test中模型的加载
model = Model.get_net()
checkpoint = torch.load(config.weights + config.model_name + '.pth')
model.load_state_dict(checkpoint["state_dict"])
Make_Confusion_Matrix(test_loader2, model) #调用模型