多GPU运行保存加载恢复checkpoint的几个关键

第一部分:认识多GPU的DataParalle model

多GPU运行保存加载恢复checkpoint的几个关键_第1张图片
在pytorch中如果使用单机器多块GPU时,会有一些小的注意事项,似乎大部分人都找不到合适的完整的介绍,这里把之前总结的做一个汇总,希望能帮更多人建立完整的知识框架。

第1层:认识model本身

类型1:
如果是cpu model或单GPU model,2种形式(sequential model和sequential&OrderedDict model)

Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU())

一种是sequential model如上,引用具体层的方式是model[0],因为sequential会自动对layers编号,可类似于list切片方式调用

Sequential(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU())

另一种是sequential&OrderedDict model,引用引用方式是model.conv1,因为sequential针对ordereddict进行了优化可以直接通过.来调用层

类型2:
如果是data paralle model:可以看到是在原model基础上wrap了一个module外壳如下

DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    …))

引用方式类似OrderedDict model的嵌套:
model.module就是引用内部的真实模型
model.module.conv1就是引用具体层(前提是sequential内包含了OrderedDict)
model.module[0]也是引用具体层(前提是sequential内不包含OrderedDict)

第2层:认识checkpoint文件

通过torch.save(name, dir)保存的就叫checkpoint文件,可以存一个dict或存一个state_dict (OrderedDict)。一般的dict文件用来保存模型运行的状态信息和参数,state_dict用来保存参数是dict的一部分。如下是一个典型的checkpoint数据结构
{meta: dict
state_dict: OrderedDict
optimizer: dict}

第3层:认识state_dict字典

state_dict必然是一个OrderedDict数据类型,保存的内容就是所有深度学习需要优化的参数。

如果是Data paralle model在模型的state_dict中保存的内容会额外增加以module作为开头

odict_keys(['module.features.0.weight', 
		      'module.features.0.bias', 
		      'module.features.3.weight', 
		      'module.features.3.bias', 
		      'module.classifier.0.weight', 
	             'module.classifier.0.bias', 
		      'module.classifier.2.weight', 
		      'module.classifier.2.bias', 
		      'module.classifier.4.weight', 
		      'module.classifier.4.bias'])

第二部分:处理模型的保存和加载的流程

1. save:

通过torch.save(name, dir)完成,先组合checkpoint, 然后保存
参考:https://stackoverflow.com/questions/42703500/best-way-to-save-a-trained-model-in-pytorch

方案1:直接save state_dict(OrderedDict),且只save state_dict,后边只是用于inference。这样就会有2种形式state_dict存在,一种不带module前缀,一种带module前缀。
注:state_dict需要先.cpu()

方案2:直接save state_dict(OrderedDict),且只save state_dict,后边只是后边用于inference。但state_dict格式都统一成不带module的形式。
注:torch.save(model.state_dict, filepath) 此时保存的是OrderedDict,但可能有两种形式
注:torch.save(model.module.state_dict, filepath)此时保存的也是state_dict,但data paralle不会再带有module前缀

方案3:间接save整个training status(dict),不只save state_dict,还save epoch/iter/optimizer state_dict等状态参数,后边用于回复训练,这个dict需要自己组合
方案3是最常用的方案,因为他适用范围更广,不但可以training也可以inference

2. load checkpoint:

通过torch.load(checkpoint, map_location)完成

通过map_location控制加载位置:
可以加载到cpu: map_location = lambda storage, loc: storage
可以加载到GPU: map_location = lambda storage, loc: storage.cuda(0)

3. 获得state_dict:

(1)获得checkpoint后需要对checkpoint判断后处理获得state_dict:
如果checkpoint是OrderedDict,那么可以直接得到state_dict
如果checkpoint是dict,那么可以从字典中获得state_dict
如果state_dict包含module前缀,那么需要先去除module前缀,下面是一种处理前缀方式

if list(state_dict.keys())[0].startswith('module.'):
    state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}

核心原因:saving DataParallel wrapped model can cause problems when the model_state_dict is loaded into an unwrapped model. 保存了wrapped model然后加载到unwrapped model所以必然出错。
参考:https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/4

4. load state dict:

通过load_state_dict(model, state_dict)完成
(1)模型加载state_dict前需要对model判断处理:
如果是不带module的model,则加载不带module的state_dict:load_state_dict(model, state_dict)
如果是带module的model,则取内层model(相当于去掉module)然后加载不带module的state_dict,如下:load_state_dict(model.module, state_dict)

关键:带module wrap的model,可以直接加载带module前缀的state_dict,此时model和state_dict都不需要做去module化处理,model.load_state_dict(state_dict_w/_module)即可;当然也可以model和state_dict同时做去module化处理,model.module.load_state_dict(state_dict_w/o_module)

你可能感兴趣的:(DeepLearning)