昇腾910上分布式加载模型与增量训练

问题描述:

【功能模块】

mindspore.train.serialization.load_distributed_checkpoint

【操作步骤&问题现象】

1、在Modelarts 16个节点128张昇腾910上从头训练130亿模型,模型代码修改自:https://gitee.com/mindspore/mindspore/tree/r1.3/model_zoo/official/nlp/pangu_alpha

2、使用相同计算资源,和 load_distributed_checkpoint API 加载第 1 步训练得到的模型,并做增量训练

3、第 1 步能正常训练,第 2 步加载过程会因为超出内存而报错,每个节点8张卡会分配2048GB内存,load_distributed_checkpoint 加载模型时会超出该内存

4、注:相同代码,如果先训练13亿模型,然后再使用load_distributed_checkpoint 加载训练的13亿模型,做增量训练,改过程没有问题,损失、准确率都是正常的

【截图信息】

下图是一个节点的资源使用情况,从62分钟开始,模型通过 load_distributed_checkpoint 开始加载之前保存的ckpt,然后memUsage从0增加到95%左右,然后训练进程被系统kill掉。

昇腾910上分布式加载模型与增量训练_第1张图片 

前台日志不会报错,只会多下面这样一条异常,因为捕捉不到外部kill信号,plog也没办法传到桶内。

/bin/sh: line 1: 15 Killed /bin/bash run_train.sh 's3://aix/PanGu/' 'PanGu/train.py' '/tmp/log/aix-b-model.log' --'epoch_size'='2' --'mode'='13B' --'obs_version_suffix'='increment' --'pre_trained'='obs://aix/saved_model/13B_increment_lm' --data_url='s3://aix/data_test/' --train_url='s3://aix/PanGu/saved_model/V0089/'

解答:

您好,您能否提供ckpt和embedding词表大小的相关信息:

1、每个ckpt大小是多大

2、embedding词表大小是多大

我们怀疑是ckpt太大了,多张卡同时load可能导致oom,以下是我们的修改及优化建议:

1. 如果训练前后的卡数不变,不需要调用load_distributed_checkpoint接口,直接调用load_checkpoint接口即可。即,每张卡只load自己的ckpt,不需要load所有的ckpt

2. 如果训练卡数发生变化,例如从128卡到64卡

    a. 没有开启优化器并行,那么每mp份是一份完整的模型,例如mp=8,每8个ckpt是一个完整的模型,每卡只需要调用load_distributed_checkpoint进来8个ckpt。或者每卡调用load_checkpoint装载一个ckpt。

    b. 开启了优化器并行,那么所有的ckpt才是一份完整的模型,需要调用load_distributed_checkpoint接口,并且对ckpt进行瘦身

3. 对ckpt进行瘦身。因为embedding默认是数据并行,假设其本身占用2GB,那么128卡就会占用256GB。一台机器8卡同时去load,就会瞬时需要256GB*8大小的内存。

    a. 瘦身的过程:

        i. 将每个ckpt中的embedding单独存放,即先删除所有ckpt中的embedding变量。将其单独存放一份

        ii. 如果不需要优化器转台,也可以现将所有ckpt中的优化器变量进行删除

你可能感兴趣的:(技术博客,分布式,自然语言处理,人工智能)