|
| 1 | +# Load libraries |
| 2 | +library(stochtree) |
| 3 | + |
| 4 | +# Generate data |
| 5 | +random_seed <- 1234 |
| 6 | +set.seed(random_seed) |
| 7 | +n <- 500 |
| 8 | +p <- 50 |
| 9 | +X <- matrix(runif(n * p), ncol = p) |
| 10 | +# fmt: skip |
| 11 | +f_XW <- ( |
| 12 | + ((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + |
| 13 | + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + |
| 14 | + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + |
| 15 | + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5) |
| 16 | +) |
| 17 | +noise_sd <- 1 |
| 18 | +y <- f_XW + rnorm(n, 0, noise_sd) |
| 19 | + |
| 20 | +# Split into train and test sets |
| 21 | +test_set_pct <- 0.2 |
| 22 | +n_test <- round(test_set_pct * n) |
| 23 | +n_train <- n - n_test |
| 24 | +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) |
| 25 | +train_inds <- (1:n)[!((1:n) %in% test_inds)] |
| 26 | +X_test <- X[test_inds, ] |
| 27 | +X_train <- X[train_inds, ] |
| 28 | +y_test <- y[test_inds] |
| 29 | +y_train <- y[train_inds] |
| 30 | + |
| 31 | +# Run BART model |
| 32 | +general_params <- list(num_threads = 1, random_seed = random_seed) |
| 33 | +bart_model <- bart( |
| 34 | + X_train = X_train, |
| 35 | + y_train = y_train, |
| 36 | + X_test = X_test, |
| 37 | + num_gfr = 100, |
| 38 | + num_mcmc = 100, |
| 39 | + general_params = general_params |
| 40 | +) |
| 41 | + |
| 42 | +# # Save results |
| 43 | +# write.csv( |
| 44 | +# bart_model$y_hat_test, |
| 45 | +# file = "tools/debug/seed_benchmark_y_hat.csv", |
| 46 | +# row.names = FALSE |
| 47 | +# ) |
| 48 | + |
| 49 | +# Read results and compare to our estimates |
| 50 | +y_hat_test_benchmark <- as.matrix(read.csv( |
| 51 | + "tools/debug/seed_benchmark_y_hat.csv" |
| 52 | +)) |
| 53 | + |
| 54 | +# Compare results |
| 55 | +sum(abs(y_hat_test_benchmark - bart_model$y_hat_test) > 1e-6) |
0 commit comments