我们在使用pytorch的过程,经常会需要加载模型参数,但是pytorch当中,GPU和CPU模型下加载的参数的类型是不同的,不能互相直接调用,下面分情况进行操作说明。
问题:使用GPU训练的模型在CPU下无法运行,显示:
Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
发现原来是自己使用了DataParallel的问题,我保存的是整个模型。
model=DataParalle(model).cuda()
state = {
'epoch':epoch,
'state_dict':model.state_dict(),
'best_precision':best_precision,
'lowest_loss':lowest_loss,
'stage':stage,
'lr':lr
} # 保存训练中各种参数
#字典
torch.save(state,‘xx.pth.tar’)
然后把模型copy到自己电脑上用cpu跑
model= my_new_nets(*args) #读取自己的网络模型
model=DataParalle(model).cuda()
model.load_state_dict(torch.load('xx.pth.tar' )['state_dict']) # 在模型中导入参数
...
model(data)
# 传入数据的时候出错
解决方案:在CPU环境下,不能直接导入GPU训练的DataParallel模型
所以换个策略,现在把GPU类型转化为CPU类型
model = DataParallel(model)
…
real_model = model.module#这个才是你实际的模型,如果直接报错model的话,其实是保存了DataParallel(model)这个,这样会导致cpu环境下加载出错
state = {
'epoch':epoch,
'state_dict':real_model.state_dict(),
'best_precision':best_precision,
'lowest_loss':lowest_loss,
'stage':stage,
'lr':lr
} # 保存训练中各种参数
#字典
torch.save(state,‘xxx.pth.tar’) #这样才是正确的保存模型方式,这样在cpu环境的模型才不会出错
其实还有一种方案,不需要重新保存模型,只需修改一下读取参数的方式
model=my_new_nets(*args)
model=DataParalle(model) #不用cuda()
# 在模型中导入参数
model.load_state_dict(torch.load('xx.pth.tar',map_location=‘cpu’ )['state_dict'])
model(data)
完整代码:
model = my_new_nets(*args).cuda() # 网络结构,里面传入你自定义的参数
model = torch.nn.DataParallel(model, device_ids=[0]) # 将model转为muit-gpus模式
#载入weights
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)['state_dict']
model.load_state_dict(checkpoint) # 用weights初始化网络
# 载入为单gpu模型
gpu_model = model.module # GPU-version
# 载入为cpu模型
cpu_model = my_new_nets(*args)
cpu_model.load_state_dict(gpu_model.state_dict())
torch.save(cpu_model.state_dict(), 'cpu_mode.pth.tar')
# cpu模型存储, 注意这里的state_dict后的()必须加上,否则报'function' object has no attribute 'copy'错误
还有另外一种方案:修改权重orderdic
中的名称
上述代码只有在模型在一个GPU上训练时才起作用。如果我在多个GPU上训练我的模型,保存它,然后尝试在CPU上加载,我得到这个错误:KeyError: 'unexpected key "module.conv1.weight" in state_dict'
如何解决?
代码中保存了模型nn.DataParallel
,该模型将模型存储在该模型中module
,而现在您正试图加载模型DataParallel
。您可以nn.DataParallel
在网络中暂时添加一个加载目的,也可以加载权重文件,创建一个没有module
前缀的新的有序字典,然后加载它。同时因为在多gpu上训练的模型在保存时候在参数名前多加了一个“module.”前缀,加载的时候把这个前缀去掉就行了
# original saved file with DataParallel
state_dict = torch.load('myfile.pth.tar')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # 去掉 `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
(1) 模型在GPU上保存,运行在CPU上
save
torch.save(model.state_dict(), PATH)
load
device = torch.device("cpu")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
(2) 模型在GPU上保存,运行在GPU上
save
torch.save(model.state_dict(), PATH)
load
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
(3) 模型在CPU上保存,运行在GPU上
save
torch.save(model.state_dict(), PATH)
load
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
model.to(device)
模型在GPU1上保存,运行在GPU0和1上
torch.load('modelparameters.pth', map_location={'cuda:1':'cuda:0'})