这几天看书的时候突然注意到了这个经典的优化方法,于是重新推导了一遍,为以后应用做参考。
最小二乘法应该是我接触的最早的优化方法,也是求解线性回归的一种方法。线性回归的主要作用是用拟合的方式,求解两组变量之间的线性关系(当然也可以不是线性的,那就是另外的回归方法了)。也就是把一个系统的输出写成输入的线性组合的形式。而这组线性关系的参数求解方法,就是最小二乘法。
我们从最简单的线性回归开始,即输入和输出都是1维的。此时,最小二乘法也是最简单的。
假设有输入信号 x = { x 0 , x 1 , . . . , x t } x = \{x_0, x_1, ..., x_t\} x={x0,x1,...,xt},同时输出信号为 y = { y 0 , y 1 , . . . , y t } y = \{y_0, y_1, ..., y_t\} y={y0,y1,...,yt},我们假设输入信号 x x x和输出信号 y y y之间的关系可以写成如下形式:
y p r e = a x + b (1) y_{pre} = ax+b \tag{1} ypre=ax+b(1)
我们需要求解最优的 a a a和 b b b,这里最优的含义就是,预测的最准确,也就是预测值和真实值的误差最小,即:
a r g m i n a , b ∑ i = 0 t ( y i − a x i − b ) 2 (2) arg\, min_{a, b}{\sum_{i=0}^{t}{(y_i-ax_i-b)^2}} \tag{2} argmina,bi=0∑t(yi−axi−b)2(2)
我们假设误差函数为:
e r r = ∑ i = 0 t ( y i − a x i − b ) 2 (3) err = \sum_{i=0}^{t}{(y_i-ax_i-b)^2} \tag{3} err=i=0∑t(yi−axi−b)2(3)
e r r err err对 a a a和 b b b分别求偏导:
∂ e r r ∂ a = ∑ i = 0 t 2 ( a x i + b − y i ) ∗ x i (4) \frac{\partial{err}}{\partial{a}} = \sum_{i=0}^{t}{2(ax_i+b-y_i)*x_i} \tag{4} ∂a∂err=i=0∑t2(axi+b−yi)∗xi(4)
∂ e r r ∂ b = ∑ i = 0 t 2 ( a x i + b − y i ) (5) \frac{\partial{err}}{\partial{b}} = \sum_{i=0}^{t}{2(ax_i+b-y_i)} \tag{5} ∂b∂err=i=0∑t2(axi+b−yi)(5)
根据极值定理,有 ∂ e r r ∂ a = 0 \frac{\partial{err}}{\partial{a}}=0 ∂a∂err=0,且 ∂ e r r ∂ b = 0 \frac{\partial{err}}{\partial{b}}=0 ∂b∂err=0,所以有:
∑ i = 0 t 2 ( a x i + b − y i ) = 0 (6) \sum_{i=0}^{t}{2(ax_i+b-y_i)} = 0 \tag{6} i=0∑t2(axi+b−yi)=0(6)
∑ i = 0 t ( y i − a x i ) = ∑ i = 0 t b (7) \sum_{i=0}^{t}(y_i - ax_i) = \sum_{i=0}^{t}{b} \tag{7} i=0∑t(yi−axi)=i=0∑tb(7)
∑ i = 0 t y i − a ∗ ∑ i = 0 t x i = ( t + 1 ) ∗ b (8) \sum_{i=0}^{t}{y_i} - a * \sum_{i=0}^{t}{x_i} = (t+1)*b \tag{8} i=0∑tyi−a∗i=0∑txi=(t+1)∗b(8)
b = y ˉ − a x ˉ (9) b = \bar{y} - a\bar{x} \tag{9} b=yˉ−axˉ(9)
其中, y ˉ \bar{y} yˉ表示 y y y的均值, x ˉ \bar{x} xˉ表示 x x x的均值。将Eq(9)代入Eq(4),有:
∑ i = 0 t 2 ( a x i + b − y i ) ∗ x i = 0 (10) \sum_{i=0}^{t}{2(ax_i+b-y_i)*x_i} = 0 \tag{10} i=0∑t2(axi+b−yi)∗xi=0(10)
∑ i = 0 t a x i 2 + ∑ i = 0 t b x i = ∑ i = 0 t y i x i (11) \sum_{i=0}^{t}{ax_i^2} + \sum_{i=0}^{t}bx_i = \sum_{i=0}^{t}{y_ix_i} \tag{11} i=0∑taxi2+i=0∑tbxi=i=0∑tyixi(11)
a ∑ i = 0 t x i 2 + x ˉ ( y ˉ − a x ˉ ) = ∑ i = 0 t x i y i (12) a\sum_{i=0}^{t}x_i^2 + \bar{x}(\bar{y}-a\bar{x}) = \sum_{i=0}^{t}{x_iy_i} \tag{12} ai=0∑txi2+xˉ(yˉ−axˉ)=i=0∑txiyi(12)
a ( ∑ i = 0 t x i 2 − x ˉ 2 ) = ∑ i = 0 t x i y i − x ˉ y ˉ (13) a(\sum_{i=0}^{t}{x_i^2 - \bar{x}^2}) = \sum_{i=0}^{t}{x_iy_i}-\bar{x}\bar{y} \tag{13} a(i=0∑txi2−xˉ2)=i=0∑txiyi−xˉyˉ(13)
a = ∑ i = 0 t x i y i − x ˉ y ˉ ∑ i = 0 t x i 2 − x ˉ 2 (14) a = \frac{\sum_{i=0}^{t}{x_iy_i}-\bar{x}\bar{y}}{\sum_{i=0}^{t}{x_i^2 - \bar{x}^2}} \tag{14} a=∑i=0txi2−xˉ2∑i=0txiyi−xˉyˉ(14)
所以Eq(14)和Eq(9)就是最简单的最小二乘法的计算方法。
然后我们进一步考虑,如果输入和输出是多维数据,要如何计算。
假设输入信号为 X ∈ R m ∗ t X \in R^{m*t} X∈Rm∗t, 输出信号为 Y ∈ R n ∗ t Y \in R^{n*t} Y∈Rn∗t,那么有:
Y = W 0 X + B = W X 1 (15) Y = W_0X+B = WX_1 \tag{15} Y=W0X+B=WX1(15)
其中 W 0 ∈ R n ∗ m W_0 \in R^{n*m} W0∈Rn∗m是回归矩阵的系数, B ∈ R 1 ∗ t B \in R^{1*t} B∈R1∗t表示常数项,这里可以直接写到 W W W矩阵中。 W ∈ R n ∗ ( m + 1 ) W \in R^{n*(m+1)} W∈Rn∗(m+1), X 1 ∈ R ( m + 1 ) ∗ t X_1 \in R^{(m+1)*t} X1∈R(m+1)∗t
X 1 = [ x 11 x 12 . . . x 1 t x 11 x 12 . . . x 1 t ⋮ ⋮ . . . ⋮ x m 1 x m 2 . . . x m t 1 1 . . . 1 ] (16) X_1 = \begin{bmatrix} x_{11} &x_{12} & ... &x_{1t}\\ x_{11} &x_{12} & ... &x_{1t}\\ {\vdots} &{\vdots} &... &{\vdots}\\ x_{m1} &x_{m2} &... &x_{mt}\\ 1 &1 &... &1\\ \end{bmatrix} \tag{16} X1=⎣⎢⎢⎢⎢⎢⎡x11x11⋮xm11x12x12⋮xm21...............x1tx1t⋮xmt1⎦⎥⎥⎥⎥⎥⎤(16)
所以有:
arg m i n W ( Y − W X 1 ) (17) \arg min_{W}({Y-WX_1}) \tag{17} argminW(Y−WX1)(17)
假设误差函数为 E E E,则有:
E = ( Y − W X 1 ) ( Y − W X 1 ) T = Y Y T − W X 1 Y T − Y X 1 T W T + W X 1 X 1 T W T (18) E = (Y-WX_1)(Y-WX_1)^T = YY^T - WX_1Y^T-YX_1^TW^T+WX_1X_1^TW^T \tag{18} E=(Y−WX1)(Y−WX1)T=YYT−WX1YT−YX1TWT+WX1X1TWT(18)
计算 E E E对 W W W的偏导,则该偏导等于0:
∂ E ∂ W = − X 1 Y T − X 1 Y T + 2 W X X T = 0 (19) \frac{\partial{E}}{\partial{W}} = -X_1Y^T-X_1Y^T + 2WXX^T = 0 \tag{19} ∂W∂E=−X1YT−X1YT+2WXXT=0(19)
所以有:
W = ( X 1 X 1 T ) − 1 X 1 Y T (20) W = (X_1X_1^T)^{-1}X_1Y^T \tag{20} W=(X1X1T)−1X1YT(20)
至此矩阵形式的最小二乘法(多元线性回归的参数解法)推导完成。注意这里的 X 1 X_1 X1和 Y Y Y中的数据排列方式为:每一行是一个维度的数据,每一列表示一个时间点。如果不是这么记录的话,那么公式需要加上转置。
后续会附上代码链接