Error(s) in loading state_dict for DataParallel

关于PyTorch模型保存与导入的一些注意点:

1.没有使用并行计算:

import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 1)
        self.linear = nn.Linear(2, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.linear(x)
        return x


net = Net()
state_dict = net.state_dict()
for key, value in state_dict.items():
    print(key)

输出:

conv1.weight
conv1.bias
linear.weight
linear.bias

2.使用并行计算(调用net.state_dict()):

import torch.nn as nn
import torch


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 1)
        self.linear = nn.Linear(2, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.linear(x)
        return x


net = Net()

net = nn.DataParallel(net, device_ids=[0, 3])
net.cuda()

state_dict = net.state_dict()
for key, value in state_dict.items():
    print(key)
module.conv1.weight
module.conv1.bias
module.linear.weight
module.linear.bias

可以发现,模型的key前面带了"module."的字符串

3.使用并行计算(调用net.module.state_dict()):

import torch.nn as nn
import torch


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 1)
        self.linear = nn.Linear(2, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.linear(x)
        return x


net = Net()

net = nn.DataParallel(net, device_ids=[0, 3])
net.cuda()

# 这里使用了并行,所以net.module.state_dict()保存的不带module,而net.state_dict()带module()
state_dict = net.module.state_dict()
for key, value in state_dict.items():
    print(key)

模型输出为:

conv1.weight
conv1.bias
linear.weight
linear.bias

可以看到没有"module."了。

总结为,如果使用了并行net = nn.DataParallel(net, device_ids=[--]),在保存模型时候:

  1. net.module.state_dict()保存的不带"module."
  2. net.state_dict()带"module."

那么,如果我们在训练的时候使用了DataParallel时:

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model, device_ids=[0, 3])
    data_parallel = True
model.to(device)
model.load_state_dict(torch.load('./model_epoch_600.pth'))

如果这里的model_epoch_600.pth的模型为前面第一或第三种情况,及模型中不带有“module.”字样,那么就会报出错误:

RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.bn.weight", "module.bn.bias", "module.bn.running_mean", "module.bn.running_var".……………………………………
Unexpected key(s) in state_dict: "bn.weight", "bn.bias", "running_mean", "bn.running_var", ……………………………………

类似的错误,即我们需要带有".module"的,而保存的模型不带有。

解决方法:

from collections import OrderedDict


new_state_dict = OrderedDict()
for key, value in torch.load("./model_epoch_600.pth").items():
    name = 'module.' + key
    new_state_dict[name] = value
model.load_state_dict(new_state_dict)

来手动修改保存的模型即可。

或者情况相反,你多了"module.",可以通过以下方法解决:

from collections import OrderedDict


new_state_dict = OrderedDict()
for key, value in torch.load("./model_epoch_600.pth").items():
    name = key[7:]
    new_state_dict[name] = value
model.load_state_dict(new_state_dict)

因为通过state_dict()或是module.state_dict()函数保存的模型参数,其本质上是一个OrderedDict!

import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 1)
        self.linear = nn.Linear(2, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.linear(x)
        return x


net = Net()
state_dict = net.state_dict()
print(state_dict)

输出为:

OrderedDict([('conv1.weight', tensor([[[[ 0.2986]],

         [[-0.3642]]],


        [[[ 0.6761]],

         [[ 0.1944]]]])), ('conv1.bias', tensor([ 0.6060, -0.4560])), ('linear.weight', tensor([[-0.2554,  0.4958],
        [ 0.1802, -0.0579],
        [ 0.3246, -0.6828],
        [ 0.2968,  0.6336],
        [ 0.6546, -0.6072],
        [-0.5858, -0.7052],
        [ 0.5672,  0.1555],
        [-0.1569,  0.5623],
        [-0.6982,  0.3347],
        [-0.2944, -0.4632]])), ('linear.bias', tensor([ 0.3750,  0.5366,  0.4006, -0.6096, -0.6294,  0.6686,  0.3804, -0.0299,
         0.4152, -0.6917]))])

 

你可能感兴趣的:(深度学习,torch,python)