本菜鸡觉得RNN求导公式太复杂了, 所以想了一个办法拆分求导的公式.
那就是用语法树.
原文参见RNN反向求导详解_格物致知-CSDN博客
o t = φ ( V s t ) = φ ( V ϕ ( W s t − 1 + U x t ) ) L t = loss ( o t , y t ) \begin{aligned} o_t&=\varphi(Vs_t)=\varphi(V\phi(Ws_{t-1}+Ux_t))\\ L_t&=\text{loss}(o_t,y_t) \end{aligned} otLt=φ(Vst)=φ(Vϕ(Wst−1+Uxt))=loss(ot,yt)
令 o t ∗ = V s t o_t^*=Vs_t ot∗=Vst, s t ∗ = U x t + W s t − 1 s_t^*=Ux_t+Ws_{t-1} st∗=Uxt+Wst−1
则 o t = φ ( o t ∗ ) o_t=\varphi(o_t^*) ot=φ(ot∗), s t = ϕ ( s t ∗ ) s_t=\phi(s_t^*) st=ϕ(st∗)
现在把 L t L_t Lt画成一棵语法树, 然后开始一步一步求导
用 ∗ * ∗表示元素相乘, 用 × \times ×表示矩阵乘法
∂ L t ∂ o t ∗ = ∂ L t ∂ o t ∗ ∂ o t ∂ o t ∗ = ∂ L t ∂ o t ∗ φ ′ ( o t ∗ ) (1) \begin{aligned} \cfrac{\partial L_t}{\partial o_t^*}=\cfrac{\partial L_t}{\partial o_t}*\cfrac{\partial o_t}{\partial o_t^*}=\cfrac{\partial L_t}{\partial o_t}*\varphi'(o_t^*)\tag{1} \end{aligned} ∂ot∗∂Lt=∂ot∂Lt∗∂ot∗∂ot=∂ot∂Lt∗φ′(ot∗)(1)
式1的结果是一个与 o t ∗ o_t^* ot∗的维度一致的向量.
∂ L t ∂ V t = ∂ L t ∂ o t ∗ [ ? ] ∂ o t ∗ ∂ V (2) \cfrac{\partial L_t}{\partial V_t}=\cfrac{\partial L_t}{\partial o_t^*}[?]\cfrac{\partial o_t^*}{\partial V}\tag{2} ∂Vt∂Lt=∂ot∗∂Lt[?]∂V∂ot∗(2)
公式2整体上是标量对矩阵求导, 标量对矩阵求导就是标量对矩阵中的每个元素求导; 有一个中间值 o t ∗ o_t^* ot∗是向量.
的前半部分在公式1中求过了, 后面是对矩阵×向量的求导
既然是对 V V V求导那结果的形状必然跟 V V V一样
还是写个例子算算怎么求导吧
o ∗ = V × s = [ V 11 V 12 V 13 V 14 V 21 V 22 V 23 V 24 V 31 V 32 V 33 V 34 ] × [ s 1 s 2 s 3 s 4 ] = [ V 11 s 1 + V 12 s 2 + V 13 s 3 + V 14 s 4 V 21 s 1 + V 22 s 2 + V 23 s 3 + V 24 s 4 V 31 s 1 + V 32 s 2 + V 33 s 3 + V 34 s 4 ] = [ o 1 ∗ o 2 ∗ o 3 ∗ ] (3) \boldsymbol{o^*}=\boldsymbol{V}\times\boldsymbol{s}= \begin{bmatrix}V_{11}&V_{12}&V_{13}&V_{14}\\V_{21}&V_{22}&V_{23}&V_{24}\\V_{31}&V_{32}&V_{33}&V_{34}\end{bmatrix} \times \begin{bmatrix}s_1\\s_2\\s_3\\s_4\end{bmatrix}= \begin{bmatrix}V_{11}s_1+V_{12}s_2+V_{13}s_3+V_{14}s_4\\V_{21}s_1+V_{22}s_2+V_{23}s_3+V_{24}s_4\\V_{31}s_1+V_{32}s_{2}+V_{33}s_3+V_{34}s_4\end{bmatrix}= \begin{bmatrix}o^*_1\\o^*_2\\o^*_3\end{bmatrix}\tag{3} o∗=V×s=⎣⎡V11V21V31V12V22V32V13V23V33V14V24V34⎦⎤×⎣⎢⎢⎡s1s2s3s4⎦⎥⎥⎤=⎣⎡V11s1+V12s2+V13s3+V14s4V21s1+V22s2+V23s3+V24s4V31s1+V32s2+V33s3+V34s4⎦⎤=⎣⎡o1∗o2∗o3∗⎦⎤(3)
∂ L ∂ V 11 = ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ V 11 = ∂ L ∂ o 1 ∗ s 1 ∂ L ∂ V 12 = ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ V 12 = ∂ L ∂ o 1 ∗ s 2 ⋮ ∂ L ∂ V 34 = ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ V 34 = ∂ L ∂ o 3 ∗ s 4 \begin{aligned} \cfrac{\partial L}{\partial V_{11}}=\cfrac{\partial L}{\partial o^*_1}\cfrac{\partial o^*_1}{\partial V_{11}}&=\cfrac{\partial L}{\partial o^*_1}s_1\\ \cfrac{\partial L}{\partial V_{12}}=\cfrac{\partial L}{\partial o^*_1}\cfrac{\partial o^*_1}{\partial V_{12}}&=\cfrac{\partial L}{\partial o^*_1}s_2\\ &\vdots\\ \cfrac{\partial L}{\partial V_{34}}=\cfrac{\partial L}{\partial o^*_3}\cfrac{\partial o^*_3}{\partial V_{34}}&=\cfrac{\partial L}{\partial o^*_3}s_4\\ \end{aligned} ∂V11∂L=∂o1∗∂L∂V11∂o1∗∂V12∂L=∂o1∗∂L∂V12∂o1∗∂V34∂L=∂o3∗∂L∂V34∂o3∗=∂o1∗∂Ls1=∂o1∗∂Ls2⋮=∂o3∗∂Ls4
∂ L ∂ V = [ ∂ L ∂ o 1 ∗ ∂ L ∂ o 2 ∗ ∂ L ∂ o 3 ∗ ] × [ s 1 s 2 s 3 s 4 ] \begin{aligned} \cfrac{\partial L}{\partial V}=\begin{bmatrix}\cfrac{\partial L}{\partial o^*_1}\\\cfrac{\partial L}{\partial o^*_2}\\\cfrac{\partial L}{\partial o^*_3}\end{bmatrix}\times\begin{bmatrix}s_1&s_2&s_3&s_4\end{bmatrix} \end{aligned} ∂V∂L=⎣⎢⎢⎢⎢⎢⎢⎡∂o1∗∂L∂o2∗∂L∂o3∗∂L⎦⎥⎥⎥⎥⎥⎥⎤×[s1s2s3s4]
所以式2应该写成
∂ L t ∂ V t = ∂ L t ∂ o t ∗ × ∂ o t ∗ ∂ V = ∂ L t ∂ o t ∗ × s t T (4) \cfrac{\partial L_t}{\partial V_t}=\cfrac{\partial L_t}{\partial o_t^*}\times\cfrac{\partial o_t^*}{\partial V}=\cfrac{\partial L_t}{\partial o_t^*}\times s_t^T\tag{4} ∂Vt∂Lt=∂ot∗∂Lt×∂V∂ot∗=∂ot∗∂Lt×stT(4)
然后求 L t L_t Lt对 s t s_t st的导数, 还要参考式3
图片来源: https://wenku.baidu.com/view/0c28ff2249d7c1c708a1284ac850ad02de8007c1.html
∂ L ∂ s 1 = [ ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ s 1 + ∂ L ∂ o 2 ∗ ∂ o 2 ∗ ∂ s 1 + ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ s 1 ] ⋮ ∂ L ∂ s 4 = [ ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ s 4 + ∂ L ∂ o 2 ∗ ∂ o 2 ∗ ∂ s 4 + ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ s 4 ] \begin{aligned} \cfrac{\partial L}{\partial s_1}&=\begin{bmatrix}\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_1} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_1} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_1}\end{bmatrix}\\ &\vdots\\ \cfrac{\partial L}{\partial s_4}&=\begin{bmatrix}\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_4} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_4} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_4}\end{bmatrix}\\ \end{aligned} ∂s1∂L∂s4∂L=[∂o1∗∂L∂s1∂o1∗+∂o2∗∂L∂s1∂o2∗+∂o3∗∂L∂s1∂o3∗]⋮=[∂o1∗∂L∂s4∂o1∗+∂o2∗∂L∂s4∂o2∗+∂o3∗∂L∂s4∂o3∗]
∂ L ∂ s = [ ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ s 1 + ∂ L ∂ o 2 ∗ ∂ o 2 ∗ ∂ s 1 + ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ s 1 ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ s 2 + ∂ L ∂ o 2 ∗ ∂ o 2 ∗ ∂ s 2 + ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ s 2 ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ s 3 + ∂ L ∂ o 2 ∗ ∂ o 2 ∗ ∂ s 3 + ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ s 3 ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ s 4 + ∂ L ∂ o 2 ∗ ∂ o 2 ∗ ∂ s 4 + ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ s 4 ] = [ ∂ o 1 ∗ ∂ s 1 ∂ o 2 ∗ ∂ s 1 ∂ o 3 ∗ ∂ s 1 ∂ o 1 ∗ ∂ s 2 ∂ o 2 ∗ ∂ s 2 ∂ o 3 ∗ ∂ s 2 ∂ o 1 ∗ ∂ s 3 ∂ o 2 ∗ ∂ s 3 ∂ o 3 ∗ ∂ s 3 ∂ o 1 ∗ ∂ s 4 ∂ o 2 ∗ ∂ s 4 ∂ o 3 ∗ ∂ s 4 ] × [ ∂ L ∂ o 1 ∗ ∂ L ∂ o 2 ∗ ∂ L ∂ o 3 ∗ ] = ? × ∂ L ∂ o t (5) \begin{aligned} \cfrac{\partial L}{\partial s}&=\begin{bmatrix}\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_1} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_1} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_1}\\\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_2} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_2} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_2}\\\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_3} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_3} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_3}\\\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_4} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_4} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_4}\end{bmatrix}\\&= \begin{bmatrix}\cfrac{\partial o_1^*}{\partial s_1}&\cfrac{\partial o_2^*}{\partial s_1}&\cfrac{\partial o_3^*}{\partial s_1}\\\cfrac{\partial o_1^*}{\partial s_2}&\cfrac{\partial o_2^*}{\partial s_2}&\cfrac{\partial o_3^*}{\partial s_2}\\\cfrac{\partial o_1^*}{\partial s_3}&\cfrac{\partial o_2^*}{\partial s_3}&\cfrac{\partial o_3^*}{\partial s_3}\\\cfrac{\partial o_1^*}{\partial s_4}&\cfrac{\partial o_2^*}{\partial s_4}&\cfrac{\partial o_3^*}{\partial s_4}\end{bmatrix}\times\begin{bmatrix}\cfrac{\partial L}{\partial o_1^*}\\\cfrac{\partial L}{\partial o_2^*}\\\cfrac{\partial L}{\partial o_3^*}\end{bmatrix} \\&=?\times{\partial L \over \partial o_t}\tag{5} \end{aligned} ∂s∂L=⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡∂o1∗∂L∂s1∂o1∗+∂o2∗∂L∂s1∂o2∗+∂o3∗∂L∂s1∂o3∗∂o1∗∂L∂s2∂o1∗+∂o2∗∂L∂s2∂o2∗+∂o3∗∂L∂s2∂o3∗∂o1∗∂L∂s3∂o1∗+∂o2∗∂L∂s3∂o2∗+∂o3∗∂L∂s3∂o3∗∂o1∗∂L∂s4∂o1∗+∂o2∗∂L∂s4∂o2∗+∂o3∗∂L∂s4∂o3∗⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤=⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡∂s1∂o1∗∂s2∂o1∗∂s3∂o1∗∂s4∂o1∗∂s1∂o2∗∂s2∂o2∗∂s3∂o2∗∂s4∂o2∗∂s1∂o3∗∂s2∂o3∗∂s3∂o3∗∂s4∂o3∗⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤×⎣⎢⎢⎢⎢⎢⎢⎡∂o1∗∂L∂o2∗∂L∂o3∗∂L⎦⎥⎥⎥⎥⎥⎥⎤=?×∂ot∂L(5)
要解决式5的后一步, 需要先向量求导的问题
参考链接: https://zhuanlan.zhihu.com/p/36448789
文中有一句话:
不过为了方便我们在实践中应用,通常情况下即使 y y y向量是列向量也按照行向量来进行求导。
根据这句话可以得出, 一般情况下是行向量对列向量求导.
行向量 X X X对列向量 Y Y Y求导会形成一个矩阵, 矩阵的宽度是 X X X的长度, 矩阵的高度是 Y Y Y的长度
所以式5中的问号矩阵应该是一个行向量 o t ∗ o_t^* ot∗对列向量 s s s求导
∂ L ∂ s t = ∂ o t ∗ ∂ s t × ∂ L ∂ o t ∗ (6) \cfrac{\partial L}{\partial s_t}=\cfrac{\partial o_t^*}{\partial s_t}\times\cfrac{\partial L}{\partial o_t^*}\tag{6} ∂st∂L=∂st∂ot∗×∂ot∗∂L(6)
式6中的 ∂ o t ∗ ∂ s t \cfrac{\partial o_t^*}{\partial s_t} ∂st∂ot∗还可以继续求出结果
∂ o t ∗ ∂ s t = [ ∂ o 1 ∗ ∂ s 1 ∂ o 2 ∗ ∂ s 1 ∂ o 3 ∗ ∂ s 1 ∂ o 1 ∗ ∂ s 2 ∂ o 2 ∗ ∂ s 2 ∂ o 3 ∗ ∂ s 2 ∂ o 1 ∗ ∂ s 3 ∂ o 2 ∗ ∂ s 3 ∂ o 3 ∗ ∂ s 3 ∂ o 1 ∗ ∂ s 4 ∂ o 2 ∗ ∂ s 4 ∂ o 3 ∗ ∂ s 4 ] = [ V 11 V 21 V 31 V 12 V 22 V 32 V 13 V 23 V 33 V 14 V 24 V 34 ] = V T \begin{aligned} \cfrac{\partial o_t^*}{\partial s_t}&=\begin{bmatrix}\cfrac{\partial o_1^*}{\partial s_1}&\cfrac{\partial o_2^*}{\partial s_1}&\cfrac{\partial o_3^*}{\partial s_1}\\\cfrac{\partial o_1^*}{\partial s_2}&\cfrac{\partial o_2^*}{\partial s_2}&\cfrac{\partial o_3^*}{\partial s_2}\\\cfrac{\partial o_1^*}{\partial s_3}&\cfrac{\partial o_2^*}{\partial s_3}&\cfrac{\partial o_3^*}{\partial s_3}\\\cfrac{\partial o_1^*}{\partial s_4}&\cfrac{\partial o_2^*}{\partial s_4}&\cfrac{\partial o_3^*}{\partial s_4}\end{bmatrix}\\ &=\begin{bmatrix}V_{11}&V_{21}&V_{31}\\V_{12}&V_{22}&V_{32}\\V_{13}&V_{23}&V_{33}\\V_{14}&V_{24}&V_{34}\end{bmatrix}\\ &=V^T \end{aligned} ∂st∂ot∗=⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡∂s1∂o1∗∂s2∂o1∗∂s3∂o1∗∂s4∂o1∗∂s1∂o2∗∂s2∂o2∗∂s3∂o2∗∂s4∂o2∗∂s1∂o3∗∂s2∂o3∗∂s3∂o3∗∂s4∂o3∗⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤=⎣⎢⎢⎡V11V12V13V14V21V22V23V24V31V32V33V34⎦⎥⎥⎤=VT
上面的结果带入式6中得到
∂ L ∂ s t = ∂ o t ∗ ∂ s t × ∂ L ∂ o t ∗ = V T × ∂ L ∂ o t ∗ (7) \cfrac{\partial L}{\partial s_t}=\cfrac{\partial o_t^*}{\partial s_t}\times\cfrac{\partial L}{\partial o_t^*}=V^T\times\cfrac{\partial L}{\partial o_t^*}\tag{7} ∂st∂L=∂st∂ot∗×∂ot∗∂L=VT×∂ot∗∂L(7)
到此为止, 所以涉及到的技术都已经写完了, 把求导结果都填到语法树上后