Skip to content

Commit 80d8a23

Browse files
authored
Merge pull request #208 from StochasticTree/ensure_seed_determinism
Update bart and bcf to respect user-provided random seed
2 parents 78f30a5 + 1f9ff12 commit 80d8a23

File tree

4 files changed

+232
-0
lines changed

4 files changed

+232
-0
lines changed

R/bart.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,16 @@ bart <- function(
250250
drop_vars_variance <- variance_forest_params_updated$drop_vars
251251
num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample
252252

253+
# Set a function-scoped RNG if user provided a random seed
254+
custom_rng <- random_seed >= 0
255+
if (custom_rng) {
256+
# Store original global environment RNG state
257+
original_global_seed <- .Random.seed
258+
# Set new seed and store associated RNG state
259+
set.seed(random_seed)
260+
function_scoped_seed <- .Random.seed
261+
}
262+
253263
# Check if there are enough GFR samples to seed num_chains samplers
254264
if (num_gfr > 0) {
255265
if (num_chains > num_gfr) {
@@ -1758,6 +1768,11 @@ bart <- function(
17581768
}
17591769
rm(outcome_train)
17601770
rm(rng)
1771+
1772+
# Restore global RNG state if user provided a random seed
1773+
if (custom_rng) {
1774+
.Random.seed <- original_global_seed
1775+
}
17611776

17621777
return(result)
17631778
}

R/bcf.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,16 @@ bcf <- function(
340340
keep_vars_variance <- variance_forest_params_updated$keep_vars
341341
drop_vars_variance <- variance_forest_params_updated$drop_vars
342342
num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample
343+
344+
# Set a function-scoped RNG if user provided a random seed
345+
custom_rng <- random_seed >= 0
346+
if (custom_rng) {
347+
# Store original global environment RNG state
348+
original_global_seed <- .Random.seed
349+
# Set new seed and store associated RNG state
350+
set.seed(random_seed)
351+
function_scoped_seed <- .Random.seed
352+
}
343353

344354
# Check if there are enough GFR samples to seed num_chains samplers
345355
if (num_gfr > 0) {
@@ -2544,6 +2554,11 @@ bcf <- function(
25442554
result[["bart_propensity_model"]] = bart_model_propensity
25452555
}
25462556
class(result) <- "bcfmodel"
2557+
2558+
# Restore global RNG state if user provided a random seed
2559+
if (custom_rng) {
2560+
.Random.seed <- original_global_seed
2561+
}
25472562

25482563
return(result)
25492564
}

tools/debug/bart_random_seed.R

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,63 @@ y_hat_test_benchmark <- as.matrix(read.csv(
5353

5454
# Compare results
5555
sum(abs(y_hat_test_benchmark - bart_model$y_hat_test) > 1e-6)
56+
57+
# Generate probit data
58+
random_seed <- 1234
59+
set.seed(random_seed)
60+
n <- 500
61+
p <- 50
62+
X <- matrix(runif(n * p), ncol = p)
63+
# fmt: skip
64+
f_XW <- (
65+
((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) +
66+
((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) +
67+
((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) +
68+
((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)
69+
)
70+
noise_sd <- 1
71+
W <- f_XW + rnorm(n, 0, noise_sd)
72+
y <- (W > 0) * 1
73+
74+
# Split into train and test sets
75+
test_set_pct <- 0.2
76+
n_test <- round(test_set_pct * n)
77+
n_train <- n - n_test
78+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
79+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
80+
X_test <- X[test_inds, ]
81+
X_train <- X[train_inds, ]
82+
W_test <- W[test_inds]
83+
W_train <- W[train_inds]
84+
y_test <- y[test_inds]
85+
y_train <- y[train_inds]
86+
87+
# Set a different global seed as a test
88+
set.seed(9812384)
89+
90+
# Run BART model
91+
general_params <- list(num_threads = 1, random_seed = random_seed,
92+
probit_outcome_model = T)
93+
bart_model <- bart(
94+
X_train = X_train,
95+
y_train = y_train,
96+
X_test = X_test,
97+
num_gfr = 100,
98+
num_mcmc = 100,
99+
general_params = general_params
100+
)
101+
102+
# # Save results
103+
# write.csv(
104+
# bart_model$y_hat_test,
105+
# file = "tools/debug/seed_benchmark_probit_y_hat.csv",
106+
# row.names = FALSE
107+
# )
108+
109+
# Read results and compare to our estimates
110+
y_hat_test_benchmark <- as.matrix(read.csv(
111+
"tools/debug/seed_benchmark_probit_y_hat.csv"
112+
))
113+
114+
# Compare results
115+
sum(abs(y_hat_test_benchmark - bart_model$y_hat_test) > 1e-6)

tools/debug/bcf_random_seed.R

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Load libraries
2+
library(stochtree)
3+
4+
# Generate data
5+
random_seed <- 1234
6+
set.seed(random_seed)
7+
n <- 500
8+
p <- 50
9+
X <- matrix(runif(n * p), ncol = p)
10+
# fmt: skip
11+
f_X <- (
12+
((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) +
13+
((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) +
14+
((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) +
15+
((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)
16+
)
17+
mu_X <- f_X
18+
pi_X <- pnorm(f_X * 0.25)
19+
tau_X <- 0.5 * X[,2]
20+
Z <- rbinom(n, 1, pi_X)
21+
E_XZ <- mu_X + Z * tau_X
22+
y <- E_XZ + rnorm(n, 0, 1)
23+
24+
# Split into train and test sets
25+
test_set_pct <- 0.2
26+
n_test <- round(test_set_pct * n)
27+
n_train <- n - n_test
28+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
29+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
30+
X_test <- X[test_inds, ]
31+
X_train <- X[train_inds, ]
32+
Z_test <- Z[test_inds]
33+
Z_train <- Z[train_inds]
34+
pi_test <- pi_X[test_inds]
35+
pi_train <- pi_X[train_inds]
36+
y_test <- y[test_inds]
37+
y_train <- y[train_inds]
38+
39+
# Set a different global seed as a test
40+
set.seed(837475)
41+
42+
# Run BCF model
43+
general_params <- list(num_threads = 1, random_seed = random_seed)
44+
bcf_model <- bcf(
45+
X_train = X_train,
46+
Z_train = Z_train,
47+
propensity_train = pi_train,
48+
y_train = y_train,
49+
X_test = X_test,
50+
Z_test = Z_test,
51+
propensity_test = pi_test,
52+
num_gfr = 100,
53+
num_mcmc = 100,
54+
general_params = general_params
55+
)
56+
57+
# # Save results
58+
# write.csv(
59+
# bcf_model$y_hat_test,
60+
# file = "tools/debug/seed_benchmark_bcf_y_hat.csv",
61+
# row.names = FALSE
62+
# )
63+
64+
# Read results and compare to our estimates
65+
y_hat_test_benchmark <- as.matrix(read.csv(
66+
"tools/debug/seed_benchmark_bcf_y_hat.csv"
67+
))
68+
69+
# Compare results
70+
sum(abs(y_hat_test_benchmark - bcf_model$y_hat_test) > 1e-6)
71+
72+
# Generate probit data
73+
random_seed <- 1234
74+
set.seed(random_seed)
75+
n <- 500
76+
p <- 50
77+
X <- matrix(runif(n * p), ncol = p)
78+
# fmt: skip
79+
f_X <- (
80+
((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) +
81+
((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) +
82+
((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) +
83+
((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)
84+
)
85+
mu_X <- f_X
86+
pi_X <- pnorm(f_X * 0.25)
87+
tau_X <- 0.5 * X[,2]
88+
Z <- rbinom(n, 1, pi_X)
89+
E_XZ <- mu_X + Z * tau_X
90+
W <- E_XZ + rnorm(n, 0, 1)
91+
y <- (W > 0) * 1
92+
93+
# Split into train and test sets
94+
test_set_pct <- 0.2
95+
n_test <- round(test_set_pct * n)
96+
n_train <- n - n_test
97+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
98+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
99+
X_test <- X[test_inds, ]
100+
X_train <- X[train_inds, ]
101+
Z_test <- Z[test_inds]
102+
Z_train <- Z[train_inds]
103+
pi_test <- pi_X[test_inds]
104+
pi_train <- pi_X[train_inds]
105+
W_test <- W[test_inds]
106+
W_train <- W[train_inds]
107+
y_test <- y[test_inds]
108+
y_train <- y[train_inds]
109+
110+
# Set a different global seed as a test
111+
set.seed(23446345)
112+
113+
# Run BCF model
114+
general_params <- list(num_threads = 1, random_seed = random_seed,
115+
probit_outcome_model = T)
116+
bcf_model <- bcf(
117+
X_train = X_train,
118+
Z_train = Z_train,
119+
propensity_train = pi_train,
120+
y_train = y_train,
121+
X_test = X_test,
122+
Z_test = Z_test,
123+
propensity_test = pi_test,
124+
num_gfr = 100,
125+
num_mcmc = 100,
126+
general_params = general_params
127+
)
128+
129+
# # Save results
130+
# write.csv(
131+
# bcf_model$y_hat_test,
132+
# file = "tools/debug/seed_benchmark_bcf_probit_y_hat.csv",
133+
# row.names = FALSE
134+
# )
135+
136+
# Read results and compare to our estimates
137+
y_hat_test_benchmark <- as.matrix(read.csv(
138+
"tools/debug/seed_benchmark_bcf_probit_y_hat.csv"
139+
))
140+
141+
# Compare results
142+
sum(abs(y_hat_test_benchmark - bcf_model$y_hat_test) > 1e-6)

0 commit comments

Comments
 (0)