Skip to content

Commit 9e9de2a

Browse files
committed
Added data retrieval methods and unit tests in R
1 parent 7423c76 commit 9e9de2a

File tree

7 files changed

+283
-0
lines changed

7 files changed

+283
-0
lines changed

R/cpp11.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ forest_dataset_add_weights_cpp <- function(dataset_ptr, weights) {
4444
invisible(.Call(`_stochtree_forest_dataset_add_weights_cpp`, dataset_ptr, weights))
4545
}
4646

47+
forest_dataset_get_covariates_cpp <- function(dataset_ptr) {
48+
.Call(`_stochtree_forest_dataset_get_covariates_cpp`, dataset_ptr)
49+
}
50+
51+
forest_dataset_get_basis_cpp <- function(dataset_ptr) {
52+
.Call(`_stochtree_forest_dataset_get_basis_cpp`, dataset_ptr)
53+
}
54+
55+
forest_dataset_get_variance_weights_cpp <- function(dataset_ptr) {
56+
.Call(`_stochtree_forest_dataset_get_variance_weights_cpp`, dataset_ptr)
57+
}
58+
4759
create_column_vector_cpp <- function(outcome) {
4860
.Call(`_stochtree_create_column_vector_cpp`, outcome)
4961
}
@@ -116,6 +128,18 @@ rfx_dataset_add_weights_cpp <- function(dataset_ptr, weights) {
116128
invisible(.Call(`_stochtree_rfx_dataset_add_weights_cpp`, dataset_ptr, weights))
117129
}
118130

131+
rfx_dataset_get_group_labels_cpp <- function(dataset_ptr) {
132+
.Call(`_stochtree_rfx_dataset_get_group_labels_cpp`, dataset_ptr)
133+
}
134+
135+
rfx_dataset_get_basis_cpp <- function(dataset_ptr) {
136+
.Call(`_stochtree_rfx_dataset_get_basis_cpp`, dataset_ptr)
137+
}
138+
139+
rfx_dataset_get_variance_weights_cpp <- function(dataset_ptr) {
140+
.Call(`_stochtree_rfx_dataset_get_variance_weights_cpp`, dataset_ptr)
141+
}
142+
119143
rfx_container_cpp <- function(num_components, num_groups) {
120144
.Call(`_stochtree_rfx_container_cpp`, num_components, num_groups)
121145
}

R/data.R

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,27 @@ ForestDataset <- R6::R6Class(
6868
return(dataset_num_basis_cpp(self$data_ptr))
6969
},
7070

71+
#' @description
72+
#' Return covariates as an R matrix
73+
#' @return Covariate data
74+
get_covariates = function() {
75+
return(forest_dataset_get_covariates_cpp(self$data_ptr))
76+
},
77+
78+
#' @description
79+
#' Return bases as an R matrix
80+
#' @return Basis data
81+
get_basis = function() {
82+
return(forest_dataset_get_basis_cpp(self$data_ptr))
83+
},
84+
85+
#' @description
86+
#' Return variance weights as an R vector
87+
#' @return Variance weight data
88+
get_variance_weights = function() {
89+
return(forest_dataset_get_variance_weights_cpp(self$data_ptr))
90+
},
91+
7192
#' @description
7293
#' Whether or not a dataset has a basis matrix
7394
#' @return True if basis matrix is loaded, false otherwise
@@ -230,6 +251,27 @@ RandomEffectsDataset <- R6::R6Class(
230251
return(rfx_dataset_num_basis_cpp(self$data_ptr))
231252
},
232253

254+
#' @description
255+
#' Return group labels as an R vector
256+
#' @return Group label data
257+
get_group_labels = function() {
258+
return(rfx_dataset_get_group_labels_cpp(self$data_ptr))
259+
},
260+
261+
#' @description
262+
#' Return bases as an R matrix
263+
#' @return Basis data
264+
get_basis = function() {
265+
return(rfx_dataset_get_basis_cpp(self$data_ptr))
266+
},
267+
268+
#' @description
269+
#' Return variance weights as an R vector
270+
#' @return Variance weight data
271+
get_variance_weights = function() {
272+
return(rfx_dataset_get_variance_weights_cpp(self$data_ptr))
273+
},
274+
233275
#' @description
234276
#' Whether or not a dataset has group label indices
235277
#' @return True if group label vector is loaded, false otherwise

man/ForestDataset.Rd

Lines changed: 42 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/RandomEffectsDataset.Rd

Lines changed: 42 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/R_data.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "cpp11/integers.hpp"
12
#include <cpp11.hpp>
23
#include <stochtree/container.h>
34
#include <stochtree/data.h>
@@ -106,6 +107,47 @@ void forest_dataset_add_weights_cpp(cpp11::external_pointer<StochTree::ForestDat
106107
UNPROTECT(1);
107108
}
108109

110+
[[cpp11::register]]
111+
cpp11::writable::doubles_matrix<> forest_dataset_get_covariates_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr) {
112+
// Initialize output matrix
113+
int num_row = dataset_ptr->NumObservations();
114+
int num_col = dataset_ptr->NumCovariates();
115+
cpp11::writable::doubles_matrix<> output(num_row, num_col);
116+
117+
for (int i = 0; i < num_row; i++) {
118+
for (int j = 0; j < num_col; j++) {
119+
output(i, j) = dataset_ptr->CovariateValue(i, j);
120+
}
121+
}
122+
123+
return output;
124+
}
125+
126+
[[cpp11::register]]
127+
cpp11::writable::doubles_matrix<> forest_dataset_get_basis_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr) {
128+
// Initialize output matrix
129+
int num_row = dataset_ptr->NumObservations();
130+
int num_col = dataset_ptr->NumBasis();
131+
cpp11::writable::doubles_matrix<> output(num_row, num_col);
132+
for (int i = 0; i < num_row; i++) {
133+
for (int j = 0; j < num_col; j++) {
134+
output(i, j) = dataset_ptr->BasisValue(i, j);
135+
}
136+
}
137+
return output;
138+
}
139+
140+
[[cpp11::register]]
141+
cpp11::writable::doubles forest_dataset_get_variance_weights_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr) {
142+
// Initialize output vector
143+
int num_row = dataset_ptr->NumObservations();
144+
cpp11::writable::doubles output(num_row);
145+
for (int i = 0; i < num_row; i++) {
146+
output.at(i) = dataset_ptr->VarWeightValue(i);
147+
}
148+
return output;
149+
}
150+
109151
[[cpp11::register]]
110152
cpp11::external_pointer<StochTree::ColumnVector> create_column_vector_cpp(cpp11::doubles outcome) {
111153
// Unpack pointers to data and dimensions
@@ -282,3 +324,36 @@ void rfx_dataset_add_weights_cpp(cpp11::external_pointer<StochTree::RandomEffect
282324
// Unprotect pointers to R data
283325
UNPROTECT(1);
284326
}
327+
328+
[[cpp11::register]]
329+
cpp11::writable::integers rfx_dataset_get_group_labels_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset_ptr) {
330+
int num_row = dataset_ptr->NumObservations();
331+
cpp11::writable::integers output(num_row);
332+
for (int i = 0; i < num_row; i++) {
333+
output.at(i) = dataset_ptr->GroupId(i);
334+
}
335+
return output;
336+
}
337+
338+
[[cpp11::register]]
339+
cpp11::writable::doubles_matrix<> rfx_dataset_get_basis_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset_ptr) {
340+
int num_row = dataset_ptr->NumObservations();
341+
int num_col = dataset_ptr->NumBases();
342+
cpp11::writable::doubles_matrix<> output(num_row, num_col);
343+
for (int i = 0; i < num_row; i++) {
344+
for (int j = 0; j < num_col; j++) {
345+
output(i, j) = dataset_ptr->BasisValue(i, j);
346+
}
347+
}
348+
return output;
349+
}
350+
351+
[[cpp11::register]]
352+
cpp11::writable::doubles rfx_dataset_get_variance_weights_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset_ptr) {
353+
int num_row = dataset_ptr->NumObservations();
354+
cpp11::writable::doubles output(num_row);
355+
for (int i = 0; i < num_row; i++) {
356+
output.at(i) = dataset_ptr->VarWeightValue(i);
357+
}
358+
return output;
359+
}

src/cpp11.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,27 @@ extern "C" SEXP _stochtree_forest_dataset_add_weights_cpp(SEXP dataset_ptr, SEXP
8888
END_CPP11
8989
}
9090
// R_data.cpp
91+
cpp11::writable::doubles_matrix<> forest_dataset_get_covariates_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr);
92+
extern "C" SEXP _stochtree_forest_dataset_get_covariates_cpp(SEXP dataset_ptr) {
93+
BEGIN_CPP11
94+
return cpp11::as_sexp(forest_dataset_get_covariates_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestDataset>>>(dataset_ptr)));
95+
END_CPP11
96+
}
97+
// R_data.cpp
98+
cpp11::writable::doubles_matrix<> forest_dataset_get_basis_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr);
99+
extern "C" SEXP _stochtree_forest_dataset_get_basis_cpp(SEXP dataset_ptr) {
100+
BEGIN_CPP11
101+
return cpp11::as_sexp(forest_dataset_get_basis_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestDataset>>>(dataset_ptr)));
102+
END_CPP11
103+
}
104+
// R_data.cpp
105+
cpp11::writable::doubles forest_dataset_get_variance_weights_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr);
106+
extern "C" SEXP _stochtree_forest_dataset_get_variance_weights_cpp(SEXP dataset_ptr) {
107+
BEGIN_CPP11
108+
return cpp11::as_sexp(forest_dataset_get_variance_weights_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestDataset>>>(dataset_ptr)));
109+
END_CPP11
110+
}
111+
// R_data.cpp
91112
cpp11::external_pointer<StochTree::ColumnVector> create_column_vector_cpp(cpp11::doubles outcome);
92113
extern "C" SEXP _stochtree_create_column_vector_cpp(SEXP outcome) {
93114
BEGIN_CPP11
@@ -223,6 +244,27 @@ extern "C" SEXP _stochtree_rfx_dataset_add_weights_cpp(SEXP dataset_ptr, SEXP we
223244
return R_NilValue;
224245
END_CPP11
225246
}
247+
// R_data.cpp
248+
cpp11::writable::integers rfx_dataset_get_group_labels_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset_ptr);
249+
extern "C" SEXP _stochtree_rfx_dataset_get_group_labels_cpp(SEXP dataset_ptr) {
250+
BEGIN_CPP11
251+
return cpp11::as_sexp(rfx_dataset_get_group_labels_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::RandomEffectsDataset>>>(dataset_ptr)));
252+
END_CPP11
253+
}
254+
// R_data.cpp
255+
cpp11::writable::doubles_matrix<> rfx_dataset_get_basis_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset_ptr);
256+
extern "C" SEXP _stochtree_rfx_dataset_get_basis_cpp(SEXP dataset_ptr) {
257+
BEGIN_CPP11
258+
return cpp11::as_sexp(rfx_dataset_get_basis_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::RandomEffectsDataset>>>(dataset_ptr)));
259+
END_CPP11
260+
}
261+
// R_data.cpp
262+
cpp11::writable::doubles rfx_dataset_get_variance_weights_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset_ptr);
263+
extern "C" SEXP _stochtree_rfx_dataset_get_variance_weights_cpp(SEXP dataset_ptr) {
264+
BEGIN_CPP11
265+
return cpp11::as_sexp(rfx_dataset_get_variance_weights_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::RandomEffectsDataset>>>(dataset_ptr)));
266+
END_CPP11
267+
}
226268
// R_random_effects.cpp
227269
cpp11::external_pointer<StochTree::RandomEffectsContainer> rfx_container_cpp(int num_components, int num_groups);
228270
extern "C" SEXP _stochtree_rfx_container_cpp(SEXP num_components, SEXP num_groups) {
@@ -1579,6 +1621,9 @@ static const R_CallMethodDef CallEntries[] = {
15791621
{"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2},
15801622
{"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2},
15811623
{"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2},
1624+
{"_stochtree_forest_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_basis_cpp, 1},
1625+
{"_stochtree_forest_dataset_get_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_covariates_cpp, 1},
1626+
{"_stochtree_forest_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_variance_weights_cpp, 1},
15821627
{"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2},
15831628
{"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 3},
15841629
{"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2},
@@ -1699,6 +1744,9 @@ static const R_CallMethodDef CallEntries[] = {
16991744
{"_stochtree_rfx_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_basis_cpp, 2},
17001745
{"_stochtree_rfx_dataset_add_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_group_labels_cpp, 2},
17011746
{"_stochtree_rfx_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_weights_cpp, 2},
1747+
{"_stochtree_rfx_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_basis_cpp, 1},
1748+
{"_stochtree_rfx_dataset_get_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_group_labels_cpp, 1},
1749+
{"_stochtree_rfx_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_variance_weights_cpp, 1},
17021750
{"_stochtree_rfx_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_basis_cpp, 1},
17031751
{"_stochtree_rfx_dataset_has_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_group_labels_cpp, 1},
17041752
{"_stochtree_rfx_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_variance_weights_cpp, 1},

test/R/testthat/test-dataset.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ test_that("ForestDataset can be constructed and updated", {
2525
expect_no_error(
2626
forest_dataset$update_variance_weights(new_variance_weights)
2727
)
28+
29+
# Check that we recover the correct data through get_covariates, get_basis, and get_variance_weights
30+
expect_equal(covariates, forest_dataset$get_covariates())
31+
expect_equal(new_basis, forest_dataset$get_basis())
32+
expect_equal(new_variance_weights, forest_dataset$get_variance_weights())
2833
})
2934

3035
test_that("RandomEffectsDataset can be constructed and updated", {
@@ -53,4 +58,9 @@ test_that("RandomEffectsDataset can be constructed and updated", {
5358
expect_no_error(
5459
rfx_dataset$update_variance_weights(new_variance_weights)
5560
)
61+
62+
# Check that we recover the correct data through get_group_labels, get_basis, and get_variance_weights
63+
expect_equal(group_ids, rfx_dataset$get_group_labels())
64+
expect_equal(new_basis, rfx_dataset$get_basis())
65+
expect_equal(new_variance_weights, rfx_dataset$get_variance_weights())
5666
})

0 commit comments

Comments
 (0)