torch.nn.Module.load_state_dict in PyTorch

上一篇笔记写了torch.save和torch.load来存储和读取训练好的model,这一篇是关于另一种saving和loading model的方法—用参数字典而不是整个训练好的model来加载model。

torch.nn.Module.load_state_dict

  1. 需要理解的定义
    state_dict: 就是一个简单的Python 字典对象(dictionary object),用来存储参数(比如weights、biases),字典中存储model的layers和它对应的权重张量相对应。
    (note: 部分layer如卷积层、线性层有参数所以存储在state_dict里,而有的layer如pooling layer没有参数可学习,所以在state_dict没有对应layer及参数)
# 举例
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 10)
        

    def forward(self, x):
        ...
# Model.state_dict
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias 
  1. 用法
# save model
torch.save(model.state_dict(), PATH)

# load model by model
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
#解释一下这个语句
#load_state_dict是load字典对象,所以需要用torch.load(PATH)先给它凡序列化成字典对象,再传给load_state_dict加载
model.load_state_dict(torch.load(PATH))
  1. 两种save和load方法的区别:Save/Load Entire Model
    VS Save/Load state_dict(官方推荐)
# method1:Save/Load Entire Model
# save
torch.save(model, PATH)
# load
model = torch.load(PATH)
model.eval()
# method2: Save/Load state_dict (Recommended)
# save
torch.save(model.state_dict(), PATH)
# load
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
Save/Load Entire Model Save/Load state_dict (Recommended)
参数少,简洁 参数较多(需要先定义model,再加载参数)
不灵活(序列号的数据是与特定的classes和整个目录结构绑定在一起的,加载的时候是加载这个model class对应的固定的存储位置) 更灵活(只是加载参数)

参考资料
https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict

你可能感兴趣的:(pytorch,pytorch,python)