通俗理解torch.distributed.barrier()工作原理

 1、背景介绍

      在pytorch的多卡训练中,通常有两种方式,一种是单机多卡模式(存在一个节点,通过torch.nn.DataParallel(model)实现),一种是多机多卡模式(存在一个节点或者多个节点,通过torch.nn.parallel.DistributedDataParallel(model),在单机多卡环境下使用第二种分布式训练模式具有更快的速度。pytorch在分布式训练过程中,对于数据的读取是采用主进程预读取并缓存,然后其它进程从缓存中读取,不同进程之间的数据同步具体通过torch.distributed.barrier()实现。

 

2、通俗理解torch.distributed.barrier()

      代码示例如下:

def create_dataloader():
    #使用上下文管理器中实现的barrier函数确保分布式中的主进程首先处理数据,然后其它进程直接从缓存中读取
    with torch_distributed_zero_first(rank):
        dataset = LoadImagesAndLabels()


from contextlib import contextmanager

#定义的用于同步不同进程对数据读取的上下文管理器
@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
    Decorator to make all processes in distributed training wait for each local_master to do something.
    """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield   #中断后执行上下文代码,然后返回到此处继续往下执行
    if local_rank == 0:
        torch.distributed.barrier()

(1)进程号rank理解

在多进程上下文中,我们通常假定rank 0是第一个进程或者主进程,其它进程分别具有0,1,2不同rank号,这样总共具有4个进程。

(2)单一进程数据处理

通常有一些操作是没有必要以并行的方式进行处理的,如数据读取与处理操作,只需要一个进程进行处理并缓存,然后与其它进程共享缓存处理数据,但是由于不同进程是同步执行的,单一进程处理数据必然会导致进程之间出现不同步的现象,为此,torch中采用了barrier()函数对其它非主进程进行阻塞,来达到同步的目的。

(3)barrier()具体原理

在上面的代码示例中,如果执行create_dataloader()函数的进程不是主进程,即rank不等于0或者-1,上下文管理器会执行相应的torch.distributed.barrier(),设置一个阻塞栅栏,让此进程处于等待状态,等待所有进程到达栅栏处(包括主进程数据处理完毕);如果执行create_dataloader()函数的进程是主进程,其会直接去读取数据并处理,然后其处理结束之后会接着遇到torch.distributed.barrier(),此时,所有进程都到达了当前的栅栏处,这样所有进程就达到了同步,并同时得到释放。

参考文章:https://stackoverflow.com/questions/59760328/how-does-torch-distributed-barrier-work

参考代码:https://github.com/ultralytics/yolov5

你可能感兴趣的:(torch,机器学习,分布式,pytorch)