Skip to content

Commit 69d60aa

Browse files
committed
Updated interfaces and added unit tests
1 parent 8cf4bed commit 69d60aa

File tree

7 files changed

+115
-7
lines changed

7 files changed

+115
-7
lines changed

R/cpp11.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ rfx_dataset_update_var_weights_cpp <- function(dataset_ptr, weights, exponentiat
8080
invisible(.Call(`_stochtree_rfx_dataset_update_var_weights_cpp`, dataset_ptr, weights, exponentiate))
8181
}
8282

83+
rfx_dataset_update_group_labels_cpp <- function(dataset_ptr, group_labels) {
84+
invisible(.Call(`_stochtree_rfx_dataset_update_group_labels_cpp`, dataset_ptr, group_labels))
85+
}
86+
8387
rfx_dataset_num_basis_cpp <- function(dataset) {
8488
.Call(`_stochtree_rfx_dataset_num_basis_cpp`, dataset)
8589
}

src/cpp11.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,14 @@ extern "C" SEXP _stochtree_rfx_dataset_update_var_weights_cpp(SEXP dataset_ptr,
157157
END_CPP11
158158
}
159159
// R_data.cpp
160+
void rfx_dataset_update_group_labels_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset_ptr, cpp11::integers group_labels);
161+
extern "C" SEXP _stochtree_rfx_dataset_update_group_labels_cpp(SEXP dataset_ptr, SEXP group_labels) {
162+
BEGIN_CPP11
163+
rfx_dataset_update_group_labels_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::RandomEffectsDataset>>>(dataset_ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(group_labels));
164+
return R_NilValue;
165+
END_CPP11
166+
}
167+
// R_data.cpp
160168
int rfx_dataset_num_basis_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset);
161169
extern "C" SEXP _stochtree_rfx_dataset_num_basis_cpp(SEXP dataset) {
162170
BEGIN_CPP11
@@ -1697,6 +1705,7 @@ static const R_CallMethodDef CallEntries[] = {
16971705
{"_stochtree_rfx_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_basis_cpp, 1},
16981706
{"_stochtree_rfx_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_rows_cpp, 1},
16991707
{"_stochtree_rfx_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_basis_cpp, 2},
1708+
{"_stochtree_rfx_dataset_update_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_group_labels_cpp, 2},
17001709
{"_stochtree_rfx_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_var_weights_cpp, 3},
17011710
{"_stochtree_rfx_group_ids_from_json_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_cpp, 2},
17021711
{"_stochtree_rfx_group_ids_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_string_cpp, 2},

src/py_stochtree.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ class ForestDatasetCpp {
6767
dataset_->AddVarianceWeights(data_ptr, num_row);
6868
}
6969

70-
void UpdateVarianceWeights(py::array_t<double> weight_vector, data_size_t num_row) {
70+
void UpdateVarianceWeights(py::array_t<double> weight_vector, data_size_t num_row, bool exponentiate) {
7171
// Extract pointer to contiguous block of memory
7272
double* data_ptr = static_cast<double*>(weight_vector.mutable_data());
7373

7474
// Load covariates
75-
dataset_->UpdateVarWeights(data_ptr, num_row);
75+
dataset_->UpdateVarWeights(data_ptr, num_row, exponentiate);
7676
}
7777

7878
data_size_t NumRows() {
@@ -2067,6 +2067,7 @@ PYBIND11_MODULE(stochtree_cpp, m) {
20672067
.def("AddBasis", &ForestDatasetCpp::AddBasis)
20682068
.def("UpdateBasis", &ForestDatasetCpp::UpdateBasis)
20692069
.def("AddVarianceWeights", &ForestDatasetCpp::AddVarianceWeights)
2070+
.def("UpdateVarianceWeights", &ForestDatasetCpp::UpdateVarianceWeights)
20702071
.def("NumRows", &ForestDatasetCpp::NumRows)
20712072
.def("NumCovariates", &ForestDatasetCpp::NumCovariates)
20722073
.def("NumBasis", &ForestDatasetCpp::NumBasis)
@@ -2200,6 +2201,9 @@ PYBIND11_MODULE(stochtree_cpp, m) {
22002201
.def("AddGroupLabels", &RandomEffectsDatasetCpp::AddGroupLabels)
22012202
.def("AddBasis", &RandomEffectsDatasetCpp::AddBasis)
22022203
.def("AddVarianceWeights", &RandomEffectsDatasetCpp::AddVarianceWeights)
2204+
.def("UpdateGroupLabels", &RandomEffectsDatasetCpp::UpdateGroupLabels)
2205+
.def("UpdateBasis", &RandomEffectsDatasetCpp::UpdateBasis)
2206+
.def("UpdateVarianceWeights", &RandomEffectsDatasetCpp::UpdateVarianceWeights)
22032207
.def("HasGroupLabels", &RandomEffectsDatasetCpp::HasGroupLabels)
22042208
.def("HasBasis", &RandomEffectsDatasetCpp::HasBasis)
22052209
.def("HasVarianceWeights", &RandomEffectsDatasetCpp::HasVarianceWeights);

stochtree/data.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def add_variance_weights(self, variance_weights: np.array):
9494

9595
self.dataset_cpp.AddVarianceWeights(variance_weights_, n)
9696

97-
def update_variance_weights(self, variance_weights: np.array):
97+
def update_variance_weights(self, variance_weights: np.array, exponentiate: bool = False):
9898
"""
9999
Update variance weights in a dataset. Allows users to build an ensemble that depends on
100100
variance weights that are updated throughout the sampler.
@@ -103,6 +103,8 @@ def update_variance_weights(self, variance_weights: np.array):
103103
----------
104104
variance_weights : np.array
105105
Univariate numpy array of variance weights.
106+
exponentiate : bool
107+
Whether to exponentiate the variance weights before storing them in the dataset.
106108
"""
107109
if not self.has_variance_weights():
108110
raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.")
@@ -114,7 +116,7 @@ def update_variance_weights(self, variance_weights: np.array):
114116
raise ValueError("variance_weights must be a 1-dimensional numpy array.")
115117
if self.num_observations() != n:
116118
raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).")
117-
self.dataset_cpp.UpdateVarianceWeights(variance_weights_, n)
119+
self.dataset_cpp.UpdateVarianceWeights(variance_weights_, n, exponentiate)
118120

119121
def num_observations(self) -> int:
120122
"""

stochtree/random_effects.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def add_variance_weights(self, variance_weights: np.array):
111111
n = variance_weights_.shape[0]
112112
self.rfx_dataset_cpp.AddVarianceWeights(variance_weights_, n)
113113

114-
def update_variance_weights(self, variance_weights: np.array):
114+
def update_variance_weights(self, variance_weights: np.array, exponentiate: bool = False):
115115
"""
116116
Update variance weights in a dataset. Allows users to build an ensemble that depends on
117117
variance weights that are updated throughout the sampler.
@@ -120,6 +120,8 @@ def update_variance_weights(self, variance_weights: np.array):
120120
----------
121121
variance_weights : np.array
122122
Univariate numpy array of variance weights.
123+
exponentiate : bool
124+
Whether to exponentiate the variance weights before storing them in the dataset.
123125
"""
124126
if not self.has_variance_weights():
125127
raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.")
@@ -133,7 +135,7 @@ def update_variance_weights(self, variance_weights: np.array):
133135
n = variance_weights_.shape[0]
134136
if self.num_observations() != n:
135137
raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).")
136-
self.rfx_dataset_cpp.UpdateVarianceWeights(variance_weights, n)
138+
self.rfx_dataset_cpp.UpdateVarianceWeights(variance_weights, n, exponentiate)
137139

138140
def num_observations(self) -> int:
139141
"""

test/R/testthat/test-dataset.R

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
test_that("ForestDataset can be constructed and updated", {
2+
# Generate data
3+
n <- 20
4+
num_covariates <- 10
5+
num_basis <- 5
6+
covariates <- matrix(runif(n * num_covariates), ncol = num_covariates)
7+
basis <- matrix(runif(n * num_basis), ncol = num_basis)
8+
variance_weights <- runif(n)
9+
10+
# Copy data to a ForestDataset object
11+
forest_dataset <- createForestDataset(covariates, basis, variance_weights)
12+
13+
# Run first round of expectations
14+
expect_equal(forest_dataset$num_observations(), n)
15+
expect_equal(forest_dataset$num_covariates(), num_covariates)
16+
expect_equal(forest_dataset$num_basis(), num_basis)
17+
expect_equal(forest_dataset$has_variance_weights(), T)
18+
19+
# Update data
20+
new_basis <- matrix(runif(n * num_basis), ncol = num_basis)
21+
new_variance_weights <- runif(n)
22+
expect_no_error(
23+
forest_dataset$update_basis(new_basis)
24+
)
25+
expect_no_error(
26+
forest_dataset$update_variance_weights(new_variance_weights)
27+
)
28+
})
29+
30+
test_that("RandomEffectsDataset can be constructed and updated", {
31+
# Generate data
32+
n <- 20
33+
num_groups <- 4
34+
num_basis <- 5
35+
group_ids <- sample(as.integer(1:num_groups), size = n, replace = T)
36+
rfx_basis <- cbind(1, matrix(runif(n*(num_basis-1)), ncol = (num_basis-1)))
37+
variance_weights <- runif(n)
38+
39+
# Copy data to a RandomEffectsDataset object
40+
rfx_dataset <- createRandomEffectsDataset(group_ids, rfx_basis, variance_weights)
41+
42+
# Run first round of expectations
43+
expect_equal(rfx_dataset$num_observations(), n)
44+
expect_equal(rfx_dataset$num_basis(), num_basis)
45+
expect_equal(rfx_dataset$has_variance_weights(), T)
46+
47+
# Update data
48+
new_rfx_basis <- matrix(runif(n * num_basis), ncol = num_basis)
49+
new_variance_weights <- runif(n)
50+
expect_no_error(
51+
rfx_dataset$update_basis(new_basis)
52+
)
53+
expect_no_error(
54+
rfx_dataset$update_variance_weights(new_variance_weights)
55+
)
56+
})

test/python/test_data.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
from stochtree import Dataset
3+
from stochtree import Dataset, RandomEffectsDataset
44

55
class TestDataset:
66
def test_dataset_update(self):
@@ -29,3 +29,34 @@ def test_dataset_update(self):
2929
with np.testing.assert_no_warnings():
3030
forest_dataset.update_basis(new_basis)
3131
forest_dataset.update_variance_weights(new_variance_weights)
32+
33+
class TestRFXDataset:
34+
def test_rfx_dataset_update(self):
35+
# Generate data
36+
n = 20
37+
num_groups = 4
38+
num_basis = 5
39+
rng = np.random.default_rng()
40+
group_labels = rng.choice(num_groups, size=n)
41+
basis = np.empty((n, num_basis))
42+
basis[:, 0] = 1.0
43+
if num_basis > 1:
44+
basis[:, 1:] = rng.uniform(-1, 1, (n, num_basis - 1))
45+
variance_weights = rng.uniform(0, 1, size=n)
46+
47+
# Construct dataset
48+
rfx_dataset = RandomEffectsDataset()
49+
rfx_dataset.add_group_labels(group_labels)
50+
rfx_dataset.add_basis(basis)
51+
rfx_dataset.add_variance_weights(variance_weights)
52+
assert rfx_dataset.num_observations() == n
53+
assert rfx_dataset.num_basis() == num_basis
54+
assert rfx_dataset.has_variance_weights()
55+
56+
# Update dataset
57+
new_basis = rng.uniform(0, 1, size=(n, num_basis))
58+
new_variance_weights = rng.uniform(0, 1, size=n)
59+
with np.testing.assert_no_warnings():
60+
rfx_dataset.update_basis(new_basis)
61+
rfx_dataset.update_variance_weights(new_variance_weights)
62+

0 commit comments

Comments
 (0)