Skip to content

Commit bd92eed

Browse files
committed
enable quantile regression
1 parent b9cc657 commit bd92eed

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

R/rand_forest_data.R

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ grf_conf_int <- function(
157157
res
158158
}
159159

160+
qrf_quantile_convert <- function(x, object) {
161+
matrix_to_quantile_pred(x$predictions, object)
162+
}
163+
160164
# ------------------------------------------------------------------------------
161165

162166
set_new_model("rand_forest")
@@ -849,3 +853,48 @@ set_pred(
849853
)
850854
)
851855
)
856+
857+
set_fit(
858+
model = "rand_forest",
859+
eng = "grf",
860+
mode = "quantile regression",
861+
value = list(
862+
interface = "data.frame",
863+
data = c(x = "X", y = "Y", weights = "case.weights"),
864+
protect = c("x", "y", "weights"),
865+
func = c(pkg = "grf", fun = "quantile_forest"),
866+
defaults = list(
867+
num.threads = 1,
868+
quantiles = quote(quantile_levels)
869+
)
870+
)
871+
)
872+
873+
set_encoding(
874+
model = "rand_forest",
875+
eng = "grf",
876+
mode = "quantile regression",
877+
options = list(
878+
predictor_indicators = "one_hot",
879+
compute_intercept = FALSE,
880+
remove_intercept = TRUE,
881+
allow_sparse_x = FALSE
882+
)
883+
)
884+
885+
set_pred(
886+
model = "rand_forest",
887+
eng = "grf",
888+
mode = "quantile regression",
889+
type = "quantile",
890+
value = list(
891+
pre = NULL,
892+
post = qrf_quantile_convert,
893+
func = c(fun = "predict"),
894+
args = list(
895+
object = expr(object$fit),
896+
newdata = expr(new_data),
897+
quantiles = NULL
898+
)
899+
)
900+
)

0 commit comments

Comments
 (0)