Skip to content

Commit db67f2b

Browse files
committed
update parsnip add-in code and db
1 parent b34bdea commit db67f2b

File tree

2 files changed

+59
-20
lines changed

2 files changed

+59
-20
lines changed

data/model_db.rda

557 Bytes
Binary file not shown.

inst/add-in/parsnip_model_db.R

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,66 @@ library(usethis)
99
# also requires installation of:
1010
packages <- c(
1111
"parsnip",
12-
"discrim",
13-
"plsmod",
14-
"rules",
15-
"baguette",
16-
"poissonreg",
17-
"multilevelmod",
18-
"modeltime",
19-
"modeltime.gluonts"
12+
parsnip:::extensions(),
13+
"modeltime"
14+
# "modeltime.gluonts" # required python packages to create spec
2015
)
2116

17+
loaded <- map(packages, library, character.only = TRUE)
18+
2219
# ------------------------------------------------------------------------------
2320

24-
# Detects model specifications via their print methods
25-
print_methods <- function(x) {
26-
require(x, character.only = TRUE)
27-
ns <- asNamespace(ns = x)
28-
mthds <- ls(envir = ns, pattern = "^print\\.")
29-
mthds <- gsub("^print\\.", "", mthds)
30-
purrr::map(mthds, get_engines) |>
21+
get_model <- function(x) {
22+
res <- get_from_env(x)
23+
if (!is.null(res)) {
24+
res <- dplyr::mutate(res, model = x)
25+
}
26+
res
27+
}
28+
29+
get_packages <- function(x) {
30+
res <- get_from_env(paste0(x, "_pkgs"))
31+
if (is.null(res)) {
32+
return(res)
33+
}
34+
res <-
35+
res |>
36+
tidyr::unnest(pkg) |>
37+
dplyr::mutate(
38+
model = x
39+
)
40+
41+
res
42+
}
43+
44+
get_models <- function() {
45+
res <- ls(envir = get_model_env(), pattern = "_fit$")
46+
models <- gsub("_fit$", "", res)
47+
models <-
48+
purrr::map(models, get_model) |>
49+
purrr::list_rbind()
50+
51+
# get source package
52+
pkgs <- gsub("_fit$", "_pkgs", res)
53+
pkgs <-
54+
unique(models$model) |>
55+
purrr::map(get_packages) |>
3156
purrr::list_rbind() |>
32-
dplyr::mutate(package = x)
57+
dplyr::filter(pkg %in% packages)
58+
dplyr::left_join(models, pkgs, by = dplyr::join_by(engine, mode, model)) |>
59+
dplyr::rename(package = pkg) |>
60+
dplyr::mutate(
61+
package = dplyr::if_else(is.na(package), "parsnip", package),
62+
call_from_parsnip = package %in% parsnip:::extensions(),
63+
caller_package = dplyr::if_else(
64+
call_from_parsnip,
65+
"parsnip",
66+
package
67+
)
68+
)
3369
}
70+
71+
3472
get_engines <- function(x) {
3573
eng <- try(parsnip::show_engines(x), silent = TRUE)
3674
if (inherits(eng, "try-error")) {
@@ -77,8 +115,8 @@ get_tunable_param <- function(mode, package, model, engine) {
77115
# ------------------------------------------------------------------------------
78116

79117
model_db <-
80-
purrr::map(packages, print_methods) |>
81-
purrr::list_rbind() |>
118+
get_models() |>
119+
dplyr::filter(mode %in% c("regression", "classification")) |>
82120
dplyr::filter(engine != "liquidSVM") |>
83121
dplyr::filter(model != "surv_reg") |>
84122
dplyr::filter(engine != "spark") |>
@@ -98,9 +136,10 @@ model_db <-
98136
dplyr::left_join(model_db, num_modes, by = c("package", "model", "engine")) |>
99137
dplyr::mutate(
100138
parameters = purrr::pmap(
101-
list(mode, package, model, engine),
139+
list(mode, caller_package, model, engine),
102140
get_tunable_param
103141
)
104-
)
142+
) |>
143+
dplyr::select(-call_from_parsnip, -caller_package)
105144

106145
usethis::use_data(model_db, overwrite = TRUE)

0 commit comments

Comments
 (0)