基于mlr3工具包的机器学习(1)——数据、模型、训练、预测


专注系列化、高质量的R语言教程

(查看推文索引


mlr3是一个关于机器学习的工具包,关于它的详细介绍可参见:

  • 网页版:https://mlr3book.mlr-org.com/intro.html

  • pdf版:https://mlr3book.mlr-org.com/mlr3book.pdf

在本篇推文中,学堂君将介绍机器学习其中的四个环节:数据、模型、训练、预测,但并不涉及具体的模型算法。目录如下:

  • 1 基本框架

  • 2 数据与任务

    • 2.1 任务的类型

    • 2.2 创建任务

    • 2.3 内置任务

    • 2.4 任务的属性和方法

  • 3 模型

  • 4 训练和预测

1 基本框架

为了帮助读者快速入门该包的使用方法,本节以mtcars数据集为例,展示机器学习从创建任务、选择模型到训练、预测四个步骤。

创建任务

任务(task)是一种对象格式,用来储存模型的数据(data)和元数据(meta-data)。简单来说,它就是我们在运行模型前所有准备工作的封装。

这里仅选取mtcars数据集中的三个变量,并使用mpg作为响应变量(target)来创建一个回归类型的任务:

library(mlr3)
## 数据集
data <- mtcars[, c("mpg", "cyl", "disp")] 
## 创建任务
task <- as_task_regr(x = data, target = "mpg")

选择模型

模型是封装过的算法,称之为learner。这里选择一个可以适用于回归任务的regr.rpart

learner <- lrn("regr.rpart")
learner
## : Regression Tree
## * Model: -
## * Parameters: xval=0
## * Packages: mlr3, rpart
## * Predict Types:  [response]
## * Feature Types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, selected_features, weights

训练

learner是一个R6对象,可以使用它的train方法来训练任务。这里task参数使用的是前面创建的任务,训练后的model属性就是训练结果:

learner$train(task = task) 
learner$model
## n= 32 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
## 1) root 32 1126.04700 20.09062  
##   2) cyl>=5 21  198.47240 16.64762  
##     4) cyl>=7 14   85.20000 15.10000 *
##     5) cyl< 7 7   12.67714 19.74286 *
##   3) cyl< 5 11  203.38550 26.66364 *

因为是基于R6工具包的面向对象的编程,因此不需要learner <- learner$train(task = task)就可以实现learner对象的更新。

预测

在训练后,我们可以使用“新数据”来进行预测,调用的是predict_newdata方法:

## 新数据
newdata <- data.frame(cyl = rep(c(4,6,8), 3),
                      disp = rep(c(100, 150, 200), each = 3)) 
## 预测
learner$predict_newdata(newdata)
##  for 9 observations:
##     row_ids truth response
##           1    NA 26.66364
##           2    NA 19.74286
##           3    NA 15.10000
## ---                       
##           7    NA 26.66364
##           8    NA 19.74286
##           9    NA 15.10000

下面分节详细介绍各步骤。

2 数据与任务

2.1 任务的类型

任务和任务类型是对象和类的关系。查看前面创建的任务task的类型:

class(task)
## [1] "TaskRegr"       "TaskSupervised" "Task"           "R6"

常见的任务类型有:

  • 回归任务:响应变量(target)是数值型,类别为TaskRegr

  • 分类任务:响应变量是标签型(包括字符串和因子),类别为TaskClassif

  • 生存分析任务:响应变量为时间-事件型(time to an event),类别为拓展包mlr3proba 中的TaskSurv

  • 密度分析任务:非监督学习,用来预测密度,类别为拓展包mlr3proba 中的TaskDens

  • 聚类分析任务:类别为拓展包mlr3cluster中的TaskClust

  • 空间分析任务:类别为拓展包mlr3spatiotempcv中的TaskRegrSTTaskClassifST

  • 有序回归任务:响应变量为有序变量,类别为拓展包mlr3ordinal中的TaskOrdinal

上述很多任务类型来自mlr3的拓展包。实际上,mlr3工具包只保留了一些比较基础的功能,更复杂的功能则分门别类地由拓展包完成,这构成了以mlr3为核心的R工具包生态。如下图[1]

基于mlr3工具包的机器学习(1)——数据、模型、训练、预测_第1张图片

2.2 创建任务

TaskRegr为例,它是一个R6类,因此可以使用new方法来创建一个归属于该类的任务对象(可参见推文R语言与面向对象的编程(3):R6类)。

task0 <- TaskRegr$new(id = "task", backend = data,
                      target = "mpg")
class(task0)
## [1] "TaskRegr"       "TaskSupervised" "Task"           "R6"

也可以像前面一样使用快捷函数as_task_regr()

task <- as_task_regr(x = data, target = "mpg")

此外,mlr3工具包中还有如下同类型的函数:

## 创建分类任务
as_task_classif() 
## 创建非监督任务
as_task_unsupervised()

2.3 内置任务

mlr3工具包还内置了一些任务,储存在mlr_tasks对象中:

mlr_tasks 
##  with 11 stored values
## Keys: boston_housing, breast_cancer, german_credit, iris, mtcars,
##   penguins, pima, sonar, spam, wine, zoo

使用get方法可以提取某个内置任务:

task_mtcars <- mlr_tasks$get("mtcars")

更简便的方法是使用快捷函数tsk()

task_mtcars <- tsk("mtcars")

2.4 任务的属性和方法

任务作为一个R6对象,也有其属性和方法,下面一一进行介绍。

数据

data方法可以按行、按列查看数据;backend属性是备份数据,即创建任务时的data参数:

task_mtcars$data()
##      mpg am carb cyl  disp drat gear  hp  qsec vs    wt
##  1: 21.0  1    4   6 160.0 3.90    4 110 16.46  0 2.620
##  2: 21.0  1    4   6 160.0 3.90    4 110 17.02  0 2.875
##  3: 22.8  1    1   4 108.0 3.85    4  93 18.61  1 2.320
##  4: 21.4  0    1   6 258.0 3.08    3 110 19.44  1 3.215
##  5: 18.7  0    2   8 360.0 3.15    3 175 17.02  0 3.440
##  6: 18.1  0    1   6 225.0 2.76    3 105 20.22  1 3.460
task_mtcars$data(rows = c(1, 5, 10), cols = "mpg")
##     mpg
## 1: 21.0
## 2: 18.7
## 3: 19.2
task_mtcars$backend
##  (32x13)
##              model  mpg cyl disp  hp drat    wt  qsec vs am gear carb ..row_id
##          Mazda RX4 21.0   6  160 110 3.90 2.620 16.46  0  1    4    4        1
##      Mazda RX4 Wag 21.0   6  160 110 3.90 2.875 17.02  0  1    4    4        2
##         Datsun 710 22.8   4  108  93 3.85 2.320 18.61  1  1    4    1        3
##     Hornet 4 Drive 21.4   6  258 110 3.08 3.215 19.44  1  0    3    1        4
##  Hornet Sportabout 18.7   8  360 175 3.15 3.440 17.02  0  0    3    2        5
##            Valiant 18.1   6  225 105 2.76 3.460 20.22  1  0    3    1        6
## [...] (26 rows omitted)

行数和列数

nrow属性是行数,ncol属性是列数,row_ids属性是行编号:

task_mtcars$nrow
## [1] 32
task_mtcars$ncol
## [1] 11
task_mtcars$row_ids
##  [1]  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
## [26] 26 27 28 29 30 31 32

响应变量和预测变量

target_names属性是响应变量,即创建任务时的target参数;feature_names是预测变量:

task_mtcars$target_names
## [1] "mpg"
task_mtcars$feature_names
##  [1] "am"   "carb" "cyl"  "disp" "drat" "gear" "hp"   "qsec" "vs"   "wt"

行、列角色

col_roles属性是列的角色,对应值为列名;row_roles属性是行的角色,对应值为行编号:

task_mtcars$col_roles
## $feature
##  [1] "am"   "carb" "cyl"  "disp" "drat" "gear" "hp"   "qsec" "vs"   "wt"  
## 
## $target
## [1] "mpg"
## 
## $name
## [1] "model"
## 
## $order
## character(0)
## 
## $stratum
## character(0)
## 
## $group
## character(0)
## 
## $weight
## character(0)
task_mtcars$row_roles
## $use
##  [1]  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
## [26] 26 27 28 29 30 31 32
## 
## $holdout
## integer(0)
## 
## $early_stopping
## integer(0)

克隆

clone方法可以用来备份任务对象:

task_mtcars_clone <- task_mtcars$clone()

列选择

select方法类似dplyr工具包中的同名函数,用来选择除响应变量以外的列(变量):

task_mtcars$select(c("cyl", "disp")) 

task_mtcars$feature_names
## [1] "cyl"  "disp"

执行列选择后,被删除的列不再是预测变量。

行筛选

filter方法类似dplyr工具包中的同名函数,用来筛选行(样本):

task_mtcars$filter(2:5) 

task_mtcars$data()
##     mpg cyl disp
## 1: 21.0   6  160
## 2: 22.8   4  108
## 3: 21.4   6  258
## 4: 18.7   8  360
task_mtcars$backend 
##  (32x13)
##              model  mpg cyl disp  hp drat    wt  qsec vs am gear carb ..row_id
##          Mazda RX4 21.0   6  160 110 3.90 2.620 16.46  0  1    4    4        1
##      Mazda RX4 Wag 21.0   6  160 110 3.90 2.875 17.02  0  1    4    4        2
##         Datsun 710 22.8   4  108  93 3.85 2.320 18.61  1  1    4    1        3
##     Hornet 4 Drive 21.4   6  258 110 3.08 3.215 19.44  1  0    3    1        4
##  Hornet Sportabout 18.7   8  360 175 3.15 3.440 17.02  0  0    3    2        5
##            Valiant 18.1   6  225 105 2.76 3.460 20.22  1  0    3    1        6
## [...] (26 rows omitted)
  • data方法输出结果为当前的数据情况,会随列选择和行筛选等操作变化而变化;

  • backend属性为备份数据,是最初创建任务所用的data参数,不会变化。

需要注意的是,行筛选使用的数字是行编号,而非行序号:

task_mtcars$filter(2) 

task_mtcars$data()
##    mpg cyl disp
## 1:  21   6  160

前面task_mtcars$filter(2:5)所筛选的是行编号为2-5的样本,在筛选后其行编号不会改变,这样编号为2的行的序号就是1了。使用task_mtcars$filter(2)所筛选的是第1行数据,即行编号为2的数据。

数据合并

rbind方法可以用来增加行数,cbind方法可以用来增加列数:

task_mtcars$rbind(
  data.frame(mpg = 1, cyl = 2, disp = 3)
) 

task_mtcars$cbind(
  data.frame(data.frame(newvar = 1:2))
)

task_mtcars$data()
##    mpg cyl disp newvar
## 1:  21   6  160      1
## 2:   1   2    3      2

3 模型

learner对应着模型算法。mlr3工具包及其拓展包内置了许多learner,储存在mlr_learners对象中。同样,可以使用get方法提取某个learner

mlr_learners
##  with 6 stored values
## Keys: classif.debug, classif.featureless, classif.rpart, regr.debug,
##   regr.featureless, regr.rpart

learner0 <- mlr_learners$get("regr.rpart")

也可以使用快捷函数lrn()

learner = lrn("regr.rpart")

查看模型的参数设置:

learner$param_set
## 
##                 id    class lower upper nlevels        default value
##  1:             cp ParamDbl     0     1     Inf           0.01      
##  2:     keep_model ParamLgl    NA    NA       2          FALSE      
##  3:     maxcompete ParamInt     0   Inf     Inf              4      
##  4:       maxdepth ParamInt     1    30      30             30      
##  5:   maxsurrogate ParamInt     0   Inf     Inf              5      
##  6:      minbucket ParamInt     1   Inf     Inf       
##  7:       minsplit ParamInt     1   Inf     Inf             20      
##  8: surrogatestyle ParamInt     0     1       2              0      
##  9:   usesurrogate ParamInt     0     2       3              2      
## 10:           xval ParamInt     0   Inf     Inf             10     0

如果有需要也可以进行更改:

learner$param_set$values$cp = 0.01
learner$param_set
## 
##                 id    class lower upper nlevels        default value
##  1:             cp ParamDbl     0     1     Inf           0.01  0.01
##  2:     keep_model ParamLgl    NA    NA       2          FALSE      
##  3:     maxcompete ParamInt     0   Inf     Inf              4      
##  4:       maxdepth ParamInt     1    30      30             30      
##  5:   maxsurrogate ParamInt     0   Inf     Inf              5      
##  6:      minbucket ParamInt     1   Inf     Inf       
##  7:       minsplit ParamInt     1   Inf     Inf             20      
##  8: surrogatestyle ParamInt     0     1       2              0      
##  9:   usesurrogate ParamInt     0     2       3              2      
## 10:           xval ParamInt     0   Inf     Inf             10     0

4 训练和预测

使用前20行数据进行模型训练:

task_mtcars_clone$filter(1:20)
learner$train(task_mtcars_clone) 

learner$model 
## n= 20 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
## 1) root 20 792.76200 20.13000  
##   2) wt>=3.3275 11  93.85636 15.68182 *
##   3) wt< 3.3275 9 215.24000 25.56667 *

使用21-32行数据进行预测:

learner$predict_newdata(mtcars[21:32,])
##  for 12 observations:
##     row_ids truth response
##           1  21.5 25.56667
##           2  15.5 15.68182
##           3  15.2 15.68182
## ---                       
##          10  19.7 25.56667
##          11  15.0 15.68182
##          12  21.4 25.56667

参考资料

[1]

mlr3的工具包生态: https://mlr3book.mlr-org.com/intro.html#package-ecosystem

你可能感兴趣的:(基于mlr3工具包的机器学习(1)——数据、模型、训练、预测)