Pytorch模型的保存及加载

深度学习模型保存模型参数的方法有两种:

1.保存整个网络(模型结构+模型参数):
# 保存整个模型和参数
torch.save(model_object, 'convit_tiny.pth')  
    
# 对应的加载模型代码为
model = torch.load('convit_tiny.pth')
print(model)

此时print的是整个网络的模型结构;
Pytorch模型的保存及加载_第1张图片
若要加载模型的参数:

model = torch.load('convit_tiny.pth')
args = model.state_dict()
print(args)

此时输出的是模型的训练参数:
Pytorch模型的保存及加载_第2张图片

2.直接保存网络的模型参数:
# 将my_resnet模型储存为my_resnet.pth,此时保存的仅仅是模型的参数
torch.save(model.state_dict(), "convit.pth")
# 直接加载参数
args = torch.load("convit.pth")
# 若要加载模型则先需要初始化之前所定义的网络
new_model = Net()
# 再使用load_state_dict方法将权重加载进网络
# 注意:model.state_dict()其实返回的是一个OrderDict,存储了网络结构的名字和对应的参数;而这里是导入参数因此用的是model.load_state_dict()而不是model.state_dict()
new_model.load_state_dict(torch.load('convit.pth'))

你可能感兴趣的:(Python,Pytorch,pytorch,python,深度学习)