Skip to content

Commit 73bd635

Browse files
committed
Refactored predict method for BART and BCF to include labels for prediction components
1 parent 46e49f3 commit 73bd635

File tree

15 files changed

+191
-114
lines changed

15 files changed

+191
-114
lines changed

R/bart.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,15 +1276,23 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL
12761276
result <- list()
12771277
if ((object$model_params$has_rfx) || (object$model_params$include_mean_forest)) {
12781278
result[["y_hat"]] = y_hat
1279+
} else {
1280+
result[["y_hat"]] <- NULL
12791281
}
12801282
if (object$model_params$include_mean_forest) {
12811283
result[["mean_forest_predictions"]] = mean_forest_predictions
1284+
} else {
1285+
result[["mean_forest_predictions"]] <- NULL
12821286
}
12831287
if (object$model_params$has_rfx) {
12841288
result[["rfx_predictions"]] = rfx_predictions
1289+
} else {
1290+
result[["rfx_predictions"]] <- NULL
12851291
}
12861292
if (object$model_params$include_variance_forest) {
12871293
result[["variance_forest_predictions"]] = variance_forest_predictions
1294+
} else {
1295+
result[["variance_forest_predictions"]] <- NULL
12881296
}
12891297
return(result)
12901298
}

R/bcf.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,10 +1800,14 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU
18001800
"y_hat" = y_hat
18011801
)
18021802
if (object$model_params$has_rfx) {
1803-
result[["rfx_predictions"]] = rfx_predictions
1803+
result[["rfx_predictions"]] <- rfx_predictions
1804+
} else {
1805+
result[["rfx_predictions"]] <- NULL
18041806
}
18051807
if (object$model_params$include_variance_forest) {
1806-
result[["variance_forest_predictions"]] = variance_forest_predictions
1808+
result[["variance_forest_predictions"]] <- variance_forest_predictions
1809+
} else {
1810+
result[["variance_forest_predictions"]] <- NULL
18071811
}
18081812
return(result)
18091813
}

demo/debug/multi_chain.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,15 @@ def outcome_mean(X, W):
118118
)
119119

120120
# Inspect the model outputs
121-
y_hat_mcmc_2 = bart_model_2.predict(X_test, basis_test)
121+
bart_preds_2 = bart_model_2.predict(X_test, basis_test)
122+
y_hat_mcmc_2 = bart_preds_2['y_hat']
122123
y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True)
123-
y_hat_mcmc_3 = bart_model_3.predict(X_test, basis_test)
124+
y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True)
125+
bart_preds_3 = bart_model_3.predict(X_test, basis_test)
126+
y_hat_mcmc_3 = bart_preds_3['y_hat']
124127
y_avg_mcmc_3 = np.squeeze(y_hat_mcmc_3).mean(axis=1, keepdims=True)
125-
y_hat_mcmc_4 = bart_model_4.predict(X_test, basis_test)
128+
bart_preds_4 = bart_model_4.predict(X_test, basis_test)
129+
y_hat_mcmc_4 = bart_preds_4['y_hat']
126130
y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True)
127131
y_df = pd.DataFrame(
128132
np.concatenate(

demo/debug/parallel_multi_chain.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def outcome_mean(X, W):
145145
)
146146

147147
# Inspect the model outputs
148-
y_hat_mcmc = combined_bart.predict(X_test, basis_test)
148+
bart_preds = combined_bart.predict(X_test, basis_test)
149+
y_hat_mcmc = bart_preds['y_hat']
149150
y_avg_mcmc = np.squeeze(y_hat_mcmc).mean(axis=1, keepdims=True)
150151
y_df = pd.DataFrame(
151152
np.concatenate((y_avg_mcmc, np.expand_dims(y_test, axis=1)), axis=1),

demo/debug/random_effects.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def outcome_mean(group_labels, basis):
7575
rfx_model.sample(rfx_dataset, outcome, rfx_tracker, rfx_container, True, 1.0, cpp_rng)
7676

7777
# Inspect the samples
78-
rfx_preds = rfx_container.predict(group_labels, basis) * y_std + y_bar
78+
bart_preds = rfx_container.predict(group_labels, basis)
79+
rfx_preds = bart_preds['y_hat'] * y_std + y_bar
7980
rfx_comparison_df = pd.DataFrame(
8081
np.concatenate((rfx_preds, np.expand_dims(rfx_term, axis=1)), axis=1),
8182
columns=["Predicted", "Actual"],

demo/debug/rfx_serialization.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,13 @@ def rfx_mean(group_labels, basis):
6060
rfx_basis_train=basis, num_gfr=10, num_mcmc=10)
6161

6262
# Extract predictions from the sampler
63-
y_hat_orig = bart_orig.predict(X, W, group_labels, basis)
63+
bart_preds_orig = bart_orig.predict(X, W, group_labels, basis)
64+
y_hat_orig = bart_preds_orig['y_hat']
6465

6566
# "Round-trip" the model to JSON string and back and check that the predictions agree
6667
bart_json_string = bart_orig.to_json()
6768
bart_reloaded = BARTModel()
6869
bart_reloaded.from_json(bart_json_string)
69-
y_hat_reloaded = bart_reloaded.predict(X, W, group_labels, basis)
70+
bart_preds_reloaded = bart_reloaded.predict(X, W, group_labels, basis)
71+
y_hat_reloaded = bart_preds_reloaded['y_hat']
7072
np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded)

demo/debug/serialization.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,13 @@ def outcome_mean(X, W):
9898
global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)
9999

100100
# Extract predictions from the sampler
101-
y_hat_orig = forest_container.predict(dataset)
101+
bart_preds_orig = forest_container.predict(dataset)
102+
y_hat_orig = bart_preds_orig['y_hat']
102103

103104
# "Round-trip" the forest to JSON string and back and check that the predictions agree
104105
forest_json_string = forest_container.dump_json_string()
105106
forest_container_reloaded = ForestContainer(num_trees, W.shape[1], False, False)
106107
forest_container_reloaded.load_from_json_string(forest_json_string)
107-
y_hat_reloaded = forest_container_reloaded.predict(dataset)
108+
bart_preds_reloaded = forest_container_reloaded.predict(dataset)
109+
y_hat_reloaded = bart_preds_reloaded['y_hat']
108110
np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded)

demo/notebooks/prototype_interface.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,8 @@
341341
"outputs": [],
342342
"source": [
343343
"# Forest predictions\n",
344-
"forest_preds = forest_container.predict(dataset) * y_std + y_bar\n",
344+
"bart_preds = forest_container.predict(dataset)\n",
345+
"forest_preds = bart_preds['y_hat'] * y_std + y_bar\n",
345346
"forest_preds_gfr = forest_preds[:, :num_warmstart]\n",
346347
"forest_preds_mcmc = forest_preds[:, num_warmstart:num_samples]\n",
347348
"\n",
@@ -1101,7 +1102,7 @@
11011102
],
11021103
"metadata": {
11031104
"kernelspec": {
1104-
"display_name": "venv",
1105+
"display_name": "venv (3.12.9)",
11051106
"language": "python",
11061107
"name": "python3"
11071108
},

demo/notebooks/serialization.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,8 @@
241241
"metadata": {},
242242
"outputs": [],
243243
"source": [
244-
"y_hat_deserialized = bart_model_deserialized.predict(X_test, basis_test)\n",
244+
"bart_preds_deserialized = bart_model_deserialized.predict(X_test, basis_test)\n",
245+
"y_hat_deserialized = bart_preds_deserialized['y_hat']\n",
245246
"y_avg_mcmc_deserialized = np.squeeze(y_hat_deserialized).mean(axis=1, keepdims=True)\n",
246247
"y_df = pd.DataFrame(\n",
247248
" np.concatenate((y_avg_mcmc, y_avg_mcmc_deserialized), axis=1),\n",
@@ -325,7 +326,8 @@
325326
"metadata": {},
326327
"outputs": [],
327328
"source": [
328-
"y_hat_file_deserialized = bart_model_file_deserialized.predict(X_test, basis_test)\n",
329+
"bart_preds_file_deserialized = bart_model_file_deserialized.predict(X_test, basis_test)\n",
330+
"y_hat_file_deserialized = bart_preds_file_deserialized['y_hat']\n",
329331
"y_avg_mcmc_file_deserialized = np.squeeze(y_hat_file_deserialized).mean(\n",
330332
" axis=1, keepdims=True\n",
331333
")\n",
@@ -381,7 +383,7 @@
381383
],
382384
"metadata": {
383385
"kernelspec": {
384-
"display_name": "venv",
386+
"display_name": "venv (3.12.9)",
385387
"language": "python",
386388
"name": "python3"
387389
},

stochtree/bart.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,11 +1691,11 @@ def predict(
16911691

16921692
has_mean_predictions = self.include_mean_forest or self.has_rfx
16931693
if has_mean_predictions and self.include_variance_forest:
1694-
return (mean_pred, variance_pred)
1694+
return {"y_hat": mean_pred, "variance_forest_predictions": variance_pred}
16951695
elif has_mean_predictions and not self.include_variance_forest:
1696-
return mean_pred
1696+
return {"y_hat": mean_pred, "variance_forest_predictions": None}
16971697
elif not has_mean_predictions and self.include_variance_forest:
1698-
return variance_pred
1698+
return {"y_hat": None, "variance_forest_predictions": variance_pred}
16991699

17001700
def predict_mean(
17011701
self,

0 commit comments

Comments
 (0)