本文是对 RNN 循环神经网络中的每一个神经元进行反向传播求导的数学推导过程,下面还使用 PyTorch
对导数公式进行编程求证。
一个普通的 RNN 神经网络如下图所示:
其中 x ⟨ t ⟩ x^{\langle t \rangle} x⟨t⟩ 表示某一个输入数据在 t t t 时刻的输入; a ⟨ t ⟩ a^{\langle t \rangle} a⟨t⟩ 表示神经网络在 t t t 时刻时的hidden state,也就是要传送到 t + 1 t+1 t+1 时刻的值; y ⟨ t ⟩ y^{\langle t \rangle} y⟨t⟩ 则表示在第 t t t 时刻输入数据传入以后产生的预测值,在进行预测或 sampling 时 y ⟨ t ⟩ y^{\langle t \rangle} y⟨t⟩ 通常作为下一时刻即 t + 1 t+1 t+1 时刻的输入,也就是说 x ⟨ t ⟩ = y ^ ⟨ t ⟩ x^{\langle t \rangle}=\hat{y}^{\langle t \rangle} x⟨t⟩=y^⟨t⟩ ;下面对数据的维度进行说明。
下图所示的是一个特定的 RNN 神经元:
上图说明了在第 t t t 时刻的神经元中,数据的输入 x ⟨ t ⟩ x^{\langle t \rangle} x⟨t⟩ 和上一层的 hidden state a ⟨ t ⟩ a^{\langle t \rangle} a⟨t⟩ 是如何经过计算得到下一层的 hidden state 和预测输出 y ^ ⟨ t ⟩ \hat{y}^{\langle t \rangle} y^⟨t⟩ 。
下面是对五个参数的维度说明:
计算 t t t 时刻的 hidden state a ⟨ t ⟩ a^{\langle t \rangle} a⟨t⟩ :
KaTeX parse error: No such environment: split at position 8: \begin{̲s̲p̲l̲i̲t̲}̲ z1^{\langle t …
预测 t t t 时刻的输出 y ^ ⟨ t ⟩ \hat{y}^{\langle t \rangle} y^⟨t⟩ :
KaTeX parse error: No such environment: split at position 8: \begin{̲s̲p̲l̲i̲t̲}̲ z2^{\langle t …
在当今流行的深度学习编程框架中,我们只需要编写一个神经网络的结构和负责神经网络的前向传播,至于反向传播的求导和参数更新,完全由框架搞定;即便如此,我们在学习阶段也要自己动手证明一下反向传播的有效性。
下图是 RNN 神经网络中的一个基本的神经元,图中标注了反向传播所需传来的参数和输出等。
就如一个全连接的神经网络一样,损失函数 J J J 的导数通过微积分的链式法则(chain rule)反向传播到每一个时间轴上。
为了方便,我们将损失函数关于神经元中参数的偏导符号简记为 d p a r a m e t e r s \mathrm{d}\mathit{parameters} dparameters ;例如将 ∂ J ∂ W a x \frac{\partial J}{\partial W_{ax}} ∂Wax∂J 记为 d W a x \mathrm{d}W_{ax} dWax 。
上图的反向传播的实现并没有包括全连接层和 Softmax 层。
计算损失函数关于各个参数的偏导数之前,我们先引入一个计算图(computation graph),其演示了一个 RNN 神经元的前向传播和如何利用计算图进行链式法则的反向求导。
因为当进行反向传播求导时,我们需要将整个时间轴的输入全部输入之后,才可以从最后一个时刻开始往前传进行反向传播,所以我们假设 t t t 时刻就为最后一个时刻 T x T_x Tx 。
如果我们想要先计算 ∂ ℓ ∂ W a x \frac{\partial\ell}{\partial W_{ax}} ∂Wax∂ℓ 所以我们可以从计算图中看到,反向传播的路径:
我们需要按部就班的分别对从 W a x W_{ax} Wax 计算到 ℓ \ell ℓ 一路相关的变量进行求偏导,利用链式法则,将红色路线上一路的偏导数相乘到一起,就可以求出偏导数 ∂ ℓ ∂ W a x \frac{\partial\ell}{\partial W_{ax}} ∂Wax∂ℓ ;所以我们得到:
KaTeX parse error: No such environment: split at position 8: \begin{̲s̲p̲l̲i̲t̲}̲ \frac{\partial…
在上面的公式中,我们仅需要分别求出每一个偏导即可,其中红色的部分就是关于 S o f t m a x \mathrm{Softmax} Softmax 的求导,关于 S o f t m a x \mathrm{Softmax} Softmax 求导的推导过程,可以看本人的另一篇博客: 关于 Softmax 回归的反向传播求导数过程
关于 t a n h \mathrm{tanh} tanh 的求导公式如下:
∂ tanh ( x ) ∂ x = 1 − tanh 2 ( x ) \frac{\partial \tanh(x)} {\partial x} = 1 - \tanh^2(x) ∂x∂tanh(x)=1−tanh2(x)
所以上面的式子就得到:
KaTeX parse error: No such environment: split at position 8: \begin{̲s̲p̲l̲i̲t̲}̲ \frac{\partial…
我们就可以得到在最后时刻 t t t 参数 W a x W_{ax} Wax 的偏导数。
关于上面式子中的偏导数的计算,除了标量对矩阵的求导,在后面还包括了两个一个矩阵或向量对另一个矩阵或向量中的求导,实际上这是非常麻烦的一件事。
比如在计算 ∂ z 1 ⟨ t ⟩ ∂ W a x \frac{\partial z1^{\langle t\rangle}}{\partial W_{ax}} ∂Wax∂z1⟨t⟩ 偏导数的时候,我们发现 z 1 ⟨ t ⟩ z1^{\langle t\rangle} z1⟨t⟩ 是一个 R n a × m \mathbb{R}^{n_a\times m} Rna×m 的矩阵,而 W a x W_{ax} Wax 则是一个 R n a × n x \mathbb{R}^{n_a\times n_x} Rna×nx 的矩阵,这一项就是一个矩阵对另一个矩阵求偏导,如果直接对其求导我们将会得到一个四维的矩阵 R n a × n x × n a × m \mathbb{R}^{n_a\times n_x\times n_a\times m} Rna×nx×na×m (雅可比矩阵 Jacobian matrix);只不过这个高维矩阵中偏导数的值有很多 0 0 0 。
在神经网络中,如果直接将这个高维矩阵直接生搬硬套进梯度下降里更新参数是不可行,因为我们需要得到的梯度是关于自变量同型的向量或矩阵而且我们还要处理更高维度的矩阵的乘法;所以我们需要将结果进行一定的处理得到我们仅仅需要的信息。
一般在深度学习框架中都会有自动求梯度的功能包,这些包(比如
PyTorch
)中就只允许一个标量对向量或矩阵求导,其他情况是不允许的,除非在反向传播的函数里传入一个同型的权重向量或矩阵才可以得到导数。
我们先简单求出一个偏导数 ∂ ℓ ∂ W a x \frac{\partial\ell}{\partial W_{ax}} ∂Wax∂ℓ 我们下面使用 PyTorch
中的自动求梯度的包进行验证我们的公式是否正确。
import torch
# 这是神经网络中的一些架构的参数
n_x = 6
n_y = 6
m = 1
T_x = 5
T_y = 5
n_a = 3
# 定义所有参数矩阵
# requires_grad 为 True 表明在涉及这个变量的运算时建立计算图
# 为了之后反向传播求导
W_ax = torch.randn((n_a, n_x), requires_grad=True)
W_aa = torch.randn((n_a, n_a), requires_grad=True)
ba = torch.randn((n_a, 1), requires_grad=True)
W_ya = torch.randn((n_y, n_a), requires_grad=True)
by = torch.randn((n_y, 1), requires_grad=True)
# t 时刻的输入和上一时刻的 hidden state
x_t = torch.randn((n_x, m), requires_grad=True)
a_prev = torch.randn((n_a, m), requires_grad=True)
y_t = torch.randn((n_y, m), requires_grad=True)
# 开始模拟一个神经元 t 时刻的前向传播
# 从输入一直到计算出 loss
z1_t = torch.matmul(W_ax, x_t) + torch.matmul(W_aa, a_prev) + ba
z1_t.retain_grad()
a_t = torch.tanh(z1_t)
a_t.retain_grad()
z2_t = torch.matmul(W_ya, a_t) + by
z2_t.retain_grad()
y_hat = torch.exp(z2_t) / torch.sum(torch.exp(z2_t), dim=0)
y_hat.retain_grad()
loss_t = -torch.sum(y_t * torch.log(y_hat), dim=0)
loss_t.retain_grad()
# 对最后的 loss 标量开始进行反向传播求导
loss_t.backward()
# 我们就可以得到 W_ax 的导数
# 存储在后缀 _autograd 变量中,表明是由框架自动求导得到的
W_ax_autograd = W_ax.grad
# 查看框架计算得到的导数
W_ax_autograd
tensor([[ 0.5252, 1.1938, -0.2352, 1.1571, -1.0168, 0.3195],
[-1.0536, -2.3949, 0.4718, -2.3213, 2.0398, -0.6410],
[-0.0316, -0.0717, 0.0141, -0.0695, 0.0611, -0.0192]])
# 我们对自己推演出的公式进行手动计算导数
# 存储在后缀 _manugrad 变量中,表明是手动由公式计算得到的
W_ax_manugrad = torch.matmul(torch.matmul((y_hat - y_t).T, W_ya).T * (1 - torch.square(torch.tanh(z1_t))), x_t.T)
#torch.matmul(torch.matmul(W_ya.T, y_hat - y_t) * (1 - torch.square(torch.tanh(z1_t))), x_t.T)
# 输出手动计算的导数
W_ax_manugrad
tensor([[ 0.5195, 1.1809, -0.2327, 1.1447, -1.0058, 0.3161],
[-1.0195, -2.3172, 0.4565, -2.2461, 1.9737, -0.6202],
[-0.0309, -0.0703, 0.0138, -0.0681, 0.0599, -0.0188]],
grad_fn=)
# 查看两种求导结果的之差的 L2 范数
torch.norm(W_ax_manugrad - W_ax_autograd)
tensor(0.1356, grad_fn=)
通过上面的编程输出可以看到,我们手动计算的导数和框架自己求出的导数虽然有一定的误差,但是一一对照可以大体看到我们手动求出来的导数大体是对的,并没有说错的非常离谱。
但上面只是当 t = T x t=T_x t=Tx 即 t t t 时刻是最后一个输入单元的时候,也就是说所求的关于 W a x _W{ax} Wax 的导数只是全部导数的一部分,因为参数共享,所以每一时刻的神经元都有对 W a x W_{ax} Wax 的导数,所以需要将所有时刻的神经元关于 W a x W_{ax} Wax 的导数全部加起来。
若 t t t 不是最后一时刻,可能是神经网络里的中间的某一时刻的神经元;也就是说,在进行反向传播的时候,想要求 t t t 时刻的导数,就得等到 t + 1 t+1 t+1 时刻的导数值传进来,然后根据链式法则才可以计算当前时刻参数的导数。
下面是一个简易的计算图,只绘制出了 W a x W_ax Wax 到 ℓ \ell ℓ 的计算中,共涉及到哪些变量(在整个神经网络中的 W a x W_{ax} Wax 的权重参数是共享的):
下面使用一个视频展示整个神经网络中从 W a x W_{ax} Wax 到一个数据批量的损失值 ℓ \ell ℓ 的大体流向:
计算完 ℓ \ell ℓ 之后就可以计算 ∂ ℓ ∂ W a x \frac{\partial\ell}{\partial W_{ax}} ∂Wax∂ℓ 的导数值,但是 RNN 神经网络的反向传播区别于全连接神经网络的。
然后,我们演示一下如何进行反向传播的,注意看每一个时刻的 a ⟨ t ⟩ a^{\langle t\rangle} a⟨t⟩ 的计算都是等 a ⟨ t + 1 ⟩ a^{\langle t+1\rangle} a⟨t+1⟩ 的导数值传进来才进行计算的;同样地, W a x W_{ax} Wax 导数的计算也不是一步到位的,也是需要等到所有时刻的 a a a 的值全部传到才计算完。
所以对于神经网络中间某一个单元 t t t 我们有:
KaTeX parse error: No such environment: split at position 8: \begin{̲s̲p̲l̲i̲t̲}̲ \frac{\partial…
关于红色的部分的意思是需要等到 t + 1 t+1 t+1 时刻的导数值传进来,然后才可以进行对 t + 1 t+1 t+1 时刻关于当前时刻 t t t 的参数求导,最后得到参数梯度的一个分量。其实若仔细展开每一个偏导项,就像是一个递归一样,每次求某一时刻的导数总是要从最后一时刻往前传到当前时刻才可以进行。
多元复合函数的求导法则
如果函数 u = φ ( t ) u=\varphi(t) u=φ(t) 及 v = ψ ( t ) v=\psi(t) v=ψ(t) 都在点 t t t 可导,函数 z = f ( u , v ) z=f(u,v) z=f(u,v) 在对应点 ( u , v ) (u,v) (u,v) 具有连续偏导数,那么复合函数 z = f [ φ ( t ) , ψ ( t ) ] z=f[\varphi(t),\psi(t)] z=f[φ(t),ψ(t)] 在点 t t t 可导,且有
d z d t = ∂ z ∂ u d u d t + ∂ z ∂ v d v d t \frac{\mathrm{d}z}{\mathrm{d}t}=\frac{\partial z}{\partial u}\frac{\mathrm{d}u}{\mathrm{d}t}+\frac{\partial z}{\partial v}\frac{\mathrm{d}v}{\mathrm{d}t} dtdz=∂u∂zdtdu+∂v∂zdtdv
下面使用一张计算图说明 a ⟨ t ⟩ a^{\langle t\rangle} a⟨t⟩ 到 ℓ \ell ℓ 的计算关系。
也就是说第 t t t 时刻 ℓ \ell ℓ 关于 a ⟨ t ⟩ a^{\langle t\rangle} a⟨t⟩ 的导数是由两部分相加组成,也就是说是由两条路径反向传播,这两条路径分别是 ℓ → ℓ ⟨ t ⟩ → y ^ ⟨ t ⟩ → z 2 ⟨ t ⟩ → a ⟨ t ⟩ \ell\to\ell^{\langle t\rangle}\to\hat{y}^{\langle t\rangle}\to z2^{\langle t\rangle}\to a^{\langle t\rangle} ℓ→ℓ⟨t⟩→y^⟨t⟩→z2⟨t⟩→a⟨t⟩ 和 ℓ → ℓ ⟨ t + 1 ⟩ → y ^ ⟨ t + 1 ⟩ → z 2 ⟨ t + 1 ⟩ → a ⟨ t + 1 ⟩ → z 1 ⟨ t + 1 ⟩ → a ⟨ t ⟩ \ell\to\ell^{\langle t+1\rangle}\to\hat{y}^{\langle t+1\rangle}\to z2^{\langle t+1\rangle}\to a^{\langle t+1\rangle}\to z1^{\langle t+1\rangle}\to a^{\langle t\rangle} ℓ→ℓ⟨t+1⟩→y^⟨t+1⟩→z2⟨t+1⟩→a⟨t+1⟩→z1⟨t+1⟩→a⟨t⟩ ,我们将这两条路径导数之和使用 d a n e x t \mathrm{d}a_{\mathrm{next}} danext 表示。
所以我们可以得到在中间某一时刻的神经单元关于 W a x W_{ax} Wax 的导数为:
∂ ℓ ∂ W a x = ( d a n e x t ∗ ( 1 − tanh 2 ( z 1 ⟨ t ⟩ ) ) x ⟨ t ⟩ T \frac{\partial\ell}{\partial W_{ax}}=\left(\mathrm{d}a_{\mathrm{next}} * \left( 1-\tanh^2(z1^{\langle t \rangle}\right)\right) x^{\langle t \rangle T} ∂Wax∂ℓ=(danext∗(1−tanh2(z1⟨t⟩))x⟨t⟩T
通过同样的方法,我们就可以得到其它参数的导数:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…
除了传递参数的导数,在第 t t t 时刻还需要传送 ℓ \ell ℓ 关于 z 1 ⟨ t ⟩ z1^{\langle t\rangle} z1⟨t⟩ 的导数到 t − 1 t-1 t−1 时刻,将需要传送到上一时刻的导数记作为 d a p r e v \mathrm{d}a_{\mathrm{prev}} daprev 我们得到:
KaTeX parse error: No such environment: split at position 8: \begin{̲s̲p̲l̲i̲t̲}̲ \mathrm{d}a_{\…
可以看到,一个循环神经网络的反向传播实际上是非常复杂的,因为每一时刻的神经元都与参数有计算关系,所以反向传播时的路径非常杂乱,其中还涉及到了高维的矩阵,所以在计算时需要对高维矩阵进行一定的矩阵代数转换才方便导数和更新参数的计算。