DistributedDataParallel数据不均衡

背景

在使用 DistributedDataParallel 进行数据并行训练时,每次反向传播都需要执行 all_reduce 操作以同步各个进程的梯度。all_reduce 需要进程组中的所有进程参与,如果某一个进程没有执行 all_reduce(一个进程的输入较其他进程少),那么其他进程将会挂起或出错(取决于后端,nccl 后端会挂起,gloo 后端会报错)。

问题

在进行模型训练时,由于数据不均衡,导致不同GPU上训练的轮数不同。比如,0号GPU正在训练第25轮epoch,1号GPU正在训练第30轮epoch。这样训练出来的模型精度不好(0号GPU训练精度92;1号GPU训练精度95,模型只能保存25.pt)。

解决方法

使用model.join方法;
使用Join上下文管理器:with Join([model]);

学习资料

  1. 浅析 PyTorch 的 Join 原理
    https://zhuanlan.zhihu.com/p/630904458
  2. 通信包
    https://www.jianshu.com/p/5f6cd6b50140
  3. 数据不均衡导致GPU挂起
    https://zhuanlan.zhihu.com/p/560490906?utm_id=0
  4. DP与DDP的区别
    https://blog.csdn.net/ytusdc/article/details/122091284
    here
    here

你可能感兴趣的:(pytorch,ddp)