Skip to content

Commit 729a24d

Browse files
committed
Updated posterior interval R functions to handle mu and tau separately from prognostic function and cate
1 parent a18f053 commit 729a24d

File tree

2 files changed

+180
-21
lines changed

2 files changed

+180
-21
lines changed

R/posterior_transformation.R

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ posterior_predictive_heuristic_multiplier <- function(
832832
#'
833833
#' This function computes posterior credible intervals for specified terms from a fitted BCF model. It supports intervals for prognostic forests, CATE forests, variance forests, random effects, and overall mean outcome predictions.
834834
#' @param model_object A fitted BCF model object of class `bcfmodel`.
835-
#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`.
835+
#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`.
836836
#' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval).
837837
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
838838
#' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).
@@ -895,6 +895,29 @@ compute_bcf_posterior_interval <- function(
895895
}
896896

897897
# Check that all the necessary inputs were provided for interval computation
898+
for (term in terms) {
899+
if (
900+
!(term %in%
901+
c(
902+
"prognostic_function",
903+
"mu",
904+
"cate",
905+
"tau",
906+
"variance_forest",
907+
"rfx",
908+
"y_hat",
909+
"all"
910+
))
911+
) {
912+
stop(
913+
paste0(
914+
"Term '",
915+
term,
916+
"' was requested. Valid terms are 'prognostic_function', 'mu', 'cate', 'tau', 'variance_forest', 'rfx', 'y_hat', and 'all'."
917+
)
918+
)
919+
}
920+
}
898921
needs_covariates_intermediate <- ((("y_hat" %in% terms) ||
899922
("all" %in% terms)))
900923
needs_covariates <- (("prognostic_function" %in% terms) ||
@@ -975,16 +998,22 @@ compute_bcf_posterior_interval <- function(
975998
"'rfx_group_ids' must have the same length as the number of rows in 'covariates'"
976999
)
9771000
}
978-
if (is.null(rfx_basis)) {
979-
stop(
980-
"'rfx_basis' must be provided in order to compute the requested intervals"
981-
)
982-
}
983-
if (!is.matrix(rfx_basis)) {
984-
stop("'rfx_basis' must be a matrix")
1001+
1002+
if (model_object$model_params$rfx_model_spec == "custom") {
1003+
if (is.null(rfx_basis)) {
1004+
stop(
1005+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
1006+
)
1007+
}
9851008
}
986-
if (nrow(rfx_basis) != nrow(covariates)) {
987-
stop("'rfx_basis' must have the same number of rows as 'covariates'")
1009+
1010+
if (!is.null(rfx_basis)) {
1011+
if (!is.matrix(rfx_basis)) {
1012+
stop("'rfx_basis' must be a matrix")
1013+
}
1014+
if (nrow(rfx_basis) != nrow(covariates)) {
1015+
stop("'rfx_basis' must have the same number of rows as 'covariates'")
1016+
}
9881017
}
9891018
}
9901019

@@ -1006,11 +1035,15 @@ compute_bcf_posterior_interval <- function(
10061035
if (has_multiple_terms) {
10071036
result <- list()
10081037
for (term_name in names(predictions)) {
1009-
result[[term_name]] <- summarize_interval(
1010-
predictions[[term_name]],
1011-
sample_dim = 2,
1012-
level = level
1013-
)
1038+
if (!is.null(predictions[[term_name]])) {
1039+
result[[term_name]] <- summarize_interval(
1040+
predictions[[term_name]],
1041+
sample_dim = 2,
1042+
level = level
1043+
)
1044+
} else {
1045+
result[[term_name]] <- NULL
1046+
}
10141047
}
10151048
return(result)
10161049
} else {
@@ -1161,11 +1194,15 @@ compute_bart_posterior_interval <- function(
11611194
if (has_multiple_terms) {
11621195
result <- list()
11631196
for (term_name in names(predictions)) {
1164-
result[[term_name]] <- summarize_interval(
1165-
predictions[[term_name]],
1166-
sample_dim = 2,
1167-
level = level
1168-
)
1197+
if (!is.null(predictions[[term_name]])) {
1198+
result[[term_name]] <- summarize_interval(
1199+
predictions[[term_name]],
1200+
sample_dim = 2,
1201+
level = level
1202+
)
1203+
} else {
1204+
result[[term_name]] <- NULL
1205+
}
11691206
}
11701207
return(result)
11711208
} else {
@@ -1253,8 +1290,12 @@ bart_model_has_term <- function(model_object, term) {
12531290
bcf_model_has_term <- function(model_object, term) {
12541291
if (term == "prognostic_function") {
12551292
return(TRUE)
1293+
} else if (term == "mu") {
1294+
return(TRUE)
12561295
} else if (term == "cate") {
12571296
return(TRUE)
1297+
} else if (term == "tau") {
1298+
return(TRUE)
12581299
} else if (term == "variance_forest") {
12591300
return(model_object$model_params$include_variance_forest)
12601301
} else if (term == "rfx") {
@@ -1280,15 +1321,17 @@ validate_bart_term <- function(term) {
12801321
validate_bcf_term <- function(term) {
12811322
model_terms <- c(
12821323
"prognostic_function",
1324+
"mu",
12831325
"cate",
1326+
"tau",
12841327
"variance_forest",
12851328
"rfx",
12861329
"y_hat",
12871330
"all"
12881331
)
12891332
if (!(term %in% model_terms)) {
12901333
stop(
1291-
"'term' must be one of 'prognostic_function', 'cate', 'variance_forest', 'rfx', 'y_hat', or 'all' for bcfmodel objects"
1334+
"'term' must be one of 'prognostic_function', 'mu', 'cate', 'tau', 'variance_forest', 'rfx', 'y_hat', or 'all' for bcfmodel objects"
12921335
)
12931336
}
12941337
}

tools/debug/bcf_predict_debug.R

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,119 @@ tau_hat_test <- predict(
354354

355355
# Compare to prognostic function returned from the larger prediction
356356
all(abs(tau_hat_test - posterior_preds_test$tau_hat) < 0.0001)
357+
358+
# Compute intervals for all of the model terms
359+
posterior_intervals_test <- compute_bcf_posterior_interval(
360+
model_object = bcf_model,
361+
scale = "linear",
362+
terms = "all",
363+
covariates = X_test,
364+
treatment = Z_test,
365+
propensity = pi_test,
366+
rfx_group_ids = rfx_group_ids_test,
367+
level = 0.95
368+
)
369+
370+
# Compute intervals for just the prognostic term
371+
prog_intervals_test <- compute_bcf_posterior_interval(
372+
model_object = bcf_model,
373+
scale = "linear",
374+
terms = "prognostic_function",
375+
covariates = X_test,
376+
treatment = Z_test,
377+
propensity = pi_test,
378+
rfx_group_ids = rfx_group_ids_test,
379+
level = 0.95
380+
)
381+
382+
# Compute intervals for just the CATE term
383+
cate_intervals_test <- compute_bcf_posterior_interval(
384+
model_object = bcf_model,
385+
scale = "linear",
386+
terms = "cate",
387+
covariates = X_test,
388+
treatment = Z_test,
389+
propensity = pi_test,
390+
rfx_group_ids = rfx_group_ids_test,
391+
level = 0.95
392+
)
393+
394+
# Check that they match the corresponding terms from the full interval list
395+
all(
396+
abs(
397+
posterior_intervals_test$prognostic_function$lower -
398+
prog_intervals_test$lower
399+
) <
400+
0.0001
401+
)
402+
all(
403+
abs(
404+
posterior_intervals_test$prognostic_function$upper -
405+
prog_intervals_test$upper
406+
) <
407+
0.0001
408+
)
409+
all(
410+
abs(
411+
posterior_intervals_test$cate$lower -
412+
cate_intervals_test$lower
413+
) <
414+
0.0001
415+
)
416+
all(
417+
abs(
418+
posterior_intervals_test$cate$upper -
419+
cate_intervals_test$upper
420+
) <
421+
0.0001
422+
)
423+
424+
# Check that the prog and CATE intervals are different from the mu and tau intervals
425+
mu_intervals_test <- compute_bcf_posterior_interval(
426+
model_object = bcf_model,
427+
scale = "linear",
428+
terms = "mu",
429+
covariates = X_test,
430+
treatment = Z_test,
431+
propensity = pi_test,
432+
rfx_group_ids = rfx_group_ids_test,
433+
level = 0.95
434+
)
435+
tau_intervals_test <- compute_bcf_posterior_interval(
436+
model_object = bcf_model,
437+
scale = "linear",
438+
terms = "tau",
439+
covariates = X_test,
440+
treatment = Z_test,
441+
propensity = pi_test,
442+
rfx_group_ids = rfx_group_ids_test,
443+
level = 0.95
444+
)
445+
all(
446+
abs(
447+
mu_intervals_test$lower -
448+
prog_intervals_test$lower
449+
) >
450+
0.0001
451+
)
452+
all(
453+
abs(
454+
mu_intervals_test$upper -
455+
prog_intervals_test$upper
456+
) >
457+
0.0001
458+
)
459+
all(
460+
abs(
461+
tau_intervals_test$lower -
462+
cate_intervals_test$lower
463+
) >
464+
0.0001
465+
)
466+
all(
467+
abs(
468+
tau_intervals_test$upper -
469+
cate_intervals_test$upper
470+
) >
471+
0.0001
472+
)

0 commit comments

Comments
 (0)