【Pytorch】反向传播为NaN报错的排查解决方法,RuntimeError: Function ‘BmmBackward0‘ returned nan values

         最近在训练模型的过程中,反复出现方向传播至为NaN的报错,报错信息如下所示:

File "/home/fu/anaconda3/envs/torch/lib/python3.7/site-packages/torch/autograd/__init__.py", line 156, in backward allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: Function 'BmmBackward0' returned nan values in its 1th output.

        一般情况下,反向传播中有个别Nan值,并不会引起训练发生报错,只有在打开自动微分异常监测时:torch.autograd.detect_anomaly(True),才会出现任意Nan都会引起模型报错。

  在模型正常训练阶段不建议打开autograd.detect_anomaly,会使训练速度大大减慢,以笔者 这里的测试,打开后,原本4个小时的训练被减慢至7.5个小时;打开后可以辅助找到出现Nan值的位置。pytorch官方文档中的表述是这样的:

【Pytorch】反向传播为NaN报错的排查解决方法,RuntimeError: Function ‘BmmBackward0‘ returned nan values_第1张图片

 

       我在autograd.detect_anomaly打开的情况下经过多次记录,发现反向传播出现NaN的位置主要有三处:

1. 开平方根sqrt()函数,其导数为1/2*x^(-0.5),当输入值为0时,会使反向传播值为Nan

2. n次幂pow()函数,尤其是n小于1时,也会出现输入值在分母上的情况

 File "/home/fu/0805_fully_debug/code/model/modules.py", line 281, in forward
    v1 = torch.sqrt(torch.pow(vx1, 2) + torch.pow(vy1, 2)).clone()
 (Triggered internally at  ../torch/csrc/autograd/python_anomaly_mode.cpp:104.)

反向传播时出现NaN

这里可以通过不使用这两个函数来避免该问题,如果还是无法解决,可以在分母上加上微值eps,避免分母为0

3.损失函数中出现Nan

 File "/home/fu/0805_fully_debug/code/model/losses.py", line 14, in nll_with_covariances
    errors = coordinates_delta.permute(0, 1, 2, 4, 3) @ precision_matrices @ coordinates_delta
 (function _print_stack)

这时并不能直接看出问题,这时可以将所有梯度打印出来,参照记录一次 NaN in Loss 的解决过程中的过程,一步步找到问题的根源。

你可能感兴趣的:(pytorch,深度学习,python)