机器学习实战代码详解(九)树回归

regCART.py

#coding=utf-8
from numpy import *

def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = map(float, curLine)
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):
    mt0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
    mt1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
    return mt0, mt1

#---------回归树的切分函数-----------
#叶节点生成函数,当chooseBestSplit()函数确定不再对数据进行切分时,将调用该函数来得到叶节点模型
def regLeaf(dataSet):
    return mean(dataSet[:,-1])
#误差估计函数
def regErr(dataSet):
    return var(dataSet[:, -1]) * shape(dataSet)[0]
#找到数据的最佳二元切分方式,如果找不到一个好的二元切分,该函数返回None并同时调用createTree()方法来产生叶节点
def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops = (1, 4)):
    tolS = ops[0]; tolN = ops[1]        #tolS容许的误差下降值,tolN切分的最少样本数
    if len(set(dataSet[:, -1].T.tolist()[0])) == 1:     #tolist将数组或者矩阵转换成列表,如果所有的值相等,退出
        return None, leafType(dataSet)
    m, n = shape(dataSet)           #计算当前数据集大小
    S = errType(dataSet)            #当前数据集误差
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n - 1):          #找到最佳的切分方式
        for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]): #在所有可能的特征值上遍历
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)  #切分数据
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue     #切分后的子集小于tolN
            newS = errType(mat0) + errType(mat1)            #切分后的误差
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if(S - bestS) < tolS: return None, leafType(dataSet)    #如果误差之差小鱼容忍值,返回
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) #切分
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): return None, leafType(dataSet)
    return bestIndex, bestValue     #返回切分特征与特征值

def createTree(dataSet, leafType = regLeaf, errType = regErr, ops = (1, 4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree
#测试输入是否为一棵树,返回布尔类型的结果,判断该节点是否为叶节点
def isTree(obj):
    return (type(obj).__name__ == 'dict')
#从上往下遍历树直到叶节点为止,如果找到两个叶节点则计算它们的平均值
def getMean(tree):
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    if isTree[tree['left']]: tree['left'] = getMean(tree['left'])
    return (tree['left'] + tree['right'])/2.0
#tree待剪枝的树,testData剪枝所需要的测试数据
def prune(tree, testData):
    if shape(testData)[0] == 0: return getMean(tree)    #测试集是否为空
    if (isTree(tree['right'])) or isTree(tree['left']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNomerge = sum(power(lSet[:, -1] - tree['left'], 2)) + sum(power(rSet[:, -1] - tree['right'], 2))
        treeMean = (tree['left'] + tree['right'])/2.0
        errorMerge = sum(power(testData[:, -1] - treeMean, 2))
        if(errorMerge < errorNomerge):      #如果合并后误差下,则合并
            print "merging"
            return treeMean
        else: return tree
    else:
        return tree

#模型树叶节点生成函数
def linearSolve(dataSet):
    m, n = shape(dataSet)
    X = mat(ones((m, n))); Y = mat(ones((m, 1)))
    X[:, 1:n] = dataSet[:, 0:n-1]; Y = dataSet[:, -1]
    xTx = X.T * X
    if linalg.det(xTx) == 0.0 :
        raise NameError('this matrix is singular, cannot do inverse, \n try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws, X, Y
def modelLeaf(dataSet):
    ws, X, Y = linearSolve(dataSet)
    return ws

def modelErr(dataSet):
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y -yHat, 2))

#用树回归进行预测代码
def regTreeEval(model, inDat):
    return float(model)

def modelTreeEval(model, inDat):
    n = shape(inDat)[1]
    X = mat(ones((1,n+1)))
    X[:,1:n+1]=inDat
    return float(X*model)

#输入单个数据点或者行向量,函数treeForeCast()会返回一个浮点值
def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree): return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)
        else: return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)
        else: return modelEval(tree['right'], inData)
#自顶向下遍历整棵树,直到命中叶节点为止,一旦到达叶节点,它就会在输入数据上调用modelEval函数
def createForeCast(tree, testData, modelEval=regTreeEval):
    m=len(testData)
    yHat = mat(zeros((m,1)))
    for i in range(m):
        yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat

treeExplore.py

#coding=utf-8

from numpy import *
from Tkinter import *
import regCART
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

#用tolS,tolN生成图
def reDraw(tolS,tolN):
    reDraw.f.clf()        # clear the figure
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get():
        if tolN < 2: tolN = 2
        myTree=regCART.createTree(reDraw.rawDat, regCART.modelLeaf, \
                                  regCART.modelErr, (tolS,tolN))
        yHat = regCART.createForeCast(myTree, reDraw.testDat, \
                                       regCART.modelTreeEval)
    else:
        myTree=regCART.createTree(reDraw.rawDat, ops=(tolS,tolN))
        yHat = regCART.createForeCast(myTree, reDraw.testDat)
    reDraw.a.scatter(reDraw.rawDat[:,0], reDraw.rawDat[:,1], s=5) #use scatter for data set
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) #use plot for yHat
    reDraw.canvas.show()

#获取输入框的值
def getInputs():
    try: tolN = int(tolNentry.get())     #获取tolN输入
    except:                             #防止输入错误,并初始化成原始值
        tolN = 10
        print "enter Integer for tolN"
        tolNentry.delete(0, END)
        tolNentry.insert(0, '10')
    try: tolS = float(tolSentry.get())  #获取tolS输入
    except:                             #防止输入错误,并初始化成原始值
        tolS = 1.0
        print "enter Float for tolS"
        tolSentry.delete(0, END)
        tolSentry.insert(0, '1.0')
    return tolN, tolS

def drawNewTree():
    tolN, tolS = getInputs()
    reDraw(tolS, tolN)


root = Tk()

reDraw.f = Figure(figsize=(5, 4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)        #绘图形组件,可以在其中绘制图形
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

Label(root, text = 'Plot Place Holder').grid(row = 0, columnspan=3)
Label(root, text = 'tolN').grid(row = 1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row = 1, column = 1)
tolNentry.insert(0, '10')
Label(root, text = 'tolS').grid(row = 2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row = 2, column = 1)
tolSentry.insert(0, '1.0')
Button(root, text = 'reDraw', command = drawNewTree).grid(row=1, column=2,rowspan=3)  #点击reDraw按钮,调用drawNewTree函数

chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text = 'Model Tree', variable = chkBtnVar)
chkBtn.grid(row = 3, column = 0, columnspan = 2)

reDraw.rawDat = mat(regCART.loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)
reDraw(1.0, 10)
root.mainloop()

总结

  1. 数据集中经常包含一些复杂的相互关系,使得输入数据和目标变量之间呈现非线性关系。对这些复杂关系建模,一种可行的方式是使用树对预测值分段,包括分段常熟或分段直线。一般采用树结构来对这种数据建模。相应地,若叶节点使用的模型是分段常数,则成为回归树,若叶节点使用的模型是线性回归方程则称为模型树。
  2. CART算法可以用于构建二元树并处理离散型或者连续型数据的切分。若使用不同的误差准则,就可以通过CART算法构建模型树和回归树。该算法构建出的树会倾向于对数据过拟合。一棵过拟合的树常常十分复杂,剪枝技术的出现就是为了解决这个问题。两种剪枝方法分别是预剪枝和后剪枝,预剪枝更有效但需要用户定义一些参数。

你可能感兴趣的:(机器学习实战)