假设 X X X为 N × m N\times m N×m的矩阵(其中, N N N为样本个数(batch size), m m m为特征维数)
h 1 h_1 h1与 Z 1 Z_1 Z1的维数为 m 1 → W 1 m_1 \rightarrow W_1 m1→W1为 m × m 1 m\times m_1 m×m1的矩阵, b 1 ∈ R m 1 , b_1 \in \mathbb{R}^{m_1}, b1∈Rm1,
h 2 h_2 h2与 Z 2 Z_2 Z2的维数为 m 2 → W 2 m_2 \rightarrow W_2 m2→W2为 m 1 × m 2 m_1\times m_2 m1×m2的矩阵, b 2 ∈ R m 2 , b_2 \in \mathbb{R}^{m_2}, b2∈Rm2,
⋮ {\vdots} ⋮
h L h_L hL与 Z L Z_L ZL的维数为 m L → W L m_L \rightarrow W_L mL→WL为 m L − 1 × m L m_{L-1}\times m_L mL−1×mL的矩阵, b L ∈ R m L b_L \in \mathbb{R}^{m_L} bL∈RmL
h 1 = x W 1 + b ~ 1 , Z 1 = f 1 ( h 1 ) , b ~ 1 为 b 1 T 沿 着 行 方 向 扩 展 成 N 行 h 2 = Z 1 W 2 + b ~ 2 , Z 2 = f 2 ( h 2 ) ⋮ h L = Z L − 1 W L + b ~ L , Z L = f L ( h L ) out = Z L W L + 1 + b ~ L + 1 \begin{array}{l}{h_{1}=x W_{1}+\tilde{b}_{1}, Z_{1}=f_{1}\left(h_{1}\right), \tilde{b}_{1}}为b_1^T沿着行方向扩展成N行 \\ {h_{2}=Z_{1} W_{2}+\tilde{b}_{2}, Z_{2}=f_{2}\left(h_{2}\right)} \\ {\vdots} \\ {h_{L}=Z_{L-1} W_{L}+\tilde{b}_{L}, Z_{L}=f_{L}\left(h_{L}\right)} \\ {\text { out }=Z_{L} W_{L+1}+\tilde{b}_{L+1}}\end{array} h1=xW1+b~1,Z1=f1(h1),b~1为b1T沿着行方向扩展成N行h2=Z1W2+b~2,Z2=f2(h2)⋮hL=ZL−1WL+b~L,ZL=fL(hL) out =ZLWL+1+b~L+1
假设输出为 n n n维,则 o u t out out为大小为 N × n N\times n N×n的矩阵,根据MSE或CE准则可以求得 ∂ J ∂ o u t \frac{\partial J}{\partial out} ∂out∂J,对于回归问题与分类问题, ∂ J ∂ o u t \frac{\partial J}{\partial out} ∂out∂J的求解方法如下:
|
|
KL距离(相对熵):是Kullback-Leibler Divergence的简称,也叫相对熵(Relative Entropy).它衡量的是相同事件空间里的两个概率分布的差异情况。其物理意义是:在相同事件空间里,概率分布 P ( x ) P(x) P(x)对应的每个事件,若用概率分布 Q ( x ) Q(x) Q(x)编码时,平均每个基本事件(符号)编码长度增加了多少比特。我们用 D ( P ∣ ∣ Q ) D(P||Q) D(P∣∣Q)表示KL距离,计算公式如下:
D ( P ∣ ∣ Q ) = ∑ x ∈ X P ( x ) log P ( x ) Q ( x ) D(P||Q)=\sum_{x\in X}P(x)\log\frac{P(x)}{Q(x)} D(P∣∣Q)=x∈X∑P(x)logQ(x)P(x)
当两个概率分布完全相同时,即 P ( X ) = Q ( X ) P(X)=Q(X) P(X)=Q(X),其相对熵为0.
out = Z L W L + 1 + b ~ L + 1 \text { out }=Z_{L} W_{L+1}+\tilde{b}_{L+1} out =ZLWL+1+b~L+1,为了便于详细说明反向传播算法,假设 Z L Z_L ZL为 2 × 3 2\times 3 2×3的向量, W L + 1 W_{L+1} WL+1为 3 × 2 3\times 2 3×2的向量:
Z L = ( z 11 z 12 z 13 z 21 z 22 z 23 ) 2 × 3 , W L + 1 = ( w 11 w 12 w 21 w 22 w 31 w 32 ) 3 × 2 b ~ L + 1 = ( b 1 b 2 b 1 b 2 ) 2 × 2 , out = ( o 11 o 12 o 21 o 22 ) ⇒ Z L W L + 1 + b ~ L + 1 = ( z 11 w 11 + z 12 w 21 + z 13 w 31 + b 1 z 11 w 12 + z 12 w 22 + z 13 w 32 + b 2 z 21 w 11 + z 22 w 21 + z 23 w 31 + b 1 z 21 w 12 + z 22 w 22 + z 23 w 32 + b 2 ) = out . \begin{array}{l}{Z_{L}=\left(\begin{array}{ccc}{z_{11}} & {z_{12}} & {z_{13}} \\ {z_{21}} & {z_{22}} & {z_{23}}\end{array}\right)_{2 \times 3}, W_{L+1}=\left(\begin{array}{cc}{w_{11}} & {w_{12}} \\ {w_{21}} & {w_{22}} \\ {w_{31}} & {w_{32}}\end{array}\right)_{3 \times 2} \tilde{b}_{L+1}=\left(\begin{array}{cc}{b_{1}} & {b_{2}} \\ {b_{1}} & {b_{2}}\end{array}\right)_{2 \times 2}, \text { out }=\left(\begin{array}{cc}{o_{11}} & {o_{12}} \\ {o_{21}} & {o_{22}}\end{array}\right)} \\ \Rightarrow {Z_{L}W_{L+1}+\tilde{b}_{L+1}=\left(\begin{array}{cc}{z_{11} w_{11}+z_{12} w_{21}+z_{13} w_{31}+b_1} & {z_{11} w_{12}+z_{12} w_{22}+z_{13} w_{32}+b_2} \\ {z_{21} w_{11}+z_{22} w_{21}+z_{23} w_{31}+b_1} & {z_{21} w_{12}+z_{22} w_{22}+z_{23} w_{32}+b_2}\end{array}\right)=\text{out}.}\end{array} ZL=(z11z21z12z22z13z23)2×3,WL+1=⎝⎛w11w21w31w12w22w32⎠⎞3×2b~L+1=(b1b1b2b2)2×2, out =(o11o21o12o22)⇒ZLWL+1+b~L+1=(z11w11+z12w21+z13w31+b1z21w11+z22w21+z23w31+b1z11w12+z12w22+z13w32+b2z21w12+z22w22+z23w32+b2)=out.
所以,
o 11 = z 11 w 11 + z 12 w 21 + z 13 w 31 + b 1 o 12 = z 11 w 12 + z 12 w 22 + z 13 w 32 + b 2 o 21 = z 21 w 11 + z 22 w 21 + z 23 w 31 + b 1 o 22 = z 21 w 12 + z 22 w 22 + z 23 w 32 + b 2 \begin{array}{l}{o_{11}=z_{11} w_{11}+z_{12} w_{21}+z_{13} w_{31}+b_{1}} \\ {o_{12}=z_{11} w_{12}+z_{12} w_{22}+z_{13} w_{32}+b_{2}} \\ {o_{21}=z_{21} w_{11}+z_{22} w_{21}+z_{23} w_{31}+b_{1}} \\ {o_{22}=z_{21} w_{12}+z_{22} w_{22}+z_{23} w_{32}+b_{2}}\end{array} o11=z11w11+z12w21+z13w31+b1o12=z11w12+z12w22+z13w32+b2o21=z21w11+z22w21+z23w31+b1o22=z21w12+z22w22+z23w32+b2
∂ J ∂ w 11 = ∂ J ∂ o 11 z 11 + ∂ J ∂ o 21 z 21 , ∂ J ∂ w 12 = ∂ J ∂ o 12 z 11 + ∂ J ∂ o 22 z 21 ∂ J ∂ w 21 = ∂ J ∂ o 11 z 12 + ∂ J ∂ o 21 z 22 , ∂ J ∂ w 22 = ∂ J ∂ o 12 z 12 + ∂ J ∂ o 22 z 22 ∂ J ∂ w 31 = ∂ J ∂ o 11 z 13 + ∂ J ∂ o 21 z 23 , ∂ J ∂ w 32 = ∂ J ∂ o 12 z 13 + ∂ J ∂ o 22 z 23 \begin{aligned} \frac{\partial J}{\partial w_{11}} &=\frac{\partial J}{\partial o_{11}} z_{11}+\frac{\partial J}{\partial o_{21}} z_{21}, \frac{\partial J}{\partial w_{12}}=\frac{\partial J}{\partial o_{12}} z_{11}+\frac{\partial J}{\partial o_{22}} z_{21} \\ \frac{\partial J}{\partial w_{21}} &=\frac{\partial J}{\partial o_{11}} z_{12}+\frac{\partial J}{\partial o_{21}} z_{22}, \frac{\partial J}{\partial w_{22}}=\frac{\partial J}{\partial o_{12}} z_{12}+\frac{\partial J}{\partial o_{22}} z_{22} \\ \frac{\partial J}{\partial w_{31}} &=\frac{\partial J}{\partial o_{11}} z_{13}+\frac{\partial J}{\partial o_{21}} z_{23}, \frac{\partial J}{\partial w_{32}}=\frac{\partial J}{\partial o_{12}} z_{13}+\frac{\partial J}{\partial o_{22}} z_{23} \end{aligned} ∂w11∂J∂w21∂J∂w31∂J=∂o11∂Jz11+∂o21∂Jz21,∂w12∂J=∂o12∂Jz11+∂o22∂Jz21=∂o11∂Jz12+∂o21∂Jz22,∂w22∂J=∂o12∂Jz12+∂o22∂Jz22=∂o11∂Jz13+∂o21∂Jz23,∂w32∂J=∂o12∂Jz13+∂o22∂Jz23
⇒ ( ∂ J ∂ w 11 ∂ J ∂ w 12 ∂ J ∂ w 21 ∂ J ∂ w 22 ∂ J ∂ w 31 ∂ J ∂ w 32 ) = ( z 11 z 21 z 12 z 22 z 13 z 23 ) ( ∂ J ∂ o 11 ∂ J ∂ o 12 ∂ J ∂ o 21 ∂ J ∂ o 22 ) \Rightarrow \left(\begin{array}{cc}{\frac{\partial J}{\partial w_{11}}} & {\frac{\partial J}{\partial w_{12}}} \\ {\frac{\partial J}{\partial w_{21}}} & {\frac{\partial J}{\partial w_{22}}} \\ {\frac{\partial J}{\partial w_{31}}} & {\frac{\partial J}{\partial w_{32}}}\end{array}\right)=\left(\begin{array}{cc}{z_{11}} & {z_{21}} \\ {z_{12}} & {z_{22}} \\ {z_{13}} & {z_{23}}\end{array}\right)\left(\begin{array}{cc}{\frac{\partial J}{\partial o_{11}}} & {\frac{\partial J}{\partial o_{12}}} \\ {\frac{\partial J}{\partial o_{21}}} & {\frac{\partial J}{\partial o_{22}}}\end{array}\right) ⇒⎝⎛∂w11∂J∂w21∂J∂w31∂J∂w12∂J∂w22∂J∂w32∂J⎠⎞=⎝⎛z11z12z13z21z22z23⎠⎞(∂o11∂J∂o21∂J∂o12∂J∂o22∂J)
即, ∂ J ∂ W L + 1 = Z L T ∂ J ∂ o u t \frac{\partial J}{\partial W_{L+1}}=Z_L^T\frac{\partial J}{\partial out} ∂WL+1∂J=ZLT∂out∂J
{ ∂ J ∂ b 1 = ∂ J ∂ o 11 + ∂ J ∂ o 21 ∂ J ∂ b 2 = ∂ J ∂ o 12 + ∂ J ∂ o 22 ⇒ ( ∂ J ∂ b ) T = ( ∂ J ∂ b 1 ∂ J ∂ b 2 ) = ( ∂ J ∂ o 11 + ∂ J ∂ o 21 ∂ J ∂ o 12 + ∂ J ∂ o 22 ) \left\{\begin{array}{l}{\frac{\partial J}{\partial b_{1}}=\frac{\partial J}{\partial o_{11}}+\frac{\partial J}{\partial o_{21}}} \\ {\frac{\partial J}{\partial b_{2}}=\frac{\partial J}{\partial o_{12}}+\frac{\partial J}{\partial o_{22}}}\end{array} \Rightarrow\left(\frac{\partial J}{\partial b}\right)^{T}=\left(\frac{\partial J}{\partial b_{1}} \quad \frac{\partial J}{\partial b_{2}}\right)=\left(\frac{\partial J}{\partial o_{11}}+\frac{\partial J}{\partial o_{21}} \quad \frac{\partial J}{\partial o_{12}}+\frac{\partial J}{\partial o_{22}}\right)\right. {∂b1∂J=∂o11∂J+∂o21∂J∂b2∂J=∂o12∂J+∂o22∂J⇒(∂b∂J)T=(∂b1∂J∂b2∂J)=(∂o11∂J+∂o21∂J∂o12∂J+∂o22∂J)
∂ J ∂ z 11 = ∂ J ∂ o 11 w 11 + ∂ J ∂ o 12 w 12 ; ∂ J ∂ z 12 = ∂ J ∂ o 11 w 21 + ∂ J ∂ o 12 w 22 ; ∂ J ∂ z 13 = ∂ J ∂ o 11 w 31 + ∂ J ∂ o 12 w 32 ∂ J ∂ z 21 = ∂ J ∂ o 21 w 11 + ∂ J ∂ o 22 w 12 ; ∂ J ∂ z 22 = ∂ J ∂ o 21 w 21 + ∂ J ∂ o 12 w 22 ; ∂ J ∂ z 23 = ∂ J ∂ o 21 w 31 + ∂ J ∂ o 22 w 32 \begin{aligned} \frac{\partial J}{\partial z_{11}} &=\frac{\partial J}{\partial o_{11}} w_{11}+\frac{\partial J}{\partial o_{12}} w_{12} ; \frac{\partial J}{\partial z_{12}}=\frac{\partial J}{\partial o_{11}} w_{21}+\frac{\partial J}{\partial o_{12}} w_{22} ; \frac{\partial J}{\partial z_{13}}=\frac{\partial J}{\partial o_{11}} w_{31}+\frac{\partial J}{\partial o_{12}} w_{32} \\ \frac{\partial J}{\partial z_{21}} &=\frac{\partial J}{\partial o_{21}} w_{11}+\frac{\partial J}{\partial o_{22}} w_{12} ; \frac{\partial J}{\partial z_{22}}=\frac{\partial J}{\partial o_{21}} w_{21}+\frac{\partial J}{\partial o_{12}} w_{22} ; \frac{\partial J}{\partial z_{23}}=\frac{\partial J}{\partial o_{21}} w_{31}+\frac{\partial J}{\partial o_{22}} w_{32} \end{aligned} ∂z11∂J∂z21∂J=∂o11∂Jw11+∂o12∂Jw12;∂z12∂J=∂o11∂Jw21+∂o12∂Jw22;∂z13∂J=∂o11∂Jw31+∂o12∂Jw32=∂o21∂Jw11+∂o22∂Jw12;∂z22∂J=∂o21∂Jw21+∂o12∂Jw22;∂z23∂J=∂o21∂Jw31+∂o22∂Jw32
即,
( ∂ J ∂ z 11 ∂ J ∂ z 12 ∂ J ∂ z 13 ∂ J ∂ z 21 ∂ J ∂ z 22 ∂ J ∂ z 23 ) = ( ∂ J ∂ o 11 ∂ J ∂ o 12 ∂ J ∂ θ 21 ∂ J ∂ o 22 ) ( w 11 w 21 w 31 w 12 w 22 w 32 ) \left(\begin{array}{ccc}{\frac{\partial J}{\partial z_{11}}} & {\frac{\partial J}{\partial z_{12}}} & {\frac{\partial J}{\partial z_{13}}} \\ {\frac{\partial J}{\partial z_{21}}} & {\frac{\partial J}{\partial z_{22}}} & {\frac{\partial J}{\partial z_{23}}}\end{array}\right)=\left(\begin{array}{cc}{\frac{\partial J}{\partial o_{11}}} & {\frac{\partial J}{\partial o_{12}}} \\ {\frac{\partial J}{\partial \theta_{21}}} & {\frac{\partial J}{\partial o_{22}}}\end{array}\right)\left(\begin{array}{ccc}{w_{11}} & {w_{21}} & {w_{31}} \\ {w_{12}} & {w_{22}} & {w_{32}}\end{array}\right) (∂z11∂J∂z21∂J∂z12∂J∂z22∂J∂z13∂J∂z23∂J)=(∂o11∂J∂θ21∂J∂o12∂J∂o22∂J)(w11w12w21w22w31w32)
⇒ ∂ J ∂ Z L = ∂ J ∂ o u t W L + 1 T \Rightarrow \frac{\partial J}{\partial Z_{L}}=\frac{\partial J}{\partial out}W_{L+1}^T ⇒∂ZL∂J=∂out∂JWL+1T
Z L = f L ( h L ) Z_L = f_L(h_L) ZL=fL(hL)
对于不同算法 ,梯度更新方式如下:
∂ J ∂ o u t ⇒ { ∂ J ∂ W L + 1 = Z L T ∂ J ∂ o u t ∂ J ∂ Z L = ∂ J ∂ o u t W L + 1 T ( ∂ J ∂ b ) T = S u m C o l ( ∂ J ∂ o u t ) W L + 1 t + 1 = W L + 1 t − η ∂ J ∂ W L + 1 b L + 1 t + 1 = b L + 1 t − η ∂ J ∂ b L + 1 ⇒ ∂ J ∂ h L = ∂ J ∂ Z L ∂ Z L ∂ h L ⇒ { ∂ J ∂ W L = Z L − 1 T ∂ J ∂ h L ∂ J ∂ Z L − 1 = ∂ J ∂ h L W L T ⋮ ⋮ ⇒ ⋯ \frac{\partial J}{\partial out} \Rightarrow \left \{\begin{matrix} \frac{\partial J}{\partial W_{L+1}}=Z_L^T\frac{\partial J}{\partial out} \\ \frac{\partial J}{\partial Z_{L}}=\frac{\partial J}{\partial out}W_{L+1}^T \\ \left(\frac{\partial J}{\partial b}\right)^{T}=SumCol(\frac{\partial J}{\partial out}) \\ W_{L+1}^{t+1} = W_{L+1}^t-\eta \frac{\partial J}{\partial W_{L+1}} \\ b_{L+1}^{t+1} = b_{L+1}^t-\eta \frac{\partial J}{\partial b_{L+1}} \end{matrix} \right. \Rightarrow \frac{\partial J}{\partial h_L}=\frac{\partial J}{\partial Z_L}\frac{\partial Z_L}{\partial h_L} \Rightarrow \left \{\begin{matrix} \frac{\partial J}{\partial W_{L}}=Z_{L-1}^T\frac{\partial J}{\partial h_L} \\ \frac{\partial J}{\partial Z_{L-1}}=\frac{\partial J}{\partial h_L}W_{L}^T \\ \vdots \\ \vdots \end{matrix}\right. \Rightarrow \cdots ∂out∂J⇒⎩⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎧∂WL+1∂J=ZLT∂out∂J∂ZL∂J=∂out∂JWL+1T(∂b∂J)T=SumCol(∂out∂J)WL+1t+1=WL+1t−η∂WL+1∂JbL+1t+1=bL+1t−η∂bL+1∂J⇒∂hL∂J=∂ZL∂J∂hL∂ZL⇒⎩⎪⎪⎪⎪⎨⎪⎪⎪⎪⎧∂WL∂J=ZL−1T∂hL∂J∂ZL−1∂J=∂hL∂JWLT⋮⋮⇒⋯
参考:深度之眼公众号