torch模型的保存和加载

这里介绍只保存加载模型的参数的方法,因为速度快,占内存小。

保存语法:

torch.save(model.state_dict(),'model.pth')

加载模型的语法:

因为只保存了模型的参数,所以需要先定义一个网络对象,然后再加载模型的参数。

例如:

model = ClassNet()

将模型参数加载到新模型中,torch.load返回的时一个OrderedDict,model.state_dict()把模型的所有参数都以OrderdeDict的形式保存了下来。

{'epoch': 1, 'model_name': 'resnet', 'state_dict': OrderedDict([('backbone.conv1.weight', tensor([[[[-9.6635e-03, -5.8054e-03, -1.7499e-03,  ...,  5.6849e-02,
            1.7084e-02, -1.2774e-02],
          [ 1.1954e-02,  1.0023e-02, -1.0967e-01,  ..., -2.7083e-01,
           -1.2892e-01,  3.8908e-03],
          [-6.0705e-03,  5.9568e-02,  2.9577e-01,  ...,  5.2005e-01,
            2.5649e-01,  6.3826e-02],

state_dict = torch.load('model.pth')

model.load_state_dict(state_dict)

这里附上一份自己的torch分类代码中的模型的保存和加载

#自定义模型保存
state = {
           "epoch": epoch + 1,
           "model_name": config.model_name,
           "state_dict": model.state_dict(),
           "F_Score": F_Score,
           "optimizer": optimizer.state_dict(),
            }

filename = config.weights + config.model_name + ".pth" # resnet.pth

torch.save(state, filename)

#test中模型的加载

 model = Model.get_net()
 checkpoint = torch.load(config.weights + config.model_name + '.pth')
 model.load_state_dict(checkpoint["state_dict"])
 Make_Confusion_Matrix(test_loader2, model)  #调用模型

你可能感兴趣的:(深度学习,pytorch,人工智能)