从接触 BP 到真正理解 BP 花了不少时间,真正理解梯度是在 deeplearning.ai 的作业中。而我希望通过本文能让你少走不少弯路。
训练数据(training set):输入数据
标签(label):期望的结果
损失函数(Loss Function):预测值与标签的差值
梯度下降(Gradient Descent):
在我初次接触计算图的时候是排斥的,因为很反感对严格技术的近似估计的描述,有时候
舍本逐末。但是,当我在 cs231n,deeplearning.ai 都讲到计算图的时候,我不得不重新
审视这个问题。发现计算图是个好东西
在正式开始求解 GD 之前,我们先看看计算图,是对神经网络的一个最直观描述,而这种直观的描述,对于理解后续其他一些更详细的主题,如数据归一化,batch normalization,梯度暴涨,梯度消失等问题非常关键。
尝试画出下面函数的计算图
参考 cs231n lecture 4
Add : 将梯度按均匀分发
f = x + y
df/dx = 1
df/dy = 1
Max : 将梯度传递给最大的输入值
f = max(x, y) # x > y
df/dx = 1
df/dy = 0
Multiply : 将梯度放到到对端的倍数
f = x * y
df/dx = y
df/dy = x
多层 : 层与层之间关系
h = f(x)
g = f(h)
dg/dx = dg/dh * dh/dx
直观地理解 BP 的意义,对调试问题有非常大的帮助。
一个最直接的原因就是在遇到问题的时候,你能够依据直观的理解非常快地定位到可能的原因。而且对 GD 的直观理解对于理解数据归一化,batch normalization,梯度暴涨,梯度消失等问题非常关键。
number gradient :慢,估计的,但是很容易实现
analysis gradient :快,准确,但是容易出错
在检查 gradient descent 算法是否正确的方式叫 gradient check,方法是通过求导的方式获取向量 dθ d θ ,通过导数定义获取向量 dθappr d θ a p p r ,计算这两个向量的距离。
所谓 BP(反向传播)就是从最后一层往前算,因此,
如果解决上述两个问题,自然而然,就可以通过遍历,依次获取所有层的偏导。
我花了非常多时间才自己悟到了上面的关键点,而之前,我在网上竟然不能找到一个完全正确的非常好理解的 BP 推导。即使是西瓜书也仅仅是以一个特例进行推导,而且很绕。
下面以单个样本为例,推导
对于损失函数
J(w, b, x, y)
其中
aL=f(zL)=f(WLaL−1+bL) a L = f ( z L ) = f ( W L a L − 1 + b L ) 其中 f(x) 为激活函数,可以是 sigmoid, tanh, relu 等。
第一步,计算最后一层偏导
∂J(w,b,x,y)∂WL=∂J(w,b,x,y)∂aL⋅∂aL∂zL⋅∂zL∂WL ∂ J ( w , b , x , y ) ∂ W L = ∂ J ( w , b , x , y ) ∂ a L ⋅ ∂ a L ∂ z L ⋅ ∂ z L ∂ W L
∂J(w,b,x,y)∂bL=∂J(w,b,x,y)∂aL⋅∂aL∂zL⋅∂zL∂bL ∂ J ( w , b , x , y ) ∂ b L = ∂ J ( w , b , x , y ) ∂ a L ⋅ ∂ a L ∂ z L ⋅ ∂ z L ∂ b L
由于 ∂J(w,b,x,y)∂aL⋅∂aL∂zL ∂ J ( w , b , x , y ) ∂ a L ⋅ ∂ a L ∂ z L 是两个偏导的公共部分。因此,令 δL=∂J(w,b,x,y)∂aL⋅∂aL∂zL=∂J(w,b,x,y)∂zL δ L = ∂ J ( w , b , x , y ) ∂ a L ⋅ ∂ a L ∂ z L = ∂ J ( w , b , x , y ) ∂ z L
于是,上式可以表示为
∂J(w,b,x,y)∂WL=ηL(aL−1)T ∂ J ( w , b , x , y ) ∂ W L = η L ( a L − 1 ) T
∂J(w,b,x,y)∂bL=ηL ∂ J ( w , b , x , y ) ∂ b L = η L
第二步,计算相邻两层之间的偏导关系
δL−1=∂J(w,b,x,y)∂zL∂zL∂zL−1=δLWLf′(zL−1)其中∂ZL∂ZL−1=WL⊙f′(zL−1) δ L − 1 = ∂ J ( w , b , x , y ) ∂ z L ∂ z L ∂ z L − 1 = δ L W L f ′ ( z L − 1 ) 其 中 ∂ Z L ∂ Z L − 1 = W L ⊙ f ′ ( z L − 1 )
注: ⊙ ⊙ 表示矩阵内积,即对应元素相乘
以此类推
δl=∂J(w,b,x,y)∂zl=∂J(w,b,x,y)∂zL∂zL∂zL−1⋯∂zl+1∂zl=δl+1∂zl+1∂zl δ l = ∂ J ( w , b , x , y ) ∂ z l = ∂ J ( w , b , x , y ) ∂ z L ∂ z L ∂ z L − 1 ⋯ ∂ z l + 1 ∂ z l = δ l + 1 ∂ z l + 1 ∂ z l
∂J(w,b,x,y)∂Wl=∂J(w,b,x,y)∂zl⋅∂zl∂Wl=δl(al−1)T ∂ J ( w , b , x , y ) ∂ W l = ∂ J ( w , b , x , y ) ∂ z l ⋅ ∂ z l ∂ W l = δ l ( a l − 1 ) T
∂J(w,b,x,y)∂bl=∂J(w,b,x,y)∂zl⋅∂zl∂bl=δl ∂ J ( w , b , x , y ) ∂ b l = ∂ J ( w , b , x , y ) ∂ z l ⋅ ∂ z l ∂ b l = δ l
至此,整个式子可用。
因此,反向传播算法过程
从 L 到 2,计算
∂J(w,b,x,y)∂Wl=δl(al−1)T ∂ J ( w , b , x , y ) ∂ W l = δ l ( a l − 1 ) T
∂J(w,b,x,y)∂bl=δl ∂ J ( w , b , x , y ) ∂ b l = δ l
计算 f′(zl−1) f ′ ( z l − 1 )
δl−1=δlWlf′(zl−1) δ l − 1 = δ l W l f ′ ( z l − 1 )
注:以上推导没有对任何损失函数和激活函数进行假设,因此可以为任何激活函数,或损失函数。实际上损失函数和激活函数的导数都比较简单,比如平方差损失函数,sigmoid 激活函数。
对于任意的函数,
δL=∂J(w,b,x,y)∂zL δ L = ∂ J ( w , b , x , y ) ∂ z L
∂J(w,b,x,y)∂Wl=∂J(w,b,x,y)∂zl⋅∂zl∂Wl=δl∂zl∂Wl ∂ J ( w , b , x , y ) ∂ W l = ∂ J ( w , b , x , y ) ∂ z l ⋅ ∂ z l ∂ W l = δ l ∂ z l ∂ W l
∂J(w,b,x,y)∂bl=∂J(w,b,x,y)∂zl⋅∂zl∂bl=δl ∂ J ( w , b , x , y ) ∂ b l = ∂ J ( w , b , x , y ) ∂ z l ⋅ ∂ z l ∂ b l = δ l
δl=∂J(w,b,x,y)∂zl=∂J(w,b,x,y)∂zL∂zL∂zL−1⋯∂zl+1∂zl=δl+1∂zl+1∂zl δ l = ∂ J ( w , b , x , y ) ∂ z l = ∂ J ( w , b , x , y ) ∂ z L ∂ z L ∂ z L − 1 ⋯ ∂ z l + 1 ∂ z l = δ l + 1 ∂ z l + 1 ∂ z l
对于 CNN 网络
因此,通过上面,可知,CNN 与 DNN 最大的区别在于 ∂zl∂Wl ∂ z l ∂ W l 和 ∂zl+1∂al ∂ z l + 1 ∂ a l 的求解。
求解 ∂zl+1∂al ∂ z l + 1 ∂ a l
zl+1=W∗al+bl z l + 1 = W ∗ a l + b l
为了方便理解,以一个例子为例
z = a * w
a=[[a11,a12,a13],[a21,a22,a23],[a31,a32,a33]] a = [ [ a 11 , a 12 , a 13 ] , [ a 21 , a 22 , a 23 ] , [ a 31 , a 32 , a 33 ] ]
w=[[w11,w12],[w21,w22]] w = [ [ w 11 , w 12 ] , [ w 21 , w 22 ] ]
z=[[z11,z12],[z21,z22]] z = [ [ z 11 , z 12 ] , [ z 21 , z 22 ] ]
z11=a11w11+a12w12+a21w21+a22w22 z 11 = a 11 w 11 + a 12 w 12 + a 21 w 21 + a 22 w 22
z12=a12w11+a13w12+a22w21+a23w22 z 12 = a 12 w 11 + a 13 w 12 + a 22 w 21 + a 23 w 22
z21=a21w11+a22w12+a31w21+a32w22 z 21 = a 21 w 11 + a 22 w 12 + a 31 w 21 + a 32 w 22
z22=a22w11+a23w12+a32w21+a33w22 z 22 = a 22 w 11 + a 23 w 12 + a 32 w 21 + a 33 w 22
∂z11∂a11=w11 ∂ z 11 ∂ a 11 = w 11
∂z11∂a12=w12+w11 ∂ z 11 ∂ a 12 = w 12 + w 11
∂z11∂a13=w12 ∂ z 11 ∂ a 13 = w 12
h(t)=tanh(Ux(t)+Wh(t−1)+b) h ( t ) = tanh ( U x ( t ) + W h ( t − 1 ) + b )
o(t)=Vh(t)+c o ( t ) = V h ( t ) + c
y(t)=σ(o(t)) y ( t ) = σ ( o ( t ) )
由于是基于时间的,所有也叫 BPTT(back propagation through time)
损失函数
L=∑τt=1L(t) L = ∑ t = 1 τ L ( t )
∂L∂o(t)=∑τt=2(y−ŷ ) ∂ L ∂ o ( t ) = ∑ t = 2 τ ( y − y ^ )
∂L∂c=∑τt=1∂L(t)∂o(t)∂o(t)∂c=∑τt=1∂L(t)∂o(t)=∑τt=1(y−ŷ ) ∂ L ∂ c = ∑ t = 1 τ ∂ L ( t ) ∂ o ( t ) ∂ o ( t ) ∂ c = ∑ t = 1 τ ∂ L ( t ) ∂ o ( t ) = ∑ t = 1 τ ( y − y ^ )
∂L∂V=∑τt=1∂L(t)∂o(t)∂o(t)∂V=∑τt=1∂L(t)∂o(t)(h(t))T=∑τt=1(y−ŷ )(h(t))T ∂ L ∂ V = ∑ t = 1 τ ∂ L ( t ) ∂ o ( t ) ∂ o ( t ) ∂ V = ∑ t = 1 τ ∂ L ( t ) ∂ o ( t ) ( h ( t ) ) T = ∑ t = 1 τ ( y − y ^ ) ( h ( t ) ) T
δ(t)=∂L∂h(t)=∂L(t)∂o(t)∂o(t)∂h(t)+∂L(t+1)∂h(t+1)∂h(t+1)∂h(t)=VT∂L∂o(t)+WTδ(t+1)diag(1−(h(t+1))2) δ ( t ) = ∂ L ∂ h ( t ) = ∂ L ( t ) ∂ o ( t ) ∂ o ( t ) ∂ h ( t ) + ∂ L ( t + 1 ) ∂ h ( t + 1 ) ∂ h ( t + 1 ) ∂ h ( t ) = V T ∂ L ∂ o ( t ) + W T δ ( t + 1 ) d i a g ( 1 − ( h ( t + 1 ) ) 2 )
由于 δ(τ) δ ( τ ) 后没有其他索引了,因此
δ(τ)=VT∂L∂o(t) δ ( τ ) = V T ∂ L ∂ o ( t ) 有了该式,那么根据上式既可以依次计算所有 δ(t) δ ( t ) 啦
∂L∂W=∑τt=1∂L∂h(t)∂h(t)∂W=∑τt=1diag(1−(h(t))2)δ(t)(h(t−1))T ∂ L ∂ W = ∑ t = 1 τ ∂ L ∂ h ( t ) ∂ h ( t ) ∂ W = ∑ t = 1 τ d i a g ( 1 − ( h ( t ) ) 2 ) δ ( t ) ( h ( t − 1 ) ) T
∂L∂U=∑τt=1∂L∂h(t)∂h(t)∂U=∑τt=1diag(1−(h(t))2)δ(t)(x(t))T ∂ L ∂ U = ∑ t = 1 τ ∂ L ∂ h ( t ) ∂ h ( t ) ∂ U = ∑ t = 1 τ d i a g ( 1 − ( h ( t ) ) 2 ) δ ( t ) ( x ( t ) ) T
∂L∂b=∑τt=1∂L∂h(t)∂h(t)∂b=∑τt=1diag(1−(h(t))2)δ(t) ∂ L ∂ b = ∑ t = 1 τ ∂ L ∂ h ( t ) ∂ h ( t ) ∂ b = ∑ t = 1 τ d i a g ( 1 − ( h ( t ) ) 2 ) δ ( t )
由上可知,相邻两个 δ(t) δ ( t ) 和 δ(t+1) δ ( t + 1 ) 之间是 W^T 的关系,因此,随着深度增加,梯度以是指数级增加。
以上就是 BP 的 一些总结,希望对你能有帮助。