记录一次 NaN in Loss 的解决过程

前言

训练的模型经常数十个epoch后Loss中出现NaN,困扰了很久终于解决了,记录一下。

检测

我通常会在计算loss.backward()optimizer.step()前,添加条件判断是否出现NaN:

if loss != loss:
    raise Exception('NaN in loss, crack!')

定位

输出参数和梯度

同样放在loss.backward()optimizer.step()前。
用于打印网络所有层参数和梯度的最大和最小值,仅在Debug时使用:

# print grad check
v_n = []
v_v = []
v_g = []
for name, parameter in net.named_parameters():
    v_n.append(name)
    v_v.append(parameter.detach().cpu().numpy() if parameter is not None else [0])
    v_g.append(parameter.grad.detach().cpu().numpy() if parameter.grad is not None else [0])
for i in range(len(v_n)):
    if np.max(v_v[i]).item() - np.min(v_v[i]).item() < 1e-6:
        color = bcolors.FAIL + '*'
    else:
        color = bcolors.OKGREEN + ' '
    print('%svalue %s: %.3e ~ %.3e' % (color, v_n[i], np.min(v_v[i]).item(), np.max(v_v[i]).item()))
    print('%sgrad  %s: %.3e ~ %.3e' % (color, v_n[i], np.min(v_g[i]).item(), np.max(v_g[i]).item()))

部分输出如下:

 value encoder_blocks.0.KPConv.weights: -1.890e-01 ~ 9.013e-02
 grad  encoder_blocks.0.KPConv.weights: nan ~ nan
 value encoder_blocks.0.KPConv.kernel_points: -3.859e-01 ~ 4.054e-01
 grad  encoder_blocks.0.KPConv.kernel_points: 0.000e+00 ~ 0.000e+00
 value encoder_blocks.0.batch_norm.batch_norm.weight: -6.673e-03 ~ 1.823e-01
 grad  encoder_blocks.0.batch_norm.batch_norm.weight: nan ~ nan
 value encoder_blocks.0.batch_norm.batch_norm.bias: -1.067e-01 ~ 9.120e-02
 grad  encoder_blocks.0.batch_norm.batch_norm.bias: nan ~ nan

分析

首先查看是Loss中先出现NaN还是梯度中先出现NaN。
若Loss中先出现,则有可能是正向传播过程出现问题或计算Loss过程出现问题。
但我发现我的模型中梯度中首先出现NaN,说明是反向传播过程出现问题,从后往前查看梯度,找到倒数第一个NaN的地方

 value encoder.ffn0.l1.weight: -1.010e-06 ~ 1.011e-06
 grad  encoder.ffn0.l1.weight: nan ~ nan
 value encoder.ffn0.l1.bias: -9.871e-07 ~ 9.912e-07
 grad  encoder.ffn0.l1.bias: nan ~ nan
 value encoder.ffn0.l2.weight: -9.877e-07 ~ 1.031e-06
 grad  encoder.ffn0.l2.weight: nan ~ nan
 value encoder.ffn0.l2.bias: -9.813e-07 ~ 4.404e-06
 grad  encoder.ffn0.l2.bias: nan ~ nan
*value encoder.norm.weight: 3.158e-05 ~ 3.163e-05
*grad  encoder.norm.weight: -5.224e-15 ~ 9.597e-15
*value encoder.norm.bias: -2.440e-09 ~ 2.480e-09
*grad  encoder.norm.bias: -9.202e-15 ~ 1.174e-14
 value encoder.ffn1.net.0.weight: -1.017e-06 ~ 1.047e-06
 grad  encoder.ffn1.net.0.weight: -4.821e-14 ~ 6.914e-14
 value encoder.ffn1.net.0.bias: -1.174e-06 ~ 1.422e-06
 grad  encoder.ffn1.net.0.bias: -4.267e-10 ~ 6.084e-10

因此我们定位到问题是出现在 encoder.ffn0.l2 之后 encoder.norm.weight 之前,查看代码查找问题。

精确定位

知道了问题存在的位置我还是找不到问题,那就是定位得不够精确,其实可以直接使用pytorch自带的检查功能(严重影响性能,仅在Debug中使用):

with autograd.detect_anomaly():
	# 正向传播
	# 计算Loss
	# 反向传播

当模型中出现NaN时,会抛出异常,并在打印常规的Debug信息前,提供更多的Debug信息。额外提供的信息如下:

  File "~/work/trainer.py", line 252, in train                            
    outputs = net(batch, config)                                                                      
  File "~/anaconda3/envs/pytorch1.8/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl                                                                               
    result = self.forward(*input, **kwargs)                                                           
  File "~/anaconda3/envs/pytorch1.8/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 705, in forward                                                                            
    output = self.module(*inputs[0], **kwargs[0])                                                     
  File "~/anaconda3/envs/pytorch1.8/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl                                                                               
    result = self.forward(*input, **kwargs)                                                             
  File "~/anaconda3/envs/pytorch1.8/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl                                                                               
    result = self.forward(*input, **kwargs)                                                           
  File "~/work/models/crossvitmodule.py", line 672, in forward                  
    x = list(x)                                                                                       
  File "~/work/models/crossvitmodule.py", line 663, in <lambda>                 
    x = map(lambda u, v: u + self.pos_embedding(u, v), x, point)                                      
  File "~/anaconda3/envs/pytorch1.8/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl                                                                               
    result = self.forward(*input, **kwargs)                                                           
  File "~/work/models/crossvitmodule.py", line 423, in forward                  
    point = self.proj2(point)                                                                         
  File "~/anaconda3/envs/pytorch1.8/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl                                                                               
    result = self.forward(*input, **kwargs)                                                           
  File "~/work/models/crossvitmodule.py", line 404, in forward                  
    return self.leaky_relu(self.norm(self.linear(x)))                                                 
  File "~/anaconda3/envs/pytorch1.8/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl                                                                               
    result = self.forward(*input, **kwargs)                                                           
  File "~/work/models/crossvitmodule.py", line 29, in forward                   
    std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()                        
 (function _print_stack)   

直接定位到std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()这一行。
原代码为手动实现的LayerNorm:

class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (std + self.eps) * self.g + self.b

当x为0时在计算标准差时正向传播需要计算sqrt(x)
而反向传播时需要计算sqrt(x)的微分1/(2*sqrt(x))此时需要确保x != 0

解决方案

防止除0即可:

std = torch.sqrt(torch.var(x, dim = 1, unbiased = False, keepdim = True) + self.eps)

你可能感兴趣的:(debug,pytorch,pytorch,NaN,loss,debug)