#决策树 分类因变量
install.packages("rpart")
library(rpart)
fit=rpart(Species~.,method = "class",iris) #class表示分类树,利用iris数据
fit
plot(fit) #画出决策树
text(fit) #树上添加文字
install.packages("rpart.plot") #上图不好看,进一步画图
library(rpart.plot)
rpart.plot(fit) #图形美观 标注概率
printcp(fit) #复杂度参数
test1=data.frame(iris[-5])
pre=predict(fit,test1,type = "class")
table(pre,iris$Species) #混淆矩阵对比结果
#决策树 数值型因变量 回归树
data(airquality) #R语言自带数据
air=airquality
air=na.omit(air) #排除缺失值
head(air) #显示前6行
fit1=rpart(Ozone~.,method="anova",air) #anova是回归树
rpart.plot(fit1) #树叶上显示的是中位数
printcp(fit1) ##复杂度参数
pfit=prune(fit1,cp=0.018) #设定cp对决策树剪枝,避免过拟合导致泛化能力降低
rpart.plot(pfit) #剪枝后对决策树
软件操作结果
> install.packages("rpart")
trying URL 'https://cran.rstudio.com/bin/windows/contrib/3.5/rpart_4.1-13.zip'
Content type 'application/zip' length 950799 bytes (928 KB)
downloaded 928 KB
package ‘rpart’ successfully unpacked and MD5 sums checked
The downloaded binary packages are in
C:\Users\L3M309NJSJ\AppData\Local\Temp\Rtmp8aeHh4\downloaded_packages
> library(rpart)
Warning message:
程辑包‘rpart’是用R版本3.5.3 来建造的
> fit=rpart(Species~.,method = "class",iris)
> fit
n= 150
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)
2) Petal.Length< 2.45 50 0 setosa (1.00000000 0.00000000 0.00000000) *
3) Petal.Length>=2.45 100 50 versicolor (0.00000000 0.50000000 0.50000000)
6) Petal.Width< 1.75 54 5 versicolor (0.00000000 0.90740741 0.09259259) *
7) Petal.Width>=1.75 46 1 virginica (0.00000000 0.02173913 0.97826087) *
> plot(fit)
> text(fit)
> install.packages("rpart.plot")
trying URL 'https://cran.rstudio.com/bin/windows/contrib/3.5/rpart.plot_3.0.6.zip'
Content type 'application/zip' length 1064991 bytes (1.0 MB)
downloaded 1.0 MB
package ‘rpart.plot’ successfully unpacked and MD5 sums checked
The downloaded binary packages are in
C:\Users\L3M309NJSJ\AppData\Local\Temp\Rtmp8aeHh4\downloaded_packages
> library(rpart.plot)
Warning message:
程辑包‘rpart.plot’是用R版本3.5.3 来建造的
> rpart.plot(fit)
> printcp(fit)
Classification tree:
rpart(formula = Species ~ ., data = iris, method = "class")
Variables actually used in tree construction:
[1] Petal.Length Petal.Width
Root node error: 100/150 = 0.66667
n= 150
CP nsplit rel error xerror xstd
1 0.50 0 1.00 1.19 0.049592
2 0.44 1 0.50 0.69 0.061041
3 0.01 2 0.06 0.08 0.027520
> test1=data.frame(iris[-5])
> pre=predict(fit,test1,type = "class")
> table(pre,iris$Species)
pre setosa versicolor virginica
setosa 50 0 0
versicolor 0 49 5
virginica 0 1 45
> #决策树 数值型因变量 回归树
> data(airquality) #R语言自带数据
> air=airquality
> air=na.omit(air) #排除缺失值
> head(air)
Ozone Solar.R Wind Temp Month Day
1 41 190 7.4 67 5 1
2 36 118 8.0 72 5 2
3 12 149 12.6 74 5 3
4 18 313 11.5 62 5 4
7 23 299 8.6 65 5 7
8 19 99 13.8 59 5 8
> fit1=rpart(Ozone~.,method="anova",air)
> rpart.plot(fit1)
> printcp(fit1)
Regression tree:
rpart(formula = Ozone ~ ., data = air, method = "anova")
Variables actually used in tree construction:
[1] Solar.R Temp Wind
Root node error: 121802/111 = 1097.3
n= 111
CP nsplit rel error xerror xstd
1 0.484386 0 1.00000 1.01597 0.17257
2 0.097983 1 0.51561 0.55327 0.18254
3 0.057062 2 0.41763 0.60577 0.18874
4 0.020210 3 0.36057 0.50724 0.15448
5 0.018716 4 0.34036 0.50462 0.14075
6 0.017460 5 0.32164 0.48447 0.14038
7 0.010000 6 0.30418 0.48333 0.14044
> pfit=prune(fit1,cp=0.018)
> rpart.plot(pfit)