模型训练-3D并行

目录

1. 数据并行(Data Parallel)

1.1常规数据并行

 1.3 数据并行带来的显存优化效果

2. 模型并行

2.1 原理

 2.2 模型并行带来的显存优化结果

3. ZeRO

3.1 ZeRO1

 3.2 ZeRO2

 3.3 ZeRO3

 3.4 显存优化结果

4. 流水线并行

目录

1. 数据并行(Data Parallel)

1.1常规数据并行

 1.3 数据并行带来的显存优化效果

2. 模型并行

2.1 原理

 2.2 模型并行带来的显存优化结果

3. ZeRO框架

3.1 ZeRO1

 3.2 ZeRO2

 3.3 ZeRO3

 3.4 显存优化结果

4. 流水线并行

 参考资料:



1. 数据并行(Data Parallel)

1.1常规数据并行

        有一张显卡(例如rank 0显卡)专门用于存储模型参数、梯度信息和更新模型参数。将训练数据分成多份(份数等于显卡数量),每张卡上的模型参数相同,进行前向和反向传播后,每张卡上都计算得到对应部分数据的梯度,然后对多张卡上的梯度进行reduce操作,将平均后的梯度结果存放在专门的显卡上,然后在专门的显卡上利用优化器进行参数更新。最后将更新后的参数再broadcast到所有显卡上,重复上述过程

模型训练-3D并行_第1张图片

 1.2  distributed data parallel(分布式数据并行)

        区别:不需要专门的参数服务器

        初始时每张显卡上都有相同的模型参数,同样将训练数据均分成多份,每张卡上利用单独一小份的数据进行前向和反向得到梯度,然后将多张卡上的梯度参数all reduce到所有的显卡上,这样每张显卡上的梯度信息也是完全一致的,同时优化器的历史信息数据也是完全一致的,这样便可以在每一张显卡上单独进行参数更新,并且能够保证每张卡上更新后的模型参数也是完全一致的。

        总结:每张卡上单独进行参数更新

模型训练-3D并行_第2张图片

 1.3 数据并行带来的显存优化效果

transformer中,显卡上存储的模型中间结果(即是每一层的输入,也可以理解成上一层的输出)的维度是[batch,  Len,  Dim], 多卡数据并行后,每张显卡上存储的模型中间结果的维度变成[batch/卡数,  Len,  Dim]

模型训练-3D并行_第3张图片

2. 模型并行

2.1 原理

将参数矩阵(全连接层)按照行进行切分,使得在每张显卡上分别进行计算(前提:需要保证多张显卡上的输入XB是一样的,即多张显卡采用同一batch的输入数据),然后再将计算结果进行拼接

模型训练-3D并行_第4张图片

模型训练-3D并行_第5张图片

 2.2 模型并行带来的显存优化结果

模型参数、梯度和优化器参数,三大部分的显存都降低为原来的1/显卡数量

中间结果显存没有变化,因此不采用数据并行的话可能显存仍然很大导致在一张显卡上放不下。

3. ZeRO框架

3.1 ZeRO1

是基于数据并行的做法,但与原始的数据并行做法不同,ZeRO第1阶段和数据并行的区别在于

(1)每张卡上前向和反向得到梯度之后,采用Reduce Scatter得到每张显卡上对应部分(1/3)的平均梯度

(2)优化器参数也是采用对应部分的(1/3),然后去更新模型的部分参数(1/3)

(3)采用All gather操作将多张显卡上的模型参数拼接在一起

核心:每张显卡只获取一部分的梯度,只进行一部分的参数更新

模型训练-3D并行_第6张图片

 3.2 ZeRO2

与第一阶段不同,不需要全部反向传播完成之后再更新模型参数(反向传播完成之后再删除gradient),而是修改成:当计算出网络最后一层的梯度后,就采用reduce scatter得到最后一层参数的gradient*,然后进行参数更新,并且可以删去最后一层各张显卡上的gradient,然后对倒数第2层进行同样操作,依此类推。

模型训练-3D并行_第7张图片

 3.3 ZeRO3

与第2阶段相比,本质上是以时间换空间,将模型参数也进行了切分,在需要用到的时候再进行All gather,增加了显卡之间频繁通信时间

模型训练-3D并行_第8张图片

 3.4 显存优化结果

到ZeRO3,四大部分显存占用都降低为显卡数分之一。

模型训练-3D并行_第9张图片

4. 流水线并行

与模型并行有些类似,流水线并行是将模型不同的层分到不同的显卡上。每张显卡上只需要保留对应层的参数、梯度、优化器和中间结果,因此显存占用降低为显卡数分之一。

缺点:每个时刻只有一张显卡计算,其余显卡处于空闲状态

模型训练-3D并行_第10张图片

 参考资料:

多张显卡之间通信方式_佛系调参的博客-CSDN博客

5-4 BMTrain--Model Parallel(模型并行)_哔哩哔哩_bilibili

你可能感兴趣的:(深度学习,自然语言处理,人工智能,语言模型)