在实际机器学习工作中,最常用的就是实值函数 y 对向量 x或矩阵 X 求导,比如最简单的线性回归问题中由目标函数 dJ(w) d J ( w ) 求解最佳参数向量 w w 。
矩阵/向量求导问题中要明确是什么量对什么量求导,得到的是什么形式的量
本文以线性回归问题中由目标函数 dJ(w) d J ( w ) 求解最佳参数向量 w w 问题为例子,介绍个人总结的一点机器学习矩阵求导的的技巧和方法,其中包括:
1. 全微分与偏导数关系
2. 迹技巧
3. 常用的矩阵求导公式
由上面的两个公式,若我们可以把标量函数f的全微分形式写出来,那么,对于实值函数对向量求导的类型,只需把全微分中dX前面的项求转置便可得到 ∂f∂X ∂ f ∂ X 。
例子: 线性回归
d J(w) J ( w )
=d(Xw−Y)T(Xw−Y)+(Xw−Y)Td(Xw−Y) = d ( X w − Y ) T ( X w − Y ) + ( X w − Y ) T d ( X w − Y )
=2(Xw−Y)TXdw = 2 ( X w − Y ) T X d w `
=(2∗XTXw−2∗XTY)Tdw = ( 2 ∗ X T X w − 2 ∗ X T Y ) T d w
因此, ▽wJ(w)=2∗XTXw−2∗XTY ▽ w J ( w ) = 2 ∗ X T X w − 2 ∗ X T Y
性质1 tra=a,tr(aA)=a∗trA t r a = a , t r ( a A ) = a ∗ t r A ,a为标量
性质2 tr(A+B)=trA+trB t r ( A + B ) = t r A + t r B
性质3 trAB=trBA,trABC=trCAB=trBCA t r A B = t r B A , t r A B C = t r C A B = t r B C A
性质4 trA=trAT t r A = t r A T
性质5 ▽Atr(AB)=BT ▽ A t r ( A B ) = B T
性质6 ▽Atr(ABATC)=CAB+CTABT ▽ A t r ( A B A T C ) = C A B + C T A B T
实例计算:使用迹的技巧求解线性回归的最佳参数。
▽wJ(w)=▽wtrJ(w) ▽ w J ( w ) = ▽ w t r J ( w )
=▽wtr(Xw−Y)T(Xw−Y) = ▽ w t r ( X w − Y ) T ( X w − Y )
=▽wtr(wTXTXw−YTXw−wTXTY+YTY) = ▽ w t r ( w T X T X w − Y T X w − w T X T Y + Y T Y )
▽wJ(w)=▽wtrJ(w) ▽ w J ( w ) = ▽ w t r J ( w )
=▽wtr(Xw−Y)T(Xw−Y) = ▽ w t r ( X w − Y ) T ( X w − Y )
=▽wtr(wTXTXw−YTXw−wTXTY+YTY) = ▽ w t r ( w T X T X w − Y T X w − w T X T Y + Y T Y )
=▽wtr(wTXTXw)−▽wtr(YTXw)−▽wtr(wTXTY) = ▽ w t r ( w T X T X w ) − ▽ w t r ( Y T X w ) − ▽ w t r ( w T X T Y )
注:
tr(A+B)=trA+trB t r ( A + B ) = t r A + t r B
▽wJ(w)=▽wtrJ(w) ▽ w J ( w ) = ▽ w t r J ( w )
=▽wtr(Xw−Y)T(Xw−Y) = ▽ w t r ( X w − Y ) T ( X w − Y )
=▽wtr(wTXTXw−YTXw−wTXTY+YTY) = ▽ w t r ( w T X T X w − Y T X w − w T X T Y + Y T Y )
=▽wtr(wTXTXw)−▽wtr(YTXw)−▽wtr(wTXTY) = ▽ w t r ( w T X T X w ) − ▽ w t r ( Y T X w ) − ▽ w t r ( w T X T Y )
=▽wtr(wwTXTX)−▽wtr(YTXw)−▽wtr(wTXTY) = ▽ w t r ( w w T X T X ) − ▽ w t r ( Y T X w ) − ▽ w t r ( w T X T Y )
=▽wtr(wwTXTX)−2∗▽wtr(YTXw) = ▽ w t r ( w w T X T X ) − 2 ∗ ▽ w t r ( Y T X w )
▽wJ(w)=▽wtrJ(w) ▽ w J ( w ) = ▽ w t r J ( w )
=▽wtr(Xw−Y)T(Xw−Y) = ▽ w t r ( X w − Y ) T ( X w − Y )
=▽wtr(wTXTXw−YTXw−wTXTY+YTY) = ▽ w t r ( w T X T X w − Y T X w − w T X T Y + Y T Y )
=▽wtr(wTXTXw)−▽wtr(YTXw)−▽wtr(wTXTY) = ▽ w t r ( w T X T X w ) − ▽ w t r ( Y T X w ) − ▽ w t r ( w T X T Y )
=▽wtr(wwTXTX)−▽wtr(YTXw)−▽wtr(wTXTY) = ▽ w t r ( w w T X T X ) − ▽ w t r ( Y T X w ) − ▽ w t r ( w T X T Y )
=▽wtr(wwTXTX)−2∗▽wtr(YTXw) = ▽ w t r ( w w T X T X ) − 2 ∗ ▽ w t r ( Y T X w )
=▽wtr(wIwTXTX)−2∗▽wtr(YTXw) = ▽ w t r ( w I w T X T X ) − 2 ∗ ▽ w t r ( Y T X w )
=(XTXwI+XTXIw)−2∗▽wtr(YTXw) = ( X T X w I + X T X I w ) − 2 ∗ ▽ w t r ( Y T X w )
▽wJ(w)=▽wtrJ(w) ▽ w J ( w ) = ▽ w t r J ( w )
=▽wtr(Xw−Y)T(Xw−Y) = ▽ w t r ( X w − Y ) T ( X w − Y )
=▽wtr(wTXTXw−YTXw−wTXTY+YTY) = ▽ w t r ( w T X T X w − Y T X w − w T X T Y + Y T Y )
=▽wtr(wTXTXw)−▽wtr(YTXw)−▽wtr(wTXTY) = ▽ w t r ( w T X T X w ) − ▽ w t r ( Y T X w ) − ▽ w t r ( w T X T Y )
=▽wtr(wwTXTX)−▽wtr(YTXw)−▽wtr(wTXTY) = ▽ w t r ( w w T X T X ) − ▽ w t r ( Y T X w ) − ▽ w t r ( w T X T Y )
=▽wtr(wwTXTX)−2∗▽wtr(YTXw) = ▽ w t r ( w w T X T X ) − 2 ∗ ▽ w t r ( Y T X w )
=▽wtr(wIwTXTX)−2∗▽wtr(YTXw) = ▽ w t r ( w I w T X T X ) − 2 ∗ ▽ w t r ( Y T X w )
=(XTXwI+XTXIw)−2∗▽wtr(YTXw) = ( X T X w I + X T X I w ) − 2 ∗ ▽ w t r ( Y T X w )
=2∗XTXw−2∗XTY = 2 ∗ X T X w − 2 ∗ X T Y
矩阵/向量求导问题中要明确是什么量对什么量求导,得到的是什么形式的量
重要的矩阵求导公式:公式证明可以用微分分解加迹技巧证明。
证明第一条公式:
d(xTAx)=d(xT)Ax+xTd(Ax) d ( x T A x ) = d ( x T ) A x + x T d ( A x )
=(Ax)Tdx+xT(AT)Tdx = ( A x ) T d x + x T ( A T ) T d x
=(xTAT+xTA)dx = ( x T A T + x T A ) d x
则:
▽wJ(w) ▽ w J ( w )
=▽w(Xw−Y)T(Xw−Y) = ▽ w ( X w − Y ) T ( X w − Y )
=▽w(wTXTXw−YTXw−wTXTY+YTY) = ▽ w ( w T X T X w − Y T X w − w T X T Y + Y T Y )
=▽w(wTXTXw)−▽w(YTXw)−▽w(wTXTY) = ▽ w ( w T X T X w ) − ▽ w ( Y T X w ) − ▽ w ( w T X T Y )
=2∗XTXw−XTY−XTY = 2 ∗ X T X w − X T Y − X T Y
=2∗XTXw−2∗XTY = 2 ∗ X T X w − 2 ∗ X T Y
- 注:求导公式忘了可以用微分转换和迹技巧推导。
参考文章:矩阵求导公式
参考文章:矩阵求导术
参考视频:吴恩达机器学习课程
参考文章:重要矩阵求导的公式