RNN梯度消失与梯度爆炸推导

2021-07-20-15-04-53.png

梯度消失与爆炸

假设一个只有 3 个输入数据的序列,此时我们的隐藏层 h1、h2、h3 和输出 y1、y2、y3 的计算公式:

\large \begin{aligned} h_1 &= \sigma(Ux_1 + Wh_0 + b) \\ h_2 &= \sigma(Ux_2 + Wh_1 + b) \\ h_3 &= \sigma(Ux_3 + Wh_2 + b) \\ y_1 &= \sigma(Vh_1 + c) \\ y_2 &= \sigma(Vh_2 + c) \\ y_3 &= \sigma(Vh_3 + c) \end{aligned}

RNN 在时刻 t 的损失函数为 Lt,总的损失函数为

t = 3 时刻的损失函数 L3 对于网络参数 U、W、V 的梯度如下:

\begin{aligned} \frac{\partial L_3}{\partial V} &= \frac{\partial L_3}{\partial y_3} \frac{\partial y_3}{\partial V} \\ \frac{\partial L_3}{\partial U} &= \frac{\partial L_3}{\partial y_3} \frac{\partial y_3}{\partial h_3} \frac{\partial h_3}{\partial U} + \frac{\partial L_3}{\partial y_3} \frac{\partial y_3}{\partial h_3} \frac{\partial h_3}{\partial h_2} \frac{\partial h_2}{\partial U} + \frac{\partial L_3}{\partial y_3} \frac{\partial y_3}{\partial h_3} \frac{\partial h_3}{\partial h_2} \frac{\partial h_2}{\partial h_1} \frac{\partial h_1}{\partial U} \\ \frac{\partial L_3}{\partial W} &= \frac{\partial L_3}{\partial y_3} \frac{\partial y_3}{\partial h_3} \frac{\partial h_3}{\partial W} \+ \frac{\partial L_3}{\partial y_3} \frac{\partial y_3}{\partial h_3} \frac{\partial h_3}{\partial h_2} \frac{\partial h_2}{\partial W} \+ \frac{\partial L_3}{\partial y_3} \frac{\partial y_3}{\partial h_3} \frac{\partial h_3}{\partial h_2} \frac{\partial h_2}{\partial h_1} \frac{\partial h_1}{\partial W} \\ \end{aligned}

其实主要就是因为:

  • 对V求偏导时,是常数
  • 对U求偏导时:
    • 里有U,所以要继续对h3应用chain rule
    • 里的是常数,但是里又有U,继续chain rule
    • 以此类推,直到
  • 对W求偏导时一样

所以:

  1. 参数矩阵 V (对应输出 ) 的梯度很显然并没有长期依赖
  2. U和V显然就是连乘()后累加()

\begin{aligned} \frac{\partial L_t}{\partial U} = \sum_{k=0}^{t} \frac{\partial L_t}{\partial y_t} \frac{\partial y_t}{\partial h_t} (\prod_{j=k+1}^{t}\frac{\partial h_j}{\partial h_{j-1}}) \frac{\partial h_k}{\partial U} \\ \frac{\partial L_t}{\partial W} = \sum_{k=0}^{t} \frac{\partial L_t}{\partial y_t} \frac{\partial y_t}{\partial h_t} (\prod_{j=k+1}^{t}\frac{\partial h_j}{\partial h_{j-1}}) \frac{\partial h_k}{\partial W} \end{aligned}

其中的连乘项就是导致 RNN 出现梯度消失与梯度爆炸的罪魁祸首,连乘项可以如下变换:

tanh' 表示 tanh 的导数,可以看到 RNN 求梯度的时候,实际上用到了 (tanh' × W) 的连乘。当 (tanh' × W) > 1 时,多次连乘容易导致梯度爆炸;当 (tanh' × W) < 1 时,多次连乘容易导致梯度消失。

你可能感兴趣的:(RNN梯度消失与梯度爆炸推导)