https://blog.csdn.net/mch2869253130/article/details/111034068
https://www.zzsblog.top/coding/2021/08/07/pytorch%E5%AE%9A%E4%BD%8DNaN.html
按照下面的流程来判断。
...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)
————————————————
版权声明:本文为CSDN博主「风吹草地现牛羊的马」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/mch2869253130/article/details/111034068
检查NaN有三板斧, 尽管调试NaN通常需要一定的经验和耐心, 但记住这三个至少不至于手足无措.
torch.autograd.set_detect_anomaly(True)
如题, forward时出现NaN即时报错. 尽管说得好听, 但有的时候并不能准确地定位问题所在. 属于调试NaN的必要辅助.
# loss = model(X)
with torch.autograd.detect_anomaly():
loss.backward()
如题, backward时出现NaN时即时报错. 相比#1来说更难确切定位问题, 往往用于兜底, 即确保出现NaN时程序会尽快抛出异常.
assert是确保程序行为正确的重要手段. 对于一个算法来说, 出现NaN不管怎么说都意味着不正常. 同时, 对debug来说, 最重要的就是找到事发现场, 而assert正是寻找真正现场的利器.
在pytorch中, 检查NaN的函数为torch.isnan(T)
. 于是我们可以构造如下断言:
assert not torch.any(torch.isnan(T))
当然, 这么写其实有一点性能浪费, 但写python, 又是debug专用代码, 何必考虑这么多呢¯\_(ツ)_/¯
将这个断言加在你认为有可能出现NaN的步骤之后. 这样一旦出现NaN, 你至少能抓住一个现场. 哪怕这个现场已经漂移, 配合调试器你也能更有逻辑地找到真正的事发现场.
讲完三板斧总得讲讲NaN的成因, 要不然就是光有方法没有理论(x 尤其是#3, 要求调试者非常充分且熟练地掌握NaN的可能成因.
梯度爆炸, 或者梯度消失都可能导致NaN. 这个问题往往会被#2 反向传播异常检测捕获, 但真正定位到问题却难上加难. 相对来说, 重新推导一遍自己的理论模型、寻找可能导致梯度爆炸的计算显得更有针对性.
这也是NaN最常见的成因. 毕竟大多数的网络, 尤其是复现、组合别人的网络结构一般不会碰到梯度爆炸的问题, 而NaN大多出现于loss计算的部分, 诞生于某个小小的不合法计算, 然后污染它参与计算的所有结果, 最后在你的loss值上表现出来.
常见套路:
尚有其他的一些情况我自己没遇到过, 网上可能会有补充
这种问题运气好的话会被#1 正向异常检测直接找到, 但通常是找到一个漂移了亿点点的位置. 推荐用#3 assert的办法, 尤其是 自己写了loss时, 在关键位置放几个assert守门, 总归是没错的.
注意, 绝大多数时候, inf也是不合常理的存在. 因此你可能也需要同时寻找inf:
assert not torch.any(torch.isnan(T) + torch.isinf(T))
NaN的次常见成因. 顾名思义, 出现NaN仅仅是因为数据里含有NaN. 通常来说直接读图片不会出现NaN, 往往是大意地处理数据后会出现这种情况.
随便举个例子.
mask = mask / mask.max()
# serialize mask
这句话看起来没问题, 把uint8{0, 255}转成float32[0, 1]. 相信很多人都这么写过. 正常来说不会有任何问题, 直到我遇到了一张纯黑的mask :P
毕竟谁也不会想到有一张图没标注还给放数据集里了是吧. 但不管怎么说, 此时我们犯了”除零”的错误. 这个mask会变成携带NaN的脏数据输入模型, 并在计算loss时将loss结果污染. 如果程序没有及时终止, 在仅仅一次反向传播之后, 你的模型参数将变为NaN, 其一切推导将得出NaN ¯\_(ツ)_/¯