diff --git a/Cargo.toml b/Cargo.toml index 15b06e4e..b295a5fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,3 +44,8 @@ anyhow = "1.0" version = "0.32" default-features = false features = ["macros"] + +[lints.rust.unexpected_cfgs] +level = "warn" +# Set by cargo-llvm-cov when running on nightly +check-cfg = ['cfg(coverage_nightly)'] diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index f82c0d65..d5de981a 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -1,6 +1,5 @@ -use crate::distribution::{Binomial, Discrete, DiscreteCDF}; +use crate::distribution::{Binomial, BinomialError, Discrete, DiscreteCDF}; use crate::statistics::*; -use crate::Result; use rand::Rng; /// Implements the @@ -45,7 +44,7 @@ impl Bernoulli { /// result = Bernoulli::new(-0.5); /// assert!(result.is_err()); /// ``` - pub fn new(p: f64) -> Result { + pub fn new(p: f64) -> Result { Binomial::new(p, 1).map(|b| Bernoulli { b }) } @@ -265,11 +264,10 @@ impl Discrete for Bernoulli { #[rustfmt::skip] #[cfg(test)] mod testing { - use crate::distribution::DiscreteCDF; + use super::*; use crate::testing_boiler; - use super::Bernoulli; - testing_boiler!(p: f64; Bernoulli); + testing_boiler!(p: f64; Bernoulli; BinomialError); #[test] fn test_create() { diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 4febc3c0..e20ea302 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::{beta, gamma}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; /// Implements the [Beta](https://en.wikipedia.org/wiki/Beta_distribution) @@ -24,6 +23,33 @@ pub struct Beta { shape_b: f64, } +/// Represents the errors that can occur when creating a [`Beta`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum BetaError { + /// Shape A is NaN, zero or negative. + ShapeAInvalid, + + /// Shape B is NaN, zero or negative. + ShapeBInvalid, + + /// Shape A and Shape B are infinite. + BothShapesInfinite, +} + +impl std::fmt::Display for BetaError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + BetaError::ShapeAInvalid => write!(f, "Shape A is NaN, zero or negative"), + BetaError::ShapeBInvalid => write!(f, "Shape B is NaN, zero or negative"), + BetaError::BothShapesInfinite => write!(f, "Shape A and shape B are infinite"), + } + } +} + +impl std::error::Error for BetaError {} + impl Beta { /// Constructs a new beta distribution with shapeA (α) of `shape_a` /// and shapeB (β) of `shape_b` @@ -44,15 +70,19 @@ impl Beta { /// result = Beta::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape_a: f64, shape_b: f64) -> Result { - if shape_a.is_nan() - || shape_b.is_nan() - || shape_a.is_infinite() && shape_b.is_infinite() - || shape_a <= 0.0 - || shape_b <= 0.0 - { - return Err(StatsError::BadParams); - }; + pub fn new(shape_a: f64, shape_b: f64) -> Result { + if shape_a.is_nan() || shape_a <= 0.0 { + return Err(BetaError::ShapeAInvalid); + } + + if shape_b.is_nan() || shape_b <= 0.0 { + return Err(BetaError::ShapeBInvalid); + } + + if shape_a.is_infinite() && shape_b.is_infinite() { + return Err(BetaError::BothShapesInfinite); + } + Ok(Beta { shape_a, shape_b }) } @@ -433,7 +463,7 @@ mod tests { use super::super::internal::*; use crate::testing_boiler; - testing_boiler!(a: f64, b: f64; Beta); + testing_boiler!(a: f64, b: f64; Beta; BetaError); #[test] fn test_create() { diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 8eced1d7..1d86283d 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -1,7 +1,6 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{beta, factorial}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -26,6 +25,25 @@ pub struct Binomial { n: u64, } +/// Represents the errors that can occur when creating a [`Binomial`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum BinomialError { + /// The probability is NaN or not in `[0, 1]`. + ProbabilityInvalid, +} + +impl std::fmt::Display for BinomialError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + BinomialError::ProbabilityInvalid => write!(f, "Probability is NaN or not in [0, 1]"), + } + } +} + +impl std::error::Error for BinomialError {} + impl Binomial { /// Constructs a new binomial distribution /// with a given `p` probability of success of `n` @@ -47,9 +65,9 @@ impl Binomial { /// result = Binomial::new(-0.5, 5); /// assert!(result.is_err()); /// ``` - pub fn new(p: f64, n: u64) -> Result { + pub fn new(p: f64, n: u64) -> Result { if p.is_nan() || !(0.0..=1.0).contains(&p) { - Err(StatsError::BadParams) + Err(BinomialError::ProbabilityInvalid) } else { Ok(Binomial { p, n }) } @@ -328,12 +346,11 @@ impl Discrete for Binomial { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::*; - use crate::distribution::{DiscreteCDF, Discrete, Binomial}; + use super::*; use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(p: f64, n: u64; Binomial); + testing_boiler!(p: f64, n: u64; Binomial; BinomialError); #[test] fn test_create() { diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 71e09560..cb3c7ea8 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -1,6 +1,5 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -27,6 +26,36 @@ pub struct Categorical { sf: Vec, } +/// Represents the errors that can occur when creating a [`Categorical`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum CategoricalError { + /// The probability mass is empty. + ProbMassEmpty, + + /// The probabilities sums up to zero. + ProbMassSumZero, + + /// The probability mass contains at least one element which is NaN or less than zero. + ProbMassHasInvalidElements, +} + +impl std::fmt::Display for CategoricalError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + CategoricalError::ProbMassEmpty => write!(f, "Probability mass is empty"), + CategoricalError::ProbMassSumZero => write!(f, "Probabilities sum up to zero"), + CategoricalError::ProbMassHasInvalidElements => write!( + f, + "Probability mass contains at least one element which is NaN or less than zero" + ), + } + } +} + +impl std::error::Error for CategoricalError {} + impl Categorical { /// Constructs a new categorical distribution /// with the probabilities masses defined by `prob_mass` @@ -52,23 +81,36 @@ impl Categorical { /// result = Categorical::new(&[0.0, -1.0, 2.0]); /// assert!(result.is_err()); /// ``` - pub fn new(prob_mass: &[f64]) -> Result { - if !super::internal::is_valid_multinomial(prob_mass, true) { - Err(StatsError::BadParams) - } else { - // extract un-normalized cdf - let cdf = prob_mass_to_cdf(prob_mass); - // extract un-normalized sf - let sf = cdf_to_sf(&cdf); - // extract normalized probability mass - let sum = cdf[cdf.len() - 1]; - let mut norm_pmf = vec![0.0; prob_mass.len()]; - norm_pmf - .iter_mut() - .zip(prob_mass.iter()) - .for_each(|(np, pm)| *np = *pm / sum); - Ok(Categorical { norm_pmf, cdf, sf }) + pub fn new(prob_mass: &[f64]) -> Result { + if prob_mass.is_empty() { + return Err(CategoricalError::ProbMassEmpty); + } + + let mut prob_sum = 0.0; + for &p in prob_mass { + if p.is_nan() || p < 0.0 { + return Err(CategoricalError::ProbMassHasInvalidElements); + } + + prob_sum += p; } + + if prob_sum == 0.0 { + return Err(CategoricalError::ProbMassSumZero); + } + + // extract un-normalized cdf + let cdf = prob_mass_to_cdf(prob_mass); + // extract un-normalized sf + let sf = cdf_to_sf(&cdf); + // extract normalized probability mass + let sum = cdf[cdf.len() - 1]; + let mut norm_pmf = vec![0.0; prob_mass.len()]; + norm_pmf + .iter_mut() + .zip(prob_mass.iter()) + .for_each(|(np, pm)| *np = *pm / sum); + Ok(Categorical { norm_pmf, cdf, sf }) } fn cdf_max(&self) -> f64 { @@ -351,12 +393,11 @@ fn test_binary_index() { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::*; - use crate::distribution::{Categorical, Discrete, DiscreteCDF}; + use super::*; use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(prob_mass: &[f64]; Categorical); + testing_boiler!(prob_mass: &[f64]; Categorical; CategoricalError); #[test] fn test_create() { @@ -365,8 +406,15 @@ mod tests { #[test] fn test_bad_create() { - create_err(&[-1.0, 1.0]); - create_err(&[0.0, 0.0]); + let invalid: &[(&[f64], CategoricalError)] = &[ + (&[], CategoricalError::ProbMassEmpty), + (&[-1.0, 1.0], CategoricalError::ProbMassHasInvalidElements), + (&[0.0, 0.0, 0.0], CategoricalError::ProbMassSumZero), + ]; + + for &(prob_mass, err) in invalid { + test_create_err(prob_mass, err); + } } #[test] diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index eb983847..5ba7f69f 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -23,6 +22,29 @@ pub struct Cauchy { scale: f64, } +/// Represents the errors that can occur when creating a [`Cauchy`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum CauchyError { + /// The location is NaN. + LocationInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, +} + +impl std::fmt::Display for CauchyError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + CauchyError::LocationInvalid => write!(f, "Location is NaN"), + CauchyError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for CauchyError {} + impl Cauchy { /// Constructs a new cauchy distribution with the given /// location and scale. @@ -42,12 +64,16 @@ impl Cauchy { /// result = Cauchy::new(0.0, -1.0); /// assert!(result.is_err()); /// ``` - pub fn new(location: f64, scale: f64) -> Result { - if location.is_nan() || scale.is_nan() || scale <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(Cauchy { location, scale }) + pub fn new(location: f64, scale: f64) -> Result { + if location.is_nan() { + return Err(CauchyError::LocationInvalid); + } + + if scale.is_nan() || scale <= 0.0 { + return Err(CauchyError::ScaleInvalid); } + + Ok(Cauchy { location, scale }) } /// Returns the location of the cauchy distribution @@ -252,11 +278,11 @@ impl Continuous for Cauchy { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::{statistics::*, testing_boiler}; - use crate::distribution::{ContinuousCDF, Continuous, Cauchy}; + use super::*; use crate::distribution::internal::*; + use crate::testing_boiler; - testing_boiler!(location: f64, scale: f64; Cauchy); + testing_boiler!(location: f64, scale: f64; Cauchy; CauchyError); #[test] fn test_create() { @@ -270,10 +296,16 @@ mod tests { #[test] fn test_bad_create() { - create_err(f64::NAN, 1.0); - create_err(1.0, f64::NAN); - create_err(f64::NAN, f64::NAN); - create_err(1.0, 0.0); + let invalid = [ + (f64::NAN, 1.0, CauchyError::LocationInvalid), + (1.0, f64::NAN, CauchyError::ScaleInvalid), + (f64::NAN, f64::NAN, CauchyError::LocationInvalid), + (1.0, 0.0, CauchyError::ScaleInvalid), + ]; + + for (location, scale, err) in invalid { + test_create_err(location, scale, err); + } } #[test] diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index 1bb74295..796fcd23 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -24,6 +23,27 @@ pub struct Chi { freedom: f64, } +/// Represents the errors that can occur when creating a [`Chi`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum ChiError { + /// The degrees of freedom are NaN, zero or less than zero. + FreedomInvalid, +} + +impl std::fmt::Display for ChiError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ChiError::FreedomInvalid => { + write!(f, "Degrees of freedom are NaN, zero or less than zero") + } + } + } +} + +impl std::error::Error for ChiError {} + impl Chi { /// Constructs a new chi distribution /// with `freedom` degrees of freedom @@ -44,9 +64,9 @@ impl Chi { /// result = Chi::new(0.0); /// assert!(result.is_err()); /// ``` - pub fn new(freedom: f64) -> Result { + pub fn new(freedom: f64) -> Result { if freedom.is_nan() || freedom <= 0.0 { - Err(StatsError::BadParams) + Err(ChiError::FreedomInvalid) } else { Ok(Chi { freedom }) } @@ -325,13 +345,11 @@ impl Continuous for Chi { #[rustfmt::skip] #[cfg(test)] mod tests { - use std::f64; + use super::*; use crate::distribution::internal::*; - use crate::distribution::{Chi, Continuous, ContinuousCDF}; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(freedom: f64; Chi); + testing_boiler!(freedom: f64; Chi; ChiError); #[test] fn test_create() { diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index afa5df71..a847ac94 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -1,6 +1,5 @@ -use crate::distribution::{Continuous, ContinuousCDF, Gamma}; +use crate::distribution::{Continuous, ContinuousCDF, Gamma, GammaError}; use crate::statistics::*; -use crate::Result; use rand::Rng; use std::f64; @@ -48,7 +47,7 @@ impl ChiSquared { /// result = ChiSquared::new(0.0); /// assert!(result.is_err()); /// ``` - pub fn new(freedom: f64) -> Result { + pub fn new(freedom: f64) -> Result { Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared { freedom, g }) } @@ -306,12 +305,11 @@ impl Continuous for ChiSquared { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::Median; - use crate::distribution::ChiSquared; + use super::*; use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(freedom: f64; ChiSquared); + testing_boiler!(freedom: f64; ChiSquared; GammaError); #[test] fn test_median() { diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index 41ac1d6c..18e70f9b 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -1,6 +1,5 @@ use crate::distribution::ContinuousCDF; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; /// Implements the [Dirac Delta](https://en.wikipedia.org/wiki/Dirac_delta_function#As_a_distribution) @@ -18,8 +17,27 @@ use rand::Rng; #[derive(Debug, Copy, Clone, PartialEq)] pub struct Dirac(f64); +/// Represents the errors that can occur when creating a [`Dirac`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum DiracError { + /// The value v is NaN. + ValueInvalid, +} + +impl std::fmt::Display for DiracError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + DiracError::ValueInvalid => write!(f, "Value v is NaN"), + } + } +} + +impl std::error::Error for DiracError {} + impl Dirac { - /// Constructs a new dirac distribution function at value `v`. + /// Constructs a new dirac distribution function at value `v`. /// /// # Errors /// @@ -36,9 +54,9 @@ impl Dirac { /// result = Dirac::new(f64::NAN); /// assert!(result.is_err()); /// ``` - pub fn new(v: f64) -> Result { + pub fn new(v: f64) -> Result { if v.is_nan() { - Err(StatsError::BadParams) + Err(DiracError::ValueInvalid) } else { Ok(Dirac(v)) } @@ -193,11 +211,10 @@ impl Mode> for Dirac { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Dirac}; - use crate::statistics::*; + use super::*; use crate::testing_boiler; - testing_boiler!(v: f64; Dirac); + testing_boiler!(v: f64; Dirac; DiracError); #[test] fn test_create() { diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index f058b46d..355476db 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -1,7 +1,7 @@ use crate::distribution::Continuous; use crate::function::gamma; +use crate::prec; use crate::statistics::*; -use crate::{prec, Result, StatsError}; use nalgebra::{Const, Dim, Dyn, OMatrix, OVector}; use rand::Rng; use std::f64; @@ -31,6 +31,32 @@ where alpha: OVector, } +/// Represents the errors that can occur when creating a [`Dirichlet`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum DirichletError { + /// Alpha contains less than two elements. + AlphaTooShort, + + /// Alpha contains an element that is NaN, infinite, zero or less than zero. + AlphaHasInvalidElements, +} + +impl std::fmt::Display for DirichletError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + DirichletError::AlphaTooShort => write!(f, "Alpha contains less than two elements"), + DirichletError::AlphaHasInvalidElements => write!( + f, + "Alpha contains an element that is NaN, infinite, zero or less than zero" + ), + } + } +} + +impl std::error::Error for DirichletError {} + impl Dirichlet { /// Constructs a new dirichlet distribution with the given /// concentration parameters (alpha) @@ -55,7 +81,7 @@ impl Dirichlet { /// result = Dirichlet::new(alpha_err); /// assert!(result.is_err()); /// ``` - pub fn new(alpha: Vec) -> Result { + pub fn new(alpha: Vec) -> Result { Self::new_from_nalgebra(alpha.into()) } @@ -78,7 +104,7 @@ impl Dirichlet { /// result = Dirichlet::new_with_param(0.0, 1); /// assert!(result.is_err()); /// ``` - pub fn new_with_param(alpha: f64, n: usize) -> Result { + pub fn new_with_param(alpha: f64, n: usize) -> Result { Self::new(vec![alpha; n]) } } @@ -95,12 +121,16 @@ where /// /// Returns an error if vector has length less than 2 or if any element /// of alpha is NOT finite positive - pub fn new_from_nalgebra(alpha: OVector) -> Result { - if !is_valid_alpha(alpha.as_slice()) { - Err(StatsError::BadParams) - } else { - Ok(Self { alpha }) + pub fn new_from_nalgebra(alpha: OVector) -> Result { + if alpha.len() < 2 { + return Err(DirichletError::AlphaTooShort); + } + + if alpha.iter().any(|&a_i| !a_i.is_finite() || a_i <= 0.0) { + return Err(DirichletError::AlphaHasInvalidElements); } + + Ok(Self { alpha }) } /// Returns the concentration parameters of @@ -336,25 +366,15 @@ where } } -// determines if `a` is a valid alpha array -// for the Dirichlet distribution -fn is_valid_alpha(a: &[f64]) -> bool { - a.len() >= 2 && a.iter().all(|&a_i| a_i.is_finite() && a_i > 0.0) -} - #[rustfmt::skip] #[cfg(test)] mod tests { + use super::*; + use std::fmt::{Debug, Display}; use nalgebra::{dmatrix, dvector, vector, DimMin, OVector}; - use super::is_valid_alpha; - use crate::{ - distribution::{Continuous, Dirichlet}, - statistics::{MeanN, VarianceN}, - }; - fn try_create(alpha: OVector) -> Dirichlet where D: DimMin, @@ -386,18 +406,9 @@ mod tests { assert_relative_eq!(expected, x, epsilon = acc); } - #[test] - fn test_is_valid_alpha() { - assert!(!is_valid_alpha(&[1.0])); - assert!(!is_valid_alpha(&[1.0, f64::NAN])); - assert!(is_valid_alpha(&[1.0, 2.0])); - assert!(!is_valid_alpha(&[1.0, 0.0])); - assert!(!is_valid_alpha(&[1.0, f64::INFINITY])); - assert!(!is_valid_alpha(&[-1.0, 2.0])); - } - #[test] fn test_create() { + try_create(vector![1.0, 2.0]); try_create(vector![1.0, 2.0, 3.0, 4.0, 5.0]); assert!(Dirichlet::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]).is_ok()); // try_create(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate @@ -405,6 +416,10 @@ mod tests { #[test] fn test_bad_create() { + bad_create_case(vector![1.0, f64::NAN]); + bad_create_case(vector![1.0, 0.0]); + bad_create_case(vector![1.0, f64::INFINITY]); + bad_create_case(vector![-1.0, 2.0]); bad_create_case(vector![1.0]); bad_create_case(vector![1.0, 2.0, 0.0, 4.0, 5.0]); bad_create_case(vector![1.0, f64::NAN, 3.0, 4.0, 5.0]); @@ -562,4 +577,10 @@ mod tests { let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); n.ln_pdf(&vector![0.5, 0.25, 0.8, 0.9]); } + + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::(); + } } diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index 6871c80a..524bc2a3 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -1,6 +1,5 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; /// Implements the [Discrete @@ -23,6 +22,25 @@ pub struct DiscreteUniform { max: i64, } +/// Represents the errors that can occur when creating a [`DiscreteUniform`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum DiscreteUniformError { + /// The maximum is less than the minimum. + MinMaxInvalid, +} + +impl std::fmt::Display for DiscreteUniformError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + DiscreteUniformError::MinMaxInvalid => write!(f, "Maximum is less than minimum"), + } + } +} + +impl std::error::Error for DiscreteUniformError {} + impl DiscreteUniform { /// Constructs a new discrete uniform distribution with a minimum value /// of `min` and a maximum value of `max`. @@ -42,9 +60,9 @@ impl DiscreteUniform { /// result = DiscreteUniform::new(5, 0); /// assert!(result.is_err()); /// ``` - pub fn new(min: i64, max: i64) -> Result { + pub fn new(min: i64, max: i64) -> Result { if max < min { - Err(StatsError::BadParams) + Err(DiscreteUniformError::MinMaxInvalid) } else { Ok(DiscreteUniform { min, max }) } @@ -256,11 +274,10 @@ impl Discrete for DiscreteUniform { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{DiscreteCDF, Discrete, DiscreteUniform}; - use crate::statistics::*; + use super::*; use crate::testing_boiler; - testing_boiler!(min: i64, max: i64; DiscreteUniform); + testing_boiler!(min: i64, max: i64; DiscreteUniform; DiscreteUniformError); #[test] fn test_create() { diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index 104169aa..6dc7ec71 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -1,6 +1,5 @@ use crate::distribution::{ContinuousCDF, Uniform}; use crate::statistics::*; -use crate::Result; use core::cmp::Ordering; use rand::Rng; use std::collections::BTreeMap; @@ -48,6 +47,8 @@ impl Empirical { /// Constructs a new discrete uniform distribution with a minimum value /// of `min` and a maximum value of `max`. /// + /// Note that this will always succeed and never return the [`Err`][Result::Err] variant. + /// /// # Examples /// /// ``` @@ -56,7 +57,8 @@ impl Empirical { /// let mut result = Empirical::new(); /// assert!(result.is_ok()); /// ``` - pub fn new() -> Result { + #[allow(clippy::result_unit_err)] + pub fn new() -> Result { Ok(Empirical { sum: 0., mean_and_var: None, diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index ce6f68aa..9b7a332c 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -1,6 +1,5 @@ -use crate::distribution::{Continuous, ContinuousCDF, Gamma}; +use crate::distribution::{Continuous, ContinuousCDF, Gamma, GammaError}; use crate::statistics::*; -use crate::Result; use rand::Rng; /// Implements the [Erlang](https://en.wikipedia.org/wiki/Erlang_distribution) @@ -45,7 +44,7 @@ impl Erlang { /// result = Erlang::new(0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: u64, rate: f64) -> Result { + pub fn new(shape: u64, rate: f64) -> Result { Gamma::new(shape as f64, rate).map(|g| Erlang { g }) } @@ -293,11 +292,11 @@ impl Continuous for Erlang { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::Erlang; + use super::*; use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(shape: u64, rate: f64; Erlang); + testing_boiler!(shape: u64, rate: f64; Erlang; GammaError); #[test] fn test_create() { @@ -310,10 +309,16 @@ mod tests { #[test] fn test_bad_create() { - create_err(0, 1.0); - create_err(1, 0.0); - create_err(1, f64::NAN); - create_err(1, -1.0); + let invalid = [ + (0, 1.0, GammaError::ShapeInvalid), + (1, 0.0, GammaError::RateInvalid), + (1, f64::NAN, GammaError::RateInvalid), + (1, -1.0, GammaError::RateInvalid), + ]; + + for (s, r, err) in invalid { + test_create_err(s, r, err); + } } #[test] diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index d5a54d56..ec30d1f7 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -1,6 +1,5 @@ use crate::distribution::{ziggurat, Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -25,13 +24,32 @@ pub struct Exp { rate: f64, } +/// Represents the errors that can occur when creating a [`Exp`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum ExpError { + /// The rate is NaN, zero or less than zero. + RateInvalid, +} + +impl std::fmt::Display for ExpError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ExpError::RateInvalid => write!(f, "Rate is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for ExpError {} + impl Exp { /// Constructs a new exponential distribution with a /// rate (λ) of `rate`. /// /// # Errors /// - /// Returns an error if rate is `NaN` or `rate <= 0.0` + /// Returns an error if rate is `NaN` or `rate <= 0.0`. /// /// # Examples /// @@ -44,9 +62,9 @@ impl Exp { /// result = Exp::new(-1.0); /// assert!(result.is_err()); /// ``` - pub fn new(rate: f64) -> Result { + pub fn new(rate: f64) -> Result { if rate.is_nan() || rate <= 0.0 { - Err(StatsError::BadParams) + Err(ExpError::RateInvalid) } else { Ok(Exp { rate }) } @@ -279,13 +297,11 @@ impl Continuous for Exp { #[rustfmt::skip] #[cfg(test)] mod tests { - use std::f64; - use crate::distribution::{ContinuousCDF, Continuous, Exp}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(rate: f64; Exp); + testing_boiler!(rate: f64; Exp; ExpError); #[test] fn test_create() { diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index 9d5ef867..610da130 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::beta; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -26,6 +25,33 @@ pub struct FisherSnedecor { freedom_2: f64, } +/// Represents the errors that can occur when creating a [`FisherSnedecor`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum FisherSnedecorError { + /// `freedom_1` is NaN, infinite, zero or less than zero. + Freedom1Invalid, + + /// `freedom_2` is NaN, infinite, zero or less than zero. + Freedom2Invalid, +} + +impl std::fmt::Display for FisherSnedecorError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + FisherSnedecorError::Freedom1Invalid => { + write!(f, "freedom_1 is NaN, infinite, zero or less than zero.") + } + FisherSnedecorError::Freedom2Invalid => { + write!(f, "freedom_2 is NaN, infinite, zero or less than zero.") + } + } + } +} + +impl std::error::Error for FisherSnedecorError {} + impl FisherSnedecor { /// Constructs a new fisher-snedecor distribution with /// degrees of freedom `freedom_1` and `freedom_2` @@ -46,16 +72,19 @@ impl FisherSnedecor { /// result = FisherSnedecor::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(freedom_1: f64, freedom_2: f64) -> Result { - if !freedom_1.is_finite() || freedom_1 <= 0.0 || !freedom_2.is_finite() || freedom_2 <= 0.0 - { - Err(StatsError::BadParams) - } else { - Ok(FisherSnedecor { - freedom_1, - freedom_2, - }) + pub fn new(freedom_1: f64, freedom_2: f64) -> Result { + if !freedom_1.is_finite() || freedom_1 <= 0.0 { + return Err(FisherSnedecorError::Freedom1Invalid); + } + + if !freedom_2.is_finite() || freedom_2 <= 0.0 { + return Err(FisherSnedecorError::Freedom2Invalid); } + + Ok(FisherSnedecor { + freedom_1, + freedom_2, + }) } /// Returns the first degree of freedom for the @@ -385,12 +414,11 @@ impl Continuous for FisherSnedecor { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, FisherSnedecor}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(freedom_1: f64, freedom_2: f64; FisherSnedecor); + testing_boiler!(freedom_1: f64, freedom_2: f64; FisherSnedecor; FisherSnedecorError); #[test] fn test_create() { @@ -404,6 +432,9 @@ mod tests { #[test] fn test_bad_create() { + test_create_err(f64::INFINITY, 0.1, FisherSnedecorError::Freedom1Invalid); + test_create_err(0.1, f64::INFINITY, FisherSnedecorError::Freedom2Invalid); + create_err(f64::NAN, f64::NAN); create_err(0.0, f64::NAN); create_err(-1.0, f64::NAN); @@ -420,8 +451,6 @@ mod tests { create_err(0.0, -10.0); create_err(-1.0, -10.0); create_err(-10.0, -10.0); - create_err(f64::INFINITY, 0.1); - create_err(0.1, f64::INFINITY); create_err(f64::INFINITY, f64::INFINITY); } diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index e986037a..b6055c77 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -2,7 +2,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::prec; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; /// Implements the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) @@ -25,6 +24,33 @@ pub struct Gamma { rate: f64, } +/// Represents the errors that can occur when creating a [`Gamma`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum GammaError { + /// The shape is NaN, zero or less than zero. + ShapeInvalid, + + /// The rate is NaN, zero or less than zero. + RateInvalid, + + /// The shape and rate are both infinite. + ShapeAndRateInfinite, +} + +impl std::fmt::Display for GammaError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + GammaError::ShapeInvalid => write!(f, "Shape is NaN zero, or less than zero."), + GammaError::RateInvalid => write!(f, "Rate is NaN zero, or less than zero."), + GammaError::ShapeAndRateInfinite => write!(f, "Shape and rate are infinite"), + } + } +} + +impl std::error::Error for GammaError {} + impl Gamma { /// Constructs a new gamma distribution with a shape (α) /// of `shape` and a rate (β) of `rate` @@ -45,15 +71,19 @@ impl Gamma { /// result = Gamma::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: f64, rate: f64) -> Result { - if shape.is_nan() - || rate.is_nan() - || shape.is_infinite() && rate.is_infinite() - || shape <= 0.0 - || rate <= 0.0 - { - return Err(StatsError::BadParams); + pub fn new(shape: f64, rate: f64) -> Result { + if shape.is_nan() || shape <= 0.0 { + return Err(GammaError::ShapeInvalid); + } + + if rate.is_nan() || rate <= 0.0 { + return Err(GammaError::RateInvalid); } + + if shape.is_infinite() && rate.is_infinite() { + return Err(GammaError::ShapeAndRateInfinite); + } + Ok(Gamma { shape, rate }) } @@ -406,7 +436,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(shape: f64, rate: f64; Gamma); + testing_boiler!(shape: f64, rate: f64; Gamma; GammaError); #[test] fn test_create() { @@ -426,15 +456,20 @@ mod tests { #[test] fn test_bad_create() { let invalid = [ - (0.0, 0.0), - (1.0, f64::NAN), - (1.0, -1.0), - (-1.0, 1.0), - (-1.0, -1.0), - (-1.0, f64::NAN), + (0.0, 0.0, GammaError::ShapeInvalid), + (1.0, f64::NAN, GammaError::RateInvalid), + (1.0, -1.0, GammaError::RateInvalid), + (-1.0, 1.0, GammaError::ShapeInvalid), + (-1.0, -1.0, GammaError::ShapeInvalid), + (-1.0, f64::NAN, GammaError::ShapeInvalid), + ( + f64::INFINITY, + f64::INFINITY, + GammaError::ShapeAndRateInfinite, + ), ]; - for (s, r) in invalid { - create_err(s, r); + for (s, r, err) in invalid { + test_create_err(s, r, err); } } diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index f584cc0e..41e35f52 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -1,6 +1,5 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::distributions::OpenClosed01; use rand::Rng; use std::f64; @@ -25,6 +24,25 @@ pub struct Geometric { p: f64, } +/// Represents the errors that can occur when creating a [`Geometric`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum GeometricError { + /// The probability is NaN or not in `(0, 1]`. + ProbabilityInvalid, +} + +impl std::fmt::Display for GeometricError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + GeometricError::ProbabilityInvalid => write!(f, "Probability is NaN or not in (0, 1]"), + } + } +} + +impl std::error::Error for GeometricError {} + impl Geometric { /// Constructs a new shifted geometric distribution with a probability /// of `p` @@ -44,9 +62,9 @@ impl Geometric { /// result = Geometric::new(0.0); /// assert!(result.is_err()); /// ``` - pub fn new(p: f64) -> Result { + pub fn new(p: f64) -> Result { if p <= 0.0 || p > 1.0 || p.is_nan() { - Err(StatsError::BadParams) + Err(GeometricError::ProbabilityInvalid) } else { Ok(Geometric { p }) } @@ -273,12 +291,11 @@ impl Discrete for Geometric { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{DiscreteCDF, Discrete, Geometric}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(p: f64; Geometric); + testing_boiler!(p: f64; Geometric; GeometricError); #[test] fn test_create() { diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 43bc30ca..7da6f45a 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -1,7 +1,6 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::factorial; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::cmp; use std::f64; @@ -17,15 +16,38 @@ pub struct Hypergeometric { draws: u64, } +/// Represents the errors that can occur when creating a [`Hypergeometric`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum HypergeometricError { + /// The number of successes is greater than the population. + TooManySuccesses, + + /// The number of draws is greater than the population. + TooManyDraws, +} + +impl std::fmt::Display for HypergeometricError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + HypergeometricError::TooManySuccesses => write!(f, "successes > population"), + HypergeometricError::TooManyDraws => write!(f, "draws > population"), + } + } +} + +impl std::error::Error for HypergeometricError {} + impl Hypergeometric { /// Constructs a new hypergeometric distribution /// with a population (N) of `population`, number /// of successes (K) of `successes`, and number of draws - /// (n) of `draws` + /// (n) of `draws`. /// /// # Errors /// - /// If `successes > population` or `draws > population` + /// If `successes > population` or `draws > population`. /// /// # Examples /// @@ -38,16 +60,24 @@ impl Hypergeometric { /// result = Hypergeometric::new(2, 3, 2); /// assert!(result.is_err()); /// ``` - pub fn new(population: u64, successes: u64, draws: u64) -> Result { - if successes > population || draws > population { - Err(StatsError::BadParams) - } else { - Ok(Hypergeometric { - population, - successes, - draws, - }) + pub fn new( + population: u64, + successes: u64, + draws: u64, + ) -> Result { + if successes > population { + return Err(HypergeometricError::TooManySuccesses); } + + if draws > population { + return Err(HypergeometricError::TooManyDraws); + } + + Ok(Hypergeometric { + population, + successes, + draws, + }) } /// Returns the population size of the hypergeometric @@ -372,12 +402,11 @@ impl Discrete for Hypergeometric { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{DiscreteCDF, Discrete, Hypergeometric}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(population: u64, successes: u64, draws: u64; Hypergeometric); + testing_boiler!(population: u64, successes: u64, draws: u64; Hypergeometric; HypergeometricError); #[test] fn test_create() { @@ -391,8 +420,8 @@ mod tests { #[test] fn test_bad_create() { - create_err(2, 3, 2); - create_err(10, 5, 20); + test_create_err(2, 3, 2, HypergeometricError::TooManySuccesses); + test_create_err(10, 5, 20, HypergeometricError::TooManyDraws); create_err(0, 1, 1); } diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index ae853637..9e7651b0 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -1,56 +1,5 @@ use num_traits::Num; -/// Returns true if there are no elements in `x` in `arr` -/// such that `x <= 0.0` or `x` is `f64::NAN` and `sum(arr) > 0.0`. -/// IF `incl_zero` is true, it tests for `x < 0.0` instead of `x <= 0.0` -pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool { - let mut sum = 0.0; - for &elt in arr { - if incl_zero && elt < 0.0 || !incl_zero && elt <= 0.0 || elt.is_nan() { - return false; - } - sum += elt; - } - sum != 0.0 -} - -#[cfg(feature = "nalgebra")] -use nalgebra::{Dim, OVector}; - -#[cfg(feature = "nalgebra")] -pub fn check_multinomial(arr: &OVector, accept_zeroes: bool) -> crate::Result<()> -where - D: Dim, - nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, -{ - use crate::StatsError; - - if arr.len() < 2 { - return Err(StatsError::BadParams); - } - let mut sum = 0.0; - for &x in arr.iter() { - #[allow(clippy::if_same_then_else)] - if x.is_nan() { - return Err(StatsError::BadParams); - } else if x.is_infinite() { - return Err(StatsError::BadParams); - } else if x < 0.0 { - return Err(StatsError::BadParams); - } else if x == 0.0 && !accept_zeroes { - return Err(StatsError::BadParams); - } else { - sum += x; - } - } - - if sum != 0.0 { - Ok(()) - } else { - Err(StatsError::BadParams) - } -} - /// Implements univariate function bisection searching for criteria /// ```text /// smallest k such that f(k) >= z @@ -100,7 +49,7 @@ pub mod test { #[macro_export] macro_rules! testing_boiler { - ($($arg_name:ident: $arg_ty:ty),+; $dist:ty) => { + ($($arg_name:ident: $arg_ty:ty),+; $dist:ty; $dist_err:ty) => { fn make_param_text($($arg_name: $arg_ty),+) -> String { // "" let mut param_text = String::new(); @@ -140,7 +89,7 @@ pub mod test { /// Returns the error when creating a distribution with the given parameters, /// panicking if `::new` succeeds. #[allow(dead_code)] - fn create_err($($arg_name: $arg_ty),+) -> $crate::StatsError { + fn create_err($($arg_name: $arg_ty),+) -> $dist_err { match <$dist>::new($($arg_name),+) { Err(e) => e, Ok(d) => panic!( @@ -240,6 +189,25 @@ pub mod test { } } + /// Purposely fails creating a distribution with the given + /// parameters and compares the returned error to `expected`. + /// + /// Panics if `::new` succeeds. + #[allow(dead_code)] + fn test_create_err($($arg_name: $arg_ty),+, expected: $dist_err) + { + let err = create_err($($arg_name),+); + if err != expected { + panic!( + "{}::new was expected to fail with error {:?}, but failed with error {:?} for {}", + stringify!($dist), + expected, + err, + make_param_text($($arg_name),+) + ) + } + } + /// Gets a value for the given parameters by calling `create_and_get` /// and asserts that it is [`NAN`]. /// @@ -273,96 +241,112 @@ pub mod test { ) } } + + /// Asserts that associated error type is Send and Sync + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::<$dist_err>(); + } }; } pub mod boiler_tests { - use crate::distribution::Binomial; + use crate::distribution::{Beta, BetaError}; use crate::statistics::*; - use crate::StatsError; - testing_boiler!(p: f64, n: u64; Binomial); + testing_boiler!(shape_a: f64, shape_b: f64; Beta; BetaError); #[test] fn create_ok_success() { - let b = create_ok(0.8, 1200); - assert_eq!(b.p(), 0.8); - assert_eq!(b.n(), 1200); + let b = create_ok(0.8, 1.2); + assert_eq!(b.shape_a(), 0.8); + assert_eq!(b.shape_b(), 1.2); } #[test] #[should_panic] fn create_err_failure() { - create_err(0.8, 1200); + create_err(0.8, 1.2); } #[test] fn create_err_success() { - let err = create_err(-0.5, 1000); - assert_eq!(err, StatsError::BadParams); + let err = create_err(-0.5, 1.2); + assert_eq!(err, BetaError::ShapeAInvalid); } #[test] #[should_panic] fn create_ok_failure() { - create_ok(-0.5, 1000); + create_ok(-0.5, 1.2); } #[test] fn test_exact_success() { - test_exact(0.0, 4, 0.0, |dist| dist.mean().unwrap()); + test_exact(1.5, 1.5, 0.5, |dist| dist.mode().unwrap()); } #[test] #[should_panic] fn test_exact_failure() { - test_exact(0.3, 3, 0.9, |dist| dist.mean().unwrap()); + test_exact(1.2, 1.4, 0.333333333333, |dist| dist.mode().unwrap()); } #[test] fn test_relative_success() { - test_relative(0.3, 3, 0.9, |dist| dist.mean().unwrap()); + test_relative(1.2, 1.4, 0.333333333333, |dist| dist.mode().unwrap()); } #[test] #[should_panic] fn test_relative_failure() { - test_relative(0.3, 3, 0.8, |dist| dist.mean().unwrap()); + test_relative(1.2, 1.4, 0.333, |dist| dist.mode().unwrap()); } #[test] fn test_absolute_success() { - test_absolute(0.3, 3, 0.9, 1e-15, |dist| dist.mean().unwrap()); + test_absolute(1.2, 1.4, 0.333333333333, 1e-12, |dist| dist.mode().unwrap()); } #[test] #[should_panic] fn test_absolute_failure() { - test_absolute(0.3, 3, 0.9, 1e-17, |dist| dist.mean().unwrap()); + test_absolute(1.2, 1.4, 0.333333333333, 1e-15, |dist| dist.mode().unwrap()); + } + + #[test] + fn test_create_err_success() { + test_create_err(0.0, 0.5, BetaError::ShapeAInvalid); + } + + #[test] + #[should_panic] + fn test_create_err_failure() { + test_create_err(0.0, 0.5, BetaError::BothShapesInfinite); } #[test] fn test_is_nan_success() { - // Not sure that any Binomial API can return a NaN, so we force the issue - test_is_nan(0.8, 1200, |_| f64::NAN); + // Not sure that any Beta API can return a NaN, so we force the issue + test_is_nan(0.8, 1.2, |_| f64::NAN); } #[test] #[should_panic] fn test_is_nan_failure() { - test_is_nan(0.8, 1200, |dist| dist.mean().unwrap()); + test_is_nan(0.8, 1.2, |dist| dist.mean().unwrap()); } #[test] fn test_is_none_success() { - // Same as test_is_nan_success, force returning `None` here - test_none(0.8, 1200, |_| Option::::None); + test_none(f64::INFINITY, 1.2, |dist| dist.entropy()); } #[test] #[should_panic] fn test_is_none_failure() { - test_none(0.8, 1200, |dist| dist.mean()); + test_none(0.8, 1.2, |dist| dist.mean()); } } @@ -471,31 +455,6 @@ pub mod test { check_sum_pmf_is_cdf(dist, x_max); } - #[cfg(feature = "nalgebra")] - #[test] - fn test_is_valid_multinomial() { - use std::f64; - - let invalid = [1.0, f64::NAN, 3.0]; - assert!(!is_valid_multinomial(&invalid, true)); - assert!(check_multinomial(&invalid.to_vec().into(), true).is_err()); - let invalid2 = [-2.0, 5.0, 1.0, 6.2]; - assert!(!is_valid_multinomial(&invalid2, true)); - assert!(check_multinomial(&invalid2.to_vec().into(), true).is_err()); - let invalid3 = [0.0, 0.0, 0.0]; - assert!(!is_valid_multinomial(&invalid3, true)); - assert!(check_multinomial(&invalid3.to_vec().into(), true).is_err()); - let valid = [5.2, 0.0, 1e-15, 1000000.12]; - assert!(is_valid_multinomial(&valid, true)); - assert!(check_multinomial(&valid.to_vec().into(), true).is_ok()); - } - - #[test] - fn test_is_valid_multinomial_no_zero() { - let invalid = [5.2, 0.0, 1e-15, 1000000.12]; - assert!(!is_valid_multinomial(&invalid, false)); - } - #[test] fn test_integer_bisection() { fn search(z: usize, data: &[usize]) -> Option { diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index 8314a9dc..a36c7e17 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -26,6 +25,33 @@ pub struct InverseGamma { rate: f64, } +/// Represents the errors that can occur when creating an [`InverseGamma`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum InverseGammaError { + /// The shape is NaN, infinite, zero or less than zero. + ShapeInvalid, + + /// The rate is NaN, infinite, zero or less than zero. + RateInvalid, +} + +impl std::fmt::Display for InverseGammaError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + InverseGammaError::ShapeInvalid => { + write!(f, "Shape is NaN, infinite, zero or less than zero") + } + InverseGammaError::RateInvalid => { + write!(f, "Rate is NaN, infinite, zero or less than zero") + } + } + } +} + +impl std::error::Error for InverseGammaError {} + impl InverseGamma { /// Constructs a new inverse gamma distribution with a shape (α) /// of `shape` and a rate (β) of `rate` @@ -46,16 +72,16 @@ impl InverseGamma { /// result = InverseGamma::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: f64, rate: f64) -> Result { - let is_nan = shape.is_nan() || rate.is_nan(); - match (shape, rate, is_nan) { - (_, _, true) => Err(StatsError::BadParams), - (_, _, false) if shape <= 0.0 || rate <= 0.0 => Err(StatsError::BadParams), - (_, _, false) if shape.is_infinite() || rate.is_infinite() => { - Err(StatsError::BadParams) - } - (_, _, false) => Ok(InverseGamma { shape, rate }), + pub fn new(shape: f64, rate: f64) -> Result { + if shape.is_nan() || shape.is_infinite() || shape <= 0.0 { + return Err(InverseGammaError::ShapeInvalid); } + + if rate.is_nan() || rate.is_infinite() || rate <= 0.0 { + return Err(InverseGammaError::RateInvalid); + } + + Ok(InverseGamma { shape, rate }) } /// Returns the shape (α) of the inverse gamma distribution @@ -313,12 +339,11 @@ impl Continuous for InverseGamma { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, InverseGamma}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(shape: f64, rate: f64; InverseGamma); + testing_boiler!(shape: f64, rate: f64; InverseGamma; InverseGammaError); #[test] fn test_create() { @@ -328,13 +353,13 @@ mod tests { #[test] fn test_bad_create() { - create_err(0.0, 1.0); + test_create_err(0.0, 1.0, InverseGammaError::ShapeInvalid); + test_create_err(1.0, -1.0, InverseGammaError::RateInvalid); create_err(-1.0, 1.0); create_err(-100.0, 1.0); create_err(f64::NEG_INFINITY, 1.0); create_err(f64::NAN, 1.0); create_err(1.0, 0.0); - create_err(1.0, -1.0); create_err(1.0, -100.0); create_err(1.0, f64::NEG_INFINITY); create_err(1.0, f64::NAN); diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index 0ccbac78..13f03a55 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::{Distribution, Max, Median, Min, Mode}; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -23,6 +22,29 @@ pub struct Laplace { scale: f64, } +/// Represents the errors that can occur when creating a [`Laplace`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum LaplaceError { + /// The location is NaN. + LocationInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, +} + +impl std::fmt::Display for LaplaceError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + LaplaceError::LocationInvalid => write!(f, "Location is NaN"), + LaplaceError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for LaplaceError {} + impl Laplace { /// Constructs a new laplace distribution with the given /// location and scale. @@ -42,12 +64,16 @@ impl Laplace { /// result = Laplace::new(0.0, -1.0); /// assert!(result.is_err()); /// ``` - pub fn new(location: f64, scale: f64) -> Result { - if location.is_nan() || scale.is_nan() || scale <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(Laplace { location, scale }) + pub fn new(location: f64, scale: f64) -> Result { + if location.is_nan() { + return Err(LaplaceError::LocationInvalid); } + + if scale.is_nan() || scale <= 0.0 { + return Err(LaplaceError::ScaleInvalid); + } + + Ok(Laplace { location, scale }) } /// Returns the location of the laplace distribution @@ -304,7 +330,7 @@ mod tests { use crate::testing_boiler; - testing_boiler!(location: f64, scale: f64; Laplace); + testing_boiler!(location: f64, scale: f64; Laplace; LaplaceError); // A wrapper for the `assert_relative_eq!` macro from the approx crate. // @@ -332,8 +358,8 @@ mod tests { #[test] fn test_bad_create() { - create_err(2.0, -1.0); - create_err(f64::NAN, 1.0); + test_create_err(2.0, -1.0, LaplaceError::ScaleInvalid); + test_create_err(f64::NAN, 1.0, LaplaceError::LocationInvalid); create_err(f64::NAN, -1.0); } diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 88a78996..49380d2b 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -1,7 +1,7 @@ +use crate::consts; use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::erf; use crate::statistics::*; -use crate::{consts, Result, StatsError}; use rand::Rng; use std::f64; @@ -26,6 +26,29 @@ pub struct LogNormal { scale: f64, } +/// Represents the errors that can occur when creating a [`LogNormal`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum LogNormalError { + /// The location is NaN. + LocationInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, +} + +impl std::fmt::Display for LogNormalError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + LogNormalError::LocationInvalid => write!(f, "Location is NaN"), + LogNormalError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for LogNormalError {} + impl LogNormal { /// Constructs a new log-normal distribution with a location of `location` /// and a scale of `scale` @@ -46,12 +69,16 @@ impl LogNormal { /// result = LogNormal::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(location: f64, scale: f64) -> Result { - if location.is_nan() || scale.is_nan() || scale <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(LogNormal { location, scale }) + pub fn new(location: f64, scale: f64) -> Result { + if location.is_nan() { + return Err(LogNormalError::LocationInvalid); } + + if scale.is_nan() || scale <= 0.0 { + return Err(LogNormalError::ScaleInvalid); + } + + Ok(LogNormal { location, scale }) } } @@ -305,12 +332,11 @@ impl Continuous for LogNormal { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, LogNormal}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(mean: f64, std_dev: f64; LogNormal); + testing_boiler!(location: f64, scale: f64; LogNormal; LogNormalError); #[test] fn test_create() { @@ -323,9 +349,9 @@ mod tests { #[test] fn test_bad_create() { + test_create_err(f64::NAN, 1.0, LogNormalError::LocationInvalid); + test_create_err(1.0, f64::NAN, LogNormalError::ScaleInvalid); create_err(0.0, 0.0); - create_err(f64::NAN, 1.0); - create_err(1.0, f64::NAN); create_err(f64::NAN, f64::NAN); create_err(1.0, -1.0); } diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 6e43db8e..8955ed63 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -6,40 +6,40 @@ use ::num_traits::{Float, Num}; use num_traits::NumAssignOps; pub use self::bernoulli::Bernoulli; -pub use self::beta::Beta; -pub use self::binomial::Binomial; -pub use self::categorical::Categorical; -pub use self::cauchy::Cauchy; -pub use self::chi::Chi; +pub use self::beta::{Beta, BetaError}; +pub use self::binomial::{Binomial, BinomialError}; +pub use self::categorical::{Categorical, CategoricalError}; +pub use self::cauchy::{Cauchy, CauchyError}; +pub use self::chi::{Chi, ChiError}; pub use self::chi_squared::ChiSquared; -pub use self::dirac::Dirac; +pub use self::dirac::{Dirac, DiracError}; #[cfg(feature = "nalgebra")] -pub use self::dirichlet::Dirichlet; -pub use self::discrete_uniform::DiscreteUniform; +pub use self::dirichlet::{Dirichlet, DirichletError}; +pub use self::discrete_uniform::{DiscreteUniform, DiscreteUniformError}; pub use self::empirical::Empirical; pub use self::erlang::Erlang; -pub use self::exponential::Exp; -pub use self::fisher_snedecor::FisherSnedecor; -pub use self::gamma::Gamma; -pub use self::geometric::Geometric; -pub use self::hypergeometric::Hypergeometric; -pub use self::inverse_gamma::InverseGamma; -pub use self::laplace::Laplace; -pub use self::log_normal::LogNormal; +pub use self::exponential::{Exp, ExpError}; +pub use self::fisher_snedecor::{FisherSnedecor, FisherSnedecorError}; +pub use self::gamma::{Gamma, GammaError}; +pub use self::geometric::{Geometric, GeometricError}; +pub use self::hypergeometric::{Hypergeometric, HypergeometricError}; +pub use self::inverse_gamma::{InverseGamma, InverseGammaError}; +pub use self::laplace::{Laplace, LaplaceError}; +pub use self::log_normal::{LogNormal, LogNormalError}; #[cfg(feature = "nalgebra")] -pub use self::multinomial::Multinomial; +pub use self::multinomial::{Multinomial, MultinomialError}; #[cfg(feature = "nalgebra")] -pub use self::multivariate_normal::MultivariateNormal; +pub use self::multivariate_normal::{MultivariateNormal, MultivariateNormalError}; #[cfg(feature = "nalgebra")] -pub use self::multivariate_students_t::MultivariateStudent; -pub use self::negative_binomial::NegativeBinomial; -pub use self::normal::Normal; -pub use self::pareto::Pareto; -pub use self::poisson::Poisson; -pub use self::students_t::StudentsT; -pub use self::triangular::Triangular; -pub use self::uniform::Uniform; -pub use self::weibull::Weibull; +pub use self::multivariate_students_t::{MultivariateStudent, MultivariateStudentError}; +pub use self::negative_binomial::{NegativeBinomial, NegativeBinomialError}; +pub use self::normal::{Normal, NormalError}; +pub use self::pareto::{Pareto, ParetoError}; +pub use self::poisson::{Poisson, PoissonError}; +pub use self::students_t::{StudentsT, StudentsTError}; +pub use self::triangular::{Triangular, TriangularError}; +pub use self::uniform::{Uniform, UniformError}; +pub use self::weibull::{Weibull, WeibullError}; mod bernoulli; mod beta; diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index dc402050..7d1b408c 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -1,7 +1,6 @@ use crate::distribution::Discrete; use crate::function::factorial; use crate::statistics::*; -use crate::Result; use nalgebra::{Const, DVector, Dim, Dyn, OMatrix, OVector}; use rand::Rng; @@ -33,6 +32,36 @@ where n: u64, } +/// Represents the errors that can occur when creating a [`Multinomial`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum MultinomialError { + /// Fewer than two probabilities. + NotEnoughProbabilities, + + /// The sum of all probabilities is zero. + ProbabilitySumZero, + + /// At least one probability is NaN, infinite or less than zero. + ProbabilityInvalid, +} + +impl std::fmt::Display for MultinomialError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MultinomialError::NotEnoughProbabilities => write!(f, "Fewer than two probabilities"), + MultinomialError::ProbabilitySumZero => write!(f, "The probabilities sum up to zero"), + MultinomialError::ProbabilityInvalid => write!( + f, + "At least one probability is NaN, infinity or less than zero" + ), + } + } +} + +impl std::error::Error for MultinomialError {} + impl Multinomial { /// Constructs a new multinomial distribution with probabilities `p` /// and `n` number of trials. @@ -57,7 +86,7 @@ impl Multinomial { /// result = Multinomial::new(vec![0.0, -1.0, 2.0], 3); /// assert!(result.is_err()); /// ``` - pub fn new(p: Vec, n: u64) -> Result { + pub fn new(p: Vec, n: u64) -> Result { Self::new_from_nalgebra(p.into(), n) } } @@ -67,14 +96,26 @@ where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { - pub fn new_from_nalgebra(mut p: OVector, n: u64) -> Result { - match super::internal::check_multinomial(&p, true) { - Err(e) => Err(e), - Ok(_) => { - p.unscale_mut(p.lp_norm(1)); - Ok(Self { p, n }) + pub fn new_from_nalgebra(mut p: OVector, n: u64) -> Result { + if p.len() < 2 { + return Err(MultinomialError::NotEnoughProbabilities); + } + + let mut sum = 0.0; + for &val in &p { + if val.is_nan() || val < 0.0 { + return Err(MultinomialError::ProbabilityInvalid); } + + sum += val; + } + + if sum == 0.0 { + return Err(MultinomialError::ProbabilitySumZero); } + + p.unscale_mut(p.lp_norm(1)); + Ok(Self { p, n }) } /// Returns the probabilities of the multinomial @@ -295,7 +336,7 @@ where #[cfg(test)] mod tests { use crate::{ - distribution::{Discrete, Multinomial}, + distribution::{Discrete, Multinomial, MultinomialError}, statistics::{MeanN, VarianceN}, }; use nalgebra::{dmatrix, dvector, vector, DimMin, Dyn, OVector}; @@ -311,7 +352,7 @@ mod tests { mvn.unwrap() } - fn bad_create_case(p: OVector, n: u64) -> crate::StatsError + fn bad_create_case(p: OVector, n: u64) -> MultinomialError where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, @@ -344,18 +385,23 @@ mod tests { #[test] fn test_bad_create() { + assert_eq!( + bad_create_case(vector![0.5], 4), + MultinomialError::NotEnoughProbabilities, + ); + assert_eq!( bad_create_case(vector![-1.0, 2.0], 4), - crate::StatsError::BadParams + MultinomialError::ProbabilityInvalid, ); assert_eq!( bad_create_case(vector![0.0, 0.0], 4), - crate::StatsError::BadParams + MultinomialError::ProbabilitySumZero, ); assert_eq!( bad_create_case(vector![1.0, f64::NAN], 4), - crate::StatsError::BadParams + MultinomialError::ProbabilityInvalid, ); } @@ -454,6 +500,12 @@ mod tests { ); } + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::(); + } + // #[test] // #[should_panic] // fn test_pmf_x_wrong_length() { diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 61ccfa57..eb86edd3 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -108,15 +108,54 @@ where pdf_const: f64, } +/// Represents the errors that can occur when creating a [`MultivariateNormal`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum MultivariateNormalError { + /// The covariance matrix is asymmetric or contains a NaN. + CovInvalid, + + /// The mean vector contains a NaN. + MeanInvalid, + + /// The amount of rows in the vector of means is not equal to the amount + /// of rows in the covariance matrix. + DimensionMismatch, + + /// After all other validation, computing the Cholesky decomposition failed. + CholeskyFailed, +} + +impl std::fmt::Display for MultivariateNormalError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MultivariateNormalError::CovInvalid => { + write!(f, "Covariance matrix is asymmetric or contains a NaN") + } + MultivariateNormalError::MeanInvalid => write!(f, "Mean vector contains a NaN"), + MultivariateNormalError::DimensionMismatch => write!( + f, + "Mean vector and covariance matrix do not have the same number of rows" + ), + MultivariateNormalError::CholeskyFailed => { + write!(f, "Computing the Cholesky decomposition failed") + } + } + } +} + +impl std::error::Error for MultivariateNormalError {} + impl MultivariateNormal { - /// Constructs a new multivariate normal distribution with a mean of `mean` + /// Constructs a new multivariate normal distribution with a mean of `mean` /// and covariance matrix `cov` /// /// # Errors /// /// Returns an error if the given covariance matrix is not /// symmetric or positive-definite - pub fn new(mean: Vec, cov: Vec) -> Result { + pub fn new(mean: Vec, cov: Vec) -> Result { let mean = DVector::from_vec(mean); let cov = DMatrix::from_vec(mean.len(), mean.len(), cov); MultivariateNormal::new_from_nalgebra(mean, cov) @@ -141,24 +180,31 @@ where pub fn new_from_nalgebra( mean: OVector, cov: OMatrix, - ) -> Result { - // Check that the provided covariance matrix is symmetric - if cov.lower_triangle() != cov.upper_triangle().transpose() - // Check that mean and covariance do not contain NaN - || mean.iter().any(|f| f.is_nan()) + ) -> Result { + if mean.iter().any(|f| f.is_nan()) { + return Err(MultivariateNormalError::MeanInvalid); + } + + if !cov.is_square() + || cov.lower_triangle() != cov.upper_triangle().transpose() || cov.iter().any(|f| f.is_nan()) - // Check that the dimensions match - || mean.nrows() != cov.nrows() || cov.nrows() != cov.ncols() { - return Err(StatsError::BadParams); + return Err(MultivariateNormalError::CovInvalid); } + + // Compare number of rows + if mean.shape_generic().0 != cov.shape_generic().0 { + return Err(MultivariateNormalError::DimensionMismatch); + } + // Store the Cholesky decomposition of the covariance matrix // for sampling match Cholesky::new(cov.clone()) { - None => Err(StatsError::BadParams), + None => Err(MultivariateNormalError::CholeskyFailed), Some(cholesky_decomp) => { let precision = cholesky_decomp.inverse(); Ok(MultivariateNormal { + // .unwrap() because prerequisites are already checked above pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(), cov_chol_decomp: cholesky_decomp.unpack(), mu: mean, @@ -343,6 +389,8 @@ mod tests { statistics::{Max, MeanN, Min, Mode, VarianceN}, }; + use super::MultivariateNormalError; + fn try_create(mean: OVector, covariance: OMatrix) -> MultivariateNormal where D: DimMin, @@ -657,4 +705,10 @@ mod tests { let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.,]).unwrap(); mvn.pdf(&vec![1.].into()); // x.size != mu.size } + + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::(); + } } diff --git a/src/distribution/multivariate_students_t.rs b/src/distribution/multivariate_students_t.rs index 3758d328..b75e0a88 100644 --- a/src/distribution/multivariate_students_t.rs +++ b/src/distribution/multivariate_students_t.rs @@ -2,7 +2,6 @@ use crate::distribution::Continuous; use crate::distribution::{ChiSquared, Normal}; use crate::function::gamma; use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; -use crate::{Result, StatsError}; use nalgebra::{Cholesky, Const, DMatrix, Dim, DimMin, Dyn, OMatrix, OVector}; use rand::Rng; use std::f64::consts::PI; @@ -39,6 +38,54 @@ where ln_pdf_const: f64, } +/// Represents the errors that can occur when creating a [`MultivariateStudent`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum MultivariateStudentError { + /// The scale matrix is asymmetric or contains a NaN. + ScaleInvalid, + + /// The location vector contains a NaN. + LocationInvalid, + + /// The degrees of freedom are NaN, zero or less than zero. + FreedomInvalid, + + /// The amount of rows in the location vector is not equal to the amount + /// of rows in the scale matrix. + DimensionMismatch, + + /// After all other validation, computing the Cholesky decomposition failed. + /// This means that the scale matrix is not definite-positive. + CholeskyFailed, +} + +impl std::fmt::Display for MultivariateStudentError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MultivariateStudentError::ScaleInvalid => { + write!(f, "Scale matrix is asymmetric or contains a NaN") + } + MultivariateStudentError::LocationInvalid => { + write!(f, "Location vector contains a NaN") + } + MultivariateStudentError::FreedomInvalid => { + write!(f, "Degrees of freedom are NaN, zero or less than zero") + } + MultivariateStudentError::DimensionMismatch => write!( + f, + "Location vector and scale matrix do not have the same number of rows" + ), + MultivariateStudentError::CholeskyFailed => { + write!(f, "Computing the Cholesky decomposition failed") + } + } + } +} + +impl std::error::Error for MultivariateStudentError {} + impl MultivariateStudent { /// Constructs a new multivariate students t distribution with a location of `location`, /// scale matrix `scale` and `freedom` degrees of freedom. @@ -47,7 +94,11 @@ impl MultivariateStudent { /// /// Returns `StatsError::BadParams` if the scale matrix is not symmetric-positive /// definite and `StatsError::ArgMustBePositive` if freedom is non-positive. - pub fn new(location: Vec, scale: Vec, freedom: f64) -> Result { + pub fn new( + location: Vec, + scale: Vec, + freedom: f64, + ) -> Result { let dim = location.len(); Self::new_from_nalgebra(location.into(), DMatrix::from_vec(dim, dim, scale), freedom) } @@ -69,26 +120,26 @@ where location: OVector, scale: OMatrix, freedom: f64, - ) -> Result { + ) -> Result { let dim = location.len(); - // Check that the provided scale matrix is symmetric - if scale.lower_triangle() != scale.upper_triangle().transpose() - // Check that mean and covariance do not contain NaN - || location.iter().any(|f| f.is_nan()) + if location.iter().any(|f| f.is_nan()) { + return Err(MultivariateStudentError::LocationInvalid); + } + + if !scale.is_square() + || scale.lower_triangle() != scale.upper_triangle().transpose() || scale.iter().any(|f| f.is_nan()) - // Check that the dimensions match - || location.nrows() != scale.nrows() || scale.nrows() != scale.ncols() - // Check that the degrees of freedom is not NaN - || freedom.is_nan() { - return Err(StatsError::BadParams); + return Err(MultivariateStudentError::ScaleInvalid); } - // Check that degrees of freedom is positive - if freedom <= 0. { - return Err(StatsError::ArgMustBePositive( - "Degrees of freedom must be positive", - )); + + if freedom.is_nan() || freedom <= 0.0 { + return Err(MultivariateStudentError::FreedomInvalid); + } + + if location.nrows() != scale.nrows() { + return Err(MultivariateStudentError::DimensionMismatch); } let scale_det = scale.determinant(); @@ -98,7 +149,7 @@ where - 0.5 * scale_det.ln(); match Cholesky::new(scale.clone()) { - None => Err(StatsError::BadParams), // Scale matrix is not positive definite + None => Err(MultivariateStudentError::CholeskyFailed), Some(cholesky_decomp) => { let precision = cholesky_decomp.inverse(); Ok(MultivariateStudent { @@ -346,6 +397,8 @@ mod tests { statistics::{Max, MeanN, Min, Mode, VarianceN}, }; + use super::MultivariateStudentError; + fn try_create(location: Vec, scale: Vec, freedom: f64) -> MultivariateStudent { let mvs = MultivariateStudent::new(location, scale, freedom); @@ -563,5 +616,9 @@ mod tests { assert_eq!(mvs.scale_chol_decomp(), &OMatrix::::identity(2, 2)); } - + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::(); + } } diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index d36d0f98..6ed557be 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -1,7 +1,6 @@ use crate::distribution::{self, poisson, Discrete, DiscreteCDF}; use crate::function::{beta, gamma}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -41,6 +40,29 @@ pub struct NegativeBinomial { p: f64, } +/// Represents the errors that can occur when creating a [`NegativeBinomial`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum NegativeBinomialError { + /// `r` is NaN or less than zero. + RInvalid, + + /// `p` is NaN or not in `[0, 1]`. + PInvalid, +} + +impl std::fmt::Display for NegativeBinomialError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + NegativeBinomialError::RInvalid => write!(f, "r is NaN or less than zero"), + NegativeBinomialError::PInvalid => write!(f, "p is NaN or not in [0, 1]"), + } + } +} + +impl std::error::Error for NegativeBinomialError {} + impl NegativeBinomial { /// Constructs a new negative binomial distribution with parameters `r` /// and `p`. When `r` is an integer, the negative binomial distribution @@ -64,12 +86,16 @@ impl NegativeBinomial { /// result = NegativeBinomial::new(-0.5, 5.0); /// assert!(result.is_err()); /// ``` - pub fn new(r: f64, p: f64) -> Result { - if p.is_nan() || !(0.0..=1.0).contains(&p) || r.is_nan() || r < 0.0 { - Err(StatsError::BadParams) - } else { - Ok(NegativeBinomial { r, p }) + pub fn new(r: f64, p: f64) -> Result { + if r.is_nan() || r < 0.0 { + return Err(NegativeBinomialError::RInvalid); } + + if p.is_nan() || !(0.0..=1.0).contains(&p) { + return Err(NegativeBinomialError::PInvalid); + } + + Ok(NegativeBinomial { r, p }) } /// Returns the probability of success `p` of a single @@ -291,12 +317,11 @@ impl Discrete for NegativeBinomial { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{DiscreteCDF, Discrete, NegativeBinomial}; + use super::*; use crate::distribution::internal::test; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(r: f64, p: f64; NegativeBinomial); + testing_boiler!(r: f64, p: f64; NegativeBinomial; NegativeBinomialError); #[test] fn test_create() { @@ -307,8 +332,8 @@ mod tests { #[test] fn test_bad_create() { - create_err(f64::NAN, 1.0); - create_err(0.0, f64::NAN); + test_create_err(f64::NAN, 1.0, NegativeBinomialError::RInvalid); + test_create_err(0.0, f64::NAN, NegativeBinomialError::PInvalid); create_err(-1.0, 1.0); create_err(2.0, 2.0); } diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 65f6ad90..b536c101 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -1,7 +1,7 @@ +use crate::consts; use crate::distribution::{ziggurat, Continuous, ContinuousCDF}; use crate::function::erf; use crate::statistics::*; -use crate::{consts, Result, StatsError}; use rand::Rng; use std::f64; @@ -24,6 +24,31 @@ pub struct Normal { std_dev: f64, } +/// Represents the errors that can occur when creating a [`Normal`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum NormalError { + /// The mean is NaN. + MeanInvalid, + + /// The standard deviation is NaN, zero or less than zero. + StandardDeviationInvalid, +} + +impl std::fmt::Display for NormalError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + NormalError::MeanInvalid => write!(f, "Mean is NaN"), + NormalError::StandardDeviationInvalid => { + write!(f, "Standard deviation is NaN, zero or less than zero") + } + } + } +} + +impl std::error::Error for NormalError {} + impl Normal { /// Constructs a new normal distribution with a mean of `mean` /// and a standard deviation of `std_dev` @@ -44,12 +69,16 @@ impl Normal { /// result = Normal::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(mean: f64, std_dev: f64) -> Result { - if mean.is_nan() || std_dev.is_nan() || std_dev <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(Normal { mean, std_dev }) + pub fn new(mean: f64, std_dev: f64) -> Result { + if mean.is_nan() { + return Err(NormalError::MeanInvalid); } + + if std_dev.is_nan() || std_dev <= 0.0 { + return Err(NormalError::StandardDeviationInvalid); + } + + Ok(Normal { mean, std_dev }) } /// Constructs a new standard normal distribution with a mean of 0 @@ -334,12 +363,11 @@ impl std::default::Default for Normal { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, Normal}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(mean: f64, std_dev: f64; Normal); + testing_boiler!(mean: f64, std_dev: f64; Normal; NormalError); #[test] fn test_create() { @@ -352,9 +380,9 @@ mod tests { #[test] fn test_bad_create() { + test_create_err(f64::NAN, 1.0, NormalError::MeanInvalid); + test_create_err(1.0, f64::NAN, NormalError::StandardDeviationInvalid); create_err(0.0, 0.0); - create_err(f64::NAN, 1.0); - create_err(1.0, f64::NAN); create_err(f64::NAN, f64::NAN); create_err(1.0, -1.0); } diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 5fb62044..886db43b 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::distributions::OpenClosed01; use rand::Rng; use std::f64; @@ -25,6 +24,29 @@ pub struct Pareto { shape: f64, } +/// Represents the errors that can occur when creating a [`Pareto`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum ParetoError { + /// The scale is NaN, zero or less than zero. + ScaleInvalid, + + /// The shape is NaN, zero or less than zero. + ShapeInvalid, +} + +impl std::fmt::Display for ParetoError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ParetoError::ScaleInvalid => write!(f, "Scale is NaN, zero, or less than zero"), + ParetoError::ShapeInvalid => write!(f, "Shape is NaN, zero, or less than zero"), + } + } +} + +impl std::error::Error for ParetoError {} + impl Pareto { /// Constructs a new Pareto distribution with scale `scale`, and `shape` /// shape. @@ -45,13 +67,16 @@ impl Pareto { /// result = Pareto::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(scale: f64, shape: f64) -> Result { - let is_nan = scale.is_nan() || shape.is_nan(); - if is_nan || scale <= 0.0 || shape <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(Pareto { scale, shape }) + pub fn new(scale: f64, shape: f64) -> Result { + if scale.is_nan() || scale <= 0.0 { + return Err(ParetoError::ScaleInvalid); } + + if shape.is_nan() || shape <= 0.0 { + return Err(ParetoError::ShapeInvalid); + } + + Ok(Pareto { scale, shape }) } /// Returns the scale of the Pareto distribution @@ -354,12 +379,11 @@ impl Continuous for Pareto { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, Pareto}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(scale: f64, shape: f64; Pareto); + testing_boiler!(scale: f64, shape: f64; Pareto; ParetoError); #[test] fn test_create() { @@ -373,9 +397,9 @@ mod tests { #[test] fn test_bad_create() { + test_create_err(1.0, -1.0, ParetoError::ShapeInvalid); + test_create_err(-1.0, 1.0, ParetoError::ScaleInvalid); create_err(0.0, 0.0); - create_err(1.0, -1.0); - create_err(-1.0, 1.0); create_err(-1.0, -1.0); create_err(f64::NAN, 1.0); create_err(1.0, f64::NAN); diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 41b56e6a..33e8f8a2 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -1,7 +1,6 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{factorial, gamma}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -24,6 +23,25 @@ pub struct Poisson { lambda: f64, } +/// Represents the errors that can occur when creating a [`Poisson`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum PoissonError { + /// The lambda is NaN, zero or less than zero. + LambdaInvalid, +} + +impl std::fmt::Display for PoissonError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + PoissonError::LambdaInvalid => write!(f, "Lambda is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for PoissonError {} + impl Poisson { /// Constructs a new poisson distribution with a rate (λ) /// of `lambda` @@ -43,9 +61,9 @@ impl Poisson { /// result = Poisson::new(0.0); /// assert!(result.is_err()); /// ``` - pub fn new(lambda: f64) -> Result { + pub fn new(lambda: f64) -> Result { if lambda.is_nan() || lambda <= 0.0 { - Err(StatsError::BadParams) + Err(PoissonError::LambdaInvalid) } else { Ok(Poisson { lambda }) } @@ -304,12 +322,11 @@ pub fn sample_unchecked(rng: &mut R, lambda: f64) -> f64 { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{DiscreteCDF, Discrete, Poisson}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(lambda: f64; Poisson); + testing_boiler!(lambda: f64; Poisson; PoissonError); #[test] fn test_create() { diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index 4bada682..cc88707f 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::{beta, gamma}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -26,15 +25,43 @@ pub struct StudentsT { freedom: f64, } +/// Represents the errors that can occur when creating a [`StudentsT`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum StudentsTError { + /// The location is NaN. + LocationInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, + + /// The degrees of freedom are NaN, zero or less than zero. + FreedomInvalid, +} + +impl std::fmt::Display for StudentsTError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + StudentsTError::LocationInvalid => write!(f, "Location is NaN"), + StudentsTError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), + StudentsTError::FreedomInvalid => { + write!(f, "Degrees of freedom are NaN, zero or less than zero") + } + } + } +} + +impl std::error::Error for StudentsTError {} + impl StudentsT { /// Constructs a new student's t-distribution with location `location`, - /// scale `scale`, - /// and `freedom` freedom. + /// scale `scale`, and `freedom` freedom. /// /// # Errors /// /// Returns an error if any of `location`, `scale`, or `freedom` are `NaN`. - /// Returns an error if `scale <= 0.0` or `freedom <= 0.0` + /// Returns an error if `scale <= 0.0` or `freedom <= 0.0`. /// /// # Examples /// @@ -47,17 +74,24 @@ impl StudentsT { /// result = StudentsT::new(0.0, 0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(location: f64, scale: f64, freedom: f64) -> Result { - let is_nan = location.is_nan() || scale.is_nan() || freedom.is_nan(); - if is_nan || scale <= 0.0 || freedom <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(StudentsT { - location, - scale, - freedom, - }) + pub fn new(location: f64, scale: f64, freedom: f64) -> Result { + if location.is_nan() { + return Err(StudentsTError::LocationInvalid); + } + + if scale.is_nan() || scale <= 0.0 { + return Err(StudentsTError::ScaleInvalid); + } + + if freedom.is_nan() || freedom <= 0.0 { + return Err(StudentsTError::FreedomInvalid); } + + Ok(StudentsT { + location, + scale, + freedom, + }) } /// Returns the location of the student's t-distribution @@ -421,14 +455,12 @@ impl Continuous for StudentsT { #[cfg(test)] mod tests { + use super::*; use crate::consts::ACC; use crate::distribution::internal::*; - use crate::distribution::{Continuous, ContinuousCDF, StudentsT}; - use crate::statistics::*; use crate::testing_boiler; - use std::panic; - testing_boiler!(location: f64, scale: f64, freedom: f64; StudentsT); + testing_boiler!(location: f64, scale: f64, freedom: f64; StudentsT; StudentsTError); #[test] fn test_create() { @@ -446,11 +478,17 @@ mod tests { #[test] fn test_bad_create() { - create_err(f64::NAN, 1.0, 1.0); - create_err(0.0, f64::NAN, 1.0); - create_err(0.0, 1.0, f64::NAN); - create_err(0.0, -10.0, 1.0); - create_err(0.0, 10.0, -1.0); + let invalid = [ + (f64::NAN, 1.0, 1.0, StudentsTError::LocationInvalid), + (0.0, f64::NAN, 1.0, StudentsTError::ScaleInvalid), + (0.0, 1.0, f64::NAN, StudentsTError::FreedomInvalid), + (0.0, -10.0, 1.0, StudentsTError::ScaleInvalid), + (0.0, 10.0, -1.0, StudentsTError::FreedomInvalid), + ]; + + for (l, s, f, err) in invalid { + test_create_err(l, s, f, err); + } } #[test] diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index fff9fe72..eb3cb93d 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -25,6 +24,43 @@ pub struct Triangular { mode: f64, } +/// Represents the errors that can occur when creating a [`Triangular`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum TriangularError { + /// The minimum is NaN or infinite. + MinInvalid, + + /// The maximum is NaN or infinite. + MaxInvalid, + + /// The mode is NaN or infinite. + ModeInvalid, + + /// The mode is less than the minimum or greater than the maximum. + ModeOutOfRange, + + /// The minimum equals the maximum. + MinEqualsMax, +} + +impl std::fmt::Display for TriangularError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TriangularError::MinInvalid => write!(f, "Minimum is NaN or infinite."), + TriangularError::MaxInvalid => write!(f, "Maximum is NaN or infinite."), + TriangularError::ModeInvalid => write!(f, "Mode is NaN or infinite."), + TriangularError::ModeOutOfRange => { + write!(f, "Mode is less than minimum or greater than maximum") + } + TriangularError::MinEqualsMax => write!(f, "Minimum equals Maximum"), + } + } +} + +impl std::error::Error for TriangularError {} + impl Triangular { /// Constructs a new triangular distribution with a minimum of `min`, /// maximum of `max`, and a mode of `mode`. @@ -45,16 +81,27 @@ impl Triangular { /// result = Triangular::new(2.5, 1.5, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(min: f64, max: f64, mode: f64) -> Result { - if !min.is_finite() || !max.is_finite() || !mode.is_finite() { - return Err(StatsError::BadParams); + pub fn new(min: f64, max: f64, mode: f64) -> Result { + if !min.is_finite() { + return Err(TriangularError::MinInvalid); + } + + if !max.is_finite() { + return Err(TriangularError::MaxInvalid); + } + + if !mode.is_finite() { + return Err(TriangularError::ModeInvalid); } + if max < mode || mode < min { - return Err(StatsError::BadParams); + return Err(TriangularError::ModeOutOfRange); } - if ulps_eq!(max, min, max_ulps = 0) { - return Err(StatsError::BadParams); + + if min == max { + return Err(TriangularError::MinEqualsMax); } + Ok(Triangular { min, max, mode }) } } @@ -347,12 +394,11 @@ fn sample_unchecked(rng: &mut R, min: f64, max: f64, mode: f64) #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, Triangular}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(min: f64, max: f64, mode: f64; Triangular); + testing_boiler!(min: f64, max: f64, mode: f64; Triangular; TriangularError); #[test] fn test_create() { @@ -367,17 +413,23 @@ mod tests { #[test] fn test_bad_create() { - create_err(0.0, 0.0, 0.0); - create_err(0.0, 1.0, -0.1); - create_err(0.0, 1.0, 1.1); - create_err(0.0, -1.0, 0.5); - create_err(2.0, 1.0, 1.5); - create_err(f64::NAN, 1.0, 0.5); - create_err(0.2, f64::NAN, 0.5); - create_err(0.5, 1.0, f64::NAN); - create_err(f64::NAN, f64::NAN, f64::NAN); - create_err(f64::NEG_INFINITY, 1.0, 0.5); - create_err(0.0, f64::INFINITY, 0.5); + let invalid = [ + (0.0, 0.0, 0.0, TriangularError::MinEqualsMax), + (0.0, 1.0, -0.1, TriangularError::ModeOutOfRange), + (0.0, 1.0, 1.1, TriangularError::ModeOutOfRange), + (0.0, -1.0, 0.5, TriangularError::ModeOutOfRange), + (2.0, 1.0, 1.5, TriangularError::ModeOutOfRange), + (f64::NAN, 1.0, 0.5, TriangularError::MinInvalid), + (0.2, f64::NAN, 0.5, TriangularError::MaxInvalid), + (0.5, 1.0, f64::NAN, TriangularError::ModeInvalid), + (f64::NAN, f64::NAN, f64::NAN, TriangularError::MinInvalid), + (f64::NEG_INFINITY, 1.0, 0.5, TriangularError::MinInvalid), + (0.0, f64::INFINITY, 0.5, TriangularError::MaxInvalid), + ]; + + for (min, max, mode, err) in invalid { + test_create_err(min, max, mode, err); + } } #[test] diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index 4186df71..55bd7884 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::distributions::Uniform as RandUniform; use rand::Rng; use std::f64; @@ -26,13 +25,43 @@ pub struct Uniform { max: f64, } +/// Represents the errors that can occur when creating a [`Uniform`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum UniformError { + /// The minimum is NaN or infinite. + MinInvalid, + + /// The maximum is NaN or infinite. + MaxInvalid, + + /// The maximum is not greater than the minimum. + MaxNotGreaterThanMin, +} + +impl std::fmt::Display for UniformError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + UniformError::MinInvalid => write!(f, "Minimum is NaN or infinite"), + UniformError::MaxInvalid => write!(f, "Maximum is NaN or infinite"), + UniformError::MaxNotGreaterThanMin => { + write!(f, "Maximum is not greater than the minimum") + } + } + } +} + +impl std::error::Error for UniformError {} + impl Uniform { /// Constructs a new uniform distribution with a min of `min` and a max - /// of `max` + /// of `max`. /// /// # Errors /// - /// Returns an error if `min` or `max` are `NaN` or unbounded + /// Returns an error if `min` or `max` are `NaN` or infinite. + /// Returns an error if `min >= max`. /// /// # Examples /// @@ -49,17 +78,19 @@ impl Uniform { /// result = Uniform::new(f64::NEG_INFINITY, 1.0); /// assert!(result.is_err()); /// ``` - pub fn new(min: f64, max: f64) -> Result { - if min.is_nan() || max.is_nan() { - return Err(StatsError::BadParams); + pub fn new(min: f64, max: f64) -> Result { + if !min.is_finite() { + return Err(UniformError::MinInvalid); } - match (min.is_finite(), max.is_finite(), min < max) { - (false, false, _) => Err(StatsError::ArgFinite("min and max")), - (false, true, _) => Err(StatsError::ArgFinite("min")), - (true, false, _) => Err(StatsError::ArgFinite("max")), - (true, true, false) => Err(StatsError::ArgLteArg("min", "max")), - (true, true, true) => Ok(Uniform { min, max }), + if !max.is_finite() { + return Err(UniformError::MaxInvalid); + } + + if min < max { + Ok(Uniform { min, max }) + } else { + Err(UniformError::MaxNotGreaterThanMin) } } @@ -284,12 +315,11 @@ impl Continuous for Uniform { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, Uniform}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(min: f64, max: f64; Uniform); + testing_boiler!(min: f64, max: f64; Uniform; UniformError); #[test] fn test_create() { @@ -301,12 +331,18 @@ mod tests { #[test] fn test_bad_create() { - create_err(0.0, 0.0); - create_err(f64::NAN, 1.0); - create_err(1.0, f64::NAN); - create_err(f64::NAN, f64::NAN); - create_err(0.0, f64::INFINITY); - create_err(1.0, 0.0); + let invalid = [ + (0.0, 0.0, UniformError::MaxNotGreaterThanMin), + (f64::NAN, 1.0, UniformError::MinInvalid), + (1.0, f64::NAN, UniformError::MaxInvalid), + (f64::NAN, f64::NAN, UniformError::MinInvalid), + (0.0, f64::INFINITY, UniformError::MaxInvalid), + (1.0, 0.0, UniformError::MaxNotGreaterThanMin), + ]; + + for (min, max, err) in invalid { + test_create_err(min, max, err); + } } #[test] diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 2d3a8a87..71aa30ef 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -1,7 +1,7 @@ +use crate::consts; use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; -use crate::{consts, Result, StatsError}; use rand::Rng; use std::f64; @@ -27,6 +27,29 @@ pub struct Weibull { scale_pow_shape_inv: f64, } +/// Represents the errors that can occur when creating a [`Weibull`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum WeibullError { + /// The shape is NaN, zero or less than zero. + ShapeInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, +} + +impl std::fmt::Display for WeibullError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + WeibullError::ShapeInvalid => write!(f, "Shape is NaN, zero or less than zero."), + WeibullError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero."), + } + } +} + +impl std::error::Error for WeibullError {} + impl Weibull { /// Constructs a new weibull distribution with a shape (k) of `shape` /// and a scale (λ) of `scale` @@ -47,17 +70,20 @@ impl Weibull { /// result = Weibull::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: f64, scale: f64) -> Result { - let is_nan = shape.is_nan() || scale.is_nan(); - match (shape, scale, is_nan) { - (_, _, true) => Err(StatsError::BadParams), - (_, _, false) if shape <= 0.0 || scale <= 0.0 => Err(StatsError::BadParams), - (_, _, false) => Ok(Weibull { - shape, - scale, - scale_pow_shape_inv: scale.powf(-shape), - }), + pub fn new(shape: f64, scale: f64) -> Result { + if shape.is_nan() || shape <= 0.0 { + return Err(WeibullError::ShapeInvalid); } + + if scale.is_nan() || scale <= 0.0 { + return Err(WeibullError::ScaleInvalid); + } + + Ok(Weibull { + shape, + scale, + scale_pow_shape_inv: scale.powf(-shape), + }) } /// Returns the shape of the weibull distribution @@ -350,12 +376,11 @@ impl Continuous for Weibull { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, Weibull}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(shape: f64, scale: f64; Weibull); + testing_boiler!(shape: f64, scale: f64; Weibull; WeibullError); #[test] fn test_create() { @@ -367,8 +392,8 @@ mod tests { #[test] fn test_bad_create() { - create_err(f64::NAN, 1.0); - create_err(1.0, f64::NAN); + test_create_err(f64::NAN, 1.0, WeibullError::ShapeInvalid); + test_create_err(1.0, f64::NAN, WeibullError::ScaleInvalid); create_err(f64::NAN, f64::NAN); create_err(1.0, -1.0); create_err(-1.0, 1.0); diff --git a/src/lib.rs b/src/lib.rs index 56f9b162..7ca3157c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,7 @@ #![allow(clippy::excessive_precision)] #![allow(clippy::many_single_char_names)] #![forbid(unsafe_code)] +#![cfg_attr(coverage_nightly, feature(coverage_attribute))] #[macro_use] extern crate approx; diff --git a/src/stats_tests/fisher.rs b/src/stats_tests/fisher.rs index 69b41d6d..909b4e7b 100644 --- a/src/stats_tests/fisher.rs +++ b/src/stats_tests/fisher.rs @@ -1,6 +1,5 @@ use super::Alternative; -use crate::distribution::{Discrete, DiscreteCDF, Hypergeometric}; -use crate::StatsError; +use crate::distribution::{Discrete, DiscreteCDF, Hypergeometric, HypergeometricError}; const EPSILON: f64 = 1.0 - 1e-4; @@ -97,6 +96,35 @@ fn binary_search( guess } +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum FishersExactTestError { + /// The table does not describe a valid [`Hypergeometric`] distribution. + /// Make sure that the contingency table stores the data in row-major order. + TableInvalidForHypergeometric(HypergeometricError), +} + +impl std::fmt::Display for FishersExactTestError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + FishersExactTestError::TableInvalidForHypergeometric(hg_err) => { + writeln!(f, "Cannot create a Hypergeometric distribution from the data in the contingency table.")?; + writeln!(f, "Is it in row-major order?")?; + write!(f, "Inner error: '{}'", hg_err) + } + } + } +} + +impl std::error::Error for FishersExactTestError {} + +impl From for FishersExactTestError { + fn from(value: HypergeometricError) -> Self { + Self::TableInvalidForHypergeometric(value) + } +} + /// Perform a Fisher exact test on a 2x2 contingency table. /// Based on scipy's fisher test: /// Expects a table in row-major order @@ -112,7 +140,7 @@ fn binary_search( pub fn fishers_exact_with_odds_ratio( table: &[u64; 4], alternative: Alternative, -) -> Result<(f64, f64), StatsError> { +) -> Result<(f64, f64), FishersExactTestError> { // If both values in a row or column are zero, p-value is 1 and odds ratio is NaN. match table { [0, _, 0, _] | [_, 0, _, 0] => return Ok((f64::NAN, 1.0)), // both 0 in a row @@ -144,7 +172,10 @@ pub fn fishers_exact_with_odds_ratio( /// let table = [3, 5, 4, 50]; /// let p_value = fishers_exact(&table, Alternative::Less).unwrap(); /// ``` -pub fn fishers_exact(table: &[u64; 4], alternative: Alternative) -> Result { +pub fn fishers_exact( + table: &[u64; 4], + alternative: Alternative, +) -> Result { // If both values in a row or column are zero, the p-value is 1 and the odds ratio is NaN. match table { [0, _, 0, _] | [_, 0, _, 0] => return Ok(1.0), // both 0 in a row