机器学习中的矩阵向量求导(三) 矩阵向量求导之微分法

    在机器学习中的矩阵向量求导(二) 矩阵向量求导之定义法中,我们讨论了定义法求解矩阵向量求导的方法,但是这个方法对于比较复杂的求导式子,中间运算会很复杂,同时排列求导出的结果也很麻烦。因此我们需要其他的一些求导方法。本文我们讨论使用微分法来求解标量对向量的求导,以及标量对矩阵的求导。

    本文的标量对向量的求导,以及标量对矩阵的求导使用分母布局。如果遇到其他资料求导结果不同,请先确认布局是否一样。

1. 矩阵微分

    在高数里面我们学习过标量的导数和微分,他们之间有这样的关系:$df =f'(x)dx$。如果是多变量的情况,则微分可以写成:$$df=\sum\limits_{i=1}^n\frac{\partial f}{\partial x_i}dx_i = (\frac{\partial f}{\partial \mathbf{x}})^Td\mathbf{x}$$

    从上次我们可以发现标量对向量的求导和它的向量微分有一个转置的关系。

    现在我们再推广到矩阵。对于矩阵微分,我们的定义为:$$df=\sum\limits_{i=1}^m\sum\limits_{j=1}^n\frac{\partial f}{\partial X_{ij}}dX_{ij} = tr((\frac{\partial f}{\partial \mathbf{X}})^Td\mathbf{X})$$

    其中第二步使用了矩阵迹的性质,即迹函数等于主对角线的和。即$$tr(A^TB) = \sum\limits_{i,j}A_{ij}B_{ij}$$

    从上面矩阵微分的式子,我们可以看到矩阵微分和它的导数也有一个转置的关系,不过在外面套了一个迹函数而已。由于标量的迹函数就是它本身,那么矩阵微分和向量微分可以统一表示,即:$$df= tr((\frac{\partial f}{\partial \mathbf{X}})^Td\mathbf{X})\;\; \;df= tr((\frac{\partial f}{\partial \mathbf{x}})^Td\mathbf{x})$$

2. 矩阵微分的性质

    我们在讨论如何使用矩阵微分来求导前,先看看矩阵微分的性质:

    1)微分加减法:$d(X+Y) =dX+dY, d(X-Y) =dX-dY$

    2)  微分乘法:$d(XY) =(dX)Y + X(dY)$

    3)  微分转置:$d(X^T) =(dX)^T$

    4)  微分的迹:$dtr(X) =tr(dX)$

    5)  微分哈达马乘积: $d(X \odot Y) = X\odot dY + dX \odot Y$

    6) 逐元素求导:$d \sigma(X) =\sigma'(X) \odot dX$

    7) 逆矩阵微分:$d X^{-1}= -X^{-1}dXX^{-1}$

    8) 行列式微分:$d |X|= |X|tr(X^{-1}dX)$ 

    有了这些性质,我们再来看看如何由矩阵微分来求导数。

3. 使用微分法求解矩阵向量求导

    由于第一节我们已经得到了矩阵微分和导数关系,现在我们就来使用微分法求解矩阵向量求导。

    若标量函数$f$是矩阵$X$经加减乘法、逆、行列式、逐元素函数等运算构成,则使用相应的运算法则对$f$求微分,再使用迹函数技巧给$df$套上迹并将其它项交换至$dX$左侧,那么对于迹函数里面在$dX$左边的部分,我们只需要加一个转置就可以得到导数了。

    这里需要用到的迹函数的技巧主要有这么几个:

    1)  标量的迹等于自己:$tr(x) =x$

    2)  转置不变:$tr(A^T) =tr(A)$

    3)  交换率:$tr(AB) =tr(BA)$,需要满足$A,B^T$同维度。

    4)  加减法:$tr(X+Y) =tr(X)+tr(Y), tr(X-Y) =tr(X)-tr(Y)$

    5) 矩阵乘法和迹交换:$tr((A\odot B)^TC)= tr(A^T(B \odot C))$,需要满足$A,B,C$同维度。

    我们先看第一个例子,我们使用上一篇定义法中的一个求导问题:$$y=\mathbf{a}^T\mathbf{X}\mathbf{b}, \frac{\partial y}{\partial \mathbf{X}}$$

    首先,我们使用微分乘法的性质对$f$求微分,得到:$$dy = d\mathbf{a}^T\mathbf{X}\mathbf{b} + \mathbf{a}^Td\mathbf{X}\mathbf{b} + \mathbf{a}^T\mathbf{X}d\mathbf{b} = \mathbf{a}^Td\mathbf{X}\mathbf{b}$$

    第二步,就是两边套上迹函数,即:$$dy =tr(dy) = tr(\mathbf{a}^Td\mathbf{X}\mathbf{b}) = tr(\mathbf{b}\mathbf{a}^Td\mathbf{X})$$

    其中第一到第二步使用了上面迹函数性质1,第三步到第四步用到了上面迹函数的性质3.

    根据我们矩阵导数和微分的定义,迹函数里面在$dX$左边的部分$\mathbf{b}\mathbf{a}^T$,加上一个转置即为我们要求的导数,即:$$\frac{\partial f}{\partial \mathbf{X}} = (\mathbf{b}\mathbf{a}^T)^T =ab^T$$

    以上就是微分法的基本流程,先求微分再做迹函数变换,最后得到求导结果。比起定义法,我们现在不需要去对矩阵中的单个标量进行求导了。

    再来看看$$y=\mathbf{a}^Texp(\mathbf{X}\mathbf{b}), \frac{\partial y}{\partial \mathbf{X}}$$

    $$dy =tr(dy) = tr(\mathbf{a}^Tdexp(\mathbf{X}\mathbf{b})) = tr(\mathbf{a}^T (exp(\mathbf{X}\mathbf{b}) \odot d(\mathbf{X}\mathbf{b}))) = tr((\mathbf{a}  \odot exp(\mathbf{X}\mathbf{b}) )^T d\mathbf{X}\mathbf{b}) =  tr(\mathbf{b}(\mathbf{a}  \odot exp(\mathbf{X}\mathbf{b}) )^T d\mathbf{X})  $$

    其中第三步到第4步使用了上面迹函数的性质5. 这样我们的求导结果为:$$\frac{\partial y}{\partial \mathbf{X}} =(\mathbf{a}  \odot exp(\mathbf{X}\mathbf{b}) )b^T $$

    以上就是微分法的基本思路。

4. 迹函数对向量矩阵求导

    由于微分法使用了迹函数的技巧,那么迹函数对对向量矩阵求导这一大类问题,使用微分法是最简单直接的。下面给出一些常见的迹函数的求导过程,也顺便给大家熟练掌握微分法的技巧。

    首先是$\frac{\partial tr(AB)}{\partial A} = B^T, \frac{\partial tr(AB)}{\partial B} =A^T$,这个直接根据矩阵微分的定义即可得到。

    再来看看$\frac{\partial tr(W^TAW)}{\partial W}$:$$d(tr(W^TAW)) = tr(dW^TAW +W^TAdW) = tr(dW^TAW)+tr(W^TAdW) = tr((dW)^TAW) + tr(W^TAdW) = tr(W^TA^TdW) +  tr(W^TAdW) = tr(W^T(A+A^T)dW) $$

    因此可以得到:$$\frac{\partial tr(W^TAW)}{\partial W} = (A+A^T)W$$

    最后来个更加复杂的迹函数求导:$\frac{\partial tr(B^TX^TCXB)}{\partial X} $: $$d(tr(B^TX^TCXB)) = tr(B^TdX^TCXB) + tr(B^TX^TCdXB) = tr((dX)^TCXBB^T) + tr(BB^TX^TCdX) = tr(BB^TX^TC^TdX) + tr(BB^TX^TCdX) = tr((BB^TX^TC^T + BB^TX^TC)dX)$$

    因此可以得到:$$\frac{\partial tr(B^TX^TCXB)}{\partial X}= (C+C^T)XBB^T$$

5. 微分法求导小结

    使用矩阵微分,可以在不对向量或矩阵中的某一元素单独求导再拼接,因此会比较方便,当然熟练使用的前提是对上面矩阵微分的性质,以及迹函数的性质熟练运用。

    还有一些场景,求导的自变量和因变量直接有复杂的多层链式求导的关系,此时微分法使用起来也有些麻烦。如果我们可以利用一些常用的简单求导结果,再使用链式求导法则,则会非常的方便。因此下一篇我们讨论向量矩阵求导的链式法则。

 

(欢迎转载,转载请注明出处。欢迎沟通交流: [email protected]) 

你可能感兴趣的:(机器学习中的矩阵向量求导(三) 矩阵向量求导之微分法)