作者:顾全,浙江大学软件工程硕士,现任桃树科技算法工程师
地址:
https://github.com/ZJUguquan/OnlineRandomForest
参与:Cynthia
翻译:本文为天善智能编译,未经容许,禁止转载
Online Random Forest(ORF) 是由Amir Saffari等人最先提出。之后,Arthur Lui使用Python实现了算法。非常感谢他们的工作。在论文内容和Lui的算法的基础上,我通过R和R6包重构了代码。此外,ORF在此包中的实现,与randomForest结合,使它同时支持增量学习和批量学习,例如:在ORF的基础上构建树,然后通过ORF进行更新。通过这种方法,它将比以前快得多。
if(!require(devtools)) install.packages("devtools")
devtools::install_github("ZJUguquan/OnlineRandomForest")
最小举例:增量学习
library(OnlineRandomForest)
param <- list('minSamples'= 1, 'minGain'= 0.1, 'numClasses'= 3, 'x.rng'= dataRange(iris[1:4]))
orf <- ORF$new(param, numTrees = 10)
for (i in 1:150) orf$update(iris[i, 1:4], as.integer(iris[i, 5]))
cat("Mean depth of trees in the forest is:", orf$meanTreeDepth(), "\n")
orf$forest[[2]]$draw()
## Mean depth of trees in the forest is: 3
## Root X4 < 1.21
## |----L: X3 < 2.38
## |----L: Leaf 1
## |----R: Leaf 2
## |----R: X4 < 2.15
## |----L: X1 < 4.92
## |----L: Leaf 3
## |----R: Leaf 3
## |----R: Leaf 3
分类举例
library(OnlineRandomForest)
# data preparation
dat <- iris; dat[,5] <- as.integer(dat[,5])
x.rng <- dataRange(dat[1:4])
param <- list('minSamples'= 2, 'minGain'= 0.2, 'numClasses'= 3, 'x.rng'= x.rng)
ind.gen <- sample(1:150,30) # for generate ORF
ind.updt <- sample(setdiff(1:150, ind.gen), 100) # for uodate ORF
ind.test <- setdiff(setdiff(1:150, ind.gen), ind.updt) # for test
# construct ORF and update
rf <- randomForest::randomForest(factor(Species) ~ ., data = dat[ind.gen, ], maxnodes = 2, ntree = 100)
orf <- ORF$new(param)
orf$generateForest(rf, df.train = dat[ind.gen, ], y.col = "Species")
cat("Mean size of trees in the forest is:", orf$meanTreeSize(), "\n")
## Mean size of trees in the forest is: 3
for (i in ind.updt) {
orf$update(dat[i, 1:4], dat[i, 5])
}
cat("After update, mean size of trees in the forest is:", orf$meanTreeSize(), "\n")
## After update, mean size of trees in the forest is: 11.9
# predict
orf$confusionMatrix(dat[ind.test, 1:4], dat[ind.test, 5], pretty = T)
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Row Total |
## | N / Col Total |
## |-------------------------|
##
##
## Total Observations in Table: 20
##
##
## | actual
## prediction | 1 | 2 | 3 | Row Total |
## -------------|-----------|-----------|-----------|-----------|
## 1 | 4 | 0 | 0 | 4 |
## | 1.000 | 0.000 | 0.000 | 0.200 |
## | 1.000 | 0.000 | 0.000 | |
## -------------|-----------|-----------|-----------|-----------|
## 2 | 0 | 9 | 2 | 11 |
## | 0.000 | 0.818 | 0.182 | 0.550 |
## | 0.000 | 1.000 | 0.286 | |
## -------------|-----------|-----------|-----------|-----------|
## 3 | 0 | 0 | 5 | 5 |
## | 0.000 | 0.000 | 1.000 | 0.250 |
## | 0.000 | 0.000 | 0.714 | |
## -------------|-----------|-----------|-----------|-----------|
## Column Total | 4 | 9 | 7 | 20 |
## | 0.200 | 0.450 | 0.350 | |
## -------------|-----------|-----------|-----------|-----------|
##
##
# compare
table(predict(rf, newdata = dat[ind.test,]) == dat[ind.test, 5])
## FALSE TRUE
## 9 11
table(orf$predicts(X = dat[ind.test,]) == dat[ind.test, 5])
## FALSE TRUE
## 2 18
回归举例
# data preparation
if(!require(ggplot2)) install.packages("ggplot2")
data("diamonds", package = "ggplot2")
dat <- as.data.frame(diamonds[sample(1:53000,1000), c(1:6,8:10,7)])
for (col in c("cut","color","clarity")) dat[[col]] <- as.integer(dat[[col]]) # Don't forget this
x.rng <- dataRange(dat[1:9])
param <- list('minSamples'= 10, 'minGain'= 1, 'maxDepth' = 10, 'x.rng'= x.rng)
ind.gen <- sample(1:1000, 800)
ind.updt <- sample(setdiff(1:1000, ind.gen), 100)
ind.test <- setdiff(setdiff(1:1000, ind.gen), ind.updt)
# construct ORF
rf <- randomForest::randomForest(price ~ ., data = dat[ind.gen, ], maxnodes = 20, ntree = 100)
orf <- ORF$new(param)
orf$generateForest(rf, df.train = dat[ind.gen, ], y.col = "price")
orf$meanTreeSize()
## [1] 39
# and update
for (i in ind.updt) {
orf$update(dat[i, 1:9], dat[i, 10])}
orf$meanTreeSize()
## [1] 105.7
# predict and compare
if(!require(Metrics)) install.packages("Metrics")
preds.rf <- predict(rf, newdata = dat[ind.test,])
Metrics::rmse(preds.rf, dat$price[ind.test])
## [1] 988.8055
preds <- orf$predicts(dat[ind.test, 1:9])
Metrics::rmse(preds, dat$price[ind.test]) # make progress
## [1] 869.9613
ta <- Tree$new("abc", NULL, NULL)
tb <- Tree$new(1, Tree$new(36), Tree$new(3))
tc <- Tree$new(89, tb, ta)
tc$draw()# update tc
tc$right$updateChildren(Tree$new("666"), Tree$new(999))
tc$right$right$updateChildren(Tree$new("666"), Tree$new(999))
tc$draw()
通过random Forest包配置一个Online random Tree,并升级
# data preparation
library(randomForest)
dat1 <- iris; dat1[,5] <- as.integer(dat1[,5])
rf <- randomForest(factor(Species) ~ ., data = dat1, maxnodes = 3)
treemat1 <- getTree(rf, 1, labelVar=F)
treemat1 <- cbind(treemat1, node.ind = 1:nrow(treemat1))
x.rng1 <- dataRange(dat1[1:4])
param1 <- list('minSamples'= 5, 'minGain'= 0.1, 'numClasses'= 3, 'x.rng'= x.rng1)
ind.gen <- sample(1:150,50) # for generate ORT
ind.updt <- setdiff(1:150, ind.gen) # for update ORT
# origin
ort2 <- ORT$new(param1)
ort2$draw()
## Root 1
## Leaf 1
# generate a tree
ort2$generateTree(treemat1, df.node = dat1[ind.gen,])
ort2$draw()
## Root X3 < 2.45
## |----L: Leaf 1
## |----R: X3 < 4.75
## |----L: Leaf 2
## |----R: Leaf 3
# update this tree
for(i in ind.updt) {
ort2$update(dat1[i,1:4], dat1[i,5])
}
ort2$draw()
## Root X3 < 2.45
## |----L: Leaf 1
## |----R: X3 < 4.75
## |----L: Leaf 2
## |----R: X4 < 2.19
## |----L: X2 < 3.68
## |----L: X1 < 7.12
## |----L: X3 < 4.06
## |----L: Leaf 1
## |----R: Leaf 3
## |----R: Leaf 3
## |----R: Leaf 1
## |----R: Leaf 3
大家都在看
2017年R语言发展报告(国内)
R语言中文社区历史文章整理(作者篇)
R语言中文社区历史文章整理(类型篇)
公众号后台回复关键字即可学习
回复 R R语言快速入门及数据挖掘
回复 Kaggle案例 Kaggle十大案例精讲(连载中)
回复 文本挖掘 手把手教你做文本挖掘
回复 可视化 R语言可视化在商务场景中的应用
回复 大数据 大数据系列免费视频教程
回复 量化投资 张丹教你如何用R语言量化投资
回复 用户画像 京东大数据,揭秘用户画像
回复 数据挖掘 常用数据挖掘算法原理解释与应用
回复 机器学习 人工智能系列之机器学习与实践
回复 爬虫 R语言爬虫实战案例分享