树回归

原理:

  • 将数据集切分成很多份易建模的数据
  • 利用线性回归技术建模

优点

  • 可以对复杂和非线性的数据建模

缺点

  • 结果不易理解

适用数据类型

  • 数值型和标称型数据

选择最佳特征之后,数据划分方法:

  • ID3: 按最佳特征的所有可能取值来划分
  • CART:二元切分法
import numpy as np

定义加载数据函数: x,y都放在一个dataMat里

def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split()
        fltLine = list(map(float,curLine)) #转换成list形式,map()得到是一个对象
        dataMat.append(fltLine)
    return dataMat

定义二元切分法

#返回两个数据集,大于value或小于value
def binSplitDataSet(dataSet, feature, value):
#     #dataSet为array
#     mat0 = dataSet[dataSet[:,feature] > value] 
#     mat1 = dataSet[dataSet[:,feature] <= value]

    #dataSet为matrix
    mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:] #np.nonzero返回非零元素的位置,(行,列)
    mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0,mat1

定义构建树的函数

#计算目标变量的均值
def regLeaf(dataSet):
    return np.mean(dataSet[:,-1])

#计算目标变量的总方差
def regErr(dataSet):
    return np.var(dataSet[:,-1])*dataSet.shape[0]
def createTree(dataSet, leafType = regLeaf, errType= regErr, ops=(1,4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None:
        return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree
测试函数
testMat = np.mat(np.eye(4)) #matrix
# testMat = np.eye(4) #array
testMat
matrix([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
mat0,mat1 = binSplitDataSet(testMat,1,0)
print(mat0)
print(mat1)
[[0. 1. 0. 0.]]
[[1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]

回归树的切分函数

# #计算目标变量的均值
# def regLeaf(dataSet):
#     return np.mean(dataSet[:,-1])

# #计算目标变量的总方差
# def regErr(dataSet):
#     return np.var(dataSet[:,-1])*dataSet.shape[0]

#选择最优划分特征
def chooseBestSplit(dataSet, leafType = regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0] #容许的误差下降值
    tolN=ops[1] #切分的最少样本数
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #如果目标变量值相同,则返回预测值(叶子节点)
        return None, leafType(dataSet)
    m,n = dataSet.shape
    S = errType(dataSet) #初始误差
    bestS = np.inf #初始最小误差为无穷大
    bestIndex = 0;bestValue = 0
    for featIndex in range(n-1): #遍历所有特征,除掉最后一列的目标变量
        for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
            mat0 ,mat1 = binSplitDataSet(dataSet,featIndex,splitVal)
            
            #若划分的数据集少于自定义的最少样本数,则不划分
            if (mat0.shape[0] < tolN) or (mat1.shape[0] < tolN):continue 
                
            newS = errType(mat0)+errType(mat1) #计算划分数据之后的误差
            if newS < bestS: #若划分之后的误差更小,则更新最小误差
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
                
    #若误差下降太小,则直接返回预测值
    if (S - bestS )< tolS:
        return None,leafType(dataSet)
    
    #否则,按最优特征和值,划分数据
    mat0,mat1 = binSplitDataSet(dataSet,bestIndex,bestValue)
    
    #若划分的数据集少于最少样本数,则返回预测值。
    #这里条件成立的情况只会是初始化的feature和value是最优的,因为上面循环已经有阈值判断的条件了
    if (mat0.shape[0] < tolN) or (mat1.shape[0]< tolN):
        return None,leafType(dataSet)
    return bestIndex,bestValue
测试数据
# %matplotlib inline
import matplotlib.pyplot as plt
def data2show(data):
    xArr = data[:,-2].A 
    yArr = data[:,-1].A 
    fig = plt.figure()
    # plt.grid()
    ax = fig.add_subplot(111)
    plt.scatter(xArr,yArr,s=5,label='raw data')
    ax.legend(loc='upper left')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    plt.show()
ex00.txt
myDat = loadDataSet('../../Reference Code/Ch09/ex00.txt')
myMat = np.mat(myDat)
print(createTree(myMat))
data2show(myMat)
{'spInd': 0, 'spVal': 0.48813, 'left': 1.0180967672413792, 'right': -0.04465028571428572}
output_17_1.png
  • 两个叶节点
myMat
matrix([[ 3.609800e-02,  1.550960e-01],
        [ 9.933490e-01,  1.077553e+00],
        [ 5.308970e-01,  8.934620e-01],
        [ 7.123860e-01,  5.648580e-01],
        [ 3.435540e-01, -3.717000e-01],
        [ 9.801600e-02, -3.327600e-01],
        [ 6.911150e-01,  8.343910e-01],
        [ 9.135800e-02,  9.993500e-02],
        [ 7.270980e-01,  1.000567e+00],
        [ 9.519490e-01,  9.452550e-01],
        [ 7.685960e-01,  7.602190e-01],
        [ 5.413140e-01,  8.937480e-01],
        [ 1.463660e-01,  3.428300e-02],
        [ 6.731950e-01,  9.150770e-01],
        [ 1.835100e-01,  1.848430e-01],
        [ 3.395630e-01,  2.067830e-01],
        [ 5.179210e-01,  1.493586e+00],
        [ 7.037550e-01,  1.101678e+00],
        [ 8.307000e-03,  6.997600e-02],
        [ 2.439090e-01, -2.946700e-02],
        [ 3.069640e-01, -1.773210e-01],
        [ 3.649200e-02,  4.081550e-01],
        [ 2.955110e-01,  2.882000e-03],
        [ 8.375220e-01,  1.229373e+00],
        [ 2.020540e-01, -8.774400e-02],
        [ 9.193840e-01,  1.029889e+00],
        [ 3.772010e-01, -2.435500e-01],
        [ 8.148250e-01,  1.095206e+00],
        [ 6.112700e-01,  9.820360e-01],
        [ 7.224300e-02, -4.209830e-01],
        [ 4.102300e-01,  3.317220e-01],
        [ 8.690770e-01,  1.114825e+00],
        [ 6.205990e-01,  1.334421e+00],
        [ 1.011490e-01,  6.883400e-02],
        [ 8.208020e-01,  1.325907e+00],
        [ 5.200440e-01,  9.619830e-01],
        [ 4.881300e-01, -9.779100e-02],
        [ 8.198230e-01,  8.352640e-01],
        [ 9.750220e-01,  6.735790e-01],
        [ 9.531120e-01,  1.064690e+00],
        [ 4.759760e-01, -1.637070e-01],
        [ 2.731470e-01, -4.552190e-01],
        [ 8.045860e-01,  9.240330e-01],
        [ 7.479500e-02, -3.496920e-01],
        [ 6.253360e-01,  6.236960e-01],
        [ 6.562180e-01,  9.585060e-01],
        [ 8.340780e-01,  1.010580e+00],
        [ 7.819300e-01,  1.074488e+00],
        [ 9.849000e-03,  5.659400e-02],
        [ 3.022170e-01, -1.486500e-01],
        [ 6.782870e-01,  9.077270e-01],
        [ 1.805060e-01,  1.036760e-01],
        [ 1.936410e-01, -3.275890e-01],
        [ 3.434790e-01,  1.752640e-01],
        [ 1.458090e-01,  1.369790e-01],
        [ 9.967570e-01,  1.035533e+00],
        [ 5.902100e-01,  1.336661e+00],
        [ 2.380700e-01, -3.584590e-01],
        [ 5.613620e-01,  1.070529e+00],
        [ 3.775970e-01,  8.850500e-02],
        [ 9.914200e-02,  2.528000e-02],
        [ 5.395580e-01,  1.053846e+00],
        [ 7.902400e-01,  5.332140e-01],
        [ 2.422040e-01,  2.093590e-01],
        [ 1.523240e-01,  1.328580e-01],
        [ 2.526490e-01, -5.561300e-02],
        [ 8.959300e-01,  1.077275e+00],
        [ 1.333000e-01, -2.231430e-01],
        [ 5.597630e-01,  1.253151e+00],
        [ 6.436650e-01,  1.024241e+00],
        [ 8.772410e-01,  7.970050e-01],
        [ 6.137650e-01,  1.621091e+00],
        [ 6.457620e-01,  1.026886e+00],
        [ 6.513760e-01,  1.315384e+00],
        [ 6.977180e-01,  1.212434e+00],
        [ 7.425270e-01,  1.087056e+00],
        [ 9.010560e-01,  1.055900e+00],
        [ 3.623140e-01, -5.564640e-01],
        [ 9.482680e-01,  6.318620e-01],
        [ 2.340000e-04,  6.090300e-02],
        [ 7.500780e-01,  9.062910e-01],
        [ 3.254120e-01, -2.192450e-01],
        [ 7.268280e-01,  1.017112e+00],
        [ 3.480130e-01,  4.893900e-02],
        [ 4.581210e-01, -6.145600e-02],
        [ 2.807380e-01, -2.288800e-01],
        [ 5.677040e-01,  9.690580e-01],
        [ 7.509180e-01,  7.481040e-01],
        [ 5.758050e-01,  8.990900e-01],
        [ 5.079400e-01,  1.107265e+00],
        [ 7.176900e-02, -1.109460e-01],
        [ 5.535200e-01,  1.391273e+00],
        [ 4.011520e-01, -1.216400e-01],
        [ 4.066490e-01, -3.663170e-01],
        [ 6.521210e-01,  1.004346e+00],
        [ 3.478370e-01, -1.534050e-01],
        [ 8.193100e-02, -2.697560e-01],
        [ 8.216480e-01,  1.280895e+00],
        [ 4.801400e-02,  6.449600e-02],
        [ 1.309620e-01,  1.842410e-01],
        [ 7.734220e-01,  1.125943e+00],
        [ 7.896250e-01,  5.526140e-01],
        [ 9.699400e-02,  2.271670e-01],
        [ 6.257910e-01,  1.244731e+00],
        [ 5.895750e-01,  1.185812e+00],
        [ 3.231810e-01,  1.808110e-01],
        [ 8.224430e-01,  1.086648e+00],
        [ 3.603230e-01, -2.048300e-01],
        [ 9.501530e-01,  1.022906e+00],
        [ 5.275050e-01,  8.795600e-01],
        [ 8.600490e-01,  7.174900e-01],
        [ 7.044000e-03,  9.415000e-02],
        [ 4.383670e-01,  3.401400e-02],
        [ 5.745730e-01,  1.066130e+00],
        [ 5.366890e-01,  8.672840e-01],
        [ 7.821670e-01,  8.860490e-01],
        [ 9.898880e-01,  7.442070e-01],
        [ 7.614740e-01,  1.058262e+00],
        [ 9.854250e-01,  1.227946e+00],
        [ 1.325430e-01, -3.293720e-01],
        [ 3.469860e-01, -1.503890e-01],
        [ 7.687840e-01,  8.997050e-01],
        [ 8.489210e-01,  1.170959e+00],
        [ 4.492800e-01,  6.909800e-02],
        [ 6.617200e-02,  5.243900e-02],
        [ 8.137190e-01,  7.066010e-01],
        [ 6.619230e-01,  7.670400e-01],
        [ 5.294910e-01,  1.022206e+00],
        [ 8.464550e-01,  7.200300e-01],
        [ 4.486560e-01,  2.697400e-02],
        [ 7.950720e-01,  9.657210e-01],
        [ 1.181560e-01, -7.740900e-02],
        [ 8.424800e-02, -1.954700e-02],
        [ 8.458150e-01,  9.526170e-01],
        [ 5.769460e-01,  1.234129e+00],
        [ 7.720830e-01,  1.299018e+00],
        [ 6.966480e-01,  8.454230e-01],
        [ 5.950120e-01,  1.213435e+00],
        [ 6.486750e-01,  1.287407e+00],
        [ 8.970940e-01,  1.240209e+00],
        [ 5.529900e-01,  1.036158e+00],
        [ 3.329820e-01,  2.100840e-01],
        [ 6.561500e-02, -3.069700e-01],
        [ 2.786610e-01,  2.536280e-01],
        [ 7.731680e-01,  1.140917e+00],
        [ 2.036930e-01, -6.403600e-02],
        [ 3.556880e-01, -1.193990e-01],
        [ 9.888520e-01,  1.069062e+00],
        [ 5.187350e-01,  1.037179e+00],
        [ 5.145630e-01,  1.156648e+00],
        [ 9.764140e-01,  8.629110e-01],
        [ 9.190740e-01,  1.123413e+00],
        [ 6.977770e-01,  8.278050e-01],
        [ 9.280970e-01,  8.832250e-01],
        [ 9.002720e-01,  9.968710e-01],
        [ 3.441020e-01, -6.153900e-02],
        [ 1.480490e-01,  2.042980e-01],
        [ 1.300520e-01, -2.616700e-02],
        [ 3.020010e-01,  3.171350e-01],
        [ 3.371000e-01,  2.633200e-02],
        [ 3.149240e-01, -1.952000e-03],
        [ 2.696810e-01, -1.659710e-01],
        [ 1.960050e-01, -4.884700e-02],
        [ 1.290610e-01,  3.051070e-01],
        [ 9.367830e-01,  1.026258e+00],
        [ 3.055400e-01, -1.159910e-01],
        [ 6.839210e-01,  1.414382e+00],
        [ 6.223980e-01,  7.663300e-01],
        [ 9.025320e-01,  8.616010e-01],
        [ 7.125030e-01,  9.334900e-01],
        [ 5.900620e-01,  7.055310e-01],
        [ 7.231200e-01,  1.307248e+00],
        [ 1.882180e-01,  1.136850e-01],
        [ 6.436010e-01,  7.825520e-01],
        [ 5.202070e-01,  1.209557e+00],
        [ 2.331150e-01, -3.481470e-01],
        [ 4.656250e-01, -1.529400e-01],
        [ 8.845120e-01,  1.117833e+00],
        [ 6.632000e-01,  7.016340e-01],
        [ 2.688570e-01,  7.344700e-02],
        [ 7.292340e-01,  9.319560e-01],
        [ 4.296640e-01, -1.886590e-01],
        [ 7.371890e-01,  1.200781e+00],
        [ 3.785950e-01, -2.960940e-01],
        [ 9.301730e-01,  1.035645e+00],
        [ 7.743010e-01,  8.367630e-01],
        [ 2.739400e-01, -8.571300e-02],
        [ 8.244420e-01,  1.082153e+00],
        [ 6.260110e-01,  8.405440e-01],
        [ 6.793900e-01,  1.307217e+00],
        [ 5.782520e-01,  9.218850e-01],
        [ 7.855410e-01,  1.165296e+00],
        [ 5.974090e-01,  9.747700e-01],
        [ 1.408300e-02, -1.325250e-01],
        [ 6.638700e-01,  1.187129e+00],
        [ 5.523810e-01,  1.369630e+00],
        [ 6.838860e-01,  9.999850e-01],
        [ 2.103340e-01, -6.899000e-03],
        [ 6.045290e-01,  1.212685e+00],
        [ 2.507440e-01,  4.629700e-02]])
ex0.txt
myDat = loadDataSet('../../Reference Code/Ch09/ex0.txt')
myMat = np.mat(myDat)
print(createTree(myMat))
data2show(myMat)
{'spInd': 1, 'spVal': 0.39435, 'left': {'spInd': 1, 'spVal': 0.582002, 'left': {'spInd': 1, 'spVal': 0.797583, 'left': 3.9871632, 'right': 2.9836209534883724}, 'right': 1.980035071428571}, 'right': {'spInd': 1, 'spVal': 0.197834, 'left': 1.0289583666666666, 'right': -0.023838155555555553}}
output_21_1.png
  • 五个叶节点,即五个预测值
  • 总方差越小,划分的数据集越集中。
myMat
matrix([[ 1.000000e+00,  4.091750e-01,  1.883180e+00],
        [ 1.000000e+00,  1.826030e-01,  6.390800e-02],
        [ 1.000000e+00,  6.636870e-01,  3.042257e+00],
        [ 1.000000e+00,  5.173950e-01,  2.305004e+00],
        [ 1.000000e+00,  1.364300e-02, -6.769800e-02],
        [ 1.000000e+00,  4.696430e-01,  1.662809e+00],
        [ 1.000000e+00,  7.254260e-01,  3.275749e+00],
        [ 1.000000e+00,  3.943500e-01,  1.118077e+00],
        [ 1.000000e+00,  5.077600e-01,  2.095059e+00],
        [ 1.000000e+00,  2.373950e-01,  1.181912e+00],
        [ 1.000000e+00,  5.753400e-02,  2.216630e-01],
        [ 1.000000e+00,  3.698200e-01,  9.384530e-01],
        [ 1.000000e+00,  9.768190e-01,  4.149409e+00],
        [ 1.000000e+00,  6.160510e-01,  3.105444e+00],
        [ 1.000000e+00,  4.137000e-01,  1.896278e+00],
        [ 1.000000e+00,  1.052790e-01, -1.213450e-01],
        [ 1.000000e+00,  6.702730e-01,  3.161652e+00],
        [ 1.000000e+00,  9.527580e-01,  4.135358e+00],
        [ 1.000000e+00,  2.723160e-01,  8.590630e-01],
        [ 1.000000e+00,  3.036970e-01,  1.170272e+00],
        [ 1.000000e+00,  4.866980e-01,  1.687960e+00],
        [ 1.000000e+00,  5.118100e-01,  1.979745e+00],
        [ 1.000000e+00,  1.958650e-01,  6.869000e-02],
        [ 1.000000e+00,  9.867690e-01,  4.052137e+00],
        [ 1.000000e+00,  7.856230e-01,  3.156316e+00],
        [ 1.000000e+00,  7.975830e-01,  2.950630e+00],
        [ 1.000000e+00,  8.130600e-02,  6.893500e-02],
        [ 1.000000e+00,  6.597530e-01,  2.854020e+00],
        [ 1.000000e+00,  3.752700e-01,  9.997430e-01],
        [ 1.000000e+00,  8.191360e-01,  4.048082e+00],
        [ 1.000000e+00,  1.424320e-01,  2.309230e-01],
        [ 1.000000e+00,  2.151120e-01,  8.166930e-01],
        [ 1.000000e+00,  4.127000e-02,  1.307130e-01],
        [ 1.000000e+00,  4.413600e-02, -5.377060e-01],
        [ 1.000000e+00,  1.313370e-01, -3.391090e-01],
        [ 1.000000e+00,  4.634440e-01,  2.124538e+00],
        [ 1.000000e+00,  6.719050e-01,  2.708292e+00],
        [ 1.000000e+00,  9.465590e-01,  4.017390e+00],
        [ 1.000000e+00,  9.041760e-01,  4.004021e+00],
        [ 1.000000e+00,  3.066740e-01,  1.022555e+00],
        [ 1.000000e+00,  8.190060e-01,  3.657442e+00],
        [ 1.000000e+00,  8.454720e-01,  4.073619e+00],
        [ 1.000000e+00,  1.562580e-01,  1.199400e-02],
        [ 1.000000e+00,  8.571850e-01,  3.640429e+00],
        [ 1.000000e+00,  4.001580e-01,  1.808497e+00],
        [ 1.000000e+00,  3.753950e-01,  1.431404e+00],
        [ 1.000000e+00,  8.858070e-01,  3.935544e+00],
        [ 1.000000e+00,  2.399600e-01,  1.162152e+00],
        [ 1.000000e+00,  1.486400e-01, -2.273300e-01],
        [ 1.000000e+00,  1.431430e-01, -6.872800e-02],
        [ 1.000000e+00,  3.215820e-01,  8.250510e-01],
        [ 1.000000e+00,  5.093930e-01,  2.008645e+00],
        [ 1.000000e+00,  3.558910e-01,  6.645660e-01],
        [ 1.000000e+00,  9.386330e-01,  4.180202e+00],
        [ 1.000000e+00,  3.480570e-01,  8.648450e-01],
        [ 1.000000e+00,  4.388980e-01,  1.851174e+00],
        [ 1.000000e+00,  7.814190e-01,  2.761993e+00],
        [ 1.000000e+00,  9.113330e-01,  4.075914e+00],
        [ 1.000000e+00,  3.246900e-02,  1.102290e-01],
        [ 1.000000e+00,  4.999850e-01,  2.181987e+00],
        [ 1.000000e+00,  7.716630e-01,  3.152528e+00],
        [ 1.000000e+00,  6.703610e-01,  3.046564e+00],
        [ 1.000000e+00,  1.762020e-01,  1.289540e-01],
        [ 1.000000e+00,  3.921700e-01,  1.062726e+00],
        [ 1.000000e+00,  9.111880e-01,  3.651742e+00],
        [ 1.000000e+00,  8.722880e-01,  4.401950e+00],
        [ 1.000000e+00,  7.331070e-01,  3.022888e+00],
        [ 1.000000e+00,  6.102390e-01,  2.874917e+00],
        [ 1.000000e+00,  7.327390e-01,  2.946801e+00],
        [ 1.000000e+00,  7.148250e-01,  2.893644e+00],
        [ 1.000000e+00,  7.638600e-02,  7.213100e-02],
        [ 1.000000e+00,  5.590090e-01,  1.748275e+00],
        [ 1.000000e+00,  4.272580e-01,  1.912047e+00],
        [ 1.000000e+00,  8.418750e-01,  3.710686e+00],
        [ 1.000000e+00,  5.589180e-01,  1.719148e+00],
        [ 1.000000e+00,  5.332410e-01,  2.174090e+00],
        [ 1.000000e+00,  9.566650e-01,  3.656357e+00],
        [ 1.000000e+00,  6.203930e-01,  3.522504e+00],
        [ 1.000000e+00,  5.661200e-01,  2.234126e+00],
        [ 1.000000e+00,  5.232580e-01,  1.859772e+00],
        [ 1.000000e+00,  4.768840e-01,  2.097017e+00],
        [ 1.000000e+00,  1.764080e-01,  1.794000e-03],
        [ 1.000000e+00,  3.030940e-01,  1.231928e+00],
        [ 1.000000e+00,  6.097310e-01,  2.953862e+00],
        [ 1.000000e+00,  1.777400e-02, -1.168030e-01],
        [ 1.000000e+00,  6.226160e-01,  2.638864e+00],
        [ 1.000000e+00,  8.865390e-01,  3.943428e+00],
        [ 1.000000e+00,  1.486540e-01, -3.285130e-01],
        [ 1.000000e+00,  1.043500e-01, -9.986600e-02],
        [ 1.000000e+00,  1.168680e-01, -3.083600e-02],
        [ 1.000000e+00,  5.165140e-01,  2.359786e+00],
        [ 1.000000e+00,  6.648960e-01,  3.212581e+00],
        [ 1.000000e+00,  4.327000e-03,  1.889750e-01],
        [ 1.000000e+00,  4.255590e-01,  1.904109e+00],
        [ 1.000000e+00,  7.436710e-01,  3.007114e+00],
        [ 1.000000e+00,  9.351850e-01,  3.845834e+00],
        [ 1.000000e+00,  6.973000e-01,  3.079411e+00],
        [ 1.000000e+00,  4.445510e-01,  1.939739e+00],
        [ 1.000000e+00,  6.837530e-01,  2.880078e+00],
        [ 1.000000e+00,  7.559930e-01,  3.063577e+00],
        [ 1.000000e+00,  9.026900e-01,  4.116296e+00],
        [ 1.000000e+00,  9.449100e-02, -2.409630e-01],
        [ 1.000000e+00,  8.738310e-01,  4.066299e+00],
        [ 1.000000e+00,  9.918100e-01,  4.011834e+00],
        [ 1.000000e+00,  1.856110e-01,  7.771000e-02],
        [ 1.000000e+00,  6.945510e-01,  3.103069e+00],
        [ 1.000000e+00,  6.572750e-01,  2.811897e+00],
        [ 1.000000e+00,  1.187460e-01, -1.046300e-01],
        [ 1.000000e+00,  8.430200e-02,  2.521600e-02],
        [ 1.000000e+00,  9.453410e-01,  4.330063e+00],
        [ 1.000000e+00,  7.858270e-01,  3.087091e+00],
        [ 1.000000e+00,  5.309330e-01,  2.269988e+00],
        [ 1.000000e+00,  8.795940e-01,  4.010701e+00],
        [ 1.000000e+00,  6.527700e-01,  3.119542e+00],
        [ 1.000000e+00,  8.793380e-01,  3.723411e+00],
        [ 1.000000e+00,  7.647390e-01,  2.792078e+00],
        [ 1.000000e+00,  5.048840e-01,  2.192787e+00],
        [ 1.000000e+00,  5.542030e-01,  2.081305e+00],
        [ 1.000000e+00,  4.932090e-01,  1.714463e+00],
        [ 1.000000e+00,  3.637830e-01,  8.858540e-01],
        [ 1.000000e+00,  3.164650e-01,  1.028187e+00],
        [ 1.000000e+00,  5.802830e-01,  1.951497e+00],
        [ 1.000000e+00,  5.428980e-01,  1.709427e+00],
        [ 1.000000e+00,  1.126610e-01,  1.440680e-01],
        [ 1.000000e+00,  8.167420e-01,  3.880240e+00],
        [ 1.000000e+00,  2.341750e-01,  9.218760e-01],
        [ 1.000000e+00,  4.028040e-01,  1.979316e+00],
        [ 1.000000e+00,  7.094230e-01,  3.085768e+00],
        [ 1.000000e+00,  8.672980e-01,  3.476122e+00],
        [ 1.000000e+00,  9.933920e-01,  3.993679e+00],
        [ 1.000000e+00,  7.115800e-01,  3.077880e+00],
        [ 1.000000e+00,  1.336430e-01, -1.053650e-01],
        [ 1.000000e+00,  5.203100e-02, -1.647030e-01],
        [ 1.000000e+00,  3.668060e-01,  1.096814e+00],
        [ 1.000000e+00,  6.975210e-01,  3.092879e+00],
        [ 1.000000e+00,  7.872620e-01,  2.987926e+00],
        [ 1.000000e+00,  4.767100e-01,  2.061264e+00],
        [ 1.000000e+00,  7.214170e-01,  2.746854e+00],
        [ 1.000000e+00,  2.303760e-01,  7.167100e-01],
        [ 1.000000e+00,  1.043970e-01,  1.038310e-01],
        [ 1.000000e+00,  1.978340e-01,  2.377600e-02],
        [ 1.000000e+00,  1.292910e-01, -3.329900e-02],
        [ 1.000000e+00,  5.285280e-01,  1.942286e+00],
        [ 1.000000e+00,  9.493000e-03, -6.338000e-03],
        [ 1.000000e+00,  9.985330e-01,  3.808753e+00],
        [ 1.000000e+00,  3.635220e-01,  6.527990e-01],
        [ 1.000000e+00,  9.013860e-01,  4.053747e+00],
        [ 1.000000e+00,  8.326930e-01,  4.569290e+00],
        [ 1.000000e+00,  1.190020e-01, -3.277300e-02],
        [ 1.000000e+00,  4.876380e-01,  2.066236e+00],
        [ 1.000000e+00,  1.536670e-01,  2.227850e-01],
        [ 1.000000e+00,  2.386190e-01,  1.089268e+00],
        [ 1.000000e+00,  2.081970e-01,  1.487788e+00],
        [ 1.000000e+00,  7.509210e-01,  2.852033e+00],
        [ 1.000000e+00,  1.834030e-01,  2.448600e-02],
        [ 1.000000e+00,  9.956080e-01,  3.737750e+00],
        [ 1.000000e+00,  1.513110e-01,  4.501700e-02],
        [ 1.000000e+00,  1.268040e-01,  1.238000e-03],
        [ 1.000000e+00,  9.831530e-01,  3.892763e+00],
        [ 1.000000e+00,  7.724950e-01,  2.819376e+00],
        [ 1.000000e+00,  7.841330e-01,  2.830665e+00],
        [ 1.000000e+00,  5.693400e-02,  2.346330e-01],
        [ 1.000000e+00,  4.255840e-01,  1.810782e+00],
        [ 1.000000e+00,  9.987090e-01,  4.237235e+00],
        [ 1.000000e+00,  7.078150e-01,  3.034768e+00],
        [ 1.000000e+00,  4.138160e-01,  1.742106e+00],
        [ 1.000000e+00,  2.171520e-01,  1.169250e+00],
        [ 1.000000e+00,  3.605030e-01,  8.311650e-01],
        [ 1.000000e+00,  9.779890e-01,  3.729376e+00],
        [ 1.000000e+00,  5.079530e-01,  1.823205e+00],
        [ 1.000000e+00,  9.207710e-01,  4.021970e+00],
        [ 1.000000e+00,  2.105420e-01,  1.262939e+00],
        [ 1.000000e+00,  9.286110e-01,  4.159518e+00],
        [ 1.000000e+00,  5.803730e-01,  2.039114e+00],
        [ 1.000000e+00,  8.413900e-01,  4.101837e+00],
        [ 1.000000e+00,  6.815300e-01,  2.778672e+00],
        [ 1.000000e+00,  2.927950e-01,  1.228284e+00],
        [ 1.000000e+00,  4.569180e-01,  1.736620e+00],
        [ 1.000000e+00,  1.341280e-01, -1.950460e-01],
        [ 1.000000e+00,  1.624100e-02, -6.321500e-02],
        [ 1.000000e+00,  6.912140e-01,  3.305268e+00],
        [ 1.000000e+00,  5.820020e-01,  2.063627e+00],
        [ 1.000000e+00,  3.031020e-01,  8.988400e-01],
        [ 1.000000e+00,  6.225980e-01,  2.701692e+00],
        [ 1.000000e+00,  5.250240e-01,  1.992909e+00],
        [ 1.000000e+00,  9.967750e-01,  3.811393e+00],
        [ 1.000000e+00,  8.810250e-01,  4.353857e+00],
        [ 1.000000e+00,  7.234570e-01,  2.635641e+00],
        [ 1.000000e+00,  6.763460e-01,  2.856311e+00],
        [ 1.000000e+00,  2.546250e-01,  1.352682e+00],
        [ 1.000000e+00,  4.886320e-01,  2.336459e+00],
        [ 1.000000e+00,  5.198750e-01,  2.111651e+00],
        [ 1.000000e+00,  1.601760e-01,  1.217260e-01],
        [ 1.000000e+00,  6.094830e-01,  3.264605e+00],
        [ 1.000000e+00,  5.318810e-01,  2.103446e+00],
        [ 1.000000e+00,  3.216320e-01,  8.968550e-01],
        [ 1.000000e+00,  8.451480e-01,  4.220850e+00],
        [ 1.000000e+00,  1.200300e-02, -2.172830e-01],
        [ 1.000000e+00,  1.888300e-02, -3.005770e-01],
        [ 1.000000e+00,  7.147600e-02,  6.014000e-03]])

预剪枝

  • tolS和tolN其实是预剪枝操作

  • 调节tolS为0,tolN为1,容许误差下降为0,最少划分样本为1,则分得很多叶节点

createTree(myMat,ops=(0,1))
{'spInd': 1,
 'spVal': 0.39435,
 'left': {'spInd': 1,
  'spVal': 0.582002,
  'left': {'spInd': 1,
   'spVal': 0.797583,
   'left': {'spInd': 1,
    'spVal': 0.819006,
    'left': {'spInd': 1,
     'spVal': 0.832693,
     'left': {'spInd': 1,
      'spVal': 0.867298,
      'left': {'spInd': 1,
       'spVal': 0.872288,
       'left': {'spInd': 1,
        'spVal': 0.952758,
        'left': {'spInd': 1,
         'spVal': 0.998533,
         'left': 4.237235,
         'right': {'spInd': 1,
          'spVal': 0.956665,
          'left': {'spInd': 1,
           'spVal': 0.993392,
           'left': {'spInd': 1,
            'spVal': 0.995608,
        ...
myDat = loadDataSet('../../Reference Code/Ch09/ex2.txt')
myMat = np.mat(myDat)
print(createTree(myMat))
data2show(myMat)
{'spInd': 0, 'spVal': 0.499171, 'left': {'spInd': 0, 'spVal': 0.729397, 'left': {'spInd': 0, 'spVal': 0.952833, 'left': {'spInd': 0, 'spVal': 0.958512, 'left': 105.24862350000001, 'right': 112.42895575000001}, 'right': {'spInd': 0, 'spVal': 0.759504, 'left': {'spInd': 0, 'spVal': 0.790312, 'left': {'spInd': 0, 'spVal': 0.833026, 'left': {'spInd': 0, 'spVal': 0.944221, 'left': 87.3103875, 'right': {'spInd': 0, 'spVal': 0.85497, 'left': {'spInd': 0, 'spVal': 0.910975, 'left': 96.452867, 'right': {'spInd': 0, 'spVal': 0.892999, 'left': 104.825409, 'right': {'spInd': 0, 'spVal': 0.872883, 'left': 95.181793, 'right': 102.25234449999999}}}, 'right': 95.27584316666666}}, 'right': {'spInd': 0, 'spVal': 0.811602, 'left': 81.110152, 'right': 88.78449880000001}}, 'right': 102.35780185714285}, 'right': 78.08564325}}, 'right': {'spInd': 0, 'spVal': 0.640515, 'left': {'spInd': 0, 'spVal': 0.666452, 'left': {'spInd': 0, 'spVal': 0.706961, 'left': 114.554706, 'right': {'spInd': 0, 'spVal': 0.698472, 'left': 104.82495374999999, 'right': 108.92921799999999}}, 'right': 114.1516242857143}, 'right': {'spInd': 0, 'spVal': 0.613004, 'left': 93.67344971428572, 'right': {'spInd': 0, 'spVal': 0.582311, 'left': 123.2101316, 'right': {'spInd': 0, 'spVal': 0.553797, 'left': 97.20018024999999, 'right': {'spInd': 0, 'spVal': 0.51915, 'left': {'spInd': 0, 'spVal': 0.543843, 'left': 109.38961049999999, 'right': 110.979946}, 'right': 101.73699325000001}}}}}}, 'right': {'spInd': 0, 'spVal': 0.457563, 'left': {'spInd': 0, 'spVal': 0.467383, 'left': 12.50675925, 'right': 3.4331330000000007}, 'right': {'spInd': 0, 'spVal': 0.126833, 'left': {'spInd': 0, 'spVal': 0.373501, 'left': {'spInd': 0, 'spVal': 0.437652, 'left': -12.558604833333334, 'right': {'spInd': 0, 'spVal': 0.412516, 'left': 14.38417875, 'right': {'spInd': 0, 'spVal': 0.385021, 'left': -0.8923554999999995, 'right': 3.6584772500000016}}}, 'right': {'spInd': 0, 'spVal': 0.335182, 'left': {'spInd': 0, 'spVal': 0.350725, 'left': -15.08511175, 'right': -22.693879600000002}, 'right': {'spInd': 0, 'spVal': 0.324274, 'left': 15.05929075, 'right': {'spInd': 0, 'spVal': 0.297107, 'left': -19.9941552, 'right': {'spInd': 0, 'spVal': 0.166765, 'left': {'spInd': 0, 'spVal': 0.202161, 'left': {'spInd': 0, 'spVal': 0.217214, 'left': {'spInd': 0, 'spVal': 0.228473, 'left': {'spInd': 0, 'spVal': 0.25807, 'left': 0.40377471428571476, 'right': -13.070501}, 'right': 6.770429}, 'right': -11.822278500000001}, 'right': 3.4496025}, 'right': {'spInd': 0, 'spVal': 0.156067, 'left': -12.1079725, 'right': -6.247900000000001}}}}}}, 'right': {'spInd': 0, 'spVal': 0.084661, 'left': 6.509843285714284, 'right': {'spInd': 0, 'spVal': 0.044737, 'left': -2.544392714285715, 'right': 4.091626}}}}}
output_27_1.png
  • 这个数据集的x2取值范围很大,所以目标变量的总方差会更大(regErr更大),所以当tolS=1时候,容忍的误差下降值对于ex2.txt数据集来说太小了,所以会划分到很多叶节点,效果差
myDat = loadDataSet('../../Reference Code/Ch09/ex2.txt')
myMat2 = np.mat(myDat)
print(createTree(myMat2,ops=(10000,4)))
data2show(myMat2)
{'spInd': 0, 'spVal': 0.499171, 'left': 101.35815937735848, 'right': -2.637719329787234}
output_29_1.png
  • 增大tolS,构建的树只有两个节点,效果好

后剪枝

1. 若是子树,返回True
def isTree(obj):
    istree = (type(obj).__name__ == 'dict')
    return istree
2. 合并叶结点,返回平均值
def getMean(tree):
    if isTree(tree['right']): 
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']): 
        tree['left'] = getMean(tree['left'])
    merge = (tree['right']+tree['left'])/2.0
    return merge
3. 剪枝,判断合并后的误差是否比合并前更小
def prune(tree,testData):
    #没有测试数据则对树进行坍塌处理:返回所有叶节点的均值
    if testData.shape[0] == 0:
        return getMean(tree)
    
    #若有子树,则划分
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet, rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])
    #左子树若有子树,则剪枝
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'],lSet)
    #右子树若有子树,则剪枝
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'],lSet)
        
    #若是叶节点,则计算合并之后的误差和合并之前的误差
    if not isTree(tree['right']) and not isTree(tree['left']):
        lSet, rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])

        #合并之前的总方差
        errorNoMerge = sum(np.power(lSet[:,-1] - tree['left'],2)) + sum(np.power(rSet[:,-1] - tree['right'],2))
        
        #均值
        treeMean = (tree['right']+tree['left'])/2.0
        #合并之后的总方差
        errorMerge = sum(np.power(testData[:,-1] - treeMean,2))

        #对比误差
        if errorMerge < errorNoMerge:
            print('merging')
            return treeMean
        else:
            return tree
    else:
        return tree
测试
myDat = loadDataSet('../../Reference Code/Ch09/ex2.txt')
myMat2 = np.mat(myDat)
myTree = createTree(myMat2,ops=(0,1))
print(myTree)
data2show(myMat2)
{'spInd': 0, 'spVal': 0.499171, 'left': {'spInd': 0, 'spVal': 0.729397, 'left': {'spInd': 0, 'spVal': 0.952833, 'left': {'spInd': 0, 'spVal': 0.965969, 'left': {'spInd': 0, 'spVal': 0.968621, 'left': 86.399637, 'right': 98.648346}, 'right': {'spInd': 0, 'spVal': 0.956951, 'left': {'spInd': 0, 'spVal': 0.958512, 'left': {'spInd': 0, 'spVal': 0.960398, 'left': 112.386764, 'right': 123.559747}, 'right': 135.837013}, 'right': {'spInd': 0, 'spVal': 0.953902, 'left': {'spInd': 0, 'spVal': 0.954711, 'left': 82.016541, 'right': 100.935789}, 'right': 130.92648}}}, 'right': {'spInd': 0, 'spVal': 0.759504, 'left': {'spInd': 0, 'spVal': 0.763328, 'left': {'spInd': 0, 'spVal': 0.769043, 'left': {'spInd': 0, 'spVal': 0.790312, 'left': {'spInd': 0, 'spVal': 0.806158, 'left': {'spInd': 0, 'spVal': 0.815215, 'left': {'spInd': 0, 'spVal': 0.833026, 'left': {'spInd': 0, 'spVal': 0.841547, 'left': {'spInd': 0, 'spVal': 0.841625, 'left': {'spInd': 0, 'spVal': 0.944221, 'left': {'spInd': 0, 'spVal': 0.948822, 'left': {'spInd': 0, 'spVal': 0.949198, 'left': {'spInd': 0, 'spVal': 0.952377, 'left': 100.649591, 'right': 73.520802}, 'right': 105.752508}, 'right': 69.318649}, 'right': {'spInd': 0, 'spVal': 0.85497, 'left': {'spInd': 0, 'spVal': 0.936524, 'left': {'spInd': 0, 'spVal': 0.937766, 'left': 100.120253, 'right': 119.949824}, 'right': {'spInd': 0, 'spVal': 0.934853, 'left': 65.548418, 'right': {'spInd': 0, 'spVal': 0.925782, 'left': 115.753994, 'right': {'spInd': 0, 'spVal': 0.910975, 'left': {'spInd': 0, 'spVal': 0.912161, 'left': {'spInd': 0, 'spVal': 0.915263, 'left': 92.074619, 'right': 96.71761}, 'right': 85.005351}, 'right': {'spInd': 0, 'spVal': 0.901444, 'left': {'spInd': 0, 'spVal': 0.908629, 'left': 106.814667, 'right': 118.513475}, 'right': {'spInd': 0, 'spVal': 0.901421, 'left': 87.300625, 'right': {'spInd': 0, 'spVal': 0.892999, 'left': {'spInd': 0, 'spVal': 0.900699, 'left': 100.133819, 'right': {'spInd': 0, 'spVal': 0.896683, 'left': 109.188248, 'right': 107.00162}}, 'right': {'spInd': 0, 'spVal': 0.888426, 'left': 82.436686, 'right': {'spInd': 0, 'spVal': 0.872199, 'left': {'spInd': 0, 'spVal': 0.883615, 'left': {'spInd': 0, 'spVal': 0.885676, 'left': 94.896354, 'right': 108.045948}, 'right': {'spInd': 0, 'spVal': 0.872883, 'left': 95.348184, 'right': 95.887712}}, 'right': {'spInd': 0, 'spVal': 0.866451, 'left': 111.552716, 'right': {'spInd': 0, 'spVal': 0.856421, 'left': 94.402102, 'right': 107.166848}}}}}}}}}}}, 'right': {'spInd': 0, 'spVal': 0.84294, 'left': {'spInd': 0, 'spVal': 0.847219, 'left': 89.20993, 'right': 76.240984}, 'right': 95.893131}}}, 'right': 60.552308}, 'right': {'spInd': 0, 'spVal': 0.838587, 'left': 115.669032, 'right': 134.089674}}, 'right': {'spInd': 0, 'spVal': 0.823848, 'left': 76.723835, 'right': {'spInd': 0, 'spVal': 0.819722, 'left': 59.342323, 'right': 70.054508}}}, 'right': {'spInd': 0, 'spVal': 0.811602, 'left': 118.319942, 'right': {'spInd': 0, 'spVal': 0.811363, 'left': 99.841379, 'right': 112.981216}}}, 'right': {'spInd': 0, 'spVal': 0.799873, 'left': 62.877698, 'right': {'spInd': 0, 'spVal': 0.798198, 'left': 91.368473, 'right': 76.853728}}}, 'right': {'spInd': 0, 'spVal': 0.786865, 'left': {'spInd': 0, 'spVal': 0.787755, 'left': 110.15973, 'right': 118.642009}, 'right': {'spInd': 0, 'spVal': 0.785574, 'left': 100.598825, 'right': {'spInd': 0, 'spVal': 0.777582, 'left': 107.024467, 'right': 100.838446}}}}, 'right': 64.041941}, 'right': 115.199195}, 'right': {'spInd': 0, 'spVal': 0.740859, 'left': {'spInd': 0, 'spVal': 0.757527, 'left': 81.106762, 'right': 63.549854}, 'right': {'spInd': 0, 'spVal': 0.731636, 'left': 93.773929, 'right': 73.912028}}}}, 'right': {'spInd': 0, 'spVal': 0.640515, 'left': {'spInd': 0, 'spVal': 0.642373, 'left': {'spInd': 0, 'spVal': 0.642707, 'left': {'spInd': 0, 'spVal': 0.665329, 'left': {'spInd': 0, 'spVal': 0.706961, 'left': {'spInd': 0, 'spVal': 0.70889, 'left': {'spInd': 0, 'spVal': 0.716211, 'left': 110.90283, 'right': {'spInd': 0, 'spVal': 0.710234, 'left': 103.345308, 'right': 108.553919}}, 'right': 135.416767}, 'right': {'spInd': 0, 'spVal': 0.698472, 'left': {'spInd': 0, 'spVal': 0.69892, 'left': {'spInd': 0, 'spVal': 0.699873, 'left': {'spInd': 0, 'spVal': 0.70639, 'left': 106.180427, 'right': 105.062147}, 'right': 115.586605}, 'right': 92.470636}, 'right': {'spInd': 0, 'spVal': 0.689099, 'left': 120.521925, 'right': {'spInd': 0, 'spVal': 0.666452, 'left': {'spInd': 0, 'spVal': 0.667851, 'left': {'spInd': 0, 'spVal': 0.680486, 'left': 112.378209, 'right': 110.367074}, 'right': 92.449664}, 'right': {'spInd': 0, 'spVal': 0.665652, 'left': 120.014736, 'right': 105.547997}}}}}, 'right': {'spInd': 0, 'spVal': 0.661073, 'left': 121.980607, 'right': {'spInd': 0, 'spVal': 0.652462, 'left': 115.687524, 'right': 112.715799}}}, 'right': 82.500766}, 'right': 140.613941}, 'right': {'spInd': 0, 'spVal': 0.613004, 'left': {'spInd': 0, 'spVal': 0.623909, 'left': {'spInd': 0, 'spVal': 0.628061, 'left': {'spInd': 0, 'spVal': 0.637999, 'left': 82.713621, 'right': {'spInd': 0, 'spVal': 0.632691, 'left': 91.656617, 'right': 93.645293}}, 'right': {'spInd': 0, 'spVal': 0.624827, 'left': 117.628346, 'right': 105.970743}}, 'right': {'spInd': 0, 'spVal': 0.618868, 'left': 87.181863, 'right': 76.917665}}, 'right': {'spInd': 0, 'spVal': 0.606417, 'left': 168.180746, 'right': {'spInd': 0, 'spVal': 0.513332, 'left': {'spInd': 0, 'spVal': 0.533511, 'left': {'spInd': 0, 'spVal': 0.548539, 'left': {'spInd': 0, 'spVal': 0.553797, 'left': {'spInd': 0, 'spVal': 0.560301, 'left': {'spInd': 0, 'spVal': 0.599142, 'left': 93.521396, 'right': {'spInd': 0, 'spVal': 0.589806, 'left': 130.378529, 'right': {'spInd': 0, 'spVal': 0.582311, 'left': {'spInd': 0, 'spVal': 0.585413, 'left': 98.674874, 'right': 125.295113}, 'right': {'spInd': 0, 'spVal': 0.571214, 'left': 82.589328, 'right': {'spInd': 0, 'spVal': 0.569327, 'left': 114.872056, 'right': 108.435392}}}}}, 'right': 82.903945}, 'right': {'spInd': 0, 'spVal': 0.549814, 'left': 120.857321, 'right': 137.267576}}, 'right': {'spInd': 0, 'spVal': 0.546601, 'left': 83.114502, 'right': {'spInd': 0, 'spVal': 0.537834, 'left': {'spInd': 0, 'spVal': 0.543843, 'left': 96.319043, 'right': 98.36201}, 'right': 90.995536}}}, 'right': {'spInd': 0, 'spVal': 0.51915, 'left': {'spInd': 0, 'spVal': 0.531944, 'left': 129.766743, 'right': 124.795495}, 'right': 116.176162}}, 'right': {'spInd': 0, 'spVal': 0.508548, 'left': 101.075609, 'right': {'spInd': 0, 'spVal': 0.508542, 'left': 93.292829, 'right': 96.403373}}}}}}}, 'right': {'spInd': 0, 'spVal': 0.457563, 'left': {'spInd': 0, 'spVal': 0.465561, 'left': {'spInd': 0, 'spVal': 0.467383, 'left': {'spInd': 0, 'spVal': 0.483803, 'left': {'spInd': 0, 'spVal': 0.487381, 'left': {'spInd': 0, 'spVal': 0.487537, 'left': 11.924204, 'right': 5.149336}, 'right': 27.729263}, 'right': 5.224234}, 'right': {'spInd': 0, 'spVal': 0.46568, 'left': -9.712925, 'right': -23.777531}}, 'right': {'spInd': 0, 'spVal': 0.463241, 'left': 30.051931, 'right': 17.171057}}, 'right': {'spInd': 0, 'spVal': 0.455761, 'left': -34.044555, 'right': {'spInd': 0, 'spVal': 0.126833, 'left': {'spInd': 0, 'spVal': 0.130626, 'left': {'spInd': 0, 'spVal': 0.382037, 'left': {'spInd': 0, 'spVal': 0.388789, 'left': {'spInd': 0, 'spVal': 0.437652, 'left': {'spInd': 0, 'spVal': 0.454312, 'left': {'spInd': 0, 'spVal': 0.454375, 'left': 9.841938, 'right': 3.043912}, 'right': {'spInd': 0, 'spVal': 0.446196, 'left': {'spInd': 0, 'spVal': 0.451087, 'left': -20.360067, 'right': -28.724685}, 'right': -5.108172}}, 'right': {'spInd': 0, 'spVal': 0.412516, 'left': {'spInd': 0, 'spVal': 0.418943, 'left': {'spInd': 0, 'spVal': 0.426711, 'left': {'spInd': 0, 'spVal': 0.428582, 'left': 19.745224, 'right': 15.224266}, 'right': -21.594268}, 'right': 44.161493}, 'right': {'spInd': 0, 'spVal': 0.403228, 'left': -26.419289, 'right': {'spInd': 0, 'spVal': 0.391609, 'left': -1.729244, 'right': 3.001104}}}}, 'right': {'spInd': 0, 'spVal': 0.385021, 'left': 21.578007, 'right': 24.816941}}, 'right': {'spInd': 0, 'spVal': 0.335182, 'left': {'spInd': 0, 'spVal': 0.370042, 'left': {'spInd': 0, 'spVal': 0.378965, 'left': -29.007783, 'right': {'spInd': 0, 'spVal': 0.373501, 'left': {'spInd': 0, 'spVal': 0.377383, 'left': 13.583555, 'right': 5.241196}, 'right': -8.228297}}, 'right': {'spInd': 0, 'spVal': 0.35679, 'left': -32.124495, 'right': {'spInd': 0, 'spVal': 0.350725, 'left': {'spInd': 0, 'spVal': 0.351478, 'left': -19.526539, 'right': -0.461116}, 'right': {'spInd': 0, 'spVal': 0.350065, 'left': -40.086564, 'right': {'spInd': 0, 'spVal': 0.342761, 'left': -1.319852, 'right': {'spInd': 0, 'spVal': 0.342155, 'left': -31.584855, 'right': {'spInd': 0, 'spVal': 0.3417, 'left': -16.930416, 'right': -23.547711}}}}}}}, 'right': {'spInd': 0, 'spVal': 0.324274, 'left': {'spInd': 0, 'spVal': 0.32889, 'left': {'spInd': 0, 'spVal': 0.331364, 'left': {'spInd': 0, 'spVal': 0.3349, 'left': 2.768225, 'right': 18.97665}, 'right': -1.290825}, 'right': 39.783113}, 'right': {'spInd': 0, 'spVal': 0.309133, 'left': {'spInd': 0, 'spVal': 0.310956, 'left': {'spInd': 0, 'spVal': 0.318309, 'left': -13.189243, 'right': -27.605424}, 'right': -49.939516}, 'right': {'spInd': 0, 'spVal': 0.131833, 'left': {'spInd': 0, 'spVal': 0.138619, 'left': {'spInd': 0, 'spVal': 0.156067, 'left': {'spInd': 0, 'spVal': 0.166765, 'left': {'spInd': 0, 'spVal': 0.193282, 'left': {'spInd': 0, 'spVal': 0.211633, 'left': {'spInd': 0, 'spVal': 0.228473, 'left': {'spInd': 0, 'spVal': 0.25807, 'left': {'spInd': 0, 'spVal': 0.284794, 'left': {'spInd': 0, 'spVal': 0.300318, 'left': 8.814725, 'right': {'spInd': 0, 'spVal': 0.297107, 'left': -18.051318, 'right': {'spInd': 0, 'spVal': 0.295993, 'left': -1.798377, 'right': {'spInd': 0, 'spVal': 0.290749, 'left': -14.988279, 'right': -14.391613}}}}, 'right': {'spInd': 0, 'spVal': 0.273863, 'left': 35.623746, 'right': {'spInd': 0, 'spVal': 0.264926, 'left': -9.457556, 'right': {'spInd': 0, 'spVal': 0.264639, 'left': 5.280579, 'right': 2.557923}}}}, 'right': {'spInd': 0, 'spVal': 0.228628, 'left': {'spInd': 0, 'spVal': 0.228751, 'left': {'spInd': 0, 'spVal': 0.232802, 'left': -20.425137, 'right': 1.222318}, 'right': -30.812912}, 'right': -2.266273}}, 'right': {'spInd': 0, 'spVal': 0.222271, 'left': {'spInd': 0, 'spVal': 0.2232, 'left': 19.425158, 'right': 15.501642}, 'right': {'spInd': 0, 'spVal': 0.218321, 'left': -9.255852, 'right': {'spInd': 0, 'spVal': 0.217214, 'left': 1.410768, 'right': -3.958752}}}}, 'right': {'spInd': 0, 'spVal': 0.202161, 'left': {'spInd': 0, 'spVal': 0.203993, 'left': {'spInd': 0, 'spVal': 0.206207, 'left': -8.332207, 'right': -12.619036}, 'right': -22.379119}, 'right': {'spInd': 0, 'spVal': 0.199903, 'left': -1.983889, 'right': -3.372472}}}, 'right': {'spInd': 0, 'spVal': 0.176523, 'left': 18.208423, 'right': 0.946348}}, 'right': {'spInd': 0, 'spVal': 0.156273, 'left': {'spInd': 0, 'spVal': 0.164134, 'left': {'spInd': 0, 'spVal': 0.166431, 'left': -14.740059, 'right': -6.512506}, 'right': -27.405211}, 'right': 0.225886}}, 'right': {'spInd': 0, 'spVal': 0.13988, 'left': 7.557349, 'right': 7.336784}}, 'right': -29.087463}, 'right': 22.478291}}}}}, 'right': -39.524461}, 'right': {'spInd': 0, 'spVal': 0.124723, 'left': 22.891675, 'right': {'spInd': 0, 'spVal': 0.085111, 'left': {'spInd': 0, 'spVal': 0.108801, 'left': {'spInd': 0, 'spVal': 0.11515, 'left': -1.402796, 'right': 13.795828}, 'right': {'spInd': 0, 'spVal': 0.10796, 'left': -16.106164, 'right': {'spInd': 0, 'spVal': 0.085873, 'left': -1.293195, 'right': -10.137104}}}, 'right': {'spInd': 0, 'spVal': 0.084661, 'left': 37.820659, 'right': {'spInd': 0, 'spVal': 0.080061, 'left': -24.132226, 'right': {'spInd': 0, 'spVal': 0.068373, 'left': {'spInd': 0, 'spVal': 0.079632, 'left': 2.229873, 'right': 29.420068}, 'right': {'spInd': 0, 'spVal': 0.061219, 'left': -15.160836, 'right': {'spInd': 0, 'spVal': 0.044737, 'left': {'spInd': 0, 'spVal': 0.053764, 'left': {'spInd': 0, 'spVal': 0.055862, 'left': 6.695567, 'right': -3.131497}, 'right': -13.731698}, 'right': {'spInd': 0, 'spVal': 0.028546, 'left': {'spInd': 0, 'spVal': 0.039914, 'left': 3.855393, 'right': 11.220099}, 'right': {'spInd': 0, 'spVal': 0.000256, 'left': -8.377094, 'right': 9.668106}}}}}}}}}}}}}
output_39_1.png
  • 叶节点太多,使用后剪枝
myDatTest = loadDataSet('../../Reference Code/Ch09/ex2test.txt')
myDatTest = np.mat(myDatTest)
prune(myTree,myDatTest)
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging





{'spInd': 0,
 'spVal': 0.499171,
 'left': {'spInd': 0,
  'spVal': 0.729397,
  'left': {'spInd': 0,
   'spVal': 0.952833,
   'left': {'spInd': 0,
    'spVal': 0.965969,
    'left': 92.5239915,
    'right': {'spInd': 0,
     'spVal': 0.956951,
     'left': {'spInd': 0,
      'spVal': 0.958512,
     ...

模型树

  • 把叶节点设定为分段线性函数
#求解岭回归
def linearSolve(dataSet):
    m,n = dataSet.shape
    X = np.mat(np.ones((m,n)))
    Y = np.mat(np.ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]
    Y = dataSet[:,-1]
    xTx = X.T*X
    if np.linalg.det(xTx) == 0:
        raise NameError('This matrix is singular,cannot do inverse,\n try increasing the scond value of ops')
    ws = xTx.I*(X.T*Y)
    return ws,X,Y

#生成叶节点
def modelLeaf(dataSet):
    ws,X,Y = linearSolve(dataSet)
    return ws

#计算误差
def modelErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    err = sum(np.power(Y-yHat,2))
    return err
myDat2 = loadDataSet('../../Reference Code/Ch09/exp2.txt')
myDat2 = np.mat(myDat2)
myTree = createTree(myDat2,leafType = modelLeaf, errType=modelErr,ops=(1,10))
print(myTree)
data2show(myDat2)

{'spInd': 0, 'spVal': 0.285477, 'left': matrix([[1.69855694e-03],
        [1.19647739e+01]]), 'right': matrix([[3.46877936],
        [1.18521743]])}
output_44_1.png
  • 可以看到在'spVal': 0.285477创建了两个模型,在图上也可以看到样本分布在x1=0.28分段
  • y1 = 3.468 + 1.1852x
  • y2 = 0.0016985 + 11.964x

树回归和标准回归的比较

#回归树的叶子节点预测值,model是一个常数(均值)
def regTreeEval(model,inDat):
    return float(model)

#模型树的叶子节点预测值,model是一个ws矩阵
def modelTreeEval(model,inDat):
    n = inDat.shape[1]
    X = np.mat(np.ones((1,n+1)))
    #格式化处理,添加了全1的x0列
    X[:,1:n+1] = inDat
    return float(X*model)

#自顶向下遍历整棵树,直到命中叶节点为止
def treeForeCast(tree,inData,modelEval = regTreeEval):
    #若是叶节点,则返回预测值
    if not isTree(tree):
        return modelEval(tree,inData)
    
    #先遍历左子树
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']):#存在子树,则往子树走,直到找到叶节点
            return treeForeCast(tree['left'],inData,modelEval)
        else:#叶节点则返回预测值
            return modelEval(tree['left'],inData)
    
    #再遍历右子树
    else:
        if isTree(tree['right']):#往下走
            return treeForeCast(tree['right'],inData,modelEval)
        else:#返回预测值
            return modelEval(tree['right'],inData)
        
#遍历每一个测试样本,返回预测值
def createForeCast(tree,testData,modelEval = regTreeEval):
    m = len(testData)
    yHat= np.mat(np.zeros((m,1)))
    for i in range(m):
        yHat[i,0] = treeForeCast(tree,np.mat(testData[i]),modelEval)
    return yHat

trainMat = np.mat(loadDataSet('../../Reference Code/Ch09/bikeSpeedVsIq_train.txt'))
testMat = np.mat(loadDataSet('../../Reference Code/Ch09/bikeSpeedVsIq_test.txt'))

trainMat[0:10,:]
matrix([[  3.      ,  46.852122],
        [ 23.      , 178.676107],
        [  0.      ,  86.154024],
        [  6.      ,  68.707614],
        [ 15.      , 139.737693],
        [ 17.      , 141.988903],
        [ 12.      ,  94.477135],
        [  8.      ,  86.083788],
        [  9.      ,  97.265824],
        [  7.      ,  80.400027]])
testMat[0:10,:]
matrix([[ 12.      , 121.010516],
        [ 19.      , 157.337044],
        [ 12.      , 116.031825],
        [ 15.      , 132.124872],
        [  2.      ,  52.719612],
        [  6.      ,  39.058368],
        [  3.      ,  50.757763],
        [ 20.      , 166.740333],
        [ 11.      , 115.808227],
        [ 21.      , 165.582995]])

1. 构建回归树

myTree = createTree(trainMat,leafType = regLeaf, errType=regErr,ops=(1,20))
print(myTree)
data2show(trainMat)
data2show(testMat)
{'spInd': 0, 'spVal': 10.0, 'left': {'spInd': 0, 'spVal': 17.0, 'left': {'spInd': 0, 'spVal': 20.0, 'left': 168.34161286956524, 'right': 157.0484078846154}, 'right': {'spInd': 0, 'spVal': 14.0, 'left': 141.06067981481482, 'right': 122.90893026923078}}, 'right': {'spInd': 0, 'spVal': 7.0, 'left': 94.7066578125, 'right': {'spInd': 0, 'spVal': 5.0, 'left': 69.02117757692308, 'right': 50.94683665}}}
output_52_1.png
output_52_2.png
x=testMat[:,0]
y=testMat[:,1]

yHat = createForeCast(myTree,testMat[:,0]) 
np.corrcoef(yHat,testMat[:,1],rowvar=0)[0,1] #rowvar=0,表示每一列是一个向量
0.9640852318222141

2. 构建模型树

myTree = createTree(trainMat,leafType = modelLeaf, errType=modelErr,ops=(1,20))
print(myTree)
yHat = createForeCast(myTree,testMat[:,0],modelEval=modelTreeEval)
np.corrcoef(yHat,testMat[:,1],rowvar=0)[0,1] #rowvar=0,表示每一列是一个向量
{'spInd': 0, 'spVal': 4.0, 'left': {'spInd': 0, 'spVal': 12.0, 'left': {'spInd': 0, 'spVal': 16.0, 'left': {'spInd': 0, 'spVal': 20.0, 'left': matrix([[47.58621512],
        [ 5.51066299]]), 'right': matrix([[37.54851927],
        [ 6.23298637]])}, 'right': matrix([[43.41251481],
        [ 6.37966738]])}, 'right': {'spInd': 0, 'spVal': 9.0, 'left': matrix([[-2.87684083],
        [10.20804482]]), 'right': {'spInd': 0, 'spVal': 6.0, 'left': matrix([[-11.84548851],
        [ 12.12382261]]), 'right': matrix([[-17.21714265],
        [ 13.72153115]])}}}, 'right': matrix([[ 68.87014372],
        [-11.78556471]])}





0.9760412191380593

3. 普通线性回归

ws,X,Y=linearSolve(trainMat)
ws
matrix([[37.58916794],
        [ 6.18978355]])
m = testMat.shape[0]
for i in range(m):
    yHat[i] = testMat[i,0]*ws[1,0] + ws[0,0]
np.corrcoef(yHat,testMat[:,1],rowvar=0)[0,1] #rowvar=0,表示每一列是一个向量
0.9434684235674763
  • 模型树>回归树>普通线性回归(岭回归)

4. 集成Matplotlib和Tkinter

#在windows下运行
import tkinter as tk
root = tk.Tk()
myLabel = tk.Label(root, text='Hello,World')
myLabel.grid()
root.mainloop()
from numpy import *
from tkinter import *

#import matplotlib
#matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

def reDraw(tolS,tolN):
    reDraw.f.clf()        # clear the figure
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get():
        if tolN < 2: tolN = 2
        myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS,tolN))
        yHat = createForeCast(myTree, reDraw.testDat, modelTreeEval)
    else:
        myTree=createTree(reDraw.rawDat, ops=(tolS,tolN))
        yHat = createForeCast(myTree, reDraw.testDat)
    reDraw.a.scatter(reDraw.rawDat[:,0].flatten().A[0], reDraw.rawDat[:,1].flatten().A[0], s=5) #use scatter for data set
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) #use plot for yHat
    reDraw.canvas.show()
    
def getInputs():
    try: tolN = int(tolNentry.get())
    except: 
        tolN = 10 
        print("enter Integer for tolN")
        tolNentry.delete(0, END)
        tolNentry.insert(0,'10')
    try: tolS = float(tolSentry.get())
    except: 
        tolS = 1.0 
        print("enter Float for tolS")
        tolSentry.delete(0, END)
        tolSentry.insert(0,'1.0')
    return tolN,tolS

def drawNewTree():
    tolN,tolS = getInputs()#get values from Entry boxes
    reDraw(tolS,tolN)
    
root=Tk()

reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvas
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

Label(root, text="tolN").grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)

reDraw.rawDat = mat(loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0, 10)
               
root.mainloop()
回归树.png

模型树.png

你可能感兴趣的:(树回归)