详解PyTorch FSDP数据并行(Fully Sharded Data Parallel)

1. 背景介绍

全切片数据并行(Fully Sharded Data Parallel,简称为FSDP)是数据并行的一种新的方式,FSDP最早是在2021年在FairScale-FSDP中提出的,后来合入了PyTorch 1.11版本中。微软之前Deepspeed框架中提出过三种级别的ZERO算法,FSDP可以看成是ZERO-3的实现。

2. 详细介绍

传统的数据并行(DDP)是在每一个GPU卡上保存整个model的参数/梯度/优化器状态, 然后对数据集切分为 N N N 个shard分片给不同的GPU进行训练,计算完梯度后通过all-reduce通信来做梯度的融合。如下图:
详解PyTorch FSDP数据并行(Fully Sharded Data Parallel)_第1张图片

在FSDP中的主要思路是想办法把model的梯度/优化器状态/参数都进行切分操作,每个GPU只存部分的参数信息,也就是在ZERO-3的思路。为了能把所有的参数进行分片处理,核心在于要把DDP中的all-reduce操作拆解为reduce-scatter和all-gather 操作。

详解PyTorch FSDP数据并行(Fully Sharded Data Parallel)_第2张图片

如下图,在进行FSDP前向计算其中的一层Layer时,由于每个GPU都只保存了部分参数,所以需要先通过all-gather操作获得全部的参数;同理,在反向计算过程中,也需要通过all-gather操作,获得全部的参数;最后计算出来的梯度只是部分的结果,需要通过reduce-scatter通信进行累加操作,最终每个GPU卡分别只更新自己那部分参数(也就是local本地weight更新)。

详解PyTorch FSDP数据并行(Fully Sharded Data Parallel)_第3张图片

FSDP的应用是对原有model layers加上了一层wrapper封装,只有在FSDP实例中的layer才会在前向和后向过程中执行gather相关操作,通过切分可以利用相同的显存大小训练更大的模型。为了进一步提升显存利用率,FSDP也支持把不活跃的实例全部offload调出到CPU上去。

FSDP计算过程的伪码如下:

FSDP forward pass:
    for layer_i in layers:
        all-gather full weights for layer_i
        forward pass for layer_i
        discard full weights for layer_i

FSDP backward pass:
    for layer_i in layers:
        all-gather full weights for layer_i
        backward pass for layer_i
        discard full weights for layer_i
        reduce-scatter gradients for layer_i

在PyTorch中的示例如下, 通过FullyShardedDataParallel实现对model的封装,通过CPUOffload来决定采用哪种策略把参数调到CPU上。

from torch.distributed.fsdp import (
   FullyShardedDataParallel,
   CPUOffload,
)
from torch.distributed.fsdp.wrap import (
   default_auto_wrap_policy,
)
import torch.nn as nn
 
class model(nn.Module):
   def __init__(self):
       super().__init__()
       self.layer1 = nn.Linear(8, 4)
       self.layer2 = nn.Linear(4, 16)
       self.layer3 = nn.Linear(16, 4)
 
model = DistributedDataParallel(model())
fsdp_model = FullyShardedDataParallel(
   model(),
   fsdp_auto_wrap_policy=default_auto_wrap_policy,
   cpu_offload=CPUOffload(offload_params=True),
)

使用FSDP训练GPT-175B和GPT-1T参数量大小的模型,词表大小50K,fp16的精度和使用SGD的优化器。

在这里插入图片描述

结果如下,使用FSDP时在GPU卡数增大的情况下,对GPU单卡的吞叶没有影响;在A100-40G机器下增大batch_size 但吞吐没有增加, 瓶颈不在于通信而是CUDA cache的分配到了瓶颈;当换为A100-80G机器时,CUDA cache的分配问题得到解决后,增大batch_size后吞吐进一步增加。

详解PyTorch FSDP数据并行(Fully Sharded Data Parallel)_第4张图片

3. 参考

  • Fully Sharded Data Parallel: faster AI training with fewer GPUs
  • nccl-collectives
  • Introducing PyTorch Fully Sharded Data Parallel (FSDP) API
  • PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel

你可能感兴趣的:(训练框架,大模型,pytorch,人工智能,python)