diff --git a/.Rbuildignore b/.Rbuildignore
index ff5d83a..18a6d7e 100644
--- a/.Rbuildignore
+++ b/.Rbuildignore
@@ -23,3 +23,4 @@
.vscode
^kaggle$
^oml_cache$
+^vignettes$
diff --git a/.gitignore b/.gitignore
index 7c9632c..8a8a349 100644
--- a/.gitignore
+++ b/.gitignore
@@ -98,3 +98,4 @@ rsconnect/
.Rprofile
kaggle/
README.html
+inst/doc
diff --git a/DESCRIPTION b/DESCRIPTION
index eccd887..bf5cf3e 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -2,14 +2,14 @@ Package: mlr3automl
Title: AutoML extention for 'mlr3'
Version: 0.0.1
Authors@R: c(
- person("Damir", "Pulatov", , "damirpolat@protonmail.com", role = c("cre", "aut")),
- person("Marc", "Becker", , "marcbecker@posteo.de", role = "aut",
- comment = c(ORCID = "0000-0002-8115-0400")),
+ person("Marc", "Becker", , "marcbecker@posteo.de", role = c("cre", "aut"),
+ comment = c(ORCID = "0000-0002-8115-0400")),
+ person("Damir", "Pulatov", , "damirpolat@protonmail.com", role = "aut"),
person("Baisu", "Zhou", , "baisu.zhou@outlook.com", role = "aut")
)
Description: Flexible AutoML system for the 'mlr3' ecosystem.
License: LGPL-3
-URL: https://github.com/mlr-org/mlr3automl
+URL: https://github.com/mlr-org/mlr3automl https://mlr3automl.mlr-org.com
BugReports: https://github.com/mlr-org/mlr3automl/issues
Depends:
mlr3 (>= 1.2.0.9000),
@@ -35,6 +35,7 @@ Suggests:
fastai,
glmnet,
kknn,
+ knitr,
lgr,
lightgbm,
MASS,
@@ -42,22 +43,24 @@ Suggests:
mlr3extralearners,
mlr3torch (>= 0.3.0.9000),
mlr3viz,
+ quarto,
ranger,
redux,
reticulate,
+ rmarkdown,
rpart,
testthat (>= 3.0.0),
xgboost
Remotes:
catboost/catboost/catboost/R-package,
eagerai/fastai,
+ mlr-org/bbotk,
mlr-org/mlr3,
mlr-org/mlr3extralearners,
mlr-org/mlr3learners,
mlr-org/mlr3mbo@so_config_6,
mlr-org/mlr3torch,
mlr-org/mlr3tuning,
- mlr-org/bbotk,
mlr-org/rush
Config/testthat/edition: 3
Config/testthat/parallel: false
diff --git a/R/Auto.R b/R/Auto.R
index d2866b5..3724b8c 100644
--- a/R/Auto.R
+++ b/R/Auto.R
@@ -3,7 +3,7 @@
#' @description
#' This class is the base class for all autos.
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @template param_id
#' @template param_task
@@ -14,6 +14,8 @@
#' @template param_large_data_set
#' @template param_size
#' @template param_devices
+#' @template param_pv
+#' @template param_graph
#'
#' @export
Auto = R6Class("Auto",
@@ -80,7 +82,7 @@ Auto = R6Class("Auto",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
stop("Abstract")
},
@@ -154,6 +156,12 @@ Auto = R6Class("Auto",
#' Get the search space for the learner.
search_space = function(task) {
private$.search_space
+ },
+
+ #' @description
+ #' Modify the graph for the final model.
+ final_graph = function(graph, task, pv) {
+ graph
}
),
diff --git a/R/AutoCatboost.R b/R/AutoCatboost.R
index 2e65e6a..f8db213 100644
--- a/R/AutoCatboost.R
+++ b/R/AutoCatboost.R
@@ -1,6 +1,6 @@
#' @title Catboost Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R Auto.R
#'
#' @description
#' Catboost auto.
@@ -10,7 +10,12 @@
#' @template param_measure
#' @template param_n_threads
#' @template param_timeout
+#' @template param_memory_limit
+#' @template param_large_data_set
+#' @template param_size
#' @template param_devices
+#' @template param_pv
+#' @template param_graph
#'
#' @export
AutoCatboost = R6Class("AutoCatboost",
@@ -31,7 +36,7 @@ AutoCatboost = R6Class("AutoCatboost",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
@@ -52,10 +57,17 @@ AutoCatboost = R6Class("AutoCatboost",
task_type = task_type)
set_threads(learner, n_threads)
- po("removeconstants", id = "catboost_removeconstants") %>>%
+
+ graph = po("removeconstants", id = "catboost_removeconstants") %>>%
po("colapply", id = "catboost_colapply", applicator = as.numeric, affect_columns = selector_type("integer")) %>>%
po("removeconstants", id = "catboost_post_removeconstants") %>>%
learner
+
+ if (task$nrow * task$ncol > pv$large_data_size) {
+ graph = po("subsample", frac = 0.25, stratify = inherits(task, "TaskClassif"), use_groups = FALSE, id = "catboost_subsample") %>>% graph
+ }
+
+ graph
},
#' @description
@@ -115,6 +127,14 @@ AutoCatboost = R6Class("AutoCatboost",
"merror" # default
)
}
+ },
+
+ #' @description
+ #' Modify the graph for the final model.
+ final_graph = function(graph, task, pv) {
+ if (task$nrow * task$ncol > pv$large_data_size) {
+ graph$param_set$set_values(catboost_subsample.frac = 1)
+ }
}
),
diff --git a/R/AutoExtraTrees.R b/R/AutoExtraTrees.R
index fc93819..fe21d7b 100644
--- a/R/AutoExtraTrees.R
+++ b/R/AutoExtraTrees.R
@@ -1,17 +1,21 @@
#' @title Extra Trees Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @description
#' Extra Trees auto.
#'
#' @template param_id
-#' @template param_n_threads
-#' @template param_timeout
#' @template param_task
#' @template param_measure
+#' @template param_n_threads
+#' @template param_timeout
+#' @template param_memory_limit
+#' @template param_large_data_set
#' @template param_size
#' @template param_devices
+#' @template param_pv
+#' @template param_graph
#'
#' @export
AutoExtraTrees = R6Class("AutoExtraTrees",
@@ -37,7 +41,7 @@ AutoExtraTrees = R6Class("AutoExtraTrees",
#' @param measure ([mlr3::Measure]).
#' @param n_threads (`numeric(1)`).
#' @param timeout (`numeric(1)`).
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
@@ -53,13 +57,19 @@ AutoExtraTrees = R6Class("AutoExtraTrees",
sample.fraction = 1)
set_threads(learner, n_threads)
- po("removeconstants", id = "extra_trees_removeconstants") %>>%
+ graph = po("removeconstants", id = "extra_trees_removeconstants") %>>%
po("imputeoor", id = "extra_trees_imputeoor") %>>%
po("fixfactors", id = "extra_trees_fixfactors") %>>%
po("imputesample", affect_columns = selector_type(c("factor", "ordered")), id = "extra_trees_imputesample") %>>%
po("collapsefactors", target_level_count = 40, id = "extra_trees_collapse") %>>%
po("removeconstants", id = "extra_trees_post_removeconstants") %>>%
learner
+
+ if (task$nrow * task$ncol > pv$large_data_size) {
+ graph = po("subsample", frac = 0.25, stratify = inherits(task, "TaskClassif"), use_groups = FALSE, id = "extra_trees_subsample") %>>% graph
+ }
+
+ graph
},
#' @description
@@ -70,6 +80,14 @@ AutoExtraTrees = R6Class("AutoExtraTrees",
num_trees = 100
tree_size_bytes = task$nrow / 60000 * 1e6
ceiling((tree_size_bytes * num_trees) / 1e6)
+ },
+
+ #' @description
+ #' Modify the graph for the final model.
+ final_graph = function(graph, task, pv) {
+ if (task$nrow * task$ncol > pv$large_data_size) {
+ graph$param_set$set_values(extra_trees_subsample.frac = 1)
+ }
}
)
)
diff --git a/R/AutoFTTransformer.R b/R/AutoFTTransformer.R
index c8cc74c..7da8a39 100644
--- a/R/AutoFTTransformer.R
+++ b/R/AutoFTTransformer.R
@@ -1,6 +1,6 @@
#' @title FTTransformer Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @description
#' FTTransformer auto.
@@ -11,6 +11,7 @@
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
+#' @template param_pv
#'
#' @export
AutoFTTransformer = R6Class("AutoFTTransformer",
@@ -31,7 +32,7 @@ AutoFTTransformer = R6Class("AutoFTTransformer",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
diff --git a/R/AutoFastai.R b/R/AutoFastai.R
index b80585a..5087539 100644
--- a/R/AutoFastai.R
+++ b/R/AutoFastai.R
@@ -1,6 +1,6 @@
#' @title Fastai Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @description
#' Fastai auto.
@@ -13,6 +13,7 @@
#' @template param_memory_limit
#' @template param_large_data_set
#' @template param_devices
+#' @template param_pv
#'
#' @export
AutoFastai = R6Class("AutoFastai",
@@ -49,7 +50,7 @@ AutoFastai = R6Class("AutoFastai",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
diff --git a/R/AutoGlmnet.R b/R/AutoGlmnet.R
index be600b1..ac68196 100644
--- a/R/AutoGlmnet.R
+++ b/R/AutoGlmnet.R
@@ -1,6 +1,6 @@
#' @title Glmnet Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @description
#' Glmnet auto.
@@ -11,6 +11,7 @@
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
+#' @template param_pv
#'
#' @export
AutoGlmnet = R6Class("AutoGlmnet",
@@ -30,7 +31,7 @@ AutoGlmnet = R6Class("AutoGlmnet",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
diff --git a/R/AutoKknn.R b/R/AutoKknn.R
index a4275b9..91ed867 100644
--- a/R/AutoKknn.R
+++ b/R/AutoKknn.R
@@ -1,6 +1,6 @@
#' @title Kknn Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @description
#' Kknn auto.
@@ -11,6 +11,7 @@
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
+#' @template param_pv
#'
#' @export
AutoKknn = R6Class("AutoKknn",
@@ -30,7 +31,7 @@ AutoKknn = R6Class("AutoKknn",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
diff --git a/R/AutoLda.R b/R/AutoLda.R
index 9d66116..e0bec1a 100644
--- a/R/AutoLda.R
+++ b/R/AutoLda.R
@@ -1,6 +1,6 @@
#' @title Lda Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @description
#' Lda auto.
@@ -14,6 +14,7 @@
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
+#' @template param_pv
#'
#' @export
AutoLda = R6Class("AutoLda",
@@ -33,7 +34,7 @@ AutoLda = R6Class("AutoLda",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
diff --git a/R/AutoLightgbm.R b/R/AutoLightgbm.R
index e79b984..2c60eef 100644
--- a/R/AutoLightgbm.R
+++ b/R/AutoLightgbm.R
@@ -1,6 +1,6 @@
#' @title Lightgbm Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @description
#' Lightgbm auto.
@@ -10,7 +10,12 @@
#' @template param_measure
#' @template param_n_threads
#' @template param_timeout
+#' @template param_memory_limit
+#' @template param_large_data_set
+#' @template param_size
#' @template param_devices
+#' @template param_pv
+#' @template param_graph
#'
#' @export
AutoLightgbm = R6Class("AutoLightgbm",
@@ -30,7 +35,7 @@ AutoLightgbm = R6Class("AutoLightgbm",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
@@ -41,15 +46,19 @@ AutoLightgbm = R6Class("AutoLightgbm",
device_type = if ("cuda" %in% devices) "gpu" else "cpu"
- learner = lrn(sprintf("%s.lightgbm", task$task_type),
+ graph = lrn(sprintf("%s.lightgbm", task$task_type),
id = "lightgbm",
early_stopping_rounds = self$early_stopping_rounds(task),
callbacks = list(cb_timeout_lightgbm(timeout * 0.9)),
eval = self$internal_measure(measure, task),
device_type = device_type)
- set_threads(learner, n_threads)
+ set_threads(graph, n_threads)
- learner
+ if (task$nrow * task$ncol > pv$large_data_size) {
+ graph = po("subsample", frac = 0.25, stratify = inherits(task, "TaskClassif"), use_groups = FALSE, id = "lightgbm_subsample") %>>% graph
+ }
+
+ graph
},
#' @description
@@ -97,6 +106,15 @@ AutoLightgbm = R6Class("AutoLightgbm",
"classif.logloss" = "multi_logloss",
"multi_error") # default
}
+ },
+
+ #' @description
+ #' Modify the graph for the final model.
+ final_graph = function(graph, task, pv) {
+ if (task$nrow * task$ncol > pv$large_data_size) {
+ graph$param_set$set_values(lightgbm_subsample.frac = 1)
+ }
+ graph$param_set$set_values(lightgbm.callbacks = NULL)
}
),
diff --git a/R/AutoMlp.R b/R/AutoMlp.R
index 8108e99..36407d4 100644
--- a/R/AutoMlp.R
+++ b/R/AutoMlp.R
@@ -1,6 +1,6 @@
#' @title Mlp Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @description
#' Mlp auto.
@@ -11,6 +11,7 @@
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
+#' @template param_pv
#'
#' @export
AutoMlp = R6Class("AutoMlp",
@@ -31,7 +32,7 @@ AutoMlp = R6Class("AutoMlp",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
diff --git a/R/AutoRanger.R b/R/AutoRanger.R
index b019823..8195bfc 100644
--- a/R/AutoRanger.R
+++ b/R/AutoRanger.R
@@ -1,6 +1,6 @@
#' @title Ranger Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @description
#' Ranger auto.
@@ -10,8 +10,12 @@
#' @template param_measure
#' @template param_n_threads
#' @template param_timeout
+#' @template param_memory_limit
+#' @template param_large_data_set
+#' @template param_size
#' @template param_devices
-#'
+#' @template param_pv
+#' @template param_graph
#'
#' @export
AutoRanger = R6Class("AutoRanger",
@@ -31,7 +35,7 @@ AutoRanger = R6Class("AutoRanger",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
@@ -42,13 +46,19 @@ AutoRanger = R6Class("AutoRanger",
learner = lrn(sprintf("%s.ranger", task$task_type), id = "ranger")
set_threads(learner, n_threads)
- po("removeconstants", id = "ranger_removeconstants") %>>%
+ graph = po("removeconstants", id = "ranger_removeconstants") %>>%
po("imputeoor", id = "ranger_imputeoor") %>>%
po("fixfactors", id = "ranger_fixfactors") %>>%
po("imputesample", affect_columns = selector_type(c("factor", "ordered")), id = "ranger_imputesample") %>>%
po("collapsefactors", target_level_count = 100, id = "ranger_collapse") %>>%
po("removeconstants", id = "ranger_post_removeconstants") %>>%
learner
+
+ if (task$nrow * task$ncol > pv$large_data_size) {
+ graph = po("subsample", frac = 0.25, stratify = inherits(task, "TaskClassif"), use_groups = FALSE, id = "ranger_subsample") %>>% graph
+ }
+
+ graph
},
#' @description
@@ -62,6 +72,14 @@ AutoRanger = R6Class("AutoRanger",
memory_size = (tree_size * num_trees) / 1e6
lg$info("Ranger memory size: %s MB", round(memory_size))
ceiling(memory_size)
+ },
+
+ #' @description
+ #' Modify the graph for the final model.
+ final_graph = function(graph, task, pv) {
+ if (task$nrow * task$ncol > pv$large_data_size) {
+ graph$param_set$set_values(ranger_subsample.frac = 1)
+ }
}
),
diff --git a/R/AutoResNet.R b/R/AutoResNet.R
index e3264a7..96cc98d 100644
--- a/R/AutoResNet.R
+++ b/R/AutoResNet.R
@@ -1,6 +1,6 @@
#' @title ResNet Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @description
#' ResNet auto.
@@ -11,6 +11,7 @@
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
+#' @template param_pv
#'
#' @export
AutoResNet = R6Class("AutoResNet",
@@ -30,7 +31,7 @@ AutoResNet = R6Class("AutoResNet",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
diff --git a/R/AutoSvm.R b/R/AutoSvm.R
index 5b49027..aa6289a 100644
--- a/R/AutoSvm.R
+++ b/R/AutoSvm.R
@@ -1,6 +1,6 @@
#' @title Svm Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @description
#' Svm auto.
@@ -11,6 +11,7 @@
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
+#' @template param_pv
#'
#' @export
AutoSvm = R6Class("AutoSvm",
@@ -30,7 +31,7 @@ AutoSvm = R6Class("AutoSvm",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
diff --git a/R/AutoTabpfn.R b/R/AutoTabpfn.R
index 6a3beef..dda9d5e 100644
--- a/R/AutoTabpfn.R
+++ b/R/AutoTabpfn.R
@@ -1,6 +1,6 @@
#' @title Tabpfn Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @description
#' Tabpfn auto.
@@ -13,6 +13,7 @@
#' @template param_memory_limit
#' @template param_large_data_set
#' @template param_devices
+#' @template param_pv
#'
#' @export
AutoTabpfn = R6Class("AutoTabpfn",
@@ -52,7 +53,7 @@ AutoTabpfn = R6Class("AutoTabpfn",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
diff --git a/R/AutoXgboost.R b/R/AutoXgboost.R
index e786ca2..a66770d 100644
--- a/R/AutoXgboost.R
+++ b/R/AutoXgboost.R
@@ -1,6 +1,6 @@
#' @title Xgboost Auto
#'
-#' @include mlr_auto.R
+#' @include mlr_auto.R Auto.R
#'
#' @description
#' Xgboost auto.
@@ -10,7 +10,12 @@
#' @template param_measure
#' @template param_n_threads
#' @template param_timeout
+#' @template param_memory_limit
+#' @template param_large_data_set
+#' @template param_size
#' @template param_devices
+#' @template param_pv
+#' @template param_graph
#'
#' @export
AutoXgboost = R6Class("AutoXgboost",
@@ -30,7 +35,7 @@ AutoXgboost = R6Class("AutoXgboost",
#' @description
#' Create the graph for the auto.
- graph = function(task, measure, n_threads, timeout, devices) {
+ graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
@@ -52,13 +57,19 @@ AutoXgboost = R6Class("AutoXgboost",
learner$set_values(device = "cuda")
}
- po("removeconstants", id = "xgboost_removeconstants") %>>%
+ graph = po("removeconstants", id = "xgboost_removeconstants") %>>%
po("imputeoor", id = "xgboost_imputeoor") %>>%
po("fixfactors", id = "xgboost_fixfactors") %>>%
po("imputesample", affect_columns = selector_type(c("factor", "ordered")), id = "xgboost_imputesample") %>>%
po("encodeimpact", id = "xgboost_encode") %>>%
po("removeconstants", id = "xgboost_post_removeconstants") %>>%
learner
+
+ if (task$nrow * task$ncol > pv$large_data_size) {
+ graph = po("subsample", frac = 0.25, stratify = inherits(task, "TaskClassif"), use_groups = FALSE, id = "xgboost_subsample") %>>% graph
+ }
+
+ graph
},
#' @description
@@ -109,6 +120,15 @@ AutoXgboost = R6Class("AutoXgboost",
"merror" # default
)
}
+ },
+
+ #' @description
+ #' Modify the graph for the final model.
+ final_graph = function(graph, task, pv) {
+ if (task$nrow * task$ncol > pv$large_data_size) {
+ graph$param_set$set_values(xgboost_subsample.frac = 1)
+ }
+ graph$param_set$set_values(xgboost.callbacks = NULL)
}
),
diff --git a/R/train_auto.R b/R/train_auto.R
index 5cf5587..702391c 100644
--- a/R/train_auto.R
+++ b/R/train_auto.R
@@ -42,7 +42,7 @@ train_auto = function(self, private, task) {
error_config("All learners have no hyperparameters to tune. Combine with other learners.")
}
- branches = map(autos, function(auto) auto$graph(task, pv$measure, n_threads, pv$learner_timeout, pv$devices))
+ branches = map(autos, function(auto) auto$graph(task, pv$measure, n_threads, pv$learner_timeout, pv$devices, pv))
graph_learner = as_learner(po("branch", options = names(branches)) %>>%
gunion(unname(branches)) %>>%
po("unbranch", options = names(branches)), clone = TRUE)
@@ -139,15 +139,13 @@ train_auto = function(self, private, task) {
# fit final model
lg$info("Learner '%s' fits final model", self$id)
-
if (length(learners_with_validation)) {
set_validate(graph_learner, NULL, ids = learners_with_validation)
- # FIXME: remove this once we have a better way to handle this
- graph_learner$param_set$values$xgboost.callbacks = NULL
- graph_learner$param_set$values$lightgbm.callbacks = NULL
}
+
graph_learner$param_set$set_values(.values = self$instance$result_learner_param_vals, .insert = FALSE)
graph_learner$timeout = c(train = Inf, predict = Inf)
+ walk(autos, function(auto) auto$final_graph(graph_learner, task, pv))
graph_learner$train(task)
list(graph_learner = graph_learner, instance = self$instance)
diff --git a/man-roxygen/param_graph.R b/man-roxygen/param_graph.R
new file mode 100644
index 0000000..fd2812e
--- /dev/null
+++ b/man-roxygen/param_graph.R
@@ -0,0 +1,3 @@
+#' @param graph (`mlr3pipelines::GraphLearner`)\cr
+#' Graph learner.
+
diff --git a/man-roxygen/param_pv.R b/man-roxygen/param_pv.R
new file mode 100644
index 0000000..73a7570
--- /dev/null
+++ b/man-roxygen/param_pv.R
@@ -0,0 +1,3 @@
+#' @param pv (`list`)\cr
+#' Parameter values.
+
diff --git a/man/Auto.Rd b/man/Auto.Rd
index 9876eb5..0c66cd4 100644
--- a/man/Auto.Rd
+++ b/man/Auto.Rd
@@ -32,6 +32,7 @@ This class is the base class for all autos.
\item \href{#method-Auto-design_default}{\code{Auto$design_default()}}
\item \href{#method-Auto-design_set}{\code{Auto$design_set()}}
\item \href{#method-Auto-search_space}{\code{Auto$search_space()}}
+\item \href{#method-Auto-final_graph}{\code{Auto$final_graph()}}
\item \href{#method-Auto-clone}{\code{Auto$clone()}}
}
}
@@ -72,7 +73,7 @@ Creates a new instance of this \link[R6:R6Class]{R6} class.
\subsection{Method \code{check()}}{
Check if the auto is compatible with the task.
\subsection{Usage}{
-\if{html}{\out{
}}\preformatted{Auto$check(task, memory_limit = Inf, large_data_set = FALSE, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{Auto$check(task, memory_limit = Inf, large_data_set = FALSE, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -88,6 +89,9 @@ Check if the auto is compatible with the task.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
@@ -98,7 +102,7 @@ Default is "cpu".}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{Auto$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{Auto$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -116,6 +120,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
@@ -210,6 +217,29 @@ Get the search space for the learner.
}
}
\if{html}{\out{ }}
+\if{html}{\out{ }}
+\if{latex}{\out{\hypertarget{method-Auto-final_graph}{}}}
+\subsection{Method \code{final_graph()}}{
+Modify the graph for the final model.
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{Auto$final_graph(graph, task, pv)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{graph}}{(\code{mlr3pipelines::GraphLearner})\cr
+Graph learner.}
+
+\item{\code{task}}{(\link[mlr3:Task]{mlr3::Task}).}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
+}
+\if{html}{\out{
}}
+}
+}
+\if{html}{\out{ }}
\if{html}{\out{ }}
\if{latex}{\out{\hypertarget{method-Auto-clone}{}}}
\subsection{Method \code{clone()}}{
diff --git a/man/AutoCatboost.Rd b/man/AutoCatboost.Rd
index 209448f..78bdb90 100644
--- a/man/AutoCatboost.Rd
+++ b/man/AutoCatboost.Rd
@@ -16,6 +16,7 @@ Catboost auto.
\item \href{#method-AutoCatboost-graph}{\code{AutoCatboost$graph()}}
\item \href{#method-AutoCatboost-estimate_memory}{\code{AutoCatboost$estimate_memory()}}
\item \href{#method-AutoCatboost-internal_measure}{\code{AutoCatboost$internal_measure()}}
+\item \href{#method-AutoCatboost-final_graph}{\code{AutoCatboost$final_graph()}}
\item \href{#method-AutoCatboost-clone}{\code{AutoCatboost$clone()}}
}
}
@@ -54,7 +55,7 @@ Identifier for the new instance.}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoCatboost$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoCatboost$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -72,6 +73,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
@@ -113,6 +117,29 @@ Get the internal measure for the auto.
}
}
\if{html}{\out{ }}
+\if{html}{\out{ }}
+\if{latex}{\out{\hypertarget{method-AutoCatboost-final_graph}{}}}
+\subsection{Method \code{final_graph()}}{
+Modify the graph for the final model.
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{AutoCatboost$final_graph(graph, task, pv)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{graph}}{(\code{mlr3pipelines::GraphLearner})\cr
+Graph learner.}
+
+\item{\code{task}}{(\link[mlr3:Task]{mlr3::Task}).}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
+}
+\if{html}{\out{
}}
+}
+}
+\if{html}{\out{ }}
\if{html}{\out{ }}
\if{latex}{\out{\hypertarget{method-AutoCatboost-clone}{}}}
\subsection{Method \code{clone()}}{
diff --git a/man/AutoExtraTrees.Rd b/man/AutoExtraTrees.Rd
index 453930f..d0fa575 100644
--- a/man/AutoExtraTrees.Rd
+++ b/man/AutoExtraTrees.Rd
@@ -15,6 +15,7 @@ Extra Trees auto.
\item \href{#method-AutoExtraTrees-new}{\code{AutoExtraTrees$new()}}
\item \href{#method-AutoExtraTrees-graph}{\code{AutoExtraTrees$graph()}}
\item \href{#method-AutoExtraTrees-estimate_memory}{\code{AutoExtraTrees$estimate_memory()}}
+\item \href{#method-AutoExtraTrees-final_graph}{\code{AutoExtraTrees$final_graph()}}
\item \href{#method-AutoExtraTrees-clone}{\code{AutoExtraTrees$clone()}}
}
}
@@ -53,7 +54,7 @@ Identifier for the new instance.}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoExtraTrees$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoExtraTrees$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -71,6 +72,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
@@ -93,6 +97,29 @@ Estimate the memory for the auto.
}
}
\if{html}{\out{ }}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-AutoExtraTrees-final_graph}{}}}
+\subsection{Method \code{final_graph()}}{
+Modify the graph for the final model.
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{AutoExtraTrees$final_graph(graph, task, pv)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{graph}}{(\code{mlr3pipelines::GraphLearner})\cr
+Graph learner.}
+
+\item{\code{task}}{(\link[mlr3:Task]{mlr3::Task}).}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
+}
+\if{html}{\out{
}}
+}
+}
+\if{html}{\out{ }}
\if{html}{\out{}}
\if{latex}{\out{\hypertarget{method-AutoExtraTrees-clone}{}}}
\subsection{Method \code{clone()}}{
diff --git a/man/AutoFTTransformer.Rd b/man/AutoFTTransformer.Rd
index 89617f7..e8413f5 100644
--- a/man/AutoFTTransformer.Rd
+++ b/man/AutoFTTransformer.Rd
@@ -19,12 +19,13 @@ FTTransformer auto.
}
}
\if{html}{\out{
-Inherited methods
+Inherited methods
@@ -53,7 +54,7 @@ Identifier for the new instance.}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoFTTransformer$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoFTTransformer$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -71,6 +72,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
diff --git a/man/AutoFastai.Rd b/man/AutoFastai.Rd
index 0ad0f9c..546dc90 100644
--- a/man/AutoFastai.Rd
+++ b/man/AutoFastai.Rd
@@ -26,6 +26,7 @@ Fastai auto.
mlr3automl::Auto$design_default()
mlr3automl::Auto$design_set()
mlr3automl::Auto$early_stopping_rounds()
+mlr3automl::Auto$final_graph()
mlr3automl::Auto$search_space()
@@ -85,7 +86,7 @@ Default is "cpu".}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoFastai$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoFastai$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -103,6 +104,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
diff --git a/man/AutoGlmnet.Rd b/man/AutoGlmnet.Rd
index f5188f9..ba900e2 100644
--- a/man/AutoGlmnet.Rd
+++ b/man/AutoGlmnet.Rd
@@ -25,6 +25,7 @@ Glmnet auto.
mlr3automl::Auto$design_set()
mlr3automl::Auto$early_stopping_rounds()
mlr3automl::Auto$estimate_memory()
+mlr3automl::Auto$final_graph()
mlr3automl::Auto$search_space()
@@ -53,7 +54,7 @@ Identifier for the new instance.}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoGlmnet$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoGlmnet$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -71,6 +72,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
diff --git a/man/AutoKknn.Rd b/man/AutoKknn.Rd
index c0c69ab..a78f2ee 100644
--- a/man/AutoKknn.Rd
+++ b/man/AutoKknn.Rd
@@ -19,13 +19,14 @@ Kknn auto.
}
}
\if{html}{\out{
-Inherited methods
+Inherited methods
}}
@@ -53,7 +54,7 @@ Identifier for the new instance.}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoKknn$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoKknn$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -71,6 +72,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
diff --git a/man/AutoLda.Rd b/man/AutoLda.Rd
index 3bf46b5..f394ffe 100644
--- a/man/AutoLda.Rd
+++ b/man/AutoLda.Rd
@@ -25,6 +25,7 @@ Lda auto.
mlr3automl::Auto$design_set()
mlr3automl::Auto$early_stopping_rounds()
mlr3automl::Auto$estimate_memory()
+mlr3automl::Auto$final_graph()
mlr3automl::Auto$search_space()
@@ -53,7 +54,7 @@ Identifier for the new instance.}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoLda$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoLda$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -75,6 +76,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
diff --git a/man/AutoLightgbm.Rd b/man/AutoLightgbm.Rd
index 55b31bf..1fa4fa7 100644
--- a/man/AutoLightgbm.Rd
+++ b/man/AutoLightgbm.Rd
@@ -16,6 +16,7 @@ Lightgbm auto.
\item \href{#method-AutoLightgbm-graph}{\code{AutoLightgbm$graph()}}
\item \href{#method-AutoLightgbm-estimate_memory}{\code{AutoLightgbm$estimate_memory()}}
\item \href{#method-AutoLightgbm-internal_measure}{\code{AutoLightgbm$internal_measure()}}
+\item \href{#method-AutoLightgbm-final_graph}{\code{AutoLightgbm$final_graph()}}
\item \href{#method-AutoLightgbm-clone}{\code{AutoLightgbm$clone()}}
}
}
@@ -54,7 +55,7 @@ Identifier for the new instance.}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoLightgbm$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoLightgbm$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -72,6 +73,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
@@ -113,6 +117,29 @@ Get the internal measure for the auto.
}
}
\if{html}{\out{ }}
+\if{html}{\out{ }}
+\if{latex}{\out{\hypertarget{method-AutoLightgbm-final_graph}{}}}
+\subsection{Method \code{final_graph()}}{
+Modify the graph for the final model.
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{AutoLightgbm$final_graph(graph, task, pv)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{graph}}{(\code{mlr3pipelines::GraphLearner})\cr
+Graph learner.}
+
+\item{\code{task}}{(\link[mlr3:Task]{mlr3::Task}).}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
+}
+\if{html}{\out{
}}
+}
+}
+\if{html}{\out{ }}
\if{html}{\out{ }}
\if{latex}{\out{\hypertarget{method-AutoLightgbm-clone}{}}}
\subsection{Method \code{clone()}}{
diff --git a/man/AutoMlp.Rd b/man/AutoMlp.Rd
index ec2db5e..98ff7c5 100644
--- a/man/AutoMlp.Rd
+++ b/man/AutoMlp.Rd
@@ -19,12 +19,13 @@ Mlp auto.
}
}
\if{html}{\out{
-Inherited methods
+Inherited methods
@@ -53,7 +54,7 @@ Identifier for the new instance.}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoMlp$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoMlp$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -71,6 +72,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
diff --git a/man/AutoRanger.Rd b/man/AutoRanger.Rd
index e64aa9b..f017d43 100644
--- a/man/AutoRanger.Rd
+++ b/man/AutoRanger.Rd
@@ -15,6 +15,7 @@ Ranger auto.
\item \href{#method-AutoRanger-new}{\code{AutoRanger$new()}}
\item \href{#method-AutoRanger-graph}{\code{AutoRanger$graph()}}
\item \href{#method-AutoRanger-estimate_memory}{\code{AutoRanger$estimate_memory()}}
+\item \href{#method-AutoRanger-final_graph}{\code{AutoRanger$final_graph()}}
\item \href{#method-AutoRanger-clone}{\code{AutoRanger$clone()}}
}
}
@@ -53,7 +54,7 @@ Identifier for the new instance.}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoRanger$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoRanger$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -71,6 +72,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
@@ -93,6 +97,29 @@ Estimate the memory for the auto.
}
}
\if{html}{\out{ }}
+\if{html}{\out{ }}
+\if{latex}{\out{\hypertarget{method-AutoRanger-final_graph}{}}}
+\subsection{Method \code{final_graph()}}{
+Modify the graph for the final model.
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{AutoRanger$final_graph(graph, task, pv)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{graph}}{(\code{mlr3pipelines::GraphLearner})\cr
+Graph learner.}
+
+\item{\code{task}}{(\link[mlr3:Task]{mlr3::Task}).}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
+}
+\if{html}{\out{
}}
+}
+}
+\if{html}{\out{ }}
\if{html}{\out{ }}
\if{latex}{\out{\hypertarget{method-AutoRanger-clone}{}}}
\subsection{Method \code{clone()}}{
diff --git a/man/AutoResNet.Rd b/man/AutoResNet.Rd
index c54bfc2..720d7b7 100644
--- a/man/AutoResNet.Rd
+++ b/man/AutoResNet.Rd
@@ -19,12 +19,13 @@ ResNet auto.
}
}
\if{html}{\out{
-Inherited methods
+Inherited methods
@@ -53,7 +54,7 @@ Identifier for the new instance.}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoResNet$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoResNet$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -71,6 +72,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
diff --git a/man/AutoSvm.Rd b/man/AutoSvm.Rd
index e2191f9..750dc80 100644
--- a/man/AutoSvm.Rd
+++ b/man/AutoSvm.Rd
@@ -26,6 +26,7 @@ Svm auto.
mlr3automl::Auto$design_set()
mlr3automl::Auto$early_stopping_rounds()
mlr3automl::Auto$estimate_memory()
+mlr3automl::Auto$final_graph()
mlr3automl::Auto$search_space()
@@ -54,7 +55,7 @@ Identifier for the new instance.}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoSvm$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoSvm$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -72,6 +73,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
diff --git a/man/AutoTabpfn.Rd b/man/AutoTabpfn.Rd
index 98f3244..1736600 100644
--- a/man/AutoTabpfn.Rd
+++ b/man/AutoTabpfn.Rd
@@ -26,6 +26,7 @@ Tabpfn auto.
mlr3automl::Auto$design_default()
mlr3automl::Auto$design_set()
mlr3automl::Auto$early_stopping_rounds()
+mlr3automl::Auto$final_graph()
}}
@@ -84,7 +85,7 @@ Default is "cpu".}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoTabpfn$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoTabpfn$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -102,6 +103,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
diff --git a/man/AutoXgboost.Rd b/man/AutoXgboost.Rd
index aa246a1..6a6b65b 100644
--- a/man/AutoXgboost.Rd
+++ b/man/AutoXgboost.Rd
@@ -16,6 +16,7 @@ Xgboost auto.
\item \href{#method-AutoXgboost-graph}{\code{AutoXgboost$graph()}}
\item \href{#method-AutoXgboost-estimate_memory}{\code{AutoXgboost$estimate_memory()}}
\item \href{#method-AutoXgboost-internal_measure}{\code{AutoXgboost$internal_measure()}}
+\item \href{#method-AutoXgboost-final_graph}{\code{AutoXgboost$final_graph()}}
\item \href{#method-AutoXgboost-clone}{\code{AutoXgboost$clone()}}
}
}
@@ -54,7 +55,7 @@ Identifier for the new instance.}
\subsection{Method \code{graph()}}{
Create the graph for the auto.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{AutoXgboost$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{AutoXgboost$graph(task, measure, n_threads, timeout, devices, pv)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -72,6 +73,9 @@ Create the graph for the auto.
Devices to use.
Allowed values are \code{"cpu"} and \code{"cuda"}.
Default is "cpu".}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
}
\if{html}{\out{}}
}
@@ -113,6 +117,29 @@ Get the internal measure for the auto.
}
}
\if{html}{\out{ }}
+\if{html}{\out{ }}
+\if{latex}{\out{\hypertarget{method-AutoXgboost-final_graph}{}}}
+\subsection{Method \code{final_graph()}}{
+Modify the graph for the final model.
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{AutoXgboost$final_graph(graph, task, pv)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{graph}}{(\code{mlr3pipelines::GraphLearner})\cr
+Graph learner.}
+
+\item{\code{task}}{(\link[mlr3:Task]{mlr3::Task}).}
+
+\item{\code{pv}}{(\code{list})\cr
+Parameter values.}
+}
+\if{html}{\out{
}}
+}
+}
+\if{html}{\out{ }}
\if{html}{\out{ }}
\if{latex}{\out{\hypertarget{method-AutoXgboost-clone}{}}}
\subsection{Method \code{clone()}}{
diff --git a/man/mlr3automl-package.Rd b/man/mlr3automl-package.Rd
index b6d272c..8f40b68 100644
--- a/man/mlr3automl-package.Rd
+++ b/man/mlr3automl-package.Rd
@@ -13,17 +13,17 @@ Flexible AutoML system for the 'mlr3' ecosystem.
\seealso{
Useful links:
\itemize{
- \item \url{https://github.com/mlr-org/mlr3automl}
+ \item \url{https://github.com/mlr-org/mlr3automl https://mlr3automl.mlr-org.com}
\item Report bugs at \url{https://github.com/mlr-org/mlr3automl/issues}
}
}
\author{
-\strong{Maintainer}: Damir Pulatov \email{damirpolat@protonmail.com}
+\strong{Maintainer}: Marc Becker \email{marcbecker@posteo.de} (\href{https://orcid.org/0000-0002-8115-0400}{ORCID})
Authors:
\itemize{
- \item Marc Becker \email{marcbecker@posteo.de} (\href{https://orcid.org/0000-0002-8115-0400}{ORCID})
+ \item Damir Pulatov \email{damirpolat@protonmail.com}
\item Baisu Zhou \email{baisu.zhou@outlook.com}
}
diff --git a/pkgdown/_pkgdown.yml b/pkgdown/_pkgdown.yml
new file mode 100644
index 0000000..652706b
--- /dev/null
+++ b/pkgdown/_pkgdown.yml
@@ -0,0 +1,96 @@
+url: https://mlr3automl.mlr-org.com
+
+template:
+ bootstrap: 5
+ light-switch: true
+ math-rendering: mathjax
+ package: mlr3pkgdowntemplate
+
+development:
+ mode: auto
+ version_label: default
+ version_tooltip: "Version"
+
+toc:
+ depth: 3
+
+navbar:
+ structure:
+ left: [intro, reference, news, book, articles]
+ right: [search, github, mattermost, stackoverflow, rss, lightswitch]
+ components:
+ home: ~
+ reference:
+ icon: fa fa-file-alt
+ text: Reference
+ href: reference/index.html
+ mattermost:
+ icon: fa fa-comments
+ href: https://lmmisld-lmu-stats-slds.srv.mwn.de/mlr_invite/
+ book:
+ text: mlr3book
+ icon: fa fa-link
+ href: https://mlr3book.mlr-org.com
+ stackoverflow:
+ icon: fab fa-stack-overflow
+ href: https://stackoverflow.com/questions/tagged/mlr3
+ rss:
+ icon: fa-rss
+ href: https://mlr-org.com/
+
+reference:
+ - title: Classification Learner
+ contents:
+ - LearnerClassifAuto
+ - LearnerClassifAutoCatboost
+ - LearnerClassifAutoFTTransformer
+ - LearnerClassifAutoFastai
+ - LearnerClassifAutoGlmnet
+ - LearnerClassifAutoKKNN
+ - LearnerClassifAutoLightGBM
+ - LearnerClassifAutoMLP
+ - LearnerClassifAutoRanger
+ - LearnerClassifAutoResNet
+ - LearnerClassifAutoSVM
+ - LearnerClassifAutoTabPFN
+ - LearnerClassifAutoXgboost
+ - title: Regression Learner
+ contents:
+ - LearnerRegrAuto
+ - LearnerRegrAutoCatboost
+ - LearnerRegrAutoExtraTrees
+ - LearnerRegrAutoFTTransformer
+ - LearnerRegrAutoFastai
+ - LearnerRegrAutoGlmnet
+ - LearnerRegrAutoKKNN
+ - LearnerRegrAutoLightGBM
+ - LearnerRegrAutoMLP
+ - LearnerRegrAutoRanger
+ - LearnerRegrAutoResNet
+ - LearnerRegrAutoSVM
+ - LearnerRegrAutoTabPFN
+ - LearnerRegrAutoXgboost
+ - title: Auto
+ contents:
+ - Auto
+ - AutoCatboost
+ - AutoExtraTrees
+ - AutoFTTransformer
+ - AutoFastai
+ - AutoGlmnet
+ - AutoKknn
+ - AutoLda
+ - AutoLightgbm
+ - AutoMlp
+ - AutoRanger
+ - AutoResNet
+ - AutoSvm
+ - AutoTabpfn
+ - AutoXgboost
+ - mlr_auto
+ - title: Callbacks
+ contents:
+ - mlr3automl.initial_design_runtime
+ - title: Package
+ contents:
+ - mlr3automl-package
diff --git a/pkgdown/favicon/apple-touch-icon.png b/pkgdown/favicon/apple-touch-icon.png
new file mode 100644
index 0000000..6231a01
Binary files /dev/null and b/pkgdown/favicon/apple-touch-icon.png differ
diff --git a/pkgdown/favicon/favicon-96x96.png b/pkgdown/favicon/favicon-96x96.png
new file mode 100644
index 0000000..ef4c666
Binary files /dev/null and b/pkgdown/favicon/favicon-96x96.png differ
diff --git a/pkgdown/favicon/favicon.ico b/pkgdown/favicon/favicon.ico
new file mode 100644
index 0000000..f51e833
Binary files /dev/null and b/pkgdown/favicon/favicon.ico differ
diff --git a/pkgdown/favicon/favicon.svg b/pkgdown/favicon/favicon.svg
new file mode 100644
index 0000000..ecc31f2
--- /dev/null
+++ b/pkgdown/favicon/favicon.svg
@@ -0,0 +1,3 @@
+
\ No newline at end of file
diff --git a/pkgdown/favicon/site.webmanifest b/pkgdown/favicon/site.webmanifest
new file mode 100644
index 0000000..4ebda26
--- /dev/null
+++ b/pkgdown/favicon/site.webmanifest
@@ -0,0 +1,21 @@
+{
+ "name": "",
+ "short_name": "",
+ "icons": [
+ {
+ "src": "/web-app-manifest-192x192.png",
+ "sizes": "192x192",
+ "type": "image/png",
+ "purpose": "maskable"
+ },
+ {
+ "src": "/web-app-manifest-512x512.png",
+ "sizes": "512x512",
+ "type": "image/png",
+ "purpose": "maskable"
+ }
+ ],
+ "theme_color": "#ffffff",
+ "background_color": "#ffffff",
+ "display": "standalone"
+}
\ No newline at end of file
diff --git a/pkgdown/favicon/web-app-manifest-192x192.png b/pkgdown/favicon/web-app-manifest-192x192.png
new file mode 100644
index 0000000..07fa141
Binary files /dev/null and b/pkgdown/favicon/web-app-manifest-192x192.png differ
diff --git a/pkgdown/favicon/web-app-manifest-512x512.png b/pkgdown/favicon/web-app-manifest-512x512.png
new file mode 100644
index 0000000..f889a12
Binary files /dev/null and b/pkgdown/favicon/web-app-manifest-512x512.png differ
diff --git a/tests/testthat/test_LearnerClassifAuto.R b/tests/testthat/test_LearnerClassifAuto.R
index 29aeba9..08e3ed6 100644
--- a/tests/testthat/test_LearnerClassifAuto.R
+++ b/tests/testthat/test_LearnerClassifAuto.R
@@ -59,6 +59,8 @@ test_that("all learner on cpu work", {
expect_class(learner$train(task), "LearnerClassifAuto")
expect_set_equal(learner$model$instance$archive$data$branch.selection, c("catboost", "glmnet", "kknn", "lightgbm", "ranger", "svm", "xgboost", "lda", "extra_trees"))
+ expect_null(learner$model$graph_learner$param_set$values$xgboost.callbacks)
+ expect_null(learner$model$graph_learner$param_set$values$lightgbm.callbacks)
})
test_that("memory limit works", {
@@ -123,6 +125,8 @@ test_that("large data set switch works", {
rush_plan(n_workers = 2, worker_type = "remote")
mirai::daemons(2)
+ options(bbotk.debug = TRUE)
+
task = tsk("penguins")
learner = lrn("classif.auto",
learner_ids = c("catboost", "glmnet", "kknn", "lightgbm", "ranger", "svm", "xgboost", "lda", "extra_trees"),
@@ -140,6 +144,16 @@ test_that("large data set switch works", {
expect_class(learner$train(task), "LearnerClassifAuto")
expect_set_equal(learner$model$instance$archive$data$branch.selection, c("ranger", "xgboost", "catboost", "lightgbm", "extra_trees"))
+ expect_equal(learner$model$graph_learner$param_set$values$ranger_subsample.frac, 1)
+ expect_equal(learner$model$graph_learner$param_set$values$xgboost_subsample.frac, 1)
+ expect_equal(learner$model$graph_learner$param_set$values$catboost_subsample.frac, 1)
+ expect_equal(learner$model$graph_learner$param_set$values$lightgbm_subsample.frac, 1)
+ expect_equal(learner$model$graph_learner$param_set$values$extra_trees_subsample.frac, 1)
+ expect_true(learner$model$graph_learner$param_set$values$ranger_subsample.stratify)
+ expect_true(learner$model$graph_learner$param_set$values$xgboost_subsample.stratify)
+ expect_true(learner$model$graph_learner$param_set$values$catboost_subsample.stratify)
+ expect_true(learner$model$graph_learner$param_set$values$lightgbm_subsample.stratify)
+ expect_true(learner$model$graph_learner$param_set$values$extra_trees_subsample.stratify)
})
test_that("resample works", {
diff --git a/vignettes/.gitignore b/vignettes/.gitignore
new file mode 100644
index 0000000..9e2bd63
--- /dev/null
+++ b/vignettes/.gitignore
@@ -0,0 +1,4 @@
+*.html
+*.R
+
+/.quarto/
diff --git a/vignettes/binary.csv b/vignettes/binary.csv
new file mode 100644
index 0000000..1fcfc04
--- /dev/null
+++ b/vignettes/binary.csv
@@ -0,0 +1,22 @@
+task_id,instances,features,n,error_rate,auc_automl,auc_autogluon,auc_autosklearn,auc_ranger,auc_nsrb_catboost,auc_nsrb_xgboost,auc_nsrb_lightgbm,se_autogluon,se_automl,se_autosklearn,se_nsrb_catboost,se_nsrb_lightgbm,se_nsrb_xgboost,se_ranger
+Australian,690,15,15985,0,0.927139291570531,0.9403798,0.9401302,0.939441905010666,0.95631890934947,0.956301713755025,0.956988449204069,0.00625951688551121,0.00974758131746723,0.00645153811255718,,,,0.0062217771635561
+blood-transfusion-service-center,748,5,16795,0,0.675452929709896,0.7545867,0.7552632,0.72151129457631,0.784070060772847,0.792113863089095,0.783840729274166,0.0137883955935981,0.0266506873423692,0.0125092638079856,,,,0.0182506668299243
+credit-g,1000,21,14950,0,0.790666666666667,0.7914286,0.7954762,0.797714285714286,0.821380952380953,0.831190476190476,0.819,0.0124515724042298,0.01463130485926,0.0120870631253235,,,,0.0110423724085986
+kc1,2109,22,10778,0,0.828256027889212,0.8394143,0.8388319,0.83301100341815,0.845781043736435,0.844540375825292,0.850140003623581,0.0103067323424492,0.0106013179665369,0.0115374944963993,,,,0.0101397054967363
+jasmine,2984,145,5143,0,0.88434325781121,0.8866436,0.8868102,0.885832349894149,0.894148041379517,0.89654228788493,0.894693962734411,0.00559981955701353,0.00486630530584565,0.00544310501215375,,,,0.00499365391013639
+kr-vs-kp,3196,37,12280,0,0.999633361467846,0.9999059,0.9998938,0.999057242546749,0.999952829119882,0.999892155429527,0.999954850361197,4.90671761748926e-05,0.000139230849479588,5.75538009170546e-05,,,,0.000319843783503223
+sylvine,5124,21,4156,0,0.989789833632888,0.9910253,0.9895615,0.982162980254058,0.991085975736032,0.990626157768042,0.990948771317181,0.000912019615407965,0.00128259306146027,0.000729145683507612,,,,0.0013369986296593
+phoneme,5404,6,8785,0,0.966881616584246,0.9723063,0.9703522,0.960735811750647,0.970141543405407,0.965436968346144,0.968326216905136,0.00243765833924463,0.00265533747788129,0.0027189979682065,,,,0.00311925761866985
+christine,5418,1637,531,0,0.823549087274391,0.828612,0.8178186,0.804034843792472,,0.827972819703597,0.830315793932847,0.00410316305901787,0.00412538075662619,0.0040968347516372,,,,0.00497127144893378
+guillermo,20000,4297,762,0.099737532808399,0.918016039652204,0.9297315,0.9068246,0.899920657376192,,0.920684627886105,0.920749947743724,0.00201515118401683,0.00263784046155174,0.00269879505705787,,,,0.00310538640679411
+riccardo,20000,4297,983,0.113936927772126,0.9997828,0.9998226,0.9997564,0.998583066666667,,0.9998708,0.9998952,4.29211680487202e-05,5.30773300285917e-05,6.70161174643844e-05,,,,0.000167324456013549
+Amazon_employee_access,32769,10,10502,0.0499904780041897,0.896561143661484,0.8948862,0.8782126,0.812647362179038,0.907755346022958,0.86988339978743,0.871084976336888,0.00385602698757268,0.00721915663673388,0.00323081820459015,,,,0.00396292903499147
+nomao,34465,119,8675,0,0.996953656781572,0.9967804,0.9965543,0.994531714464582,0.99658345023132,0.996874121610083,0.996894092945198,0.000177412400919442,0.000214864175613767,0.00021046572959352,,,,0.000235809076857645
+bank-marketing,45211,17,9320,0.0410944206008584,0.937457045512735,0.9415669,0.9391333,0.933101189069023,0.939300290764459,0.938817397246257,0.939851080171622,0.00183242660880775,0.00195781654217951,0.00209943206103196,,,,0.0022766907661668
+adult,48842,15,13004,0.03237465395263,0.931297677712053,0.9317201,0.9308382,0.91701232984862,0.931232944657938,0.931651617146471,0.931908271153075,0.00111491287601817,0.00132130085901969,0.00130615325627925,,,,0.00129102187097203
+KDDCup09_appetency,50000,231,7279,0,0.835180253455325,0.8462116,0.8422286,0.674596354434783,0.85383283113579,0.757159689667086,0.836303848960943,0.00410940322025798,0.00559423030898553,0.00504041454688445,,,,0.00707760782427283
+APSFailure,76000,171,10266,0,0.991328966276702,0.9922027,0.9918997,0.990571966373177,0.994609943610863,0.994614328701291,0.994300775109902,0.000738122107483282,0.000960802341479511,0.000819317494422439,,,,0.000699258290806024
+numerai28.6,96320,22,13187,0,0.529656183979438,0.5239074,0.5306631,0.51944323409202,0.533024426064101,0.532654822877499,0.532978909285434,0.00163424674595566,0.00151656230607815,0.00116353999549268,,,,0.00178015116664133
+higgs,98050,29,8940,0,0.813944999568845,0.8431651,0.8419924,0.801837780350896,0.817769644381571,0.816669351412767,0.816786016241555,0.00025055846290504,0.00207614141742915,0.000812688558633218,,,,0.00207426139122784
+MiniBooNE,130064,51,3096,0.0155038759689922,0.987133776160861,0.9891384,0.9879684,0.981362207753962,,,,0.000287950234897939,0.000323888888810125,0.000303766364899813,,,,0.000368143047104423
+airlines,539383,8,6490,0,0.730441935427352,0.7301134,0.7273721,0.694990665592163,,,,0.000603091778532364,0.000648809705217837,0.000583604231383636,,,,0.000466112563536697
diff --git a/vignettes/mlr3automl.qmd b/vignettes/mlr3automl.qmd
new file mode 100644
index 0000000..22b579d
--- /dev/null
+++ b/vignettes/mlr3automl.qmd
@@ -0,0 +1,217 @@
+---
+title: "mlr3automl"
+format: html
+---
+
+# Introduction
+
+*mlr3automl* is the Automated Machine Learning (AutoML) package of the [mlr3](https://mlr-org.com/) ecosystem.
+
+
+# Benchmark
+
+## Binary Classification
+
+```{r}
+#| echo: false
+library(gt)
+library(data.table)
+
+tab = fread("/home/marc/repositories/mlr3automl/vignettes/binary.csv")
+
+gt(tab) %>%
+ fmt_number(
+ columns = c("auc_automl", "auc_autogluon", "auc_autosklearn", "auc_ranger", "auc_nsrb_catboost", "auc_nsrb_xgboost", "auc_nsrb_lightgbm"),
+ decimals = 2
+ ) %>%
+ fmt_percent(
+ columns = "error_rate",
+ decimals = 0
+ ) %>%
+ fmt_integer(
+ columns = c("n", "instances", "features")
+ ) %>%
+ tab_style(
+ style = list(
+ cell_fill(color = "#90ee90")
+ ),
+ locations = cells_body(
+ columns = auc_autogluon,
+ rows = auc_autogluon > auc_automl & auc_autogluon > auc_autosklearn & auc_autogluon > auc_ranger
+ )
+ ) %>%
+ tab_style(
+ style = list(
+ cell_fill(color = "#90ee90")
+ ),
+ locations = cells_body(
+ columns = auc_autosklearn,
+ rows = auc_autosklearn > auc_automl & auc_autosklearn > auc_autogluon & auc_autosklearn > auc_ranger
+ )
+ ) %>%
+ tab_style(
+ style = list(
+ cell_fill(color = "#90ee90")
+ ),
+ locations = cells_body(
+ columns = auc_automl,
+ rows = auc_automl > auc_autogluon & auc_automl > auc_autosklearn & auc_automl > auc_ranger
+ )
+ ) %>%
+ tab_style(
+ style = list(
+ cell_fill(color = "#90ee90")
+ ),
+ locations = cells_body(
+ columns = auc_ranger,
+ rows = auc_ranger > auc_automl & auc_ranger > auc_autogluon & auc_ranger > auc_autosklearn
+ )
+ ) %>%
+ tab_style(
+ style = list(
+ cell_fill(color = "#B069DB")
+ ),
+ locations = cells_body(
+ columns = auc_nsrb_catboost,
+ rows = auc_nsrb_catboost > auc_automl & auc_nsrb_catboost > auc_autogluon & auc_nsrb_catboost > auc_autosklearn & auc_nsrb_catboost > auc_ranger
+ )
+ ) %>%
+ tab_style(
+ style = list(
+ cell_fill(color = "#B069DB")
+ ),
+ locations = cells_body(
+ columns = auc_nsrb_xgboost,
+ rows = auc_nsrb_xgboost > auc_automl & auc_nsrb_xgboost > auc_autogluon & auc_nsrb_xgboost > auc_autosklearn & auc_nsrb_xgboost > auc_ranger
+ )
+ ) %>%
+ tab_style(
+ style = list(
+ cell_fill(color = "#B069DB")
+ ),
+ locations = cells_body(
+ columns = auc_nsrb_lightgbm,
+ rows = auc_nsrb_lightgbm > auc_automl & auc_nsrb_lightgbm > auc_autogluon & auc_nsrb_lightgbm > auc_autosklearn & auc_nsrb_lightgbm > auc_ranger
+ )
+ ) %>%
+ cols_label(
+ task_id = "Task",
+ auc_automl = "mlr3automl",
+ auc_autogluon = "AutoGluon",
+ auc_autosklearn = "auto-sklearn",
+ auc_ranger = "ranger",
+ auc_nsrb_catboost = "nsrb catboost",
+ auc_nsrb_xgboost = "nsrb xgboost",
+ auc_nsrb_lightgbm = "nsrb lightgbm",
+ n = "N Evaluations",
+ error_rate = "Error Rate",
+ instances = "Instances",
+ features = "Features",
+ ) %>%
+ cols_hide(
+ columns = c("se_automl", "se_autogluon", "se_autosklearn", "se_ranger", "se_nsrb_catboost", "se_nsrb_xgboost", "se_nsrb_lightgbm")
+ ) %>%
+ tab_spanner(
+ label = md('**nsrb**'),
+ columns = c("auc_nsrb_catboost", "auc_nsrb_xgboost", "auc_nsrb_lightgbm")
+ ) %>%
+ tab_spanner(
+ label = md('**automl**'),
+ columns = c("auc_automl", "auc_autogluon", "auc_autosklearn")
+ ) %>%
+ tab_spanner(
+ label = md('**baseline**'),
+ columns = c("auc_ranger")
+ )
+```
+
+
+## Multiclass Classification
+
+```{r}
+# | echo: false
+library(gt)
+library(data.table)
+
+tab = fread("/home/marc/repositories/mlr3automl/vignettes/multiclass.csv")
+
+gt(tab) %>%
+ fmt_number(
+ columns = c("logloss_automl", "logloss_autogluon", "logloss_autosklearn", "logloss_ranger", "logloss_featureless", "logloss_nsrb_catboost", "logloss_nsrb_xgboost", "logloss_nsrb_lightgbm"),
+ decimals = 2
+ ) %>%
+ fmt_percent(
+ columns = "error_rate",
+ decimals = 0
+ ) %>%
+ fmt_integer(
+ columns = c("n", "instances", "features", "classes")
+ ) %>%
+ tab_style(
+ style = list(
+ cell_fill(color = "#90ee90")
+ ),
+ locations = cells_body(
+ columns = logloss_autogluon,
+ rows = logloss_autogluon < logloss_automl & logloss_autogluon < logloss_autosklearn & logloss_autogluon < logloss_ranger
+ )
+ ) %>%
+ tab_style(
+ style = list(
+ cell_fill(color = "#90ee90")
+ ),
+ locations = cells_body(
+ columns = logloss_autosklearn,
+ rows = logloss_autosklearn < logloss_automl & logloss_autosklearn < logloss_autogluon & logloss_autosklearn < logloss_ranger
+ )
+ ) %>%
+ tab_style(
+ style = list(
+ cell_fill(color = "#90ee90")
+ ),
+ locations = cells_body(
+ columns = logloss_automl,
+ rows = logloss_automl < logloss_autogluon & logloss_automl < logloss_autosklearn & logloss_automl < logloss_ranger
+ )
+ ) %>%
+ tab_style(
+ style = list(
+ cell_fill(color = "#90ee90")
+ ),
+ locations = cells_body(
+ columns = logloss_ranger,
+ rows = logloss_ranger < logloss_automl & logloss_ranger < logloss_autogluon & logloss_ranger < logloss_autosklearn
+ )
+ ) %>%
+ cols_label(
+ task_id = "Task",
+ logloss_automl = "mlr3automl",
+ logloss_autogluon = "AutoGluon",
+ logloss_autosklearn = "auto-sklearn",
+ logloss_ranger = "ranger",
+ logloss_nsrb_catboost = "nsrb catboost",
+ logloss_nsrb_xgboost = "nsrb xgboost",
+ logloss_nsrb_lightgbm = "nsrb lightgbm",
+ logloss_featureless = "featureless",
+ n = "N Evaluations",
+ error_rate = "Error Rate",
+ instances = "Instances",
+ features = "Features",
+ classes = "Classes",
+ ) %>%
+ cols_hide(
+ columns = c("se_automl", "se_autogluon", "se_autosklearn", "se_ranger", "se_featureless", "se_nsrb_catboost", "se_nsrb_xgboost", "se_nsrb_lightgbm")
+ ) %>%
+ tab_spanner(
+ label = md('**nsrb**'),
+ columns = c("logloss_nsrb_catboost", "logloss_nsrb_xgboost", "logloss_nsrb_lightgbm")
+ ) %>%
+ tab_spanner(
+ label = md('**automl**'),
+ columns = c("logloss_automl", "logloss_autogluon", "logloss_autosklearn")
+ ) %>%
+ tab_spanner(
+ label = md('**baseline**'),
+ columns = c("logloss_ranger", "logloss_featureless")
+ )
+```
diff --git a/vignettes/multiclass.csv b/vignettes/multiclass.csv
new file mode 100644
index 0000000..d7290bf
--- /dev/null
+++ b/vignettes/multiclass.csv
@@ -0,0 +1,18 @@
+task_id,instances,features,classes,n,error_rate,logloss_automl,logloss_autogluon,logloss_autosklearn,logloss_ranger,logloss_featureless,logloss_nsrb_catboost,logloss_nsrb_xgboost,logloss_nsrb_lightgbm,se_autogluon,se_automl,se_autosklearn,se_featureless,se_nsrb_catboost,se_nsrb_lightgbm,se_nsrb_xgboost,se_ranger
+vehicle,846,19,4,4489,0.00467810202717755,0.246133163650464,0.2975602,0.3287371,0.510561282581669,1.38568463508668,0.417346228374785,0.403397127819529,0.400828178937227,,0.00698817843953814,0.00954073621600206,8.75196107532944e-05,,,,0.0100907914175518
+cnae-9,1080,857,9,1925,0.0592207792207792,0.1978110074977,0.13670169,0.1433449,0.383605591195359,2.19722457733622,0.183215789910806,0.193544966806869,0.136017340635148,,0.0224507384704561,0.0136468275592616,0,,,,0.0103694601320938
+car,1728,7,4,5673,0.00176273576590869,0.00336325567640855,0.00445799303,0.00165220439,0.106281337377552,0.835786293921271,0.000142381808763811,0.00867447124562228,0.00903699561744691,,0.00111708308672078,0.00112451410025991,0.00193086739498989,,,,0.00318904353062951
+mfeat-factors,2000,217,10,4634,0.0336642209753992,0.104556663597505,0.06739112,0.07379011,0.254825952589775,2.30258509299405,0.0764777469290919,0.0868363708318213,0.0750898112199451,,0.0156109055218984,0.00955443793852946,0,,,,0.0068126625016406
+segment,2310,20,7,4471,0.0123014985461865,0.164644724021799,0.05380231,0.06224911,0.185014501231642,1.94591014905531,0.144107199170834,0.156203856960625,0.148784204207072,,0.0124020486003542,0.00818762151706872,0,,,,0.00894840175189936
+fabert,8237,801,7,3137,0,0.779805238351973,0.6823213,0.732713,0.81957240919859,1.87470714450971,0.735900488762994,0.758233912235071,0.736252296217865,,0.0141495553725495,0.00829106589582358,0.000191327211443389,,,,0.00730733440520438
+robert,10000,7201,10,510,0.17843137254902,1.29986025246257,1.265303,1.38254,1.703142824254,2.30224253913693,,1.25048427998283,1.23802438049629,,0.0095081263259262,0.00989761284350929,1.01225183054029e-05,,,,0.00456315597417938
+dilbert,10000,2001,5,1019,0.0716388616290481,0.0212365465041528,0.011531044,0.02948814,0.345697306635468,1.60913038836964,,0.0262178686924159,0.0173229095620056,,0.00379598606715973,0.0024238032368894,6.88809046860954e-06,,,,0.0029826048870314
+jungle_chess_2pcs_raw_endgame_complete,44819,7,3,10897,0.0100945214279159,0.220940363696335,0.006359055,0.202931,0.365947425082544,0.935131176802189,0.225096653055923,0.216816828439813,0.21023144507842,,0.00228145184669591,0.00663809773956365,5.64241219502468e-05,,,,0.00200385029090546
+shuttle,58000,10,7,10191,0.120007850063782,0.000433851922532271,0.0004103703,0.00019398261,0.000999513993301768,0.665633632977827,0.000172939847837588,0.000338603285323662,0.000304703105793594,,0.000105959044951895,6.52572911754825e-05,0.000197431059510338,,,,7.7661666167531e-05
+volkert,58310,181,10,1087,0.0965961361545538,0.798647525057554,0.6736234,0.8331933,0.993226353940388,2.05260271776454,,0.772378708579234,0.722045503644117,,0.00455993559043237,0.0119893109708885,6.62318081584586e-05,,,,0.00275743097891916
+helena,65196,28,100,478,0.50418410041841,2.62423067028888,2.469964,2.48524,2.86122418928546,4.14341318196952,,,,,0.0133154206797103,0.00973240360856452,0.000229507436082131,,,,0.0107391517203801
+connect-4,67557,43,3,3973,0.0072992700729927,0.319769844402902,0.29332,0.3424652,0.455782398418385,0.84456127352411,0.307335718368707,0.304755426273849,0.307800546798303,,0.00238211608853492,0.00396902146014746,2.55615445901264e-05,,,,0.00162967435011445
+Fashion-MNIST,70000,785,10,588,0.185374149659864,0.252519395641236,0.2166764,0.2470057,0.368519094425655,2.30258509299405,,0.235421502589961,,,0.00489295048550023,0.00388432476056722,0,,,,0.00187728299184983
+jannis,83733,55,4,1838,0.0201305767138194,0.673021499200483,0.6470854,0.6723501,0.732399719936977,1.10866842074873,0.666277640906523,0.667610014709809,0.667920427068077,,0.00202470355108686,0.0014239732086275,5.79329889946298e-05,,,,0.00159798786642093
+dionis,416188,61,355,337,0.670623145400593,2.44416591642388,0.2478194,0.5234148,5.85739174873652,5.85739174873652,,,,,0.76552476216906,0.0456489105143936,1.34855287420444e-05,,,,1.34855287420444e-05
+covertype,581012,55,7,536,0.0111940298507463,0.0883230133679485,0.05741523,0.0946388,0.60940715506443,1.20516053044485,,,,,0.00133009751186704,0.00323271539517821,3.58581462169515e-06,,,,0.000486609053327916