Skip to content

Commit 533f2a1

Browse files
committed
Added demo / debug scripts for new python functionality
1 parent 19302ac commit 533f2a1

File tree

4 files changed

+663
-0
lines changed

4 files changed

+663
-0
lines changed

demo/debug/bart_contrast_debug.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Demo of contrast computation function for BART
2+
3+
# Load libraries
4+
from stochtree import BARTModel
5+
from sklearn.model_selection import train_test_split
6+
import numpy as np
7+
8+
# Generate data
9+
n = 500
10+
p = 5
11+
rng = np.random.default_rng(1234)
12+
X = rng.uniform(low=0.0, high=1.0, size=(n, p))
13+
W = rng.normal(loc=0.0, scale=1.0, size=(n, 1))
14+
f_XW = np.where(
15+
((0 <= X[:, 0]) & (X[:, 0] < 0.25)),
16+
-7.5 * W[:, 0],
17+
np.where(
18+
((0.25 <= X[:, 0]) & (X[:, 0] < 0.5)),
19+
-2.5 * W[:, 0],
20+
np.where(
21+
((0.5 <= X[:, 0]) & (X[:, 0] < 0.75)),
22+
2.5 * W[:, 0],
23+
7.5 * W[:, 0],
24+
),
25+
),
26+
)
27+
E_Y = f_XW
28+
snr = 2
29+
y = E_Y + rng.normal(loc=0.0, scale=1.0, size=(n,)) * (np.std(E_Y) / snr)
30+
31+
# Train-test split
32+
test_set_pct = 0.2
33+
train_inds, test_inds = train_test_split(
34+
np.arange(n), test_size=test_set_pct, random_state=1234
35+
)
36+
X_train = X[train_inds, :]
37+
X_test = X[test_inds, :]
38+
W_train = W[train_inds, :]
39+
W_test = W[test_inds, :]
40+
y_train = y[train_inds]
41+
y_test = y[test_inds]
42+
n_test = len(test_inds)
43+
n_train = len(train_inds)
44+
45+
# Fit BART model
46+
bart_model = BARTModel()
47+
bart_model.sample(
48+
X_train=X_train,
49+
leaf_basis_train=W_train,
50+
y_train=y_train,
51+
num_gfr=10,
52+
num_burnin=0,
53+
num_mcmc=1000,
54+
)
55+
56+
# Compute contrast posterior
57+
contrast_posterior_test = bart_model.compute_contrast(
58+
covariates_0=X_test,
59+
covariates_1=X_test,
60+
basis_0=np.zeros((n_test, 1)),
61+
basis_1=np.ones((n_test, 1)),
62+
type="posterior",
63+
scale="linear",
64+
)
65+
66+
# Compute the same quantity via two predict calls
67+
y_hat_posterior_test_0 = bart_model.predict(
68+
covariates=X_test,
69+
basis=np.zeros((n_test, 1)),
70+
type="posterior",
71+
terms="y_hat",
72+
scale="linear",
73+
)
74+
y_hat_posterior_test_1 = bart_model.predict(
75+
covariates=X_test,
76+
basis=np.ones((n_test, 1)),
77+
type="posterior",
78+
terms="y_hat",
79+
scale="linear",
80+
)
81+
contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0
82+
83+
# Compare results
84+
contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test
85+
np.allclose(contrast_diff, 0, atol=0.001)
86+
87+
# Generate data for a BART model with random effects
88+
X = rng.uniform(low=0.0, high=1.0, size=(n, p))
89+
W = rng.normal(loc=0.0, scale=1.0, size=(n, 1))
90+
f_XW = np.where(
91+
((0 <= X[:, 0]) & (X[:, 0] < 0.25)),
92+
-7.5 * W[:, 0],
93+
np.where(
94+
((0.25 <= X[:, 0]) & (X[:, 0] < 0.5)),
95+
-2.5 * W[:, 0],
96+
np.where(
97+
((0.5 <= X[:, 0]) & (X[:, 0] < 0.75)),
98+
2.5 * W[:, 0],
99+
7.5 * W[:, 0],
100+
),
101+
),
102+
)
103+
num_rfx_groups = 3
104+
group_labels = rng.choice(num_rfx_groups, size=n)
105+
basis = np.empty((n, 2))
106+
basis[:, 0] = 1.0
107+
basis[:, 1] = rng.uniform(0, 1, (n,))
108+
rfx_coefs = np.array([[-2, -2], [0, 0], [2, 2]])
109+
rfx_term = np.sum(rfx_coefs[group_labels, :] * basis, axis=1)
110+
E_Y = f_XW + rfx_term
111+
snr = 2
112+
y = E_Y + rng.normal(loc=0.0, scale=1.0, size=(n,)) * (np.std(E_Y) / snr)
113+
114+
# Train-test split
115+
train_inds, test_inds = train_test_split(
116+
np.arange(n), test_size=test_set_pct, random_state=1234
117+
)
118+
X_train = X[train_inds, :]
119+
X_test = X[test_inds, :]
120+
W_train = W[train_inds, :]
121+
W_test = W[test_inds, :]
122+
y_train = y[train_inds]
123+
y_test = y[test_inds]
124+
group_ids_train = group_labels[train_inds]
125+
group_ids_test = group_labels[test_inds]
126+
rfx_basis_train = basis[train_inds, :]
127+
rfx_basis_test = basis[test_inds, :]
128+
n_test = len(test_inds)
129+
n_train = len(train_inds)
130+
131+
# Fit BART model
132+
bart_model = BARTModel()
133+
bart_model.sample(
134+
X_train=X_train,
135+
leaf_basis_train=W_train,
136+
y_train=y_train,
137+
rfx_group_ids_train=group_ids_train,
138+
rfx_basis_train=rfx_basis_train,
139+
num_gfr=10,
140+
num_burnin=0,
141+
num_mcmc=1000,
142+
)
143+
144+
# Compute contrast posterior
145+
contrast_posterior_test = bart_model.compute_contrast(
146+
covariates_0=X_test,
147+
covariates_1=X_test,
148+
basis_0=np.zeros((n_test, 1)),
149+
basis_1=np.ones((n_test, 1)),
150+
rfx_group_ids_0=group_ids_test,
151+
rfx_group_ids_1=group_ids_test,
152+
rfx_basis_0=rfx_basis_test,
153+
rfx_basis_1=rfx_basis_test,
154+
type="posterior",
155+
scale="linear",
156+
)
157+
158+
# Compute the same quantity via two predict calls
159+
y_hat_posterior_test_0 = bart_model.predict(
160+
covariates=X_test,
161+
basis=np.zeros((n_test, 1)),
162+
rfx_group_ids=group_ids_test,
163+
rfx_basis=rfx_basis_test,
164+
type="posterior",
165+
terms="y_hat",
166+
scale="linear",
167+
)
168+
y_hat_posterior_test_1 = bart_model.predict(
169+
covariates=X_test,
170+
basis=np.ones((n_test, 1)),
171+
rfx_group_ids=group_ids_test,
172+
rfx_basis=rfx_basis_test,
173+
type="posterior",
174+
terms="y_hat",
175+
scale="linear",
176+
)
177+
contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0
178+
179+
# Compare results
180+
contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test
181+
np.allclose(contrast_diff, 0, atol=0.001)

0 commit comments

Comments
 (0)