Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ sudo: false

matrix:
include:
- rust: 1.32.0
name: "Linux, 1.32.0"
- rust: 1.36.0
name: "Linux, 1.36.0"
env: ALLOC=0
os: linux

Expand Down
10 changes: 9 additions & 1 deletion rand_distr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,17 @@ travis-ci = { repository = "rust-random/rand" }
appveyor = { repository = "rust-random/rand" }

[dependencies]
rand = { path = "..", version = "0.7" }
rand = { path = "..", version = "0.7", default-features = false }
num-traits = { version = "0.2", default-features = false, features = ["libm"] }

[features]
default = ["std"]
std = ["rand/std", "num-traits/std", "alloc"]
alloc = ["rand/alloc"]

[dev-dependencies]
rand_pcg = { version = "0.2", path = "../rand_pcg" }
# For inline examples
rand = { path = "..", version = "0.7", default-features = false, features = ["std_rng", "std"] }
# Histogram implementation for testing uniformity
average = "0.10.3"
27 changes: 5 additions & 22 deletions rand_distr/src/binomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

use crate::{Distribution, Uniform};
use rand::Rng;
use std::{error, fmt};
use core::fmt;

/// The binomial distribution `Binomial(n, p)`.
///
Expand Down Expand Up @@ -53,7 +53,8 @@ impl fmt::Display for Error {
}
}

impl error::Error for Error {}
#[cfg(feature = "std")]
impl std::error::Error for Error {}

impl Binomial {
/// Construct a new `Binomial` with the given shape parameters `n` (number
Expand All @@ -72,7 +73,7 @@ impl Binomial {
/// Convert a `f64` to an `i64`, panicing on overflow.
// In the future (Rust 1.34), this might be replaced with `TryFrom`.
fn f64_to_i64(x: f64) -> i64 {
assert!(x < (::std::i64::MAX as f64));
assert!(x < (core::i64::MAX as f64));
x as i64
}

Expand Down Expand Up @@ -106,7 +107,7 @@ impl Distribution<u64> for Binomial {
// Ranlib uses 30, and GSL uses 14.
const BINV_THRESHOLD: f64 = 10.;

if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (::std::i32::MAX as u64) {
if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (core::i32::MAX as u64) {
// Use the BINV algorithm.
let s = p / q;
let a = ((self.n + 1) as f64) * s;
Expand Down Expand Up @@ -338,22 +339,4 @@ mod test {
fn test_binomial_invalid_lambda_neg() {
Binomial::new(20, -10.0).unwrap();
}

#[test]
fn value_stability() {
fn test_samples(n: u64, p: f64, expected: &[u64]) {
let distr = Binomial::new(n, p).unwrap();
let mut rng = crate::test::rng(353);
let mut buf = [0; 4];
for x in &mut buf {
*x = rng.sample(&distr);
}
assert_eq!(buf, expected);
}

// We have multiple code paths: np < 10, p > 0.5
test_samples(2, 0.7, &[1, 1, 2, 1]);
test_samples(20, 0.3, &[7, 7, 5, 7]);
test_samples(2000, 0.6, &[1194, 1208, 1192, 1210]);
}
}
43 changes: 25 additions & 18 deletions rand_distr/src/cauchy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

//! The Cauchy distribution.

use crate::utils::Float;
use num_traits::{Float, FloatConst};
use crate::{Distribution, Standard};
use rand::Rng;
use std::{error, fmt};
use core::fmt;

/// The Cauchy distribution `Cauchy(median, scale)`.
///
Expand All @@ -32,9 +32,11 @@ use std::{error, fmt};
/// println!("{} is from a Cauchy(2, 5) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Cauchy<N> {
median: N,
scale: N,
pub struct Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F>
{
median: F,
scale: F,
}

/// Error type returned from `Cauchy::new`.
Expand All @@ -52,30 +54,31 @@ impl fmt::Display for Error {
}
}

impl error::Error for Error {}
#[cfg(feature = "std")]
impl std::error::Error for Error {}

impl<N: Float> Cauchy<N>
where Standard: Distribution<N>
impl<F> Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F>
{
/// Construct a new `Cauchy` with the given shape parameters
/// `median` the peak location and `scale` the scale factor.
pub fn new(median: N, scale: N) -> Result<Cauchy<N>, Error> {
if !(scale > N::from(0.0)) {
pub fn new(median: F, scale: F) -> Result<Cauchy<F>, Error> {
if !(scale > F::zero()) {
return Err(Error::ScaleTooSmall);
}
Ok(Cauchy { median, scale })
}
}

impl<N: Float> Distribution<N> for Cauchy<N>
where Standard: Distribution<N>
impl<F> Distribution<F> for Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
// sample from [0, 1)
let x = Standard.sample(rng);
// get standard cauchy random number
// note that π/2 is not exactly representable, even if x=0.5 the result is finite
let comp_dev = (N::pi() * x).tan();
let comp_dev = (F::PI() * x).tan();
// shift and scale according to parameters
self.median + self.scale * comp_dev
}
Expand Down Expand Up @@ -108,10 +111,12 @@ mod test {
sum += numbers[i];
}
let median = median(&mut numbers);
println!("Cauchy median: {}", median);
#[cfg(feature = "std")]
std::println!("Cauchy median: {}", median);
assert!((median - 10.0).abs() < 0.4); // not 100% certain, but probable enough
let mean = sum / 1000.0;
println!("Cauchy mean: {}", mean);
#[cfg(feature = "std")]
std::println!("Cauchy mean: {}", mean);
// for a Cauchy distribution the mean should not converge
assert!((mean - 10.0).abs() > 0.4); // not 100% certain, but probable enough
}
Expand All @@ -130,8 +135,10 @@ mod test {

#[test]
fn value_stability() {
fn gen_samples<N: Float + core::fmt::Debug>(m: N, s: N, buf: &mut [N])
where Standard: Distribution<N> {


fn gen_samples<F: Float + FloatConst + core::fmt::Debug>(m: F, s: F, buf: &mut [F])
where Standard: Distribution<F> {
let distr = Cauchy::new(m, s).unwrap();
let mut rng = crate::test::rng(353);
for x in buf {
Expand Down
76 changes: 35 additions & 41 deletions rand_distr/src/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
// except according to those terms.

//! The dirichlet distribution.

use crate::utils::Float;
#![cfg(feature = "alloc")]
use num_traits::Float;
use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal};
use rand::Rng;
use std::{error, fmt};
use core::fmt;
use alloc::{vec, vec::Vec};

/// The Dirichlet distribution `Dirichlet(alpha)`.
///
Expand All @@ -31,9 +32,15 @@ use std::{error, fmt};
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
/// ```
#[derive(Clone, Debug)]
pub struct Dirichlet<N> {
pub struct Dirichlet<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Concentration parameters (alpha)
alpha: Vec<N>,
alpha: Vec<F>,
}

/// Error type returned from `Dirchlet::new`.
Expand All @@ -58,25 +65,27 @@ impl fmt::Display for Error {
}
}

impl error::Error for Error {}
#[cfg(feature = "std")]
impl std::error::Error for Error {}

impl<N: Float> Dirichlet<N>
impl<F> Dirichlet<F>
where
StandardNormal: Distribution<N>,
Exp1: Distribution<N>,
Open01: Distribution<N>,
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
///
/// Requires `alpha.len() >= 2`.
#[inline]
pub fn new<V: Into<Vec<N>>>(alpha: V) -> Result<Dirichlet<N>, Error> {
pub fn new<V: Into<Vec<F>>>(alpha: V) -> Result<Dirichlet<F>, Error> {
let a = alpha.into();
if a.len() < 2 {
return Err(Error::AlphaTooShort);
}
for &ai in &a {
if !(ai > N::from(0.0)) {
if !(ai > F::zero()) {
return Err(Error::AlphaTooSmall);
}
}
Expand All @@ -88,8 +97,8 @@ where
///
/// Requires `size >= 2`.
#[inline]
pub fn new_with_size(alpha: N, size: usize) -> Result<Dirichlet<N>, Error> {
if !(alpha > N::from(0.0)) {
pub fn new_with_size(alpha: F, size: usize) -> Result<Dirichlet<F>, Error> {
if !(alpha > F::zero()) {
return Err(Error::AlphaTooSmall);
}
if size < 2 {
Expand All @@ -101,25 +110,26 @@ where
}
}

impl<N: Float> Distribution<Vec<N>> for Dirichlet<N>
impl<F> Distribution<Vec<F>> for Dirichlet<F>
where
StandardNormal: Distribution<N>,
Exp1: Distribution<N>,
Open01: Distribution<N>,
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<N> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<F> {
let n = self.alpha.len();
let mut samples = vec![N::from(0.0); n];
let mut sum = N::from(0.0);
let mut samples = vec![F::zero(); n];
let mut sum = F::zero();

for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) {
let g = Gamma::new(a, N::from(1.0)).unwrap();
let g = Gamma::new(a, F::one()).unwrap();
*s = g.sample(rng);
sum += *s;
sum = sum + (*s);
}
let invacc = N::from(1.0) / sum;
let invacc = F::one() / sum;
for s in samples.iter_mut() {
*s *= invacc;
*s = (*s)*invacc;
}
samples
}
Expand Down Expand Up @@ -170,20 +180,4 @@ mod test {
fn test_dirichlet_invalid_alpha() {
Dirichlet::new_with_size(0.0f64, 2).unwrap();
}

#[test]
fn value_stability() {
let mut rng = crate::test::rng(223);
assert_eq!(
rng.sample(Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap()),
vec![0.12941567177708177, 0.4702121891675036, 0.4003721390554146]
);
assert_eq!(rng.sample(Dirichlet::new_with_size(8.0, 5).unwrap()), vec![
0.17684200044809556,
0.29915953935953055,
0.1832858056608014,
0.1425623503573967,
0.19815030417417595
]);
}
}
Loading