@@ -504,14 +504,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
504504 if (! is.null(X_test )) X_test <- preprocessPredictionData(X_test , X_train_metadata )
505505
506506 # Convert all input data to matrices if not already converted
507- if ((is.null(dim(Z_train ))) && (! is.null(Z_train ))) {
508- Z_train <- as.matrix(as.numeric(Z_train ))
509- }
507+ Z_col <- ifelse(is.null(dim(Z_train )), 1 , ncol(Z_train ))
508+ Z_train <- matrix (as.numeric(Z_train ), ncol = Z_col )
510509 if ((is.null(dim(propensity_train ))) && (! is.null(propensity_train ))) {
511510 propensity_train <- as.matrix(propensity_train )
512511 }
513- if ((is.null(dim( Z_test ))) && ( ! is.null(Z_test ) )) {
514- Z_test <- as. matrix(as.numeric(Z_test ))
512+ if (! is.null(Z_test )) {
513+ Z_test <- matrix (as.numeric(Z_test ), ncol = Z_col )
515514 }
516515 if ((is.null(dim(propensity_test ))) && (! is.null(propensity_test ))) {
517516 propensity_test <- as.matrix(propensity_test )
@@ -580,9 +579,30 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
580579 }
581580 }
582581
583- # Stop if multivariate treatment is provided
584- if (ncol(Z_train ) > 1 ) stop(" Multivariate treatments are not currently supported" )
585-
582+ # # Stop if multivariate treatment is provided
583+ # if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported")
584+
585+ # Handle multivariate treatment
586+ has_multivariate_treatment <- ncol(Z_train ) > 1
587+ if (has_multivariate_treatment ) {
588+ # Disable adaptive coding, internal propensity model, and
589+ # leaf scale sampling if treatment is multivariate
590+ if (adaptive_coding ) {
591+ warning(" Adaptive coding is incompatible with multivariate treatment and will be ignored" )
592+ adaptive_coding <- FALSE
593+ }
594+ if (is.null(propensity_train )) {
595+ if (propensity_covariate != " none" ) {
596+ warning(" No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'" )
597+ propensity_covariate <- " none"
598+ }
599+ }
600+ if (sample_sigma2_leaf_tau ) {
601+ warning(" Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model." )
602+ sample_sigma2_leaf_tau <- FALSE
603+ }
604+ }
605+
586606 # Random effects covariance prior
587607 if (has_rfx ) {
588608 if (is.null(rfx_prior_var )) {
@@ -835,18 +855,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
835855 current_sigma2 <- sigma2_init
836856 }
837857
838- # Switch off leaf scale sampling for multivariate treatments
839- if (ncol(Z_train ) > 1 ) {
840- if (sample_sigma2_leaf_tau ) {
841- warning(" Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model." )
842- sample_sigma2_leaf_tau <- FALSE
843- }
844- }
845-
846858 # Set mu and tau leaf models / dimensions
847859 leaf_model_mu_forest <- 0
848860 leaf_dimension_mu_forest <- 1
849- if (ncol( Z_train ) > 1 ) {
861+ if (has_multivariate_treatment ) {
850862 leaf_model_tau_forest <- 2
851863 leaf_dimension_tau_forest <- ncol(Z_train )
852864 } else {
@@ -973,21 +985,21 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
973985
974986 # Container of forest samples
975987 forest_samples_mu <- createForestSamples(num_trees_mu , 1 , TRUE )
976- forest_samples_tau <- createForestSamples(num_trees_tau , 1 , FALSE )
988+ forest_samples_tau <- createForestSamples(num_trees_tau , ncol( Z_train ) , FALSE )
977989 active_forest_mu <- createForest(num_trees_mu , 1 , TRUE )
978- active_forest_tau <- createForest(num_trees_tau , 1 , FALSE )
990+ active_forest_tau <- createForest(num_trees_tau , ncol( Z_train ) , FALSE )
979991 if (include_variance_forest ) {
980992 forest_samples_variance <- createForestSamples(num_trees_variance , 1 , TRUE , TRUE )
981993 active_forest_variance <- createForest(num_trees_variance , 1 , TRUE , TRUE )
982994 }
983995
984996 # Initialize the leaves of each tree in the prognostic forest
985- active_forest_mu $ prepare_for_sampler(forest_dataset_train , outcome_train , forest_model_mu , 0 , init_mu )
997+ active_forest_mu $ prepare_for_sampler(forest_dataset_train , outcome_train , forest_model_mu , leaf_model_mu_forest , init_mu )
986998 active_forest_mu $ adjust_residual(forest_dataset_train , outcome_train , forest_model_mu , FALSE , FALSE )
987999
9881000 # Initialize the leaves of each tree in the treatment effect forest
989- init_tau <- 0 .
990- active_forest_tau $ prepare_for_sampler(forest_dataset_train , outcome_train , forest_model_tau , 1 , init_tau )
1001+ init_tau <- rep( 0 . , ncol( Z_train ))
1002+ active_forest_tau $ prepare_for_sampler(forest_dataset_train , outcome_train , forest_model_tau , leaf_model_tau_forest , init_tau )
9911003 active_forest_tau $ adjust_residual(forest_dataset_train , outcome_train , forest_model_tau , TRUE , FALSE )
9921004
9931005 # Initialize the leaves of each tree in the variance forest
@@ -1450,7 +1462,18 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
14501462 } else {
14511463 tau_hat_train <- forest_samples_tau $ predict_raw(forest_dataset_train )* y_std_train
14521464 }
1453- y_hat_train <- mu_hat_train + tau_hat_train * as.numeric(Z_train )
1465+ if (has_multivariate_treatment ) {
1466+ tau_train_dim <- dim(tau_hat_train )
1467+ tau_num_obs <- tau_train_dim [1 ]
1468+ tau_num_samples <- tau_train_dim [3 ]
1469+ treatment_term_train <- matrix (NA_real_ , nrow = tau_num_obs , tau_num_samples )
1470+ for (i in 1 : nrow(Z_train )) {
1471+ treatment_term_train [i ,] <- colSums(tau_hat_train [i ,,] * Z_train [i ,])
1472+ }
1473+ } else {
1474+ treatment_term_train <- tau_hat_train * as.numeric(Z_train )
1475+ }
1476+ y_hat_train <- mu_hat_train + treatment_term_train
14541477 if (has_test ) {
14551478 mu_hat_test <- forest_samples_mu $ predict(forest_dataset_test )* y_std_train + y_bar_train
14561479 if (adaptive_coding ) {
@@ -1459,7 +1482,18 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
14591482 } else {
14601483 tau_hat_test <- forest_samples_tau $ predict_raw(forest_dataset_test )* y_std_train
14611484 }
1462- y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test )
1485+ if (has_multivariate_treatment ) {
1486+ tau_test_dim <- dim(tau_hat_test )
1487+ tau_num_obs <- tau_test_dim [1 ]
1488+ tau_num_samples <- tau_test_dim [3 ]
1489+ treatment_term_test <- matrix (NA_real_ , nrow = tau_num_obs , tau_num_samples )
1490+ for (i in 1 : nrow(Z_test )) {
1491+ treatment_term_test [i ,] <- colSums(tau_hat_test [i ,,] * Z_test [i ,])
1492+ }
1493+ } else {
1494+ treatment_term_test <- tau_hat_test * as.numeric(Z_test )
1495+ }
1496+ y_hat_test <- mu_hat_test + treatment_term_test
14631497 }
14641498 if (include_variance_forest ) {
14651499 sigma2_x_hat_train <- exp(sigma2_x_train_raw )
@@ -1526,6 +1560,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
15261560 " treatment_dim" = ncol(Z_train ),
15271561 " propensity_covariate" = propensity_covariate ,
15281562 " binary_treatment" = binary_treatment ,
1563+ " multivariate_treatment" = has_multivariate_treatment ,
15291564 " adaptive_coding" = adaptive_coding ,
15301565 " internal_propensity_model" = internal_propensity_model ,
15311566 " num_samples" = num_retained_samples ,
@@ -1722,6 +1757,17 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU
17221757 } else {
17231758 tau_hat <- object $ forests_tau $ predict_raw(forest_dataset_pred )* y_std
17241759 }
1760+ if (object $ model_params $ multivariate_treatment ) {
1761+ tau_dim <- dim(tau_hat )
1762+ tau_num_obs <- tau_dim [1 ]
1763+ tau_num_samples <- tau_dim [3 ]
1764+ treatment_term <- matrix (NA_real_ , nrow = tau_num_obs , tau_num_samples )
1765+ for (i in 1 : nrow(Z_train )) {
1766+ treatment_term [i ,] <- colSums(tau_hat [i ,,] * Z [i ,])
1767+ }
1768+ } else {
1769+ treatment_term <- tau_hat * as.numeric(Z )
1770+ }
17251771 if (object $ model_params $ include_variance_forest ) {
17261772 s_x_raw <- object $ forests_variance $ predict(forest_dataset_pred )
17271773 }
@@ -1732,7 +1778,7 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU
17321778 }
17331779
17341780 # Compute overall "y_hat" predictions
1735- y_hat <- mu_hat + tau_hat * as.numeric( Z )
1781+ y_hat <- mu_hat + treatment_term
17361782 if (object $ model_params $ has_rfx ) y_hat <- y_hat + rfx_predictions
17371783
17381784 # Scale variance forest predictions
0 commit comments