Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .lintr
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ linters: linters_with_defaults(
# the following setup changes/removes certain linters
assignment_linter = NULL, # do not force using <- for assignments
object_name_linter = object_name_linter(c("snake_case", "CamelCase")), # only allow snake case and camel case object names
cyclocomp_linter = NULL, # do not check function complexity
commented_code_linter = NULL, # allow code in comments
line_length_linter = line_length_linter(180L)
line_length_linter = line_length_linter(180L),
indentation_linter = indentation_linter(indent = 2, hanging_indent_style = "never")
)

4 changes: 4 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ Suggests:
htmlwidgets,
ranger,
themis
Remotes:
mlr-org/mlr3misc@common_baseclass,
mlr-org/mlr3learners@common_baseclass,
mlr-org/mlr3@common_baseclass
ByteCompile: true
Encoding: UTF-8
Config/testthat/edition: 3
Expand Down
31 changes: 8 additions & 23 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#' Identifier of the resulting [`Learner`][mlr3::Learner].
#' * `param_vals` :: named `list`\cr
#' List of hyperparameter settings, overwriting the hyperparameter settings . Default `list()`.
#' Deprecated, will be removed in the future.
#' * `task_type` :: `character(1)`\cr
#' What `task_type` the `GraphLearner` should have; usually automatically inferred for [`Graph`]s that are simple enough.
#' * `predict_type` :: `character(1)`\cr
Expand Down Expand Up @@ -180,11 +181,14 @@
GraphLearner = R6Class("GraphLearner", inherit = Learner,
public = list(
impute_selected_features = FALSE,
initialize = function(graph, id = NULL, param_vals = list(), task_type = NULL, predict_type = NULL, clone_graph = TRUE) {
initialize = function(graph, id = NULL, task_type = NULL, predict_type = NULL, clone_graph = TRUE) {
graph = as_graph(graph, clone = assert_flag(clone_graph))
graph$state = NULL

id = assert_string(id, null.ok = TRUE) %??% paste(graph$ids(sorted = TRUE), collapse = ".")
if (".has_id" %in% names(private)) {
private$.has_id = TRUE
}
self$id = id # init early so 'base_learner()' can use it in error messages
private$.graph = graph

Expand Down Expand Up @@ -217,17 +221,15 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
)

super$initialize(id = id, task_type = task_type,
param_set = alist(private$.graph$param_set),
feature_types = mlr_reflections$task_feature_types,
predict_types = names(mlr_reflections$learner_predict_types[[task_type]]),
packages = graph$packages,
properties = properties,
man = "mlr3pipelines::GraphLearner"
properties = properties
)

if (length(param_vals)) {
private$.graph$param_set$values = insert_named(private$.graph$param_set$values, param_vals)
}
if (!is.null(predict_type)) self$predict_type = predict_type

},
base_learner = function(recursive = Inf, return_po = FALSE, return_all = FALSE, resolve_branching = TRUE) {
assert(check_numeric(recursive, lower = Inf), check_int(recursive))
Expand Down Expand Up @@ -393,12 +395,6 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
private$.graph$edges
},
param_set = function(rhs) {
if (!missing(rhs) && !identical(rhs, self$graph$param_set)) {
stop("param_set is read-only.")
}
self$graph$param_set
},
pipeops_param_set = function(rhs) {
value = map(self$graph$pipeops, "param_set")
if (!missing(rhs) && !identical(value, rhs)) {
Expand Down Expand Up @@ -434,16 +430,6 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
if (!length(ivs)) return(named_list())
ivs
},
deep_clone = function(name, value) {
# FIXME this repairs the mlr3::Learner deep_clone() method which is broken.
if (is.environment(value) && !is.null(value[[".__enclos_env__"]])) {
return(value$clone(deep = TRUE))
}
if (name == "state") {
value$log = copy(value$log)
}
value
},

.train = function(task) {
if (!is.null(get0("validate", self))) {
Expand Down Expand Up @@ -604,7 +590,6 @@ as_learner.PipeOp = function(x, clone = FALSE, ...) {
as_learner(as_graph(x, clone = FALSE, ...), clone = clone)
}


infer_task_type = function(graph) {
output = graph$output
# check the high level input and output
Expand Down
6 changes: 2 additions & 4 deletions R/LearnerAvg.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ LearnerClassifAvg = R6Class("LearnerClassifAvg", inherit = LearnerClassif,
param_set = ps,
predict_types = c("response", "prob"),
feature_types = c("integer", "numeric", "factor"),
properties = c("twoclass", "multiclass"),
man = "mlr3pipelines::LearnerClassifAvg"
properties = c("twoclass", "multiclass")
)
},
prepare_data = function(task) {
Expand Down Expand Up @@ -145,8 +144,7 @@ LearnerRegrAvg = R6Class("LearnerRegrAvg", inherit = LearnerRegr,
id = id,
param_set = ps,
predict_types = "response",
feature_types = c("integer", "numeric"),
man = "mlr3pipelines::LearnerRegrAvg"
feature_types = c("integer", "numeric")
)
},
prepare_data = function(task) {
Expand Down
146 changes: 40 additions & 106 deletions R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#'
#' * `id` :: `character(1)`\cr
#' Identifier of resulting object. See `$id` slot.
#' Deprecated, will be removed in the future.
#' * `param_set` :: [`ParamSet`][paradox::ParamSet] | `list` of `expression`\cr
#' Parameter space description. This should be created by the subclass and given to `super$initialize()`.
#' If this is a [`ParamSet`][paradox::ParamSet], it is used as the `PipeOp`'s [`ParamSet`][paradox::ParamSet]
Expand All @@ -50,6 +51,7 @@
#' * `param_vals` :: named `list`\cr
#' List of hyperparameter settings, overwriting the hyperparameter settings given in `param_set`. The
#' subclass should have its own `param_vals` parameter and pass it on to `super$initialize()`. Default `list()`.
#' Deprecated, will be removed in the future. Use the [po()] syntax to set hyperparameters on construction.
#' * `input` :: [`data.table`][data.table::data.table] with columns `name` (`character`), `train` (`character`), `predict` (`character`)\cr
#' Sets the `$input` slot of the resulting object; see description there.
#' * `output` :: [`data.table`][data.table::data.table] with columns `name` (`character`), `train` (`character`), `predict` (`character`)\cr
Expand Down Expand Up @@ -249,30 +251,52 @@
#' @template seealso_pipeopslist
#' @export
PipeOp = R6Class("PipeOp",
inherit = Mlr3Component,
public = list(
packages = NULL,
state = NULL,
input = NULL,
output = NULL,
.result = NULL,
tags = NULL,
properties = NULL,

initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract", properties = character(0)) {
if (inherits(param_set, "ParamSet")) {
private$.param_set = assert_param_set(param_set)
private$.param_set_source = NULL
} else {
lapply(param_set, function(x) assert_param_set(eval(x)))
private$.param_set_source = param_set
initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract", properties = character(0), dict_entry = id) {


## ------ deprecating id and param_vals
sc = sys.calls()
found = 0
# exclude the last (i.e., current) and second-tolast frame, these are PipeOp$new() directly.
for (i in rev(seq_along(sc))[-c(1, 2)]) {
if (identical(sc[[i]], quote(initialize(...)))) {
found = i
break
}
}
self$id = assert_string(id, min.chars = 1)
if (found > 0) {
sf = sys.frames()[[found - 1]]
if (identical(class(self), sf$classes)) {
newcall = match.call(sys.function(found), sc[[found]], envir = sf)
passes_param_vals = !is.null(newcall$param_vals)
dots = (function(...) evalq(list(...), envir = sf))() # function(...) is here to pacify R CMD check static checks
unnamed_dots = dots[is.na(names2(dots))]
passes_id = length(unnamed_dots) && !is.null(newcall$id) && identical(newcall$id, unnamed_dots[[1]])
if (passes_param_vals || passes_id) {
mlr3component_deprecation_msg("passing param_vals, and unnamed id, for PipeOp construction directly is deprecated and will be removed in the future.
Use the po()-syntax to set these, instead:
po(\"pipeop\", \"newid\", param_vals = list(a = 1)) --> po(\"pipeop\", id = \"newid\", a = 1)")
}
}
}
## ------

super$initialize(id = id, dict_entry = dict_entry, dict_shortaccess = "po",
param_set = param_set, packages = packages, properties = properties
)

self$properties = assert_subset(properties, mlr_reflections$pipeops$properties)
assert_subset(properties, mlr_reflections$pipeops$properties)
self$param_set$values = insert_named(self$param_set$values, param_vals)
self$input = assert_connection_table(input)
self$output = assert_connection_table(output)
self$packages = union("mlr3pipelines", assert_character(packages, any.missing = FALSE, min.chars = 1L))
self$tags = assert_subset(tags, mlr_reflections$pipeops$valid_tags)
},

Expand Down Expand Up @@ -377,111 +401,21 @@ PipeOp = R6Class("PipeOp",
),

active = list(
id = function(val) {
if (!missing(val)) {
private$.id = val
}
private$.id
},
param_set = function(val) {
if (is.null(private$.param_set)) {
sourcelist = lapply(private$.param_set_source, function(x) eval(x))
if (length(sourcelist) > 1) {
private$.param_set = ParamSetCollection$new(sourcelist)
} else {
private$.param_set = sourcelist[[1]]
}
}
if (!missing(val) && !identical(val, private$.param_set)) {
stop("param_set is read-only.")
}
private$.param_set
},
predict_type = function(val) {
if (!missing(val)) {
stop("$predict_type is read-only.")
}
return(NULL)
NULL
},
innum = function() nrow(self$input),
outnum = function() nrow(self$output),
is_trained = function() !is.null(self$state),
hash = function() {
digest(list(class(self), self$id, lapply(self$param_set$values, function(val) {
# ideally we would just want to hash `param_set$values`, but one of the values
# could be an R6 object with a `$hash` slot as well, in which case we take that
# slot's value. This is to avoid different hashes from essentially the same
# objects.
# In the following we also avoid accessing `val$hash` twice, because it could
# potentially be an expensive AB.
if (is.environment(val) && !is.null({vhash = get0("hash", val, mode = "any", inherits = FALSE, ifnotfound = NULL)})) {
vhash
} else {
val
}
}), private$.additional_phash_input()), algo = "xxhash64")
},
phash = function() {
digest(list(class(self), self$id, private$.additional_phash_input()), algo = "xxhash64")
},
man = function(x) {
if (!missing(x)) stop("man is read-only")
paste0(topenv(self$.__enclos_env__)$.__NAMESPACE__.$spec[["name"]], "::", class(self)[[1]])
},
label = function(x) {
if (!missing(x)) stop("label is read-only")
if (is.null(private$.label)) {
helpinfo = self$help()
helpcontent = NULL
if (inherits(helpinfo, "help_files_with_topic") && length(helpinfo)) {
ghf = get(".getHelpFile", mode = "function", envir = getNamespace("utils"))
helpcontent = ghf(helpinfo)
} else if (inherits(helpinfo, "dev_topic")) {
helpcontent = tools::parse_Rd(helpinfo$path)
}
if (is.null(helpcontent)) {
private$.label = "LABEL COULD NOT BE RETRIEVED"
} else {
private$.label = Filter(function(x) identical(attr(x, "Rd_tag"), "\\title"), helpcontent)[[1]][[1]][1]
}
}
private$.label
}
is_trained = function() !is.null(self$state)
),

private = list(
.state_class = NULL,
deep_clone = function(name, value) {
if (!is.null(private$.param_set_source)) {
private$.param_set = NULL # required to keep clone identical to original, otherwise tests get really ugly
if (name == ".param_set_source") {
value = lapply(value, function(x) {
if (inherits(x, "R6")) x$clone(deep = TRUE) else x
})
}
}
if (is.environment(value) && !is.null(value[[".__enclos_env__"]])) {
return(value$clone(deep = TRUE))
}
value
},
.train = function(input) stop("abstract"),
.predict = function(input) stop("abstract"),
.additional_phash_input = function() {
if (is.null(self$initialize)) return(NULL)
initformals <- names(formals(args(self$initialize)))
if (!test_subset(initformals, c("id", "param_vals"))) {
warningf("PipeOp %s has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.

The hash and phash of a PipeOp must differ when it represents a different operation; since %s has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.

This warning will become an error in the future.", class(self)[[1]], class(self)[[1]])
}
},
.param_set = NULL,
.param_set_source = NULL,
.label = NULL,
.id = NULL
.predict = function(input) stop("abstract")
)
)

Expand Down Expand Up @@ -511,7 +445,7 @@ check_types = function(self, data, direction, operation) {
description = sprintf("%s of PipeOp %s's $%s()", direction, self$id, operation)
if (direction == "input" && "..." %in% typetable$name) {
assert_list(data, min.len = nrow(typetable) - 1, .var.name = description)
typetable = typetable[rep(1:.N, ifelse(get("name") == "...", length(data) - nrow(typetable) + 1, 1))]
typetable = typetable[rep(seq_len(.N), ifelse(get("name") == "...", length(data) - nrow(typetable) + 1, 1))]
} else {
assert_list(data, len = nrow(typetable), .var.name = description)
}
Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpADAS.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
#'
#' * `id` :: `character(1)`\cr
#' Identifier of resulting object, default `"adas"`.
#' Deprecated, will be removed in the future. Use the [po()] syntax to set a custom ID on construction.
#' * `param_vals` :: named `list`\cr
#' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Default `list()`.
#' Deprecated, will be removed in the future. Use the [po()] syntax to set hyperparameters on construction.
#'
#' @section Input and Output Channels:
#' Input and output channels are inherited from [`PipeOpTaskPreproc`]. Instead of a [`Task`][mlr3::Task], a
Expand Down Expand Up @@ -81,7 +83,7 @@ PipeOpADAS = R6Class("PipeOpADAS",
K = p_int(lower = 1, default = 5, tags = c("train", "adas"))
)
super$initialize(id, param_set = ps, param_vals = param_vals, can_subset_cols = FALSE,
packages = "smotefamily", task_type = "TaskClassif", tags = "imbalanced data")
packages = "smotefamily", task_type = "TaskClassif", tags = "imbalanced data", dict_entry = "adas")
}
),
private = list(
Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpBLSmote.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
#'
#' * `id` :: `character(1)`\cr
#' Identifier of resulting object, default `"smote"`.
#' Deprecated, will be removed in the future. Use the [po()] syntax to set a custom ID on construction.
#' * `param_vals` :: named `list`\cr
#' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Default `list()`.
#' Deprecated, will be removed in the future. Use the [po()] syntax to set hyperparameters on construction.
#'
#' @section Input and Output Channels:
#' Input and output channels are inherited from [`PipeOpTaskPreproc`]. Instead of a [`Task`][mlr3::Task], a
Expand Down Expand Up @@ -93,7 +95,7 @@ PipeOpBLSmote = R6Class("PipeOpBLSmote",
)
ps$values = list(quiet = TRUE)
super$initialize(id, param_set = ps, param_vals = param_vals, can_subset_cols = FALSE,
packages = "smotefamily", task_type = "TaskClassif", tags = "imbalanced data")
packages = "smotefamily", task_type = "TaskClassif", tags = "imbalanced data", dict_entry = "blsmote")
}
),
private = list(
Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpBoxCox.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
#'
#' * `id` :: `character(1)`\cr
#' Identifier of resulting object, default `"boxcox"`.
#' Deprecated, will be removed in the future. Use the [po()] syntax to set a custom ID on construction.
#' * `param_vals` :: named `list`\cr
#' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Default `list()`.
#' Deprecated, will be removed in the future. Use the [po()] syntax to set hyperparameters on construction.
#'
#' @section Input and Output Channels:
#' Input and output channels are inherited from [`PipeOpTaskPreproc`].
Expand Down Expand Up @@ -77,7 +79,7 @@ PipeOpBoxCox = R6Class("PipeOpBoxCox",
upper = p_dbl(tags = c("train", "boxcox"))
)
super$initialize(id, param_set = ps, param_vals = param_vals,
packages = "bestNormalize", feature_types = c("numeric", "integer"))
packages = "bestNormalize", feature_types = c("numeric", "integer"), dict_entry = "boxcox")
}
),
private = list(
Expand Down
Loading
Loading