上一篇笔记写了torch.save和torch.load来存储和读取训练好的model,这一篇是关于另一种saving和loading model的方法—用参数字典而不是整个训练好的model来加载model。
# 举例
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 10)
def forward(self, x):
...
# Model.state_dict
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias
# save model
torch.save(model.state_dict(), PATH)
# load model by model
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
#解释一下这个语句
#load_state_dict是load字典对象,所以需要用torch.load(PATH)先给它凡序列化成字典对象,再传给load_state_dict加载
model.load_state_dict(torch.load(PATH))
# method1:Save/Load Entire Model
# save
torch.save(model, PATH)
# load
model = torch.load(PATH)
model.eval()
# method2: Save/Load state_dict (Recommended)
# save
torch.save(model.state_dict(), PATH)
# load
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
Save/Load Entire Model | Save/Load state_dict (Recommended) |
---|---|
参数少,简洁 | 参数较多(需要先定义model,再加载参数) |
不灵活(序列号的数据是与特定的classes和整个目录结构绑定在一起的,加载的时候是加载这个model class对应的固定的存储位置) | 更灵活(只是加载参数) |
参考资料
https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict