6.4 回归树 CART
回归树(也称为分类回归树CART)主要以一种树状结构来表达回归分析模型的回归算法,该类方法不仅可以应用于回归分析(称为回归树),也可以用于分类分析(称为分类树)。
6.4.1 rpart函数
library(rpart)
sol.rpart<-rpart(Sepal.Length~Sepal.Width+Petal.Length+Petal.Width,data=iris)
plot(sol.rpart,uniform=TRUE,compress=TRUE,lty=3,branch=0.7)
text(sol.rpart,all=TRUE,digits=7,use.n=TRUE,cex=0.9,xpd=TRUE)
sol.rpart
n=150
node),split, n, deviance, yval
* denotes terminal node
1) root 150 102.1683000 5.843333
2) Petal.Length< 4.25 73 13.1391800 5.179452
4) Petal.Length< 3.4 53 6.1083020 5.005660
8) Sepal.Width< 3.25 20 1.0855000 4.735000 *
9) Sepal.Width>=3.25 33 2.6696970 5.169697 *
5) Petal.Length>=3.4 20 1.1880000 5.640000 *
3) Petal.Length>=4.25 77 26.3527300 6.472727
6) Petal.Length< 6.05 68 13.4923500 6.326471
12) Petal.Length< 5.15 43 8.2576740 6.165116
24) Sepal.Width< 3.05 33 5.2218180 6.054545 *
25) Sepal.Width>=3.05 10 1.3010000 6.530000 *
13) Petal.Length>=5.15 25 2.1896000 6.604000 *
7) Petal.Length>=6.05 9 0.4155556 7.577778 *
除了使用rpart包自带的plot函数外,还可以使用maptree包的draw.tree函数绘制更为复杂的树形结构图。代码如下:
library(maptree)
draw.tree(sol.rpart)
draw.tree(sol.rpart,nodeinfo=TRUE)
leaf<-which(sol.rpart$frame$var=="
point.col<-""
for(iin 1:nrow(iris)){
point.col[i]<-which(leaf==sol.rpart$where[i])
}
plot(iris$Petal.Length,iris$Petal.Width,pch=16,col=point.col,xlab="Petal.Length",ylab="Petal.Windth")
6.4.2 预测及模型性能衡量
使用predict函数可以预测模型的样本,并使用所有的预测误差的总和来衡量模型的性能
6.4.3 过度拟合和剪枝
一个合理的模型,其测试集误差(代价)和回归树的规模(复杂)都要尽可能地小。
sol.rpart$cptable
CP nsplit rel error xerror xstd
10.61346237 0 1.0000000 1.01249960.09881164
20.12180701 1 0.3865376 0.39519900.04881294
30.05718872 2 0.2647306 0.29629410.03158525
40.02980452 3 0.2075419 0.25411440.03060910
50.02303165 4 0.1777374 0.26360600.03072885
60.01698037 5 0.1547057 0.24934100.02970677
70.01000000 6 0.1377254 0.23798590.02933008
剪枝实际就是找到一个合理的cp值,达到平衡测试集误差和树规模的目的,从而使其数值均很小。
使用plotcp函数可以绘制出cp的波动关系,从而直接观察出合适的cp取值。
plotcp(sol.rpart,minline=TRUE,lty=3,col=1,upper=c("size","splits","none"))
prune(sol.rpart,0.030)$cptable
CP nsplit rel error xerror xstd
10.61346237 0 1.0000000 1.01249960.09881164
20.12180701 1 0.3865376 0.39519900.04881294
30.05718872 2 0.2647306 0.29629410.03158525
40.03000000 3 0.2075419 0.25411440.03060910
r<-prune(sol.rpart,0.030)
plot(r,uniform=T,compress=T,lty=3,branch=0.7)
text(r,all=TRUE,digits=7,use.n=T,cex=0.9,xpd=TRUE)