Skip to content

Commit

Permalink
Use testing_boiler! for Uniform
Browse files Browse the repository at this point in the history
  • Loading branch information
FreezyLemon committed Sep 4, 2024
1 parent 6b0d71b commit 533a6bf
Showing 1 changed file with 89 additions and 126 deletions.
215 changes: 89 additions & 126 deletions src/distribution/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,216 +284,179 @@ impl Continuous<f64, f64> for Uniform {
#[rustfmt::skip]
#[cfg(test)]
mod tests {
use crate::statistics::*;
use crate::distribution::{ContinuousCDF, Continuous, Uniform};
use crate::distribution::internal::*;
use crate::statistics::*;
use crate::testing_boiler;

fn try_create(min: f64, max: f64) -> Uniform {
let n = Uniform::new(min, max);
assert!(n.is_ok(), "failed create over interval [{}, {}]", min, max);
n.unwrap()
}

fn create_case(min: f64, max: f64) {
let n = try_create(min, max);
assert_eq!(n.min(), min);
assert_eq!(n.max(), max);
}

fn bad_create_case(min: f64, max: f64) {
let n = Uniform::new(min, max);
assert!(n.is_err());
}

fn get_value<F>(min: f64, max: f64, eval: F) -> f64
where F: Fn(Uniform) -> f64
{
let n = try_create(min, max);
eval(n)
}

fn test_case<F>(min: f64, max: f64, expected: f64, eval: F)
where F: Fn(Uniform) -> f64
{

let x = get_value(min, max, eval);
assert_eq!(expected, x);
}

fn test_almost<F>(min: f64, max: f64, expected: f64, acc: f64, eval: F)
where F: Fn(Uniform) -> f64
{

let x = get_value(min, max, eval);
assert_almost_eq!(expected, x, acc);
}
testing_boiler!(min: f64, max: f64; Uniform);

#[test]
fn test_create() {
create_case(0.0, 0.1);
create_case(0.0, 1.0);
create_case(-5.0, 11.0);
create_case(-5.0, 100.0);
create_ok(0.0, 0.1);
create_ok(0.0, 1.0);
create_ok(-5.0, 11.0);
create_ok(-5.0, 100.0);
}

#[test]
fn test_bad_create() {
bad_create_case(0.0, 0.0);
bad_create_case(f64::NAN, 1.0);
bad_create_case(1.0, f64::NAN);
bad_create_case(f64::NAN, f64::NAN);
bad_create_case(0.0, f64::INFINITY);
bad_create_case(1.0, 0.0);
create_err(0.0, 0.0);
create_err(f64::NAN, 1.0);
create_err(1.0, f64::NAN);
create_err(f64::NAN, f64::NAN);
create_err(0.0, f64::INFINITY);
create_err(1.0, 0.0);
}

#[test]
fn test_variance() {
let variance = |x: Uniform| x.variance().unwrap();
test_case(-0.0, 2.0, 1.0 / 3.0, variance);
test_case(0.0, 2.0, 1.0 / 3.0, variance);
test_almost(0.1, 4.0, 1.2675, 1e-15, variance);
test_case(10.0, 11.0, 1.0 / 12.0, variance);
test_exact(-0.0, 2.0, 1.0 / 3.0, variance);
test_exact(0.0, 2.0, 1.0 / 3.0, variance);
test_absolute(0.1, 4.0, 1.2675, 1e-15, variance);
test_exact(10.0, 11.0, 1.0 / 12.0, variance);
}

#[test]
fn test_entropy() {
let entropy = |x: Uniform| x.entropy().unwrap();
test_case(-0.0, 2.0, 0.6931471805599453094172, entropy);
test_case(0.0, 2.0, 0.6931471805599453094172, entropy);
test_almost(0.1, 4.0, 1.360976553135600743431, 1e-15, entropy);
test_case(1.0, 10.0, 2.19722457733621938279, entropy);
test_case(10.0, 11.0, 0.0, entropy);
test_exact(-0.0, 2.0, 0.6931471805599453094172, entropy);
test_exact(0.0, 2.0, 0.6931471805599453094172, entropy);
test_absolute(0.1, 4.0, 1.360976553135600743431, 1e-15, entropy);
test_exact(1.0, 10.0, 2.19722457733621938279, entropy);
test_exact(10.0, 11.0, 0.0, entropy);
}

#[test]
fn test_skewness() {
let skewness = |x: Uniform| x.skewness().unwrap();
test_case(-0.0, 2.0, 0.0, skewness);
test_case(0.0, 2.0, 0.0, skewness);
test_case(0.1, 4.0, 0.0, skewness);
test_case(1.0, 10.0, 0.0, skewness);
test_case(10.0, 11.0, 0.0, skewness);
test_exact(-0.0, 2.0, 0.0, skewness);
test_exact(0.0, 2.0, 0.0, skewness);
test_exact(0.1, 4.0, 0.0, skewness);
test_exact(1.0, 10.0, 0.0, skewness);
test_exact(10.0, 11.0, 0.0, skewness);
}

#[test]
fn test_mode() {
let mode = |x: Uniform| x.mode().unwrap();
test_case(-0.0, 2.0, 1.0, mode);
test_case(0.0, 2.0, 1.0, mode);
test_case(0.1, 4.0, 2.05, mode);
test_case(1.0, 10.0, 5.5, mode);
test_case(10.0, 11.0, 10.5, mode);
test_exact(-0.0, 2.0, 1.0, mode);
test_exact(0.0, 2.0, 1.0, mode);
test_exact(0.1, 4.0, 2.05, mode);
test_exact(1.0, 10.0, 5.5, mode);
test_exact(10.0, 11.0, 10.5, mode);
}

#[test]
fn test_median() {
let median = |x: Uniform| x.median();
test_case(-0.0, 2.0, 1.0, median);
test_case(0.0, 2.0, 1.0, median);
test_case(0.1, 4.0, 2.05, median);
test_case(1.0, 10.0, 5.5, median);
test_case(10.0, 11.0, 10.5, median);
test_exact(-0.0, 2.0, 1.0, median);
test_exact(0.0, 2.0, 1.0, median);
test_exact(0.1, 4.0, 2.05, median);
test_exact(1.0, 10.0, 5.5, median);
test_exact(10.0, 11.0, 10.5, median);
}

#[test]
fn test_pdf() {
let pdf = |arg: f64| move |x: Uniform| x.pdf(arg);
test_case(0.0, 0.1, 0.0, pdf(-5.0));
test_case(0.0, 0.1, 10.0, pdf(0.05));
test_case(0.0, 0.1, 0.0, pdf(5.0));
test_case(0.0, 1.0, 0.0, pdf(-5.0));
test_case(0.0, 1.0, 1.0, pdf(0.5));
test_case(0.0, 0.1, 0.0, pdf(5.0));
test_case(0.0, 10.0, 0.0, pdf(-5.0));
test_case(0.0, 10.0, 0.1, pdf(1.0));
test_case(0.0, 10.0, 0.1, pdf(5.0));
test_case(0.0, 10.0, 0.0, pdf(11.0));
test_case(-5.0, 100.0, 0.0, pdf(-10.0));
test_case(-5.0, 100.0, 0.009523809523809523809524, pdf(-5.0));
test_case(-5.0, 100.0, 0.009523809523809523809524, pdf(0.0));
test_case(-5.0, 100.0, 0.0, pdf(101.0));
test_exact(0.0, 0.1, 0.0, pdf(-5.0));
test_exact(0.0, 0.1, 10.0, pdf(0.05));
test_exact(0.0, 0.1, 0.0, pdf(5.0));
test_exact(0.0, 1.0, 0.0, pdf(-5.0));
test_exact(0.0, 1.0, 1.0, pdf(0.5));
test_exact(0.0, 0.1, 0.0, pdf(5.0));
test_exact(0.0, 10.0, 0.0, pdf(-5.0));
test_exact(0.0, 10.0, 0.1, pdf(1.0));
test_exact(0.0, 10.0, 0.1, pdf(5.0));
test_exact(0.0, 10.0, 0.0, pdf(11.0));
test_exact(-5.0, 100.0, 0.0, pdf(-10.0));
test_exact(-5.0, 100.0, 0.009523809523809523809524, pdf(-5.0));
test_exact(-5.0, 100.0, 0.009523809523809523809524, pdf(0.0));
test_exact(-5.0, 100.0, 0.0, pdf(101.0));
}

#[test]
fn test_ln_pdf() {
let ln_pdf = |arg: f64| move |x: Uniform| x.ln_pdf(arg);
test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(-5.0));
test_almost(0.0, 0.1, 2.302585092994045684018, 1e-15, ln_pdf(0.05));
test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0));
test_case(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0));
test_case(0.0, 1.0, 0.0, ln_pdf(0.5));
test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0));
test_case(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(-5.0));
test_case(0.0, 10.0, -2.302585092994045684018, ln_pdf(1.0));
test_case(0.0, 10.0, -2.302585092994045684018, ln_pdf(5.0));
test_case(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(11.0));
test_case(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(-10.0));
test_case(-5.0, 100.0, -4.653960350157523371101, ln_pdf(-5.0));
test_case(-5.0, 100.0, -4.653960350157523371101, ln_pdf(0.0));
test_case(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(101.0));
test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(-5.0));
test_absolute(0.0, 0.1, 2.302585092994045684018, 1e-15, ln_pdf(0.05));
test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0));
test_exact(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0));
test_exact(0.0, 1.0, 0.0, ln_pdf(0.5));
test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0));
test_exact(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(-5.0));
test_exact(0.0, 10.0, -2.302585092994045684018, ln_pdf(1.0));
test_exact(0.0, 10.0, -2.302585092994045684018, ln_pdf(5.0));
test_exact(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(11.0));
test_exact(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(-10.0));
test_exact(-5.0, 100.0, -4.653960350157523371101, ln_pdf(-5.0));
test_exact(-5.0, 100.0, -4.653960350157523371101, ln_pdf(0.0));
test_exact(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(101.0));
}

#[test]
fn test_cdf() {
let cdf = |arg: f64| move |x: Uniform| x.cdf(arg);
test_case(0.0, 0.1, 0.5, cdf(0.05));
test_case(0.0, 1.0, 0.5, cdf(0.5));
test_case(0.0, 10.0, 0.1, cdf(1.0));
test_case(0.0, 10.0, 0.5, cdf(5.0));
test_case(-5.0, 100.0, 0.0, cdf(-5.0));
test_case(-5.0, 100.0, 0.04761904761904761904762, cdf(0.0));
test_exact(0.0, 0.1, 0.5, cdf(0.05));
test_exact(0.0, 1.0, 0.5, cdf(0.5));
test_exact(0.0, 10.0, 0.1, cdf(1.0));
test_exact(0.0, 10.0, 0.5, cdf(5.0));
test_exact(-5.0, 100.0, 0.0, cdf(-5.0));
test_exact(-5.0, 100.0, 0.04761904761904761904762, cdf(0.0));
}

#[test]
fn test_inverse_cdf() {
let inverse_cdf = |arg: f64| move |x: Uniform| x.inverse_cdf(arg);
test_case(0.0, 0.1, 0.05, inverse_cdf(0.5));
test_case(0.0, 10.0, 5.0, inverse_cdf(0.5));
test_case(1.0, 10.0, 1.0, inverse_cdf(0.0));
test_case(1.0, 10.0, 4.0, inverse_cdf(1.0 / 3.0));
test_case(1.0, 10.0, 10.0, inverse_cdf(1.0));
test_exact(0.0, 0.1, 0.05, inverse_cdf(0.5));
test_exact(0.0, 10.0, 5.0, inverse_cdf(0.5));
test_exact(1.0, 10.0, 1.0, inverse_cdf(0.0));
test_exact(1.0, 10.0, 4.0, inverse_cdf(1.0 / 3.0));
test_exact(1.0, 10.0, 10.0, inverse_cdf(1.0));
}

#[test]
fn test_cdf_lower_bound() {
let cdf = |arg: f64| move |x: Uniform| x.cdf(arg);
test_case(0.0, 3.0, 0.0, cdf(-1.0));
test_exact(0.0, 3.0, 0.0, cdf(-1.0));
}

#[test]
fn test_cdf_upper_bound() {
let cdf = |arg: f64| move |x: Uniform| x.cdf(arg);
test_case(0.0, 3.0, 1.0, cdf(5.0));
test_exact(0.0, 3.0, 1.0, cdf(5.0));
}


#[test]
fn test_sf() {
let sf = |arg: f64| move |x: Uniform| x.sf(arg);
test_case(0.0, 0.1, 0.5, sf(0.05));
test_case(0.0, 1.0, 0.5, sf(0.5));
test_case(0.0, 10.0, 0.9, sf(1.0));
test_case(0.0, 10.0, 0.5, sf(5.0));
test_case(-5.0, 100.0, 1.0, sf(-5.0));
test_case(-5.0, 100.0, 0.9523809523809523, sf(0.0));
test_exact(0.0, 0.1, 0.5, sf(0.05));
test_exact(0.0, 1.0, 0.5, sf(0.5));
test_exact(0.0, 10.0, 0.9, sf(1.0));
test_exact(0.0, 10.0, 0.5, sf(5.0));
test_exact(-5.0, 100.0, 1.0, sf(-5.0));
test_exact(-5.0, 100.0, 0.9523809523809523, sf(0.0));
}

#[test]
fn test_sf_lower_bound() {
let sf = |arg: f64| move |x: Uniform| x.sf(arg);
test_case(0.0, 3.0, 1.0, sf(-1.0));
test_exact(0.0, 3.0, 1.0, sf(-1.0));
}

#[test]
fn test_sf_upper_bound() {
let sf = |arg: f64| move |x: Uniform| x.sf(arg);
test_case(0.0, 3.0, 0.0, sf(5.0));
test_exact(0.0, 3.0, 0.0, sf(5.0));
}

#[test]
fn test_continuous() {
test::check_continuous_distribution(&try_create(0.0, 10.0), 0.0, 10.0);
test::check_continuous_distribution(&try_create(-2.0, 15.0), -2.0, 15.0);
test::check_continuous_distribution(&create_ok(0.0, 10.0), 0.0, 10.0);
test::check_continuous_distribution(&create_ok(-2.0, 15.0), -2.0, 15.0);
}

#[test]
Expand All @@ -511,7 +474,7 @@ mod tests {
let min = -0.5;
let max = 0.5;
let num_trials = 10_000;
let n = try_create(min, max);
let n = create_ok(min, max);

assert!((0..num_trials)
.map(|_| n.sample::<StdRng>(&mut r))
Expand Down

0 comments on commit 533a6bf

Please sign in to comment.