From 44a08b78c1000c1f186aaa72d861fa2ac0fa8000 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 29 Jan 2025 11:54:26 -0800 Subject: [PATCH 1/3] toggle mlp activation by engine --- R/tunable.R | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/R/tunable.R b/R/tunable.R index 9f44b60f1..8d6fde55b 100644 --- a/R/tunable.R +++ b/R/tunable.R @@ -355,9 +355,18 @@ tunable.mlp <- function(x, ...) { list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/2))) res$call_info[res$name == "epochs"] <- list(list(pkg = "dials", fun = "epochs", range = c(5L, 500L))) + activation_values <- rlang::eval_tidy( + rlang::call2("brulee_activations", .ns = "brulee") + ) + res$call_info[res$name == "activation"] <- + list(list(pkg = "dials", fun = "activation", values = activation_values)) + } else if (x$engine == "keras") { + activation_values <- parsnip::keras_activations() + res$call_info[res$name == "activation"] <- + list(list(pkg = "dials", fun = "activation", values = activation_values)) } res -} + } #' @export tunable.survival_reg <- function(x, ...) { From 7bf3754774ba28d1b4941692d70b2d1d39d3abe9 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 29 Jan 2025 14:22:47 -0800 Subject: [PATCH 2/3] normalize hard sigmoid activation name with brulee --- R/mlp.R | 10 +++++++++- tests/testthat/_snaps/mlp_keras.md | 10 ++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/R/mlp.R b/R/mlp.R index d04bca2ef..ec6a8ed5c 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -200,6 +200,7 @@ keras_mlp <- {.val {activation}}." ) } + activation <- get_activation_fn(activation) if (penalty > 0 & dropout > 0) { cli::cli_abort("Please use either dropout or weight decay.", call = NULL) @@ -351,7 +352,7 @@ mlp_num_weights <- function(p, hidden_units, classes) { } allowed_keras_activation <- - c("elu", "exponential", "gelu", "hard_sigmoid", "linear", "relu", "selu", + c("elu", "exponential", "gelu", "hardsigmoid", "linear", "relu", "selu", "sigmoid", "softmax", "softplus", "softsign", "swish", "tanh") #' Activation functions for neural networks in keras @@ -363,6 +364,13 @@ keras_activations <- function() { allowed_keras_activation } +get_activation_fn <- function(arg, ...) { + if (arg == "hardsigmoid") { + arg <- "hard_sigmoid" + } + arg +} + ## ----------------------------------------------------------------------------- #' @importFrom purrr map diff --git a/tests/testthat/_snaps/mlp_keras.md b/tests/testthat/_snaps/mlp_keras.md index f38dbea23..bf9652e57 100644 --- a/tests/testthat/_snaps/mlp_keras.md +++ b/tests/testthat/_snaps/mlp_keras.md @@ -6,3 +6,13 @@ Error: ! object 'novar' not found +# all keras activation functions + + Code + mlp(mode = "classification", hidden_units = 2, penalty = 0.01, epochs = 2, + activation = "invalid") %>% set_engine("keras", verbose = 0) %>% parsnip::fit( + Class ~ A + B, data = modeldata::two_class_dat) + Condition + Error in `parsnip::keras_mlp()`: + ! `activation` should be one of: elu, exponential, gelu, hardsigmoid, linear, relu, selu, sigmoid, softmax, softplus, softsign, swish, and tanh, not "invalid". + From e6ea127680a7cb229c1b19726459c93868e2e9d3 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 29 Jan 2025 14:39:34 -0800 Subject: [PATCH 3/3] add skip_if_not_installed --- tests/testthat/test-tunable.R | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/testthat/test-tunable.R b/tests/testthat/test-tunable.R index ced23c200..b7e24d108 100644 --- a/tests/testthat/test-tunable.R +++ b/tests/testthat/test-tunable.R @@ -1,4 +1,5 @@ test_that('brulee has mixture object', { + skip_if_not_installed("brulee") # for issue 1236 mlp_spec <- mlp(