Skip to content

Commit 7423c76

Browse files
committed
Added methods to query data from a C++ object back to python as a numpy array
1 parent 69d60aa commit 7423c76

File tree

4 files changed

+154
-0
lines changed

4 files changed

+154
-0
lines changed

src/py_stochtree.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,48 @@ class ForestDatasetCpp {
7575
dataset_->UpdateVarWeights(data_ptr, num_row, exponentiate);
7676
}
7777

78+
py::array_t<double> GetCovariates() {
79+
// Initialize n x p numpy array to store the covariates
80+
data_size_t n = dataset_->NumObservations();
81+
int num_covariates = dataset_->NumCovariates();
82+
auto result = py::array_t<double>(py::detail::any_container<py::ssize_t>({n, num_covariates}));
83+
auto accessor = result.mutable_unchecked<2>();
84+
for (size_t i = 0; i < n; i++) {
85+
for (int j = 0; j < num_covariates; j++) {
86+
accessor(i,j) = dataset_->CovariateValue(i,j);
87+
}
88+
}
89+
90+
return result;
91+
}
92+
93+
py::array_t<double> GetBasis() {
94+
// Initialize n x k numpy array to store the basis
95+
data_size_t n = dataset_->NumObservations();
96+
int num_basis = dataset_->NumBasis();
97+
auto result = py::array_t<double>(py::detail::any_container<py::ssize_t>({n, num_basis}));
98+
auto accessor = result.mutable_unchecked<2>();
99+
for (size_t i = 0; i < n; i++) {
100+
for (int j = 0; j < num_basis; j++) {
101+
accessor(i,j) = dataset_->BasisValue(i,j);
102+
}
103+
}
104+
105+
return result;
106+
}
107+
108+
py::array_t<double> GetVarianceWeights() {
109+
// Initialize n x 1 numpy array to store the variance weights
110+
data_size_t n = dataset_->NumObservations();
111+
auto result = py::array_t<double>(py::detail::any_container<py::ssize_t>({n}));
112+
auto accessor = result.mutable_unchecked<1>();
113+
for (size_t i = 0; i < n; i++) {
114+
accessor(i) = dataset_->VarWeightValue(i);
115+
}
116+
117+
return result;
118+
}
119+
78120
data_size_t NumRows() {
79121
return dataset_->NumObservations();
80122
}
@@ -1313,6 +1355,36 @@ class RandomEffectsDatasetCpp {
13131355
}
13141356
rfx_dataset_->UpdateGroupLabels(group_labels_vec, num_row);
13151357
}
1358+
py::array_t<double> GetBasis() {
1359+
int num_row = rfx_dataset_->NumObservations();
1360+
int num_col = rfx_dataset_->NumBases();
1361+
auto result = py::array_t<double>(py::detail::any_container<py::ssize_t>({num_row, num_col}));
1362+
auto accessor = result.mutable_unchecked<2>();
1363+
for (py::ssize_t i = 0; i < num_row; i++) {
1364+
for (int j = 0; j < num_col; j++) {
1365+
accessor(i,j) = rfx_dataset_->BasisValue(i,j);
1366+
}
1367+
}
1368+
return result;
1369+
}
1370+
py::array_t<double> GetVarianceWeights() {
1371+
int num_row = rfx_dataset_->NumObservations();
1372+
auto result = py::array_t<double>(py::detail::any_container<py::ssize_t>({num_row}));
1373+
auto accessor = result.mutable_unchecked<1>();
1374+
for (py::ssize_t i = 0; i < num_row; i++) {
1375+
accessor(i) = rfx_dataset_->VarWeightValue(i);
1376+
}
1377+
return result;
1378+
}
1379+
py::array_t<int> GetGroupLabels() {
1380+
int num_row = rfx_dataset_->NumObservations();
1381+
auto result = py::array_t<int>(py::detail::any_container<py::ssize_t>({num_row}));
1382+
auto accessor = result.mutable_unchecked<1>();
1383+
for (py::ssize_t i = 0; i < num_row; i++) {
1384+
accessor(i) = rfx_dataset_->GroupId(i);
1385+
}
1386+
return result;
1387+
}
13161388
bool HasGroupLabels() {return rfx_dataset_->HasGroupLabels();}
13171389
bool HasBasis() {return rfx_dataset_->HasBasis();}
13181390
bool HasVarianceWeights() {return rfx_dataset_->HasVarWeights();}
@@ -2071,6 +2143,9 @@ PYBIND11_MODULE(stochtree_cpp, m) {
20712143
.def("NumRows", &ForestDatasetCpp::NumRows)
20722144
.def("NumCovariates", &ForestDatasetCpp::NumCovariates)
20732145
.def("NumBasis", &ForestDatasetCpp::NumBasis)
2146+
.def("GetCovariates", &ForestDatasetCpp::GetCovariates)
2147+
.def("GetBasis", &ForestDatasetCpp::GetBasis)
2148+
.def("GetVarianceWeights", &ForestDatasetCpp::GetVarianceWeights)
20742149
.def("HasBasis", &ForestDatasetCpp::HasBasis)
20752150
.def("HasVarianceWeights", &ForestDatasetCpp::HasVarianceWeights);
20762151

@@ -2204,6 +2279,9 @@ PYBIND11_MODULE(stochtree_cpp, m) {
22042279
.def("UpdateGroupLabels", &RandomEffectsDatasetCpp::UpdateGroupLabels)
22052280
.def("UpdateBasis", &RandomEffectsDatasetCpp::UpdateBasis)
22062281
.def("UpdateVarianceWeights", &RandomEffectsDatasetCpp::UpdateVarianceWeights)
2282+
.def("GetGroupLabels", &RandomEffectsDatasetCpp::GetGroupLabels)
2283+
.def("GetBasis", &RandomEffectsDatasetCpp::GetBasis)
2284+
.def("GetVarianceWeights", &RandomEffectsDatasetCpp::GetVarianceWeights)
22072285
.def("HasGroupLabels", &RandomEffectsDatasetCpp::HasGroupLabels)
22082286
.def("HasBasis", &RandomEffectsDatasetCpp::HasBasis)
22092287
.def("HasVarianceWeights", &RandomEffectsDatasetCpp::HasVarianceWeights);

stochtree/data.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,39 @@ def num_basis(self) -> int:
150150
Dimension of the basis vector in the dataset, returning 0 if the dataset does not have a basis
151151
"""
152152
return self.dataset_cpp.NumBasis()
153+
154+
def get_covariates(self) -> np.array:
155+
"""
156+
Return the covariates in a Dataset as a numpy array
157+
158+
Returns
159+
-------
160+
np.array
161+
Covariate data
162+
"""
163+
return self.dataset_cpp.GetCovariates()
164+
165+
def get_basis(self) -> np.array:
166+
"""
167+
Return the bases in a Dataset as a numpy array
168+
169+
Returns
170+
-------
171+
np.array
172+
Basis data
173+
"""
174+
return self.dataset_cpp.GetBasis()
175+
176+
def get_variance_weights(self) -> np.array:
177+
"""
178+
Return the variance weights in a Dataset as a numpy array
179+
180+
Returns
181+
-------
182+
np.array
183+
Variance weights data
184+
"""
185+
return self.dataset_cpp.GetVarianceWeights()
153186

154187
def has_basis(self) -> bool:
155188
"""

stochtree/random_effects.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,39 @@ def update_variance_weights(self, variance_weights: np.array, exponentiate: bool
136136
if self.num_observations() != n:
137137
raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).")
138138
self.rfx_dataset_cpp.UpdateVarianceWeights(variance_weights, n, exponentiate)
139+
140+
def get_group_labels(self) -> np.array:
141+
"""
142+
Return the group labels in a RandomEffectsDataset as a numpy array
143+
144+
Returns
145+
-------
146+
np.array
147+
One-dimensional numpy array of group labels.
148+
"""
149+
return self.rfx_dataset_cpp.GetGroupLabels()
150+
151+
def get_basis(self) -> np.array:
152+
"""
153+
Return the bases in a RandomEffectsDataset as a numpy array
154+
155+
Returns
156+
-------
157+
np.array
158+
Two-dimensional numpy array of basis vectors.
159+
"""
160+
return self.rfx_dataset_cpp.GetBasis()
161+
162+
def get_variance_weights(self) -> np.array:
163+
"""
164+
Return the variance weights in a RandomEffectsDataset as a numpy array
165+
166+
Returns
167+
-------
168+
np.array
169+
One-dimensional numpy array of variance weights.
170+
"""
171+
return self.rfx_dataset_cpp.GetVarianceWeights()
139172

140173
def num_observations(self) -> int:
141174
"""

test/python/test_data.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ def test_dataset_update(self):
2929
with np.testing.assert_no_warnings():
3030
forest_dataset.update_basis(new_basis)
3131
forest_dataset.update_variance_weights(new_variance_weights)
32+
33+
# Check that we recover the correct data through get_covariates, get_basis, and get_variance_weights
34+
np.testing.assert_array_equal(forest_dataset.get_covariates(), covariates)
35+
np.testing.assert_array_equal(forest_dataset.get_basis(), new_basis)
36+
np.testing.assert_array_equal(forest_dataset.get_variance_weights(), new_variance_weights)
3237

3338
class TestRFXDataset:
3439
def test_rfx_dataset_update(self):
@@ -59,4 +64,9 @@ def test_rfx_dataset_update(self):
5964
with np.testing.assert_no_warnings():
6065
rfx_dataset.update_basis(new_basis)
6166
rfx_dataset.update_variance_weights(new_variance_weights)
67+
68+
# Check that we recover the correct data through get_group_labels, get_basis, and get_variance_weights
69+
np.testing.assert_array_equal(rfx_dataset.get_group_labels(), group_labels)
70+
np.testing.assert_array_equal(rfx_dataset.get_basis(), new_basis)
71+
np.testing.assert_array_equal(rfx_dataset.get_variance_weights(), new_variance_weights)
6272

0 commit comments

Comments
 (0)