最小二乘法---线性回归的求解方法

这几天看书的时候突然注意到了这个经典的优化方法,于是重新推导了一遍,为以后应用做参考。

背景

最小二乘法应该是我接触的最早的优化方法,也是求解线性回归的一种方法。线性回归的主要作用是用拟合的方式,求解两组变量之间的线性关系(当然也可以不是线性的,那就是另外的回归方法了)。也就是把一个系统的输出写成输入的线性组合的形式。而这组线性关系的参数求解方法,就是最小二乘法。

我们从最简单的线性回归开始,即输入和输出都是1维的。此时,最小二乘法也是最简单的。

假设有输入信号 x = { x 0 , x 1 , . . . , x t } x = \{x_0, x_1, ..., x_t\} x={x0,x1,...,xt},同时输出信号为 y = { y 0 , y 1 , . . . , y t } y = \{y_0, y_1, ..., y_t\} y={y0,y1,...,yt},我们假设输入信号 x x x和输出信号 y y y之间的关系可以写成如下形式:

y p r e = a x + b (1) y_{pre} = ax+b \tag{1} ypre=ax+b(1)

我们需要求解最优的 a a a b b b,这里最优的含义就是,预测的最准确,也就是预测值和真实值的误差最小,即:

a r g   m i n a , b ∑ i = 0 t ( y i − a x i − b ) 2 (2) arg\, min_{a, b}{\sum_{i=0}^{t}{(y_i-ax_i-b)^2}} \tag{2} argmina,bi=0t(yiaxib)2(2)

我们假设误差函数为:

e r r = ∑ i = 0 t ( y i − a x i − b ) 2 (3) err = \sum_{i=0}^{t}{(y_i-ax_i-b)^2} \tag{3} err=i=0t(yiaxib)2(3)

e r r err err a a a b b b分别求偏导:

∂ e r r ∂ a = ∑ i = 0 t 2 ( a x i + b − y i ) ∗ x i (4) \frac{\partial{err}}{\partial{a}} = \sum_{i=0}^{t}{2(ax_i+b-y_i)*x_i} \tag{4} aerr=i=0t2(axi+byi)xi(4)

∂ e r r ∂ b = ∑ i = 0 t 2 ( a x i + b − y i ) (5) \frac{\partial{err}}{\partial{b}} = \sum_{i=0}^{t}{2(ax_i+b-y_i)} \tag{5} berr=i=0t2(axi+byi)(5)

根据极值定理,有 ∂ e r r ∂ a = 0 \frac{\partial{err}}{\partial{a}}=0 aerr=0,且 ∂ e r r ∂ b = 0 \frac{\partial{err}}{\partial{b}}=0 berr=0,所以有:

∑ i = 0 t 2 ( a x i + b − y i ) = 0 (6) \sum_{i=0}^{t}{2(ax_i+b-y_i)} = 0 \tag{6} i=0t2(axi+byi)=0(6)

∑ i = 0 t ( y i − a x i ) = ∑ i = 0 t b (7) \sum_{i=0}^{t}(y_i - ax_i) = \sum_{i=0}^{t}{b} \tag{7} i=0t(yiaxi)=i=0tb(7)

∑ i = 0 t y i − a ∗ ∑ i = 0 t x i = ( t + 1 ) ∗ b (8) \sum_{i=0}^{t}{y_i} - a * \sum_{i=0}^{t}{x_i} = (t+1)*b \tag{8} i=0tyiai=0txi=(t+1)b(8)

b = y ˉ − a x ˉ (9) b = \bar{y} - a\bar{x} \tag{9} b=yˉaxˉ(9)

其中, y ˉ \bar{y} yˉ表示 y y y的均值, x ˉ \bar{x} xˉ表示 x x x的均值。将Eq(9)代入Eq(4),有:

∑ i = 0 t 2 ( a x i + b − y i ) ∗ x i = 0 (10) \sum_{i=0}^{t}{2(ax_i+b-y_i)*x_i} = 0 \tag{10} i=0t2(axi+byi)xi=0(10)

∑ i = 0 t a x i 2 + ∑ i = 0 t b x i = ∑ i = 0 t y i x i (11) \sum_{i=0}^{t}{ax_i^2} + \sum_{i=0}^{t}bx_i = \sum_{i=0}^{t}{y_ix_i} \tag{11} i=0taxi2+i=0tbxi=i=0tyixi(11)

a ∑ i = 0 t x i 2 + x ˉ ( y ˉ − a x ˉ ) = ∑ i = 0 t x i y i (12) a\sum_{i=0}^{t}x_i^2 + \bar{x}(\bar{y}-a\bar{x}) = \sum_{i=0}^{t}{x_iy_i} \tag{12} ai=0txi2+xˉ(yˉaxˉ)=i=0txiyi(12)

a ( ∑ i = 0 t x i 2 − x ˉ 2 ) = ∑ i = 0 t x i y i − x ˉ y ˉ (13) a(\sum_{i=0}^{t}{x_i^2 - \bar{x}^2}) = \sum_{i=0}^{t}{x_iy_i}-\bar{x}\bar{y} \tag{13} a(i=0txi2xˉ2)=i=0txiyixˉyˉ(13)

a = ∑ i = 0 t x i y i − x ˉ y ˉ ∑ i = 0 t x i 2 − x ˉ 2 (14) a = \frac{\sum_{i=0}^{t}{x_iy_i}-\bar{x}\bar{y}}{\sum_{i=0}^{t}{x_i^2 - \bar{x}^2}} \tag{14} a=i=0txi2xˉ2i=0txiyixˉyˉ(14)

所以Eq(14)和Eq(9)就是最简单的最小二乘法的计算方法。

然后我们进一步考虑,如果输入和输出是多维数据,要如何计算。

假设输入信号为 X ∈ R m ∗ t X \in R^{m*t} XRmt, 输出信号为 Y ∈ R n ∗ t Y \in R^{n*t} YRnt,那么有:

Y = W 0 X + B = W X 1 (15) Y = W_0X+B = WX_1 \tag{15} Y=W0X+B=WX1(15)

其中 W 0 ∈ R n ∗ m W_0 \in R^{n*m} W0Rnm是回归矩阵的系数, B ∈ R 1 ∗ t B \in R^{1*t} BR1t表示常数项,这里可以直接写到 W W W矩阵中。 W ∈ R n ∗ ( m + 1 ) W \in R^{n*(m+1)} WRn(m+1) X 1 ∈ R ( m + 1 ) ∗ t X_1 \in R^{(m+1)*t} X1R(m+1)t
X 1 = [ x 11 x 12 . . . x 1 t x 11 x 12 . . . x 1 t ⋮ ⋮ . . . ⋮ x m 1 x m 2 . . . x m t 1 1 . . . 1 ] (16) X_1 = \begin{bmatrix} x_{11} &x_{12} & ... &x_{1t}\\ x_{11} &x_{12} & ... &x_{1t}\\ {\vdots} &{\vdots} &... &{\vdots}\\ x_{m1} &x_{m2} &... &x_{mt}\\ 1 &1 &... &1\\ \end{bmatrix} \tag{16} X1=x11x11xm11x12x12xm21...............x1tx1txmt1(16)

所以有:

arg ⁡ m i n W ( Y − W X 1 ) (17) \arg min_{W}({Y-WX_1}) \tag{17} argminW(YWX1)(17)

假设误差函数为 E E E,则有:

E = ( Y − W X 1 ) ( Y − W X 1 ) T = Y Y T − W X 1 Y T − Y X 1 T W T + W X 1 X 1 T W T (18) E = (Y-WX_1)(Y-WX_1)^T = YY^T - WX_1Y^T-YX_1^TW^T+WX_1X_1^TW^T \tag{18} E=(YWX1)(YWX1)T=YYTWX1YTYX1TWT+WX1X1TWT(18)

计算 E E E W W W的偏导,则该偏导等于0:

∂ E ∂ W = − X 1 Y T − X 1 Y T + 2 W X X T = 0 (19) \frac{\partial{E}}{\partial{W}} = -X_1Y^T-X_1Y^T + 2WXX^T = 0 \tag{19} WE=X1YTX1YT+2WXXT=0(19)

所以有:

W = ( X 1 X 1 T ) − 1 X 1 Y T (20) W = (X_1X_1^T)^{-1}X_1Y^T \tag{20} W=(X1X1T)1X1YT(20)

至此矩阵形式的最小二乘法(多元线性回归的参数解法)推导完成。注意这里的 X 1 X_1 X1 Y Y Y中的数据排列方式为:每一行是一个维度的数据,每一列表示一个时间点。如果不是这么记录的话,那么公式需要加上转置。

后续会附上代码链接

你可能感兴趣的:(数学)