10 回归算法 - 梯度下降在线性回归中的应用

=== 梯度下降理论 ===

概念:想象你身处在一座山上,山上有很多倾斜角度不同的下坡路。现在我们想尽快走到山脚下,最快的方法是什么?
首先,你要找到当前落脚点附近最陡的下坡路,然后走上一步。接着在当前落脚地继续寻找下一个最陡峭的下坡路,然后再走上一步。依次循环后,我们就能以最短的路程走完整个下山路。

对应我们的目标:目标函数θ求解极小值的点在哪里?

10 回归算法 - 梯度下降在线性回归中的应用_第1张图片

1、初始化θ(随机初始化,可以初始化为0)
2、沿着负梯度方向迭代,更新后的θ使得J(θ)更小
θ = θ - α * ∂J(θ) / ∂θ
α : 学习率、步长
3、更新后的θ使得J(θ)更小,这个θ的更新,其实是指将(θ0、θ1、θ2、... 、θn) 一个个的更新。

更新每一个θ的步骤:

10 回归算法 - 梯度下降在线性回归中的应用_第2张图片
求导

第一步链式求导,不再赘述。
解释一下最后一步是如何转化的:


= ∂ / ∂θj
= xj


然后再思考以下的问题:
步骤1:初始化θ(随机初始化,可以初始化为0)
1、θ2、... 、θn) = (0,0,...,0)

在第1步进行计算的时候,首先应该更新的是θ1的值:
θ1 = θ1 - ( hθ(x)-y )x1
然后要更新θ2的值:
θ2 = θ2 - ( hθ(x)-y )x2

问题来了:
当θ还没有经过更新的时候,所有(θ1、θ2、... 、θn) = (0,0,...,0)
但当数据更新完θ1后,在要更新θ2时,在
中θ1的值是更新完之前的值0,还是θ1更新完后的值?(完成了最后一次迭代更新后的值)

答案是: 更新完之前的值0。
也就是说,每次在计算更新θ值的时候,我们不会去关心其他θ最终更新后的值是多少。
只有当所有的θ都更新完成后, 中θ1中的θ值才会发生变化。


关于J(θ)的函数要找到极小值点,对应的就是找到最小值点时θ的取值(哪一组θ的向量可以让损失函数最小)。

有时候函数所描绘出来的图形不是一个严格意义上的凸函数,如下图所示,它有很多局部的极小值点,但是我们要找到函数对应的全局最小值 (极小值的点)。

10 回归算法 - 梯度下降在线性回归中的应用_第3张图片

当初始值选择不同的时候,最后找到的极值点可能不同。所以,使用梯度下降法肯定能找到一个局部极小值点,但未必是全局的极小值。如何解决这个问题?

后续会说到SGD随机梯度下降。即,选择多个初始值点,来对比梯度下降法收敛时,两个极值点是否有区别。
简单举例:当θ=3时,J(θ)=2.5;当θ=5时,J(θ)=3.3; 最后比较后选择2.5是全局最优解。

注意:梯度下降只能求出近似解,而不是精确的解析解。但是这种求法速度快,能无限接近于真实值。

=== 案例 ===

引入头文件

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import warnings
import sklearn
from sklearn.linear_model import LinearRegression,Ridge, LassoCV, RidgeCV, ElasticNetCV
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline
from sklearn.linear_model.coordinate_descent import ConvergenceWarning
# inline 在行内显示
# plt.show() 在行内显示
%matplotlib inline 

# tk 在图形化界面显示,但是有可能会出现闪退的问题
# 解决方法 plt.ion() plt.pause(10) plt.close()
#%matplotlib tk

## 设置字符集,防止中文乱码
mpl.rcParams['font.sans-serif']=[u'simHei']
mpl.rcParams['axes.unicode_minus']=False

梯度下降
案例1:使用梯度下降法,求 y=f(x)= x2 的最小值
步骤:
1、要想知道f(x)取得最小值时x等于多少,首先要对f(x)进行求导 y'=2x
2、设置一个初始值x,通过初始值的不断更新和迭代,最终取得最小值时,观察x的取值是多少
3、指定一个步长step,即每次更新x的幅度。如果step太小,那么计算时间会很长。如果step过大,很可能一下子就错过最小值。

## 原函数
def f(x):
    return x ** 2

## 首先要对f(x)进行求导 y'=2x
def h(x):
    return 2 * x

X=[]
Y=[]

x=2 #初始值
step = 0.8 #步长

f_change = f(x)
f_current = f(x)
X.append(x)
Y.append(f_current)
while f_change>1e-10:
    x = x-step * h(x)
    tmp = f(x)
    f_change = np.abs(f_current - tmp)
    f_current = tmp
    X.append(x)
    Y.append(f_current)
    print(u'x=',x)
    print(u'f_change:',f_change,'f_current=',f_current)
print(u'最终结果为',(x,f_current))

x= -1.2000000000000002
f_change: 2.56 f_current= 1.4400000000000004
x= 0.7200000000000002
f_change: 0.9216 f_current= 0.5184000000000003
x= -0.43200000000000016
f_change: 0.331776 f_current= 0.18662400000000015
x= 0.2592000000000001
f_change: 0.11943936 f_current= 0.06718464000000005
x= -0.1555200000000001
f_change: 0.0429981696 f_current= 0.02418647040000003
x= 0.09331200000000006
f_change: 0.015479341056 f_current= 0.008707129344000012
x= -0.05598720000000004
f_change: 0.00557256278016 f_current= 0.0031345665638400047
x= 0.03359232000000004
f_change: 0.00200612260086 f_current= 0.0011284439629824024
x= -0.020155392000000022
f_change: 0.000722204136309 f_current= 0.0004062398266736649
x= 0.012093235200000017
f_change: 0.000259993489071 f_current= 0.00014624633760251945
x= -0.007255941120000012
f_change: 9.35976560656e-05 f_current= 5.2648681536907024e-05
x= 0.0043535646720000085
f_change: 3.36951561836e-05 f_current= 1.8953525353286542e-05
x= -0.0026121388032000056
f_change: 1.21302562261e-05 f_current= 6.823269127183157e-06
x= 0.0015672832819200039
f_change: 4.3668922414e-06 f_current= 2.4563768857859384e-06
x= -0.0009403699691520025
f_change: 1.5720812069e-06 f_current= 8.842956788829382e-07
x= 0.0005642219814912016
f_change: 5.65949234485e-07 f_current= 3.183464443978579e-07
x= -0.00033853318889472107
f_change: 2.03741724415e-07 f_current= 1.146047199832289e-07
x= 0.00020311991333683268
f_change: 7.33470207893e-08 f_current= 4.1257699193962417e-08
x= -0.00012187194800209963
f_change: 2.64049274841e-08 f_current= 1.4852771709826476e-08
x= 7.312316880125978e-05
f_change: 9.50577389429e-09 f_current= 5.3469978155375315e-09
x= -4.387390128075587e-05
f_change: 3.42207860194e-09 f_current= 1.9249192135935116e-09
x= 2.6324340768453522e-05
f_change: 1.2319482967e-09 f_current= 6.929709168936641e-10
x= -1.5794604461072116e-05
f_change: 4.43501386812e-10 f_current= 2.494695300817192e-10
x= 9.476762676643271e-06
f_change: 1.59660499252e-10 f_current= 8.980903082941893e-11
x= -5.686057605985963e-06
f_change: 5.74777797308e-11 f_current= 3.233125109859082e-11
最终结果为 (-5.686057605985963e-06, 3.233125109859082e-11)

画图:

fig = plt.figure()
X2 = np.arange(-2.1,2.15,0.05)
Y2 = X2 ** 2
print('X=',X)
print('Y=',Y)
plt.plot(X2,Y2,'-',color = '#666666',linewidth = 2)
plt.plot(X,Y,'bo--')
plt.title('$y=x^2$的最小值的最终解: x=%f y=%f'%(x,f_current))
plt.show()

X= [2, -1.2000000000000002, 0.7200000000000002, -0.43200000000000016, 0.2592000000000001, -0.1555200000000001, 0.09331200000000006, -0.05598720000000004, 0.03359232000000004, -0.020155392000000022, 0.012093235200000017, -0.007255941120000012, 0.0043535646720000085, -0.0026121388032000056, 0.0015672832819200039, -0.0009403699691520025, 0.0005642219814912016, -0.00033853318889472107, 0.00020311991333683268, -0.00012187194800209963, 7.312316880125978e-05, -4.387390128075587e-05, 2.6324340768453522e-05, -1.5794604461072116e-05, 9.476762676643271e-06, -5.686057605985963e-06]
Y= [4, 1.4400000000000004, 0.5184000000000003, 0.18662400000000015, 0.06718464000000005, 0.02418647040000003, 0.008707129344000012, 0.0031345665638400047, 0.0011284439629824024, 0.0004062398266736649, 0.00014624633760251945, 5.2648681536907024e-05, 1.8953525353286542e-05, 6.823269127183157e-06, 2.4563768857859384e-06, 8.842956788829382e-07, 3.183464443978579e-07, 1.146047199832289e-07, 4.1257699193962417e-08, 1.4852771709826476e-08, 5.3469978155375315e-09, 1.9249192135935116e-09, 6.929709168936641e-10, 2.494695300817192e-10, 8.980903082941893e-11, 3.233125109859082e-11]

10 回归算法 - 梯度下降在线性回归中的应用_第4张图片

案例2:使用梯度下降法,求 z=f(x,y)= x2+22 的最小值
步骤:
1、要想知道f(x,y)取得最小值时x,y分别等于多少,首先要对f(x,y)的x和y分别求偏导:∂z/∂x = 2x ;∂z/∂y = 4y ;
2、设置一个初始值x,y,通过初始值的不断更新和迭代,最终取得最小值时,观察x,y的取值是多少
3、指定一个步长step,即每次更新x的幅度。如果step太小,那么计算时间会很长。如果step过大,很可能一下子就错过最小值。

## 原函数
def f(x,y):
    return x ** 2 + 2*y ** 2
## 偏函数
def hx(t):
    return 2 * t

def hy(t):
    return 4 * t

X=[]
Y=[]
Z=[]

x=2
y=2

f_change = f(x,y)
f_current = f(x,y)
step = 0.1

X.append(x)
Y.append(y)
Z.append(f_current)

while f_change > 1e-10:
    x = x - step * hx(x)
    y = y - step * hy(y)
    f_change = f_current - f(x,y)
    f_current = f(x,y)
    X.append(x)
    Y.append(y)
    Z.append(f_current)
    print(u'x,y:',(x,y))
    print(u'f_change:',f_change,'f_current=',f_current)
    
print(u'最终结果为:',(x,y))

x,y: (1.6, 1.2)
f_change: 6.56 f_current= 5.44
x,y: (1.28, 0.72)
f_change: 2.7648 f_current= 2.6752000000000002
x,y: (1.024, 0.432)
f_change: 1.2533760000000003 f_current= 1.421824
x,y: (0.8192, 0.2592)
f_change: 0.6163660799999999 f_current= 0.80545792
x,y: (0.65536, 0.15552)
f_change: 0.3275882496 f_current= 0.47786967040000006
x,y: (0.5242880000000001, 0.09331199999999999)
f_change: 0.18557750476799995 f_current= 0.2922921656320001
x,y: (0.4194304000000001, 0.055987199999999994)
f_change: 0.11010117206016004 f_current= 0.18219099357184007
x,y: (0.33554432000000006, 0.033592319999999995)
f_change: 0.06734411496161283 f_current= 0.11484687861022724
x,y: (0.26843545600000007, 0.020155391999999994)
f_change: 0.04197680491895195 f_current= 0.0728700736912753
x,y: (0.21474836480000006, 0.012093235199999997)
f_change: 0.026460720831796354 f_current= 0.04640935285947894
x,y: (0.17179869184000005, 0.007255941119999998)
f_change: 0.016789264978469828 f_current= 0.029620087881009113
x,y: (0.13743895347200002, 0.004353564671999998)
f_change: 0.010692714898823952 f_current= 0.01892737298218516
x,y: (0.10995116277760002, 0.0026121388031999987)
f_change: 0.0068244682477845 f_current= 0.012102904734400661
x,y: (0.08796093022208001, 0.0015672832819199991)
f_change: 0.00436086673509546 f_current= 0.007742037999305201
x,y: (0.070368744177664, 0.0009403699691519995)
f_change: 0.002788509250805914 f_current= 0.004953528748499287
x,y: (0.056294995342131206, 0.0005642219814911997)
f_change: 0.0017837655550399173 f_current= 0.00316976319345937
x,y: (0.04503599627370496, 0.00033853318889471976)
f_change: 0.0011412930236542364 f_current= 0.0020284701698051336
x,y: (0.03602879701896397, 0.00020311991333683184)
f_change: 0.0007303134397730385 f_current= 0.001298156730032095
x,y: (0.028823037615171177, 0.0001218719480020991)
f_change: 0.00046735952712310285 f_current= 0.0008307972029089922
x,y: (0.02305843009213694, 7.312316880125945e-05)
f_change: 0.0002990953105993947 f_current= 0.0005317018923095975
x,y: (0.018446744073709553, 4.387390128075567e-05)
f_change: 0.0001914156755502318 f_current= 0.0003402862167593657
x,y: (0.014757395258967642, 2.63243407684534e-05)
f_change: 0.00012250411598813127 f_current= 0.00021778210077123443
x,y: (0.011805916207174114, 1.579460446107204e-05)
f_change: 7.840194434135785e-05 f_current= 0.00013938015642987658
x,y: (0.009444732965739291, 9.476762676643223e-06)
f_change: 5.01769960176924e-05 f_current= 8.920316041218418e-05
x,y: (0.0075557863725914335, 5.686057605985934e-06)
f_change: 3.2113188041443564e-05 f_current= 5.7089972370740616e-05
x,y: (0.006044629098073147, 3.4116345635915604e-06)
f_change: 2.055240815896724e-05 f_current= 3.653756421177338e-05
x,y: (0.004835703278458518, 2.046980738154936e-06)
f_change: 1.3153529634218635e-05 f_current= 2.3384034577554743e-05
x,y: (0.003868562622766814, 1.2281884428929618e-06)
f_change: 8.418254794392588e-06 f_current= 1.4965779783162154e-05
x,y: (0.0030948500982134514, 7.36913065735777e-07)
f_change: 5.387681566668612e-06 f_current= 9.578098216493543e-06
x,y: (0.002475880078570761, 4.421478394414662e-07)
f_change: 3.4481156620405597e-06 f_current= 6.129982554452983e-06
x,y: (0.001980704062856609, 2.652887036648797e-07)
f_change: 2.2067938290801117e-06 f_current= 3.923188725372871e-06
x,y: (0.0015845632502852873, 1.591732221989278e-07)
f_change: 1.4123479805459678e-06 f_current= 2.5108407448269034e-06
x,y: (0.00126765060022823, 9.550393331935669e-08)
f_change: 9.039026823259094e-07 f_current= 1.606938062500994e-06
x,y: (0.0010141204801825839, 5.730235999161401e-08)
f_change: 5.784977076081185e-07 f_current= 1.0284403548928755e-06
x,y: (0.0008112963841460671, 3.438141599496841e-08)
f_change: 3.70238529600229e-07 f_current= 6.582018252926464e-07
x,y: (0.0006490371073168537, 2.0628849596981046e-08)
f_change: 2.3695265776731848e-07 f_current= 4.2124916752532797e-07
x,y: (0.000519229685853483, 1.2377309758188627e-08)
f_change: 1.5164970054742573e-07 f_current= 2.6959946697790224e-07
x,y: (0.0004153837486827864, 7.426385854913176e-09)
f_change: 9.705580819783557e-08 f_current= 1.7254365878006667e-07
x,y: (0.0003323069989462291, 4.455831512947906e-09)
f_change: 6.211571719170869e-08 f_current= 1.1042794158835798e-07
x,y: (0.00026584559915698326, 2.6734989077687434e-09)
f_change: 3.9754058982927365e-08 f_current= 7.067388260543061e-08
x,y: (0.00021267647932558662, 1.604099344661246e-09)
f_change: 2.544259774195767e-08 f_current= 4.5231284863472945e-08
x,y: (0.0001701411834604693, 9.624596067967475e-10)
f_change: 1.628326255229122e-08 f_current= 2.8948022311181726e-08
x,y: (0.00013611294676837542, 5.774757640780485e-10)
f_change: 1.0421288032544167e-08 f_current= 1.852673427863756e-08
x,y: (0.00010889035741470034, 3.4648545844682907e-10)
f_change: 6.6696243404962685e-09 f_current= 1.185710993814129e-08
x,y: (8.711228593176027e-05, 2.0789127506809745e-10)
f_change: 4.268559577798094e-09 f_current= 7.588550360343196e-09
x,y: (6.968982874540821e-05, 1.2473476504085845e-10)
f_change: 2.7318781297477534e-09 f_current= 4.856672230595443e-09
x,y: (5.575186299632657e-05, 7.484085902451507e-11)
f_change: 1.7484020030230723e-09 f_current= 3.1082702275723705e-09
x,y: (4.460149039706126e-05, 4.490451541470904e-11)
f_change: 1.11897728192919e-09 f_current= 1.9892929456431805e-09
x,y: (3.568119231764901e-05, 2.694270924882542e-11)
f_change: 7.161454604326741e-10 f_current= 1.2731474852105065e-09
x,y: (2.8544953854119208e-05, 1.6165625549295253e-11)
f_change: 4.583330946761887e-10 f_current= 8.148143905343177e-10
x,y: (2.2835963083295365e-05, 9.699375329577152e-12)
f_change: 2.9333318059250085e-10 f_current= 5.214812099418169e-10
x,y: (1.8268770466636293e-05, 5.8196251977462915e-12)
f_change: 1.8773323557910674e-10 f_current= 3.3374797436271016e-10
x,y: (1.4615016373309035e-05, 3.4917751186477746e-12)
f_change: 1.201492707705946e-10 f_current= 2.1359870359211555e-10
x,y: (1.1692013098647227e-05, 2.0950650711886646e-12)
f_change: 7.689553329316844e-11 f_current= 1.3670317029894712e-10
最终结果为: (1.1692013098647227e-05, 2.0950650711886646e-12)

画图:
要引入3D画图的库,并且将画出来的图进行界面显示,方便我们从各个角度去查看3D图形

from mpl_toolkits.mplot3d import Axes3D
# tk 在图形化界面显示,但是有可能会出现闪退的问题
# 解决方法 plt.ion() plt.pause(10) plt.close()
%matplotlib tk
fig = plt.figure()
ax = Axes3D(fig)
X2 = np.arange(-2,2,0.2)
Y2 = np.arange(-2,2,0.2)
X2,Y2=np.meshgrid(X2,Y2)
Z2 = X2**2 +2*Y2**2
ax.plot_surface(X2,Y2,Z2,rstride=1,cstride=1,cmap='rainbow')
ax.plot(X,Y,Z,'ro--')
plt.show()
print('X=',X)
print('Y=',Y)
print('Z=',Z)
10 回归算法 - 梯度下降在线性回归中的应用_第5张图片
角度1
10 回归算法 - 梯度下降在线性回归中的应用_第6张图片
角度2
10 回归算法 - 梯度下降在线性回归中的应用_第7张图片
角度3
10 回归算法 - 梯度下降在线性回归中的应用_第8张图片
角度4

X= [2, 1.6, 1.28, 1.024, 0.8192, 0.65536, 0.5242880000000001, 0.4194304000000001, 0.33554432000000006, 0.26843545600000007, 0.21474836480000006, 0.17179869184000005, 0.13743895347200002, 0.10995116277760002, 0.08796093022208001, 0.070368744177664, 0.056294995342131206, 0.04503599627370496, 0.03602879701896397, 0.028823037615171177, 0.02305843009213694, 0.018446744073709553, 0.014757395258967642, 0.011805916207174114, 0.009444732965739291, 0.0075557863725914335, 0.006044629098073147, 0.004835703278458518, 0.003868562622766814, 0.0030948500982134514, 0.002475880078570761, 0.001980704062856609, 0.0015845632502852873, 0.00126765060022823, 0.0010141204801825839, 0.0008112963841460671, 0.0006490371073168537, 0.000519229685853483, 0.0004153837486827864, 0.0003323069989462291, 0.00026584559915698326, 0.00021267647932558662, 0.0001701411834604693, 0.00013611294676837542, 0.00010889035741470034, 8.711228593176027e-05, 6.968982874540821e-05, 5.575186299632657e-05, 4.460149039706126e-05, 3.568119231764901e-05, 2.8544953854119208e-05, 2.2835963083295365e-05, 1.8268770466636293e-05, 1.4615016373309035e-05, 1.1692013098647227e-05]
Y= [2, 1.2, 0.72, 0.432, 0.2592, 0.15552, 0.09331199999999999, 0.055987199999999994, 0.033592319999999995, 0.020155391999999994, 0.012093235199999997, 0.007255941119999998, 0.004353564671999998, 0.0026121388031999987, 0.0015672832819199991, 0.0009403699691519995, 0.0005642219814911997, 0.00033853318889471976, 0.00020311991333683184, 0.0001218719480020991, 7.312316880125945e-05, 4.387390128075567e-05, 2.63243407684534e-05, 1.579460446107204e-05, 9.476762676643223e-06, 5.686057605985934e-06, 3.4116345635915604e-06, 2.046980738154936e-06, 1.2281884428929618e-06, 7.36913065735777e-07, 4.421478394414662e-07, 2.652887036648797e-07, 1.591732221989278e-07, 9.550393331935669e-08, 5.730235999161401e-08, 3.438141599496841e-08, 2.0628849596981046e-08, 1.2377309758188627e-08, 7.426385854913176e-09, 4.455831512947906e-09, 2.6734989077687434e-09, 1.604099344661246e-09, 9.624596067967475e-10, 5.774757640780485e-10, 3.4648545844682907e-10, 2.0789127506809745e-10, 1.2473476504085845e-10, 7.484085902451507e-11, 4.490451541470904e-11, 2.694270924882542e-11, 1.6165625549295253e-11, 9.699375329577152e-12, 5.8196251977462915e-12, 3.4917751186477746e-12, 2.0950650711886646e-12]
Z= [12, 5.44, 2.6752000000000002, 1.421824, 0.80545792, 0.47786967040000006, 0.2922921656320001, 0.18219099357184007, 0.11484687861022724, 0.0728700736912753, 0.04640935285947894, 0.029620087881009113, 0.01892737298218516, 0.012102904734400661, 0.007742037999305201, 0.004953528748499287, 0.00316976319345937, 0.0020284701698051336, 0.001298156730032095, 0.0008307972029089922, 0.0005317018923095975, 0.0003402862167593657, 0.00021778210077123443, 0.00013938015642987658, 8.920316041218418e-05, 5.7089972370740616e-05, 3.653756421177338e-05, 2.3384034577554743e-05, 1.4965779783162154e-05, 9.578098216493543e-06, 6.129982554452983e-06, 3.923188725372871e-06, 2.5108407448269034e-06, 1.606938062500994e-06, 1.0284403548928755e-06, 6.582018252926464e-07, 4.2124916752532797e-07, 2.6959946697790224e-07, 1.7254365878006667e-07, 1.1042794158835798e-07, 7.067388260543061e-08, 4.5231284863472945e-08, 2.8948022311181726e-08, 1.852673427863756e-08, 1.185710993814129e-08, 7.588550360343196e-09, 4.856672230595443e-09, 3.1082702275723705e-09, 1.9892929456431805e-09, 1.2731474852105065e-09, 8.148143905343177e-10, 5.214812099418169e-10, 3.3374797436271016e-10, 2.1359870359211555e-10, 1.3670317029894712e-10]

你可能感兴趣的:(10 回归算法 - 梯度下降在线性回归中的应用)