Pytorch DDP DistributedDataParallel load_state_dict checkpoint 加载断点 继续训练的正确打开方式

一、网上的一些博客是错误的,我按照他们的方法去改,结果不对,比如:

Pytorch采坑记录:DDP加载之前的checkpoint后loss上升(metric下降)
[原创][深度][PyTorch] DDP系列第二篇:实现原理与源代码解析

二、pytorch官网的说法也是让人摸不着头脑,官方说法,但是保存checkpoint 的做法是对的,直接在主进程共保存即可,这点不再赘述,但是加载断点文件到DDP中的说法不是很清楚。特别是map_location = {‘cuda:%d’ % 0: ‘cuda:%d’ % rank} 很容易让人误解。

三、正确的方法:

1、在你的Model()类初始化和DistributedDataParallel()类初始化中间加载断点:在每个进程运行到这里时都会从硬盘中读取文件到CPU,然后再到每一个显卡,两种方式都可。主要是设置map_location参数,否则会出现如官网描述的错误,主显卡显存略大,多出来几个进程等等。这都是不是正确的打开方式。。。

Pytorch DDP DistributedDataParallel load_state_dict checkpoint 加载断点 继续训练的正确打开方式_第1张图片

2、1为何要去掉module.?因为DistributedDataParallel之前的model的state_dict是单卡的,之后他就会加上module.,如果按下面的保存方式就需要在加载断点之前去掉。当然你也可以保存为单卡的state_dict:orch.save(model.module.state_dict(), pth) (可能不对,网上很多教程,自行查找)。另外如果你在DistributedDataParallel()之后,开始第一个训练的epoch时加载模型参数,那么也不需要去掉module.

在这里插入图片描述

3、加载和保存优化器?这里我感觉很复杂,我没有使用保存优化器,我只是使用保存的断点做预训练和测试,所以不需要优化器断点训练。有懂的大佬可以留言!

4、DDP很难用,各种bug和小技巧,主要是文档写的不好,ε=(´ο`*)))唉

5、请不要忘了看看国外网友的参考,写的都很好:

https://discuss.pytorch.org/t/what-is-the-proper-way-to-checkpoint-during-training-when-using-distributed-data-parallel-ddp-in-pytorch/139575

https://github.com/pytorch/pytorch/issues/23138

你可能感兴趣的:(pytorch,分布式)