RuntimeError: Error(s) in loading state_dict for RNN:问题解决

RuntimeError: Error(s) in loading state_dict for RNN:问题解决

当我们在进行机器学习时,会出现加载训练好的模型参数出现错误的情况,以下便是我遇到的情况。
在这里插入图片描述
在这里插入图片描述
从上图可见,在加载训练好的模型参数时出现了问题,报错意思很明显,即加载的参数名称和新实例化的模型参数名称不一样。
首先,查看已训练好的模型参数,如下图所示。
RuntimeError: Error(s) in loading state_dict for RNN:问题解决_第1张图片
而新实例化的模型参数如下图所示。
RuntimeError: Error(s) in loading state_dict for RNN:问题解决_第2张图片
很明显,参数列表的名称不一样。
我怀疑我没有保存进去,所以重新训练一遍模型,结果如下所示:

RuntimeError: Error(s) in loading state_dict for RNN:问题解决_第3张图片
果然。。。。。。

接下来介绍几种pytorch保存并加载模型的方法。

  • 方法一:加载/保存整个模型

保存:torch.save(model, PATH)
加载:model = torch.load(PATH)
保存和加载模型都是采用非常直观的语法并且都只需要几行代码即可实现。这种实现保存模型的做法将是采用 Python 的 pickle 模块来保存整个模型,这种做法的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是 pickle 并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors 后采用都可能出现错误。

  • 方法二:加载和保存一个通用的检查点(Checkpoint)
    保存的示例代码:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

加载的示例代码:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

当保存一个通用的检查点(checkpoint)时,无论是用于继续训练还是预测,都需要保存更多的信息,不仅仅是 state_dict ,比如说优化器的 state_dict 也是非常重要的,它包含了用于模型训练时需要更新的参数和缓存信息,还可以保存的信息包括 epoch,即中断训练的批次,最后一次的训练 loss,额外的 torch.nn.Embedding 层等等。

上述保存代码就是介绍了如何保存这么多种信息,通过用一个字典来进行组织,然后继续调用 torch.save 方法,一般保存的文件后缀名是 .tar 。

加载代码也如上述代码所示,首先需要初始化模型和优化器,然后加载模型时分别调用 torch.load 加载对应的 state_dict 。然后通过不同的键来获取对应的数值。

加载完后,根据后续步骤,调用 model.eval() 用于预测,model.train() 用于恢复训练。

你可能感兴趣的:(rnn,人工智能,深度学习)