File tree Expand file tree Collapse file tree 2 files changed +10
-7
lines changed Expand file tree Collapse file tree 2 files changed +10
-7
lines changed Original file line number Diff line number Diff line change @@ -3061,8 +3061,10 @@ predict.bcfmodel <- function(
30613061 t(tau_hat_raw ) * (object $ b_1_samples - object $ b_0_samples )
30623062 ) *
30633063 y_std
3064- control_adj <- t(t(tau_hat_raw ) * object $ b_0_samples ) * y_std
3065- mu_hat_forest <- mu_hat_forest + control_adj
3064+ if (predict_mu_forest || predict_mu_forest_intermediate ) {
3065+ control_adj <- t(t(tau_hat_raw ) * object $ b_0_samples ) * y_std
3066+ mu_hat_forest <- mu_hat_forest + control_adj
3067+ }
30663068 } else {
30673069 tau_hat_forest <- object $ forests_tau $ predict_raw(forest_dataset_pred ) *
30683070 y_std
Original file line number Diff line number Diff line change @@ -2604,12 +2604,13 @@ def predict(
26042604 adaptive_coding_weights = np .expand_dims (
26052605 self .b1_samples - self .b0_samples , axis = (0 , 2 )
26062606 )
2607- b0_weights = np .expand_dims (
2608- self .b0_samples , axis = (0 , 2 )
2609- )
2610- control_adj = tau_raw * b0_weights * self .y_std
2607+ if predict_mu_forest or predict_mu_forest_intermediate :
2608+ b0_weights = np .expand_dims (
2609+ self .b0_samples , axis = (0 , 2 )
2610+ )
2611+ control_adj = tau_raw * b0_weights * self .y_std
2612+ mu_x_forest = mu_x_forest + np .squeeze (control_adj )
26112613 tau_raw = tau_raw * adaptive_coding_weights
2612- mu_x_forest = mu_x_forest + np .squeeze (control_adj )
26132614 tau_x_forest = np .squeeze (tau_raw * self .y_std )
26142615 if Z .shape [1 ] > 1 :
26152616 treatment_term = np .multiply (
You can’t perform that action at this time.
0 commit comments