Pytorch加载保存好的模型发现与实际保存模型的参数不一致

Pytorch加载保存好的模型发现与实际保存的参数不一致

  • 发现问题
    • 初衷:
    • 问题:
      • 保存模型并打印参数
      • 加载模型并打印模型参数
      • 问题来了
  • 解决问题
    • 先说结论:
    • 原因:
    • 联想:
    • 问题本质:
  • 小结
  • 参考博客

发现问题

初衷:

最近跑深度学习代码,数据量比较庞大所以训练速度慢的一,一旦有什么问题训练突然停止了,心态都要炸了,所以这时候使用pytorch保存每轮训练好的模型(可以使用一些保存技巧,保存得到的当前最新的模型checkpoint,倒不至于每轮都保存),即使训练停止还能接着去加载保存好的模型得到当前模型输出结果,节省时间,防止每次重复性的从头开始训练。

问题:

加载训练了几天的模型checkpoint用于继续训练,发现网络输出的结果与之前保存模型时候输出的结果差别很大(因为之前也有成功用过加载和保存模型,就一直觉得没什么问题,就没管了,后来因为训练实在太耗时,跟同学讨论时候被启发那说这里肯定是有问题的,所以就开始去找问题了),因此试着去输出保存的模型参数和加载的模型参数,看看二者是否一致。(谷歌很久没有找到解决方案,自己找着找着发现了个解决方案,记录一下,避坑避坑)

保存模型并打印参数

我在此保存的是模型和优化器两个东西:
在这里插入图片描述
保存的模型参数
部分输出结果:(保存的模型参数)
Pytorch加载保存好的模型发现与实际保存模型的参数不一致_第1张图片
Pytorch加载保存好的模型发现与实际保存模型的参数不一致_第2张图片

加载模型并打印模型参数

加载模型和优化器:
Pytorch加载保存好的模型发现与实际保存模型的参数不一致_第3张图片
Pytorch加载保存好的模型发现与实际保存模型的参数不一致_第4张图片
运行代码,发现会报错,看了一下发现可能是模型加载时的参数名字不一致,加载的参数比网络需要的参数多了module字符,谷歌发现可以在加载模型时候多写一个参数strict = False即可。
Pytorch加载保存好的模型发现与实际保存模型的参数不一致_第5张图片
修改代码:
Pytorch加载保存好的模型发现与实际保存模型的参数不一致_第6张图片
成功运行,不报错了,得到部分输出结果:(加载模型参数)
Pytorch加载保存好的模型发现与实际保存模型的参数不一致_第7张图片
Pytorch加载保存好的模型发现与实际保存模型的参数不一致_第8张图片

问题来了

我们知道,保存的模型和加载的模型的参数应该是一模一样的才对,但是通过对比上述我们自己保存的模型和加载这个模型的参数可以发现,二者并不一样,至此就已经发现问题了,最开始发现的加载模型后网络的输出结果和保存时的输出结果不同的原因就是因为二者使用的参数并不一致,而实际上应该是一致的才对。

解决问题

先说结论:

保存模型时候,不能直接保存model.state_dict(),而是保存model.module.state_dict()。优化器optim不需改变,与它无关(保存的和加载得到的optim输出的参数都是一致的)。
在这里插入图片描述

原因:

Pytorch加载保存好的模型发现与实际保存模型的参数不一致_第9张图片
我的代码使用了GPU运行,而且代码里也加了多块GPU并行的操作:model = torch.nn.DataParallel(model),恰巧我在pytorch官方教程中发现,有一节说了保存torch.nn.DataParallel 模型,所以这就说明了对于torch.nn.DataParallel 模型是得用所给的专门保存语句的:torch.save(model.module.state_dict(), PATH)——比普通模型多了一个.module字段。
修改保存模型的代码之后,发现加载的参数和保存的参数一致了!并且网络输出结果也一致了!

联想:

还记得上面提到的,在之前加载模型有见过一个报错就是说模型参数名字不一致,而它们相差的正好就是module这个字段,但是上面用的解决方案就是多加了strict=False参数,虽然不会报错了,但是却没有解决根本问题。因此问题解决之后,也就没有参数字段名不一致的问题了,strict=False参数可以去掉了,也不会报错了(测试过了)。

问题本质:

尚未明确。
自己乱想:我用了多块GPU并行,所以模型参数可能具体放到了各块gpu上面,后来直接加载模型时候没有区分,就不知道加载什么乱七八糟的参数去了,具体我也没弄太懂,也有问题就是可能是我虽然使用了torch.nn.DataParallel 模型,但是实际运行时还是一块GPU,就不知道到底是为啥。我自己还测试了一下保存的model.state_dict()和model.module.state_dict()输出结果有什么不同,发现它们的参数们大小实际是一致的,只是在后面加载模型时候,只有model.module.state_dict()保存的模型能成功加载和保存模型一模一样的参数,而model.state_dict()保存的模型加载出来参数就变了,就有问题了,很迷,但是相信pytorch官方教程总没问题,使用了torch.nn.DataParallel 模型,就使用相应的保存模型语句:torch.save(model.module.state_dict(), PATH) 保存模型参数。
Pytorch加载保存好的模型发现与实际保存模型的参数不一致_第10张图片

小结

写给自己:发现问题时候,不能想当然以为之前 用过没问题就觉得没问题,忽视问题存在, 还是得积极去发现问题、解决问题才对,就像今天这个问题,早点解决的话,那些训练的时间也不会白白浪费了。

参考博客

pytorch读取模型失败RuntimeError: Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict: https://blog.csdn.net/qq_34769162/article/details/115038161.
pytorch官方教程——保存加载模型: https://pytorch123.com/ThirdSection/SaveModel/#3-checkpoint.

你可能感兴趣的:(pytorch-gpu,深度学习,pytorch,神经网络)