Skip to content

Commit f621350

Browse files
authored
Merge pull request #198 from StochasticTree/rng_seed_debug
Add AIR format file and initial debug script for stochtree's propagation of random seed
2 parents 2d59a31 + ce57b76 commit f621350

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

air.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[format]
2+
line-width = 80
3+
indent-width = 2
4+
indent-style = "space"
5+
line-ending = "auto"
6+
persistent-line-breaks = true
7+
exclude = []
8+
default-exclude = true
9+
skip = []

tools/debug/bart_random_seed.R

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Load libraries
2+
library(stochtree)
3+
4+
# Generate data
5+
random_seed <- 1234
6+
set.seed(random_seed)
7+
n <- 500
8+
p <- 50
9+
X <- matrix(runif(n * p), ncol = p)
10+
# fmt: skip
11+
f_XW <- (
12+
((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) +
13+
((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) +
14+
((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) +
15+
((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)
16+
)
17+
noise_sd <- 1
18+
y <- f_XW + rnorm(n, 0, noise_sd)
19+
20+
# Split into train and test sets
21+
test_set_pct <- 0.2
22+
n_test <- round(test_set_pct * n)
23+
n_train <- n - n_test
24+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
25+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
26+
X_test <- X[test_inds, ]
27+
X_train <- X[train_inds, ]
28+
y_test <- y[test_inds]
29+
y_train <- y[train_inds]
30+
31+
# Run BART model
32+
general_params <- list(num_threads = 1, random_seed = random_seed)
33+
bart_model <- bart(
34+
X_train = X_train,
35+
y_train = y_train,
36+
X_test = X_test,
37+
num_gfr = 100,
38+
num_mcmc = 100,
39+
general_params = general_params
40+
)
41+
42+
# # Save results
43+
# write.csv(
44+
# bart_model$y_hat_test,
45+
# file = "tools/debug/seed_benchmark_y_hat.csv",
46+
# row.names = FALSE
47+
# )
48+
49+
# Read results and compare to our estimates
50+
y_hat_test_benchmark <- as.matrix(read.csv(
51+
"tools/debug/seed_benchmark_y_hat.csv"
52+
))
53+
54+
# Compare results
55+
sum(abs(y_hat_test_benchmark - bart_model$y_hat_test) > 1e-6)

0 commit comments

Comments
 (0)