本文给出门控循环单元GRUCell的定义公式, 并求解其在反向传播中的梯度.
给出的相关公式是完整的, 编程导向的, 可以直接用于代码实现, 已通过 Python 验证.
配套代码, 请参考文章 :
纯 Python 和 PyTorch 对比实现门控循环单元 GRU 及反向传播
Affine 变换的定义和梯度, 请参考文章 :
affine/linear(仿射/线性)变换函数详解及全连接层反向传播的梯度求导
系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981
n考虑输入一个 3 阶张量 X l m n X_{lmn} Xlmn, 该张量可以表示为 l l l 个尺寸为 m × n m \times n m×n 的矩阵 X m n X_{mn} Xmn, 同时表明循环单元的输入尺寸为 n n n.
设第一个输入矩阵为 X m n ( 1 ) X_{mn}^{(1)} Xmn(1) , 对应的 3 个变换矩阵分别为 W r , W u , W c W_r,W_u,W_c Wr,Wu,Wc, 偏置向量为 a r , a u , a c a_r,a_u,a_c ar,au,ac .
设初始隐含层矩阵为 H m r ( 0 ) H_{mr}^{(0)} Hmr(0), 对应的 2 个变换矩阵分别为 V r , V u , V c V_r,V_u,V_c Vr,Vu,Vc, 偏置向量为 b r , b u , b c b_r,b_u,b_c br,bu,bc .
则一次 GRUCell 循环变换为 :
A r = X ( 1 ) W r T + a r + H ( 0 ) V r T + b r A u = X ( 1 ) W u T + a u + H ( 0 ) V u T + b u g r = s i g m o i d ( A r ) g u = s i g m o i d ( A u )    A c = X ( 1 ) W c T + a c + g r ⊙ ( H ( 0 ) V c T + b c ) g c = t a n h ( A c ) H ( 1 ) = ( 1 − g u ) ⊙ g c + g u ⊙ H ( 0 ) A_r =X^{(1)}{W_{r}}^T + a_{r} + H^{(0)}V_{r}^T + b_{r}\\ A_u =X^{(1)}{W_{u}}^T + a_{u} + H^{(0)}V_{u}^T + b_{u}\\ g_r =sigmoid(A_r)\\ g_u =sigmoid(A_u)\\ \;\\ A_c =X^{(1)}{W_{c}}^T + a_{c} + g_r \odot (H^{(0)}V_{c}^T + b_{c})\\ g_c =tanh(A_c)\\ H^{(1)} =(1 - g_u)\odot g_c + g_u \odot H^{(0)} Ar=X(1)WrT+ar+H(0)VrT+brAu=X(1)WuT+au+H(0)VuT+bugr=sigmoid(Ar)gu=sigmoid(Au)Ac=X(1)WcT+ac+gr⊙(H(0)VcT+bc)gc=tanh(Ac)H(1)=(1−gu)⊙gc+gu⊙H(0)
上式中的 ⊙ \odot ⊙ 表示 element-wise 元素积, 将以上过程记为 :
H ( 1 ) = G R U C e l l ( X ( 1 ) , H ( 0 ) ) H^{(1)} = GRUCell(X^{(1)},H^{(0)}) H(1)=GRUCell(X(1),H(0))
循环到下一次时, 将 H ( 1 ) , C ( 1 ) H^{(1)},C^{(1)} H(1),C(1) 代入 H ( 0 ) , C ( 0 ) H^{(0)},C^{(0)} H(0),C(0) 的位置, 与下一个 X ( 2 ) X^{(2)} X(2) 重新进行运算.
下面使用迭代记法表示 GRUCell 运算.
使用 H ( 0 ) H^{(0)} H(0) 表示初始隐含层矩阵, 对于 :
X l m n = X m n ( 1 ) , X m n ( 2 ) , X m n ( 3 ) , ⋯   , X m n ( l ) X_{lmn} = X_{mn}^{(1)},X_{mn}^{(2)},X_{mn}^{(3)},\cdots,X_{mn}^{(l)} Xlmn=Xmn(1),Xmn(2),Xmn(3),⋯,Xmn(l)
则 :
H ( 1 ) = G R U C e l l ( X ( 1 ) , H ( 0 ) )    H ( 2 ) = G R U C e l l ( X ( 2 ) , H ( 1 ) )    H ( 3 ) = G R U C e l l ( X ( 3 ) , H ( 2 ) ) ⋮ H ( l ) = G R U C e l l ( X ( l ) , H ( l − 1 ) ) H^{(1)} = GRUCell(X^{(1)},H^{(0)})\\ \;\\ H^{(2)} = GRUCell(X^{(2)},H^{(1)})\\ \;\\ H^{(3)} = GRUCell(X^{(3)},H^{(2)})\\ \vdots\\ H^{(l)} = GRUCell(X^{(l)},H^{(l-1)})\\ H(1)=GRUCell(X(1),H(0))H(2)=GRUCell(X(2),H(1))H(3)=GRUCell(X(3),H(2))⋮H(l)=GRUCell(X(l),H(l−1))
展开最后一层作为示例 :
A r = X ( l ) W r T + a r + H ( l − 1 ) V r T + b r A u = X ( l ) W u T + a u + H ( l − 1 ) V u T + b u g r = s i g m o i d ( A r ) g u = s i g m o i d ( A u )    A c = X ( l ) W c T + a c + g r ⊙ ( H ( l − 1 ) V c T + b c ) g c = t a n h ( A c ) H ( l ) = ( 1 − g u ) ⊙ g c + g u ⊙ H ( l − 1 ) A_r =X^{(l)}{W_{r}}^T + a_{r} + H^{(l-1)}V_{r}^T + b_{r}\\ A_u =X^{(l)}{W_{u}}^T + a_{u} + H^{(l-1)}V_{u}^T + b_{u}\\ g_r =sigmoid(A_r)\\ g_u =sigmoid(A_u)\\ \;\\ A_c =X^{(l)}{W_{c}}^T + a_{c} + g_r \odot (H^{(l-1)}V_{c}^T + b_{c})\\ g_c =tanh(A_c)\\ H^{(l)} =(1 - g_u)\odot g_c + g_u \odot H^{(l-1)} Ar=X(l)WrT+ar+H(l−1)VrT+brAu=X(l)WuT+au+H(l−1)VuT+bugr=sigmoid(Ar)gu=sigmoid(Au)Ac=X(l)WcT+ac+gr⊙(H(l−1)VcT+bc)gc=tanh(Ac)H(l)=(1−gu)⊙gc+gu⊙H(l−1)
在迭代的过程中 W ,    V ,    a ,    b W, \; V , \; a, \; b W,V,a,b 是共享的, 不变的.
使用 3 阶张量表示 :
H l m r = G R U C e l l ( l ) ( X l m n , H m r ( 0 ) ) H_{lmr} = GRUCell^{(l)}(X_{lmn},H_{mr}^{(0)}) Hlmr=GRUCell(l)(Xlmn,Hmr(0))
GRUCell 的上标 ( l ) (l) (l) 表示经过 l l l 次循环迭代计算, 输入尺寸为 l × m × n l \times m \times n l×m×n 的张量 X l m n X_{lmn} Xlmn 将输出尺寸为 l × m × r l \times m \times r l×m×r 的张量 H l m r H_{lmr} Hlmr .
考虑输入一个 3 阶张量 X l m n X_{lmn} Xlmn, 经过 GRUCell 运算后, 输出 3 阶张量 H l m r H_{lmr} Hlmr, 往前 forward 传播得到误差值 error ( 标量 e ), e 对 H l m r H_{lmr} Hlmr 的梯度 ∇ e ( H l m r ) \nabla e_{(H_{lmr})} ∇e(Hlmr) 已由上游给出, 求 e 对 X l m n X_{lmn} Xlmn 的梯度.
H i j n , C i j n = R N N C e l l ( i ) ( X i j k , H j n ( 0 ) , C j n ( 0 ) )    e = f o r w a r d ( H i j n ) H_{ijn},C_{ijn} = RNNCell^{(i)}(X_{ijk},H_{jn}^{(0)},C_{jn}^{(0)})\\ \;\\ e = forward(H_{ijn}) Hijn,Cijn=RNNCell(i)(Xijk,Hjn(0),Cjn(0))e=forward(Hijn)
从 GRUCell 运算的定义可以看出, 每一次循环迭代都是由 Affine 计算和激活函数计算组合而成.
Affine 计算的定义及梯度求导公式已在上面的 <相关> 中给出.
关于 Affine 的梯度 :
A = X W T + b    d e d X = ∇ e ( A ) W    d e d W = ∇ e ( A ) T X    d e d b = s u m ( ∇ e ( A ) ,    a x i s = 0 ) A = XW^T + b\\ \;\\ \frac {d e}{d X} =\nabla e_{(A)}W\\ \;\\ \frac {d e}{d W} =\nabla e_{(A)}^TX\\ \;\\ \frac {de}{db}=sum(\nabla e_{(A)},\; axis=0) A=XWT+bdXde=∇e(A)WdWde=∇e(A)TXdbde=sum(∇e(A),axis=0)
关于 tanh 的梯度 :
y = t a n h ( x ) = e x − e − x e x + e − x    d y d x = 1 − y 2 y = tanh(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}} \\ \;\\ \frac{dy}{dx}= 1-y^2 y=tanh(x)=ex+e−xex−e−xdxdy=1−y2
关于 sigmoid 的梯度 :
y = s i g m o i d ( x ) = 1 1 + e − x    d y d x = y ( 1 − y ) y = sigmoid(x)=\frac{1}{1+e^{-x}} \\ \;\\ \frac{dy}{dx}= y(1-y) y=sigmoid(x)=1+e−x1dxdy=y(1−y)
GRUCell 的运算是循环迭代的, 每一次梯度不仅受到上游 forward 运算的影响, 还受到自身上一步运算的影响.
为了避免符号混乱, 将上游 forward 运算传递到 H 的梯度 ∇ e ( H l m r ) \nabla e_{(H_{lmr})} ∇e(Hlmr) 记为 ∇ e ( F l m r ) \nabla e_{(F_{lmr})} ∇e(Flmr), ∇ e ( H l m r ) \nabla e_{(H_{lmr})} ∇e(Hlmr) 用于迭代过程中的内部计算.
从最后一步开始算起 :
d e d g u ( l ) = d e d F ( l ) ⊙ ( − g c ( l ) + H ( l − 1 ) ) \frac{de}{dg_u^{(l)}}=\frac{de}{dF^{(l)}}\odot(-g_c^{(l)}+H^{(l-1)})\\ dgu(l)de=dF(l)de⊙(−gc(l)+H(l−1))
在这一步, 同样可以得到 d e / d H ( l − 1 ) de / dH^{(l-1)} de/dH(l−1), 这个结果不依赖于 d e / d g u ( l ) de / dg_u^{(l)} de/dgu(l), 是独立的. d e / d H ( l − 1 ) de / dH^{(l-1)} de/dH(l−1) 的计算过程比较长, 放到下文, 这里先拿来使用.
按顺序往下迭代 :
d e d g u ( l ) = d e d F ( l ) ⊙ ( − g c ( l ) + H ( l − 1 ) ) d e d g u ( l − 1 ) = ( d e d F ( l − 1 ) + d e d H ( l − 1 ) ) ⊙ ( − g c ( l − 1 ) + H ( l − 2 ) ) d e d g u ( l − 2 ) = ( d e d F ( l − 2 ) + d e d H ( l − 2 ) ) ⊙ ( − g c ( l − 2 ) + H ( l − 3 ) ) ⋮ d e d g u ( 1 ) = ( d e d F ( 1 ) + d e d H ( 1 ) ) ⊙ ( − g c ( 1 ) + H ( 0 ) ) \frac{de}{dg_u^{(l)}}=\frac{de}{dF^{(l)}}\odot(-g_c^{(l)}+H^{(l-1)})\\ \frac{de}{dg_u^{(l-1)}}=(\frac{de}{dF^{(l-1)}}+\frac{de}{dH^{(l-1)}})\odot(-g_c^{(l-1)}+H^{(l-2)})\\ \frac{de}{dg_u^{(l-2)}}=(\frac{de}{dF^{(l-2)}}+\frac{de}{dH^{(l-2)}})\odot(-g_c^{(l-2)}+H^{(l-3)})\\ \vdots\\ \frac{de}{dg_u^{(1)}}=(\frac{de}{dF^{(1)}}+\frac{de}{dH^{(1)}})\odot(-g_c^{(1)}+H^{(0)}) dgu(l)de=dF(l)de⊙(−gc(l)+H(l−1))dgu(l−1)de=(dF(l−1)de+dH(l−1)de)⊙(−gc(l−1)+H(l−2))dgu(l−2)de=(dF(l−2)de+dH(l−2)de)⊙(−gc(l−2)+H(l−3))⋮dgu(1)de=(dF(1)de+dH(1)de)⊙(−gc(1)+H(0))
d e d g c ( l ) = d e d F ( l ) ⊙ ( 1 − g u ( l ) ) d e d g c ( l − 1 ) = ( d e d F ( l − 1 ) + d e d H ( l − 1 ) ) ⊙ ( 1 − g u ( l − 1 ) ) d e d g c ( l − 2 ) = ( d e d F ( l − 2 ) + d e d H ( l − 2 ) ) ⊙ ( 1 − g u ( l − 2 ) ) ⋮ d e d g c ( 1 ) = ( d e d F ( 1 ) + d e d H ( 1 ) ) ⊙ ( 1 − g u ( 1 ) ) \frac{de}{dg_c^{(l)}}=\frac{de}{dF^{(l)}}\odot(1-g_u^{(l)})\\ \frac{de}{dg_c^{(l-1)}}=(\frac{de}{dF^{(l-1)}}+\frac{de}{dH^{(l-1)}})\odot(1-g_u^{(l-1)})\\ \frac{de}{dg_c^{(l-2)}}=(\frac{de}{dF^{(l-2)}}+\frac{de}{dH^{(l-2)}})\odot(1-g_u^{(l-2)})\\ \vdots\\ \frac{de}{dg_c^{(1)}}=(\frac{de}{dF^{(1)}}+\frac{de}{dH^{(1)}})\odot(1-g_u^{(1)}) dgc(l)de=dF(l)de⊙(1−gu(l))dgc(l−1)de=(dF(l−1)de+dH(l−1)de)⊙(1−gu(l−1))dgc(l−2)de=(dF(l−2)de+dH(l−2)de)⊙(1−gu(l−2))⋮dgc(1)de=(dF(1)de+dH(1)de)⊙(1−gu(1))
d e d A u = d e d g u ⊙ g u ⊙ ( 1 − g u )    d e d A c = d e d g c ⊙ ( 1 − g c 2 ) \frac{de}{dA_u}=\frac{de}{dg_u}\odot g_u\odot(1-g_u)\\ \;\\ \frac{de}{dA_c}=\frac{de}{dg_c}\odot (1-g_c^2)\\ dAude=dgude⊙gu⊙(1−gu)dAcde=dgcde⊙(1−gc2)
这里不涉及迭代, 分步计算即可.
d e d g r = d e d A c ⊙ ( H ( l − 1 ) V c T + b c )    d e d A r = d e d g r ⊙ g r ( 1 − g r ) \frac{de}{dg_r}=\frac{de}{dA_c}\odot (H^{(l-1)}V_c^T + b_c)\\ \;\\ \frac{de}{dA_r}=\frac{de}{dg_r}\odot g_r (1-g_r)\\ dgrde=dAcde⊙(H(l−1)VcT+bc)dArde=dgrde⊙gr(1−gr)
这里不涉及迭代, 分步计算即可.
这里涉及迭代, 按顺序计算 :
d e d H ( l − 1 ) = d e d F ( l ) ⊙ g u ( l ) + d e d A r ( l ) V r + d e d A u ( l ) V u + ( d e d A c ( l ) ⊙ g r ( l ) ) V c d e d H ( l − 2 ) = ( d e d F ( l − 1 ) + d e d H ( l − 1 ) ) ⊙ g u ( l − 1 ) + d e d A r ( l − 1 ) V r + d e d A u ( l − 1 ) V u + ( d e d A c ( l − 1 ) ⊙ g r ( l − 1 ) ) V c d e d H ( l − 3 ) = ( d e d F ( l − 2 ) + d e d H ( l − 2 ) ) ⊙ g u ( l − 2 ) + d e d A r ( l − 2 ) V r + d e d A u ( l − 2 ) V u + ( d e d A c ( l − 2 ) ⊙ g r ( l − 2 ) ) V c ⋮ d e d H ( 0 ) = ( d e d F ( 1 ) + d e d H ( 1 ) ) ⊙ g u ( 1 ) + d e d A r ( 1 ) V r + d e d A u ( 1 ) V u + ( d e d A c ( 1 ) ⊙ g r ( 1 ) ) V c \frac{de}{dH^{(l-1)}}=\frac{de}{dF^{(l)}}\odot g_u^{(l)}+\frac{de}{dA_r^{(l)}}V_r+\frac{de}{dA_u^{(l)}}V_u+(\frac{de}{dA_c^{(l)}}\odot g_r^{(l)}) V_c\\ \frac{de}{dH^{(l-2)}}=(\frac{de}{dF^{(l-1)}}+\frac{de}{dH^{(l-1)}})\odot g_u^{(l-1)}+\frac{de}{dA_r^{(l-1)}}V_r+\frac{de}{dA_u^{(l-1)}}V_u+(\frac{de}{dA_c^{(l-1)}}\odot g_r^{(l-1)}) V_c\\ \frac{de}{dH^{(l-3)}}=(\frac{de}{dF^{(l-2)}}+\frac{de}{dH^{(l-2)}})\odot g_u^{(l-2)}+\frac{de}{dA_r^{(l-2)}}V_r+\frac{de}{dA_u^{(l-2)}}V_u+(\frac{de}{dA_c^{(l-2)}}\odot g_r^{(l-2)}) V_c\\ \vdots\\ \frac{de}{dH^{(0)}}=(\frac{de}{dF^{(1)}}+\frac{de}{dH^{(1)}})\odot g_u^{(1)}+\frac{de}{dA_r^{(1)}}V_r+\frac{de}{dA_u^{(1)}}V_u+(\frac{de}{dA_c^{(1)}}\odot g_r^{(1)}) V_c\\ dH(l−1)de=dF(l)de⊙gu(l)+dAr(l)deVr+dAu(l)deVu+(dAc(l)de⊙gr(l))VcdH(l−2)de=(dF(l−1)de+dH(l−1)de)⊙gu(l−1)+dAr(l−1)deVr+dAu(l−1)deVu+(dAc(l−1)de⊙gr(l−1))VcdH(l−3)de=(dF(l−2)de+dH(l−2)de)⊙gu(l−2)+dAr(l−2)deVr+dAu(l−2)deVu+(dAc(l−2)de⊙gr(l−2))Vc⋮dH(0)de=(dF(1)de+dH(1)de)⊙gu(1)+dAr(1)deVr+dAu(1)deVu+(dAc(1)de⊙gr(1))Vc
d e d X = d e d A r W r + d e d A u W u + ( d e d A c ⊙ g r ) W c \frac{de}{dX}=\frac{de}{dA_r}W_r+\frac{de}{dA_u}W_u+(\frac{de}{dA_c}\odot g_r) W_c dXde=dArdeWr+dAudeWu+(dAcde⊙gr)Wc
这里不涉及迭代, 分步计算即可.
d e d W r = ( d e d A r ) T X    d e d W u = ( d e d A u ) T X    d e d W c = ( d e d A c ) T X \frac{de}{dW_r}=(\frac{de}{dA_r})^T X\\ \;\\ \frac{de}{dW_u}=(\frac{de}{dA_u})^T X\\ \;\\ \frac{de}{dW_c}=(\frac{de}{dA_c})^T X dWrde=(dArde)TXdWude=(dAude)TXdWcde=(dAcde)TX
这里不涉及迭代, 分步计算即可.
d e d W r = ( d e d A r ) T H ( l − 1 )    d e d W u = ( d e d A u ) T H ( l − 1 )    d e d W c = ( d e d A c ⊙ g r ) T H ( l − 1 ) \frac{de}{dW_r}=(\frac{de}{dA_r})^T H^{(l-1)}\\ \;\\ \frac{de}{dW_u}=(\frac{de}{dA_u})^T H^{(l-1)}\\ \;\\ \frac{de}{dW_c}=(\frac{de}{dA_c}\odot g_r)^T H^{(l-1)} dWrde=(dArde)TH(l−1)dWude=(dAude)TH(l−1)dWcde=(dAcde⊙gr)TH(l−1)
这里不涉及迭代, 分步计算即可.
d e d a r = s u m ( ∇ ( d e d A r ) T ,    a x i s = 0 )    d e d a u = s u m ( ∇ ( d e d A u ) T ,    a x i s = 0 )    d e d a c = s u m ( ∇ ( d e d A c ) T ,    a x i s = 0 ) \frac{de}{da_r}=sum(\nabla (\frac{de}{dA_r})^T,\; axis=0)\\ \;\\ \frac{de}{da_u}=sum(\nabla (\frac{de}{dA_u})^T,\; axis=0)\\ \;\\ \frac{de}{da_c}=sum(\nabla (\frac{de}{dA_c})^T,\; axis=0) darde=sum(∇(dArde)T,axis=0)daude=sum(∇(dAude)T,axis=0)dacde=sum(∇(dAcde)T,axis=0)
这里不涉及迭代, 分步计算即可. 同样的 :
d e d a = d e d b \frac{de}{da}=\frac{de}{db} dade=dbde