Skip to content

Commit 92cf76d

Browse files
authored
Merge pull request #18 from StochasticTree/adding-new-model-workflow
Added instructions on adding a new model to stochtree
2 parents 3d96722 + 5d459fe commit 92cf76d

File tree

3 files changed

+273
-0
lines changed

3 files changed

+273
-0
lines changed

docs/development/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
`stochtree` is in active development. Here, we detail some aspects of the development process
44

55
* [Contributing](contributing.md): how to get involved with stochtree, by contributing code, documentation, or helpful feedback
6+
* [Adding New Models](new-models.md): how to add a new outcome model in C++ and make it available through the R and Python frontends
67
* [Roadmap](roadmap.md): timelines for new feature development and releases

docs/development/new-models.md

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Adding New Models to stochtree
2+
3+
While the process of working with `stochtree`'s codebase to add
4+
functionality or fix bugs is covered in the [contributing](contributing.md)
5+
page, this page discusses a specific type of contribution in detail:
6+
contributing new models (i.e. likelihoods and leaf parameter priors).
7+
8+
Our C++ core is designed to support any conditionally-conjugate model, but this flexibility requires some explanation in order to be easily modified.
9+
10+
## Overview
11+
12+
The key components of `stochtree`'s models are:
13+
14+
1. A **SuffStat** class that stores and accumulates sufficient statistics
15+
2. A **LeafModel** class that computes marginal likelihoods / posterior parameters and samples leaf node parameters
16+
17+
Each model implements a different version of these two classes. For example, the "classic"
18+
BART model with constant Gaussian leaves and a Gaussian likelihood is represented by the
19+
`GaussianConstantSuffStat` and `GaussianConstantLeafModel` classes.
20+
21+
Each class implements a common API, and we use a [factory pattern](https://en.wikipedia.org/wiki/Factory_(object-oriented_programming)) and the C++17
22+
[std::variant](https://www.cppreference.com/w/cpp/utility/variant.html)
23+
feature to dispatch the correct model at runtime.
24+
Finally, R and Python wrappers expose this flexibility through the BART / BCF interfaces.
25+
26+
Adding a new leaf model thus requires implementing new `SuffStat` and `LeafModel`
27+
classes, then updating the factory functions and R / Python logic.
28+
29+
## SuffStat Class
30+
31+
As a pattern, sufficient statistic classes end in `*SuffStat` and implement several methods:
32+
33+
* `IncrementSuffStat`: Increment a model's sufficient statistics by one data observation
34+
* `ResetSuffStat`: Reset a model's sufficient statistics to zero / empty
35+
* `AddSuffStat`: Combine two sufficient statistics, storing their sum in the sufficient statistic object that calls this method (without modifying the supplied `SuffStat` objects)
36+
* `SubtractSuffStat`: Same as above but subtracting the second `SuffStat` argument from the first, rather than adding
37+
* `SampleGreaterThan`: Checks whether the current sample size of a `SuffStat` object is greater than some threshold
38+
* `SampleGreaterThanEqual`: Checks whether the current sample size of a `SuffStat` object is greater than or equal to some threshold
39+
* `SampleSize`: Returns the current sample size of a `SuffStat` object
40+
41+
For the sake of illustration, imagine we are adding a model called `OurNewModel`. The new sufficient statistic class should look something like:
42+
43+
```cpp
44+
class OurNewModelSuffStat {
45+
public:
46+
data_size_t n;
47+
// Custom sufficient statistics for `OurNewModel`
48+
double stat1;
49+
double stat2;
50+
51+
OurNewModelSuffStat() {
52+
n = 0;
53+
stat1 = 0.0;
54+
stat2 = 0.0;
55+
}
56+
57+
void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome,
58+
ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
59+
n += 1;
60+
stat1 += /* accumulate from outcome, dataset, or tracker as needed */;
61+
stat2 += /* accumulate from outcome, dataset, or tracker as needed */;
62+
}
63+
64+
void ResetSuffStat() {
65+
n = 0;
66+
stat1 = 0.0;
67+
stat2 = 0.0;
68+
}
69+
70+
void AddSuffStat(OurNewModelSuffStat& lhs, OurNewModelSuffStat& rhs) {
71+
n = lhs.n + rhs.n;
72+
stat1 = lhs.stat1 + rhs.stat1;
73+
stat2 = lhs.stat2 + rhs.stat2;
74+
}
75+
76+
void SubtractSuffStat(OurNewModelSuffStat& lhs, OurNewModelSuffStat& rhs) {
77+
n = lhs.n - rhs.n;
78+
stat1 = lhs.stat1 - rhs.stat1;
79+
stat2 = lhs.stat2 - rhs.stat2;
80+
}
81+
82+
bool SampleGreaterThan(data_size_t threshold) { return n > threshold; }
83+
bool SampleGreaterThanEqual(data_size_t threshold) { return n >= threshold; }
84+
data_size_t SampleSize() { return n; }
85+
};
86+
```
87+
88+
## LeafModel Class
89+
90+
Leaf model classes end in `*LeafModel` and implement several methods:
91+
92+
* `SplitLogMarginalLikelihood`: the log marginal likelihood of a potential split, as a function of the sufficient statistics for the newly proposed left and right node (i.e. ignoring data points unaffected by a split)
93+
* `NoSplitLogMarginalLikelihood`: the log marginal likelihood of a node without splitting, as a function of the sufficient statistics for that node
94+
* `SampleLeafParameters`: Sample the leaf node parameters for every leaf in a provided tree, according to this model's conditionally conjugate leaf node posterior
95+
* `RequiresBasis`: Whether or not a model requires regressing on "basis functions" in the leaves
96+
97+
As above, imagine that we are implementing a new model called `OurNewModel`. The new leaf model class should look something like:
98+
99+
```cpp
100+
class OurNewModelLeafModel {
101+
public:
102+
OurNewModelLeafModel(/* model parameters */) {
103+
// Set model parameters
104+
}
105+
106+
double SplitLogMarginalLikelihood(OurNewModelSuffStat& left_stat,
107+
OurNewModelSuffStat& right_stat,
108+
double global_variance) {
109+
double left_log_ml = /* calculate left node log ML */;
110+
double right_log_ml = /* calculate right node log ML */;
111+
return left_log_ml + right_log_ml;
112+
}
113+
114+
double NoSplitLogMarginalLikelihood(OurNewModelSuffStat& suff_stat,
115+
double global_variance) {
116+
double log_ml = /* calculate node log ML */;
117+
return log_ml;
118+
}
119+
120+
void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker,
121+
ColumnVector& residual, Tree* tree, int tree_num,
122+
double global_variance, std::mt19937& gen) {
123+
// Sample parameters for every leaf in a tree, update `tree` directly
124+
}
125+
126+
inline bool RequiresBasis() { return /* true/false based on your model */; }
127+
128+
// Helper methods below for `SampleLeafParameters`, which depend on the
129+
// nature of the leaf model (i.e. location-scale, shape-scale, etc...)
130+
131+
double PosteriorParameterMean(OurNewModelSuffStat& suff_stat,
132+
double global_variance) {
133+
return /* calculate posterior mean */;
134+
}
135+
136+
double PosteriorParameterVariance(OurNewModelSuffStat& suff_stat,
137+
double global_variance) {
138+
return /* calculate posterior variance */;
139+
}
140+
141+
private:
142+
// Leaf model parameters
143+
double param1_;
144+
double param2_;
145+
};
146+
```
147+
148+
## Factory Functions
149+
150+
Updating the factory pattern to be able to dispatch `OurNewModel` has several steps.
151+
152+
First, we add our model to the `ModelType` enum in `include/stochtree/leaf_model.h`:
153+
154+
```cpp
155+
enum ModelType {
156+
kConstantLeafGaussian,
157+
kUnivariateRegressionLeafGaussian,
158+
kMultivariateRegressionLeafGaussian,
159+
kLogLinearVariance,
160+
kOurNewModel // New model
161+
};
162+
```
163+
164+
Next, we add the `OurNewModelSuffStat` and `OurNewModelLeafModel` classes to the `std::variant` unions in `include/stochtree/leaf_model.h`:
165+
166+
```cpp
167+
using SuffStatVariant = std::variant<GaussianConstantSuffStat,
168+
GaussianUnivariateRegressionSuffStat,
169+
GaussianMultivariateRegressionSuffStat,
170+
LogLinearVarianceSuffStat,
171+
OurNewModelSuffStat>; // New model
172+
173+
using LeafModelVariant = std::variant<GaussianConstantLeafModel,
174+
GaussianUnivariateRegressionLeafModel,
175+
GaussianMultivariateRegressionLeafModel,
176+
LogLinearVarianceLeafModel,
177+
OurNewModelLeafModel>; // New model
178+
```
179+
180+
Finally, we update the factory functions to dispatch the correct class from the union based on the `ModelType` integer code
181+
182+
```cpp
183+
static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim = 0) {
184+
if (model_type == kConstantLeafGaussian) {
185+
return createSuffStat<GaussianConstantSuffStat>();
186+
} else if (model_type == kUnivariateRegressionLeafGaussian) {
187+
return createSuffStat<GaussianUnivariateRegressionSuffStat>();
188+
} else if (model_type == kMultivariateRegressionLeafGaussian) {
189+
return createSuffStat<GaussianMultivariateRegressionSuffStat, int>(basis_dim);
190+
} else if (model_type == kLogLinearVariance) {
191+
return createSuffStat<LogLinearVarianceSuffStat>();
192+
} else if (model_type == kOurNewModel) { // New model
193+
return createSuffStat<OurNewModelSuffStat>();
194+
} else {
195+
Log::Fatal("Incompatible model type provided to suff stat factory");
196+
}
197+
}
198+
199+
static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau,
200+
Eigen::MatrixXd& Sigma0, double a, double b) {
201+
if (model_type == kConstantLeafGaussian) {
202+
return createLeafModel<GaussianConstantLeafModel, double>(tau);
203+
} else if (model_type == kUnivariateRegressionLeafGaussian) {
204+
return createLeafModel<GaussianUnivariateRegressionLeafModel, double>(tau);
205+
} else if (model_type == kMultivariateRegressionLeafGaussian) {
206+
return createLeafModel<GaussianMultivariateRegressionLeafModel, Eigen::MatrixXd>(Sigma0);
207+
} else if (model_type == kLogLinearVariance) {
208+
return createLeafModel<LogLinearVarianceLeafModel, double, double>(a, b);
209+
} else if (model_type == kOurNewModel) { // New model
210+
return createLeafModel<OurNewModelLeafModel, /* initializer types */>(/* initializer values */);
211+
} else {
212+
Log::Fatal("Incompatible model type provided to leaf model factory");
213+
}
214+
}
215+
```
216+
217+
## R Wrapper
218+
219+
To reflect this change through to the R interface, we first add the new model to the logic in the `sample_gfr_one_iteration_cpp`
220+
and `sample_mcmc_one_iteration_cpp` functions in the `src/sampler.cpp` file
221+
222+
```cpp
223+
// Convert leaf model type to enum
224+
StochTree::ModelType model_type;
225+
if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian;
226+
else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian;
227+
else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian;
228+
else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance;
229+
else if (leaf_model_int == 4) model_type = StochTree::ModelType::kOurNewModel; // New model
230+
else StochTree::Log::Fatal("Invalid model type");
231+
```
232+
233+
Then we add the integer code for `OurNewModel` to the `leaf_model_type` field signature in `R/config.R`
234+
235+
```r
236+
#' @field leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression, 4 = your new model)
237+
leaf_model_type = NULL,
238+
```
239+
240+
## Python Wrapper
241+
242+
Python's C++ wrapper code contains similar logic to that of the `src/sampler.cpp` file in the R interface.
243+
Add the new model to the `SampleOneIteration` method of the `ForestSamplerCpp` class in the `src/py_stochtree.cpp` file.
244+
245+
```cpp
246+
// Convert leaf model type to enum
247+
StochTree::ModelType model_type;
248+
if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian;
249+
else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian;
250+
else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian;
251+
else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance;
252+
else if (leaf_model_int == 4) model_type = StochTree::ModelType::kOurNewModel; // New model
253+
else StochTree::Log::Fatal("Invalid model type");
254+
```
255+
256+
And then add the integer code for your new model to the `leaf_model_type` documentation in `stochtree/config.py`
257+
258+
## Additional Considerations
259+
260+
Some of the `SuffStat` and `LeafModel` classes currently supported by stochtree require extra initialization parameters.
261+
We support this via [variadic templates](https://en.cppreference.com/w/cpp/language/parameter_pack.html) in C++
262+
263+
```cpp
264+
template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
265+
static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
266+
ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
267+
std::vector<int>& sweep_update_indices, double global_variance, std::vector<FeatureType>& feature_types, int cutpoint_grid_size,
268+
bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, LeafSuffStatConstructorArgs&... leaf_suff_stat_args)
269+
```
270+
271+
If your new classes take any initialization arguments, these are provided in the factory functions, so you might also need to edit the signature of the factory functions.

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ nav:
8080
- 'Development':
8181
- 'Development': development/index.md
8282
- 'Contributing': development/contributing.md
83+
- 'Adding New Models': development/new-models.md
8384
- 'Roadmap': development/roadmap.md
8485
extra:
8586
social:

0 commit comments

Comments
 (0)