Pytorch——基于DistributedDataParallel单机多GPU并行之broadcast_buffers

在使用Pytorch进行单机多GPU并行时,往往有两种方式,一种是基于DataParallel,另一种是基于DistributedDataParallel。前者太简单了,且速度不如后者,所以不作过多讨论。咱们在使用DistributedDataParallel时,其中有个参数broadcast_buffers,官方给出的解释就是

broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
                  the module at beginning of the forward function.
                  (default: ``True``)

也就是说,每次进行forward时,都会对模型中buffers进行统一。因此作了以下小实验验证下这件事!

import os
import argparse

import torch as th
import torch.multiprocessing as mp
import torch.distributed as dist

os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'

class Mod(th.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = th.nn.Conv2d(3, 5, 3, padding=1)
        self.bn = th.nn.BatchNorm2d(5)


    def forward(self, z):
        print('forward_before\t{}'.format(self.bn.running_mean))
        t = self.lin(z)
        return (self.bn(t)).sum()


def main_worker(rank, world_size, seed=0):
    world_size = world_size


    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


    mod = Mod().cuda(rank)
    # process_group = th.distributed.new_group(list(range(world_size)))
    # mod = th.nn.SyncBatchNorm.convert_sync_batchnorm(mod, process_group)

    optim = th.optim.Adam(mod.parameters(), lr=1e-3)


    mod = th.nn.parallel.DistributedDataParallel(mod,
            device_ids=[rank], output_device=rank, broadcast_buffers=True)

    if rank % 2 == 0:
        z1 = th.zeros(7, 3, 5, 5).cuda(rank)
    else:
        z1 = th.ones(7, 3, 5, 5).cuda(rank)

    out = mod(z1)

    # if rank == 1:
    print('forward_after\t{}'.format(mod.module.bn.running_mean))

    # mod(z2) # <<---- The presence of this unused call causes an inplace error in backward() below if dec is a DDP module.

    loss = (out**2).mean()

    optim.zero_grad()
    loss.backward()
    optim.step()

    print('backward_after\t{}'.format(mod.module.bn.running_mean))

    out = mod(z1)

    print(mod.module.bn.running_mean)



if __name__ == "__main__":

    mp.spawn(main_worker, nprocs=2, args=(2, 0))

broadcast_buffers=True时,运行结果为Pytorch——基于DistributedDataParallel单机多GPU并行之broadcast_buffers_第1张图片

可以发现,在第一次运行forward后,BN中的runing mean会统计各自GPU上的batch统计量,因此得到的结果不同;当进行backward时也不会造成这个buffer的不同。再进行第二次forward之前时,cuda1上的runing mean则会与cuda0上的一致!

broadcast_buffers=False时,运行结果为

Pytorch——基于DistributedDataParallel单机多GPU并行之broadcast_buffers_第2张图片

这时,就不会一致了。这种情况下,就跟DataParallel下的BN实现一模一样了!所以我们能不能说当broadcast_buffers=True时,就能达到所谓的同步BN的效果呢??????答案是否定的,同步BN的意思是runing mean是统计所有的分布上GPUs上的batch的统计量,而这个仅仅在每次保持其他GPU上的统计量与GPU0上的一致。所以怎么实现同步BN呢?Pytorch已经实现了!利用

 

process_group = th.distributed.new_group(list(range(world_size)))
mod = th.nn.SyncBatchNorm.convert_sync_batchnorm(mod, process_group)

把上面的程序注释删除即可, 最终运行效果如下!

Pytorch——基于DistributedDataParallel单机多GPU并行之broadcast_buffers_第3张图片

可以发现达到了同步BN效果!需要强调的是,Pytorch的同步BN只能适用于2d形式下,1d不支持!!!!!

你可能感兴趣的:(Pytorch那些事儿)