专注系列化、高质量的R语言教程
(查看推文索引)
mlr3
是一个关于机器学习的工具包,关于它的详细介绍可参见:
网页版:https://mlr3book.mlr-org.com/intro.html
pdf版:https://mlr3book.mlr-org.com/mlr3book.pdf
在本篇推文中,学堂君将介绍机器学习其中的四个环节:数据、模型、训练、预测,但并不涉及具体的模型算法。目录如下:
1 基本框架
2 数据与任务
2.1 任务的类型
2.2 创建任务
2.3 内置任务
2.4 任务的属性和方法
3 模型
4 训练和预测
为了帮助读者快速入门该包的使用方法,本节以mtcars
数据集为例,展示机器学习从创建任务、选择模型到训练、预测四个步骤。
创建任务
任务(task)是一种对象格式,用来储存模型的数据(data)和元数据(meta-data)。简单来说,它就是我们在运行模型前所有准备工作的封装。
这里仅选取mtcars
数据集中的三个变量,并使用mpg
作为响应变量(target)来创建一个回归类型的任务:
library(mlr3)
## 数据集
data <- mtcars[, c("mpg", "cyl", "disp")]
## 创建任务
task <- as_task_regr(x = data, target = "mpg")
选择模型
模型是封装过的算法,称之为learner
。这里选择一个可以适用于回归任务的regr.rpart
:
learner <- lrn("regr.rpart")
learner
## : Regression Tree
## * Model: -
## * Parameters: xval=0
## * Packages: mlr3, rpart
## * Predict Types: [response]
## * Feature Types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, selected_features, weights
训练
learner
是一个R6对象,可以使用它的train
方法来训练任务。这里task
参数使用的是前面创建的任务,训练后的model
属性就是训练结果:
learner$train(task = task)
learner$model
## n= 32
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 32 1126.04700 20.09062
## 2) cyl>=5 21 198.47240 16.64762
## 4) cyl>=7 14 85.20000 15.10000 *
## 5) cyl< 7 7 12.67714 19.74286 *
## 3) cyl< 5 11 203.38550 26.66364 *
因为是基于
R6
工具包的面向对象的编程,因此不需要learner <- learner$train(task = task)
就可以实现learner
对象的更新。
预测
在训练后,我们可以使用“新数据”来进行预测,调用的是predict_newdata
方法:
## 新数据
newdata <- data.frame(cyl = rep(c(4,6,8), 3),
disp = rep(c(100, 150, 200), each = 3))
## 预测
learner$predict_newdata(newdata)
## for 9 observations:
## row_ids truth response
## 1 NA 26.66364
## 2 NA 19.74286
## 3 NA 15.10000
## ---
## 7 NA 26.66364
## 8 NA 19.74286
## 9 NA 15.10000
下面分节详细介绍各步骤。
任务和任务类型是对象和类的关系。查看前面创建的任务task
的类型:
class(task)
## [1] "TaskRegr" "TaskSupervised" "Task" "R6"
常见的任务类型有:
回归任务:响应变量(target
)是数值型,类别为TaskRegr
;
分类任务:响应变量是标签型(包括字符串和因子),类别为TaskClassif
;
生存分析任务:响应变量为时间-事件型(time to an event),类别为拓展包mlr3proba
中的TaskSurv
;
密度分析任务:非监督学习,用来预测密度,类别为拓展包mlr3proba
中的TaskDens
;
聚类分析任务:类别为拓展包mlr3cluster
中的TaskClust
;
空间分析任务:类别为拓展包mlr3spatiotempcv
中的TaskRegrST
或TaskClassifST
;
有序回归任务:响应变量为有序变量,类别为拓展包mlr3ordinal
中的TaskOrdinal
。
上述很多任务类型来自mlr3
的拓展包。实际上,mlr3
工具包只保留了一些比较基础的功能,更复杂的功能则分门别类地由拓展包完成,这构成了以mlr3
为核心的R工具包生态。如下图[1]:
以TaskRegr
为例,它是一个R6类,因此可以使用new
方法来创建一个归属于该类的任务对象(可参见推文R语言与面向对象的编程(3):R6类)。
task0 <- TaskRegr$new(id = "task", backend = data,
target = "mpg")
class(task0)
## [1] "TaskRegr" "TaskSupervised" "Task" "R6"
也可以像前面一样使用快捷函数as_task_regr()
:
task <- as_task_regr(x = data, target = "mpg")
此外,mlr3
工具包中还有如下同类型的函数:
## 创建分类任务
as_task_classif()
## 创建非监督任务
as_task_unsupervised()
mlr3
工具包还内置了一些任务,储存在mlr_tasks
对象中:
mlr_tasks
## with 11 stored values
## Keys: boston_housing, breast_cancer, german_credit, iris, mtcars,
## penguins, pima, sonar, spam, wine, zoo
使用get
方法可以提取某个内置任务:
task_mtcars <- mlr_tasks$get("mtcars")
更简便的方法是使用快捷函数tsk()
task_mtcars <- tsk("mtcars")
任务作为一个R6对象,也有其属性和方法,下面一一进行介绍。
数据
data
方法可以按行、按列查看数据;backend
属性是备份数据,即创建任务时的data
参数:
task_mtcars$data()
## mpg am carb cyl disp drat gear hp qsec vs wt
## 1: 21.0 1 4 6 160.0 3.90 4 110 16.46 0 2.620
## 2: 21.0 1 4 6 160.0 3.90 4 110 17.02 0 2.875
## 3: 22.8 1 1 4 108.0 3.85 4 93 18.61 1 2.320
## 4: 21.4 0 1 6 258.0 3.08 3 110 19.44 1 3.215
## 5: 18.7 0 2 8 360.0 3.15 3 175 17.02 0 3.440
## 6: 18.1 0 1 6 225.0 2.76 3 105 20.22 1 3.460
task_mtcars$data(rows = c(1, 5, 10), cols = "mpg")
## mpg
## 1: 21.0
## 2: 18.7
## 3: 19.2
task_mtcars$backend
## (32x13)
## model mpg cyl disp hp drat wt qsec vs am gear carb ..row_id
## Mazda RX4 21.0 6 160 110 3.90 2.620 16.46 0 1 4 4 1
## Mazda RX4 Wag 21.0 6 160 110 3.90 2.875 17.02 0 1 4 4 2
## Datsun 710 22.8 4 108 93 3.85 2.320 18.61 1 1 4 1 3
## Hornet 4 Drive 21.4 6 258 110 3.08 3.215 19.44 1 0 3 1 4
## Hornet Sportabout 18.7 8 360 175 3.15 3.440 17.02 0 0 3 2 5
## Valiant 18.1 6 225 105 2.76 3.460 20.22 1 0 3 1 6
## [...] (26 rows omitted)
行数和列数
nrow
属性是行数,ncol
属性是列数,row_ids
属性是行编号:
task_mtcars$nrow
## [1] 32
task_mtcars$ncol
## [1] 11
task_mtcars$row_ids
## [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
## [26] 26 27 28 29 30 31 32
响应变量和预测变量
target_names
属性是响应变量,即创建任务时的target
参数;feature_names
是预测变量:
task_mtcars$target_names
## [1] "mpg"
task_mtcars$feature_names
## [1] "am" "carb" "cyl" "disp" "drat" "gear" "hp" "qsec" "vs" "wt"
行、列角色
col_roles
属性是列的角色,对应值为列名;row_roles
属性是行的角色,对应值为行编号:
task_mtcars$col_roles
## $feature
## [1] "am" "carb" "cyl" "disp" "drat" "gear" "hp" "qsec" "vs" "wt"
##
## $target
## [1] "mpg"
##
## $name
## [1] "model"
##
## $order
## character(0)
##
## $stratum
## character(0)
##
## $group
## character(0)
##
## $weight
## character(0)
task_mtcars$row_roles
## $use
## [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
## [26] 26 27 28 29 30 31 32
##
## $holdout
## integer(0)
##
## $early_stopping
## integer(0)
克隆
clone
方法可以用来备份任务对象:
task_mtcars_clone <- task_mtcars$clone()
列选择
select
方法类似dplyr
工具包中的同名函数,用来选择除响应变量以外的列(变量):
task_mtcars$select(c("cyl", "disp"))
task_mtcars$feature_names
## [1] "cyl" "disp"
执行列选择后,被删除的列不再是预测变量。
行筛选
filter
方法类似dplyr
工具包中的同名函数,用来筛选行(样本):
task_mtcars$filter(2:5)
task_mtcars$data()
## mpg cyl disp
## 1: 21.0 6 160
## 2: 22.8 4 108
## 3: 21.4 6 258
## 4: 18.7 8 360
task_mtcars$backend
## (32x13)
## model mpg cyl disp hp drat wt qsec vs am gear carb ..row_id
## Mazda RX4 21.0 6 160 110 3.90 2.620 16.46 0 1 4 4 1
## Mazda RX4 Wag 21.0 6 160 110 3.90 2.875 17.02 0 1 4 4 2
## Datsun 710 22.8 4 108 93 3.85 2.320 18.61 1 1 4 1 3
## Hornet 4 Drive 21.4 6 258 110 3.08 3.215 19.44 1 0 3 1 4
## Hornet Sportabout 18.7 8 360 175 3.15 3.440 17.02 0 0 3 2 5
## Valiant 18.1 6 225 105 2.76 3.460 20.22 1 0 3 1 6
## [...] (26 rows omitted)
data
方法输出结果为当前的数据情况,会随列选择和行筛选等操作变化而变化;
backend
属性为备份数据,是最初创建任务所用的data
参数,不会变化。
需要注意的是,行筛选使用的数字是行编号,而非行序号:
task_mtcars$filter(2)
task_mtcars$data()
## mpg cyl disp
## 1: 21 6 160
前面
task_mtcars$filter(2:5)
所筛选的是行编号为2-5的样本,在筛选后其行编号不会改变,这样编号为2的行的序号就是1了。使用task_mtcars$filter(2)
所筛选的是第1行数据,即行编号为2的数据。
数据合并
rbind
方法可以用来增加行数,cbind
方法可以用来增加列数:
task_mtcars$rbind(
data.frame(mpg = 1, cyl = 2, disp = 3)
)
task_mtcars$cbind(
data.frame(data.frame(newvar = 1:2))
)
task_mtcars$data()
## mpg cyl disp newvar
## 1: 21 6 160 1
## 2: 1 2 3 2
learner
对应着模型算法。mlr3
工具包及其拓展包内置了许多learner
,储存在mlr_learners
对象中。同样,可以使用get
方法提取某个learner
:
mlr_learners
## with 6 stored values
## Keys: classif.debug, classif.featureless, classif.rpart, regr.debug,
## regr.featureless, regr.rpart
learner0 <- mlr_learners$get("regr.rpart")
也可以使用快捷函数lrn()
:
learner = lrn("regr.rpart")
查看模型的参数设置:
learner$param_set
##
## id class lower upper nlevels default value
## 1: cp ParamDbl 0 1 Inf 0.01
## 2: keep_model ParamLgl NA NA 2 FALSE
## 3: maxcompete ParamInt 0 Inf Inf 4
## 4: maxdepth ParamInt 1 30 30 30
## 5: maxsurrogate ParamInt 0 Inf Inf 5
## 6: minbucket ParamInt 1 Inf Inf
## 7: minsplit ParamInt 1 Inf Inf 20
## 8: surrogatestyle ParamInt 0 1 2 0
## 9: usesurrogate ParamInt 0 2 3 2
## 10: xval ParamInt 0 Inf Inf 10 0
如果有需要也可以进行更改:
learner$param_set$values$cp = 0.01
learner$param_set
##
## id class lower upper nlevels default value
## 1: cp ParamDbl 0 1 Inf 0.01 0.01
## 2: keep_model ParamLgl NA NA 2 FALSE
## 3: maxcompete ParamInt 0 Inf Inf 4
## 4: maxdepth ParamInt 1 30 30 30
## 5: maxsurrogate ParamInt 0 Inf Inf 5
## 6: minbucket ParamInt 1 Inf Inf
## 7: minsplit ParamInt 1 Inf Inf 20
## 8: surrogatestyle ParamInt 0 1 2 0
## 9: usesurrogate ParamInt 0 2 3 2
## 10: xval ParamInt 0 Inf Inf 10 0
使用前20行数据进行模型训练:
task_mtcars_clone$filter(1:20)
learner$train(task_mtcars_clone)
learner$model
## n= 20
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 20 792.76200 20.13000
## 2) wt>=3.3275 11 93.85636 15.68182 *
## 3) wt< 3.3275 9 215.24000 25.56667 *
使用21-32行数据进行预测:
learner$predict_newdata(mtcars[21:32,])
## for 12 observations:
## row_ids truth response
## 1 21.5 25.56667
## 2 15.5 15.68182
## 3 15.2 15.68182
## ---
## 10 19.7 25.56667
## 11 15.0 15.68182
## 12 21.4 25.56667
[1]
mlr3的工具包生态: https://mlr3book.mlr-org.com/intro.html#package-ecosystem