diff --git a/DESCRIPTION b/DESCRIPTION index fded77f48..bbc018a48 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 1.2.1.9003 +Version: 1.2.1.9004 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@posit.co", role = "aut"), diff --git a/NAMESPACE b/NAMESPACE index 0782892d6..ea8dce360 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -185,6 +185,7 @@ export(.dat) export(.extract_surv_status) export(.extract_surv_time) export(.facts) +export(.get_prediction_column_names) export(.lvls) export(.model_param_name_key) export(.obs) diff --git a/R/misc.R b/R/misc.R index 3582d8ca2..263a9b0f2 100644 --- a/R/misc.R +++ b/R/misc.R @@ -575,3 +575,75 @@ is_cran_check <- function() { } # nocov end +# ------------------------------------------------------------------------------ + +#' Obtain names of prediction columns for a fitted model or workflow +#' +#' [.get_prediction_column_names()] returns a list that has the names of the +#' columns for the primary prediction types for a model. +#' @param x A fitted parsnip model (class `"model_fit"`) or a fitted workflow. +#' @param syms Should the column names be converted to symbols? Defaults to `FALSE`. +#' @return A list with elements `"estimate"` and `"probabilities"`. +#' @examplesIf !parsnip:::is_cran_check() +#' library(dplyr) +#' library(modeldata) +#' data("two_class_dat") +#' +#' levels(two_class_dat$Class) +#' lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat) +#' +#' .get_prediction_column_names(lr_fit) +#' .get_prediction_column_names(lr_fit, syms = TRUE) +#' @export +.get_prediction_column_names <- function(x, syms = FALSE) { + if (!inherits(x, c("model_fit", "workflow"))) { + cli::cli_abort("{.arg x} should be an object with class {.cls model_fit} or + {.cls workflow}, not {.obj_type_friendly {x}}.") + } + + if (inherits(x, "workflow")) { + x <- x %>% extract_fit_parsnip(x) + } + model_spec <- extract_spec_parsnip(x) + model_engine <- model_spec$engine + model_mode <- model_spec$mode + model_type <- class(model_spec)[1] + + # appropriate populate the model db + inst_res <- purrr::map(required_pkgs(x), rlang::check_installed) + predict_types <- + get_from_env(paste0(model_type, "_predict")) %>% + dplyr::filter(engine == model_engine & mode == model_mode) %>% + purrr::pluck("type") + + if (length(predict_types) == 0) { + cli::cli_abort("Prediction information could not be found for this + {.fn {model_type}} with engine {.val {model_engine}} and mode + {.val {model_mode}}. Does a parsnip extension package need to + be loaded?") + } + + res <- list(estimate = character(0), probabilities = character(0)) + + if (model_mode == "regression") { + res$estimate <- ".pred" + } else if (model_mode == "classification") { + res$estimate <- ".pred_class" + if (any(predict_types == "prob")) { + res$probabilities <- paste0(".pred_", x$lvl) + } + } else if (model_mode == "censored regression") { + res$estimate <- ".pred_time" + if (any(predict_types %in% c("survival"))) { + res$probabilities <- ".pred" + } + } else { + # Should be unreachable + cli::cli_abort("Unsupported model mode {model_mode}.") + } + + if (syms) { + res <- purrr::map(res, rlang::syms) + } + res +} diff --git a/_pkgdown.yml b/_pkgdown.yml index c79ecca06..3596da3f3 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -111,3 +111,4 @@ reference: - .extract_surv_status - .extract_surv_time - .model_param_name_key + - .get_prediction_column_names diff --git a/man/dot-get_prediction_column_names.Rd b/man/dot-get_prediction_column_names.Rd new file mode 100644 index 000000000..36fba2c08 --- /dev/null +++ b/man/dot-get_prediction_column_names.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/misc.R +\name{.get_prediction_column_names} +\alias{.get_prediction_column_names} +\title{Obtain names of prediction columns for a fitted model or workflow} +\usage{ +.get_prediction_column_names(x, syms = FALSE) +} +\arguments{ +\item{x}{A fitted model (class \code{"model_fit"}) or a fitted workflow.} + +\item{syms}{Should the column names be converted to symbols?} +} +\value{ +A list with elements \code{"estimate"} and \code{"probabilities"}. +} +\description{ +\code{\link[=.get_prediction_column_names]{.get_prediction_column_names()}} returns a list that has the names of the +columns for the primary prediction types for a model. +} +\examples{ +\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +library(dplyr) +library(modeldata) +data("two_class_dat") + +levels(two_class_dat$Class) +lr_fit <- logistic_reg() \%>\% fit(Class ~ ., data = two_class_dat) + +.get_prediction_column_names(lr_fit) +.get_prediction_column_names(lr_fit, syms = TRUE) +\dontshow{\}) # examplesIf} +} diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index b6b1f918c..b221b1dde 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -227,3 +227,19 @@ Error in `check_outcome()`: ! For a censored regression model, the outcome should be a object, not an integer vector. +# obtaining prediction columns + + Code + .get_prediction_column_names(1) + Condition + Error in `.get_prediction_column_names()`: + ! `x` should be an object with class or , not a number. + +--- + + Code + .get_prediction_column_names(unk_fit) + Condition + Error in `.get_prediction_column_names()`: + ! Prediction information could not be found for this `linear_reg()` with engine "lm" and mode "Depeche". Does a parsnip extension package need to be loaded? + diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index e689bcbcf..901c92748 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -249,3 +249,53 @@ test_that('check_outcome works as expected', { check_outcome(1:2, cens_spec) ) }) + +# ------------------------------------------------------------------------------ + +test_that('obtaining prediction columns', { + skip_if_not_installed("modeldata") + data(two_class_dat, package = "modeldata") + + ### classification + lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat) + expect_equal( + .get_prediction_column_names(lr_fit), + list(estimate = ".pred_class", + probabilities = c(".pred_Class1", ".pred_Class2")) + ) + expect_equal( + .get_prediction_column_names(lr_fit, syms = TRUE), + list(estimate = list(quote(.pred_class)), + probabilities = list(quote(.pred_Class1), quote(.pred_Class2))) + ) + + ### regression + ols_fit <- linear_reg() %>% fit(mpg ~ ., data = mtcars) + expect_equal( + .get_prediction_column_names(ols_fit), + list(estimate = ".pred", + probabilities = character(0)) + ) + expect_equal( + .get_prediction_column_names(ols_fit, syms = TRUE), + list(estimate = list(quote(.pred)), + probabilities = list()) + ) + + ### censored regression + # in extratests + + ### bad input + expect_snapshot( + .get_prediction_column_names(1), + error = TRUE + ) + + unk_fit <- ols_fit + unk_fit$spec$mode <- "Depeche" + expect_snapshot( + .get_prediction_column_names(unk_fit), + error = TRUE + ) + +})