PyTorch Lightning入门教程(二)

文章目录

  • PyTorch Lightning入门教程(二)
    • 前言
    • 单机多卡
    • 多机多卡
    • 半精度训练

PyTorch Lightning入门教程(二)

前言

pytorch lightning提供了比较方便的多GPU训练方式,同时包括多种策略和拓展库,比如ddp,fairscale等,下面将从单机多卡和多机多卡两个角度介绍。

单机多卡

pytorch lightning的官网提供了比较详细的使用方法,可以参考https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu.html

一般来说,只要在trainer的参数中指定了参数gpus,就可以使用多GPU运行了,例如:

Trainer(gpus=4)  # 使用4块显卡进行计算
Trainer(gpus=[0, 2]) # 使用0和2号显卡进行计算

当然这里支持多种写法来加载GPU,这里有说明https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_basic.html,可以参考。

需要注意的是这有一个strategy参数,可以参考https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html,pytorch lightning支持多种训练框架,包括了dp,dpp,horovod,bagua等。所以这里别忘记设定好所使用的多GPU框架。

多机多卡

这部分和单机多卡的区别不大,只是增加了一个参数num_nodes,来设定所使用的机器数量,例如:

Trainer(gpus=4, num_nodes=3) # 表示使用3台机器,共12张显卡 

之后需要注意的是启动进程,可以参考https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html

针对上面那个例子,如果想要启动训练,需要在三台机器上分别输入下面命令:

master_ip=192.168.1.10
机器1:
CUDA_VISIBLE_DEVICES="0,1,2,3" MASTER_ADDR=$master_ip MASTER_PORT=45321 WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=0 python main.py
机器2:
CUDA_VISIBLE_DEVICES="0,1,2,3" MASTER_ADDR=$master_ip MASTER_PORT=45321 WORLD_SIZE=3 NODE_RANK=1 LOCAL_RANK=0 python main.py
机器3:
CUDA_VISIBLE_DEVICES="0,1,2,3" MASTER_ADDR=$master_ip MASTER_PORT=45321 WORLD_SIZE=3 NODE_RANK=2 LOCAL_RANK=0 python main.py

这样就可以在3台机器上进行训练了

如果训练的时候报错或者卡死,无法运行,可以试试在上面的命令中加上NCCL_IB_DISABLE=1,将NCCL_IB_DISABLE设置为1来禁止使用IB/RoCE传输方式,转而使用IP传输,对于不支持RDMA技术的服务器,这个值设置为1可以解决部分训练卡死的问题。如果网络接口不能被自动发现,则手工设置NCCL_SOCKET_IFNAME=eth0,如果还有问题,就设置NCCL的debug模式NCCL_DEBUG=INFO

注:需要注意的是,由于目前pytorch lightning还在开发发展中,很多新的功能只有新版本的才有,所以需要注意自己的pytorch lightning的版本,比如本文中提到的strategy策略,笔者在使用的时候,发现对于1.5.10的pytorch lightning版本,使用fairscale策略,无法在多机多卡的情况下使用,后来升级到1.7.4版本之后,才可以正常使用。

半精度训练

这里设置比较简单,只需要新加一个参数precision即可:

Trainer(gpus=4, num_nodes=3, precision=16)

但是需要注意的是并不是所有的情况下都支持半精度训练,比如DataParallel就不支持半精度,而DistributedDataParallel是可以支持的,所以平时多机多卡训练就不要用DataParallel。

你可能感兴趣的:(深度学习,python,pytorch,深度学习,人工智能)