关于PyTorch模型保存与导入的一些注意点:
1.没有使用并行计算:
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(2, 2, 1)
self.linear = nn.Linear(2, 10)
def forward(self, x):
x = self.conv1(x)
x = self.linear(x)
return x
net = Net()
state_dict = net.state_dict()
for key, value in state_dict.items():
print(key)
输出:
conv1.weight
conv1.bias
linear.weight
linear.bias
2.使用并行计算(调用net.state_dict()):
import torch.nn as nn
import torch
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(2, 2, 1)
self.linear = nn.Linear(2, 10)
def forward(self, x):
x = self.conv1(x)
x = self.linear(x)
return x
net = Net()
net = nn.DataParallel(net, device_ids=[0, 3])
net.cuda()
state_dict = net.state_dict()
for key, value in state_dict.items():
print(key)
module.conv1.weight
module.conv1.bias
module.linear.weight
module.linear.bias
可以发现,模型的key前面带了"module."的字符串
3.使用并行计算(调用net.module.state_dict()):
import torch.nn as nn
import torch
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(2, 2, 1)
self.linear = nn.Linear(2, 10)
def forward(self, x):
x = self.conv1(x)
x = self.linear(x)
return x
net = Net()
net = nn.DataParallel(net, device_ids=[0, 3])
net.cuda()
# 这里使用了并行,所以net.module.state_dict()保存的不带module,而net.state_dict()带module()
state_dict = net.module.state_dict()
for key, value in state_dict.items():
print(key)
模型输出为:
conv1.weight
conv1.bias
linear.weight
linear.bias
可以看到没有"module."了。
总结为,如果使用了并行net = nn.DataParallel(net, device_ids=[--]),在保存模型时候:
那么,如果我们在训练的时候使用了DataParallel时:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model, device_ids=[0, 3])
data_parallel = True
model.to(device)
model.load_state_dict(torch.load('./model_epoch_600.pth'))
如果这里的model_epoch_600.pth的模型为前面第一或第三种情况,及模型中不带有“module.”字样,那么就会报出错误:
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.bn.weight", "module.bn.bias", "module.bn.running_mean", "module.bn.running_var".……………………………………
Unexpected key(s) in state_dict: "bn.weight", "bn.bias", "running_mean", "bn.running_var", ……………………………………
类似的错误,即我们需要带有".module"的,而保存的模型不带有。
解决方法:
from collections import OrderedDict
new_state_dict = OrderedDict()
for key, value in torch.load("./model_epoch_600.pth").items():
name = 'module.' + key
new_state_dict[name] = value
model.load_state_dict(new_state_dict)
来手动修改保存的模型即可。
或者情况相反,你多了"module.",可以通过以下方法解决:
from collections import OrderedDict
new_state_dict = OrderedDict()
for key, value in torch.load("./model_epoch_600.pth").items():
name = key[7:]
new_state_dict[name] = value
model.load_state_dict(new_state_dict)
因为通过state_dict()或是module.state_dict()函数保存的模型参数,其本质上是一个OrderedDict!
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(2, 2, 1)
self.linear = nn.Linear(2, 10)
def forward(self, x):
x = self.conv1(x)
x = self.linear(x)
return x
net = Net()
state_dict = net.state_dict()
print(state_dict)
输出为:
OrderedDict([('conv1.weight', tensor([[[[ 0.2986]],
[[-0.3642]]],
[[[ 0.6761]],
[[ 0.1944]]]])), ('conv1.bias', tensor([ 0.6060, -0.4560])), ('linear.weight', tensor([[-0.2554, 0.4958],
[ 0.1802, -0.0579],
[ 0.3246, -0.6828],
[ 0.2968, 0.6336],
[ 0.6546, -0.6072],
[-0.5858, -0.7052],
[ 0.5672, 0.1555],
[-0.1569, 0.5623],
[-0.6982, 0.3347],
[-0.2944, -0.4632]])), ('linear.bias', tensor([ 0.3750, 0.5366, 0.4006, -0.6096, -0.6294, 0.6686, 0.3804, -0.0299,
0.4152, -0.6917]))])