使用pytorch DDP(DistributedDataParallel,分布式数据并行)可以进行多卡训练,涉及到模型保存与加载问题时,一般会涉及到以下两种需求:
如何无bug且高效的解决以上需求?(假设训练设备为“单机4卡”)
对于需求1,由于DDP在多卡中维护了相同的模型参数(通过在4张GPU上确保模型初始化以及广播相同的梯度来保证4张卡中的模型参数是完全相同的),因此只需要在其中一张卡保存模型即可:
def save_checkpoint(local_rank, ddp_model, path):
#只在GPU 0 上保存模型
if local_rank== 0:
state = {
'model': ddp_model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(state, path)
对于需求2,一般会使用torch.load()方法从磁盘加载文件:
def load_checkpoint(path):
checkpoint = torch.load(path)
model = Net()
model.load_state_dict(checkpoint['model'])
model = DDP(model, device_ids=[gpu])
return model
但是此时往往会遇到多进程在GPU0上占用过多显存的问题:
使用nvidia-smi命令:
上图中,在所有使用GPU0的进程中,除了PID为62250的进程外,还存在其他三个进程,而这三个进程还分别使用GPU1\2\3。这三个额外进程在GPU0占用了725MB*3的显存空间,这可能会导致GPU0在训练时出现爆显存的问题。
在DDP中,会为每张卡单独创建一个进程:
上图的情况是正常的,每个进程只会使用与其对应的一张显卡。
该问题出现的原因是:torch.load()的不正确使用。
在pytorch对torch.
load()
方法的官方文档中,有以下说明:
If
map_location
is missing,torch.load
will first load the module to CPU and then copy each parameter to where it was saved
意思是,如果map_location参数是空的,则torch.load方法会先把模型加载到CPU,然后把模型参数复制到保存它的地方(根据上文,保存模型的位置恰好是GPU 0)。
跑在GPU1上的进程在执行到torch.load方法后,会先加载模型到CPU,之后该进程顺理成章地调用GPU0,把一部分数据复制到GPU0,也就出现了前面图中的问题。
与其说是bug,倒不如说没仔细阅读文档。
两种解决方法方法。
一,将map_location指定为CPU:
def load_checkpoint(path):
#加载到CPU
checkpoint = torch.load(path,map_location='cpu')
model = Net()
model.load_state_dict(checkpoint['model'])
model = DDP(model, device_ids=[gpu])
return model
二,将map_location指定为local_rank对应的GPU:
def load_checkpoint(path):
#加载到CPU
checkpoint = torch.load(path,map_location='cuda:{}'.format(local_rank))
model = Net()
model.load_state_dict(checkpoint['model'])
model = DDP(model, device_ids=[gpu])
return model