反向传播算法推导过程(非常详细)

1. 前向传播

反向传播算法推导过程(非常详细)_第1张图片
假设 X X X N × m N\times m N×m的矩阵(其中, N N N为样本个数(batch size), m m m为特征维数)

h 1 h_1 h1 Z 1 Z_1 Z1的维数为 m 1 → W 1 m_1 \rightarrow W_1 m1W1 m × m 1 m\times m_1 m×m1的矩阵, b 1 ∈ R m 1 , b_1 \in \mathbb{R}^{m_1}, b1Rm1,
h 2 h_2 h2 Z 2 Z_2 Z2的维数为 m 2 → W 2 m_2 \rightarrow W_2 m2W2 m 1 × m 2 m_1\times m_2 m1×m2的矩阵, b 2 ∈ R m 2 , b_2 \in \mathbb{R}^{m_2}, b2Rm2,
⋮ {\vdots}
h L h_L hL Z L Z_L ZL的维数为 m L → W L m_L \rightarrow W_L mLWL m L − 1 × m L m_{L-1}\times m_L mL1×mL的矩阵, b L ∈ R m L b_L \in \mathbb{R}^{m_L} bLRmL

前向算法:

h 1 = x W 1 + b ~ 1 , Z 1 = f 1 ( h 1 ) , b ~ 1 为 b 1 T 沿 着 行 方 向 扩 展 成 N 行 h 2 = Z 1 W 2 + b ~ 2 , Z 2 = f 2 ( h 2 ) ⋮ h L = Z L − 1 W L + b ~ L , Z L = f L ( h L )  out  = Z L W L + 1 + b ~ L + 1 \begin{array}{l}{h_{1}=x W_{1}+\tilde{b}_{1}, Z_{1}=f_{1}\left(h_{1}\right), \tilde{b}_{1}}为b_1^T沿着行方向扩展成N行 \\ {h_{2}=Z_{1} W_{2}+\tilde{b}_{2}, Z_{2}=f_{2}\left(h_{2}\right)} \\ {\vdots} \\ {h_{L}=Z_{L-1} W_{L}+\tilde{b}_{L}, Z_{L}=f_{L}\left(h_{L}\right)} \\ {\text { out }=Z_{L} W_{L+1}+\tilde{b}_{L+1}}\end{array} h1=xW1+b~1,Z1=f1(h1),b~1b1T沿Nh2=Z1W2+b~2,Z2=f2(h2)hL=ZL1WL+b~L,ZL=fL(hL) out =ZLWL+1+b~L+1
假设输出为 n n n维,则 o u t out out为大小为 N × n N\times n N×n的矩阵,根据MSE或CE准则可以求得 ∂ J ∂ o u t \frac{\partial J}{\partial out} outJ,对于回归问题与分类问题, ∂ J ∂ o u t \frac{\partial J}{\partial out} outJ的求解方法如下:

反向传播算法推导过程(非常详细)_第2张图片
回归问题
反向传播算法推导过程(非常详细)_第3张图片
分类问题
  • 对于回归问题,对out直接计算损失,损失函数为MSE。 损失: J = 1 2 N ∑ i = 1 N ∣ ∣ y i − y i ~ ∣ ∣ 2 J=\frac{1}{2N}\sum_{i=1}^{N}||y_i-\tilde{y_i}||^2 J=2N1i=1Nyiyi~2
    ∂ J ∂ y i = 1 2 N ∑ i = 1 N ( y i − y i ~ ) × 2 = 1 N ∑ i = 1 N ( y i − y i ~ ) \begin{aligned} \frac{\partial J}{\partial y_i}&=\frac{1}{2N}\sum_{i=1}^{N}(y_i-\tilde{y_i})\times 2 \\ &=\frac{1}{N}\sum_{i=1}^{N}(y_i-\tilde{y_i}) \end{aligned} yiJ=2N1i=1N(yiyi~)×2=N1i=1N(yiyi~)
  • 对于分类问题,out后接softmax进行分类,然后使用CE(cross entropy)计算loss. S k = e y k ∑ i = 1 n e y i S_k=\frac{e^{y_k}}{\sum_{i=1}^{n}e^{y_i}} Sk=i=1neyieyk一个样本对应的网络的输出 S ( s 1 , s 2 , . . . , s n ) S(s_1,s_2,...,s_n) S(s1,s2,...,sn)是一个概率分布,而这个样本的标注 S ~ \tilde{S} S~一般为 ( 0 , 0 , . . . , 1 , 0 , 0 , . . . , 0 ) (0,0,...,1,0,0,...,0) (0,0,...,1,0,0,...,0),也可以看做一个概率分布(硬分布)。cross entropy可以看成是 S S S S ~ \tilde{S} S~之间的KL距离:
    D ( S ~ ∣ ∣ S ) = Σ S ~ log ⁡ S ~ S D(\tilde{S}||S)=\Sigma\tilde{S}\log\frac{\tilde{S}}{S} D(S~S)=ΣS~logSS~
    • 假设 S ~ = ( 0 , 0 , . . . , 1 , 0 , 0 , . . . , 0 ) \tilde{S}=(0,0,...,1,0,0,...,0) S~=(0,0,...,1,0,0,...,0),其中1为第 k k k个元素(索引从0开始),令 S = ( s 0 , s 1 , . . . , s k , . . . , s n − 1 ) S=(s_0,s_1,...,s_k,...,s_{n-1}) S=(s0,s1,...,sk,...,sn1).
      损失:
      J = D ( S ~ ∣ ∣ S ) = 1 × log ⁡ 1 s k = − log ⁡ s k ( C E 损 失 函 数 , 可 看 做 目 标 类 别 概 率 最 大 ) = − log ⁡ e y k ∑ i = 0 n − 1 e y i \begin{aligned} J=D(\tilde{S}||S)&=1\times \log\frac{1}{s_k}\\&=-\log s_k \quad(CE损失函数,可看做目标类别概率最大)\\ &=-\log\frac{e^{y_k}}{\sum_{i=0}^{n-1}e^{y_i}} \end{aligned} J=D(S~S)=1×logsk1=logsk(CE,)=logi=0n1eyieyk
      ∂ J ∂ y m = ∂ J ∂ y m ( log ⁡ ∑ i = 0 n − 1 e y i − y k ) = e y m ∑ i = 0 n − 1 e y i − δ ( m = k ) = s m − δ ( m = k ) 写 成 向 量 形 式 为 : ∂ J ∂ y = S − S ~ \begin{aligned} &\frac{\partial J}{\partial y_m}=\frac{\partial J}{\partial y_m}(\log \sum_{i=0}^{n-1}e^{y_i}-y_k)=\frac{e^{y_m}}{\sum_{i=0}^{n-1}e^{y_i}}-\delta(m=k)=s_m-\delta(m=k) \\ &写成向量形式为:\frac{\partial J}{\partial y}=S-\tilde{S} \end{aligned} ymJ=ymJ(logi=0n1eyiyk)=i=0n1eyieymδ(m=k)=smδ(m=k):yJ=SS~

KL距离(相对熵):是Kullback-Leibler Divergence的简称,也叫相对熵(Relative Entropy).它衡量的是相同事件空间里的两个概率分布的差异情况。其物理意义是:在相同事件空间里,概率分布 P ( x ) P(x) P(x)对应的每个事件,若用概率分布 Q ( x ) Q(x) Q(x)编码时,平均每个基本事件(符号)编码长度增加了多少比特。我们用 D ( P ∣ ∣ Q ) D(P||Q) D(PQ)表示KL距离,计算公式如下:
D ( P ∣ ∣ Q ) = ∑ x ∈ X P ( x ) log ⁡ P ( x ) Q ( x ) D(P||Q)=\sum_{x\in X}P(x)\log\frac{P(x)}{Q(x)} D(PQ)=xXP(x)logQ(x)P(x)
当两个概率分布完全相同时,即 P ( X ) = Q ( X ) P(X)=Q(X) P(X)=Q(X),其相对熵为0.

2.反向传播

 out  = Z L W L + 1 + b ~ L + 1 \text { out }=Z_{L} W_{L+1}+\tilde{b}_{L+1}  out =ZLWL+1+b~L+1,为了便于详细说明反向传播算法,假设 Z L Z_L ZL 2 × 3 2\times 3 2×3的向量, W L + 1 W_{L+1} WL+1 3 × 2 3\times 2 3×2的向量:
Z L = ( z 11 z 12 z 13 z 21 z 22 z 23 ) 2 × 3 , W L + 1 = ( w 11 w 12 w 21 w 22 w 31 w 32 ) 3 × 2 b ~ L + 1 = ( b 1 b 2 b 1 b 2 ) 2 × 2 ,  out  = ( o 11 o 12 o 21 o 22 ) ⇒ Z L W L + 1 + b ~ L + 1 = ( z 11 w 11 + z 12 w 21 + z 13 w 31 + b 1 z 11 w 12 + z 12 w 22 + z 13 w 32 + b 2 z 21 w 11 + z 22 w 21 + z 23 w 31 + b 1 z 21 w 12 + z 22 w 22 + z 23 w 32 + b 2 ) = out . \begin{array}{l}{Z_{L}=\left(\begin{array}{ccc}{z_{11}} & {z_{12}} & {z_{13}} \\ {z_{21}} & {z_{22}} & {z_{23}}\end{array}\right)_{2 \times 3}, W_{L+1}=\left(\begin{array}{cc}{w_{11}} & {w_{12}} \\ {w_{21}} & {w_{22}} \\ {w_{31}} & {w_{32}}\end{array}\right)_{3 \times 2} \tilde{b}_{L+1}=\left(\begin{array}{cc}{b_{1}} & {b_{2}} \\ {b_{1}} & {b_{2}}\end{array}\right)_{2 \times 2}, \text { out }=\left(\begin{array}{cc}{o_{11}} & {o_{12}} \\ {o_{21}} & {o_{22}}\end{array}\right)} \\ \Rightarrow {Z_{L}W_{L+1}+\tilde{b}_{L+1}=\left(\begin{array}{cc}{z_{11} w_{11}+z_{12} w_{21}+z_{13} w_{31}+b_1} & {z_{11} w_{12}+z_{12} w_{22}+z_{13} w_{32}+b_2} \\ {z_{21} w_{11}+z_{22} w_{21}+z_{23} w_{31}+b_1} & {z_{21} w_{12}+z_{22} w_{22}+z_{23} w_{32}+b_2}\end{array}\right)=\text{out}.}\end{array} ZL=(z11z21z12z22z13z23)2×3,WL+1=w11w21w31w12w22w323×2b~L+1=(b1b1b2b2)2×2, out =(o11o21o12o22)ZLWL+1+b~L+1=(z11w11+z12w21+z13w31+b1z21w11+z22w21+z23w31+b1z11w12+z12w22+z13w32+b2z21w12+z22w22+z23w32+b2)=out.
所以,
o 11 = z 11 w 11 + z 12 w 21 + z 13 w 31 + b 1 o 12 = z 11 w 12 + z 12 w 22 + z 13 w 32 + b 2 o 21 = z 21 w 11 + z 22 w 21 + z 23 w 31 + b 1 o 22 = z 21 w 12 + z 22 w 22 + z 23 w 32 + b 2 \begin{array}{l}{o_{11}=z_{11} w_{11}+z_{12} w_{21}+z_{13} w_{31}+b_{1}} \\ {o_{12}=z_{11} w_{12}+z_{12} w_{22}+z_{13} w_{32}+b_{2}} \\ {o_{21}=z_{21} w_{11}+z_{22} w_{21}+z_{23} w_{31}+b_{1}} \\ {o_{22}=z_{21} w_{12}+z_{22} w_{22}+z_{23} w_{32}+b_{2}}\end{array} o11=z11w11+z12w21+z13w31+b1o12=z11w12+z12w22+z13w32+b2o21=z21w11+z22w21+z23w31+b1o22=z21w12+z22w22+z23w32+b2

1) 损失 J J J W W W的导数:

∂ J ∂ w 11 = ∂ J ∂ o 11 z 11 + ∂ J ∂ o 21 z 21 , ∂ J ∂ w 12 = ∂ J ∂ o 12 z 11 + ∂ J ∂ o 22 z 21 ∂ J ∂ w 21 = ∂ J ∂ o 11 z 12 + ∂ J ∂ o 21 z 22 , ∂ J ∂ w 22 = ∂ J ∂ o 12 z 12 + ∂ J ∂ o 22 z 22 ∂ J ∂ w 31 = ∂ J ∂ o 11 z 13 + ∂ J ∂ o 21 z 23 , ∂ J ∂ w 32 = ∂ J ∂ o 12 z 13 + ∂ J ∂ o 22 z 23 \begin{aligned} \frac{\partial J}{\partial w_{11}} &=\frac{\partial J}{\partial o_{11}} z_{11}+\frac{\partial J}{\partial o_{21}} z_{21}, \frac{\partial J}{\partial w_{12}}=\frac{\partial J}{\partial o_{12}} z_{11}+\frac{\partial J}{\partial o_{22}} z_{21} \\ \frac{\partial J}{\partial w_{21}} &=\frac{\partial J}{\partial o_{11}} z_{12}+\frac{\partial J}{\partial o_{21}} z_{22}, \frac{\partial J}{\partial w_{22}}=\frac{\partial J}{\partial o_{12}} z_{12}+\frac{\partial J}{\partial o_{22}} z_{22} \\ \frac{\partial J}{\partial w_{31}} &=\frac{\partial J}{\partial o_{11}} z_{13}+\frac{\partial J}{\partial o_{21}} z_{23}, \frac{\partial J}{\partial w_{32}}=\frac{\partial J}{\partial o_{12}} z_{13}+\frac{\partial J}{\partial o_{22}} z_{23} \end{aligned} w11Jw21Jw31J=o11Jz11+o21Jz21,w12J=o12Jz11+o22Jz21=o11Jz12+o21Jz22,w22J=o12Jz12+o22Jz22=o11Jz13+o21Jz23,w32J=o12Jz13+o22Jz23
⇒ ( ∂ J ∂ w 11 ∂ J ∂ w 12 ∂ J ∂ w 21 ∂ J ∂ w 22 ∂ J ∂ w 31 ∂ J ∂ w 32 ) = ( z 11 z 21 z 12 z 22 z 13 z 23 ) ( ∂ J ∂ o 11 ∂ J ∂ o 12 ∂ J ∂ o 21 ∂ J ∂ o 22 ) \Rightarrow \left(\begin{array}{cc}{\frac{\partial J}{\partial w_{11}}} & {\frac{\partial J}{\partial w_{12}}} \\ {\frac{\partial J}{\partial w_{21}}} & {\frac{\partial J}{\partial w_{22}}} \\ {\frac{\partial J}{\partial w_{31}}} & {\frac{\partial J}{\partial w_{32}}}\end{array}\right)=\left(\begin{array}{cc}{z_{11}} & {z_{21}} \\ {z_{12}} & {z_{22}} \\ {z_{13}} & {z_{23}}\end{array}\right)\left(\begin{array}{cc}{\frac{\partial J}{\partial o_{11}}} & {\frac{\partial J}{\partial o_{12}}} \\ {\frac{\partial J}{\partial o_{21}}} & {\frac{\partial J}{\partial o_{22}}}\end{array}\right) w11Jw21Jw31Jw12Jw22Jw32J=z11z12z13z21z22z23(o11Jo21Jo12Jo22J)
即, ∂ J ∂ W L + 1 = Z L T ∂ J ∂ o u t \frac{\partial J}{\partial W_{L+1}}=Z_L^T\frac{\partial J}{\partial out} WL+1J=ZLToutJ

2) 损失对偏置b的导数等于将 ∂ J ∂ o u t \frac{\partial J}{\partial out} outJ的每一列加起来:

{ ∂ J ∂ b 1 = ∂ J ∂ o 11 + ∂ J ∂ o 21 ∂ J ∂ b 2 = ∂ J ∂ o 12 + ∂ J ∂ o 22 ⇒ ( ∂ J ∂ b ) T = ( ∂ J ∂ b 1 ∂ J ∂ b 2 ) = ( ∂ J ∂ o 11 + ∂ J ∂ o 21 ∂ J ∂ o 12 + ∂ J ∂ o 22 ) \left\{\begin{array}{l}{\frac{\partial J}{\partial b_{1}}=\frac{\partial J}{\partial o_{11}}+\frac{\partial J}{\partial o_{21}}} \\ {\frac{\partial J}{\partial b_{2}}=\frac{\partial J}{\partial o_{12}}+\frac{\partial J}{\partial o_{22}}}\end{array} \Rightarrow\left(\frac{\partial J}{\partial b}\right)^{T}=\left(\frac{\partial J}{\partial b_{1}} \quad \frac{\partial J}{\partial b_{2}}\right)=\left(\frac{\partial J}{\partial o_{11}}+\frac{\partial J}{\partial o_{21}} \quad \frac{\partial J}{\partial o_{12}}+\frac{\partial J}{\partial o_{22}}\right)\right. {b1J=o11J+o21Jb2J=o12J+o22J(bJ)T=(b1Jb2J)=(o11J+o21Jo12J+o22J)

3) 损失 J J J Z Z Z的导数:

∂ J ∂ z 11 = ∂ J ∂ o 11 w 11 + ∂ J ∂ o 12 w 12 ; ∂ J ∂ z 12 = ∂ J ∂ o 11 w 21 + ∂ J ∂ o 12 w 22 ; ∂ J ∂ z 13 = ∂ J ∂ o 11 w 31 + ∂ J ∂ o 12 w 32 ∂ J ∂ z 21 = ∂ J ∂ o 21 w 11 + ∂ J ∂ o 22 w 12 ; ∂ J ∂ z 22 = ∂ J ∂ o 21 w 21 + ∂ J ∂ o 12 w 22 ; ∂ J ∂ z 23 = ∂ J ∂ o 21 w 31 + ∂ J ∂ o 22 w 32 \begin{aligned} \frac{\partial J}{\partial z_{11}} &=\frac{\partial J}{\partial o_{11}} w_{11}+\frac{\partial J}{\partial o_{12}} w_{12} ; \frac{\partial J}{\partial z_{12}}=\frac{\partial J}{\partial o_{11}} w_{21}+\frac{\partial J}{\partial o_{12}} w_{22} ; \frac{\partial J}{\partial z_{13}}=\frac{\partial J}{\partial o_{11}} w_{31}+\frac{\partial J}{\partial o_{12}} w_{32} \\ \frac{\partial J}{\partial z_{21}} &=\frac{\partial J}{\partial o_{21}} w_{11}+\frac{\partial J}{\partial o_{22}} w_{12} ; \frac{\partial J}{\partial z_{22}}=\frac{\partial J}{\partial o_{21}} w_{21}+\frac{\partial J}{\partial o_{12}} w_{22} ; \frac{\partial J}{\partial z_{23}}=\frac{\partial J}{\partial o_{21}} w_{31}+\frac{\partial J}{\partial o_{22}} w_{32} \end{aligned} z11Jz21J=o11Jw11+o12Jw12;z12J=o11Jw21+o12Jw22;z13J=o11Jw31+o12Jw32=o21Jw11+o22Jw12;z22J=o21Jw21+o12Jw22;z23J=o21Jw31+o22Jw32
即,
( ∂ J ∂ z 11 ∂ J ∂ z 12 ∂ J ∂ z 13 ∂ J ∂ z 21 ∂ J ∂ z 22 ∂ J ∂ z 23 ) = ( ∂ J ∂ o 11 ∂ J ∂ o 12 ∂ J ∂ θ 21 ∂ J ∂ o 22 ) ( w 11 w 21 w 31 w 12 w 22 w 32 ) \left(\begin{array}{ccc}{\frac{\partial J}{\partial z_{11}}} & {\frac{\partial J}{\partial z_{12}}} & {\frac{\partial J}{\partial z_{13}}} \\ {\frac{\partial J}{\partial z_{21}}} & {\frac{\partial J}{\partial z_{22}}} & {\frac{\partial J}{\partial z_{23}}}\end{array}\right)=\left(\begin{array}{cc}{\frac{\partial J}{\partial o_{11}}} & {\frac{\partial J}{\partial o_{12}}} \\ {\frac{\partial J}{\partial \theta_{21}}} & {\frac{\partial J}{\partial o_{22}}}\end{array}\right)\left(\begin{array}{ccc}{w_{11}} & {w_{21}} & {w_{31}} \\ {w_{12}} & {w_{22}} & {w_{32}}\end{array}\right) (z11Jz21Jz12Jz22Jz13Jz23J)=(o11Jθ21Jo12Jo22J)(w11w12w21w22w31w32)
⇒ ∂ J ∂ Z L = ∂ J ∂ o u t W L + 1 T \Rightarrow \frac{\partial J}{\partial Z_{L}}=\frac{\partial J}{\partial out}W_{L+1}^T ZLJ=outJWL+1T

4) 损失 J J J h h h的导数:

Z L = f L ( h L ) Z_L = f_L(h_L) ZL=fL(hL)

  • f L f_L fL为sigmoid时, Z L = 1 1 + e − h L Z_L=\frac{1}{1+e^{-h_L}} ZL=1+ehL1.
    ∂ J ∂ h L = ∂ J ∂ Z L d z L d h L = ∂ J ∂ Z L e − h L ( 1 + e − h L ) 2 = ∂ J ∂ Z L 1 1 + e − h L e − h L 1 + e − h L = ∂ J ∂ Z L Z L ( 1 − Z L ) \begin{array}{l}{\frac{\partial J}{\partial h_{L}}=\frac{\partial J}{\partial Z_{L}} \frac{d z_{L}}{d h_{L}}=\frac{\partial J}{\partial Z_{L}} \frac{e^{-h L}}{\left(1+e^{-h_{L}}\right)^{2}}=\frac{\partial J}{\partial Z_{L}} \frac{1}{1+e^{-h_{L}}} \frac{e^{-h_{L}}}{1+e^{-h_{L}}}} \\ {=\frac{\partial J}{\partial Z_{L}} Z_{L}\left(1-Z_{L}\right)}\end{array} hLJ=ZLJdhLdzL=ZLJ(1+ehL)2ehL=ZLJ1+ehL11+ehLehL=ZLJZL(1ZL)
  • f L f_L fL为tanh时, Z L = e h L − e − h L e h L + e − h L {Z_{L}=\frac{e^{h_{L}}-e^{-h_{L}}}{e^{h_{L}}+e^{-h_{L}}}} ZL=ehL+ehLehLehL.
    ∂ J ∂ h L = ∂ J ∂ Z L d Z L d h L = ∂ J ∂ Z L 4 ( e h L + e − h L ) 2 = ∂ J ∂ Z L [ 1 − ( e h L − e − h L e h L + e − h L ) 2 ] = ∂ J ∂ z L [ 1 − z L 2 ] \begin{array}{l} {\frac{\partial J}{\partial h_{L}}=\frac{\partial J}{\partial Z_{L}} \frac{d Z_{L}}{d h_{L}}=\frac{\partial J}{\partial Z_{L}} \frac{4}{\left(e^{h_{L}}+e^{-h_{L}}\right)^{2}}=\frac{\partial J}{\partial Z_{L}}\left[1-\left(\frac{e^{h_{L}}-e^{-h_{L}}}{e^{h_{L}}+e^{-h_{L}}}\right)^{2}\right]} \\ {=\frac{\partial J}{\partial z_{L}}\left[1-z_{L}^{2}\right]}\end{array} hLJ=ZLJdhLdZL=ZLJ(ehL+ehL)24=ZLJ[1(ehL+ehLehLehL)2]=zLJ[1zL2]
  • f L f_L fL为relu时, Z L = r e l u ( h L ) = { 0 , h L ≤ 0 h L , h L > 0 Z_L=relu(h_L)=\left\{\begin{matrix} 0,&h_L\leq 0 \\ h_L,&h_L > 0 \end{matrix}\right. ZL=relu(hL)={0,hL,hL0hL>0.
    ∂ J ∂ h L = ∂ J ∂ Z L ∂ Z L ∂ h L = { 0 , h L ≤ 0 ∂ J ∂ Z L , h L > 0 \begin{array}{l} \frac{\partial J}{\partial h_L}=\frac{\partial J}{\partial Z_L}\frac{\partial Z_L}{\partial h_L}=\left\{\begin{matrix} 0,&h_L\leq 0 \\ \frac{\partial J}{\partial Z_L},&h_L > 0 \end{matrix}\right. \end{array} hLJ=ZLJhLZL={0,ZLJ,hL0hL>0

3. 梯度更新

对于不同算法 ,梯度更新方式如下:
∂ J ∂ o u t ⇒ { ∂ J ∂ W L + 1 = Z L T ∂ J ∂ o u t ∂ J ∂ Z L = ∂ J ∂ o u t W L + 1 T ( ∂ J ∂ b ) T = S u m C o l ( ∂ J ∂ o u t ) W L + 1 t + 1 = W L + 1 t − η ∂ J ∂ W L + 1 b L + 1 t + 1 = b L + 1 t − η ∂ J ∂ b L + 1 ⇒ ∂ J ∂ h L = ∂ J ∂ Z L ∂ Z L ∂ h L ⇒ { ∂ J ∂ W L = Z L − 1 T ∂ J ∂ h L ∂ J ∂ Z L − 1 = ∂ J ∂ h L W L T ⋮ ⋮ ⇒ ⋯ \frac{\partial J}{\partial out} \Rightarrow \left \{\begin{matrix} \frac{\partial J}{\partial W_{L+1}}=Z_L^T\frac{\partial J}{\partial out} \\ \frac{\partial J}{\partial Z_{L}}=\frac{\partial J}{\partial out}W_{L+1}^T \\ \left(\frac{\partial J}{\partial b}\right)^{T}=SumCol(\frac{\partial J}{\partial out}) \\ W_{L+1}^{t+1} = W_{L+1}^t-\eta \frac{\partial J}{\partial W_{L+1}} \\ b_{L+1}^{t+1} = b_{L+1}^t-\eta \frac{\partial J}{\partial b_{L+1}} \end{matrix} \right. \Rightarrow \frac{\partial J}{\partial h_L}=\frac{\partial J}{\partial Z_L}\frac{\partial Z_L}{\partial h_L} \Rightarrow \left \{\begin{matrix} \frac{\partial J}{\partial W_{L}}=Z_{L-1}^T\frac{\partial J}{\partial h_L} \\ \frac{\partial J}{\partial Z_{L-1}}=\frac{\partial J}{\partial h_L}W_{L}^T \\ \vdots \\ \vdots \end{matrix}\right. \Rightarrow \cdots outJWL+1J=ZLToutJZLJ=outJWL+1T(bJ)T=SumCol(outJ)WL+1t+1=WL+1tηWL+1JbL+1t+1=bL+1tηbL+1JhLJ=ZLJhLZLWLJ=ZL1ThLJZL1J=hLJWLT
参考:深度之眼公众号

你可能感兴趣的:(机器学习,深度学习)