load_state_dict报错: Error(s) in loading state_dict for XXX: Missing keys(s) in state_dict: ....

1. 报错内容

/opt/tools/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1221 
   1222         if len(error_msgs) > 0:
-> 1223             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1224                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1225         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for UNet:
	Missing key(s) in state_dict:................(太多了)
	Unexpected key(s) in state_dict: ................(也很多)

2. 解决方案

train的时候我加了数据并行

# train
model = nn.DataParallel(model)

但是test里没加并行,加上就不会报错了

# test
model = UNet(n_channels=3, n_classes=1)
model = nn.DataParallel(model)
model.load_state_dict(torch.load(weight_path))

问题解决了之后我还想探索一下nn.DataParallel()到底做了些什么,与不并行有什么区别。

3. 错误分析

以上的问题大概是导入的权重的键与模型本身的键无法对应。比如相同的地方,导入的权重叫module.inc.conv.conv.0.weight, 而模型这个地方的权重叫inc.conv.conv.0.weight

并且两边权重的键个数也不一样,加了并行的模型权重多了18个结尾是num_batches_tracked的键,可能是用来记录这个分支算了多少个batch(猜想)

所以给test加上数据并行的最省事的解决方案了。

4. 另一种可能的解决方案

今天看pytorch的官方教程中提到,保存torch.nn.DataParallel模型的时候要使用

torch.save(model.module.state_dict(), PATH)

这样推理的时候就不需要model = nn.DataParallel(model)了,推荐使用这样的保存方式!

你可能感兴趣的:(胖琦的pytorch,pytorch,深度学习,python)