近日受小C同学的影响,开始慢慢培养写博客的习惯,“开坑之作”(之前那篇请无视~~)就打算谈谈最近研究的CART。鄙人不才,为了写这篇博客参考了不少资料,若写的有不正确的地方,还请各位大牛指正。
套话说完了,正式开始吧。CART全名为分类与回归树,意指该模型可以同时处理分类与回归问题。对于给定的训练数据集,CART通过最小化数据集的GINI系数(分类树)或者基于最小二乘准则最小化输入与输出的总均方误差(回归树)实现机器学习任务,本文首先介绍CART在回归问题中的应用。回归树的生成可分为两步—树的生成和剪枝。
对于给定的训练数据集 T={(x1,y1),(x2,y2),...(xN,yN)} ,回归树希望按照某几个特征对数据集进行递归式划分以形成二叉树,使得划分后的数据集叶子结点的输出尽可能接近训练样本的y值。这个过程主要涉及到分裂数据集的特征选择和树的递归生成。就特征选择而言,若设选择数据集T的j号特征某个分量s作为分割的阈值,将数据集分为 R1={x|xj≤s} , R2={x|xj>s} 两部分,则分割后的数据集与实际y值的均方误差可表示为:
如果生成的CART树枝条太多,容易把数据集中的一些噪声也拟合进去,这时候就需要减去一些枝条,防止CART树出现过拟合。剪枝又分为预剪枝和后剪枝。后剪枝需要一定数据,因此,实际使用CART树时,常常将训练数据分为训练集和剪枝的数据集。
预剪枝通过调整树停止生长的策略,如提前终止树生长(通过调整均方误差下降的最小值实现)等可实现。这种方法不需要给定数据集,但是受到建模者所给参数的影响太大,有较大弊病。
后剪枝的具体做法是,将数据根据训练好的树模型将数据集递归地分割到叶子结点,然后考虑减去叶子和不剪去叶子两种情况下数据集的均方误差值,如果剪枝使得该值变小,则剪之,否则放弃。遍历所有结点,剪去所有冗余的枝条,就实现了后剪枝。
预剪枝和后剪枝在实际实现CART算法时常常结合使用,最大可能地避免树的过拟合。
基于Python的CART回归树模型实现参考了《机器学习实战》一书。使用numpy和pygraphviz绘图包。pygraphviz安装并不是直接pip install就能搞定的,具体安装步骤参见http://www.cnblogs.com/AimeeKing/p/5021675.html
主要包含了树的创建函数createTree、根据特征的分量将数据集分裂的函数splitData、选择最优分裂点的函数chooseSplit、剪枝函数cutBranches、画树函数drawTree等。
splitData函数根据选定的特征的目标分量thres将数据集分为该特征数值大于thres和小于thres的两部分数据。
def splitData(data, feature, thres):
mat0 = data[np.nonzero(data[:, feature] > thres)[0], :]
mat1 = data[np.nonzero(data[:, feature] <= thres)[0], :]
return mat0, mat1
chooseSplit函数遍历所有特征的所有分量,寻找最合适的数据分裂方案
def chooseSplit(data, ops=1):
feature = data[:, 0:-1]
# 获得特征数目以及样本数
minFunc = []
oriVal = float(np.var(data[:, -1]) * len(data[:, -1]))
sampleNum, featureNum = map(int, np.shape(feature))
if len(np.unique(np.array(feature))) == 1: # 剩余样本都一样
return None, np.mean(data[:, -1]) # 既然都一样,拿哪个特征无所谓
for i in range(featureNum):
tempFeature = feature[:, i] # 获得切分数据集
FeatureFunc = []
for thres in tempFeature:
# 拆分标签集
subMat1, subMat2 = splitData(data, i, thres)
# 获得标签便于计算目标函数
y1, y2 = [subMat1[:, -1], subMat2[:, -1]]
# 处理一下空集的情况
if len(y2) == 0:
FeatureFunc.append(np.var(y1) * len(y1))
elif len(y1) == 0:
FeatureFunc.append(np.var(y2) * len(y2))
else:
FeatureFunc.append(np.var(y1) * len(y1) + np.var(y2) * len(y2))
minFunc.append(FeatureFunc)
# 寻找最优分割特征与数值
locFeature, locVal = np.where(minFunc == np.amin(minFunc))
# 下降值小于ops,不再生长树
if (oriVal - np.amin(minFunc)) < ops:
return None, float(np.var(data[:, -1]) * len(data[:, -1]))
spVal = float(data[locVal, locFeature[0]]) # 用于分割数据的特征
return locFeature[0], spVal
splitData函数用于树的创建,首先选择当前数据集的最优分裂方案,如果只有一个数据点就返回,否则创建节点,并分裂数据集,将两个数据集递归地传给节点的左右子树继续分裂。最后返回生成的树。
# 创建树,生成的树叶子结点没有左右子树!!!!
def createTree(data, ops=1):
# 选中最优的分割
name, val = chooseSplit(data)
if name is None:
return val # 直接返回
tree = {'node': name, 'val': val}
# 递归建树
ldata, rdata = splitData(data, name, val)
tree['rchild'] = createTree(rdata, ops)
tree['lchild'] = createTree(ldata, ops)
return tree
cutBranches用于树的后剪枝,这里需要一个判断节点是否为叶子结点的函数,由于Python是借助字典实现树结构的,所以可以判断当前结点类型是否为dict来实现。如果左枝条或者右枝条不为叶子结点,则按照给定的树模型分割数据集,递归剪枝过程。当左右节点为叶子结点时,就先按照叶子结点的要求分割一次数据集计算y的均方误差,再计算不剪枝时的均方误差,比较判断是否有必要剪枝。
# 后剪枝
def cutBranches(tree, testData=[]):
if len(testData) == 0: # 没有测试数据
print "没有测试数据,不能剪枝!"
# 左只或右枝不为叶子,则进行数据分割
if not (isLeaf(tree['lchild'])) or not (isLeaf(tree['rchild'])):
ldata, rdata = splitData(testData, tree['node'], tree['val'])
# 左枝不为叶子
if not (isLeaf(tree['lchild'])):
tree['lchild'] = cutBranches(tree['lchild'], ldata)
# 右枝不为叶子
if not isLeaf(tree['rchild']):
tree['rchild'] = cutBranches(tree['rchild'], rdata)
# 两边都是树叶,开始判断要不要减支
if isLeaf(tree['lchild']) and isLeaf(tree['rchild']):
ldata, rdata = splitData(testData, tree['node'], tree['val'])
# 不进行剪枝的目标函数值
noMerge = sum(np.power(ldata[:, -1] - tree['lchild'], 2)) + \
sum(np.power(rdata[:, -1] - tree['rchild'], 2))
treeMean = 0.5 * (tree['lchild'] + tree['rchild'])
Merge = sum(np.power(ldata[:, -1] - treeMean, 2))
# 判断是否要剪枝
if Merge < noMerge:
print "merging..."
return treeMean # 返回左右子树平均值实现合并
else:
return tree
else:
return tree
下图为剪枝之前CART生成的树,可以看到,生成的树臃肿,带有大量的叶子结点。如果不设置ops=1,则该树将会变得更加庞大臃肿,它甚至可能为每一个样本生成一个节点。
下图为经过后剪枝之后的CART树,多余的枝条被减去,整棵树较之前显得更为小巧。说明在一定数据集的支撑下,后剪枝能够起到一定效果。
画树函数drawTree则判断节点是否有lchild属性或者rchild属性,如果有,则使用pygraphviz添加从节点到左右子树的节点然后递归调用画树函数,否则只添加有向边。
# 画树
def drawTree(graph, tree):
# 递归画树
if tree.has_key('lchild'):
if not isLeaf(tree['lchild']):
graph.add_edge(tree['val'], tree['lchild']['val'])
drawTree(graph, tree['lchild'])
else:
graph.add_edge(tree['val'], tree['lchild'])
if tree.has_key('rchild'):
if not isLeaf(tree['rchild']):
graph.add_edge(tree['val'], tree['rchild']['val'])
drawTree(graph, tree['rchild'])
else:
graph.add_edge(tree['val'], tree['rchild'])
drawTree(newGraph, tree)
newGraph.layout(prog='dot')
newGraph.draw('treeP.jpg')
代码和相关数据我传了一份供交流参考。地址:https://github.com/FlyingRoastDuck/CART_REG
[1] 李航. 统计学习方法 [M]. 北京:清华大学出版社, 2012: 65-70.
[2] Peter H. 机器学习实战 [M]. 北京:人民邮电出版社,2013: 161-170.