线性回归的原理及实践(牛顿法)

至此分类算法在前面都学完了,下面将进行回归算法的学习。

回归算法和分类算法都属于监督学习算法,不同的是分类算法中标签的是一些离散的值,代表不同的类别,而在回归算法中,标签是一些连续的值回归算法需要训练得到样本特征到这些连续标签之间的映射线性回归是一类重要的回归问题,在线性回归中,目标值与特征之间存在着线性相关的关系。

一、线性回归

基本形式:给定由 d 个属性描述的示例 x=(x_{1};x_{2};...;x_{d}) ,线性模型试图学得一个通过属性的线性组合来进行预测的函数,即:

                                                            f(x)=w_{1}x_{1}+w_{2}x_{2}+...+w_{d}x_{d}+b

一般形式为:

                                                                           f(x)=w^{T}x+b

w 和 b 学得之后,模型就得以确定。

线性回归试图学得一个线性模型以尽可能准确地预测实值输出标记。

1、当输入属性就一个时线性回归试图学得:

                                                         f(x)=wx_{i}+b ,使得   f(x_{i})\simeq y_{i}

    如何确定 w 和 b 那?关键是衡量 f 和 y 之间的差别。均方误差是回归任务中最常用的性能度量,因此可以使均方误差最小。基于均方误差最小化来进行模型求解的方法称为“最小二乘法”。在线性回归中,最小二乘法就是找到一条直线,使所有样本到直线的欧式距离和最小。即:    

线性回归的原理及实践(牛顿法)_第1张图片

2、输入d个属性时试图学得:

二、牛顿法

除了梯度下降法,牛顿法也是机器学习中用的比较多的一种优化算法。牛顿法的基本思想是利用迭代点  x_{k} 处的一阶导数 (梯度)和二阶导数 ( Hessen 矩阵) 对目标函数进行二次函数近似,然后把二次模型的极小点作为新的迭代点,并不断重复这一过程,直至求得满足精度的近似极小值。牛顿法的速度相当快,而且能高度逼近最优值。牛顿法最突出的优点是收敛速度快,具有局部二阶收敛性,其分为基本牛顿法和全局牛顿法。

1.基本牛顿法

基本牛顿法是基于导数的算法,他每一步的迭代方向都是沿着当前点函数值下降的方向。对于一维的情形,对需要求解的优化函数 f(x),求函数的极值的问题可以转化为求导函数 f^{,}(x)=0。对 f(x) 进行泰勒展开到二阶,得:

                                                               f(x)=f(x_{k})+{f}'(x_{k})(x-x_{k})+\frac{1}{2}{f}''(x_{k})(x-x_{k})^{2}

对上式求导并令其为0,则:

                                                                            {f}'(x_{k})+{f}''(x_{k})(x-x_{k})=0

即得到:

                                                                                       x=x_{k}-\frac{{f}'(x_{k})}{{f}''(x_{k})}

这就是牛顿法的更新公式。

基本牛顿法的流程:

线性回归的原理及实践(牛顿法)_第2张图片

2.全局牛顿法

基本牛顿法初始时需要足够靠近极小点,否则将会导致算法不收敛,这时引入全局牛顿法。

全局牛顿法的流程为:

线性回归的原理及实践(牛顿法)_第3张图片

3.Armijo 搜索

全局牛顿法是基于 Armijo 的搜索,满足 Armijo 准则:

给定 \beta \in (0,1) ,\sigma \in (0,0.5),令步长因子 \alpha _{k}=\beta ^{m_{k}} ,其中 m_{k} 是满足下列不等式的最小非负整数:

三、利用线性回归进行预测

# -*- coding: utf-8 -*-
"""
Created on Thu Mar 21 20:43:26 2019

@author: 2018061801
"""

import numpy as np
from math import pow

def load_data(file_path):
    '''导入数据
    input:  file_path(string):训练数据
    output: feature(mat):特征
            label(mat):标签
    '''
    f = open(file_path)
    feature = []
    label = []
    for line in f.readlines():
        feature_tmp = []
        lines = line.strip().split("\t")
        feature_tmp.append(1)  # x0
        for i in range(len(lines) - 1):
            feature_tmp.append(float(lines[i]))
        feature.append(feature_tmp)
        label.append(float(lines[-1]))
    f.close()
    return np.mat(feature), np.mat(label).T

def least_square(feature, label):
    '''最小二乘法
    input:  feature(mat):特征
            label(mat):标签
    output: w(mat):回归系数
    '''
    w = (feature.T * feature).I * feature.T * label
    return w

def first_derivativ(feature, label, w):
    '''计算一阶导函数的值
    input:  feature(mat):特征
            label(mat):标签
    output: g(mat):一阶导数值
    '''
    m, n = np.shape(feature)
    g = np.mat(np.zeros((n, 1)))
    for i in range(m):
        err = label[i, 0] - feature[i, ] * w
        for j in range(n):
            g[j, ] -= err * feature[i, j]
    return g     

def second_derivative(feature):
    '''计算二阶导函数的值
    input:  feature(mat):特征
    output: G(mat):二阶导数值
    '''
    m, n = np.shape(feature)
    G = np.mat(np.zeros((n, n)))
    for i in range(m):
        x_left = feature[i, ].T
        x_right = feature[i, ]
        G += x_left * x_right
    return G

def get_error(feature, label, w):
    '''计算误差
    input:  feature(mat):特征
            label(mat):标签
            w(mat):线性回归模型的参数
    output: 损失函数值
    '''
    return (label - feature * w).T * (label - feature * w) / 2

def get_min_m(feature, label, sigma, delta, d, w, g):
    '''计算步长中最小的值m
    input:  feature(mat):特征
            label(mat):标签
            sigma(float),delta(float):全局牛顿法的参数
            d(mat):负的一阶导数除以二阶导数值
            g(mat):一阶导数值
    output: m(int):最小m值
    '''
    m = 0
    while True:
        w_new = w + pow(sigma, m) * d
        left = get_error(feature, label , w_new)
        right = get_error(feature, label , w) + delta * pow(sigma, m) * g.T * d
        if left <= right:
            break
        else:
            m += 1
    return m           

def newton(feature, label, iterMax, sigma, delta):
    '''牛顿法
    input:  feature(mat):特征
            label(mat):标签
            iterMax(int):最大迭代次数
            sigma(float), delta(float):牛顿法中的参数
    output: w(mat):回归系数
    '''
    n = np.shape(feature)[1]
    w = np.mat(np.zeros((n, 1)))
    it = 0
    while it <= iterMax:
        # print it
        g = first_derivativ(feature, label, w)  # 一阶导数
        G = second_derivative(feature)  # 二阶导数
        d = -G.I * g
        m = get_min_m(feature, label, sigma, delta, d, w, g)  # 得到最小的m
        w = w + pow(sigma, m) * d
        if it % 10 == 0:
            print ("\t---- itration: ", it, " , error: ", get_error(feature, label , w)[0, 0])
        it += 1       
    return w

def save_model(file_name, w):
    '''保存最终的模型
    input:  file_name(string):要保存的文件的名称
            w(mat):训练好的线性回归模型
    '''
    f_result = open(file_name, "w")
    m, n = np.shape(w)
    for i in range(m):
        w_tmp = []
        for j in range(n):
            w_tmp.append(str(w[i, j]))
        f_result.write("\t".join(w_tmp) + "\n")
    f_result.close()
    

if __name__ == "__main__":
    # 1、导入数据集
    print ("----------- 1.load data ----------")
    feature, label = load_data("D:/anaconda4.3/spyder_work/data2.txt")
    # 2.1、最小二乘求解
    print ("----------- 2.training ----------")
    print ("\t ---------- least_square ----------")
    w_ls = least_square(feature, label)
    # 2.2、牛顿法
    #print ("\t ---------- newton ----------")
    #w_newton = newton(feature, label, 50, 0.1, 0.5)
    # 3、保存最终的结果
    print ("----------- 3.save result ----------")
    save_model("weights", w_ls)
    

使用全局牛顿法的结果:

----------- 1.load data ----------
----------- 2.training ----------
         ---------- newton ----------
        ---- itration:  0  , error:  12.346444091730936
        ---- itration:  10  , error:  0.07017065415130548
        ---- itration:  20  , error:  0.07017065415130548
        ---- itration:  30  , error:  0.07017065415130548
        ---- itration:  40  , error:  0.07017065415130548
        ---- itration:  50  , error:  0.07017065415130548
----------- 3.save result ----------

使用最小二乘法结果:

----------- 1.load data ----------
----------- 2.training ----------
         ---------- least_square ----------
----------- 3.save result ----------

对于使用最小二乘法和全局牛顿法,线性回归模型最终得到相同的参数值:

w0=0.0031049944337919275

w1=0.9945024703102509

训练数据:

0.422285418967358	0.429005468089679
0.548811087562498	0.532492334219154
0.0239703698338769	0.0548126922054728
0.611366447087570	0.585248758640251
0.252719613340340	0.285278442421827
0.265421160090244	0.294045385350157
0.219072519567866	0.236511959989685
0.401861387493983	0.375684767186615
0.208115403157731	0.205472741349725
0.239844960876992	0.272723041971107
0.229903140619694	0.258764519520563
0.430828843674200	0.451484834079148
0.448723772214619	0.481495140915047
0.679206578570127	0.709728292693527
0.647192437871312	0.674913329155585
0.230590652199338	0.259896354413699
0.148270899244940	0.121926223582234
0.636281127313104	0.651125466408602
0.205732203664697	0.198436666800684
0.399610991227472	0.371592207288394
0.636183808150379	0.650870007189679
0.0472145428294716	0.0788533881645130
0.193143008904441	0.167631814179908
0.291582698587355	0.306284442203354
0.411417436503224	0.397395169553808
0.628991936629693	0.630338016215000
0.169284716438347	0.136231147351670
0.660361352093842	0.693483286168938
0.298684102512952	0.308889826735502
0.561017462906972	0.537733929537329
0.384320533934917	0.351006848533773
0.221496695817379	0.242290898136858
0.296223865312843	0.308007240154126
0.492684728153993	0.512377513273133
0.634998145179060	0.647698764185059
0.625657915955206	0.620381266360076
0.0891660562602302	0.0994224254674202
0.0742086102687974	0.0936692472873702
0.626865780781351	0.623967608736031
0.441957883339557	0.472361875114246
0.0157191849402231	0.0406701224083588
0.368674561676324	0.337927787300071
0.627217178624091	0.625017672954424
0.253846304247748	0.286173545973494
0.206928529957397	0.201944099043841
0.180152118526450	0.147499079435076
0.155349951472829	0.125977748839280
0.249693546239582	0.282734942340199
0.417614851667815	0.415091347480293
0.197917834658727	0.177822530291359
0.427647487969975	0.443772210784627
0.578657203647726	0.547737030187470
0.0591439520498619	0.0864158098819885
0.504071412195014	0.516865018298665
0.439522870310511	0.468503499870856
0.340869575847922	0.323646896848778
0.641780686860910	0.664296267194680
0.0525360513031822	0.0824850997758311
0.419126940809084	0.419622744902513
0.179854209370586	0.147136383531176
0.0121642783764464	0.0330953951133111
0.412227568135505	0.399575401149883
0.604841264477251	0.574236148799604
0.402583176912226	0.377072606458986
0.603523505984253	0.572312679990502
0.158762606652888	0.128170576482325
0.427182367147573	0.442574474390365
0.129992272434756	0.113636969780914
0.206035553827662	0.199320192463950
0.00663776019535191	0.0192663492512525
0.477238063323136	0.504856882892435
0.651289847573691	0.681638072890728
0.524660387409372	0.523952681654443
0.532747289113760	0.526671056162354
0.394888276902191	0.364036428720525
0.203313142052076	0.191574879978562
0.186561941455151	0.156262844191082
0.204876351181546	0.195969153772640
0.657841775455899	0.690546398866122
0.178007951093262	0.144965532795643
0.392104389907160	0.360145588903358
0.158809561899624	0.128201991205053
0.298746011795488	0.308911800576876
0.0969730534561077	0.102123708000668
0.246251203302465	0.279566081649905
0.569115068981751	0.541869808520661
0.585272752016451	0.552669159180588
0.630492187715448	0.634815081327384
0.393905947217334	0.362620080614826
0.589081572083947	0.555927065039807
0.483579152818206	0.508204150647193
0.0171145855736783	0.0433778070676218
0.100884656221289	0.103439716366314
0.179709022044687	0.146960905786852
0.187562784860364	0.157818728729414
0.126211524490225	0.112184315667860
0.180790786560970	0.148288782503365
0.200791269966831	0.184854236366562
0.651021486883853	0.681227789754590
0.228673106482575	0.256666270032256
0.208782263150292	0.201734022274195
0.210503575395043	0.228067485736071
0.235895301376771	0.245240124609668
0.425622581490282	0.429764168993690
0.237837661606853	0.244385096841278
0.616955289459391	0.594545102911351
0.0376932276156365	0.0952635834401339
0.203176291567000	0.198030612099198
0.447104292379001	0.461498302305307
0.655886611466805	0.698000466195644
0.394438273213089	0.385519634230031
0.402654548839040	0.348495792061340
0.348759401363818	0.345472506033022
0.511054635769362	0.523696740564166
0.00316332449308267	0.0225122980416280
0.314818282272595	0.323955952269789
0.529065739377704	0.543710285781855
0.0108301689715363	0.0467766821290459
0.407258439427563	0.361599256201253
0.190151926658210	0.140125899091758
0.406026174866894	0.386696192598053
0.0317051763169514	0.0940907024508717
0.189882277107736	0.150579718983628
0.209949766870384	0.183492303743452
0.386129592852040	0.338293568165873
0.273665516040774	0.325379231947505
0.636167150788095	0.628100827832926
0.416613293692747	0.403688605299293
0.494069542440626	0.514373228515084
0.194609675237531	0.170155775899871
0.611898884546681	0.599735753459407
0.153748283237577	0.104443291773365
0.242711804416565	0.297690243571762
0.463156223019373	0.503563329092201
0.0960441400637809	0.0765666023019307
0.620558726023471	0.601270482591644
0.353506511758655	0.326842067996216
0.377410664344231	0.330857148517901
0.178583437819129	0.169600380065018
0.632238411096283	0.646156417339950
0.239280256965016	0.277966249392183
0.451355840070360	0.508200010155423
0.0450423759253215	0.0795679785032915
0.293007592738525	0.313851527339232
0.403511713579458	0.362465555452099
0.0596389824166379	0.0572870188772311
0.0256322708904266	0.0465414970455044
0.555865901075772	0.525529923690178
0.426530196385284	0.470439004166721
0.439215848521695	0.470378673630777
0.179018740264693	0.160527847630221
0.625290218639049	0.640887522129697
0.418410724661046	0.426759508002257
0.634549153440050	0.642461422657716
0.206114431420305	0.207463564791586
0.208116012762849	0.212796289939528
0.0192654880536059	0.0769589775293750
0.160271632281092	0.135640013867957
0.229165208106023	0.270581471939234
0.0103522921134037	0.00284882679711455
0.639277385492713	0.633878717364796
0.221977964704073	0.244950875702800
0.187530818135169	0.160342122941687
0.557219810661823	0.537308123068441
0.641357769311235	0.642198856948541
0.281351013720136	0.288938540729609
0.418774620556121	0.390754314844030
0.0283657133768362	0.0501683412053051
0.0542000361133447	0.0637325840016054
0.243177527351406	0.248837488278470
0.484352451963199	0.517065364596454
0.130032441190971	0.0873818795740706
0.271763748539046	0.295343215432390
0.156834016045304	0.149082996569873
0.418718636951629	0.389778692023851
0.628958675610216	0.627415103272291
0.0249749905201242	0.0326378191743664
0.684659242241295	0.735115810087691
0.424882915843142	0.408021868456478
0.217925167868618	0.231433096036625
0.0547546652403919	0.0881287495056403
0.219256756230378	0.233997915839033
0.411463504258259	0.372008467362171
0.0245806793943106	0.0437564438317034
0.381858587724410	0.352263974674522
0.184685085405949	0.178254802473518
0.594648982458728	0.558051748155136
0.671119229155300	0.708464413032835
0.180395208041911	0.166433113457345
0.316757849993797	0.300024130486124
0.653185585082381	0.662985980247932
0.375015494409235	0.370888752538540
0.439380980609484	0.475148149367430
0.335822265332407	0.326355097208739
0.633184990422054	0.635226744623378
0.176732026339695	0.138950147503156
0.200052880581140	0.179655779303571
0.429693006890599	0.450906533983589
0.253621749753599	0.308404280155269
0.209841851399314	0.204648160393728

 

 

 

训练数据

参考文献:赵志勇《python 机器学习算法》(程序)

周志华《机器学习》
 

                                                    

 

 

 

 

 

 

 

你可能感兴趣的:(机器学习)