-
-
Notifications
You must be signed in to change notification settings - Fork 10
MultiDistribution #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
4d7387d
0048240
5c96617
6a5bd80
c9df79c
e61da69
4a78dbc
cb4c824
7572ce3
b1e663d
b53aeda
2435d3f
c4acb6d
860897c
3219cd6
e74b4b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
//! The dirichlet distribution `Dirichlet(α₁, α₂, ..., αₙ)`. | ||
|
||
#![cfg(feature = "alloc")] | ||
use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal}; | ||
use crate::{multi::MultiDistribution, Beta, Distribution, Exp1, Gamma, Open01, StandardNormal}; | ||
use core::fmt; | ||
use num_traits::{Float, NumCast}; | ||
use rand::Rng; | ||
|
@@ -68,26 +68,27 @@ where | |
} | ||
} | ||
|
||
impl<F, const N: usize> Distribution<[F; N]> for DirichletFromGamma<F, N> | ||
impl<F, const N: usize> MultiDistribution<F> for DirichletFromGamma<F, N> | ||
where | ||
F: Float, | ||
StandardNormal: Distribution<F>, | ||
Exp1: Distribution<F>, | ||
Open01: Distribution<F>, | ||
{ | ||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] { | ||
let mut samples = [F::zero(); N]; | ||
fn sample_len(&self) -> usize { | ||
N | ||
} | ||
fn sample_to_buf<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) { | ||
let mut sum = F::zero(); | ||
|
||
for (s, g) in samples.iter_mut().zip(self.samplers.iter()) { | ||
for (s, g) in output.iter_mut().zip(self.samplers.iter()) { | ||
*s = g.sample(rng); | ||
sum = sum + *s; | ||
} | ||
let invacc = F::one() / sum; | ||
for s in samples.iter_mut() { | ||
for s in output.iter_mut() { | ||
*s = *s * invacc; | ||
} | ||
samples | ||
} | ||
Comment on lines
+81
to
92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the length of |
||
} | ||
|
||
|
@@ -149,24 +150,25 @@ where | |
} | ||
} | ||
|
||
impl<F, const N: usize> Distribution<[F; N]> for DirichletFromBeta<F, N> | ||
impl<F, const N: usize> MultiDistribution<F> for DirichletFromBeta<F, N> | ||
where | ||
F: Float, | ||
StandardNormal: Distribution<F>, | ||
Exp1: Distribution<F>, | ||
Open01: Distribution<F>, | ||
{ | ||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] { | ||
let mut samples = [F::zero(); N]; | ||
fn sample_len(&self) -> usize { | ||
N | ||
} | ||
fn sample_to_buf<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) { | ||
let mut acc = F::one(); | ||
|
||
for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) { | ||
for (s, beta) in output.iter_mut().zip(self.samplers.iter()) { | ||
let beta_sample = beta.sample(rng); | ||
*s = acc * beta_sample; | ||
acc = acc * (F::one() - beta_sample); | ||
} | ||
samples[N - 1] = acc; | ||
samples | ||
output[N - 1] = acc; | ||
} | ||
Comment on lines
+163
to
172
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This one too. |
||
} | ||
|
||
|
@@ -208,7 +210,8 @@ where | |
/// | ||
/// ``` | ||
/// use rand::prelude::*; | ||
/// use rand_distr::Dirichlet; | ||
/// use rand_distr::multi::Dirichlet; | ||
/// use rand_distr::multi::MultiDistribution; | ||
/// | ||
/// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); | ||
/// let samples = dirichlet.sample(&mut rand::rng()); | ||
|
@@ -259,7 +262,7 @@ impl fmt::Display for Error { | |
"failed to create required Gamma distribution for Dirichlet distribution" | ||
} | ||
Error::FailedToCreateBeta => { | ||
"failed to create required Beta distribition for Dirichlet distribution" | ||
"failed to create required Beta distribution for Dirichlet distribution" | ||
} | ||
}) | ||
} | ||
|
@@ -315,17 +318,20 @@ where | |
} | ||
} | ||
|
||
impl<F, const N: usize> Distribution<[F; N]> for Dirichlet<F, N> | ||
impl<F, const N: usize> MultiDistribution<F> for Dirichlet<F, N> | ||
where | ||
F: Float, | ||
StandardNormal: Distribution<F>, | ||
Exp1: Distribution<F>, | ||
Open01: Distribution<F>, | ||
{ | ||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] { | ||
fn sample_len(&self) -> usize { | ||
N | ||
} | ||
fn sample_to_buf<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) { | ||
match &self.repr { | ||
DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng), | ||
DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng), | ||
DirichletRepr::FromGamma(dirichlet) => dirichlet.sample_to_buf(rng, output), | ||
DirichletRepr::FromBeta(dirichlet) => dirichlet.sample_to_buf(rng, output), | ||
} | ||
} | ||
} | ||
|
@@ -403,7 +409,7 @@ mod test { | |
let alpha_sum: f64 = alpha.iter().sum(); | ||
let expected_mean = alpha.map(|x| x / alpha_sum); | ||
for i in 0..N { | ||
assert_almost_eq!(sample_mean[i], expected_mean[i], rtol); | ||
average::assert_almost_eq!(sample_mean[i], expected_mean[i], rtol); | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
//! Contains Multi-dimensional distributions. | ||
benjamin-lieser marked this conversation as resolved.
Show resolved
Hide resolved
|
||
//! | ||
//! We provide a trait `MultiDistribution` which allows to sample from a multi-dimensional distribution without extra allocations. | ||
//! All multi-dimensional distributions implement `MultiDistribution` instead of the `Distribution` trait. | ||
Comment on lines
+3
to
+4
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We generally wrap comments at 80 chars width (sometimes up to 100 if the line already has a large indent). The wording could be a little better, e.g.
|
||
|
||
use alloc::vec::Vec; | ||
use rand::Rng; | ||
|
||
/// This trait allows to sample from a multi-dimensional distribution without extra allocations. | ||
/// For convenience it also provides a `sample` method which returns the result as a `Vec`. | ||
pub trait MultiDistribution<T> { | ||
Comment on lines
+9
to
+11
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Items have a short one-line description, with additional details in new paragraphs. |
||
/// returns the length of one sample (dimension of the distribution) | ||
fn sample_len(&self) -> usize; | ||
benjamin-lieser marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// samples from the distribution and writes the result to `buf` | ||
fn sample_to_buf<R: Rng + ?Sized>(&self, rng: &mut R, buf: &mut [T]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method should be called |
||
/// samples from the distribution and returns the result as a `Vec`, to avoid extra allocations use `sample_to_buf` | ||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<T> | ||
where | ||
T: Default, | ||
{ | ||
let mut buf = Vec::new(); | ||
buf.resize_with(self.sample_len(), || T::default()); | ||
self.sample_to_buf(rng, &mut buf); | ||
buf | ||
} | ||
} | ||
|
||
pub use dirichlet::Dirichlet; | ||
|
||
mod dirichlet; |
Uh oh!
There was an error while loading. Please reload this page.