diff --git a/NAMESPACE b/NAMESPACE index b7c6b6b24..8584428b1 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -221,6 +221,7 @@ export(discrim_flexible) export(discrim_linear) export(discrim_quad) export(discrim_regularized) +export(ensure_parsnip_format) export(eval_args) export(extract_fit_engine) export(extract_fit_time) diff --git a/R/predict.R b/R/predict.R index c8a5fed59..5088232c3 100644 --- a/R/predict.R +++ b/R/predict.R @@ -271,6 +271,8 @@ check_pred_type <- function(object, type, ..., call = rlang::caller_env()) { #' tibbles. #' #' @param x A data frame or vector (depending on the context and function). +#' @param col_name A string for a prediction column name. +#' @param overwrite A logical for whether to overwrite the column name. #' @return A tibble #' @keywords internal #' @name format-internals @@ -336,6 +338,9 @@ format_hazard <- function(x) { ensure_parsnip_format(x, ".pred") } +#' @export +#' @rdname format-internals +#' @keywords internal ensure_parsnip_format <- function(x, col_name, overwrite = TRUE) { if (isTRUE(ncol(x) > 1) | is.data.frame(x)) { x <- tibble::new_tibble(x) diff --git a/man/format-internals.Rd b/man/format-internals.Rd index a9c0c771a..ed330a3db 100644 --- a/man/format-internals.Rd +++ b/man/format-internals.Rd @@ -9,6 +9,7 @@ \alias{format_survival} \alias{format_linear_pred} \alias{format_hazard} +\alias{ensure_parsnip_format} \title{Internal functions that format predictions} \usage{ format_num(x) @@ -24,9 +25,15 @@ format_survival(x) format_linear_pred(x) format_hazard(x) + +ensure_parsnip_format(x, col_name, overwrite = TRUE) } \arguments{ \item{x}{A data frame or vector (depending on the context and function).} + +\item{col_name}{A string for a prediction column name.} + +\item{overwrite}{A logical for whether to overwrite the column name.} } \value{ A tibble