Skip to content

Commit 2b74ba8

Browse files
committed
Propagating multi-threading updates to python, temporarily removing enumeration unit tests after C++ refactor
1 parent ee156fb commit 2b74ba8

File tree

16 files changed

+590
-356
lines changed

16 files changed

+590
-356
lines changed

CMakeLists.txt

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Build options
2-
option(USE_DEBUG "Set to ON for Debug mode" OFF)
2+
option(USE_DEBUG "Build with debug symbols and without optimization" OFF)
33
option(USE_SANITIZER "Use santizer flags" OFF)
4+
option(USE_OPENMP "Use openMP" ON)
5+
option(USE_HOMEBREW_FALLBACK "(macOS-only) also look in 'brew --prefix' for libraries (e.g. OpenMP)" ON)
46
option(BUILD_TEST "Build C++ tests with Google Test" OFF)
57
option(BUILD_DEBUG_TARGETS "Build Standalone C++ Programs for Debugging" ON)
68
option(BUILD_PYTHON "Build Shared Library for Python Package" OFF)
@@ -9,8 +11,8 @@ option(BUILD_PYTHON "Build Shared Library for Python Package" OFF)
911
set(CMAKE_CXX_STANDARD 17)
1012
set(CMAKE_CXX_STANDARD_REQUIRED ON)
1113

12-
# Default to CMake 3.16
13-
cmake_minimum_required(VERSION 3.16)
14+
# Default to CMake 3.20
15+
cmake_minimum_required(VERSION 3.20)
1416

1517
# Define the project
1618
project(stochtree LANGUAGES C CXX)
@@ -34,6 +36,13 @@ if(USE_DEBUG)
3436
add_definitions(-DDEBUG)
3537
endif()
3638

39+
# Linker flags (empty by default, updated if using openmp)
40+
set(
41+
STOCHTREE_LINK_FLAGS
42+
""
43+
)
44+
45+
# Unix / MinGW compiler flags
3746
if(UNIX OR MINGW OR CYGWIN)
3847
set(
3948
CMAKE_CXX_FLAGS
@@ -42,11 +51,12 @@ if(UNIX OR MINGW OR CYGWIN)
4251
if(USE_DEBUG)
4352
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0")
4453
else()
45-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
54+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O3")
4655
endif()
4756
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unknown-pragmas -Wno-unused-private-field")
4857
endif()
4958

59+
# MSVC compiler flags
5060
if(MSVC)
5161
set(
5262
variables
@@ -72,6 +82,30 @@ else()
7282
endif()
7383
endif()
7484

85+
# OpenMP
86+
if(USE_OPENMP)
87+
add_definitions(-DSTOCHTREE_OPENMP_AVAILABLE)
88+
if(APPLE)
89+
find_package(OpenMP)
90+
if(NOT OpenMP_FOUND)
91+
if(USE_HOMEBREW_FALLBACK)
92+
execute_process(COMMAND brew --prefix libomp
93+
OUTPUT_VARIABLE HOMEBREW_LIBOMP_PREFIX
94+
OUTPUT_STRIP_TRAILING_WHITESPACE)
95+
set(OpenMP_CXX_FLAGS "-Xclang -fopenmp")
96+
set(OpenMP_CXX_INCLUDE_DIR "-I${HOMEBREW_LIBOMP_PREFIX}/include")
97+
set(OpenMP_CXX_LIB_NAMES omp)
98+
set(OpenMP_libomp_LIBRARY ${HOMEBREW_LIBOMP_PREFIX}/lib/libomp.dylib)
99+
endif()
100+
find_package(OpenMP REQUIRED)
101+
endif()
102+
else()
103+
find_package(OpenMP REQUIRED)
104+
endif()
105+
# Update flags with openmp
106+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
107+
endif()
108+
75109
# Header file directory
76110
set(StochTree_HEADER_DIR ${PROJECT_SOURCE_DIR}/include)
77111

@@ -80,6 +114,8 @@ set(BOOSTMATH_HEADER_DIR ${PROJECT_SOURCE_DIR}/deps/boost_math/include)
80114

81115
# Eigen header file directory
82116
set(EIGEN_HEADER_DIR ${PROJECT_SOURCE_DIR}/deps/eigen)
117+
add_definitions(-DEIGEN_MPL2_ONLY)
118+
add_definitions(-DEIGEN_DONT_PARALLELIZE)
83119

84120
# fast_double_parser header file directory
85121
set(FAST_DOUBLE_PARSER_HEADER_DIR ${PROJECT_SOURCE_DIR}/deps/fast_double_parser/include)
@@ -109,10 +145,11 @@ file(
109145
add_library(stochtree_objs OBJECT ${SOURCES})
110146

111147
# Include the headers in the source library
112-
target_include_directories(stochtree_objs PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
113-
114-
if(APPLE)
115-
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
148+
if(USE_OPENMP)
149+
target_include_directories(stochtree_objs PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
150+
target_link_libraries(stochtree_objs PRIVATE ${OpenMP_libomp_LIBRARY})
151+
else()
152+
target_include_directories(stochtree_objs PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
116153
endif()
117154

118155
# Python shared library
@@ -122,8 +159,13 @@ if (BUILD_PYTHON)
122159
pybind11_add_module(stochtree_cpp src/py_stochtree.cpp)
123160

124161
# Link to C++ source and headers
125-
target_include_directories(stochtree_cpp PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
126-
target_link_libraries(stochtree_cpp PRIVATE stochtree_objs)
162+
if(USE_OPENMP)
163+
target_include_directories(stochtree_cpp PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
164+
target_link_libraries(stochtree_cpp PRIVATE stochtree_objs ${OpenMP_libomp_LIBRARY})
165+
else()
166+
target_include_directories(stochtree_cpp PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
167+
target_link_libraries(stochtree_cpp PRIVATE stochtree_objs)
168+
endif()
127169

128170
# EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a
129171
# define (VERSION_INFO) here.
@@ -154,8 +196,13 @@ if(BUILD_TEST)
154196
file(GLOB CPP_TEST_SOURCES test/cpp/*.cpp)
155197
add_executable(teststochtree ${CPP_TEST_SOURCES})
156198
set(STOCHTREE_TEST_HEADER_DIR ${PROJECT_SOURCE_DIR}/test/cpp)
157-
target_include_directories(teststochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${STOCHTREE_TEST_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
158-
target_link_libraries(teststochtree PRIVATE stochtree_objs GTest::gtest_main)
199+
if(USE_OPENMP)
200+
target_include_directories(teststochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${STOCHTREE_TEST_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
201+
target_link_libraries(teststochtree PRIVATE stochtree_objs GTest::gtest_main ${OpenMP_libomp_LIBRARY})
202+
else()
203+
target_include_directories(teststochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${STOCHTREE_TEST_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
204+
target_link_libraries(teststochtree PRIVATE stochtree_objs GTest::gtest_main)
205+
endif()
159206
gtest_discover_tests(teststochtree)
160207
endif()
161208

@@ -164,7 +211,12 @@ if(BUILD_DEBUG_TARGETS)
164211
# Build test suite
165212
add_executable(debugstochtree debug/api_debug.cpp)
166213
set(StochTree_DEBUG_HEADER_DIR ${PROJECT_SOURCE_DIR}/debug)
167-
target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
168-
target_link_libraries(debugstochtree PRIVATE stochtree_objs)
214+
if(USE_OPENMP)
215+
target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
216+
target_link_libraries(debugstochtree PRIVATE stochtree_objs ${OpenMP_libomp_LIBRARY})
217+
else()
218+
target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
219+
target_link_libraries(debugstochtree PRIVATE stochtree_objs)
220+
endif()
169221
endif()
170222

R/bcf.R

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
#' - `rfx_group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
5454
#' - `rfx_variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
5555
#' - `rfx_variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
56+
#' - `num_threads` Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads.
5657
#'
5758
#' @param prognostic_forest_params (Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional.
5859
#'
@@ -174,7 +175,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
174175
rfx_working_parameter_prior_cov = NULL,
175176
rfx_group_parameter_prior_cov = NULL,
176177
rfx_variance_prior_shape = 1,
177-
rfx_variance_prior_scale = 1
178+
rfx_variance_prior_scale = 1,
179+
num_threads = -1
178180
)
179181
general_params_updated <- preprocessParams(
180182
general_params_default, general_params
@@ -248,6 +250,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
248250
rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov
249251
rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape
250252
rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale
253+
num_threads <- general_params_updated$num_threads
251254

252255
# 2. Mu forest parameters
253256
num_trees_mu <- prognostic_forest_params_updated$num_trees
@@ -1029,7 +1032,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
10291032
forest_model_mu$sample_one_iteration(
10301033
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mu,
10311034
active_forest = active_forest_mu, rng = rng, forest_model_config = forest_model_config_mu,
1032-
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
1035+
global_model_config = global_model_config, num_threads = num_threads, keep_forest = keep_sample, gfr = TRUE
10331036
)
10341037

10351038
# Cache train set predictions since they are already computed during sampling
@@ -1053,7 +1056,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
10531056
forest_model_tau$sample_one_iteration(
10541057
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_tau,
10551058
active_forest = active_forest_tau, rng = rng, forest_model_config = forest_model_config_tau,
1056-
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
1059+
global_model_config = global_model_config, num_threads = num_threads, keep_forest = keep_sample, gfr = TRUE
10571060
)
10581061

10591062
# Cannot cache train set predictions for tau because the cached predictions in the
@@ -1102,7 +1105,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
11021105
forest_model_variance$sample_one_iteration(
11031106
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
11041107
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
1105-
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
1108+
global_model_config = global_model_config, num_threads = num_threads, keep_forest = keep_sample, gfr = TRUE
11061109
)
11071110

11081111
# Cache train set predictions since they are already computed during sampling
@@ -1309,7 +1312,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
13091312
forest_model_mu$sample_one_iteration(
13101313
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mu,
13111314
active_forest = active_forest_mu, rng = rng, forest_model_config = forest_model_config_mu,
1312-
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
1315+
global_model_config = global_model_config, num_threads = num_threads, keep_forest = keep_sample, gfr = FALSE
13131316
)
13141317

13151318
# Cache train set predictions since they are already computed during sampling
@@ -1333,7 +1336,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
13331336
forest_model_tau$sample_one_iteration(
13341337
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_tau,
13351338
active_forest = active_forest_tau, rng = rng, forest_model_config = forest_model_config_tau,
1336-
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
1339+
global_model_config = global_model_config, num_threads = num_threads, keep_forest = keep_sample, gfr = FALSE
13371340
)
13381341

13391342
# Cannot cache train set predictions for tau because the cached predictions in the
@@ -1382,7 +1385,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
13821385
forest_model_variance$sample_one_iteration(
13831386
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
13841387
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
1385-
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
1388+
global_model_config = global_model_config, num_threads = num_threads, keep_forest = keep_sample, gfr = FALSE
13861389
)
13871390

13881391
# Cache train set predictions since they are already computed during sampling

R/model.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ ForestModel <- R6::R6Class(
6767
#' @param rng Wrapper around C++ random number generator
6868
#' @param forest_model_config ForestModelConfig object containing forest model parameters and settings
6969
#' @param global_model_config GlobalModelConfig object containing global model parameters and settings
70-
#' @param num_threads Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads.
70+
#' @param num_threads Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's system, this will default to `1`, otherwise to the maximum number of available threads.
7171
#' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `TRUE`.
7272
#' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `TRUE`.
7373
sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest,

debug/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This subdirectory contains a debug program for the C++ codebase.
44
The program takes several command line arguments (in order):
55

66
1. Which data-generating process (DGP) to run (integer-coded, see below for a detailed description)
7-
1. Which leaf model to sample (integer-coded, see below for a detailed description)
7+
2. Which leaf model to sample (integer-coded, see below for a detailed description)
88
3. Whether or not to include random effects (0 = no, 1 = yes)
99
4. Number of grow-from-root (GFR) samples
1010
5. Number of MCMC samples
@@ -13,6 +13,7 @@ The program takes several command line arguments (in order):
1313
8. [Optional] index of outcome column in data file (leave this blank as `0`)
1414
9. [Optional] comma-delimited string of column indices of covariates (leave this blank as `""`)
1515
10. [Optional] comma-delimited string of column indices of leaf regression bases (leave this blank as `""`)
16+
11. [Optional] number of threads to use in the GFR sampler (leave this blank as `-1`)
1617

1718
The DGPs are numbered as follows:
1819

@@ -30,6 +31,6 @@ The models are numbered as follows:
3031

3132
For an example of how to run this progam for DGP 0, leaf model 1, no random effects, 10 GFR samples, 100 MCMC samples and a default seed (`-1`), run
3233

33-
`./build/debugstochtree 0 1 0 10 100 -1 "" 0 "" ""`
34+
`./build/debugstochtree 0 1 0 10 100 -1 "" 0 "" "" -1`
3435

3536
from the main `stochtree` project directory after building with `BUILD_DEBUG_TARGETS` set to `ON`.

0 commit comments

Comments
 (0)