pytorch 学习 | 多GPU存储模型及加载参数文件的坑(Error(s) in loading state_dict for DataParallel)

个人使用pytorch的时候需要用到多GPU运行,简要说明一下应用情景:

  1. GPU不够用,你需要将模型存储在多个GPU上;
  2. 当模型初始化后运行在多个GPU上,你要加载dict模型参数文件。

第一个情景,我们使用 nn.DataParallel 来解决,直接上例子:

# 使用nvidia-smi查看可用的设备
CUDA_DEVICE_1 = 0  
CUDA_DEVICE_2 = 1
# 模型初始化
net = model()
# 模型进入DataParallel,device_ids指明设备号列表
net = nn.DataParallel(net, device_ids=[CUDA_DEVICE_1, CUDA_DEVICE_2])

这个比较简单。

第二个有点坑,在DataParallel模式中,不能直接使用load_state_dict方法。直接使用会报类似这样的错:

RuntimeError: Error(s) in loading state_dict for DataParallel:
        Missing key(s) in state_dict: "module.layer1.0.weight", "module.layer1.1.weight", "module.layer1.1.bias", "module.layer1.1.running_mean", "module.layer1.1.running_var", "module.layer1.4.0.conv1.weight",

发现有一些key缺失,其实不是确实,而是 key 名前面需要补上 module.

直接使用以下代码即可。

# 先加载模型参数dict文件
state_dict = torch.load(args.trained_model)
from collections import OrderedDict
# 初始化一个空 dict
new_state_dict = OrderedDict()
# 修改 key,没有module字段则需要不上,如果有,则需要修改为 module.features
for k, v in state_dict.items():
    if 'module' not in k:
        k = 'module.'+k
    else:
        k = k.replace('features.module.', 'module.features.')
    new_state_dict[k]=v
# 加载修改后的新参数dict文件
net.load_state_dict(new_state_dict)

你可能感兴趣的:(DL,pytorch)