机器学习入门~正规方程Normal equation

正规方程Normal equation

  • 对于某些机器学习问题,正规方程会给我们更好的方法来求解假设函数中的θ参数的最优值。
    机器学习入门~正规方程Normal equation_第1张图片
  • 梯度下降给出了一种通过不断迭代的方式,通过代价函数寻找θ最优值的解法。
  • 而正规方程给出了求解θ的解析解法,即不必运行迭代函数,而是直接一次性求解θ的最优值。
    解决方法:
    机器学习入门~正规方程Normal equation_第2张图片
    代价函数求导,并令导数为零,求解出使得导数为零的参数θx
    但实际问题中,参数θ是一个n+1维的向量,也就是θ0到θm的函数。
    机器学习入门~正规方程Normal equation_第3张图片
    解决的方法是对每一个参数θ求偏导,然后把它们全部置零,解出θ0到θn
    例:
    假设有一个m = 4的训练样本。
    机器学习入门~正规方程Normal equation_第4张图片
    为了实现正规方程的解法,加上一个特征量x0 = 1,并将所有特征量写入一个矩阵当中,如下所示。对训练样本对应的结果y也构建相同的矩阵。
    机器学习入门~正规方程Normal equation_第5张图片
    机器学习入门~正规方程Normal equation_第6张图片
    所以,X (X也被称为设计矩阵Design Matrix) 是一个m * (n+1)维的矩阵,y是一个m维向量,其中m是训练样本的数量,n是特征变量数,其实是 n+1 (算上恒等于一的x0)。
    之后,只需执行这一步运算,就可以得到使得代价函数最小化的θ。(θ是由θ0 ~ θn+1 组成的向量)
    在这里插入图片描述
  • 如何构建矩阵X
    机器学习入门~正规方程Normal equation_第7张图片
    取训练集中的一个样本,将它的所有特征量写成一个向量,经转置按顺序加入到矩阵X当中。
    对于y,只需将所有训练样本的结果加入到向量中构成一个m维向量即可。(注:m表示训练样本的个数。
    机器学习入门~正规方程Normal equation_第8张图片
  • 通过Octave计算正规方程
    机器学习入门~正规方程Normal equation_第9张图片
    pinv():将传入的参数(一个矩阵)求逆。
  • 是否需要特征缩放
    如果使用正规方程法求解参数θ,则不需要进行特征缩放。
  • 两种方法的适用情况

机器学习入门~正规方程Normal equation_第10张图片
梯度下降法的缺点:
·梯度下降法需要选择学习率α。
·梯度下降需要多次迭代。
梯度下降法的优点:
·当n( 特征量 xi 的个数 )很大时,也能运行的很好。
正规方程的缺点:
·n很大时,运行的速度很慢(对于大多数算法,求解矩阵的逆时的时间复杂度时O(n3))。因此,当n很大时,倾向于使用梯度下降法 ( n = 10000时,可以开始考虑使用梯度下降法 ) 。
·正规方程并不适用于更复杂的算法,届时仍需使用梯度下降法。

矩阵方程及不可逆性

**问题:**倘若在使用正规方程法时,矩阵XT * X是不可逆的,该如何解决? ( 即这个矩阵的奇异/退化矩阵 )
机器学习入门~正规方程Normal equation_第11张图片
尽管矩阵可能会不可逆,但是当在Octave中运行正规方程的公式时,大概率会得到正确的结果。
机器学习入门~正规方程Normal equation_第12张图片
如果发现XT * X时奇异矩阵:
①看一下特征量中是否有冗余的特征。(如平方英尺和平方米是相同意义的特征,保留一个即可)
②检查特征量是否过多,过多的情况下,如果少一些不影响,可以删除一些特征,或使用正规化的方法。

Python 实现正规方程

①首先需要引入numpy库,以方便进行矩阵运算。对于IDLE使用者,需要在windows的cmd指令窗口使用pip install numpy,而对于Pycharm使用者,需要按顺序操作:File->Settings->Project:XXX(你存储Pycharm编译文件的文件夹)->Python Interpretor->"+"->输入numpy导入,耐心等待安装(通常很快,不会超过二十秒)。
②之后引入numpy,即import numpy as np
求矩阵的逆与转置:numpy.linalg.inv(a)numpy.linalg.pinv(a),inv()为求逆,适用于传入的矩阵不是奇异矩阵的情况,pinv()为求伪逆,适用于任何情况。inv()的速度必pinv()快。求转置以矩阵X为例,进行操作X.T即可。
矩阵的点乘: Python中用 “@” 实现点乘。
⑤完整代码:
注:参考自csdn博主 koala_cola,原博文地址。

 def normalEqn(X, y):
   theta = np.linalg.pinv(X.T@X)@X.T@y 
   return theta

你可能感兴趣的:(机器学习,机器学习,python)