Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
.vscode
^kaggle$
^oml_cache$
^vignettes$
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,4 @@ rsconnect/
.Rprofile
kaggle/
README.html
inst/doc
13 changes: 8 additions & 5 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ Package: mlr3automl
Title: AutoML extention for 'mlr3'
Version: 0.0.1
Authors@R: c(
person("Damir", "Pulatov", , "[email protected]", role = c("cre", "aut")),
person("Marc", "Becker", , "[email protected]", role = "aut",
comment = c(ORCID = "0000-0002-8115-0400")),
person("Marc", "Becker", , "[email protected]", role = c("cre", "aut"),
comment = c(ORCID = "0000-0002-8115-0400")),
person("Damir", "Pulatov", , "[email protected]", role = "aut"),
person("Baisu", "Zhou", , "[email protected]", role = "aut")
)
Description: Flexible AutoML system for the 'mlr3' ecosystem.
License: LGPL-3
URL: https://github.com/mlr-org/mlr3automl
URL: https://github.com/mlr-org/mlr3automl https://mlr3automl.mlr-org.com
BugReports: https://github.com/mlr-org/mlr3automl/issues
Depends:
mlr3 (>= 1.2.0.9000),
Expand All @@ -35,29 +35,32 @@ Suggests:
fastai,
glmnet,
kknn,
knitr,
lgr,
lightgbm,
MASS,
mirai,
mlr3extralearners,
mlr3torch (>= 0.3.0.9000),
mlr3viz,
quarto,
ranger,
redux,
reticulate,
rmarkdown,
rpart,
testthat (>= 3.0.0),
xgboost
Remotes:
catboost/catboost/catboost/R-package,
eagerai/fastai,
mlr-org/bbotk,
mlr-org/mlr3,
mlr-org/mlr3extralearners,
mlr-org/mlr3learners,
mlr-org/mlr3mbo@so_config_6,
mlr-org/mlr3torch,
mlr-org/mlr3tuning,
mlr-org/bbotk,
mlr-org/rush
Config/testthat/edition: 3
Config/testthat/parallel: false
Expand Down
12 changes: 10 additions & 2 deletions R/Auto.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#' @description
#' This class is the base class for all autos.
#'
#' @include mlr_auto.R
#' @include mlr_auto.R Auto.R
#'
#' @template param_id
#' @template param_task
Expand All @@ -14,6 +14,8 @@
#' @template param_large_data_set
#' @template param_size
#' @template param_devices
#' @template param_pv
#' @template param_graph
#'
#' @export
Auto = R6Class("Auto",
Expand Down Expand Up @@ -80,7 +82,7 @@ Auto = R6Class("Auto",

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout, devices) {
graph = function(task, measure, n_threads, timeout, devices, pv) {
stop("Abstract")
},

Expand Down Expand Up @@ -154,6 +156,12 @@ Auto = R6Class("Auto",
#' Get the search space for the learner.
search_space = function(task) {
private$.search_space
},

#' @description
#' Modify the graph for the final model.
final_graph = function(graph, task, pv) {
graph
}
),

Expand Down
26 changes: 23 additions & 3 deletions R/AutoCatboost.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' @title Catboost Auto
#'
#' @include mlr_auto.R
#' @include mlr_auto.R Auto.R Auto.R
#'
#' @description
#' Catboost auto.
Expand All @@ -10,7 +10,12 @@
#' @template param_measure
#' @template param_n_threads
#' @template param_timeout
#' @template param_memory_limit
#' @template param_large_data_set
#' @template param_size
#' @template param_devices
#' @template param_pv
#' @template param_graph
#'
#' @export
AutoCatboost = R6Class("AutoCatboost",
Expand All @@ -31,7 +36,7 @@ AutoCatboost = R6Class("AutoCatboost",

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout, devices) {
graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
Expand All @@ -52,10 +57,17 @@ AutoCatboost = R6Class("AutoCatboost",
task_type = task_type)
set_threads(learner, n_threads)

po("removeconstants", id = "catboost_removeconstants") %>>%

graph = po("removeconstants", id = "catboost_removeconstants") %>>%
po("colapply", id = "catboost_colapply", applicator = as.numeric, affect_columns = selector_type("integer")) %>>%
po("removeconstants", id = "catboost_post_removeconstants") %>>%
learner

if (task$nrow * task$ncol > pv$large_data_size) {
graph = po("subsample", frac = 0.25, stratify = inherits(task, "TaskClassif"), use_groups = FALSE, id = "catboost_subsample") %>>% graph
}

graph
},

#' @description
Expand Down Expand Up @@ -115,6 +127,14 @@ AutoCatboost = R6Class("AutoCatboost",
"merror" # default
)
}
},

#' @description
#' Modify the graph for the final model.
final_graph = function(graph, task, pv) {
if (task$nrow * task$ncol > pv$large_data_size) {
graph$param_set$set_values(catboost_subsample.frac = 1)
}
}
),

Expand Down
28 changes: 23 additions & 5 deletions R/AutoExtraTrees.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
#' @title Extra Trees Auto
#'
#' @include mlr_auto.R
#' @include mlr_auto.R Auto.R
#'
#' @description
#' Extra Trees auto.
#'
#' @template param_id
#' @template param_n_threads
#' @template param_timeout
#' @template param_task
#' @template param_measure
#' @template param_n_threads
#' @template param_timeout
#' @template param_memory_limit
#' @template param_large_data_set
#' @template param_size
#' @template param_devices
#' @template param_pv
#' @template param_graph
#'
#' @export
AutoExtraTrees = R6Class("AutoExtraTrees",
Expand All @@ -37,7 +41,7 @@ AutoExtraTrees = R6Class("AutoExtraTrees",
#' @param measure ([mlr3::Measure]).
#' @param n_threads (`numeric(1)`).
#' @param timeout (`numeric(1)`).
graph = function(task, measure, n_threads, timeout, devices) {
graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
Expand All @@ -53,13 +57,19 @@ AutoExtraTrees = R6Class("AutoExtraTrees",
sample.fraction = 1)
set_threads(learner, n_threads)

po("removeconstants", id = "extra_trees_removeconstants") %>>%
graph = po("removeconstants", id = "extra_trees_removeconstants") %>>%
po("imputeoor", id = "extra_trees_imputeoor") %>>%
po("fixfactors", id = "extra_trees_fixfactors") %>>%
po("imputesample", affect_columns = selector_type(c("factor", "ordered")), id = "extra_trees_imputesample") %>>%
po("collapsefactors", target_level_count = 40, id = "extra_trees_collapse") %>>%
po("removeconstants", id = "extra_trees_post_removeconstants") %>>%
learner

if (task$nrow * task$ncol > pv$large_data_size) {
graph = po("subsample", frac = 0.25, stratify = inherits(task, "TaskClassif"), use_groups = FALSE, id = "extra_trees_subsample") %>>% graph
}

graph
},

#' @description
Expand All @@ -70,6 +80,14 @@ AutoExtraTrees = R6Class("AutoExtraTrees",
num_trees = 100
tree_size_bytes = task$nrow / 60000 * 1e6
ceiling((tree_size_bytes * num_trees) / 1e6)
},

#' @description
#' Modify the graph for the final model.
final_graph = function(graph, task, pv) {
if (task$nrow * task$ncol > pv$large_data_size) {
graph$param_set$set_values(extra_trees_subsample.frac = 1)
}
}
)
)
Expand Down
5 changes: 3 additions & 2 deletions R/AutoFTTransformer.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' @title FTTransformer Auto
#'
#' @include mlr_auto.R
#' @include mlr_auto.R Auto.R
#'
#' @description
#' FTTransformer auto.
Expand All @@ -11,6 +11,7 @@
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
#' @template param_pv
#'
#' @export
AutoFTTransformer = R6Class("AutoFTTransformer",
Expand All @@ -31,7 +32,7 @@ AutoFTTransformer = R6Class("AutoFTTransformer",

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout, devices) {
graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
Expand Down
5 changes: 3 additions & 2 deletions R/AutoFastai.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' @title Fastai Auto
#'
#' @include mlr_auto.R
#' @include mlr_auto.R Auto.R
#'
#' @description
#' Fastai auto.
Expand All @@ -13,6 +13,7 @@
#' @template param_memory_limit
#' @template param_large_data_set
#' @template param_devices
#' @template param_pv
#'
#' @export
AutoFastai = R6Class("AutoFastai",
Expand Down Expand Up @@ -49,7 +50,7 @@ AutoFastai = R6Class("AutoFastai",

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout, devices) {
graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
Expand Down
5 changes: 3 additions & 2 deletions R/AutoGlmnet.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' @title Glmnet Auto
#'
#' @include mlr_auto.R
#' @include mlr_auto.R Auto.R
#'
#' @description
#' Glmnet auto.
Expand All @@ -11,6 +11,7 @@
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
#' @template param_pv
#'
#' @export
AutoGlmnet = R6Class("AutoGlmnet",
Expand All @@ -30,7 +31,7 @@ AutoGlmnet = R6Class("AutoGlmnet",

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout, devices) {
graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
Expand Down
5 changes: 3 additions & 2 deletions R/AutoKknn.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' @title Kknn Auto
#'
#' @include mlr_auto.R
#' @include mlr_auto.R Auto.R
#'
#' @description
#' Kknn auto.
Expand All @@ -11,6 +11,7 @@
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
#' @template param_pv
#'
#' @export
AutoKknn = R6Class("AutoKknn",
Expand All @@ -30,7 +31,7 @@ AutoKknn = R6Class("AutoKknn",

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout, devices) {
graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
Expand Down
5 changes: 3 additions & 2 deletions R/AutoLda.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' @title Lda Auto
#'
#' @include mlr_auto.R
#' @include mlr_auto.R Auto.R
#'
#' @description
#' Lda auto.
Expand All @@ -14,6 +14,7 @@
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
#' @template param_pv
#'
#' @export
AutoLda = R6Class("AutoLda",
Expand All @@ -33,7 +34,7 @@ AutoLda = R6Class("AutoLda",

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout, devices) {
graph = function(task, measure, n_threads, timeout, devices, pv) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
Expand Down
Loading
Loading