定义2个测试脚本test.py和test2.py,用于测试保存和加载,models文件夹保存模型,整个测试的项目文件结构如下:
E:.
│ test.py
│ test2.py
└─ models
dongtai.pt
dongtai_state_dict.pt
jingtai.pth
test.py中定义了TheModelClass这个网络结构类,此外写了模型保存和加载的代码,test2.py是想测试在没有定义模型结构的脚本中,是否可以成功加载模型。
test.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
# 定义模型
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
if __name__ == "__main__":
# 初始化模型
model = TheModelClass()
# 模型保存,方法一动态图
torch.save(model,"models/dongtai.pt")
# 模型保存,方法二动态图
torch.save(model.state_dict(),"models/dongtai_state_dict.pt")
# 模型保存,方法三静态图
x = torch.rand(1,3,30,30) #占位符
trace_model = torch.jit.trace(model,x)
torch.jit.save(trace_model,"models/jingtai.pth")
# 模型加载,带模型结构
model_resume = torch.load("models/dongtai.pt")
# 模型加载,只有权重
weights = torch.load("models/dongtai_state_dict.pt")
model.load_state_dict(weights)
# 直接从静态图中恢复,无需模型结构
model = torch.jit.load("models/jingtai.pth")
x = torch.rand(1,3,30,30)
pred = model(x)
print(pred)
经过测试,pytorch可以通过三种方法实现模型的保存和加载:
接下来一个个说明这三种方法需要注意的地方。
# 保存
model = TheModelClass()
# 模型保存
torch.save(model,"models/dongtai.pt")
# 加载,带模型结构
model_resume = torch.load("models/dongtai.pt")
保存:首先实例化网络对象,然后通过torch.save的方式,将模型结构和权重都序列化保存下来,后缀为pt或者pth都可以,不管保存成哪种后缀,都可以解析。
加载:首先必须能访问到网络结构的类TheModelClass,然后通过torch.load的方式就可以完整的将模型结构恢复,同时加载好权重。
这里需要特别注意的点,加载模型的这个文件必须要能找到网络结构的类,不管是在哪里定义网络,都要能导入到当前读取模型的这个文件中做实例化,比如我在test2.py里面导入test.py中的网络类,就可以成功加载,否则会报找不到类的错误。
test2.py
import torch
from test import TheModelClass # 不导入或同级下找不到会有问题
model = TheModelClass()
model_resume = torch.load("models/dongtai.pt")
model.load_state_dict(model_resume)
model.eval()
print()
# 初始化模型
model = TheModelClass()
# 模型保存
torch.save(model.state_dict(),"models/dongtai_state_dict.pth")
# 模型加载,只有权重
weights = torch.load("models/dongtai_state_dict.pth")
model.load_state_dict(weights)
保存:首先实例化网络对象,然后通过torch.save的方式,只将模型权重序列化保存下来,这种方法不用保存模型结构。
加载:首先必须能访问到网络结构的类TheModelClass,并实例化,然后通过torch.load的方式就可以将模型权重反序列化取出,然后将其加载进模型对象中。
注意:必须实例化网络对象,才能加载对应的权重。
model = TheModelClass()
# 模型保存,方法三静态图
x = torch.rand(1,3,30,30) #占位符
trace_model = torch.jit.trace(model,x)
torch.jit.save(trace_model,"models/jingtai.pt")
# 直接从静态图中恢复,无需模型结构
model_ji = torch.jit.load("models/jingtai.pt")
保存:首先实例化网络对象,然后用一个随机的固定尺寸的输入,通过torch.jit.trace,将网络结构前向跑一遍,记录下网络中的节点运行路径,然后通过torch.jit.save将这个运行路径存下来,这种方法会自动记录模型中节点间的数据流动顺序,也就是间接的记录下的模型结构和每个节点的权重。不会单独再保存一个模型类。
加载:直接用torch.jit.load的方法加载模型即可,因为该模型已经记录了网络中模型节点权重和数据流动的路径,因此只要将数据输入,即可“流过”整个模型,得到最终的输出,不用单独再构造模型类的实例。
目前用的最多就是只保存权重的方法(方法二),最后一种用的最少,一般部署的时候也很少用,都是转成onnx再部署。