神经网络出现NAN的个人见解

出现场景

网络设计为4层LSTM组成的一个RNN,学习率设为0.1,num_units个数为256,出现NAN。当把层数调成2层的时候,没有出现。

原因分析:
网络层数太深,加上RNN内部计算是循环嵌套,从前往后计算,每层的输入逐渐累积;学习率过大;某些batch产生较大的梯度。
经过大梯度、大学习率、深层网络累积的输入,使得网络参数变得异常。
【注】:反向传播的时候,求的是loss function与当前层参数的梯度,此时计算梯度,输入x作为常量,参与到参数的更新。

出现原因

  1. 学习率较大,若此时反向传回来的梯度也很大的时候,参数可能会更新的非常大,倘若不幸,飞成Inf,前向传播求loss的时候,会报NAN。解决方法调小学习率。
  2. 某些batch的数据产生过大的梯度,解决方法采用梯度裁剪、数据归一化。
  3. 数据出错,网络中出现log0、除以0等不正常的操作。
    【个人认为1、2是互相依赖的,若梯度很大但学习率比较小的话,参数更新值会因为学习率较小而变小,减少NAN出现的概率;若学习率很大但梯度很小的话,参数更新值应该也不会很大,毕竟学习率一般取值不超过1。3是最常出现的错误,也是最容易理解的原因。】

举例分析

给定输入数据x为(1,1)(2,1),…(1,10),输出数据/标签y为(3,4,…,12)。通过训练网络求y= w1x1+w2x2+b w 1 ∗ x 1 + w 2 ∗ x 2 + b w1,w2,w3 w 1 , w 2 , w 3 的值。如下图所示的网络结构图。
神经网络出现NAN的个人见解_第1张图片
它的每个参数关于loss function的梯度如下:
神经网络出现NAN的个人见解_第2张图片
假设参数全部初始化为0.1,batch size为1,以w2的第一次迭代更新举例:
1)lr=1.0 当前batch的数据为(1,1),
w2:=0.1-1.0*2*(3-(0.1*1+0.1*1+0.1))*1=0.1-5,4=-5.3

2)lr=1.0 当前batch的数据为(1,10),
w2:=0.1-1.0*2(12-(0.1*1+0.1*10+0.1))*10=0.1-216=-215.9

【可以看出某些数据可以使得梯度变得很大,导致参数更新的很大;梯度较小的时候,即使学习率很大,对参数更新的影响不会很大。】

3)lr=0.1 当前batch的数据为(1,10),
w2:=0.1-0.1*2(12-(0.1+0.1*1+0.1*10))*10=0.1-21.6=-21.5

4)lr=0.01 当前batch的数据为(1,10),
w2:=0.1-0.01*2(12-(0.1+0.1*1+0.1*10))*10=0.1-2.16=-2.06

【可以看出当lr较小的时候可以限制梯度很大带来的大幅度参数更新】

你可能感兴趣的:(TensorFlow)