|
| 1 | +# Adding New Models to stochtree |
| 2 | + |
| 3 | +While the process of working with `stochtree`'s codebase to add |
| 4 | +functionality or fix bugs is covered in the [contributing](contributing.md) |
| 5 | +page, this page discusses a specific type of contribution in detail: |
| 6 | +contributing new models (i.e. likelihoods and leaf parameter priors). |
| 7 | + |
| 8 | +Our C++ core is designed to support any conditionally-conjugate model, but this flexibility requires some explanation in order to be easily modified. |
| 9 | + |
| 10 | +## Overview |
| 11 | + |
| 12 | +The key components of `stochtree`'s models are: |
| 13 | + |
| 14 | +1. A **SuffStat** class that stores and accumulates sufficient statistics |
| 15 | +2. A **LeafModel** class that computes marginal likelihoods / posterior parameters and samples leaf node parameters |
| 16 | + |
| 17 | +Each model implements a different version of these two classes. For example, the "classic" |
| 18 | +BART model with constant Gaussian leaves and a Gaussian likelihood is represented by the |
| 19 | +`GaussianConstantSuffStat` and `GaussianConstantLeafModel` classes. |
| 20 | + |
| 21 | +Each class implements a common API, and we use a [factory pattern](https://en.wikipedia.org/wiki/Factory_(object-oriented_programming)) and the C++17 |
| 22 | +[std::variant](https://www.cppreference.com/w/cpp/utility/variant.html) |
| 23 | +feature to dispatch the correct model at runtime. |
| 24 | +Finally, R and Python wrappers expose this flexibility through the BART / BCF interfaces. |
| 25 | + |
| 26 | +Adding a new leaf model thus requires implementing new `SuffStat` and `LeafModel` |
| 27 | +classes, then updating the factory functions and R / Python logic. |
| 28 | + |
| 29 | +## SuffStat Class |
| 30 | + |
| 31 | +As a pattern, sufficient statistic classes end in `*SuffStat` and implement several methods: |
| 32 | + |
| 33 | +* `IncrementSuffStat`: Increment a model's sufficient statistics by one data observation |
| 34 | +* `ResetSuffStat`: Reset a model's sufficient statistics to zero / empty |
| 35 | +* `AddSuffStat`: Combine two sufficient statistics, storing their sum in the sufficient statistic object that calls this method (without modifying the supplied `SuffStat` objects) |
| 36 | +* `SubtractSuffStat`: Same as above but subtracting the second `SuffStat` argument from the first, rather than adding |
| 37 | +* `SampleGreaterThan`: Checks whether the current sample size of a `SuffStat` object is greater than some threshold |
| 38 | +* `SampleGreaterThanEqual`: Checks whether the current sample size of a `SuffStat` object is greater than or equal to some threshold |
| 39 | +* `SampleSize`: Returns the current sample size of a `SuffStat` object |
| 40 | + |
| 41 | +For the sake of illustration, imagine we are adding a model called `OurNewModel`. The new sufficient statistic class should look something like: |
| 42 | + |
| 43 | +```cpp |
| 44 | +class OurNewModelSuffStat { |
| 45 | + public: |
| 46 | + data_size_t n; |
| 47 | + // Custom sufficient statistics for `OurNewModel` |
| 48 | + double stat1; |
| 49 | + double stat2; |
| 50 | + |
| 51 | + OurNewModelSuffStat() { |
| 52 | + n = 0; |
| 53 | + stat1 = 0.0; |
| 54 | + stat2 = 0.0; |
| 55 | + } |
| 56 | + |
| 57 | + void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, |
| 58 | + ForestTracker& tracker, data_size_t row_idx, int tree_idx) { |
| 59 | + n += 1; |
| 60 | + stat1 += /* accumulate from outcome, dataset, or tracker as needed */; |
| 61 | + stat2 += /* accumulate from outcome, dataset, or tracker as needed */; |
| 62 | + } |
| 63 | + |
| 64 | + void ResetSuffStat() { |
| 65 | + n = 0; |
| 66 | + stat1 = 0.0; |
| 67 | + stat2 = 0.0; |
| 68 | + } |
| 69 | + |
| 70 | + void AddSuffStat(OurNewModelSuffStat& lhs, OurNewModelSuffStat& rhs) { |
| 71 | + n = lhs.n + rhs.n; |
| 72 | + stat1 = lhs.stat1 + rhs.stat1; |
| 73 | + stat2 = lhs.stat2 + rhs.stat2; |
| 74 | + } |
| 75 | + |
| 76 | + void SubtractSuffStat(OurNewModelSuffStat& lhs, OurNewModelSuffStat& rhs) { |
| 77 | + n = lhs.n - rhs.n; |
| 78 | + stat1 = lhs.stat1 - rhs.stat1; |
| 79 | + stat2 = lhs.stat2 - rhs.stat2; |
| 80 | + } |
| 81 | + |
| 82 | + bool SampleGreaterThan(data_size_t threshold) { return n > threshold; } |
| 83 | + bool SampleGreaterThanEqual(data_size_t threshold) { return n >= threshold; } |
| 84 | + data_size_t SampleSize() { return n; } |
| 85 | +}; |
| 86 | +``` |
| 87 | +
|
| 88 | +## LeafModel Class |
| 89 | +
|
| 90 | +Leaf model classes end in `*LeafModel` and implement several methods: |
| 91 | +
|
| 92 | +* `SplitLogMarginalLikelihood`: the log marginal likelihood of a potential split, as a function of the sufficient statistics for the newly proposed left and right node (i.e. ignoring data points unaffected by a split) |
| 93 | +* `NoSplitLogMarginalLikelihood`: the log marginal likelihood of a node without splitting, as a function of the sufficient statistics for that node |
| 94 | +* `SampleLeafParameters`: Sample the leaf node parameters for every leaf in a provided tree, according to this model's conditionally conjugate leaf node posterior |
| 95 | +* `RequiresBasis`: Whether or not a model requires regressing on "basis functions" in the leaves |
| 96 | +
|
| 97 | +As above, imagine that we are implementing a new model called `OurNewModel`. The new leaf model class should look something like: |
| 98 | +
|
| 99 | +```cpp |
| 100 | +class OurNewModelLeafModel { |
| 101 | + public: |
| 102 | + OurNewModelLeafModel(/* model parameters */) { |
| 103 | + // Set model parameters |
| 104 | + } |
| 105 | + |
| 106 | + double SplitLogMarginalLikelihood(OurNewModelSuffStat& left_stat, |
| 107 | + OurNewModelSuffStat& right_stat, |
| 108 | + double global_variance) { |
| 109 | + double left_log_ml = /* calculate left node log ML */; |
| 110 | + double right_log_ml = /* calculate right node log ML */; |
| 111 | + return left_log_ml + right_log_ml; |
| 112 | + } |
| 113 | + |
| 114 | + double NoSplitLogMarginalLikelihood(OurNewModelSuffStat& suff_stat, |
| 115 | + double global_variance) { |
| 116 | + double log_ml = /* calculate node log ML */; |
| 117 | + return log_ml; |
| 118 | + } |
| 119 | + |
| 120 | + void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, |
| 121 | + ColumnVector& residual, Tree* tree, int tree_num, |
| 122 | + double global_variance, std::mt19937& gen) { |
| 123 | + // Sample parameters for every leaf in a tree, update `tree` directly |
| 124 | + } |
| 125 | + |
| 126 | + inline bool RequiresBasis() { return /* true/false based on your model */; } |
| 127 | +
|
| 128 | + // Helper methods below for `SampleLeafParameters`, which depend on the |
| 129 | + // nature of the leaf model (i.e. location-scale, shape-scale, etc...) |
| 130 | + |
| 131 | + double PosteriorParameterMean(OurNewModelSuffStat& suff_stat, |
| 132 | + double global_variance) { |
| 133 | + return /* calculate posterior mean */; |
| 134 | + } |
| 135 | + |
| 136 | + double PosteriorParameterVariance(OurNewModelSuffStat& suff_stat, |
| 137 | + double global_variance) { |
| 138 | + return /* calculate posterior variance */; |
| 139 | + } |
| 140 | + |
| 141 | + private: |
| 142 | + // Leaf model parameters |
| 143 | + double param1_; |
| 144 | + double param2_; |
| 145 | +}; |
| 146 | +``` |
| 147 | + |
| 148 | +## Factory Functions |
| 149 | + |
| 150 | +Updating the factory pattern to be able to dispatch `OurNewModel` has several steps. |
| 151 | + |
| 152 | +First, we add our model to the `ModelType` enum in `include/stochtree/leaf_model.h`: |
| 153 | + |
| 154 | +```cpp |
| 155 | +enum ModelType { |
| 156 | + kConstantLeafGaussian, |
| 157 | + kUnivariateRegressionLeafGaussian, |
| 158 | + kMultivariateRegressionLeafGaussian, |
| 159 | + kLogLinearVariance, |
| 160 | + kOurNewModel // New model |
| 161 | +}; |
| 162 | +``` |
| 163 | + |
| 164 | +Next, we add the `OurNewModelSuffStat` and `OurNewModelLeafModel` classes to the `std::variant` unions in `include/stochtree/leaf_model.h`: |
| 165 | + |
| 166 | +```cpp |
| 167 | +using SuffStatVariant = std::variant<GaussianConstantSuffStat, |
| 168 | + GaussianUnivariateRegressionSuffStat, |
| 169 | + GaussianMultivariateRegressionSuffStat, |
| 170 | + LogLinearVarianceSuffStat, |
| 171 | + OurNewModelSuffStat>; // New model |
| 172 | + |
| 173 | +using LeafModelVariant = std::variant<GaussianConstantLeafModel, |
| 174 | + GaussianUnivariateRegressionLeafModel, |
| 175 | + GaussianMultivariateRegressionLeafModel, |
| 176 | + LogLinearVarianceLeafModel, |
| 177 | + OurNewModelLeafModel>; // New model |
| 178 | +``` |
| 179 | + |
| 180 | +Finally, we update the factory functions to dispatch the correct class from the union based on the `ModelType` integer code |
| 181 | + |
| 182 | +```cpp |
| 183 | +static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim = 0) { |
| 184 | + if (model_type == kConstantLeafGaussian) { |
| 185 | + return createSuffStat<GaussianConstantSuffStat>(); |
| 186 | + } else if (model_type == kUnivariateRegressionLeafGaussian) { |
| 187 | + return createSuffStat<GaussianUnivariateRegressionSuffStat>(); |
| 188 | + } else if (model_type == kMultivariateRegressionLeafGaussian) { |
| 189 | + return createSuffStat<GaussianMultivariateRegressionSuffStat, int>(basis_dim); |
| 190 | + } else if (model_type == kLogLinearVariance) { |
| 191 | + return createSuffStat<LogLinearVarianceSuffStat>(); |
| 192 | + } else if (model_type == kOurNewModel) { // New model |
| 193 | + return createSuffStat<OurNewModelSuffStat>(); |
| 194 | + } else { |
| 195 | + Log::Fatal("Incompatible model type provided to suff stat factory"); |
| 196 | + } |
| 197 | +} |
| 198 | + |
| 199 | +static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, |
| 200 | + Eigen::MatrixXd& Sigma0, double a, double b) { |
| 201 | + if (model_type == kConstantLeafGaussian) { |
| 202 | + return createLeafModel<GaussianConstantLeafModel, double>(tau); |
| 203 | + } else if (model_type == kUnivariateRegressionLeafGaussian) { |
| 204 | + return createLeafModel<GaussianUnivariateRegressionLeafModel, double>(tau); |
| 205 | + } else if (model_type == kMultivariateRegressionLeafGaussian) { |
| 206 | + return createLeafModel<GaussianMultivariateRegressionLeafModel, Eigen::MatrixXd>(Sigma0); |
| 207 | + } else if (model_type == kLogLinearVariance) { |
| 208 | + return createLeafModel<LogLinearVarianceLeafModel, double, double>(a, b); |
| 209 | + } else if (model_type == kOurNewModel) { // New model |
| 210 | + return createLeafModel<OurNewModelLeafModel, /* initializer types */>(/* initializer values */); |
| 211 | + } else { |
| 212 | + Log::Fatal("Incompatible model type provided to leaf model factory"); |
| 213 | + } |
| 214 | +} |
| 215 | +``` |
| 216 | +
|
| 217 | +## R Wrapper |
| 218 | +
|
| 219 | +To reflect this change through to the R interface, we first add the new model to the logic in the `sample_gfr_one_iteration_cpp` |
| 220 | +and `sample_mcmc_one_iteration_cpp` functions in the `src/sampler.cpp` file |
| 221 | +
|
| 222 | +```cpp |
| 223 | +// Convert leaf model type to enum |
| 224 | +StochTree::ModelType model_type; |
| 225 | +if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; |
| 226 | +else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; |
| 227 | +else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; |
| 228 | +else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; |
| 229 | +else if (leaf_model_int == 4) model_type = StochTree::ModelType::kOurNewModel; // New model |
| 230 | +else StochTree::Log::Fatal("Invalid model type"); |
| 231 | +``` |
| 232 | + |
| 233 | +Then we add the integer code for `OurNewModel` to the `leaf_model_type` field signature in `R/config.R` |
| 234 | + |
| 235 | +```r |
| 236 | +#' @field leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression, 4 = your new model) |
| 237 | +leaf_model_type = NULL, |
| 238 | +``` |
| 239 | + |
| 240 | +## Python Wrapper |
| 241 | + |
| 242 | +Python's C++ wrapper code contains similar logic to that of the `src/sampler.cpp` file in the R interface. |
| 243 | +Add the new model to the `SampleOneIteration` method of the `ForestSamplerCpp` class in the `src/py_stochtree.cpp` file. |
| 244 | + |
| 245 | +```cpp |
| 246 | +// Convert leaf model type to enum |
| 247 | +StochTree::ModelType model_type; |
| 248 | +if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; |
| 249 | +else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; |
| 250 | +else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; |
| 251 | +else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; |
| 252 | +else if (leaf_model_int == 4) model_type = StochTree::ModelType::kOurNewModel; // New model |
| 253 | +else StochTree::Log::Fatal("Invalid model type"); |
| 254 | +``` |
| 255 | + |
| 256 | +And then add the integer code for your new model to the `leaf_model_type` documentation in `stochtree/config.py` |
| 257 | + |
| 258 | +## Additional Considerations |
| 259 | + |
| 260 | +Some of the `SuffStat` and `LeafModel` classes currently supported by stochtree require extra initialization parameters. |
| 261 | +We support this via [variadic templates](https://en.cppreference.com/w/cpp/language/parameter_pack.html) in C++ |
| 262 | + |
| 263 | +```cpp |
| 264 | +template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs> |
| 265 | +static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, |
| 266 | + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights, |
| 267 | + std::vector<int>& sweep_update_indices, double global_variance, std::vector<FeatureType>& feature_types, int cutpoint_grid_size, |
| 268 | + bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) |
| 269 | +``` |
| 270 | +
|
| 271 | +If your new classes take any initialization arguments, these are provided in the factory functions, so you might also need to edit the signature of the factory functions. |
0 commit comments