Skip to content

Commit 0acf4e4

Browse files
committed
Add mean to laplace helper distributions
1 parent 3d8d669 commit 0acf4e4

19 files changed

+243
-399
lines changed

stan/math/mix/prob.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33

44
#include <stan/math/mix/prob/laplace_latent_bernoulli_logit_rng.hpp>
55
#include <stan/math/mix/prob/laplace_latent_poisson_log_rng.hpp>
6-
#include <stan/math/mix/prob/laplace_latent_poisson_log_2_rng.hpp>
76
#include <stan/math/mix/prob/laplace_latent_neg_binomial_2_log_rng.hpp>
87
#include <stan/math/mix/prob/laplace_latent_rng.hpp>
98
#include <stan/math/mix/prob/laplace_marginal.hpp>
109
#include <stan/math/mix/prob/laplace_marginal_neg_binomial_2_log_lpmf.hpp>
1110
#include <stan/math/mix/prob/laplace_marginal_bernoulli_logit_lpmf.hpp>
12-
#include <stan/math/mix/prob/laplace_marginal_poisson_log_2_lpmf.hpp>
1311
#include <stan/math/mix/prob/laplace_marginal_poisson_log_lpmf.hpp>
1412

1513
#endif

stan/math/mix/prob/laplace_latent_bernoulli_logit_rng.hpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,34 @@ namespace math {
1919
* where the likelihood is a Bernoulli with logit link.
2020
* @tparam ThetaVec A type inheriting from `Eigen::EigenBase`
2121
* with dynamic sized rows and 1 column.
22+
* @tparam Mean type of the mean of the latent normal distribution
2223
* \laplace_common_template_args
2324
* @tparam RNG A valid boost rng type
2425
* @param[in] y Vector Vector of total number of trials with a positive outcome.
2526
* @param[in] n_samples Vector of number of trials.
27+
* @param[in] mean the mean of the latent normal variable.
2628
* \laplace_common_args
2729
* \laplace_options
2830
* \rng_arg
2931
* \msg_arg
3032
*/
31-
template <typename ThetaVec, typename CovarFun, typename CovarArgs,
32-
typename RNG, require_eigen_t<ThetaVec>* = nullptr>
33+
template <typename ThetaVec, typename Mean, typename CovarFun,
34+
typename CovarArgs, typename RNG,
35+
require_all_eigen_vector_t<ThetaVec>* = nullptr>
3336
inline Eigen::VectorXd laplace_latent_tol_bernoulli_logit_rng(
34-
const std::vector<int>& y, const std::vector<int>& n_samples,
37+
const std::vector<int>& y, const std::vector<int>& n_samples, Mean&& mean,
3538
CovarFun&& covariance_function, CovarArgs&& covar_args, ThetaVec&& theta_0,
3639
const double tolerance, const int max_num_steps,
3740
const int hessian_block_size, const int solver,
3841
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
3942
laplace_options_user_supplied ops{hessian_block_size, solver,
4043
max_steps_line_search, tolerance,
4144
max_num_steps, value_of(theta_0)};
42-
return laplace_base_rng(bernoulli_logit_likelihood{},
43-
std::forward_as_tuple(to_vector(y), n_samples),
44-
std::forward<CovarFun>(covariance_function),
45-
std::forward<CovarArgs>(covar_args), ops, rng, msgs);
45+
return laplace_base_rng(
46+
bernoulli_logit_likelihood{},
47+
std::forward_as_tuple(to_vector(y), n_samples, std::forward<Mean>(mean)),
48+
std::forward<CovarFun>(covariance_function),
49+
std::forward<CovarArgs>(covar_args), ops, rng, msgs);
4650
}
4751

4852
/**
@@ -54,24 +58,28 @@ inline Eigen::VectorXd laplace_latent_tol_bernoulli_logit_rng(
5458
* return a multivariate normal random variate sampled
5559
* from the gaussian approximation of p(theta | y, phi),
5660
* where the likelihood is a Bernoulli with logit link.
61+
* @tparam Mean type of the mean of the latent normal distribution
5762
* \laplace_common_template_args
5863
* @tparam RNG A valid boost rng type
5964
* @param[in] y Vector Vector of total number of trials with a positive outcome.
6065
* @param[in] n_samples Vector of number of trials.
66+
* @param[in] mean the mean of the latent normal variable.
6167
* \laplace_common_args
6268
* \rng_arg
6369
* \msg_arg
6470
*/
65-
template <typename CovarFun, typename CovarArgs, typename RNG>
71+
template <typename Mean, typename CovarFun, typename CovarArgs, typename RNG,
72+
require_eigen_vector_t<Mean>* = nullptr>
6673
inline Eigen::VectorXd laplace_latent_bernoulli_logit_rng(
67-
const std::vector<int>& y, const std::vector<int>& n_samples,
74+
const std::vector<int>& y, const std::vector<int>& n_samples, Mean&& mean,
6875
CovarFun&& covariance_function, CovarArgs&& covar_args, RNG& rng,
6976
std::ostream* msgs) {
70-
return laplace_base_rng(bernoulli_logit_likelihood{},
71-
std::forward_as_tuple(to_vector(y), n_samples),
72-
std::forward<CovarFun>(covariance_function),
73-
std::forward<CovarArgs>(covar_args),
74-
laplace_options_default{}, rng, msgs);
77+
return laplace_base_rng(
78+
bernoulli_logit_likelihood{},
79+
std::forward_as_tuple(to_vector(y), n_samples, std::forward<Mean>(mean)),
80+
std::forward<CovarFun>(covariance_function),
81+
std::forward<CovarArgs>(covar_args), laplace_options_default{}, rng,
82+
msgs);
7583
}
7684

7785
} // namespace math

stan/math/mix/prob/laplace_latent_neg_binomial_2_log_rng.hpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,34 @@ namespace math {
2323
* @tparam Eta A type for the overdispersion parameter.
2424
* @tparam ThetaVec A type inheriting from `Eigen::EigenBase`
2525
* with dynamic sized rows and 1 column.
26+
* @tparam Mean type of the mean of the latent normal distribution
2627
* \laplace_common_template_args
2728
* @tparam RNG A valid boost rng type
2829
* @param[in] y Observed counts.
2930
* @param[in] y_index Index indicating which group each observation belongs to.
3031
* @param[in] eta Overdisperison parameter.
32+
* @param[in] mean The mean of the latent normal variable.
3133
* \laplace_common_args
3234
* \laplace_options
3335
* \rng_arg
3436
* \msg_arg
3537
*/
36-
template <typename Eta, typename ThetaVec, typename CovarFun,
38+
template <typename Eta, typename ThetaVec, typename Mean, typename CovarFun,
3739
typename CovarArgs, typename RNG,
38-
require_eigen_t<ThetaVec>* = nullptr>
40+
require_eigen_vector_t<ThetaVec>* = nullptr>
3941
inline Eigen::VectorXd laplace_latent_tol_neg_binomial_2_log_rng(
4042
const std::vector<int>& y, const std::vector<int>& y_index, Eta&& eta,
41-
CovarFun&& covariance_function, CovarArgs&& covar_args, ThetaVec&& theta_0,
42-
const double tolerance, const int max_num_steps,
43+
Mean&& mean, CovarFun&& covariance_function, CovarArgs&& covar_args,
44+
ThetaVec&& theta_0, const double tolerance, const int max_num_steps,
4345
const int hessian_block_size, const int solver,
4446
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
4547
laplace_options_user_supplied ops{hessian_block_size, solver,
4648
max_steps_line_search, tolerance,
4749
max_num_steps, value_of(theta_0)};
4850
return laplace_base_rng(
4951
neg_binomial_2_log_likelihood{},
50-
std::forward_as_tuple(std::forward<Eta>(eta), y, y_index),
52+
std::forward_as_tuple(std::forward<Eta>(eta), y, y_index,
53+
std::forward<Mean>(mean)),
5154
std::forward<CovarFun>(covariance_function),
5255
std::forward<CovarArgs>(covar_args), ops, rng, msgs);
5356
}
@@ -65,23 +68,27 @@ inline Eigen::VectorXd laplace_latent_tol_neg_binomial_2_log_rng(
6568
* parameterization of the Negative Binomial.
6669
*
6770
* @tparam Eta A type for the overdispersion parameter.
71+
* @tparam Mean type of the mean of the latent normal distribution
6872
* \laplace_common_template_args
6973
* @tparam RNG A valid boost rng type
7074
* @param[in] y Observed counts.
7175
* @param[in] y_index Index indicating which group each observation belongs to.
7276
* @param[in] eta Overdisperison parameter.
77+
* @param[in] mean The mean of the latent normal variable.
7378
* \laplace_common_args
7479
* \rng_arg
7580
* \msg_arg
7681
*/
77-
template <typename Eta, typename CovarFun, typename CovarArgs, typename RNG>
82+
template <typename Eta, typename Mean, typename CovarFun, typename CovarArgs,
83+
typename RNG>
7884
inline Eigen::VectorXd laplace_latent_neg_binomial_2_log_rng(
7985
const std::vector<int>& y, const std::vector<int>& y_index, Eta&& eta,
80-
CovarFun&& covariance_function, CovarArgs&& covar_args, RNG& rng,
81-
std::ostream* msgs) {
86+
Mean&& mean, CovarFun&& covariance_function, CovarArgs&& covar_args,
87+
RNG& rng, std::ostream* msgs) {
8288
return laplace_base_rng(
8389
neg_binomial_2_log_likelihood{},
84-
std::forward_as_tuple(std::forward<Eta>(eta), y, y_index),
90+
std::forward_as_tuple(std::forward<Eta>(eta), y, y_index,
91+
std::forward<Mean>(mean)),
8592
std::forward<CovarFun>(covariance_function),
8693
std::forward<CovarArgs>(covar_args), laplace_options_default{}, rng,
8794
msgs);

stan/math/mix/prob/laplace_latent_poisson_log_2_rng.hpp

Lines changed: 0 additions & 85 deletions
This file was deleted.

stan/math/mix/prob/laplace_latent_poisson_log_rng.hpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,31 @@ namespace math {
1919
* In this specialized function, the likelihood p(y|theta) is a
2020
* @tparam ThetaVec A type inheriting from `Eigen::EigenBase`
2121
* with dynamic sized rows and 1 column.
22+
* @tparam Mean type of the mean of the latent normal distribution
2223
* \laplace_common_template_args
2324
* @tparam RNG A valid boost rng type
2425
* @param[in] y Observed counts.
2526
* @param[in] y_index Index indicating which group each observation belongs to.
27+
* @param[in] mean The mean of the latent normal variable.
2628
* \laplace_common_args
2729
* \laplace_options
2830
* \rng_arg
2931
* \msg_arg
3032
*/
31-
template <typename ThetaVec, typename CovarFun, typename CovarArgs,
32-
typename RNG, require_eigen_t<ThetaVec>* = nullptr>
33+
template <typename ThetaVec, typename Mean, typename CovarFun,
34+
typename CovarArgs, typename RNG,
35+
require_eigen_vector_t<ThetaVec>* = nullptr>
3336
inline Eigen::VectorXd laplace_latent_tol_poisson_log_rng(
3437
const std::vector<int>& y, const std::vector<int>& y_index,
35-
CovarFun&& covariance_function, CovarArgs&& covar_args, ThetaVec&& theta_0,
36-
const double tolerance, const int max_num_steps,
38+
const Mean& mean, CovarFun&& covariance_function, CovarArgs&& covar_args,
39+
ThetaVec&& theta_0, const double tolerance, const int max_num_steps,
3740
const int hessian_block_size, const int solver,
3841
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
3942
laplace_options_user_supplied ops{hessian_block_size, solver,
4043
max_steps_line_search, tolerance,
4144
max_num_steps, value_of(theta_0)};
4245
return laplace_base_rng(poisson_log_likelihood{},
43-
std::forward_as_tuple(y, y_index),
46+
std::forward_as_tuple(y, y_index, mean),
4447
std::forward<CovarFun>(covariance_function),
4548
std::forward<CovarArgs>(covar_args), ops, rng, msgs);
4649
}
@@ -55,21 +58,23 @@ inline Eigen::VectorXd laplace_latent_tol_poisson_log_rng(
5558
* The Laplace approximation is computed using a Newton solver.
5659
* In this specialized function, the likelihood p(y|theta) is a
5760
* Poisson with a log link.
61+
* @tparam Mean type of the mean of the latent normal distribution
5862
* \laplace_common_template_args
5963
* @tparam RNG A valid boost rng type
6064
* @param[in] y Observed counts.
6165
* @param[in] y_index Index indicating which group each observation belongs to.
66+
* @param[in] mean The mean of the latent normal variable.
6267
* \laplace_common_args
6368
* \rng_arg
6469
* \msg_arg
6570
*/
66-
template <typename CovarFun, typename CovarArgs, typename RNG>
71+
template <typename CovarFun, typename CovarArgs, typename RNG, typename Mean>
6772
inline Eigen::VectorXd laplace_latent_poisson_log_rng(
6873
const std::vector<int>& y, const std::vector<int>& y_index,
69-
CovarFun&& covariance_function, CovarArgs&& covar_args, RNG& rng,
70-
std::ostream* msgs) {
74+
const Mean& mean, CovarFun&& covariance_function, CovarArgs&& covar_args,
75+
RNG& rng, std::ostream* msgs) {
7176
return laplace_base_rng(poisson_log_likelihood{},
72-
std::forward_as_tuple(y, y_index),
77+
std::forward_as_tuple(y, y_index, mean),
7378
std::forward<CovarFun>(covariance_function),
7479
std::forward<CovarArgs>(covar_args),
7580
laplace_options_default{}, rng, msgs);

stan/math/mix/prob/laplace_marginal_bernoulli_logit_lpmf.hpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ namespace stan {
2020
namespace math {
2121

2222
struct bernoulli_logit_likelihood {
23-
template <typename T_theta, typename YVec>
24-
inline auto operator()(const T_theta& theta, const YVec& y,
25-
const std::vector<int>& delta_int,
23+
template <typename ThetaVec, typename YVec, typename Mean>
24+
inline auto operator()(const ThetaVec& theta, const YVec& y,
25+
const std::vector<int>& delta_int, const Mean& mean,
2626
std::ostream* pstream) const {
27-
return sum(elt_multiply(theta, y)
28-
- elt_multiply(to_vector(delta_int), log(add(1.0, exp(theta)))));
27+
auto theta_offset = to_ref(add(theta, mean));
28+
return sum(
29+
elt_multiply(theta_offset, y)
30+
- elt_multiply(to_vector(delta_int), log(add(1.0, exp(theta_offset)))));
2931
}
3032
};
3133

@@ -39,19 +41,22 @@ struct bernoulli_logit_likelihood {
3941
* @tparam propto boolean ignored
4042
* @tparam ThetaVec A type inheriting from `Eigen::EigenBase`
4143
* with dynamic sized rows and 1 column.
44+
* @tparam Mean type of the mean of the latent normal distribution
4245
* \laplace_common_template_args
4346
* @param[in] y total counts per group. Second sufficient statistics.
4447
* @param[in] n_samples number of samples per group. First sufficient
45-
* statistics.
48+
* statistics.
49+
* @param[in] mean the mean of the latent normal variable.
4650
* \laplace_common_args
4751
* \laplace_options
4852
* \msg_arg
4953
*/
50-
template <bool propto = false, typename ThetaVec, typename CovarFun,
51-
typename CovarArgs, require_eigen_t<ThetaVec>* = nullptr>
54+
template <bool propto = false, typename ThetaVec, typename Mean,
55+
typename CovarFun, typename CovarArgs,
56+
require_eigen_vector_t<ThetaVec>* = nullptr>
5257
inline auto laplace_marginal_tol_bernoulli_logit_lpmf(
5358
const std::vector<int>& y, const std::vector<int>& n_samples,
54-
CovarFun&& covariance_function, CovarArgs&& covar_args,
59+
const Mean& mean, CovarFun&& covariance_function, CovarArgs&& covar_args,
5560
const ThetaVec& theta_0, double tolerance, int max_num_steps,
5661
const int hessian_block_size, const int solver,
5762
const int max_steps_line_search, std::ostream* msgs) {
@@ -60,7 +65,7 @@ inline auto laplace_marginal_tol_bernoulli_logit_lpmf(
6065
max_num_steps, value_of(theta_0)};
6166
return laplace_marginal_density(
6267
bernoulli_logit_likelihood{},
63-
std::forward_as_tuple(to_vector(y), n_samples),
68+
std::forward_as_tuple(to_vector(y), n_samples, mean),
6469
std::forward<CovarFun>(covariance_function),
6570
std::forward<CovarArgs>(covar_args), ops, msgs);
6671
}
@@ -73,21 +78,24 @@ inline auto laplace_marginal_tol_bernoulli_logit_lpmf(
7378
* for more details.
7479
*
7580
* @tparam propto boolean ignored
81+
* @tparam Mean type of the mean of the latent normal distribution
7682
* \laplace_common_template_args
7783
* @param[in] y total counts per group. Second sufficient statistics.
7884
* @param[in] n_samples number of samples per group. First sufficient
79-
* statistics.
85+
* statistics.
86+
* @param[in] mean the mean of the latent normal variable.
8087
* \laplace_common_args
8188
* \msg_arg
8289
*/
83-
template <bool propto = false, typename CovarFun, typename CovarArgs>
90+
template <bool propto = false, typename Mean, typename CovarFun,
91+
typename CovarArgs>
8492
inline auto laplace_marginal_bernoulli_logit_lpmf(
8593
const std::vector<int>& y, const std::vector<int>& n_samples,
86-
CovarFun&& covariance_function, CovarArgs&& covar_args,
94+
const Mean& mean, CovarFun&& covariance_function, CovarArgs&& covar_args,
8795
std::ostream* msgs) {
8896
return laplace_marginal_density(
8997
bernoulli_logit_likelihood{},
90-
std::forward_as_tuple(to_vector(y), n_samples),
98+
std::forward_as_tuple(to_vector(y), n_samples, mean),
9199
std::forward<CovarFun>(covariance_function),
92100
std::forward<CovarArgs>(covar_args), laplace_options_default{}, msgs);
93101
}

0 commit comments

Comments
 (0)