很多机器学习算法都需要求解最值,比如最小二乘法求解样本空间相对拟合曲线的最短距离,最值的求解往往通过求导来计算,而机器学习中又常用矩阵来处理数据,所以很多时候会涉及到矩阵的求导。矩阵求导就像是线性代数和微积分的结合,但两者结合后规则又有些不同。
对于矩阵求导,基本的规则可以参照【手推机器学习】矩阵求导–合集,最需要注意的就是矩阵或者向量的shape以及求导结果的shape。
据视频介绍的规则, 假设 x x x 是一个列向量, 则 x T x x^{T} x xTx 是一个标量, 有 d x T x d x = 2 ∗ x T \frac{d x^{T} x}{d x}=2 * x^{T} dxdxTx=2∗xT,结果一定会是一个行向量, 所以结果是 x x x 的转置而不是 x x x。
下面会介绍另一个公式 d x T a d x = a T \frac{d x^{T} a}{d x}=a^{T} dxdxTa=aT (其中 a a a 可以是矩阵也可以是列向量), 如果 a a a 是一个列向量,则结果为行向量, 所以结果必然是带有转置的。
注意区分 d x T a d x = a T \frac{d x^{T} a}{d x}=a^{T} dxdxTa=aT和 d x T x d x = 2 ∗ x T \frac{d x^{T} x}{d x}=2 * x^{T} dxdxTx=2∗xT。这两个公式的推导过程并不困难, 比较基础。
假设 x = [ x 1 , x 2 , … , x n ] T , a = [ a 1 , a 2 , … a n ] T x=\left[x_{1}, x_{2}, \ldots, x_{n}\right]^{T} , a=\left[a_{1}, a_{2}, \ldots a_{n}\right]^{T} x=[x1,x2,…,xn]T,a=[a1,a2,…an]T, 则 x T a = x 1 a 1 + x 2 a 2 + … + x n a n x^{T} a=x_{1} a_{1}+x_{2} a_{2}+\ldots+x_{n} a_{n} xTa=x1a1+x2a2+…+xnan, 所以:
d ( x T a ) d x = [ d ( x T a ) d x 1 , d ( x T a ) d x 2 , … , d ( x T a ) d x n ] \frac{d\left(x^{T} a\right)}{d x}=\left[\frac{d\left(x^{T} a\right)}{d x_{1}}, \frac{d\left(x^{T} a\right)}{d x_{2}}, \ldots, \frac{d\left(x^{T} a\right)}{d x_{n}}\right] dxd(xTa)=[dx1d(xTa),dx2d(xTa),…,dxnd(xTa)]
其中 d ( x T a ) d x i = a i \frac{d\left(x^{T} a\right)}{d x_{i}}=a_{i} dxid(xTa)=ai , 所以上式为 d ( x T a ) d x = [ a 1 , a 2 , … , a n ] = a T \frac{d\left(x^{T} a\right)}{d x}=\left[a_{1}, a_{2}, \ldots, a_{n}\right]=a^{T} dxd(xTa)=[a1,a2,…,an]=aT 对于 d x T x d x = 2 ∗ x T \frac{d x^{T} x}{d x}=2 * x^{T} dxdxTx=2∗xT 推导过程相似。
矩阵求导中也具有链式求导法则,下面是一个相关的例子,注意矩阵的形状变化:
d f d ( s , t ) = ∂ f ∂ x ∂ x ∂ ( s , t ) = [ ∂ f ∂ x 1 ∂ f ∂ x 2 ] ⏟ = ∂ f ∂ x [ ∂ x 1 ∂ s ∂ x 1 ∂ t ∂ x 2 ∂ s ∂ x 2 ∂ t ] ⏟ = ∂ x ∂ ( s , t ) 其中 x = [ x 1 = x 1 ( s , t ) x 2 = x 2 ( s , t ) ] \frac{\mathrm{d} f}{\mathrm{~d}(s, t)}=\frac{\partial f}{\partial \boldsymbol{x}} \frac{\partial \boldsymbol{x}}{\partial(s, t)}=\underbrace{\left[\begin{array}{ll} \frac{\partial f}{\partial x_{1}} & \frac{\partial f}{\partial x_{2}} \end{array}\right]}_{=\frac{\partial f}{\partial x}} \underbrace{\left[\begin{array}{cc} \frac{\partial x_{1}}{\partial s} & \frac{\partial x_{1}}{\partial t} \\ \frac{\partial x_{2}}{\partial s} & \frac{\partial x_{2}}{\partial t} \end{array}\right]}_{=\frac{\partial x}{\partial(s, t)}} \\ 其中 x=\left[\begin{array}{l}x_{1}=x_{1}(s, t) \\ x_{2}=x_{2}(s, t)\end{array}\right] d(s,t)df=∂x∂f∂(s,t)∂x==∂x∂f [∂x1∂f∂x2∂f]=∂(s,t)∂x [∂s∂x1∂s∂x2∂t∂x1∂t∂x2]其中x=[x1=x1(s,t)x2=x2(s,t)]
很多时候,为了使结果更简洁,求导的结果也会用矩阵来表示,下面是常用的求导公式:
最朴素的公式的推导过程就是带入具体的矩阵按照视频中介绍的方式逐步的推导,这个过程很繁琐,但它的结果可能很简洁,就像上面给出的这些求导公式那样。
为了更好的理解这些公式,下面给出几个常用的公式的推导过程。
第一个公式是 ∂ x T A ∂ x = ∂ A T x ∂ x = A \frac{\partial x^{T} A}{\partial x} = \frac{\partial A^{T} x}{\partial x} = A ∂x∂xTA=∂x∂ATx=A,推导过程如下:
∂ x T A ∂ x = ∂ A T x ∂ x = [ ∂ ∑ i = 1 m A i x i ∂ x 1 ∂ ∑ i = 1 m A i x i ∂ x 2 ⋯ ∂ ∑ i − 1 m A i x i ∂ x m ] = [ A 1 A 2 ⋯ A m ] = A \frac{\partial x^{T} A}{\partial x} = \frac{\partial A^{T} x}{\partial x} \qquad\qquad\qquad\qquad\qquad\qquad \\\\ \qquad\quad\space = {\left[\begin{array}{} \frac{\partial \sum_{i = 1}^{m} A_{i} x_{i}}{\partial x_{1}} \ \frac{\partial \sum_{i = 1}^{m} A_{i} x_{i}}{\partial x_{2}} \ \cdots \ \frac{\partial \sum_{i-1}^{m} A_{i} x_{i}}{\partial x_{m}} \end{array}\right] } \\\\ \qquad\qquad = \left[\begin{array}{c} A_{1} \ A_{2} \ \cdots \ A_{m} \end{array}\right] \qquad\quad\space\qquad\quad\space\qquad\space\space\space \\\\ \qquad = A\qquad\quad\space\qquad\quad\space\qquad\quad\space\qquad\quad\space ∂x∂xTA=∂x∂ATx =[∂x1∂∑i=1mAixi ∂x2∂∑i=1mAixi ⋯ ∂xm∂∑i−1mAixi]=[A1 A2 ⋯ Am] =A
第二个公式是 ∂ x T B x ∂ x = ( A T + A ) x \frac{\partial x^{T} B x}{\partial x} = \left(A^{T}+A\right)x ∂x∂xTBx=(AT+A)x,推导过程如下:
∂ x T A x ∂ x = [ ∂ ∑ i = 1 m ∑ j − 1 m A i j x i x j ∂ x 1 ∂ ∑ i = 1 m ∑ j = 1 m A i j x i x j ∂ x 2 ⋯ ∂ ∑ i = 1 m ∑ j − 1 m A i j x i x j ∂ x m ] = [ ∑ i = 1 m A i 1 x i + ∑ j = 1 m A 1 j x j ∑ i = 1 m A i 2 x i + ∑ j = 1 m A 2 j x j ⋯ ∑ i = 1 m A i m x i + ∑ j = 1 m A m j x j ] = [ ∑ i = 1 m A i 1 x i ∑ i = 1 m A i 2 x i ⋯ ∑ i = 1 m A i m x i ] + [ ∑ j = 1 m A 1 j x j ∑ j = 1 m A 2 j x j ⋯ ∑ j = 1 m A m j x j ] = [ A 11 A 21 ⋯ A m 1 A 12 A 22 ⋯ A m 2 ⋮ ⋮ ⋱ ⋮ A 1 m A 2 m ⋯ A m m ] [ x 1 x 2 ⋮ x m ] + [ A 11 A 12 ⋯ A 1 m A 21 A 22 ⋯ A 2 m ⋮ ⋮ ⋱ ⋮ A m 1 A m 2 ⋯ A m m ] [ x 1 x 2 ⋮ x m ] = ( A T + A ) x = ( A + A T ) x \frac{\partial x^{T} A x}{\partial x} = \left[\begin{array}{c} \frac{\partial \sum_{i = 1}^{m} \sum_{j-1}^{m} A_{i j} x_{i} x_{j}}{\partial x_{1}} \\\\ \frac{\partial \sum_{i = 1}^{m} \sum_{j = 1}^{m} A_{i j} x_{i} x_{j}}{\partial x_{2}} \\\\ \cdots \\ \\ \frac{\partial \sum_{i = 1}^{m} \sum_{j-1}^{m} A_{i j} x_{i} x_{j}}{\partial x_{m}} \end{array}\right] \qquad \qquad\qquad \qquad\\ \\ = \left[\begin{array}{c} \sum_{i = 1}^{m} A_{i 1} x_{i}+\sum_{j = 1}^{m} A_{1 j} x_{j} \\ \sum_{i = 1}^{m} A_{i 2} x_{i}+\sum_{j = 1}^{m} A_{2 j} x_{j} \\ \cdots \\ \sum_{i = 1}^{m} A_{i m} x_{i}+\sum_{j = 1}^{m} A_{m j} x_{j} \end{array}\right] \\ \\ \space\space\space\space = \left[\begin{array}{c} \sum_{i = 1}^{m} A_{i 1} x_{i} \\ \sum_{i = 1}^{m} A_{i 2} x_{i} \\ \cdots \\ \sum_{i = 1}^{m} A_{i m} x_{i} \end{array}\right]+\left[\begin{array}{c} \sum_{j = 1}^{m} A_{1 j} x_{j} \\ \sum_{j = 1}^{m} A_{2 j} x_{j} \\ \cdots \\ \sum_{j = 1}^{m} A_{m j} x_{j} \end{array}\right] \\\\ \qquad = \left[\begin{array}{cccc} A_{11} & A 21 & \cdots & A m 1 \\ A_{12} & A 22 & \cdots & A m 2 \\ \vdots & \vdots & \ddots & \vdots \\ A_{1 m} & A 2 m & \cdots & A m m \end{array}\right]\left[\begin{array}{c} x_{1} \\ x_{2} \\ \vdots \\ x_{m} \end{array}\right]\\ \\ \qquad +\left[\begin{array}{cccc} A_{11} & A 12 & \cdots & A 1 m \\ A_{21} & A 22 & \cdots & A 2 m \\ \vdots & \vdots & \ddots & \vdots \\ A_{m 1} & A m 2 & \cdots & A m m \end{array}\right]\left[\begin{array}{c} x_{1} \\ x_{2} \\ \vdots \\ x_{m} \end{array}\right]\\ \\ = \left(A^{T}+A\right) x = \left(A+A^{T}\right) x \qquad ∂x∂xTAx=⎣ ⎡∂x1∂∑i=1m∑j−1mAijxixj∂x2∂∑i=1m∑j=1mAijxixj⋯∂xm∂∑i=1m∑j−1mAijxixj⎦ ⎤=⎣ ⎡∑i=1mAi1xi+∑j=1mA1jxj∑i=1mAi2xi+∑j=1mA2jxj⋯∑i=1mAimxi+∑j=1mAmjxj⎦ ⎤ =⎣ ⎡∑i=1mAi1xi∑i=1mAi2xi⋯∑i=1mAimxi⎦ ⎤+⎣ ⎡∑j=1mA1jxj∑j=1mA2jxj⋯∑j=1mAmjxj⎦ ⎤=⎣ ⎡A11A12⋮A1mA21A22⋮A2m⋯⋯⋱⋯Am1Am2⋮Amm⎦ ⎤⎣ ⎡x1x2⋮xm⎦ ⎤+⎣ ⎡A11A21⋮Am1A12A22⋮Am2⋯⋯⋱⋯A1mA2m⋮Amm⎦ ⎤⎣ ⎡x1x2⋮xm⎦ ⎤=(AT+A)x=(A+AT)x
在常用公式的基础上,通过矩阵的分配律可以化简求导过程,例如最小二乘法中方差函数对参数的求导,这里给出两种方式:
方式一:
KaTeX parse error: Expected 'EOF', got '&' at position 2: &̲\frac{dL(w)}{dw…
其中 w T ∗ X T ∗ y w^T*X^T*y wT∗XT∗y和 y T ∗ X ∗ w y^T*X*w yT∗X∗w都是标量且互为转置,因而两者相等,有:
上式 = − d ( 2 y T ∗ X ∗ w d w + w T ∗ X T ∗ X ∗ w d w 上式 = -\frac{d(2y^T*X*w}{dw} +\frac{w^T*X^T*X*w}{dw} 上式=−dwd(2yT∗X∗w+dwwT∗XT∗X∗w
结合上面的两个公式 d a T b d a = b T \frac{da^Tb}{da} = b^T dadaTb=bT和 d a T A a d a = A T a + A ∗ a = 2 A ∗ a \frac{da^TAa}{da} = A^Ta+A*a = 2A*a dadaTAa=ATa+A∗a=2A∗a(当 A A A为对称阵),有:
上式 = − 2 X T ∗ y + 2 X T ∗ X ∗ w \begin{matrix} 上式 = -2X^T*y+2X^T*X*w \end{matrix} 上式=−2XT∗y+2XT∗X∗w
令上式为0,得 w = ( X T ∗ X ) − 1 X T ∗ y w = (X^T*X)^{-1}X^T*y w=(XT∗X)−1XT∗y
方式二:
令 e ( w ) = y − X ∗ w e(w) = y-X*w e(w)=y−X∗w,则 L ( e ) = e T e L(e) = e^Te L(e)=eTe,由链式求导法则,有:
d L d w = d L d e d e d w = 2 e T ∗ ( − X ) = − 2 ( y − X ∗ w ) T ∗ X = − 2 y T ∗ X + 2 w T ∗ X T ∗ X \frac{dL}{dw} = \frac{dL}{de}\frac{de}{dw} \\ = 2e^T*(-X) \\ = -2(y-X*w)^T*X \\ = -2y^T*X+2w^T*X^T*X dwdL=dedLdwde=2eT∗(−X)=−2(y−X∗w)T∗X=−2yT∗X+2wT∗XT∗X
同样令上式为0,得 w T = y T ∗ X ∗ ( X T ∗ X ) − 1 w^T = y^T*X*(X^T*X)^{-1} wT=yT∗X∗(XT∗X)−1,由 X T ∗ X X^T*X XT∗X为对称矩阵有 [ ( X T ∗ X ) − 1 ] T = ( X T ∗ X ) − 1 [(X^T*X)^{-1}]^T = (X^T*X)^{-1} [(XT∗X)−1]T=(XT∗X)−1,得 w = ( X T ∗ X ) − 1 X T ∗ y w = (X^T*X)^{-1}X^T*y w=(XT∗X)−1XT∗y