对于深度学习的基础“梯度下降”和“自动微分”的数学原理网上讲解的博客有很多了,但是目前没看到有讲关于矩阵形式的链式法则的内容,所以写了这篇笔记,供自己学习和复习。
我印象中本科生学习的传统微积分中,当时学的是只有标量多元函数才能求梯度,本科阶段也只介绍了标量链式法则。为了尽可能简洁明了地进行推导,首先复习几个简单的概念:
我本科数学阶段老师上课所讲的多元分析学,研究的函数基本上都在标量多元函数范畴内。
根据国际惯例,本文把标量记作小写字母 x x x , 向量记作粗体小写字母 x \textbf{x} x或者带有箭头上标的小写字母 x ⃗ \vec{x} x ,矩阵记作大写字母 X X X.
下面用一个简单的二元标量函数 y = x 1 2 + x 2 2 y = x_{1}^{2} + x_{2}^{2} y=x12+x22 为例简要介绍以上基本概念:
方向导数
求上述二元标量函数 y = f ( x ⃗ ) y=f(\vec{x}) y=f(x)在点 ( 1 , 1 ) (1,1) (1,1)沿方向 ( − 1 , − 1 ) (-1,-1) (−1,−1)的方向导数:
根据定义:
方向导数 = lim t → 0 f ( x + t ) − f ( x ) t 其中 t 表示沿着指定方向的向量, t 表示方向向量 t 的模长 \textbf{方向导数} = \lim_{t \to 0} \frac{f(\textbf{x}+\textbf{t}) - f(\textbf{x} )}{t} \\ 其中\textbf{t}表示沿着指定方向的向量,t表示方向向量\textbf{t}的模长 方向导数=t→0limtf(x+t)−f(x)其中t表示沿着指定方向的向量,t表示方向向量t的模长
代入数据:
函数 y = x 1 2 + x 2 2 在点 ( 1 , 1 ) 处 : f ( x ) = 2 将方向向量单位化: ( − 1 2 , − 1 2 ) , 则 lim t → 0 t ⃗ = ( − 1 2 t , − 1 2 t ) : 在点 ( 1 − 1 2 t , 1 − 1 2 t ) 处 : f ( x+t ) = t 2 − 2 2 t + 2 lim t → 0 f ( x + t ) − f ( x ) t = t 2 − 2 2 t t = − 2 2 函数y = x_{1}^{2} + x_{2}^{2} \quad 在点(1,1)处:f(\textbf{x})=2 \\ 将方向向量单位化:(-\frac{1}{\sqrt{2}}, -\frac{1}{\sqrt{2}}),则\lim_{t\to0}\vec{t} = (-\frac{1}{\sqrt{2}}t, -\frac{1}{\sqrt{2}}t) : \\ 在点(1-\frac{1}{\sqrt{2}}t,1-\frac{1}{\sqrt{2}}t)处:f(\textbf{x+t}) = t^{2} - 2\sqrt{2}t + 2 \\ \lim_{t \to 0} \frac{f(\textbf{x}+\textbf{t}) - f(\textbf{x} )}{t} = \frac{t^{2}-2\sqrt{2}t}{t} = -2\sqrt{2} 函数y=x12+x22在点(1,1)处:f(x)=2将方向向量单位化:(−21,−21),则t→0limt=(−21t,−21t):在点(1−21t,1−21t)处:f(x+t)=t2−22t+2t→0limtf(x+t)−f(x)=tt2−22t=−22
偏导数(偏导函数)
求上述函数对于 x 1 x_{1} x1的偏导(函)数:
根据定义:
f ′ ( x 1 ) = ∂ y ∂ x 1 = 2 x 1 {f}'(x_{1}) = \frac{\partial y}{\partial x_{1}} = 2x_{1} f′(x1)=∂x1∂y=2x1
梯度
求原函数在点 ( 1 , 1 ) (1,1) (1,1) 处的梯度:
根据定义:
▽ f = ( f ′ ( x 1 ) , f ′ ( x 2 ) , ⋯ , f ′ ( x n ) ) = ( ∂ y ∂ x 1 , ∂ y ∂ x 2 , ⋯ , ∂ y ∂ x n ) \begin{align*} \bigtriangledown f &= (f'(x_{1}), f'(x_{2}), \cdots, f'(x_{n})) \\ & = (\frac{\partial y}{\partial x_{1}}, \frac{\partial y}{\partial x_{2}}, \cdots, \frac{\partial y}{\partial x_{n}} ) \end{align*} ▽f=(f′(x1),f′(x2),⋯,f′(xn))=(∂x1∂y,∂x2∂y,⋯,∂xn∂y)
代入数据得:
▽ f = ( f ′ ( x 1 ) , f ′ ( x 2 ) ) = ( ∂ y ∂ x 1 , ∂ y ∂ x 2 ) = ( 2 x 1 , 2 x 2 ) = ( 2 , 2 ) \begin{align*} \bigtriangledown f &= (f'(x_{1}), f'(x_{2})) \\ & = (\frac{\partial y}{\partial x_{1}}, \frac{\partial y}{\partial x_{2}}) \\ & = (2x_{1}, 2x_{2}) \\ & = (2,2) \end{align*} ▽f=(f′(x1),f′(x2))=(∂x1∂y,∂x2∂y)=(2x1,2x2)=(2,2)
这个是本科工科数学一元分析学的重点,此处不用多做证明,只简单地记录一下:
若 y = f ( u ) , u = g ( x ) ,其中 y , u , x 均为标量,则 : ∂ y ∂ x = ∂ y ∂ u ⋅ ∂ u ∂ x 若y=f(u), u=g(x),其中y,u,x均为标量,则:\\ \frac{\partial y}{\partial x} = \frac{\partial y}{\partial u} \cdot \frac{\partial u}{\partial x} 若y=f(u),u=g(x),其中y,u,x均为标量,则:∂x∂y=∂u∂y⋅∂x∂u
设 y = f ( u ) , u = g ( x ⃗ ) y=f(u), u=g(\vec{x}) y=f(u),u=g(x), 其中 x ⃗ = ( x 1 , x 2 , ⋯ , x n ) \vec{x} = (x_{1},x_{2},\cdots,x_{n}) x=(x1,x2,⋯,xn) , y和u均为标量
则有:
∂ y ∂ x = ∂ y ∂ u ⋅ ∂ u ∂ x ( 1 , n ) = 1 ⋅ ( 1 , n ) \frac{\partial y}{\partial \textbf{x}} = \frac{\partial y}{\partial u} \cdot \frac{\partial u}{\partial \textbf{x}} \\ (1,n) = 1\cdot(1,n) ∂x∂y=∂u∂y⋅∂x∂u(1,n)=1⋅(1,n)
更具体的展开:
∂ y ∂ x = ( ∂ y ∂ x 1 , ∂ y ∂ x 2 , ⋯ , ∂ y ∂ x n ) = ( ∂ y ∂ u ⋅ ∂ u ∂ x 1 , ∂ y ∂ u ⋅ ∂ u ∂ x 2 , ⋯ , ∂ y ∂ u ⋅ ∂ u ∂ x n ) = ( ∂ y ∂ u ) ⋅ ( ∂ u ∂ x 1 , ∂ u ∂ x 2 , ⋯ , ∂ u ∂ x n ) \begin{align*} \frac{\partial y}{\partial \textbf{x}} &= ( \frac{\partial y}{\partial x_{1}}, \frac{\partial y}{\partial x_{2}}, \cdots, \frac{\partial y}{\partial x_{n}}) \\ &= ( \frac{\partial y}{\partial u}\cdot\frac{\partial u}{\partial x_{1}}, \frac{\partial y}{\partial u}\cdot\frac{\partial u}{\partial x_{2}}, \cdots, \frac{\partial y}{\partial u}\cdot\frac{\partial u}{\partial x_{n}}) \\ &= (\frac{\partial y}{\partial u}) \cdot ( \frac{\partial u}{\partial x_{1}}, \frac{\partial u}{\partial x_{2}}, \cdots, \frac{\partial u}{\partial x_{n}}) \end{align*} ∂x∂y=(∂x1∂y,∂x2∂y,⋯,∂xn∂y)=(∂u∂y⋅∂x1∂u,∂u∂y⋅∂x2∂u,⋯,∂u∂y⋅∂xn∂u)=(∂u∂y)⋅(∂x1∂u,∂x2∂u,⋯,∂xn∂u)
这是显然的。
前面提到的例子全部都只涉及到对标量多元函数求偏导数,这是本科的工科数学中就很熟悉的内容。下面介绍的矩阵形式的链式法则均涉及到向量多元函数的偏导数。对向量多元函数求梯度得到的是一个矩阵。
设 y = f ( u ⃗ ) , u ⃗ = g ( x ⃗ ) y=f(\vec{u}), \vec{u}=g(\vec{x}) y=f(u),u=g(x), 其中 u ⃗ = ( u 1 , u 2 , ⋯ , u k ) \vec{u} = (u_{1},u_{2},\cdots,u_{k}) u=(u1,u2,⋯,uk) , x ⃗ = ( x 1 , x 2 , ⋯ , x n ) \vec{x} = (x_{1},x_{2},\cdots,x_{n}) x=(x1,x2,⋯,xn) , y为标量
链式法则的具体展开:
∂ y ∂ x = ( ∂ y ∂ x 1 , ∂ y ∂ x 2 , ⋯ , ∂ y ∂ x n ) = ( ∂ y ∂ u ⋅ ∂ u ∂ x 1 , ∂ y ∂ u ⋅ ∂ u ∂ x 2 , ⋯ , ∂ y ∂ u ⋅ ∂ u ∂ x n ) = ( ∂ y ∂ u ) ⋅ ( ∂ u ∂ x 1 , ∂ u ∂ x 2 , ⋯ , ∂ u ∂ x n ) = ( ∂ y ∂ u 1 , ∂ y ∂ u 2 , ⋯ , ∂ y ∂ u k ) ⋅ [ ∂ u 1 ∂ x ∂ u 2 ∂ x ⋯ ∂ u k ∂ x ] = ( ∂ y ∂ u 1 , ∂ y ∂ u 2 , ⋯ , ∂ y ∂ u k ) ⋅ [ ∂ u 1 ∂ x 1 ∂ u 1 ∂ x 2 ⋯ ∂ u 1 ∂ x n ∂ u 2 ∂ x 1 ∂ u 2 ∂ x 2 ⋯ ∂ u 2 ∂ x n ⋯ ⋯ ⋯ ⋯ ∂ u k ∂ x 1 ∂ u k ∂ x 2 ⋯ ∂ u k ∂ x n ] \begin{align*} \frac{\partial y}{\partial \textbf{x}} &= ( \frac{\partial y}{\partial x_{1}}, \frac{\partial y}{\partial x_{2}}, \cdots, \frac{\partial y}{\partial x_{n}}) \\ &= ( \frac{\partial y}{\partial \textbf{u}}\cdot\frac{\partial \textbf{u}}{\partial x_{1}}, \frac{\partial y}{\partial \textbf{u}}\cdot\frac{\partial \textbf{u}}{\partial x_{2}}, \cdots, \frac{\partial y}{\partial \textbf{u}}\cdot\frac{\partial \textbf{u}}{\partial x_{n}}) \\ &= (\frac{\partial y}{\partial \textbf{u}}) \cdot ( \frac{\partial \textbf{u}}{\partial x_{1}}, \frac{\partial \textbf{u}}{\partial x_{2}}, \cdots, \frac{\partial \textbf{u}}{\partial x_{n}}) \\ &= (\frac{\partial {y}}{\partial u_{1}}, \frac{\partial {y}}{\partial u_{2}}, \cdots, \frac{\partial {y}}{\partial u_{k}}) \cdot \begin{bmatrix} \frac{\partial u_{1}}{\partial \textbf{x}} \\ \frac{\partial u_{2}}{\partial \textbf{x}}\\ \cdots \\ \frac{\partial u_{k}}{\partial \textbf{x}} \end{bmatrix} \\ &= (\frac{\partial {y}}{\partial u_{1}}, \frac{\partial {y}}{\partial u_{2}}, \cdots, \frac{\partial {y}}{\partial u_{k}}) \cdot \begin{bmatrix} \frac{\partial u_{1}}{\partial x_{1}}& \frac{\partial u_{1}}{\partial x_{2}}& \cdots & \frac{\partial u_{1}}{\partial x_{n}}\\ \frac{\partial u_{2}}{\partial x_{1}}& \frac{\partial u_{2}}{\partial x_{2}}& \cdots & \frac{\partial u_{2}}{\partial x_{n}} \\ \cdots &\cdots & \cdots & \cdots \\ \frac{\partial u_{k}}{\partial x_{1}}& \frac{\partial u_{k}}{\partial x_{2}}& \cdots & \frac{\partial u_{k}}{\partial x_{n}} \end{bmatrix} \end{align*} ∂x∂y=(∂x1∂y,∂x2∂y,⋯,∂xn∂y)=(∂u∂y⋅∂x1∂u,∂u∂y⋅∂x2∂u,⋯,∂u∂y⋅∂xn∂u)=(∂u∂y)⋅(∂x1∂u,∂x2∂u,⋯,∂xn∂u)=(∂u1∂y,∂u2∂y,⋯,∂uk∂y)⋅ ∂x∂u1∂x∂u2⋯∂x∂uk =(∂u1∂y,∂u2∂y,⋯,∂uk∂y)⋅ ∂x1∂u1∂x1∂u2⋯∂x1∂uk∂x2∂u1∂x2∂u2⋯∂x2∂uk⋯⋯⋯⋯∂xn∂u1∂xn∂u2⋯∂xn∂uk
即:
∂ y ∂ x = ∂ y ∂ u ⋅ ∂ u ∂ x ( 1 , n ) = ( 1 , k ) ⋅ ( k , n ) \frac{\partial y}{\partial \textbf{x}} = \frac{\partial y}{\partial \textbf{u}} \cdot \frac{\partial \textbf{u}}{\partial \textbf{x}} \\ (1,n) = (1,k)\cdot(k,n) ∂x∂y=∂u∂y⋅∂x∂u(1,n)=(1,k)⋅(k,n)
设 y ⃗ = f ( u ⃗ ) , u ⃗ = g ( x ⃗ ) \vec{y}=f(\vec{u}), \vec{u}=g(\vec{x}) y=f(u),u=g(x), 其中 y ⃗ = ( y 1 , y 2 , ⋯ , y m ) \vec{y} = (y_{1},y_{2},\cdots,y_{m}) y=(y1,y2,⋯,ym) , u ⃗ = ( u 1 , u 2 , ⋯ , u k ) \vec{u} = (u_{1},u_{2},\cdots,u_{k}) u=(u1,u2,⋯,uk) , x ⃗ = ( x 1 , x 2 , ⋯ , x n ) \vec{x} = (x_{1},x_{2},\cdots,x_{n}) x=(x1,x2,⋯,xn)
链式法则的具体展开:
∂ y ∂ x = ( ∂ y ∂ x 1 , ∂ y ∂ x 2 , ⋯ , ∂ y ∂ x n ) = ( ∂ y ∂ u ⋅ ∂ u ∂ x 1 , ∂ y ∂ u ⋅ ∂ u ∂ x 2 , ⋯ , ∂ y ∂ u ⋅ ∂ u ∂ x n ) = ( ∂ y ∂ u ) ⋅ ( ∂ u ∂ x 1 , ∂ u ∂ x 2 , ⋯ , ∂ u ∂ x n ) = [ ∂ y 1 ∂ u 1 ∂ y 1 ∂ u 2 ⋯ ∂ y 1 ∂ u k ∂ y 2 ∂ u 1 ∂ y 2 ∂ u 2 ⋯ ∂ y 2 ∂ u k ⋯ ⋯ ⋯ ⋯ ∂ y m ∂ u 1 ∂ y m ∂ u 2 ⋯ ∂ y m ∂ u k ] ⋅ [ ∂ u 1 ∂ x 1 ∂ u 1 ∂ x 2 ⋯ ∂ u 1 ∂ x n ∂ u 2 ∂ x 1 ∂ u 2 ∂ x 2 ⋯ ∂ u 2 ∂ x n ⋯ ⋯ ⋯ ⋯ ∂ u k ∂ x 1 ∂ u k ∂ x 2 ⋯ ∂ u k ∂ x n ] \begin{align*} \frac{\partial \textbf{y}}{\partial \textbf{x}} &= ( \frac{\partial \textbf{y}}{\partial x_{1}}, \frac{\partial \textbf{y}}{\partial x_{2}}, \cdots, \frac{\partial \textbf{y}}{\partial x_{n}}) \\ &= ( \frac{\partial \textbf{y}}{\partial \textbf{u}}\cdot\frac{\partial \textbf{u}}{\partial x_{1}}, \frac{\partial \textbf{y}}{\partial \textbf{u}}\cdot\frac{\partial \textbf{u}}{\partial x_{2}}, \cdots, \frac{\partial \textbf{y}}{\partial \textbf{u}}\cdot\frac{\partial \textbf{u}}{\partial x_{n}}) \\ &= (\frac{\partial \textbf{y}}{\partial \textbf{u}}) \cdot ( \frac{\partial \textbf{u}}{\partial x_{1}}, \frac{\partial \textbf{u}}{\partial x_{2}}, \cdots, \frac{\partial \textbf{u}}{\partial x_{n}}) \\ &= \begin{bmatrix} \frac{\partial y_{1}}{\partial u_{1}}& \frac{\partial y_{1}}{\partial u_{2}}& \cdots & \frac{\partial y_{1}}{\partial u_{k}}\\ \frac{\partial y_{2}}{\partial u_{1}}& \frac{\partial y_{2}}{\partial u_{2}}& \cdots & \frac{\partial y_{2}}{\partial u_{k}} \\ \cdots &\cdots & \cdots & \cdots \\ \frac{\partial y_{m}}{\partial u_{1}}& \frac{\partial y_{m}}{\partial u_{2}}& \cdots & \frac{\partial y_{m}}{\partial u_{k}} \end{bmatrix} \cdot \begin{bmatrix} \frac{\partial u_{1}}{\partial x_{1}}& \frac{\partial u_{1}}{\partial x_{2}}& \cdots & \frac{\partial u_{1}}{\partial x_{n}}\\ \frac{\partial u_{2}}{\partial x_{1}}& \frac{\partial u_{2}}{\partial x_{2}}& \cdots & \frac{\partial u_{2}}{\partial x_{n}} \\ \cdots &\cdots & \cdots & \cdots \\ \frac{\partial u_{k}}{\partial x_{1}}& \frac{\partial u_{k}}{\partial x_{2}}& \cdots & \frac{\partial u_{k}}{\partial x_{n}} \end{bmatrix} \end{align*} ∂x∂y=(∂x1∂y,∂x2∂y,⋯,∂xn∂y)=(∂u∂y⋅∂x1∂u,∂u∂y⋅∂x2∂u,⋯,∂u∂y⋅∂xn∂u)=(∂u∂y)⋅(∂x1∂u,∂x2∂u,⋯,∂xn∂u)= ∂u1∂y1∂u1∂y2⋯∂u1∂ym∂u2∂y1∂u2∂y2⋯∂u2∂ym⋯⋯⋯⋯∂uk∂y1∂uk∂y2⋯∂uk∂ym ⋅ ∂x1∂u1∂x1∂u2⋯∂x1∂uk∂x2∂u1∂x2∂u2⋯∂x2∂uk⋯⋯⋯⋯∂xn∂u1∂xn∂u2⋯∂xn∂uk
即:
∂ y ∂ x = ∂ y ∂ u ⋅ ∂ u ∂ x ( m , n ) = ( m , k ) ⋅ ( k , n ) \frac{\partial \textbf{y}}{\partial \textbf{x}} = \frac{\partial \textbf{y}}{\partial \textbf{u}} \cdot \frac{\partial \textbf{u}}{\partial \textbf{x}} \\ (m,n) = (m,k)\cdot(k,n) ∂x∂y=∂u∂y⋅∂x∂u(m,n)=(m,k)⋅(k,n)
如果将此时的链式法则画出计算图,可以清晰地看出,向量函数 y \textbf{y} y 对 向量 x \textbf{x} x 求偏导,实际上就是遍历了从y到x的所有依赖关系。这就是上面这个矩阵相乘的本质。 ]