apex是nvidia开发的基于pytorch的多gpu训练加速包。下载:
git clone https://github.com/NVIDIA/apex
cd apex
#安装时碰倒问题pip._internal.operations.install.legacy.LegacyInstallFailure,因此采用另一个branch
git checkout f3a960f80244cf9e80558ab30f7f7e8cbf03c0a0
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
初始化该流程并与其他流程合并。这过程是"阻塞",在所有进程都加入之前,没有进程将继续。
import torch.distributed as dist
from apex.parallel import DistributedDataParallel as DDP
from apex import amp
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=args.world_size,
rank=rank
)
init_method
告诉进程组去哪儿寻找相关设置。本例中将会去环境变量中寻找MASTER_ADDR
和MASTER_PORT
。
相关的概念:
import torch.multiprocessing as mp
def train():
pass
def main():
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '8888'
mp.spawn(train, nprocs=4, args=(config,))
if __name__ == "__main__":
main()
使用如下命令:
python -m torch.distributed.launch --nproc_per_node=4 train.py --a b --c d
其中python -m torch.distributed.lanuch
表示调用torch.distributed.lanuch.py
文件来进行分布式训练,-m
表示将后面的torch.distributed.lanuch
当做模块加载。 --nproc_per_node
通常与GPU
数量保持一致。train.py
才是真正的训练文件,后面是相关参数。
执行过程中,其实相当于会自动在terminal中执行多个相同的命令:
python', '-u', 'train.py', '--local_rank=3' --a b --c d
只是每次执行时会赋值不同的local_rank
,从0
到gpus-1
.
def create_dataloader():
if rank == 0:
dataset = LoadImagesAndLabels()
torch.distributed.barrier()
如果rank
不为0,create_dataloader()函数的进程不是主进程,会执行相应的barrier(),此时会设置一个阻塞栅栏,让此进程处于等待状态,等待所有进程到达栅栏处(包括主进程数据处理完毕);如果执行create_dataloader()函数的进程是主进程,其会直接去读取数据并处理,然后其处理结束之后会接着遇到torch.distributed.barrier(),此时,所有进程都到达了当前的栅栏处,这样所有进程就达到了同步,并同时得到释放。
如果要写一个损失函数,则可以使用:
import torch
from torch import distributed
from torch import autograd
from torch.nn.parallel import DistributedDataParallel as DDP
class awesome_allgather_function(autograd.Function):
@staticmethod
def forward(ctx, input):
world_size = distributed.get_world_size()
# print(f"world size: {world_size}")
# create a destination list for the allgather. I'm assuming you're gathering from 3 workers.
allgather_list = [torch.empty_like(input) for _ in range(world_size)]
#if distributed.get_rank() == 0:
# import IPython;IPython.embed()
# print(f"hahaha, {input.shape}") # torch.Size([2, 4])
# input is actually in multiple gpus, here we collect it to allgather_list
distributed.all_gather(allgather_list, input)
# print(f"allgather_list: {len(allgather_list)}")
return torch.cat(allgather_list, dim=0)
@staticmethod
def backward(ctx, grad_output): # distribute gradient to each gpu
#print_if_rank0("backward grad_output len", len(grad_output))
#print_if_rank0("backward grad_output shape", grad_output.shape)
grads_per_rank = grad_output.shape[0] // distributed.get_world_size()
rank = distributed.get_rank()
# We'll receive gradients for the entire catted forward output, so to mimic DataParallel,
# return only the slice that corresponds to this process's input:
sl = slice(rank * grads_per_rank, (rank + 1) * grads_per_rank)
#print("worker", rank, "backward slice", sl)
return grad_output[sl]
nominator = awesome_allgather_function.apply(nominator)
denominator = awesome_allgather_function.apply(denominator)
nominator = nominator.sum(0)
denominator = denominator.sum(0)
l = loss_fn(nominator, denominator)
l.backward()
打印和保存的时候使用:
def print():
if local_rank == 0:
print("...")
def save_checkpoint():
if local_rank == 0:
save(...)
对于DDP
程序,仅仅使用ctrl+c
不能kill
后台程序,使用如下命令即可全完kill
掉:
kill $(ps aux | grep train.py | grep -v grep | awk '{print $2}')