Unexpected key(s) in state_dict: “module.conv1.weight“, “module.bn1.weight“, “module.bn1.bias“,

由于服务器老是断电 所以想加载已经训练好的上一个epoch的模型,但是在加载时遇到了这个问题
这是由于保存模型字典时每一个模块的key都自动加上了‘module’。所以在加载模型参数继续训练时就会与模型对不上号。

RuntimeError: Error(s) in loading state_dict for ResNet:
	Unexpected key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.conv2.weight", "module.bn2.weight", "module.bn2.bias", "module.bn2.running_mean", "module.bn2.running_var", "module.conv3.weight", "module.bn3.weight", "module.bn3.bias", "module.bn3.running_mean", "module.bn3.running_var", "module.layer1.0.conv1.weight", "module.layer1.0.bn1.weight", "module.layer1.0.bn1.bias", "module.layer1.0.bn1.running_mean", "module.layer1.0.bn1.running_var", "module.layer1.0.conv2.weight", "module.layer1.0.bn2.weight", "module.layer1.0.bn2.bias", 

接着查看测试的代码:可以看到在测试中就是把’module.'给去掉,从下标7开始读取key。

    for key, nkey in zip(state_dict_old.keys(), state_dict.keys()):
        if key != nkey:
            # remove the 'module.' in the 'key'
            state_dict[key[7:]] = deepcopy(state_dict_old[key])
        else:
            state_dict[key] = deepcopy(state_dict_old[key])

所以在测试的时候直接去掉前面的module.就好了像这样:new_param是从上一个epoch读取出的参数字典

    deeplab.load_state_dict({k.replace('module.',''):v for k,v in new_params.items()})

你可能感兴趣的:(笔记,深度学习)