hfai.pl
Pytorch Lightning(简称 pl) 是在 PyTorch 基础上进行封装的库,它能帮助开发者脱离 PyTorch 一些繁琐的细节,专注于核心代码的构建,在 PyTorch 社区中备受欢迎。hfai.pl 是 high-flyer 对 pl 的进一步封装,能更加轻松的适配各种集群特性,带来更好的使用体验。本文将为大家详细介绍优化细节。
若集群每个计算节点有 x 张 GPU。用户提交任务时需选定节点数量 N,则该任务可获得 N*x 个 GPU。每个进程中全局环境变量含义如下:
world_size (hfai): 节点数量,用 N 表示
rank (hfai): 节点 id,用 n 表示,n 属于 0 ~ N-1
local_rank (hfai): GPU id, 用 k 表示,k 属于 0 ~ x-1
与 PyTorch init_process_group 的变量有如下对应:
world_size (PyTorch): 进程数量,每个 GPU 对应一个进程,因此总进程数目为总 GPU 数目,计算方法为 N*x
rank (PyTorch): 进程 id,利用节点 id 和 GPU id 可以计算出进程的 id,计算方法为 n*x+k
在一般的 PyTorch 代码中,我们需要进行如下的分布式初始化:
ip = os.environ.get("MASTER_ADDR", "127.0.0.1")
port = os.environ.get("MASTER_PORT", "2223")
hosts = int(os.environ.get("WORLD_SIZE", 1)) # number of nodes
rank = int(os.environ.get("RANK", 0)) # node id
gpus = torch.cuda.device_count() # gpus per node
dist.init_process_group(
backend="nccl",
init_method=f"tcp://{ip}:{port}",
world_size=hosts*gpus,
rank=rank*gpus+local_rank
)
torch.cuda.set_device(local_rank)
pl 对分布式初始化进行了封装,用户不需要自己初始化,只需指定节点数量和 GPU 数量即可。在萤火集群上,我们还提供了 HFAIEnvironment 环境插件,只需要几行代码即可将 pl 的分布式参数无缝适配到萤火集群当中。
from hfai.pl import HFAIEnvironment
trainer = pytorch_lightning.Trainer(
gpus=x, num_nodes=N,
plugins=[HFAIEnvironment()] # 定义 HFai 环境并作为插件输入
)
为了规避跨 numa 传输带来的带宽损耗,建议对训练代码进行 numa 绑定。在一般的 PyTorch 代码中,我们通过使用 hfai.multiprocessing 启动多进程,并指定 bind_numa=True 来进行 numa 的绑定,代码如下:
import torch
import hfai
def main(gpu_id):
torch.cuda.set_device(gpu_id)
# ......
if __name__ == "__main__":
ngpus = hfai.utils.num_gpus() # 调用 cuda 函数会导致子进程产生错误
hfai.multiprocessing.fork(main, args=(), nprocs=ngpus, bind_numa=True)
pl 中对多进程的启动进行了封装,用户只需要在 strategy 中指定使用 ddp 或者 ddp_spawn 就可以指定多进程的启动方式。因此我们提供了两种新的 strategy:ddp_bind_numa 和 ddp_spawn_bind_numa,用户选择这两种 strategy,就可以在指定多进程启动方式的同时绑定 numa,使用方式如下所示:
# 使用 ddp_bind_numa 或者 ddp_spawn_bind_numa
trainer = pytorch_lightning.Trainer(strategy="ddp_spawn_bind_numa")
hfreduce 是幻方 AI 自研的高性能多卡并行通信工具,其能够更高效的在多显卡之间交换梯度信息,加速模型训练。
在一般的 PyTorch 代码中,我们通过使用 hfai.nn.parallel.DistributedDataParallel 替换 PyTorch 自带的 torch.nn.parallel.DistributedDataParallel 即可使用 hfreduce。
由于 pl 中对 DDP 的初始化进行了封装,因此我们提供了 hfreduce_bind_numa 和 hfreduce_spawn_bind_numa 两种 strategy。用户在初始化 Trainer 的时候进行如下指定即可。
# 使用 hfreduce_bind_numa 或者 hfreduce_spawn_bind_numa
trainer = pl.Trainer(strategy="hfreduce_bind_numa")
幻方萤火平台采用任务级分时调度的底层设计,给每个任务分配集群运行时间片。按照分时调度的方案规则,训练任务会被中断、并自动调起。因此在该平台上运行的代码需要在训练循环中进行打断信号的捕获和参数保存,以便之后恢复训练。断点保存的示例代码如下:
for epoch in range(epochs):
for step in range(len(data_batch)):
if hfai.distributed.get_rank() == 0 and gpu_id == 0: # 在0号节点的0号进程上接收集群调度信息
if hfai.client.receive_suspend_command():
model.save() # 保存模型、迭代器等参数到文件
time.sleep(5) # 最多预留5秒完成断点保存,之后会被强制打断
hfai.client.go_suspend() # 发送准备好被打断的信号
pl 对训练循环进行了封装,因此我们提供了 ModelCheckpointHF 作为回调,在收到打断信号的 step 进行模型的保存,在每次任务启动时,检测是否需要从上一个断点恢复训练,使用方法如下所示:
from hfai.pl import ModelCheckpointHF
output_dir = 'hfai_out'
cb = ModelCheckpointHF(dirpath=output_dir)
trainer = pytorch_lightning.Trainer(callbacks=[cb]) # 自动处理集群打断信号
ckpt_path = f'{output_dir}/{cb.CHECKPOINT_NAME_SUSPEND}.ckpt' # 检查是否有断点模型被保存
ckpt_path = ckpt_path if os.path.exists(ckpt_path) else None
trainer.fit(
model_module,
ckpt_path=hfai_suspend_ckpt_path # 自动恢复训练
)
幻方 AI 依托萤火二号集群,对 PyTorch 框架进行了深度优化,结合萤火集群的特点,对一些常用的 AI 算子重新研发,提升效率,进一步提升了模型整体的训练效率。
在 pl 框架中,我们只需要增加如下一行代码,就可以将 PyTorch 中的算子转换成 hfai 优化后的算子:
import pytorch_lightning as pl
class ToyNetModule(pl.LightningModule):
# ...
pl_module = ToyNetModule()
model_module = nn_to_hfai(pl_module) # 将算子转换为 hfai 算子
为了在 PyTorch Lightning 中使用 ffrecord 的 Dataloader,我们需要在 Dataloader 设置 skippable=False:
from ffrecord.torch import Dataset, DataLoader
class MyDataset(Dataset)
# ...
dataset = MyDataset(...)
dataloader = DataLoader(dataset, batch_size, num_workers=num_workers, skippable=False)
下面提供一个完整的示例帮助大家理解:
from hfai.pl import HFAIEnvironment
from hfai.pl import ModelCheckpointHF
import pytorch_lightning as pl
class ToyNetModule(pl.LightningModule):
...
output_dir = 'hfai_out' # 模型保存路径
cb = ModelCheckpointHF(dirpath=output_dir) # 可以接收集群打断信号的回调类
trainer = pl.Trainer(
gpus=x, # 每个节点 x 个 GPU
num_nodes=N, # 节点数量
strategy="ddp_bind_numa", # 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa
plugins=[HFAIEnvironment()], # 自动适配 HFAI 分布式环境
callbacks=[cb] # 自动处理集群打断信号
)
model_module = nn_to_hfai(ToyNetModule()) # 将算子转换为 hfai 算子
ckpt_path = f'{output_dir}/{cb.CHECKPOINT_NAME_SUSPEND}.ckpt' # 判断之前是否保存了断点模型
ckpt_path = ckpt_path if os.path.exists(ckpt_path) else None
trainer.fit(
model_module,
ckpt_path=ckpt_path # 自动恢复训练
)
通过上面的简单适配,PyTorch Lightning 就能够轻松适配不同集群啦。
END
High-Flyer AI
我们希望让更多“想象力”和“创造力”生长。期待与各方科学家及开发者们一同共建AI时代。
幻方 | 技术博客幻方AI专注前沿科技研发,以AI技术激发创造力和想象力,让人类更多梦想变成现实。幻方AI包含「萤火」深度学习训练平台 、幻方量化(使用AI进行投资的对冲基金)、AI基础科学研究。https://www.high-flyer.cn/blog/