Skip to content

Commit d6f5d93

Browse files
author
‘topepo’
committed
some predict call routing
1 parent 448a74d commit d6f5d93

File tree

8 files changed

+21
-12
lines changed

8 files changed

+21
-12
lines changed

R/mlp.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,10 @@ check_args.mlp <- function(object, call = rlang::caller_env()) {
146146

147147
# keras wrapper for feed-forward nnet
148148

149-
class2ind <- function (x, drop2nd = FALSE) {
150-
if (!is.factor(x))
151-
cli::cli_abort(c("x" = "{.arg x} should be a factor."))
149+
class2ind <- function (x, drop2nd = FALSE, call = rlang::caller_env()) {
150+
if (!is.factor(x)) {
151+
cli::cli_abort(c("x" = "{.arg x} should be a {cls factor} not {.obj_type_friendly {x}."))
152+
}
152153
y <- model.matrix( ~ x - 1)
153154
colnames(y) <- gsub("^x", "", colnames(y))
154155
attributes(y)$assign <- NULL

R/predict_class.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
#' @export
1111
predict_class.model_fit <- function(object, new_data, ...) {
1212
if (object$spec$mode != "classification") {
13-
cli::cli_abort("{.fun predict.model_fit} is for predicting factor outcomes.")
13+
cli::cli_abort("{.fun predict.model_fit} is for predicting factor outcomes.",
14+
call = rlang::call2("predict"))
1415
}
1516

1617
check_spec_pred_type(object, "class")
1718

1819
if (inherits(object$fit, "try-error")) {
19-
cli::cli_warn("Model fit failed; cannot make predictions.")
20+
cli::cli_warn("Model fit failed; cannot make predictions.",
21+
call = rlang::call2("predict"))
2022
return(NULL)
2123
}
2224

R/predict_classprob.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
#' @export
77
predict_classprob.model_fit <- function(object, new_data, ...) {
88
if (object$spec$mode != "classification") {
9-
cli::cli_abort("{.fun predict.model_fit()} is for predicting factor outcomes.")
9+
cli::cli_abort("{.fun predict.model_fit()} is for predicting factor outcomes.",
10+
call = rlang::call2("predict"))
1011
}
1112

1213
check_spec_pred_type(object, "prob", call = caller_env())
@@ -36,7 +37,8 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
3637

3738
# check and sort names
3839
if (!is.data.frame(res) & !inherits(res, "tbl_spark")) {
39-
cli::cli_abort("The was a problem with the probability predictions.")
40+
cli::cli_abort("The was a problem with the probability predictions.",
41+
call = rlang::call2("predict"))
4042
}
4143

4244
if (!is_tibble(res) & !inherits(res, "tbl_spark")) {

R/predict_numeric.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ predict_numeric.model_fit <- function(object, new_data, ...) {
1111
"{.fun predict_numeric} is for predicting numeric outcomes.",
1212
"i" = "Use {.fun predict_class} or {.fun predict_classprob} for
1313
classification models."
14-
)
14+
),
15+
call = rlang::call2("predict")
1516
)
1617
}
1718

R/predict_quantile.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ predict_quantile.model_fit <- function(object,
3737
if (object$spec$mode == "quantile regression") {
3838
if (!is.null(quantile_levels)) {
3939
cli::cli_abort("When the mode is {.val quantile regression},
40-
{.arg quantile_levels} are specified by {.fn set_mode}.")
40+
{.arg quantile_levels} are specified by {.fn set_mode}.",
41+
call = rlang::call2("predict"))
4142
}
4243
} else {
4344
if (is.null(quantile_levels)) {

R/predict_time.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ predict_time.model_fit <- function(object, new_data, ...) {
1111
"{.fun predict_time} is for predicting time outcomes.",
1212
"i" = "Use {.fun predict_class} or {.fun predict_classprob} for
1313
classification models."
14-
)
14+
),
15+
call = rlang::call2("predict")
1516
)
1617
}
1718

R/rand_forest_data.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ ranger_confint <- function(object, new_data, ...) {
6666
} else {
6767
cli::cli_abort(
6868
"Cannot compute confidence intervals for a ranger forest
69-
of type {.val {object$fit$forest$treetype}}."
69+
of type {.val {object$fit$forest$treetype}}.",
70+
call = rlang::call2("predict")
7071
)
7172
}
7273
}

tests/testthat/_snaps/linear_reg_quantreg.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0:
55
9) / 9)
66
Condition
7-
Error in `predict_quantile()`:
7+
Error in `predict()`:
88
! When the mode is "quantile regression", `quantile_levels` are specified by `set_mode()`.
99

0 commit comments

Comments
 (0)