Pytorch中torch.nn.DataParallel负载均衡问题

1. 问题概述

现在Pytorc下进行多卡训练主流的是采用torch.nn.parallel.DistributedDataParallel()(DDP)方法,但是在一些特殊的情况下这样的方法就使用不了了,特别是在进行与GAN相关的训练的时候,假如使用的损失函数是 WGAN-GP(LP),DRAGAN,那么其中会用到基于梯度的惩罚,其使用到的函数为torch.autograd.grad(),但是很不幸的是在实验的过程中该函数使用DDP会报错:

File "/home/work/anaconda3/envs/xxxxx_py/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: derivative for batch_norm_backward_elemt is not implemented

那么需要并行(单机多卡)计算那么就只能使用torch.nn.DataParallel()了,但是也带来另外一个问题那就是负载极其不均衡,使用这个并行计算方法会在主GPU上占据较多的现存,而其它的GPU显存则只占用了一部分,这样就使得无法再继续增大batchsize了,下图就是这种方式进行计算,整个数据流的路线:Pytorch中torch.nn.DataParallel负载均衡问题_第1张图片
可以在上图中看到输入数据计算和损失计算过程中都会存在数据汇总的情况,这就难免使得主

你可能感兴趣的:([3],Python相关)