**决策树是一种贪心算法,它在给定时间内做出最佳选择,但并不关心能否达到全局最优**。
**CART算法做二元切分,使用一部字典来存储树的数据结构,包含4个元素**:
待切分的特征
待切分的特征值
右子树。当不需要切分时,也可以是单个值
左子树。与右子树类似。
"""
createTree()伪代码:
找到最佳的待切分特征:
如果该节点不能再分,将该节点存为叶节点
执行二元切分
在右子树调用createTree()方法
在左子树调用createtree()方法
"""
import numpy as np
def loadDataSet(filename):
dataMat = []
fr = open(filename)
for line in fr.readlines():
curline = line.strip().split('\t')
fltline = list(map(float,curline)) #将每行映射成浮点数
print(fltline)
dataMat.append(fltline)
return dataMat
def binSplitDataSet(dataSet,feature,value):
"""根据给定值进行数据的划分
dataSet:数据集合feature:待切分的特征value:该特征的某个值
返回划分好的两个分支"""
mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:]
mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
return mat0,mat1
#回归树的切分函数
def regLeaf(dataSet):
"""负责生成叶节点。当chooseBestSplit()函数确定不再对数据进行划分时,将调用该regLeaf()函数来得到叶节点的模型。在回归树中,该模型其实就是目标变量的均值。"""
return np.mean(dataSet[:,-1])
def regErr(dataSet):
"""该函数在给定数据上计算目标变量的平方误差。因为需要返回的是总方差,所以要用均方差乘以数据集中样本的个数"""
return np.var(dataSet[:,-1])*np.shape(dataSet)[0]
def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
"""该函数的目标是找到数据集划分的最佳位置。它遍历所有的特征及其可能的取值来找到使误差最小化的切分阈值。该函数的伪代码:
对每个特征:
对每个特征值:
将数据集切分成两份
计算切分的误差
如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值
return bestIndex,bestValue
"""
tolS = ops[0] #容许的误差下降值
tolN = ops[1] #切分的最少样本数
#统计不同剩余特征值的数目,如果所有值相等则退出
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None,leafType(dataSet)
m,n = np.shape(dataSet)
S = errType(dataSet)
bestS = np.inf
bestIndex = 0
bestValue = 0
for featIndex in range(n-1):
# print(type(dataSet[:,featIndex].tolist()))
# print(type(dataSet[:,featIndex].T.A.tolist()))
# print(dataSet[:,featIndex].tolist()[0]) 结果是一个只包含一个元素的列表
for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
mat0,mat1 = binSplitDataSet(dataSet,featIndex,splitVal)
if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#如果误差减少不大则退出
if S - bestS < tolS:
print(leafType(dataSet)) #这是什么用法?leafType=regLeaf,传递了regLeaf()函数的引用,故可通过leafType()方式调用regLeaf()
return None,leafType(dataSet)
mat0,mat1 = binSplitDataSet(dataSet,bestIndex,bestValue)
#如果切分出的数据集很小则退出
if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
return None,leafType(dataSet)
return bestIndex,bestValue
def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
"""创建CART树
dataSet:数据集
leafType=regLeaf:建立叶节点的函数
errType=regErr:误差计算函数
ops=(1,4):包含树构建所需其他参数的元组\n
return retTree"""
feat,val = chooseBestSplit(dataSet,leafType,errType,ops)
if feat == None:
return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet,rSet = binSplitDataSet(dataSet,feat,val)
retTree['left'] = createTree(lSet,leafType,errType,ops)
retTree['right'] = createTree(rSet,leafType,errType,ops)
return retTree
if __name__ == '__main__':
myData = loadDataSet('./M_09_树回归/ex0.txt')
myMat = np.mat(myData)
RegTree = createTree(myMat,ops=(0,1)) #降低阈值,创建所有可能中最大的树
print(RegTree)
#导入测试数据
myDataTest = loadDataSet('./M_09_树回归/ex2test.txt')
myMatTest = np.mat(myDataTest)
treePruned = prune(RegTree,myMatTest)
print(treePruned)
#从结果可以看到,此例中后剪枝的效果没有预剪枝有效。一般地,为了寻求最佳模型可以同时使用两种剪枝技术
为成功构建以分段常数为叶节点的树,需要度量出数据的一致性。怎样计算连续型数值的混乱度呢?
首先计算所有数据的均值,然后计算每条数据的值到均值的差值。为了对正负差值同等对待,一般使用绝对值或平方值代替上述差值。
一棵树如果节点过多,表明该模型可能对数据进行了"过拟合"。那么,如何判断是否发生了过拟合?可以通过交叉验证发现。
通过降低决策树的复杂度来避免过拟合的过程称为剪枝(pruning)。在chooseBestSplit()中的提前终止条件,实际上是在进行一种所谓的预剪枝(prepruning)操纵。另一种形式的剪枝需要使用测试集和训练集,称为后剪枝。
树构建算法其实对输入的参数tolS和tolN非常敏感,通过调整不同的值可以得到不同的模型结果
使用后剪枝方法需要将数据集分为测试集和训练集。首先指定参数,使得构建的树足够大、足够复杂,便于剪枝。接下来从上而下找到叶节点,用测试集判断将这些叶节点合并是否能降低测试误差,如果是的话就合并。
#回归树剪枝函数
def isTree(obj):
"""测试输入变量是否是一棵树,返回布尔类型(用于判断是否属于叶子节点)"""
return (type(obj).__name__ == 'dict')
def getMean(tree):
"""递归函数:从上往下遍历,如果找到两个叶节点则计算它们的平均值"""
if isTree(tree['right']):
tree['right'] = getMean(tree['right'])
if isTree(tree['left']):
tree['left'] = getMean(tree['left'])
return (tree['left'] + tree['right'])/2.0
def prune(tree,testData):
"""tree:待剪枝的树
testData:剪枝所需的测试数据
基于已有的树切分测试数据:
如果存在任一子集是一棵树,则在该子集递归剪枝过程
计算将当前两个叶节点合并后的误差
计算不合并的误差
如果合并会降低误差的话,就将叶节点合并"""
#没有测试数据则对树进行塌陷处理
if np.shape(testData)[0] == 0 :
return getMean(tree)
#一旦测试集非空,则反复递归调用函数对数据进行切分
if (isTree(tree['right']) or isTree(tree['left'])):
lSet,rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])
if isTree(tree['left']):
tree['left'] = prune(tree['left'],lSet)
if isTree(tree['right']):
tree['right'] = prune(tree['right'],rSet)
#如果两个分支不再是子树,那么就可以进行合并
if not isTree(tree['left']) and not isTree(tree['right']):
lSet,rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])
errorNoMerge = sum(np.power(lSet[:,-1] - tree['left'],2)) + sum(np.power(rSet[:,-1] - tree['right'],2)) #power(A,B),求A的B次方
treeMean = (tree['left'] + tree['right'])/2.0
errorMerge = sum(np.power(testData[:,-1]-treeMean,2))
if errorMerge < errorNoMerge:
print('Merging')
return treeMean
else:
return tree
else:
return tree
用树来对数据建模,除了把叶节点简单地设定为常数值之外,还有一种方法是把节点设定为分段线性函数。这里所谓的分段线性是指模型由多个线性片段组成。
为了找到最佳切分,对于给定数据集,应该先用线性的模型来对它进行拟合,然后计算真实的目标值与模型预测值间的差值。最后将这些差值的平方求和就得到了所需的误差。
#模型树的叶节点生成函数
#模型书的叶节点是线性回归模型
def linearSolve(dataSet):
"""将数据集格式化成目标向量Y和自变量X,X和Y用于执行简单的线性回归"""
m,n = np.shape(dataSet)
X = np.mat(np.ones((m,n)))
Y = np.mat(np.ones((m,1)))
X[:,1:n] = dataSet[:,0:n-1]
Y = dataSet[:,-1]
xTx = X.T*X
if np.linalg.det(xTx) == 0.0:
raise NameError('This matrix is singular,cannotdo inverse,try increasing the second value of ops')
ws = xTx.T*(X.T*Y)
return ws,X,Y
def modelLeaf(dataSet):
"""当数据不再需要切分的时候它负责生成叶节点的模型"""
ws,X,Y = linearSolve(dataSet)
return ws
def modelErr(dataSet):
"""在给定的数据集上计算误差,返回yHat和Y之间的平方误差"""
ws,X,Y = linearSolve(dataSet)
yHat = X*ws
return sum(np.power(Y-yHat,2))
#用树回归进行预测的代码
def regTreeEval(model,inDat):
return float(model)
def modelTreeEval(model,inDat):
n = np.shape(inDat)[1]
X = np.mat(np.ones((1,n+1)))
X[:,1:n+1] = inDat
return float(X*model)
def treeForeCast(tree,inData,modelEval=regTreeEval):
"""tree:训练好的树模型
inData:要进行预测的数据
modelEval:是对叶节点数据进行预测的函数的引用,主要对数据格式进行调整以适应不同的树模型"""
if not isTree(tree):
return modelEval(tree,inData)
if inData[tree['spInd']] > tree['spVal']:
if isTree(tree['left']):
return treeForeCast(tree['left'],inData,modelEval)
else:
return modelEval(tree['left'],inData)
else:
if isTree(tree['right']):
return treeForeCast(tree['right'],inData,modelEval)
else:
return modelEval(tree['right'],inData)
def createForeCast(tree,testData,modelEval=regTreeEval):
m = len(testData)
yHat = np.mat(np.zeros((m,1)))
for i in range(m):
yHat[i,0] = treeForeCast(tree,np.mat(testData[i]),modelEval)
return yHat
使用corrcoef()分析哪个模型是最优的
Matplotlib的构建程序包含一个前端,也就是面向用户的一些代码,如plot()和scatter()方法等。事实上,它同时创建了一个后端,用于实现绘图和不同应用之间接口。通过改变后端可以将图像绘制在PNG、PDF等格式的文件上。
下面将设置后端为TkAgg(Agg是一个C++的库,可以从图像创建位图)。TkAgg可以在所选GUI框架上调用Agg,把Agg呈现在画布上。
import numpy as np
from tkinter import *
import M_0902_CART as regTree
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
def reDraw(tolS,tolN):
"""绘制树"""
reDraw.f.clf()
reDraw.a = reDraw.f.add_subplot(111)
if chkBtnVar.get():#检查复选框是否选中,选择不同的模型进行训练
if tolN < 2:
tolN = 2
myTree = regTree.createTree(reDraw.rawDat,regTree.modelLeaf,regTree.modelErr,(tolS,tolN))
yHat = regTree.createForeCast(myTree,reDraw.testDat,regTree.modelTreeEval)
else:
myTree = regTree.createTree(reDraw.rawDat,ops=(tolS,tolN))
yHat = regTree.createForeCast(myTree,reDraw.testDat)
print(type(reDraw.rawDat[:,0]))
print(reDraw.rawDat[:,0])
reDraw.a.scatter(np.array(reDraw.rawDat[:,0]),np.array(reDraw.rawDat[:,1]),s=5)
reDraw.a.plot(reDraw.testDat,yHat,linewidth=2.0)
reDraw.canvas.draw()
def getInputs():
try:
tolN = int(tolNentry.get())
except:
tolN =10
print('enter Integer for tolN')
tolNentry.delete(0,END)
tolNentry.insert(0,'10')
try:
tolS = float(tolSentry.get())
except:
tolS = 1.0
print('enter float for tolS')
tolSentry.delete(0,END)
tolSentry.insert(0,'1.0')
return tolN,tolS
def drawNewTree():
tolN,tolS = getInputs()
reDraw(tolS,tolN)
root = Tk()
reDraw.f = Figure(figsize=(5,4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.draw()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
Label(root,text='tolN').grid(row=1,column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1,column=1)
tolNentry.insert(0,'10')
Label(root,text='tolS').grid(row=2,column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2,column=1)
tolSentry.insert(0,'1.0')
Button(root,text='Redraw',command=drawNewTree).grid(row=1,column=2,rowspan=3)
chkBtnVar = IntVar() #读取checkbutton的状态需要创建一个变量
chkBtn = Checkbutton(root,text='Model Tree',variable=chkBtnVar)
chkBtn.grid(row=3,column=0,columnspan=2)
#初始化一些与reDraw()关联的全局变量
reDraw.rawDat = np.mat(regTree.loadDataSet('./M_09_树回归/sine.txt'))
reDraw.testDat = np.arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0,10)
root.mainloop()
若叶节点使用的模型是分段常数则称为回归树,若叶节点使用的模型是线性回归方程则称为模型树。
CART算法可以用于构建二元树并处理离散型或连续型数据的切分。若使用不同的误差准则,就可以通过CART算法构建模型树和回归树。