现实中,数据集中经常包含一些复杂的相互关系,使得输入数据和目标变量之间呈现非线性关系。对这些复杂的关系建模,一种可行的方式是使用树来对预测值进行分段,包括分段常数或者分段直线,即通过树结构对数据进行切分后,在叶节点上,对叶节点上的数据,取均值构造回归树,或者取线性模型构造模型树。
下面,我们统一将基于CART的回归树和模型树称作树回归。
1.1 相对之前提到的ID3决策树来说,基于二元切分的树回归切分不会过快,而且可以处理连续性特征数据。
1.2 优点:可以对复杂和非线性数据建模
1.3 缺点:结果不像线性回归那么好理解
1.4 模型树可解释性由于回归树,相对而言,模型树也具有更高的预测准确度。
对于模型树、回归树和之前的线性回归,一种比较客观的比较方法是计算相关系数,即R^2值。
只需调用Numpy库中的命令corrcoef(yHat,y,rowvar=0)即可,其中yHat为模型预测值,y是目标变量的实际值。
R^2值越接近1.0说明预测性能越好。
如果数据集中目标变量只有一种: 不进行后续切分,直接将此数据集构建为叶节点 对每个特征: 对每个特征值: 将数据集切分为两份 计算切分后两个子数据集的误差和 如果此误差和小于当前最小误差: 将当前切分设定为最佳切分并更新最小误差 如果数据集上的误差和当前最小误差之间没有达到设定的容许误差下降值: 不进行后续切分,直接将此数据集构建为叶节点 如果切分后的子数据集中的样本数低于设定的最少样本数: 不进行后续切分,直接将此数据集构建为叶节点 返回记录的最佳切分的特征和切分点
调用chooseBestSplit()找到最佳待切分特征: 如果该节点不能再分,即待切分特征无: 将该节点存为叶节点 执行二元切分 在右子树调用createTree()函数 在左子树调用createTree()函数
基于前面所得的树对测试数据进行切分: 如果存在任一子集不是叶节点而是树: 在该子集上调用prune()函数 计算此时标准二分树的误差: 即两个子叶节点上的误差和 计算将当前两个叶节点合并后的误差: 即当前标准二分树的根节点值取两叶节点均值后构成的单节点结构的误差 如果合并后降低误差的话,就将此两叶节点进行合并
from numpy import * def loadDataSet(fileName): # creat a list, but following dataSet represents matrix dataMat = [] fr = open(fileName) for line in fr.readlines(): currLine = line.strip().split('\t') fltLine = map(float, currLine) dataMat.append(fltLine) return dataMat ### Preparing for creating tree # the function regleaf and modelleaf is going to create the leafnodes def regLeaf(dataSet): return mean(dataSet[:,-1]) def regErr(dataSet): return shape(dataSet)[0] * var(dataSet[:,-1]) # used for measuring the uniformity of data # or say for calculating the chaos of data def linearSolve(dataSet): N, n = shape(dataSet) X = mat(ones((N,n))) Y = mat(ones((N,1))) X[:, 1:n] = dataSet[:, 0:n-1] Y = dataSet[:, -1] xTx = X.T * X if linalg.det(xTx) == 0.0 : raise NameError('This matrix is singular, cannot do inverse, \n\ try increasing the second value of ops') ws = xTx.I * (X.T * Y) return ws, X, Y def modelLeaf(dataSet): ws, X, Y = linearSolve(dataSet) return ws def modelErr(dataSet): ws, X, Y = linearSolve(dataSet) yHat = X * ws return sum(power(Y-yHat, 2)) ### Creating tree def binSplitDataSet(dataSet, feature, value): mat0 = dataSet[nonzero(dataSet[:,feature] >value)[0], :] mat1 = dataSet[nonzero(dataSet[:,feature]<=value)[0], :] return mat0, mat1 def chooseBestSplit(dataSet, leafType, errType, ops): tolS = ops[0] # desent error value tolerated tolN = ops[1] # minimum number of samples splited if len(set(dataSet[:,-1].T.tolist()[0])) == 1 : return None, leafType(dataSet) N, n = shape(dataSet) S = errType(dataSet) bestS = inf; bestIndex = 0; bestValue = 0 for featIndex in range(n-1) : for splitVal in set(dataSet[:,featIndex].T.tolist()[0]): mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) if (shape(mat0)[0]<tolN) or (shape(mat1)<tolN): continue newS = errType(mat0) + errType(mat1) if newS < bestS: bestIndex = featIndex bestValue = splitVal bestS = newS if (S-bestS) < tolS : return None, leafType(dataSet) mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN) : return None, leafType(dataSet) return bestIndex, bestValue def createTree(dataSet, leafType, errType, ops): feat, val = chooseBestSplit(dataSet, leafType, errType, ops) if feat == None: return val retTree = {} retTree['spInd'] = feat retTree['spVal'] = val lSet, rSet = binSplitDataSet(dataSet, feat, val) retTree['left'] = createTree(lSet, leafType, errType, ops) retTree['right'] = createTree(rSet, leafType, errType, ops) return retTree ### Post Purning def isTree(obj): return (type(obj).__name__ == 'dict') def getMean(tree): if isTree(tree['right']): tree['right'] = getMean(tree['right']) if isTree(tree['left']): tree['left'] = getMean(tree['left']) return (tree['right']+tree['left'])/2.0 def postPurning(tree, testData): if shape(testData)[0] == 0 : return getMean(tree) if isTree(tree['right']) or isTree(tree['left']) : lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) if isTree(tree['left']) : tree['left'] = postPurning(tree['left'], lSet) if isTree(tree['right']) : tree['right'] = postPurning(tree['right'], rSet) if not isTree(tree['left']) and not isTree(tree['right']) : lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) errNoMerge = sum(power(lSet[:,-1]-tree['left'], 2)) + sum(power(rSet[:,-1]-tree['right'], 2)) treeMean = (tree['left']+tree['right']) / 2.0 errMerge = sum(power(testData[:,-1]-treeMean, 2)) if errMerge < errNoMerge : print "merging" return treeMean else: return tree else: return tree ### Predicting def regTreeEval(model, inData): return float(model) def modelTreeEval(model, inData): n = shape(inData)[1] X = mat(zeros((1,n+1))) X[:,1:n+1] = inData return float(X*model) def treeForecast(tree, inData, treeEval): if not isTree(tree): return treeEval(tree, inData) if inData[tree['spInd']] > tree['spVal'] : if isTree(tree['left']) : return treeForecast(tree['left'], inData, treeEval): else: return treeEval(tree['left'], inData) else: if isTree(tree['right']) : return treeForecast(tree['right'], inData, treeEval): else: return treeEval(tree['right'], inData) def createForecast(tree, testData, treeEval): M = len(testData) yHat = mat(zeors(M,1)) for ii in range(M): yHat[ii,0] = treeForecast(tree, mat(testData[ii]), treeEval) return yHat