训练大型语言模型 (LLM) 需要巨大的计算资源和内存。为了高效地训练这些模型,我们需要采用各种并行策略,将计算和数据分布到多个 GPU 或设备上。Llama 作为当前最流行的开源大模型之一,其训练代码中采用了多种并行技术。本文将深入 Llama 的训练代码,分析其并行训练方案,主要关注参数并行和部分结构参数共享。
常见的并行训练策略包括:
Llama 主要采用了数据并行、张量并行,以及一些结构参数共享的优化。
Llama 使用了 ZeRO (Zero Redundancy Optimizer) 技术,这是一种强大的内存优化方法,它结合了数据并行和模型并行。ZeRO 的核心思想是将模型状态 (权重、梯度和优化器状态) 分片到多个 GPU 上,从而减少每个 GPU 的内存占用。
ZeRO 有三个阶段:
Llama 主要使用了 ZeRO-3,将模型参数、梯度和优化器状态都分片到多个 GPU 上。
在 Llama 的训练代码中, 以 torch.distributed.fsdp
库为例 (Fully Sharded Data Parallel, FSDP),它实现了 ZeRO-3 的功能。
以下是一个简化的 FSDP 参数分片示例:
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
import functools
# 假设我们有一个简单的 Transformer 模型
class TransformerLayer(torch.nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.linear1 = torch.nn.Linear(hidden_dim, 4 * hidden_dim)
self.linear2 = torch.nn.Linear(4 * hidden_dim, hidden_dim)
def forward(self, x):
x = self.linear1(x)
x = torch.nn.functional.relu(x)
x = self.linear2(x)
return x
# 初始化分布式环境
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")
# 模型和优化器
hidden_dim = 768
model = TransformerLayer(hidden_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 使用 FSDP 包装模型
# 使用自动包装策略
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerLayer,},
)
model = FSDP(model, fsdp_auto_wrap_policy=auto_wrap_policy, device_id=device)
# 模拟训练数据
x = torch.randn(1, 10, hidden_dim, device=device)
# 前向传播
y = model(x)
# 反向传播
loss = y.sum()
loss.backward()
# 优化器更新
optimizer.step()
# 清空梯度
optimizer.zero_grad()
print(f"Rank {rank}: 训练完成")
代码解释:
dist.init_process_group("nccl")
初始化分布式训练环境,使用 NCCL 后端。FullyShardedDataParallel
包装模型。这会将模型参数分片到多个 GPU 上。
transformer_auto_wrap_policy
会自动将模型的每一层都用 FSDP 包装起来。step
方法。运行方式:
你需要使用 torchrun
(或 torch.distributed.launch
)来启动这个脚本,例如:
torchrun --nproc_per_node=4 fsdp_example.py
这将使用 4 个 GPU 来训练模型。
原理说明:
model = FSDP(model, ...)
这一步,模型参数被分片到 4 个 GPU 上。每个 GPU 只存储一部分参数。通过这种方式,FSDP 显著减少了每个 GPU 的内存占用,使得训练更大的模型成为可能。
除了参数分片,Llama 还采用了一些结构参数共享的优化,以进一步减少内存占用和提高训练效率。
例如在 Transformer 的多头注意力 (Multi-Head Attention) 机制中,不同 head 的 query
, key
, value
矩阵的计算通常是独立的。Llama 通过共享 key 和 value 矩阵,减少了参数量和计算量。更具体地说,llama使用了分组注意力机制(Grouped-Query Attention)。
GQA 介于标准的多头注意力 (MHA) 和 Multi-Query Attention (MQA) 之间。
例如:
假设我们有 8 个 head,可以将它们分成 4 个组,每个组 2 个 head。这样,我们就只有 4 个 K 矩阵和 4 个 V 矩阵,而不是 8 个。
代码示例 (简化版)
import torch
import torch.nn as nn
class GroupedQueryAttention(nn.Module):
def __init__(self, embed_dim, num_heads, num_groups):
super().__init__()
self.num_heads = num_heads
self.num_groups = num_groups
self.head_dim = embed_dim // num_heads
self.group_size = num_heads // num_groups
self.q_proj = nn.Linear(embed_dim, embed_dim)
# 共享 K, V 矩阵
self.k_proj = nn.Linear(embed_dim, num_groups * self.head_dim)
self.v_proj = nn.Linear(embed_dim, num_groups * self.head_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, embed_dim = x.shape
# 计算 Q, K, V
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch_size, seq_len, self.num_groups, self.head_dim)
v = self.v_proj(x).view(batch_size, seq_len, self.num_groups, self.head_dim)
# 将 head 分组
q = q.view(batch_size, seq_len, self.num_groups, self.group_size, self.head_dim)
# 计算注意力
attn_scores = torch.einsum("bqlgd,bklhd->bqgkh", q, k) / (self.head_dim ** 0.5)
attn_probs = attn_scores.softmax(dim=-1)
attn_output = torch.einsum("bqgkh,bklhd->bqlgd", attn_probs, v)
# 拼接 head
attn_output = attn_output.reshape(batch_size, seq_len, embed_dim)
# 输出投影
output = self.out_proj(attn_output)
return output
# 示例
embed_dim = 768
num_heads = 8
num_groups = 4
model = GroupedQueryAttention(embed_dim, num_heads, num_groups)
x = torch.randn(1, 10, embed_dim)
y = model(x)
print(y.shape)
代码解释:
num_groups
:将 head 分成多少个组。k_proj
和 v_proj
:只输出 num_groups
个 head 的 K 和 V 矩阵。q
分成 num_groups
个组,每个组 group_size
个 head。GQA 的优势:
llama的GQA实现在llama/model.py
文件中,class Attention(nn.Module)
类下的forward
函数中,更具体地,体现在self.num_heads
和self.num_kv_heads
的参数上, 分别控制query
和kv
的head
数量,num_kv_heads
小于num_heads