Skip to content

Commit c2d1b5d

Browse files
committed
Updated regression test suite
1 parent f33c099 commit c2d1b5d

File tree

3 files changed

+48
-15
lines changed

3 files changed

+48
-15
lines changed

.github/workflows/regression-test.yml

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,29 +31,68 @@ jobs:
3131
with:
3232
extra-packages: any::testthat, any::decor, local::stochtree_cran
3333

34-
- name: Create output directory for BART regression test results
34+
- name: Create output directory for R regression test results
3535
run: |
3636
mkdir -p tools/regression/bart/stochtree_bart_r_results
3737
mkdir -p tools/regression/bcf/stochtree_bcf_r_results
3838
39-
- name: Run the BART regression test benchmark suite
39+
- name: Run the R regression test benchmark suite
4040
run: |
4141
Rscript tools/regression/bart/regression_test_dispatch_bart.R
4242
Rscript tools/regression/bcf/regression_test_dispatch_bcf.R
4343
44-
- name: Collate and analyze regression test results
44+
- name: Collate and analyze R regression test results
4545
run: |
4646
Rscript tools/regression/bart/regression_test_analysis_bart.R
4747
Rscript tools/regression/bcf/regression_test_analysis_bcf.R
48+
49+
- name: Setup Python 3.10
50+
uses: actions/setup-python@v5
51+
with:
52+
python-version: "3.10"
53+
cache: "pip"
54+
55+
- name: Install Package with Relevant Dependencies
56+
run: |
57+
pip install --upgrade pip
58+
pip install -r requirements.txt
59+
pip install .
60+
61+
- name: Create output directory for python regression test results
62+
run: |
63+
mkdir -p tools/regression/bart/stochtree_bart_python_results
64+
mkdir -p tools/regression/bcf/stochtree_bcf_python_results
65+
66+
- name: Run the python regression test benchmark suite
67+
run: |
68+
python tools/regression/bart/regression_test_dispatch_bart.py
69+
python tools/regression/bcf/regression_test_dispatch_bcf.py
4870
49-
- name: Store BART benchmark test results as an artifact of the run
71+
- name: Collate and analyze python regression test results
72+
run: |
73+
python tools/regression/bart/regression_test_analysis_bart.py
74+
python tools/regression/bcf/regression_test_analysis_bcf.py
75+
76+
- name: Store R BART benchmark test results as an artifact of the run
5077
uses: actions/upload-artifact@v4
5178
with:
5279
name: stochtree-r-bart-summary
5380
path: tools/regression/bart/stochtree_bart_r_results/stochtree_bart_r_summary.csv
5481

55-
- name: Store BCF benchmark test results as an artifact of the run
82+
- name: Store R BCF benchmark test results as an artifact of the run
5683
uses: actions/upload-artifact@v4
5784
with:
5885
name: stochtree-r-bcf-summary
5986
path: tools/regression/bcf/stochtree_bcf_r_results/stochtree_bcf_r_summary.csv
87+
88+
- name: Store python BART benchmark test results as an artifact of the run
89+
uses: actions/upload-artifact@v4
90+
with:
91+
name: stochtree-python-bart-summary
92+
path: tools/regression/bart/stochtree_bart_python_results/stochtree_bart_python_summary.csv
93+
94+
- name: Store python BCF benchmark test results as an artifact of the run
95+
uses: actions/upload-artifact@v4
96+
with:
97+
name: stochtree-python-bcf-summary
98+
path: tools/regression/bcf/stochtree_bcf_python_results/stochtree_bcf_python_summary.csv

tools/regression/bart/individual_regression_test_bart.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,5 +231,5 @@ filename <- paste(
231231
"dgp_num", dgp_num, "snr", snr_rounded, "test_set_pct", test_set_pct_rounded,
232232
"num_threads", num_threads_clean, sep = "_"
233233
)
234-
filename_full <- paste0("tools/regression/stochtree_bart_r_results/", filename, ".csv")
234+
filename_full <- paste0("tools/regression/bart/stochtree_bart_r_results/", filename, ".csv")
235235
write.csv(x = results_df, file = filename_full, row.names = F)

tools/regression/bcf/individual_regression_test_bcf.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -364,12 +364,8 @@ def main():
364364

365365
y_hat_posterior_mean = np.mean(y_hat_posterior, axis=1)
366366
if has_multivariate_treatment:
367-
# For multivariate treatment, tau_hat_posterior has shape (n_test, n_samples, n_treatments)
368-
# We want to average over the samples (axis 1) to get (n_test, n_treatments)
369-
tau_hat_posterior_mean = np.mean(tau_hat_posterior, axis=1)
367+
tau_hat_posterior_mean = np.mean(tau_hat_posterior, axis=2)
370368
else:
371-
# For univariate treatment, tau_hat_posterior has shape (n_test, n_samples)
372-
# We want to average over the samples (axis 1) to get (n_test,)
373369
tau_hat_posterior_mean = np.mean(tau_hat_posterior, axis=1)
374370

375371
# Outcome RMSE and coverage
@@ -387,15 +383,13 @@ def main():
387383
tau_hat_rmse_test = np.sqrt(np.mean((tau_hat_posterior_mean - treatment_effect_test) ** 2))
388384

389385
if has_multivariate_treatment:
390-
# For multivariate treatment, compute percentiles over samples (axis 1)
391-
tau_hat_posterior_quantile_025 = np.percentile(tau_hat_posterior, 2.5, axis=1)
392-
tau_hat_posterior_quantile_975 = np.percentile(tau_hat_posterior, 97.5, axis=1)
386+
tau_hat_posterior_quantile_025 = np.percentile(tau_hat_posterior, 2.5, axis=2)
387+
tau_hat_posterior_quantile_975 = np.percentile(tau_hat_posterior, 97.5, axis=2)
393388
tau_hat_covered = np.logical_and(
394389
treatment_effect_test >= tau_hat_posterior_quantile_025,
395390
treatment_effect_test <= tau_hat_posterior_quantile_975
396391
)
397392
else:
398-
# For univariate treatment, compute percentiles over samples (axis 1)
399393
tau_hat_posterior_quantile_025 = np.percentile(tau_hat_posterior, 2.5, axis=1)
400394
tau_hat_posterior_quantile_975 = np.percentile(tau_hat_posterior, 97.5, axis=1)
401395
tau_hat_covered = np.logical_and(

0 commit comments

Comments
 (0)