本文转载于“ 一个拉风的名字”的“Regression Tree 回归树”
1. 引言
AI时代,机器学习算法成为了研究、应用的热点。当前,最火的两类算法莫过于神经网络算法(CNN、RNN、LSTM等)与树形算法(随机森林、GBDT、XGBoost等),树形算法的基础就是决策树。决策树因其易理解、易构建、速度快的特性,被广泛应用于统计学、数据挖掘、机器学习领域。因此,对决策树的学习,是机器学习之路必不可少的一步。
根据处理数据类型的不同,决策树又分为两类:分类决策树与回归决策树,前者可用于处理离散型数据,后者可用于处理连续型数据,下面的英文引用自维基百科。
Classification tree analysis is when the predicted outcome is the class to which the data belongs.
Regression tree analysis is when the predicted outcome can be considered a real number (e.g. the price of a house, or a patient’s length of stay in a hospital).
网络上有关于分类决策树的介绍可谓数不胜数,但是对回归决策树(回归树)的介绍却少之又少。李航教授的统计学习方法 对回归树有一个简单介绍,可惜篇幅较短,没有给出一个具体实例;Google搜索回归树,有一篇介绍回归树的博客(点击),该博客所举的实例有误,其过程事实上是基于残差的GBDT。
基于以上原因,本文简单介绍了回归树(Regression Tree),简单描述了CART算法,给出了回归树的算法描述,辅以简单实例以加深理解,最后是总结部分。
2. 回归树
决策树实际上是将空间用超平面进行划分的一种方法,每次分割的时候,都将当前的空间一分为二, 这样使得每一个叶子节点都是在空间中的一个不相交的区域,在进行决策的时候,会根据输入样本每一维feature的值,一步一步往下,最后使得样本落入N个区域中的一个(假设有N个叶子节点),如下图所示。
三种比较常见的分类决策树分支划分方式包括:ID3, C4.5, CART。
分类与回归树(classificationandregressiontree, CART)模型由Breiman等人在1984年提出,是应用广泛的决策树学习方法。CART同样由特征选择、树的生成及剪枝组成,既可以用于分类也可以用于回归。下面的英文引用自维基百科
The term Classification And Regression Tree (CART) analysis is an umbrella term used to refer to both of the above procedures, first introduced by Breiman et al. Trees used for regression and trees used for classification have some similarities - but also some differences, such as the procedure used to determine where to split.
下面介绍回归树。
2.1 原理概述
既然是决策树,那么必然会存在以下两个核心问题:如何选择划分点?如何决定叶节点的输出值?
一个回归树对应着输入空间(即特征空间)的一个划分以及在划分单元上的输出值。分类树中,我们采用信息论中的方法,通过计算选择最佳划分点。而在回归树中,采用的是启发式的方法。假如我们有n个特征,每个特征有
s i ( i ∈ ( 1 , n ) ) s_i(i \in (1,n)) si(i∈(1,n)) 个取值,那我们遍历所有特征,尝试该特征所有取值,对空间进行划分,直到取到特征j的取值s,使得损失函数最小,这样就得到了一个划分点。描述该过程的公式如下:(如果看不到图请点击永久地址)
假设将输入空间划分为M个单元: R 1 , R 2 , . . . , R m R_1,R_2,...,R_m R1,R2,...,Rm那么每个区域的输出值就是: c m = a v e ( y i ∣ x i ∈ R m ) c_m=ave(y_i|x_i \in R_m) cm=ave(yi∣xi∈Rm)也就是该区域内所有点y值的平均数。
举个例子。如下图所示,假如我们想要对楼内居民的年龄进行回归,将楼划分为3个区域 R 1 , R 2 , R 3 R_1, R_2, R_3 R1,R2,R3(红线),那么 R 1 R_1 R1的输出就是第一列四个居民年龄的平均值, R 2 R_2 R2的输出就是第二列四个居民年龄的平均值, R 3 R_3 R3的输出就是第三、四列八个居民年龄的平均值。
2.2 算法描述
为了便于理解,下面举一个简单实例。训练数据见下表,目标是得到一棵最小二乘回归树。
x | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
y | 5.56 | 5.7 | 5.91 | 6.4 | 6.8 | 7.05 | 8.9 | 8.7 | 9 | 9.05 |
1.选择最优切分变量j与最优切分点s
在本数据集中,只有一个变量,因此最优切分变量自然是x。
接下来我们考虑9个切分点 [ 1.5 , 2.5 , 3.5 , 4.5 , 5.5 , 6.5 , 7.5 , 8.5 , 9.5 ] [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5] [1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5],[1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5] 你可能会问,为什么会带小数点呢?类比于篮球比赛的博彩,倘若两队比分是96:95,而盘口是“让1分 A队胜B队”,那A队让1分之后,到底是A队赢还是B队赢了?所以我们经常可以看到“让0.5分 A队胜B队”这样的盘口。在这个实例中,也是这个道理。
损失函数定义为平方损失函数 L o s s ( y , f ( x ) ) = ( f ( x ) − y ) 2 Loss(y, f(x))=(f(x)-y)^2 Loss(y,f(x))=(f(x)−y)2,将上述9个切分点一依此代入下面的公式,其中 c m = a v e ( y i ∣ x i ∈ R m ) c_m=ave(y_i|x_i \in R_m) cm=ave(yi∣xi∈Rm)(如果看不到图请点击永久地址)
例如,取 s = 1.5 s=1.5 s=1.5。此时 R 1 = { 1 } , R 2 = { 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 } R_1=\{1\} , R_2=\{2,3,4,5,6,7,8,9,10\} R1={ 1},R2={ 2,3,4,5,6,7,8,9,10},这两个区域的输出值分别为: c 1 = 5.56 , c 2 = 1 9 ( 5.7 + 5.91 + 6.4 + 6.8 + 7.05 + 8.9 + 8.7 + 9 + 9.05 ) = 7.50 c_1=5.56, c_2= \frac{1}{9}(5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05)=7.50 c1=5.56,c2=91(5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05)=7.50,得到下表:
s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 | 6.5 | 7.5 | 8.5 | 9.5 |
---|---|---|---|---|---|---|---|---|---|
c1 | 5.56 | 5.63 | 5.72 | 5.89 | 6.07 | 6.24 | 6.62 | 6.88 | 7.11 |
c2 | 7.5 | 7.73 | 7.99 | 8.25 | 8.5 | 4 8.91 | 8.92 | 9.03 | 9.05 |
把 c 1 c_1 c1, c 2 c_2 c2的值代入到上式,如: m ( 1.5 ) = 0 + 15.72 = 15.72 m(1.5)=0+15.72=15.72 m(1.5)=0+15.72=15.72。同理,可获得下表:
s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 | 6.5 | 7.5 | 8.5 | 9.5 |
---|---|---|---|---|---|---|---|---|---|
m(s) | 15.72 | 12.07 | 8.36 | 5.78 | 3.91 | 1.93 | 8.01 | 11.73 | 15.74 |
显然取 s = 6.5 s=6.5 s=6.5时, m ( s ) m(s) m(s)最小。因此,第一个划分变量 j = x j=x j=x, s = 6.5 s=6.5 s=6.5
x | 1 | 2 | 3 | 4 | 5 | 6 |
---|---|---|---|---|---|---|
y | 5.56 | 5.7 | 5.91 | 6.4 | 6.8 | 7.05 |
取切分点 [ 1.5 , 2.5 , 3.5 , 4.5 , 5.5 ] [1.5,2.5,3.5,4.5,5.5] [1.5,2.5,3.5,4.5,5.5],则各区域的输出值 c c c如下表
s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 |
---|---|---|---|---|---|
c1 | 5.56 | 5.63 | 5.72 | 5.89 | 6.07 |
c2 | 6.37 | 6.54 | 6.75 | 6.93 | 7.05 |
计算m(s):
s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 |
---|---|---|---|---|---|
m(s) | 1.3087 | 0.754 | 0.2771 | 0.4368 | 1.0644 |
s=3.5时m(s)最小。
之后的过程不再赘述。
2.4 回归树VS线性回归
不多说了,直接看图甩代码
3. 总结
实际上,回归树总体流程类似于分类树,分枝时穷举每一个特征的每一个阈值,来寻找最优切分特征j和最优切分点s,衡量的方法是平方误差最小化。分枝直到达到预设的终止条件(如叶子个数上限)就停止。
当然,处理具体问题时,单一的回归树肯定是不够用的。可以利用集成学习中的boosting框架,对回归树进行改良升级,得到的新模型就是提升树(Boosting Decision Tree),在进一步,可以得到梯度提升树(Gradient Boosting Decision Tree,GBDT),再进一步可以升级到XGBoost。