regCART.py
#coding=utf-8
from numpy import *
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = map(float, curLine)
dataMat.append(fltLine)
return dataMat
def binSplitDataSet(dataSet, feature, value):
mt0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
mt1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
return mt0, mt1
#---------回归树的切分函数-----------
#叶节点生成函数,当chooseBestSplit()函数确定不再对数据进行切分时,将调用该函数来得到叶节点模型
def regLeaf(dataSet):
return mean(dataSet[:,-1])
#误差估计函数
def regErr(dataSet):
return var(dataSet[:, -1]) * shape(dataSet)[0]
#找到数据的最佳二元切分方式,如果找不到一个好的二元切分,该函数返回None并同时调用createTree()方法来产生叶节点
def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops = (1, 4)):
tolS = ops[0]; tolN = ops[1] #tolS容许的误差下降值,tolN切分的最少样本数
if len(set(dataSet[:, -1].T.tolist()[0])) == 1: #tolist将数组或者矩阵转换成列表,如果所有的值相等,退出
return None, leafType(dataSet)
m, n = shape(dataSet) #计算当前数据集大小
S = errType(dataSet) #当前数据集误差
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n - 1): #找到最佳的切分方式
for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]): #在所有可能的特征值上遍历
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) #切分数据
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue #切分后的子集小于tolN
newS = errType(mat0) + errType(mat1) #切分后的误差
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if(S - bestS) < tolS: return None, leafType(dataSet) #如果误差之差小鱼容忍值,返回
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) #切分
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): return None, leafType(dataSet)
return bestIndex, bestValue #返回切分特征与特征值
def createTree(dataSet, leafType = regLeaf, errType = regErr, ops = (1, 4)):
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None: return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
#测试输入是否为一棵树,返回布尔类型的结果,判断该节点是否为叶节点
def isTree(obj):
return (type(obj).__name__ == 'dict')
#从上往下遍历树直到叶节点为止,如果找到两个叶节点则计算它们的平均值
def getMean(tree):
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
if isTree[tree['left']]: tree['left'] = getMean(tree['left'])
return (tree['left'] + tree['right'])/2.0
#tree待剪枝的树,testData剪枝所需要的测试数据
def prune(tree, testData):
if shape(testData)[0] == 0: return getMean(tree) #测试集是否为空
if (isTree(tree['right'])) or isTree(tree['left']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNomerge = sum(power(lSet[:, -1] - tree['left'], 2)) + sum(power(rSet[:, -1] - tree['right'], 2))
treeMean = (tree['left'] + tree['right'])/2.0
errorMerge = sum(power(testData[:, -1] - treeMean, 2))
if(errorMerge < errorNomerge): #如果合并后误差下,则合并
print "merging"
return treeMean
else: return tree
else:
return tree
#模型树叶节点生成函数
def linearSolve(dataSet):
m, n = shape(dataSet)
X = mat(ones((m, n))); Y = mat(ones((m, 1)))
X[:, 1:n] = dataSet[:, 0:n-1]; Y = dataSet[:, -1]
xTx = X.T * X
if linalg.det(xTx) == 0.0 :
raise NameError('this matrix is singular, cannot do inverse, \n try increasing the second value of ops')
ws = xTx.I * (X.T * Y)
return ws, X, Y
def modelLeaf(dataSet):
ws, X, Y = linearSolve(dataSet)
return ws
def modelErr(dataSet):
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y -yHat, 2))
#用树回归进行预测代码
def regTreeEval(model, inDat):
return float(model)
def modelTreeEval(model, inDat):
n = shape(inDat)[1]
X = mat(ones((1,n+1)))
X[:,1:n+1]=inDat
return float(X*model)
#输入单个数据点或者行向量,函数treeForeCast()会返回一个浮点值
def treeForeCast(tree, inData, modelEval=regTreeEval):
if not isTree(tree): return modelEval(tree, inData)
if inData[tree['spInd']] > tree['spVal']:
if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)
else: return modelEval(tree['left'], inData)
else:
if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)
else: return modelEval(tree['right'], inData)
#自顶向下遍历整棵树,直到命中叶节点为止,一旦到达叶节点,它就会在输入数据上调用modelEval函数
def createForeCast(tree, testData, modelEval=regTreeEval):
m=len(testData)
yHat = mat(zeros((m,1)))
for i in range(m):
yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
return yHat
treeExplore.py
#coding=utf-8
from numpy import *
from Tkinter import *
import regCART
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
#用tolS,tolN生成图
def reDraw(tolS,tolN):
reDraw.f.clf() # clear the figure
reDraw.a = reDraw.f.add_subplot(111)
if chkBtnVar.get():
if tolN < 2: tolN = 2
myTree=regCART.createTree(reDraw.rawDat, regCART.modelLeaf, \
regCART.modelErr, (tolS,tolN))
yHat = regCART.createForeCast(myTree, reDraw.testDat, \
regCART.modelTreeEval)
else:
myTree=regCART.createTree(reDraw.rawDat, ops=(tolS,tolN))
yHat = regCART.createForeCast(myTree, reDraw.testDat)
reDraw.a.scatter(reDraw.rawDat[:,0], reDraw.rawDat[:,1], s=5) #use scatter for data set
reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) #use plot for yHat
reDraw.canvas.show()
#获取输入框的值
def getInputs():
try: tolN = int(tolNentry.get()) #获取tolN输入
except: #防止输入错误,并初始化成原始值
tolN = 10
print "enter Integer for tolN"
tolNentry.delete(0, END)
tolNentry.insert(0, '10')
try: tolS = float(tolSentry.get()) #获取tolS输入
except: #防止输入错误,并初始化成原始值
tolS = 1.0
print "enter Float for tolS"
tolSentry.delete(0, END)
tolSentry.insert(0, '1.0')
return tolN, tolS
def drawNewTree():
tolN, tolS = getInputs()
reDraw(tolS, tolN)
root = Tk()
reDraw.f = Figure(figsize=(5, 4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root) #绘图形组件,可以在其中绘制图形
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
Label(root, text = 'Plot Place Holder').grid(row = 0, columnspan=3)
Label(root, text = 'tolN').grid(row = 1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row = 1, column = 1)
tolNentry.insert(0, '10')
Label(root, text = 'tolS').grid(row = 2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row = 2, column = 1)
tolSentry.insert(0, '1.0')
Button(root, text = 'reDraw', command = drawNewTree).grid(row=1, column=2,rowspan=3) #点击reDraw按钮,调用drawNewTree函数
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text = 'Model Tree', variable = chkBtnVar)
chkBtn.grid(row = 3, column = 0, columnspan = 2)
reDraw.rawDat = mat(regCART.loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)
reDraw(1.0, 10)
root.mainloop()
总结