pytorch 梯度NAN异常值解决

gradient 为nan可能原因:

  1. 梯度爆炸
  2. 学习率太大
  3. 数据本身有问题
  4. backward时,某些方法造成0在分母上, 如:使用方法sqrt()

定位造成nan的代码:

import torch
# 异常检测开启
torch.autograd.set_detect_anomaly(True)
# 反向传播时检测是否有异常值,定位code
with torch.autograd.detect_anomaly():
	loss.backward()

你可能感兴趣的:(pytorch,学习笔记,pytorch,深度学习,神经网络,人工智能)