Skip to content

Commit

Permalink
test(errors): test proposed error hierarchy on Normal
Browse files Browse the repository at this point in the history
also includes fmt on edited file
  • Loading branch information
YeungOnion committed Jun 23, 2024
1 parent fa17fbd commit df9c2f9
Showing 1 changed file with 36 additions and 5 deletions.
41 changes: 36 additions & 5 deletions src/distribution/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,15 @@ impl std::default::Default for Normal {
}

#[rustfmt::skip]
#[allow(clippy::excessive_precision)]
#[cfg(test)]
mod tests {
use crate::statistics::*;
use crate::distribution::{ContinuousCDF, Continuous, Normal};
use super::*;
use crate::distribution::internal::*;
use crate::distribution::{Continuous, ContinuousCDF, Normal};
use crate::statistics::*;

use std::error::Error as _;

fn try_create(mean: f64, std_dev: f64) -> Normal {
let n = Normal::new(mean, std_dev);
Expand All @@ -402,15 +406,17 @@ mod tests {
}

fn test_case<F>(mean: f64, std_dev: f64, expected: f64, eval: F)
where F: Fn(Normal) -> f64
where
F: Fn(Normal) -> f64,
{
let n = try_create(mean, std_dev);
let x = eval(n);
assert_eq!(expected, x);
}

fn test_almost<F>(mean: f64, std_dev: f64, expected: f64, acc: f64, eval: F)
where F: Fn(Normal) -> f64
where
F: Fn(Normal) -> f64,
{
let n = try_create(mean, std_dev);
let x = eval(n);
Expand Down Expand Up @@ -602,11 +608,36 @@ mod tests {
let n = Normal::default();

let n_mean = n.mean().unwrap();
let n_std = n.std_dev().unwrap();
let n_std = n.std_dev().unwrap();

// Check that the mean of the distribution is close to 0
assert_almost_eq!(n_mean, 0.0, 1e-15);
// Check that the standard deviation of the distribution is close to 1
assert_almost_eq!(n_std, 1.0, 1e-15);
}

#[test]
fn test_errors() {
let n = Normal::new(f64::NAN, f64::INFINITY);
assert!(
matches!(
n.err().unwrap(),
Error::InvalidMean(DistrError::InvalidConstruction(
StatsError::NotNan | StatsError::Finite(_)
))
),
"n = {}",
n.err().unwrap()
);

let n = Normal::new(0.0, 0.0);
assert!(
matches!(
n.err().unwrap(),
Error::InvalidStdDev(DistrError::DegenerateConstruction(0.0))
),
"n = {:?}",
n.err().unwrap()
);
}
}

0 comments on commit df9c2f9

Please sign in to comment.