1
+ import numpy as np
2
+ import pandas as pd
3
+ import time
4
+ import sys
5
+ import os
6
+ from typing import Dict
7
+ from stochtree import BARTModel
8
+ from sklearn .model_selection import train_test_split
9
+ import pymc as pm
10
+ import pymc_bart as pmb
11
+
12
+ def dgp1 (n : int , p : int , snr : float ) -> Dict :
13
+ rng = np .random .default_rng ()
14
+
15
+ # Covariates
16
+ X = rng .uniform (0 , 1 , size = (n , p ))
17
+
18
+ # Piecewise linear term
19
+ plm_term = np .where (
20
+ (X [:,0 ] >= 0.0 ) & (X [:,0 ] < 0.25 ), - 7.5 * X [:,1 ],
21
+ np .where (
22
+ (X [:,0 ] >= 0.25 ) & (X [:,0 ] < 0.5 ), - 2.5 * X [:,1 ],
23
+ np .where (
24
+ (X [:,0 ] >= 0.5 ) & (X [:,0 ] < 0.75 ), 2.5 * X [:,1 ],
25
+ 7.5 * X [:,1 ]
26
+ )
27
+ )
28
+ )
29
+
30
+ # Trigonometric term
31
+ trig_term = 2 * np .sin (X [:, 2 ] * 2 * np .pi ) - 1.5 * np .cos (X [:, 3 ] * 2 * np .pi )
32
+
33
+ # Outcome
34
+ f_XW = plm_term + trig_term
35
+ noise_sd = np .std (f_XW ) / snr
36
+ y = f_XW + rng .normal (0 , noise_sd , n )
37
+
38
+ return {
39
+ 'covariates' : X ,
40
+ 'basis' : None ,
41
+ 'outcome' : y ,
42
+ 'conditional_mean' : f_XW ,
43
+ 'rfx_group_ids' : None ,
44
+ 'rfx_basis' : None
45
+ }
46
+
47
+
48
+ def dgp2 (n : int , p : int , snr : float ) -> Dict :
49
+ rng = np .random .default_rng ()
50
+
51
+ # Covariates and basis
52
+ X = rng .uniform (0 , 1 , (n , p ))
53
+ W = rng .uniform (0 , 1 , (n , 2 ))
54
+
55
+ # Piecewise linear term using basis W
56
+ plm_term = np .where (
57
+ (X [:,0 ] >= 0.0 ) & (X [:,0 ] < 0.25 ), - 7.5 * W [:,0 ],
58
+ np .where (
59
+ (X [:,0 ] >= 0.25 ) & (X [:,0 ] < 0.5 ), - 2.5 * W [:,0 ],
60
+ np .where (
61
+ (X [:,0 ] >= 0.5 ) & (X [:,0 ] < 0.75 ), 2.5 * W [:,0 ],
62
+ 7.5 * W [:,0 ]
63
+ )
64
+ )
65
+ )
66
+
67
+ # Trigonometric term
68
+ trig_term = 2 * np .sin (X [:, 2 ] * 2 * np .pi ) - 1.5 * np .cos (X [:, 3 ] * 2 * np .pi )
69
+
70
+ # Outcome
71
+ f_XW = plm_term + trig_term
72
+ noise_sd = np .std (f_XW ) / snr
73
+ y = f_XW + rng .normal (0 , noise_sd , n )
74
+
75
+ return {
76
+ 'covariates' : X ,
77
+ 'basis' : W ,
78
+ 'outcome' : y ,
79
+ 'conditional_mean' : f_XW ,
80
+ 'rfx_group_ids' : None ,
81
+ 'rfx_basis' : None
82
+ }
83
+
84
+
85
+ def dgp3 (n : int , p : int , snr : float ) -> Dict :
86
+ rng = np .random .default_rng ()
87
+
88
+ # Covariates
89
+ X = rng .uniform (0 , 1 , size = (n , p ))
90
+
91
+ # Piecewise linear term
92
+ plm_term = np .where (
93
+ (X [:,0 ] >= 0.0 ) & (X [:,0 ] < 0.25 ), - 7.5 * X [:,1 ],
94
+ np .where (
95
+ (X [:,0 ] >= 0.25 ) & (X [:,0 ] < 0.5 ), - 2.5 * X [:,1 ],
96
+ np .where (
97
+ (X [:,0 ] >= 0.5 ) & (X [:,0 ] < 0.75 ), 2.5 * X [:,1 ],
98
+ 7.5 * X [:,1 ]
99
+ )
100
+ )
101
+ )
102
+
103
+ # Trigonometric term
104
+ trig_term = 2 * np .sin (X [:, 2 ] * 2 * np .pi ) - 1.5 * np .cos (X [:, 3 ] * 2 * np .pi )
105
+
106
+ # Random effects
107
+ num_groups = 3
108
+ rfx_group_ids = rng .choice (num_groups , size = n )
109
+ rfx_coefs = np .array ([[- 5 , - 3 , - 1 ], [5 , 3 , 1 ]]).T
110
+ rfx_basis = np .column_stack ([np .ones (n ), np .random .uniform (- 1 , 1 , n )])
111
+ rfx_term = np .sum (rfx_coefs [rfx_group_ids ] * rfx_basis , axis = 1 )
112
+
113
+ # Outcome
114
+ f_XW = plm_term + trig_term + rfx_term
115
+ noise_sd = np .std (f_XW ) / snr
116
+ y = f_XW + rng .normal (0 , noise_sd , n )
117
+
118
+ return {
119
+ 'covariates' : X ,
120
+ 'basis' : None ,
121
+ 'outcome' : y ,
122
+ 'conditional_mean' : f_XW ,
123
+ 'rfx_group_ids' : rfx_group_ids ,
124
+ 'rfx_basis' : rfx_basis
125
+ }
126
+
127
+
128
+ def dgp4 (n : int , p : int , snr : float ) -> Dict :
129
+ rng = np .random .default_rng ()
130
+
131
+ # Covariates and basis
132
+ X = rng .uniform (0 , 1 , (n , p ))
133
+ W = rng .uniform (0 , 1 , (n , 2 ))
134
+
135
+ # Piecewise linear term using basis W
136
+ plm_term = np .where (
137
+ (X [:,0 ] >= 0.0 ) & (X [:,0 ] < 0.25 ), - 7.5 * W [:,0 ],
138
+ np .where (
139
+ (X [:,0 ] >= 0.25 ) & (X [:,0 ] < 0.5 ), - 2.5 * W [:,0 ],
140
+ np .where (
141
+ (X [:,0 ] >= 0.5 ) & (X [:,0 ] < 0.75 ), 2.5 * W [:,0 ],
142
+ 7.5 * W [:,0 ]
143
+ )
144
+ )
145
+ )
146
+
147
+ # Trigonometric term
148
+ trig_term = 2 * np .sin (X [:, 2 ] * 2 * np .pi ) - 1.5 * np .cos (X [:, 3 ] * 2 * np .pi )
149
+
150
+ # Random effects
151
+ num_groups = 3
152
+ rfx_group_ids = rng .choice (num_groups , size = n )
153
+ rfx_coefs = np .array ([[- 5 , - 3 , - 1 ], [5 , 3 , 1 ]]).T
154
+ rfx_basis = np .column_stack ([np .ones (n ), np .random .uniform (- 1 , 1 , n )])
155
+ rfx_term = np .sum (rfx_coefs [rfx_group_ids ] * rfx_basis , axis = 1 )
156
+
157
+ # Outcome
158
+ f_XW = plm_term + trig_term + rfx_term
159
+ noise_sd = np .std (f_XW ) / snr
160
+ y = f_XW + np .random .normal (0 , noise_sd , n )
161
+
162
+ return {
163
+ 'covariates' : X ,
164
+ 'basis' : W ,
165
+ 'outcome' : y ,
166
+ 'conditional_mean' : f_XW ,
167
+ 'rfx_group_ids' : rfx_group_ids ,
168
+ 'rfx_basis' : rfx_basis
169
+ }
170
+
171
+
172
+ def compute_test_train_indices (n : int , test_set_pct : float ) -> Dict [str , np .ndarray ]:
173
+ sample_inds = np .arange (n )
174
+ train_inds , test_inds = train_test_split (sample_inds , test_size = test_set_pct )
175
+ return {'test_inds' : test_inds , 'train_inds' : train_inds }
176
+
177
+
178
+ def subset_data (data : np .ndarray , subset_inds : np .ndarray ) -> np .ndarray :
179
+ if isinstance (data , np .ndarray ):
180
+ if data .ndim == 1 :
181
+ return data [subset_inds ]
182
+ else :
183
+ return data [subset_inds , :]
184
+ else :
185
+ raise ValueError ("Data must be a numpy array" )
186
+
187
+
188
+ def main ():
189
+ # Parse command line arguments
190
+ if len (sys .argv ) > 1 :
191
+ n_iter = int (sys .argv [1 ])
192
+ n = int (sys .argv [2 ])
193
+ p = int (sys .argv [3 ])
194
+ num_gfr = int (sys .argv [4 ])
195
+ num_mcmc = int (sys .argv [5 ])
196
+ dgp_num = int (sys .argv [6 ])
197
+ snr = float (sys .argv [7 ])
198
+ test_set_pct = float (sys .argv [8 ])
199
+ num_threads = int (sys .argv [9 ])
200
+ else :
201
+ # Default arguments
202
+ n_iter = 5
203
+ n = 1000
204
+ p = 5
205
+ num_gfr = 10
206
+ num_mcmc = 100
207
+ dgp_num = 1
208
+ snr = 2.0
209
+ test_set_pct = 0.2
210
+ num_threads = - 1
211
+
212
+ print (f"n_iter = { n_iter } " )
213
+ print (f"n = { n } " )
214
+ print (f"p = { p } " )
215
+ print (f"num_gfr = { num_gfr } " )
216
+ print (f"num_mcmc = { num_mcmc } " )
217
+ print (f"dgp_num = { dgp_num } " )
218
+ print (f"snr = { snr } " )
219
+ print (f"test_set_pct = { test_set_pct } " )
220
+ print (f"num_threads = { num_threads } " )
221
+
222
+ # Run the performance evaluation
223
+ results = np .empty ((n_iter , 4 ), dtype = float )
224
+
225
+ for i in range (n_iter ):
226
+ print (f"Running iteration { i + 1 } /{ n_iter } " )
227
+
228
+ # Generate data
229
+ if dgp_num == 1 :
230
+ data_dict = dgp1 (n = n , p = p , snr = snr )
231
+ elif dgp_num == 2 :
232
+ data_dict = dgp2 (n = n , p = p , snr = snr )
233
+ elif dgp_num == 3 :
234
+ data_dict = dgp3 (n = n , p = p , snr = snr )
235
+ elif dgp_num == 4 :
236
+ data_dict = dgp4 (n = n , p = p , snr = snr )
237
+ else :
238
+ raise ValueError ("Invalid DGP input" )
239
+
240
+ covariates = data_dict ['covariates' ]
241
+ basis = data_dict ['basis' ]
242
+ conditional_mean = data_dict ['conditional_mean' ]
243
+ outcome = data_dict ['outcome' ]
244
+ rfx_group_ids = data_dict ['rfx_group_ids' ]
245
+ rfx_basis = data_dict ['rfx_basis' ]
246
+
247
+ # Split into train / test sets
248
+ subset_inds_dict = compute_test_train_indices (n , test_set_pct )
249
+ test_inds = subset_inds_dict ['test_inds' ]
250
+ train_inds = subset_inds_dict ['train_inds' ]
251
+ covariates_train = subset_data (covariates , train_inds )
252
+ covariates_test = subset_data (covariates , test_inds )
253
+ outcome_train = subset_data (outcome , train_inds )
254
+ outcome_test = subset_data (outcome , test_inds )
255
+ conditional_mean_train = subset_data (conditional_mean , train_inds )
256
+ conditional_mean_test = subset_data (conditional_mean , test_inds )
257
+ has_basis = basis is not None
258
+ has_rfx = rfx_group_ids is not None
259
+ if has_basis :
260
+ basis_train = subset_data (basis , train_inds )
261
+ basis_test = subset_data (basis , test_inds )
262
+ else :
263
+ basis_train = None
264
+ basis_test = None
265
+ if has_rfx :
266
+ rfx_group_ids_train = subset_data (rfx_group_ids , train_inds )
267
+ rfx_group_ids_test = subset_data (rfx_group_ids , test_inds )
268
+ rfx_basis_train = subset_data (rfx_basis , train_inds )
269
+ rfx_basis_test = subset_data (rfx_basis , test_inds )
270
+ else :
271
+ rfx_group_ids_train = None
272
+ rfx_group_ids_test = None
273
+ rfx_basis_train = None
274
+ rfx_basis_test = None
275
+
276
+ # Run (and time) stochtree BART
277
+ start_time = time .time ()
278
+
279
+ # Sample BART model
280
+ general_params = {'num_threads' : num_threads }
281
+ bart_model = BARTModel ()
282
+ bart_model .sample (
283
+ X_train = covariates_train ,
284
+ y_train = outcome_train ,
285
+ leaf_basis_train = basis_train ,
286
+ rfx_group_ids_train = rfx_group_ids_train ,
287
+ rfx_basis_train = rfx_basis_train ,
288
+ num_gfr = num_gfr ,
289
+ num_mcmc = num_mcmc ,
290
+ general_params = general_params
291
+ )
292
+
293
+ # Predict on the test set
294
+ test_preds = bart_model .predict (
295
+ covariates = covariates_test ,
296
+ basis = basis_test ,
297
+ rfx_group_ids = rfx_group_ids_test ,
298
+ rfx_basis = rfx_basis_test
299
+ )
300
+
301
+ bart_timing = time .time () - start_time
302
+
303
+ # Run (and time) pymc BART
304
+ start_time = time .time ()
305
+
306
+ # Sample BART model
307
+ with pm .Model () as model :
308
+ μ = pmb .BART ("μ" , covariates_train , outcome_train )
309
+ σ = pm .HalfNormal ("σ" , 5 )
310
+ obs = pm .Normal ("obs" , mu = μ , sigma = σ , observed = outcome_train )
311
+ idata = pm .sample (draws = num_mcmc , tune = 0 , chains = 1 , cores = num_threads ,
312
+ compute_convergence_checks = False , random_seed = 123 ,
313
+ progressbar = False )
314
+
315
+ pymc_bart_timing = time .time () - start_time
316
+
317
+ print (f"Stochtree BART timing: { bart_timing :.2f} seconds; PyMC BART timing: { pymc_bart_timing :.2f} seconds" )
318
+
319
+
320
+ if __name__ == "__main__" :
321
+ main ()
0 commit comments