Skip to content

Commit 2383983

Browse files
authored
Merge pull request #233 from StochasticTree/predict-python-hotfix
Hotfix for BCF predictions when only tau(X) is requested
2 parents 76d47c5 + daab8e5 commit 2383983

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

R/bcf.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3061,8 +3061,10 @@ predict.bcfmodel <- function(
30613061
t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples)
30623062
) *
30633063
y_std
3064-
control_adj <- t(t(tau_hat_raw) * object$b_0_samples) * y_std
3065-
mu_hat_forest <- mu_hat_forest + control_adj
3064+
if (predict_mu_forest || predict_mu_forest_intermediate) {
3065+
control_adj <- t(t(tau_hat_raw) * object$b_0_samples) * y_std
3066+
mu_hat_forest <- mu_hat_forest + control_adj
3067+
}
30663068
} else {
30673069
tau_hat_forest <- object$forests_tau$predict_raw(forest_dataset_pred) *
30683070
y_std

stochtree/bcf.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2604,12 +2604,13 @@ def predict(
26042604
adaptive_coding_weights = np.expand_dims(
26052605
self.b1_samples - self.b0_samples, axis=(0, 2)
26062606
)
2607-
b0_weights = np.expand_dims(
2608-
self.b0_samples, axis=(0, 2)
2609-
)
2610-
control_adj = tau_raw * b0_weights * self.y_std
2607+
if predict_mu_forest or predict_mu_forest_intermediate:
2608+
b0_weights = np.expand_dims(
2609+
self.b0_samples, axis=(0, 2)
2610+
)
2611+
control_adj = tau_raw * b0_weights * self.y_std
2612+
mu_x_forest = mu_x_forest + np.squeeze(control_adj)
26112613
tau_raw = tau_raw * adaptive_coding_weights
2612-
mu_x_forest = mu_x_forest + np.squeeze(control_adj)
26132614
tau_x_forest = np.squeeze(tau_raw * self.y_std)
26142615
if Z.shape[1] > 1:
26152616
treatment_term = np.multiply(

0 commit comments

Comments
 (0)