pytorch DDP模式中总是出现OOM问题。。

主要原因是没有进行及时的内存回收,导致显卡内存暴增:

解决方式:

在每个batch 反向传播后,加上下面的内存回收:

        del loss
        torch.cuda.empty_cache()
        gc.collect()

另外一点是建议用loss.detach().item()来从graph中分离,这样内存占用会少一点,因为如果使用loss.item(),它默认的整个graph

你可能感兴趣的:(pytorch)