Skip to content

Commit 18db957

Browse files
committed
Added simple PyMC BART comparison script
1 parent 89c5bc7 commit 18db957

File tree

1 file changed

+321
-0
lines changed

1 file changed

+321
-0
lines changed
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
import numpy as np
2+
import pandas as pd
3+
import time
4+
import sys
5+
import os
6+
from typing import Dict
7+
from stochtree import BARTModel
8+
from sklearn.model_selection import train_test_split
9+
import pymc as pm
10+
import pymc_bart as pmb
11+
12+
def dgp1(n: int, p: int, snr: float) -> Dict:
13+
rng = np.random.default_rng()
14+
15+
# Covariates
16+
X = rng.uniform(0, 1, size=(n, p))
17+
18+
# Piecewise linear term
19+
plm_term = np.where(
20+
(X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * X[:,1],
21+
np.where(
22+
(X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * X[:,1],
23+
np.where(
24+
(X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * X[:,1],
25+
7.5 * X[:,1]
26+
)
27+
)
28+
)
29+
30+
# Trigonometric term
31+
trig_term = 2 * np.sin(X[:, 2] * 2 * np.pi) - 1.5 * np.cos(X[:, 3] * 2 * np.pi)
32+
33+
# Outcome
34+
f_XW = plm_term + trig_term
35+
noise_sd = np.std(f_XW) / snr
36+
y = f_XW + rng.normal(0, noise_sd, n)
37+
38+
return {
39+
'covariates': X,
40+
'basis': None,
41+
'outcome': y,
42+
'conditional_mean': f_XW,
43+
'rfx_group_ids': None,
44+
'rfx_basis': None
45+
}
46+
47+
48+
def dgp2(n: int, p: int, snr: float) -> Dict:
49+
rng = np.random.default_rng()
50+
51+
# Covariates and basis
52+
X = rng.uniform(0, 1, (n, p))
53+
W = rng.uniform(0, 1, (n, 2))
54+
55+
# Piecewise linear term using basis W
56+
plm_term = np.where(
57+
(X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * W[:,0],
58+
np.where(
59+
(X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * W[:,0],
60+
np.where(
61+
(X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * W[:,0],
62+
7.5 * W[:,0]
63+
)
64+
)
65+
)
66+
67+
# Trigonometric term
68+
trig_term = 2 * np.sin(X[:, 2] * 2 * np.pi) - 1.5 * np.cos(X[:, 3] * 2 * np.pi)
69+
70+
# Outcome
71+
f_XW = plm_term + trig_term
72+
noise_sd = np.std(f_XW) / snr
73+
y = f_XW + rng.normal(0, noise_sd, n)
74+
75+
return {
76+
'covariates': X,
77+
'basis': W,
78+
'outcome': y,
79+
'conditional_mean': f_XW,
80+
'rfx_group_ids': None,
81+
'rfx_basis': None
82+
}
83+
84+
85+
def dgp3(n: int, p: int, snr: float) -> Dict:
86+
rng = np.random.default_rng()
87+
88+
# Covariates
89+
X = rng.uniform(0, 1, size=(n, p))
90+
91+
# Piecewise linear term
92+
plm_term = np.where(
93+
(X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * X[:,1],
94+
np.where(
95+
(X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * X[:,1],
96+
np.where(
97+
(X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * X[:,1],
98+
7.5 * X[:,1]
99+
)
100+
)
101+
)
102+
103+
# Trigonometric term
104+
trig_term = 2 * np.sin(X[:, 2] * 2 * np.pi) - 1.5 * np.cos(X[:, 3] * 2 * np.pi)
105+
106+
# Random effects
107+
num_groups = 3
108+
rfx_group_ids = rng.choice(num_groups, size=n)
109+
rfx_coefs = np.array([[-5, -3, -1], [5, 3, 1]]).T
110+
rfx_basis = np.column_stack([np.ones(n), np.random.uniform(-1, 1, n)])
111+
rfx_term = np.sum(rfx_coefs[rfx_group_ids] * rfx_basis, axis=1)
112+
113+
# Outcome
114+
f_XW = plm_term + trig_term + rfx_term
115+
noise_sd = np.std(f_XW) / snr
116+
y = f_XW + rng.normal(0, noise_sd, n)
117+
118+
return {
119+
'covariates': X,
120+
'basis': None,
121+
'outcome': y,
122+
'conditional_mean': f_XW,
123+
'rfx_group_ids': rfx_group_ids,
124+
'rfx_basis': rfx_basis
125+
}
126+
127+
128+
def dgp4(n: int, p: int, snr: float) -> Dict:
129+
rng = np.random.default_rng()
130+
131+
# Covariates and basis
132+
X = rng.uniform(0, 1, (n, p))
133+
W = rng.uniform(0, 1, (n, 2))
134+
135+
# Piecewise linear term using basis W
136+
plm_term = np.where(
137+
(X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * W[:,0],
138+
np.where(
139+
(X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * W[:,0],
140+
np.where(
141+
(X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * W[:,0],
142+
7.5 * W[:,0]
143+
)
144+
)
145+
)
146+
147+
# Trigonometric term
148+
trig_term = 2 * np.sin(X[:, 2] * 2 * np.pi) - 1.5 * np.cos(X[:, 3] * 2 * np.pi)
149+
150+
# Random effects
151+
num_groups = 3
152+
rfx_group_ids = rng.choice(num_groups, size=n)
153+
rfx_coefs = np.array([[-5, -3, -1], [5, 3, 1]]).T
154+
rfx_basis = np.column_stack([np.ones(n), np.random.uniform(-1, 1, n)])
155+
rfx_term = np.sum(rfx_coefs[rfx_group_ids] * rfx_basis, axis=1)
156+
157+
# Outcome
158+
f_XW = plm_term + trig_term + rfx_term
159+
noise_sd = np.std(f_XW) / snr
160+
y = f_XW + np.random.normal(0, noise_sd, n)
161+
162+
return {
163+
'covariates': X,
164+
'basis': W,
165+
'outcome': y,
166+
'conditional_mean': f_XW,
167+
'rfx_group_ids': rfx_group_ids,
168+
'rfx_basis': rfx_basis
169+
}
170+
171+
172+
def compute_test_train_indices(n: int, test_set_pct: float) -> Dict[str, np.ndarray]:
173+
sample_inds = np.arange(n)
174+
train_inds, test_inds = train_test_split(sample_inds, test_size=test_set_pct)
175+
return {'test_inds': test_inds, 'train_inds': train_inds}
176+
177+
178+
def subset_data(data: np.ndarray, subset_inds: np.ndarray) -> np.ndarray:
179+
if isinstance(data, np.ndarray):
180+
if data.ndim == 1:
181+
return data[subset_inds]
182+
else:
183+
return data[subset_inds, :]
184+
else:
185+
raise ValueError("Data must be a numpy array")
186+
187+
188+
def main():
189+
# Parse command line arguments
190+
if len(sys.argv) > 1:
191+
n_iter = int(sys.argv[1])
192+
n = int(sys.argv[2])
193+
p = int(sys.argv[3])
194+
num_gfr = int(sys.argv[4])
195+
num_mcmc = int(sys.argv[5])
196+
dgp_num = int(sys.argv[6])
197+
snr = float(sys.argv[7])
198+
test_set_pct = float(sys.argv[8])
199+
num_threads = int(sys.argv[9])
200+
else:
201+
# Default arguments
202+
n_iter = 5
203+
n = 1000
204+
p = 5
205+
num_gfr = 10
206+
num_mcmc = 100
207+
dgp_num = 1
208+
snr = 2.0
209+
test_set_pct = 0.2
210+
num_threads = -1
211+
212+
print(f"n_iter = {n_iter}")
213+
print(f"n = {n}")
214+
print(f"p = {p}")
215+
print(f"num_gfr = {num_gfr}")
216+
print(f"num_mcmc = {num_mcmc}")
217+
print(f"dgp_num = {dgp_num}")
218+
print(f"snr = {snr}")
219+
print(f"test_set_pct = {test_set_pct}")
220+
print(f"num_threads = {num_threads}")
221+
222+
# Run the performance evaluation
223+
results = np.empty((n_iter, 4), dtype=float)
224+
225+
for i in range(n_iter):
226+
print(f"Running iteration {i+1}/{n_iter}")
227+
228+
# Generate data
229+
if dgp_num == 1:
230+
data_dict = dgp1(n=n, p=p, snr=snr)
231+
elif dgp_num == 2:
232+
data_dict = dgp2(n=n, p=p, snr=snr)
233+
elif dgp_num == 3:
234+
data_dict = dgp3(n=n, p=p, snr=snr)
235+
elif dgp_num == 4:
236+
data_dict = dgp4(n=n, p=p, snr=snr)
237+
else:
238+
raise ValueError("Invalid DGP input")
239+
240+
covariates = data_dict['covariates']
241+
basis = data_dict['basis']
242+
conditional_mean = data_dict['conditional_mean']
243+
outcome = data_dict['outcome']
244+
rfx_group_ids = data_dict['rfx_group_ids']
245+
rfx_basis = data_dict['rfx_basis']
246+
247+
# Split into train / test sets
248+
subset_inds_dict = compute_test_train_indices(n, test_set_pct)
249+
test_inds = subset_inds_dict['test_inds']
250+
train_inds = subset_inds_dict['train_inds']
251+
covariates_train = subset_data(covariates, train_inds)
252+
covariates_test = subset_data(covariates, test_inds)
253+
outcome_train = subset_data(outcome, train_inds)
254+
outcome_test = subset_data(outcome, test_inds)
255+
conditional_mean_train = subset_data(conditional_mean, train_inds)
256+
conditional_mean_test = subset_data(conditional_mean, test_inds)
257+
has_basis = basis is not None
258+
has_rfx = rfx_group_ids is not None
259+
if has_basis:
260+
basis_train = subset_data(basis, train_inds)
261+
basis_test = subset_data(basis, test_inds)
262+
else:
263+
basis_train = None
264+
basis_test = None
265+
if has_rfx:
266+
rfx_group_ids_train = subset_data(rfx_group_ids, train_inds)
267+
rfx_group_ids_test = subset_data(rfx_group_ids, test_inds)
268+
rfx_basis_train = subset_data(rfx_basis, train_inds)
269+
rfx_basis_test = subset_data(rfx_basis, test_inds)
270+
else:
271+
rfx_group_ids_train = None
272+
rfx_group_ids_test = None
273+
rfx_basis_train = None
274+
rfx_basis_test = None
275+
276+
# Run (and time) stochtree BART
277+
start_time = time.time()
278+
279+
# Sample BART model
280+
general_params = {'num_threads': num_threads}
281+
bart_model = BARTModel()
282+
bart_model.sample(
283+
X_train=covariates_train,
284+
y_train=outcome_train,
285+
leaf_basis_train=basis_train,
286+
rfx_group_ids_train=rfx_group_ids_train,
287+
rfx_basis_train=rfx_basis_train,
288+
num_gfr=num_gfr,
289+
num_mcmc=num_mcmc,
290+
general_params=general_params
291+
)
292+
293+
# Predict on the test set
294+
test_preds = bart_model.predict(
295+
covariates=covariates_test,
296+
basis=basis_test,
297+
rfx_group_ids=rfx_group_ids_test,
298+
rfx_basis=rfx_basis_test
299+
)
300+
301+
bart_timing = time.time() - start_time
302+
303+
# Run (and time) pymc BART
304+
start_time = time.time()
305+
306+
# Sample BART model
307+
with pm.Model() as model:
308+
μ = pmb.BART("μ", covariates_train, outcome_train)
309+
σ = pm.HalfNormal("σ", 5)
310+
obs = pm.Normal("obs", mu=μ, sigma=σ, observed=outcome_train)
311+
idata = pm.sample(draws=num_mcmc, tune=0, chains=1, cores=num_threads,
312+
compute_convergence_checks=False, random_seed=123,
313+
progressbar=False)
314+
315+
pymc_bart_timing = time.time() - start_time
316+
317+
print(f"Stochtree BART timing: {bart_timing:.2f} seconds; PyMC BART timing: {pymc_bart_timing:.2f} seconds")
318+
319+
320+
if __name__ == "__main__":
321+
main()

0 commit comments

Comments
 (0)