在pytorch中如果使用单机器多块GPU时,会有一些小的注意事项,似乎大部分人都找不到合适的完整的介绍,这里把之前总结的做一个汇总,希望能帮更多人建立完整的知识框架。
类型1:
如果是cpu model或单GPU model,2种形式(sequential model和sequential&OrderedDict model)
Sequential(
(0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
(3): ReLU())
一种是sequential model如上,引用具体层的方式是model[0],因为sequential会自动对layers编号,可类似于list切片方式调用
Sequential(
(conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(relu1): ReLU()
(conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
(relu2): ReLU())
另一种是sequential&OrderedDict model,引用引用方式是model.conv1,因为sequential针对ordereddict进行了优化可以直接通过.来调用层
类型2:
如果是data paralle model:可以看到是在原model基础上wrap了一个module外壳如下
DataParallel(
(module): ResNet(
(conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
…))
引用方式类似OrderedDict model的嵌套:
model.module就是引用内部的真实模型
model.module.conv1就是引用具体层(前提是sequential内包含了OrderedDict)
model.module[0]也是引用具体层(前提是sequential内不包含OrderedDict)
通过torch.save(name, dir)保存的就叫checkpoint文件,可以存一个dict或存一个state_dict (OrderedDict)。一般的dict文件用来保存模型运行的状态信息和参数,state_dict用来保存参数是dict的一部分。如下是一个典型的checkpoint数据结构
{meta: dict
state_dict: OrderedDict
optimizer: dict}
state_dict必然是一个OrderedDict数据类型,保存的内容就是所有深度学习需要优化的参数。
如果是Data paralle model在模型的state_dict中保存的内容会额外增加以module作为开头
odict_keys(['module.features.0.weight',
'module.features.0.bias',
'module.features.3.weight',
'module.features.3.bias',
'module.classifier.0.weight',
'module.classifier.0.bias',
'module.classifier.2.weight',
'module.classifier.2.bias',
'module.classifier.4.weight',
'module.classifier.4.bias'])
通过torch.save(name, dir)完成,先组合checkpoint, 然后保存
参考:https://stackoverflow.com/questions/42703500/best-way-to-save-a-trained-model-in-pytorch
方案1:直接save state_dict(OrderedDict),且只save state_dict,后边只是用于inference。这样就会有2种形式state_dict存在,一种不带module前缀,一种带module前缀。
注:state_dict需要先.cpu()
方案2:直接save state_dict(OrderedDict),且只save state_dict,后边只是后边用于inference。但state_dict格式都统一成不带module的形式。
注:torch.save(model.state_dict, filepath) 此时保存的是OrderedDict,但可能有两种形式
注:torch.save(model.module.state_dict, filepath)此时保存的也是state_dict,但data paralle不会再带有module前缀
方案3:间接save整个training status(dict),不只save state_dict,还save epoch/iter/optimizer state_dict等状态参数,后边用于回复训练,这个dict需要自己组合
方案3是最常用的方案,因为他适用范围更广,不但可以training也可以inference
通过torch.load(checkpoint, map_location)完成
通过map_location控制加载位置:
可以加载到cpu: map_location = lambda storage, loc: storage
可以加载到GPU: map_location = lambda storage, loc: storage.cuda(0)
(1)获得checkpoint后需要对checkpoint判断后处理获得state_dict:
如果checkpoint是OrderedDict,那么可以直接得到state_dict
如果checkpoint是dict,那么可以从字典中获得state_dict
如果state_dict包含module前缀,那么需要先去除module前缀,下面是一种处理前缀方式
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
核心原因:saving DataParallel wrapped model can cause problems when the model_state_dict is loaded into an unwrapped model. 保存了wrapped model然后加载到unwrapped model所以必然出错。
参考:https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/4
通过load_state_dict(model, state_dict)完成
(1)模型加载state_dict前需要对model判断处理:
如果是不带module的model,则加载不带module的state_dict:load_state_dict(model, state_dict)
如果是带module的model,则取内层model(相当于去掉module)然后加载不带module的state_dict,如下:load_state_dict(model.module, state_dict)
关键:带module wrap的model,可以直接加载带module前缀的state_dict,此时model和state_dict都不需要做去module化处理,model.load_state_dict(state_dict_w/_module)即可;当然也可以model和state_dict同时做去module化处理,model.module.load_state_dict(state_dict_w/o_module)