cart,分类和回归树算法。
cart既可以用来构建分类决策树,也可以用来构建回归树、模型树。
用树对数据建模,把叶子节点简单设定为常数值,构成回归树。如果把叶子节点设定为分段线性函数,即构成模型树。
cart使用二元切分法来处理连续变量。所以可以固定树的节点,每个节点由4个固定属性:待切分的feature、待切分的feature value、右子树、左子树。
注:createTree考虑类别>=3时候的代码可参考博客:
http://blog.csdn.net/wzmsltw/article/details/51057311
中createTree函数。
*找到最佳的待切分feature、value:
如果该节点不能再分,将该节点存为叶节点
执行二元切分
在右子树执行createTree方法
在左子树执行createTree方法*
def createTree(dataSet, leafType, errType, cond=(1,4)):
''' 创建回归树/模型树。
:param dataSet:
:param leafType:
:param errType:
:param cond: 预剪枝条件
:return:
'''
feature, value = chooseBestSplit(dataSet, leafType, errType, cond)
if feature == None:
return value
retTree = {}
retTree['spInd'] = feature
retTree['spVal'] = value
lSet,rSet = binSplitDataSet(dataSet, feature, value)
retTree['left'] = createTree(lSet, leafType, errType, cond)
retTree['right'] = createTree(rSet, leafType, errType, cond)
return retTree
def binSplitDataSet(dataSet, feature, value):
''' 根据属性feature的特定value划分数据集
:param dataSet:
:param feature:
:param value:
:return:
'''
mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
return mat0,mat1
对数据的复杂关系建模,我们已决定借用树结构帮助切分数据,那么如何实现数据的切分呢?怎么才知道是否已经切分充分了呢?
这些问题的答案取决于叶节点的建模方式。
回归树假设叶节点是常数值,这种策略认为数据中的复杂关系可以用树结构来概括。
为成功构建以分段常数为叶节点的树,需要度量出数据的一致性。决策树使用树进行分类,会在给定节点时计算数据的混乱度。那么如何计算连续性数值的混乱度?事实上,在数据集上计算混乱度很简单,使用平方误差的总值即总方差。总方差可以通过方差乘以数据中样本点的个数得到。
选择最佳分裂属性的伪代码:
对每个特征:
对每个特征值:
将数据集切分成2份
计算切分的误差
如果当前误差小于最小误差,将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值
chooseBestSplit的代码:
def chooseBestSplit(dataSet, leafType, errType, cond=(1,4)):
''' 选择最佳待切分feature以及value
:param dataSet:
:param leafType:叶子节点的构建方法
:param errType: 总均方差计算方法
:param cond:
:return:
'''
tolS = cond[0]
tolN = cond[1]
# dataSet中都属于同一类别,直接返回
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
m,n = shape(dataSet)
S = errType(dataSet)
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
for splitVal in set(dataSet[:,featIndex].T.A[0]):
mat0,mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
continue;
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS:
return None, leafType(dataSet)
return bestIndex, bestValue
对于回归树,leafType和errType的方法为:
#取均值
def regLeaf(dataSet):
return mean(dataSet[:, -1])
#总方差和
def regErr(dataSet):
return var(dataSet[:, -1]) * shape(dataSet)[0]
对于回归树,叶子节点要是分段线性函数,为了找到最佳切分,对于给定的数据集,应该先用线性模型来进行拟合,然后计算真实的目标值与模型预测值间的差值,然后求这些差值的平方和即得到计算误差。
leafType和errType的方法为:
def linearSolve(dataSet):
''' 对dataSet进行线性拟合
:param dataSet:
:return:
'''
m,n = shape(dataSet)
X = mat(ones((m,n))); Y = mat(ones((m,1)))
X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:, -1]
xTx = X.T * X
if linalg.det(xTx) == 0.0:
raise NameError('This matrix is singular, cannot do inverse,\n\
try increasing the second value of ops')
ws = xTx.I * X.T * Y
return ws,X,Y
def modelLeaf(dataSet):
ws,X,Y = linearSolve(dataSet)
return ws
def modelErr(dataSet):
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y - yHat, 2))
剪枝有预剪枝和后剪枝。
chooseBestSplit函数中cond=(1,4)即为预剪枝的条件:
第一个元素表示划分前后方差和之差的阈值,第二个是基于某个特征和value划分后的数据集中样本个数阈值。
def isTree(obj):
return (type(obj).__name__ == 'dict')
def getMean(tree):
if isTree(tree['right']):
tree['right'] = getMean(tree['right'])
if isTree(tree['left']):
tree['left'] = getMean(tree['left'])
return (tree['left'] + tree['right'])/2.0
def prune(tree,testData):
''' 降低错误率剪枝
:param tree:
:param testData:
:return:
'''
if shape(testData)[0] == 0:
return getMean(tree)
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if (isTree(tree['left'])):
tree['left'] = prune(tree['left'], lSet)
if (isTree(tree['right'])):
tree['right'] = prune(tree['right'], rSet)
# 如果当前节点的left和right节点都是叶子节点
if (not isTree(tree['right'])) and (not isTree(tree['left'])):
errorNoMerge = sum(power(lSet[:,-1] - tree['left'], 2)) + sum(power(rSet[:, -1] - tree['right'], 2))
treeMean = (tree['left'] + tree['right'])/2.0
errorMerge = sum(power(testData[:, -1] - treeMean, 2))
# 如果合并后的误差比不合并的误差小,则返回合并后的叶子节点
if errorMerge < errorNoMerge:
print("merging")
return treeMean
else:
return tree
else:
return tree
回归树和模型树构造完成后,如何使用树模型进行数据预测呢?
思路如下:
从树根节点开始,根据属性feature和feature value判断沿左子树还是右子树向下走。直到叶子节点,如果是回归树直接叶子节点的值即为预测值,如果是回归树,使用叶子节点的线性模型计算得到预测值。
代码如下:
# 回归树预测值函数
def regTreeEval(model, inData):
return float(model)
# 模型树预测值函数
def modelTreeEval(model, inData):
n = shape(inData)[1]
X = mat(ones((1, n+1)))
X[:,1:n+1] = inData
return float(X * model)
def treeForecast(tree, inData, modelEval=regTreeEval):
''' 树预测函数
:param tree:
:param inData:
:param modelEval:
:return:
'''
if not isTree(tree):
return modelEval(tree, inData)
if inData[tree['spInd']] > tree['spVal']:
if isTree(tree['left']):
return treeForecast(tree['left'], inData, modelEval)
else:
return modelEval(tree['left'], inData)
else:
if isTree(tree['right']):
return treeForecast(tree['right'], inData, modelEval)
else:
return modelEval(tree['right'], inData)
# 预测值得一个工具函数
def createForecast(tree, testData, modelEval=regTreeEval):
m = len(testData)
yHat = mat(zeros((m,1)))
for i in range(m):
yHat[i,0] = treeForecast(tree, mat(testData[i]), modelEval)
return yHat
参考:
http://blog.csdn.net/wzmsltw/article/details/51057311
完整代码见github:
https://github.com/zhanggw/algorithm/tree/master/machine-learning/CART/cart.py