截至2023年,大型语言模型的参数量已突破万亿级别(如Google PaLM 2达到3400亿参数),单卡显存容量(NVIDIA A100 80GB)与计算能力(312 TFLOPS)面临严峻挑战。分布式训练通过多维度并行策略实现:
本文将深入剖析三大并行策略的数学本质。
定义:设有K个Worker,各Worker本地梯度为 g k = ∇ θ L k ( θ ) g_k = \nabla_\theta L_k(\theta) gk=∇θLk(θ),学习率 η \eta η,更新规则:
θ t + 1 = θ t − η ⋅ 1 K ∑ k = 1 K g k ( t ) \theta_{t+1} = \theta_t - \eta \cdot \frac{1}{K}\sum_{k=1}^K g_k^{(t)} θt+1=θt−η⋅K1k=1∑Kgk(t)
收敛条件(依据[Li et al., 2014]):
假设损失函数 L L L满足L-smooth且强凸,当学习率满足 η < 1 L \eta < \frac{1}{L} η<L1时,迭代误差界为:
E [ ∣ ∣ θ t − θ ∗ ∣ ∣ 2 ] ≤ ( 1 − η μ ) t ∣ ∣ θ 0 − θ ∗ ∣ ∣ 2 + η σ 2 K μ \mathbb{E}[||\theta_t - \theta^*||^2] \leq (1 - \eta\mu)^t ||\theta_0 - \theta^*||^2 + \frac{\eta\sigma^2}{K\mu} E[∣∣θt−θ∗∣∣2]≤(1−ημ)t∣∣θ0−θ∗∣∣2+Kμησ2
其中 μ \mu μ为强凸系数, σ 2 \sigma^2 σ2为梯度方差。
对于N个设备,参数向量 v ∈ R d v \in \mathbb{R}^d v∈Rd,通信复杂度:
关键代码路径:torch/nn/parallel/distributed.py
class DistributedDataParallel(Module):
def _sync_params(self):
# 参数广播同步
_broadcast_coalesced(self.device_ids[0], coalesced)
def _collect_gradients(self):
# 梯度桶聚合
for bucket in self._grad_acc_buckets:
grads = [param.grad for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, group=self.process_group)
梯度分桶优化:
bucket_cap_mb
调节FP16通信优化:
from torch.cuda.amp import GradScaler
scaler = GradScaler()
with autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
内存节省:梯度存储减少50%,通信量降低50%
设网络有L层,拆分到K个设备:
Device k : f s k e k ( x ) = f e k ∘ . . . ∘ f s k ( x ) \text{Device}_k: f_{s_k}^{e_k}(x) = f_{e_k} \circ ... \circ f_{s_k}(x) Devicek:fskek(x)=fek∘...∘fsk(x)
其中 s k = ( k − 1 ) ⋅ L / K + 1 s_k = (k-1)\cdot L/K + 1 sk=(k−1)⋅L/K+1, e k = k ⋅ L / K e_k = k\cdot L/K ek=k⋅L/K
以GEMM运算为例,矩阵 W ∈ R m × n W \in \mathbb{R}^{m \times n} W∈Rm×n拆分为:
W = [ W 1 ∣ W 2 ] , W 1 ∈ R m × n / 2 , W 2 ∈ R m × n / 2 W = [W_1 | W_2], \quad W_1 \in \mathbb{R}^{m \times n/2}, W_2 \in \mathbb{R}^{m \times n/2} W=[W1∣W2],W1∈Rm×n/2,W2∈Rm×n/2
前向传播:
Y = W X = [ W 1 W 2 ] [ X 1 X 2 ] = W 1 X 1 + W 2 X 2 Y = WX = [W_1 W_2] \begin{bmatrix} X_1 \\ X_2 \end{bmatrix} = W_1X_1 + W_2X_2 Y=WX=[W1W2][X1X2]=W1X1+W2X2
需要跨设备AllReduce求和。
class ParallelSelfAttention(nn.Module):
def __init__(self, hidden_size, num_heads):
self.query = ColumnParallelLinear(hidden_size, hidden_size)
self.key = ColumnParallelLinear(hidden_size, hidden_size)
self.value = ColumnParallelLinear(hidden_size, hidden_size)
self.dense = RowParallelLinear(hidden_size, hidden_size)
def forward(self, x):
q = self.query(x) # 列拆分
k = self.key(x)
v = self.value(x)
# 本地计算Attention
attn_out = scaled_dot_product_attention(q, k, v)
return self.dense(attn_out) # 行拆分
from deepspeed.moe import MOELayer
class MoEBlock(nn.Module):
def __init__(self, experts, gate):
self.experts = experts # 专家列表
self.gate = gate # 门控网络
def forward(self, x):
logits = self.gate(x)
weights, indices = topk(logits, k=2)
# 跨设备路由
expert_inputs = all_to_all(x, indices)
expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(len(self.experts))]
outputs = all_to_all(expert_outputs)
return outputs * weights
设阶段数为S,微批次数为M,每个阶段计算时间T_s:
总时间: T t o t a l = ( S + M − 1 ) ⋅ max s T s T_{total} = (S + M - 1) \cdot \max_s T_s Ttotal=(S+M−1)⋅maxsTs
气泡比例: Bubble Ratio = S − 1 S + M − 1 \text{Bubble Ratio} = \frac{S-1}{S + M - 1} Bubble Ratio=S+M−1S−1
当M=4S时,气泡比例可降至约20%。
传统流水线内存占用: O ( S ⋅ M ) O(S \cdot M) O(S⋅M)
1F1B调度内存占用: O ( S + M ) O(S + M) O(S+M)
关键代码路径:torchgpipe/pipeline.py
def pipeline(devices, chunks):
for micro_batch in range(chunks):
# 前向传播
for stage in range(len(devices)):
if micro_batch >= stage:
compute_forward(devices[stage], micro_batch - stage)
# 反向传播
for stage in reversed(range(len(devices))):
if micro_batch >= stage:
compute_backward(devices[stage], micro_batch - stage)
class DynamicPipelineScheduler:
def __init__(self, stages, max_microbatches=8):
self.ready_queue = deque()
self.stages = stages
def schedule(self):
while not all_stages_idle():
for stage in self.stages:
if stage.can_accept_microbatch():
mb = self.ready_queue.popleft()
stage.submit(mb)
advance_clock()
并行维度 | 配置 | 通信模式 |
---|---|---|
数据并行 | DP=8 | AllReduce |
流水线并行 | PP=12 | P2P通信 |
张量并行 | TP=8 | AllReduce/AllGather |
NCCL拓扑感知示例:
# 设置NCCL通信拓扑
export NCCL_ALGO=Tree
export NCCL_SOCKET_IFNAME=ib0
export NCCL_IB_HCA=mlx5
带宽计算工具:
from torch.distributed import benchmark
benchmark.all_reduce_benchmark(1024**3) # 测试1GB数据AllReduce时间
def compute_comm_ratio(profiler_data):
compute_time = profiler_data['cuda_kernel_time']
comm_time = profiler_data['nccl_time']
return compute_time / comm_time
# 使用PyTorch Profiler
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3)
) as prof:
train_step()
Alpa框架原理:
参数服务器架构:
class ParameterServer(nn.Module):
def __init__(self, model):
self.weights = model.state_dict()
def update(self, grad):
with torch.no_grad():
for k in grad:
self.weights[k] -= lr * grad[k]
class Worker(nn.Module):
def step(self):
grad = compute_grad()
dist.send(grad, dst=0) # 发送到参数服务器
基于量子纠缠的梯度同步:
Δ θ = 1 K ∑ k = 1 K Δ θ k → ⨂ k = 1 K ∣ Δ θ k ⟩ \Delta \theta = \frac{1}{K} \sum_{k=1}^K \Delta \theta_k \rightarrow \bigotimes_{k=1}^K |\Delta \theta_k\rangle Δθ=K1k=1∑KΔθk→k=1⨂K∣Δθk⟩
当前研究进展:IBM Quantum已实现4节点参数同步(2023 Nature论文)