【Pytorch】训练中跳过问题样本,解决显存爆炸\波动问题

        最近在训练模型时发现数据集部分数据存在问题,计算得到的loss非常大,这不利于模型的优化,所以就考虑通将loss过大的样本直接跳过,不使用这些数据进行优化。

        考虑让问题数据不在反向传播,用以下语句:

loss = nll_with_covariances(
    xy_future_gt, coordinates, probas, data["target/future/valid"].squeeze(-1),
    covariance_matrices) * loss_coeff
if loss>1e5 and step>100:
    del data
    torch.cuda.empty_cache()
    continue                
else:
    train_losses.append(loss.item())
    loss.backward()
optimizer.step()

        就发生了显存不断上下波动的情况,按照原来的batchsize还发生显存不足的问题,尝试使用torch.cuda.empty_cache()依然不能解决。于是就想搞清楚pytorch训练中显存的管理机制。从知乎上看到一篇很详细的讲解,pytorch占用显存有四部分组成:模型定义、前向传播过程、反向传播过程、参数更新过程。 PyTorch显存机制分析

        我分析反向传播结束后会清除前向传播数据,可以使模型再次前向传播而显存不发生变化,之前直接跳过反向传播,导致有两个batchsize的前向传播数据挤在显存中,导致显存猛增最后不足。

        torch.cuda.empty_cache()只能清除缓存区的数据,不能解决根本问题,会导致预留给下一批次位置被清除,导致训练速度变慢,由平均1.1s增加到1.25s。

       之后在pytorch论坛上看到和我相似的问题,他的解决方法就是对问题样本仍进行反向传播,但是不更新参数,所以修改后的代码为:How to skip backward if the loss is very small in training

loss = nll_with_covariances(
    xy_future_gt, coordinates, probas, data["target/future/valid"].squeeze(-1),
    covariance_matrices) * loss_coeff
train_losses.append(loss.item())
loss.backward()
if loss>1e5 and step>100:
    optimizer.zero_grad()
    del data
    continue                
else: 
    optimizer.step()

         修改后的模型训练中显存就始终保持稳定了,不会像之前一样发生剧烈的显存波动,这时的batchsize可以进一步调大,不用担心训练中显存爆炸中断了。

【Pytorch】训练中跳过问题样本,解决显存爆炸\波动问题_第1张图片

你可能感兴趣的:(pytorch,python,深度学习,算法)