R语言机器学习mlr3:特征选择和hyperband调参

获取更多R语言和生信知识,请关注公众号:医学和生信笔记。
公众号后台回复R语言,即可获得海量学习资料!

目录

    • Hyperband调参
    • 特征选择
      • filters
      • 计算分数
      • 计算变量重要性
      • 组合方法(wrapper methods)
      • 自动选择

Hyperband调参

Hyperband调参可看做是一种特殊的随机搜索方式,俗话说:“鱼与熊掌不可兼得”,Hyperband就是取其一种,感兴趣的小伙伴可以自己学习一下。

在这里举一个简单的小例子说明:
假如你有8匹马,每匹马需要4个单位的食物才能发挥最好,但是你现在只有32个单位的食物,所以你需要制定一个策略,充分利用32个单位的食物(也就是你的计算资源)来找到最好的马。
两种策略,第一种:直接放弃4匹马,把所有的食物用在另外4匹马上,这样到最后你就能挑选出4匹马中最好的一匹。但是这样的问题就是你不知道被你舍弃的那4匹马会不会有更好的。
第2种策略:在最开始时每匹马给1个单位食物,然后看它们表现,把表现好的4匹留下,表现不好的就舍弃,给予剩下4匹马更多的食物,然后再把表现好的2匹留下,如此循环,最好把剩下的食物给最后1匹马。

我们主要介绍通过mlr3hyperband包实现这一方法。

library(mlr3verse)

set.seed(123)

ll = po("subsample") %>>% lrn("classif.rpart") # mlr3自带的管道符,先进行预处理

search_space = ps(
  classif.rpart.cp = p_dbl(lower = 0.001, upper = 0.1),
  classif.rpart.minsplit = p_int(lower = 1, upper = 10),
  subsample.frac = p_dbl(lower = 0.1, upper = 1, tags = "budget")
) # tags标记

instance = TuningInstanceSingleCrit$new(
  task = tsk("iris"),
  learner = ll,
  resampling = rsmp("holdout"),
  measure = msr("classif.ce"),
  terminator = trm("none"), # hyperband terminates itself
  search_space = search_space
)

接下来进行hyperband调参:

library(mlr3hyperband)

tuner <- tnr("hyperband", eta = 3)

lgr::get_logger("bbotk")$set_threshold("warn")

tuner$optimize(instance)
## INFO  [20:51:38.099] [mlr3] Running benchmark with 9 resampling iterations 
## INFO  [20:51:38.103] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.143] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.181] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.226] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.264] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.301] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.366] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.404] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.441] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.655] [mlr3] Finished benchmark 
## INFO  [20:51:38.804] [mlr3] Running benchmark with 8 resampling iterations 
## INFO  [20:51:38.807] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.846] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.883] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.924] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.961] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:38.998] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:39.035] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:39.073] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:39.111] [mlr3] Finished benchmark 
## INFO  [20:51:39.230] [mlr3] Running benchmark with 5 resampling iterations 
## INFO  [20:51:39.233] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:39.271] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:39.309] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:39.346] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:39.387] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [20:51:39.424] [mlr3] Finished benchmark
##    classif.rpart.cp classif.rpart.minsplit subsample.frac learner_param_vals
## 1:       0.07348139                      5      0.1111111          
##     x_domain classif.ce
## 1:        0.02

查看结果:

instance$result
##    classif.rpart.cp classif.rpart.minsplit subsample.frac learner_param_vals
## 1:       0.07348139                      5      0.1111111          
##     x_domain classif.ce
## 1:        0.02
instance$result_learner_param_vals
## $subsample.frac
## [1] 0.1111111
## 
## $subsample.stratify
## [1] FALSE
## 
## $subsample.replace
## [1] FALSE
## 
## $classif.rpart.xval
## [1] 0
## 
## $classif.rpart.cp
## [1] 0.07348139
## 
## $classif.rpart.minsplit
## [1] 5
instance$result_y
## classif.ce 
##       0.02

特征选择

特征选择也是一门艺术,当我们拿到一份数据时,有很多信息是冗余的,是无效的,对于建模是没有帮助的。这样的变量用于建模只会增加噪声,降低模型表现。把冗余信息去除,挑选最合适的变量的过程被称为特征选择

filters

这种方法首先把所有预测变量计算一个分数,然后按照分数进行排名,这样我们就可以根据分数挑选合适的预测变量了。

查看支持的计算分数的方法:

mlr_filters
##  with 20 stored values
## Keys: anova, auc, carscore, cmim, correlation, disr, find_correlation,
##   importance, information_gain, jmi, jmim, kruskal_test, mim, mrmr,
##   njmim, performance, permutation, relief, selected_features, variance

特征工程是很复杂的,想要详细了解的可阅读相关书籍。

计算分数

目前只支持分类和回归。

filter <- flt("jmim")

task <- tsk("iris")
filter$calculate(task)

filter
## 
## Task Types: classif, regr
## Task Properties: -
## Packages: mlr3filters, praznik
## Feature types: integer, numeric, factor, ordered
##         feature     score
## 1:  Petal.Width 1.0000000
## 2: Sepal.Length 0.6666667
## 3: Petal.Length 0.3333333
## 4:  Sepal.Width 0.0000000

可以看到每个变量都计算出来一个分数。

# 根据相关性挑选变量
filter_cor <- flt("correlation")

# 支持更改参数,默认是pearson
filter_cor$param_set
## 
##        id    class lower upper nlevels    default value
## 1:    use ParamFct    NA    NA       5 everything      
## 2: method ParamFct    NA    NA       3    pearson
# 可以更改为spearman
filter_cor$param_set$values <- list(method = "spearman")
filter_cor$param_set
## 
##        id    class lower upper nlevels    default    value
## 1:    use ParamFct    NA    NA       5 everything         
## 2: method ParamFct    NA    NA       3    pearson spearman

计算变量重要性

所有支持importance参数的learner都支持这种方法。

比如:

lrn <- lrn("classif.ranger", importance = "impurity")

task <- tsk("iris")
filter <- flt("importance", learner = lrn)
filter$calculate(task)
filter
## 
## Task Types: classif
## Task Properties: -
## Packages: mlr3filters, mlr3, mlr3learners, ranger
## Feature types: logical, integer, numeric, character, factor, ordered
##         feature     score
## 1: Petal.Length 44.420716
## 2:  Petal.Width 43.235616
## 3: Sepal.Length  9.470614
## 4:  Sepal.Width  2.180197

组合方法(wrapper methods)

和超参数调优很相似,mlr3fselect包提供支持。

library(mlr3fselect)

task <- tsk("pima")
learner <- lrn("classif.rpart")
hout <- rsmp("holdout")
measure <- msr("classif.ce")

evals20 <- trm("evals", n_evals = 20) # 设置何时停止

# 构建实例
instance <- FSelectInstanceSingleCrit$new(
  task = task,
  learner = learner,
  resampling = hout,
  measure = measure,
  terminator = evals20
)
instance
## 
## * State:  Not optimized
## * Objective: 
## * Search Space:
## 
##          id    class lower upper nlevels        default value
## 1:      age ParamLgl    NA    NA       2       
## 2:  glucose ParamLgl    NA    NA       2       
## 3:  insulin ParamLgl    NA    NA       2       
## 4:     mass ParamLgl    NA    NA       2       
## 5: pedigree ParamLgl    NA    NA       2       
## 6: pregnant ParamLgl    NA    NA       2       
## 7: pressure ParamLgl    NA    NA       2       
## 8:  triceps ParamLgl    NA    NA       2       
## * Terminator: 
## * Terminated: FALSE
## * Archive:
## 
## Null data.table (0 rows and 0 cols)

目前mlr3fselect支持以下方法:

  • Random Search(FSelectRandomSearch)
  • Exhaustive Search (FSelectorExhaustiveSearch)
  • Sequential Search (FSelectorSequential)
  • Recursive Feature Elimination (FSelectorRFE)
  • Design Points (FSelectorDesignPoints)

我们挑选一个随机搜索:

fselector <- fs("random_search")

开始运行:

lgr::get_logger("bbotk")$set_threshold("warn")

fselector$optimize(instance)
## INFO  [20:51:39.787] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [20:51:41.955] [mlr3] Finished benchmark
##      age glucose insulin  mass pedigree pregnant pressure triceps features
## 1: FALSE    TRUE   FALSE FALSE    FALSE    FALSE    FALSE   FALSE  glucose
##    classif.ce
## 1:     0.1875

查看选中的变量:

instance$result_feature_set
## [1] "glucose"

查看结果:

instance$result_y
## classif.ce 
##     0.1875
as.data.table(instance$archive)
##       age glucose insulin  mass pedigree pregnant pressure triceps classif.ce
##  1: FALSE   FALSE   FALSE FALSE    FALSE    FALSE     TRUE   FALSE  0.3828125
##  2:  TRUE   FALSE    TRUE FALSE     TRUE     TRUE     TRUE    TRUE  0.3593750
##  3: FALSE   FALSE    TRUE FALSE    FALSE    FALSE    FALSE   FALSE  0.2890625
##  4: FALSE    TRUE    TRUE  TRUE    FALSE    FALSE     TRUE   FALSE  0.2343750
##  5: FALSE    TRUE   FALSE  TRUE    FALSE     TRUE    FALSE   FALSE  0.2226562
##  6: FALSE    TRUE   FALSE FALSE    FALSE    FALSE    FALSE   FALSE  0.1875000
##  7: FALSE    TRUE   FALSE  TRUE    FALSE    FALSE    FALSE   FALSE  0.2226562
##  8: FALSE   FALSE    TRUE FALSE    FALSE     TRUE    FALSE   FALSE  0.2812500
##  9:  TRUE    TRUE    TRUE  TRUE     TRUE     TRUE     TRUE    TRUE  0.2265625
## 10:  TRUE   FALSE   FALSE FALSE    FALSE     TRUE     TRUE   FALSE  0.3085938
## 11:  TRUE    TRUE   FALSE FALSE    FALSE    FALSE    FALSE    TRUE  0.2343750
## 12: FALSE    TRUE   FALSE FALSE     TRUE    FALSE    FALSE    TRUE  0.2460938
## 13:  TRUE    TRUE    TRUE  TRUE    FALSE     TRUE     TRUE    TRUE  0.2539062
## 14: FALSE    TRUE   FALSE FALSE     TRUE    FALSE     TRUE    TRUE  0.2148438
## 15: FALSE    TRUE    TRUE  TRUE     TRUE     TRUE    FALSE    TRUE  0.2226562
## 16: FALSE   FALSE    TRUE FALSE     TRUE     TRUE    FALSE   FALSE  0.2968750
## 17: FALSE    TRUE   FALSE FALSE    FALSE    FALSE    FALSE   FALSE  0.1875000
## 18: FALSE   FALSE    TRUE  TRUE     TRUE     TRUE     TRUE   FALSE  0.3750000
## 19: FALSE    TRUE    TRUE  TRUE     TRUE     TRUE     TRUE    TRUE  0.2343750
## 20:  TRUE   FALSE    TRUE FALSE     TRUE     TRUE     TRUE    TRUE  0.3593750
##     runtime_learners           timestamp batch_nr      resample_result
##  1:             0.03 2022-02-27 20:51:39        1 
##  2:             0.05 2022-02-27 20:51:39        2 
##  3:             0.03 2022-02-27 20:51:40        3 
##  4:             0.03 2022-02-27 20:51:40        4 
##  5:             0.03 2022-02-27 20:51:40        5 
##  6:             0.03 2022-02-27 20:51:40        6 
##  7:             0.04 2022-02-27 20:51:40        7 
##  8:             0.04 2022-02-27 20:51:40        8 
##  9:             0.03 2022-02-27 20:51:40        9 
## 10:             0.03 2022-02-27 20:51:40       10 
## 11:             0.03 2022-02-27 20:51:40       11 
## 12:             0.03 2022-02-27 20:51:41       12 
## 13:             0.05 2022-02-27 20:51:41       13 
## 14:             0.05 2022-02-27 20:51:41       14 
## 15:             0.03 2022-02-27 20:51:41       15 
## 16:             0.03 2022-02-27 20:51:41       16 
## 17:             0.04 2022-02-27 20:51:41       17 
## 18:             0.05 2022-02-27 20:51:41       18 
## 19:             0.03 2022-02-27 20:51:41       19 
## 20:             0.04 2022-02-27 20:51:41       20 
instance$archive$benchmark_result$data
## 
##   Public:
##     as_data_table: function (view = NULL, reassemble_learners = TRUE, convert_predictions = TRUE, 
##     clone: function (deep = FALSE) 
##     combine: function (rdata) 
##     data: list
##     discard: function (backends = FALSE, models = FALSE) 
##     initialize: function (data = NULL, store_backends = TRUE) 
##     iterations: function (view = NULL) 
##     learner_states: function (view = NULL) 
##     learners: function (view = NULL, states = TRUE, reassemble = TRUE) 
##     logs: function (view = NULL, condition) 
##     prediction: function (view = NULL, predict_sets = "test") 
##     predictions: function (view = NULL, predict_sets = "test") 
##     resamplings: function (view = NULL) 
##     sweep: function () 
##     task_type: active binding
##     tasks: function (view = NULL) 
##     uhashes: function (view = NULL) 
##   Private:
##     deep_clone: function (name, value) 
##     get_view_index: function (view)

应用于模型,训练任务:

task$select(instance$result_feature_set) # 只使用选中的变量
learner$train(task)

自动选择

learner = lrn("classif.rpart")
terminator = trm("evals", n_evals = 10)
fselector = fs("random_search")

at = AutoFSelector$new(
  learner = learner,
  resampling = rsmp("holdout"),
  measure = msr("classif.ce"),
  terminator = terminator,
  fselector = fselector
)
at
## 
## * Model: -
## * Parameters: xval=0
## * Packages: mlr3, mlr3fselect, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

比较不同的子集得到的模型表现:

grid = benchmark_grid(
  task = tsk("pima"),
  learner = list(at, lrn("classif.rpart")),
  resampling = rsmp("cv", folds = 3)
)

bmr = benchmark(grid, store_models = TRUE)
## INFO  [20:51:42.111] [mlr3] Running benchmark with 6 resampling iterations 
## INFO  [20:51:45.672] [mlr3] Finished benchmark
bmr$aggregate(msrs(c("classif.ce", "time_train")))
##    nr      resample_result task_id              learner_id resampling_id iters
## 1:  1     pima classif.rpart.fselector            cv     3
## 2:  2     pima           classif.rpart            cv     3
##    classif.ce time_train
## 1:  0.2539062          0
## 2:  0.2539062          0

获取更多R语言和生信知识,请关注公众号:医学和生信笔记。
公众号后台回复R语言,即可获得海量学习资料!

你可能感兴趣的:(R语言机器学习,r语言,机器学习,数据挖掘)