在前几节中,使用了小批量随机梯度下降的优化算法来训练模型。实现中,只提供了模型的正向传播(Forward propagation)的计算。
鉴于基于反向传播(Backpropagation)算法的自动求梯度,极大简化了深度学习模型训练算法的实现,本节将使用数学和计算图(Computational graph)两个方式来描述正向传播和反向传播。
具体来说,将以带 L 2 L_2 L2范数正则化的含单隐藏层的多层感知机为样例模型,解释正向传播和反向传播。
正向传播,是指对神经网络沿着从输入层到输出层的顺序,依次计算并存储模型的中间变量(包括输出)。具体举例如下:
(1) 假设输入是一个特征为 x ∈ R d \boldsymbol{x} \in \mathbb{R}^d x∈Rd的样本,且不考虑偏差项,那么有中间变量:
z = W ( 1 ) x \boldsymbol{z} = \boldsymbol{W}^{(1)} \boldsymbol{x} z=W(1)x
其中, W ( 1 ) ∈ R h × d \boldsymbol{W}^{(1)} \in \mathbb{R}^{h \times d} W(1)∈Rh×d是隐藏层的权重参数。
(2) 把中间变量 z ∈ R h \boldsymbol{z} \in \mathbb{R}^h z∈Rh输入按元素运算的激活函数 ϕ \phi ϕ后,将得到向量长度为 h h h的隐藏层变量:
h = ϕ ( z ) \boldsymbol{h} = \phi (\boldsymbol{z}) h=ϕ(z)
其中,隐藏层变量 h \boldsymbol{h} h也是一个中间变量。
(3) 假设输出层参数只有权重 W ( 2 ) ∈ R q × h \boldsymbol{W}^{(2)} \in \mathbb{R}^{q \times h} W(2)∈Rq×h,可以得到向量长度为 q q q的输出层变量:
o = W ( 2 ) h \boldsymbol{o} = \boldsymbol{W}^{(2)} \boldsymbol{h} o=W(2)h
(4) 假设损失函数为 ℓ \ell ℓ,且样本标签为 y y y,可以计算出单个数据样本的损失项:
L = ℓ ( o , y ) L = \ell(\boldsymbol{o}, y) L=ℓ(o,y)
(5) 根据 L 2 L_2 L2范数正则化的定义,给定超参数 λ \lambda λ,有正则化项:
s = λ 2 ( ∣ W ( 1 ) ∣ F 2 + ∣ W ( 2 ) ∣ F 2 ) s = \frac{\lambda}{2} \left(|\boldsymbol{W}^{(1)}|_F^2 + |\boldsymbol{W}^{(2)}|_F^2\right) s=2λ(∣W(1)∣F2+∣W(2)∣F2)
其中,矩阵的Frobenius范数等价于将矩阵变平为向量后计算 L 2 L_2 L2范数,即,对应元素的平方和再开方。
(6) 最终,模型在给定的数据样本上带正则化的损失为:
J = L + s J = L + s J=L+s
因此,称 J J J为有关给定数据样本的目标函数,简称目标函数。
通常绘制计算图,来可视化运算符和变量在计算中的依赖关系。
下图为本节中样例模型正向传播的计算图:
反向传播,指的是计算神经网络参数梯度的方法。
总的来说,反向传播依据微积分中的链式法则,沿着从输出层到输入层的顺序,依次计算并存储目标函数有关神经网络各层的中间变量以及参数的梯度。
对输入或输出为任意形状张量的函数 Y = f ( X ) \mathsf{Y}=f(\mathsf{X}) Y=f(X)和 Z = g ( Y ) \mathsf{Z}=g(\mathsf{Y}) Z=g(Y),通过链式法则,有:
∂ Z ∂ X = prod ( ∂ Z ∂ Y , ∂ Y ∂ X ) \frac{\partial \mathsf{Z}}{\partial \mathsf{X}} = \text{prod}\left(\frac{\partial \mathsf{Z}}{\partial \mathsf{Y}}, \frac{\partial \mathsf{Y}}{\partial \mathsf{X}}\right) ∂X∂Z=prod(∂Y∂Z,∂X∂Y)
其中, prod \text{prod} prod为数组的元素乘积。
沿用上文6.1.1中的样例模型为例,它的参数是 W ( 1 ) \boldsymbol{W}^{(1)} W(1)和 W ( 2 ) \boldsymbol{W}^{(2)} W(2),因此,反向传播的目标是计算 ∂ J / ∂ W ( 1 ) \partial J/\partial \boldsymbol{W}^{(1)} ∂J/∂W(1)和 ∂ J / ∂ W ( 2 ) \partial J/\partial \boldsymbol{W}^{(2)} ∂J/∂W(2)。具体为,应用链式法则依次计算各中间变量和参数的梯度。其中,计算次序与前向传播中相应中间变量的计算次序相反:
(1) 分别计算目标函数 J = L + s J=L+s J=L+s 有关损失项 L L L和正则项 s s s的梯度:
∂ J ∂ L = 1 , ∂ J ∂ s = 1. \frac{\partial J}{\partial L} = 1, \quad \frac{\partial J}{\partial s} = 1. ∂L∂J=1,∂s∂J=1.
(2) 依据链式法则,计算目标函数有关输出层变量的梯度 ∂ J / ∂ o ∈ R q \partial J/\partial \boldsymbol{o} \in \mathbb{R}^q ∂J/∂o∈Rq:
∂ J ∂ o = prod ( ∂ J ∂ L , ∂ L ∂ o ) = ∂ L ∂ o . \frac{\partial J}{\partial \boldsymbol{o}} = \text{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \boldsymbol{o}}\right) = \frac{\partial L}{\partial \boldsymbol{o}}. ∂o∂J=prod(∂L∂J,∂o∂L)=∂o∂L.
(3) 计算正则项有关两个参数的梯度:
∂ s ∂ W ( 1 ) = λ W ( 1 ) , ∂ s ∂ W ( 2 ) = λ W ( 2 ) . \frac{\partial s}{\partial \boldsymbol{W}^{(1)}} = \lambda \boldsymbol{W}^{(1)},\quad\frac{\partial s}{\partial \boldsymbol{W}^{(2)}} = \lambda \boldsymbol{W}^{(2)}. ∂W(1)∂s=λW(1),∂W(2)∂s=λW(2).
(4) 计算最靠近输出层的模型参数的梯度 ∂ J / ∂ W ( 2 ) ∈ R q × h \partial J/\partial \boldsymbol{W}^{(2)} \in \mathbb{R}^{q \times h} ∂J/∂W(2)∈Rq×h。依据链式法则,有:
∂ J ∂ W ( 2 ) = prod ( ∂ J ∂ o , ∂ o ∂ W ( 2 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 2 ) ) = ∂ J ∂ o h ⊤ + λ W ( 2 ) . \frac{\partial J}{\partial \boldsymbol{W}^{(2)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(2)}}\right) = \frac{\partial J}{\partial \boldsymbol{o}} \boldsymbol{h}^\top + \lambda \boldsymbol{W}^{(2)}. ∂W(2)∂J=prod(∂o∂J,∂W(2)∂o)+prod(∂s∂J,∂W(2)∂s)=∂o∂Jh⊤+λW(2).
(5) 沿着输出层向隐藏层继续反向传播,隐藏层变量的梯度 ∂ J / ∂ h ∈ R h \partial J/\partial \boldsymbol{h} \in \mathbb{R}^h ∂J/∂h∈Rh可以这样计算:
∂ J ∂ h = prod ( ∂ J ∂ o , ∂ o ∂ h ) = W ( 2 ) ⊤ ∂ J ∂ o . \frac{\partial J}{\partial \boldsymbol{h}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{h}}\right) = {\boldsymbol{W}^{(2)}}^\top \frac{\partial J}{\partial \boldsymbol{o}}. ∂h∂J=prod(∂o∂J,∂h∂o)=W(2)⊤∂o∂J.
(6) 由于激活函数 ϕ \phi ϕ是按元素运算的,中间变量 z \boldsymbol{z} z的梯度 ∂ J / ∂ z ∈ R h \partial J/\partial \boldsymbol{z} \in \mathbb{R}^h ∂J/∂z∈Rh的计算使用按元素乘法符 ⊙ \odot ⊙:
∂ J ∂ z = prod ( ∂ J ∂ h , ∂ h ∂ z ) = ∂ J ∂ h ⊙ ϕ ′ ( z ) . \frac{\partial J}{\partial \boldsymbol{z}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{h}}, \frac{\partial \boldsymbol{h}}{\partial \boldsymbol{z}}\right) = \frac{\partial J}{\partial \boldsymbol{h}} \odot \phi'\left(\boldsymbol{z}\right). ∂z∂J=prod(∂h∂J,∂z∂h)=∂h∂J⊙ϕ′(z).
(7) 最终,得到最靠近输入层的模型参数的梯度 ∂ J / ∂ W ( 1 ) ∈ R h × d \partial J/\partial \boldsymbol{W}^{(1)} \in \mathbb{R}^{h \times d} ∂J/∂W(1)∈Rh×d。依据链式法则,得到
∂ J ∂ W ( 1 ) = prod ( ∂ J ∂ z , ∂ z ∂ W ( 1 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 1 ) ) = ∂ J ∂ z x ⊤ + λ W ( 1 ) . \frac{\partial J}{\partial \boldsymbol{W}^{(1)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{z}}, \frac{\partial \boldsymbol{z}}{\partial \boldsymbol{W}^{(1)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(1)}}\right) = \frac{\partial J}{\partial \boldsymbol{z}} \boldsymbol{x}^\top + \lambda \boldsymbol{W}^{(1)}. ∂W(1)∂J=prod(∂z∂J,∂W(1)∂z)+prod(∂s∂J,∂W(1)∂s)=∂z∂Jx⊤+λW(1).
在训练深度学习模型时,正向传播和反向传播之间相互依赖。
依然以本节的样例模型为例,阐述如下:
一方面,正向传播的计算可能依赖于模型参数的当前值,而这些模型参数是在反向传播的梯度计算后通过优化算法迭代的。
例如,计算正则化项 s = ( λ / 2 ) ( ∣ W ( 1 ) ∣ F 2 + ∣ W ( 2 ) ∣ F 2 ) s = (\lambda/2) \left(|\boldsymbol{W}^{(1)}|_F^2 + |\boldsymbol{W}^{(2)}|_F^2\right) s=(λ/2)(∣W(1)∣F2+∣W(2)∣F2),依赖模型参数 W ( 1 ) \boldsymbol{W}^{(1)} W(1)和 W ( 2 ) \boldsymbol{W}^{(2)} W(2)的当前值,而这些当前值是优化算法最近一次根据反向传播算出梯度后迭代得到的。
另一方面,反向传播的梯度计算可能依赖于各变量的当前值,而这些变量的当前值是通过正向传播计算得到的。
例如,参数梯度 ∂ J / ∂ W ( 2 ) = ( ∂ J / ∂ o ) h ⊤ + λ W ( 2 ) \partial J/\partial \boldsymbol{W}^{(2)} = (\partial J / \partial \boldsymbol{o}) \boldsymbol{h}^\top + \lambda \boldsymbol{W}^{(2)} ∂J/∂W(2)=(∂J/∂o)h⊤+λW(2)的计算需要依赖隐藏层变量的当前值 h \boldsymbol{h} h,该当前值是通过从输入层到输出层的正向传播计算并存储得到的。
因此,在模型参数初始化完成后,需交替进行正向传播和反向传播,并根据反向传播计算的梯度迭代模型参数。
既然在反向传播中使用了正向传播中计算得到的中间变量来避免重复计算,那么这个复用也导致正向传播结束后不能立即释放中间变量内存,这也是训练要比预测占用更多内存的一个重要原因。
另外,需要指出的是,这些中间变量的个数大体上与网络层数线性相关,每个变量的大小跟批量大小和输入个数也是线性相关的,它们是导致较深的神经网络使用较大批量训练时更容易超内存的主要原因。
《动手学深度学习》(TF2.0版)