@@ -157,6 +157,10 @@ grf_conf_int <- function(
157157 res
158158}
159159
160+ qrf_quantile_convert <- function (x , object ) {
161+ matrix_to_quantile_pred(x $ predictions , object )
162+ }
163+
160164# ------------------------------------------------------------------------------
161165
162166set_new_model(" rand_forest" )
@@ -849,3 +853,48 @@ set_pred(
849853 )
850854 )
851855)
856+
857+ set_fit(
858+ model = " rand_forest" ,
859+ eng = " grf" ,
860+ mode = " quantile regression" ,
861+ value = list (
862+ interface = " data.frame" ,
863+ data = c(x = " X" , y = " Y" , weights = " case.weights" ),
864+ protect = c(" x" , " y" , " weights" ),
865+ func = c(pkg = " grf" , fun = " quantile_forest" ),
866+ defaults = list (
867+ num.threads = 1 ,
868+ quantiles = quote(quantile_levels )
869+ )
870+ )
871+ )
872+
873+ set_encoding(
874+ model = " rand_forest" ,
875+ eng = " grf" ,
876+ mode = " quantile regression" ,
877+ options = list (
878+ predictor_indicators = " one_hot" ,
879+ compute_intercept = FALSE ,
880+ remove_intercept = TRUE ,
881+ allow_sparse_x = FALSE
882+ )
883+ )
884+
885+ set_pred(
886+ model = " rand_forest" ,
887+ eng = " grf" ,
888+ mode = " quantile regression" ,
889+ type = " quantile" ,
890+ value = list (
891+ pre = NULL ,
892+ post = qrf_quantile_convert ,
893+ func = c(fun = " predict" ),
894+ args = list (
895+ object = expr(object $ fit ),
896+ newdata = expr(new_data ),
897+ quantiles = NULL
898+ )
899+ )
900+ )
0 commit comments