diff --git a/Cargo.toml b/Cargo.toml index 2b133a61..7cec0438 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ rand = "0.8" nalgebra = { version = "0.32", features = ["rand"] } approx = "0.5.0" num-traits = "0.2.14" +thiserror = "1.0.63" [dev-dependencies] criterion = "0.3.3" diff --git a/examples/error_trials.rs b/examples/error_trials.rs new file mode 100644 index 00000000..5a986984 --- /dev/null +++ b/examples/error_trials.rs @@ -0,0 +1,96 @@ +extern crate statrs; + +use anyhow::{anyhow, Context, Result as AnyhowResult}; +use statrs::distribution::{ + Continuous, Discrete, Gamma, GammaError, NegativeBinomial, Normal, ParametrizationError, +}; + +pub fn main() -> AnyhowResult<()> { + gamma_pdf(1.0, 1.0, 1.0).map(|x| println!("val = {}", x))?; + // val = 0.36787944117144233 + + normal_pdf(1.0, 1.0, 1.0).map(|x| println!("val = {}", x))?; + // val = 0.39894228040143265 + + gamma_pdf_with_negative_shape_correction(-0.5, 1.0, 1.0).map(|x| println!("val = {}", x))?; + // without shape correction would emit, the below + // Error: failed creating gamma(-0.5,1) + // + // Caused by: + // 0: shape must be finite, positive, and not nan + // 1: expected positive, got -0.5 + // after re-attempt, output is + // Error: gamma provided invalid shape + // attempting to correct shape to 0.5 + // val = 0.2075537487102974 + + nb_pmf(1, 1.0, 1).map(|x| println!("val = {}", x))?; + // Error: failed creating nb(1,1) + // + // Caused by: + // mean of 0 is degenerate + + nb_pmf(1, 0., 1).map(|x| println!("val = {}", x))?; + // Error: failed creating nb(1,0) + // + // Caused by: + // mean of inf is degenerate + + normal_pdf(1.0, f64::INFINITY, 1.0).map(|x| println!("val = {}", x))?; + // Error: failed creating normal(1, inf) + // + // Caused by: + // variance of inf is degenrate + + normal_pdf(1.0, 0.0, 1.0).map(|x| println!("val = {}", x))?; + // Error: failed creating normal(1, 0) + // + // Caused by: + // variance of 0 is degenerate + + Ok(()) +} + +pub fn gamma_pdf(shape: f64, rate: f64, x: f64) -> AnyhowResult { + Ok(Gamma::new(shape, rate) + .context(format!("failed creating gamma({},{})", shape, rate))? + .pdf(x)) +} + +pub fn gamma_pdf_with_negative_shape_correction( + shape: f64, + rate: f64, + x: f64, +) -> AnyhowResult { + match gamma_pdf(shape, rate, x) { + Ok(x) => Ok(x), + Err(ee) => { + if let GammaError::InvalidShape(e) = ee.downcast::()? { + eprintln!("Error: gamma provided invalid shape"); + if let ParametrizationError::ExpectedPositive(shape) = e { + eprintln!("\tattempting to correct shape to {}", shape.abs()); + // fails again for 0 and INF + gamma_pdf(shape.abs(), rate, x) + } else { + Err(anyhow!("cannot recover valid shape from this error")) + } + } else { + Err(anyhow!( + "cannot recover both valid shape and rate from this error" + )) + } + } + } +} + +pub fn nb_pmf(r: u64, p: f64, x: u64) -> AnyhowResult { + Ok(NegativeBinomial::new(r, p) + .context(format!("failed creating nb({},{})", r, p))? + .pmf(x)) +} + +pub fn normal_pdf(location: f64, scale: f64, x: f64) -> AnyhowResult { + Ok(Normal::new(location, scale) + .context(format!("failed creating normal({}, {})", location, scale))? + .pdf(x)) +} diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index 1c6b42b0..20fe223c 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF, Gamma}; use crate::statistics::*; -use crate::Result; use rand::Rng; use std::f64; @@ -48,7 +47,7 @@ impl ChiSquared { /// result = ChiSquared::new(0.0); /// assert!(result.is_err()); /// ``` - pub fn new(freedom: f64) -> Result { + pub fn new(freedom: f64) -> Result { Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared { freedom, g }) } diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index 1213baef..3195a056 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF, Gamma}; use crate::statistics::*; -use crate::Result; use rand::Rng; /// Implements the [Erlang](https://en.wikipedia.org/wiki/Erlang_distribution) @@ -45,7 +44,7 @@ impl Erlang { /// result = Erlang::new(0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: u64, rate: f64) -> Result { + pub fn new(shape: u64, rate: f64) -> Result { Gamma::new(shape as f64, rate).map(|g| Erlang { g }) } @@ -304,7 +303,6 @@ mod tests { create_case(1, 1.0); create_case(10, 10.0); create_case(10, 1.0); - create_case(10, f64::INFINITY); } #[test] diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 166ebb72..91206ea7 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -1,9 +1,78 @@ -use crate::distribution::{Continuous, ContinuousCDF}; +use crate::distribution::{Continuous, ContinuousCDF, ParametrizationError as ParamError}; use crate::function::gamma; use crate::prec; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; +use thiserror::Error; + +#[derive(Clone, PartialEq, Debug, Error)] +pub enum GammaError { + #[error("shape must be finite, positive, and not nan")] + InvalidShape(#[source] ParamError), + #[error("shape must be finite, positive, and not nan")] + InvalidRate(#[source] ParamError), + #[error("rate of {0} is degenerate")] + DegenerateRate(f64), + #[error("shape of {0} is degenerate")] + DegenerateShape(f64), +} + +impl From for GammaError { + fn from(value: super::negative_binomial::NegativeBinomialError) -> Self { + use super::negative_binomial::NegativeBinomialError::*; + match value { + InvalidMean(e) => Self::InvalidShape(e), + InvalidProbability(p) => { + if p.is_nan() { + Self::InvalidRate(ParamError::ExpectedNotNan) + } else { + Self::InvalidRate(ParamError::ExpectedPositive(p / (1.0 - p))) + } + } + InvalidSuccessCount(e) => Self::InvalidRate(e.into()), + DegenerateMean(m) => Self::DegenerateShape(m), + DegenerateProbability(p) => Self::DegenerateRate(p / (1.0 - p)), + DegenerateSuccessCount => Self::DegenerateRate(0.0), + } + } +} + +/// holds a valid parametrization of the gamma distribution in shape and rate. +pub struct Parameters { + shape: f64, + rate: f64, +} + +impl Parameters { + pub fn new(shape: f64, rate: f64) -> Result { + if shape.is_nan() { + Err(GammaError::InvalidShape(ParamError::ExpectedNotNan)) + } else if rate.is_nan() { + Err(GammaError::InvalidRate(ParamError::ExpectedNotNan)) + } else if rate <= 0.0 { + Err(GammaError::InvalidRate(ParamError::ExpectedPositive(rate))) + } else if shape <= 0.0 { + Err(GammaError::InvalidShape(ParamError::ExpectedPositive( + shape, + ))) + } else if rate.is_infinite() { + Err(GammaError::DegenerateRate(rate)) + } else if shape.is_infinite() { + Err(GammaError::DegenerateShape(shape)) + } else { + Ok(Self { shape, rate }) + } + } +} + +impl From for Gamma { + fn from(value: Parameters) -> Self { + Gamma { + shape: value.shape, + rate: value.rate, + } + } +} /// Implements the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) /// distribution @@ -45,16 +114,8 @@ impl Gamma { /// result = Gamma::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: f64, rate: f64) -> Result { - if shape.is_nan() - || rate.is_nan() - || shape.is_infinite() && rate.is_infinite() - || shape <= 0.0 - || rate <= 0.0 - { - return Err(StatsError::BadParams); - } - Ok(Gamma { shape, rate }) + pub fn new(shape: f64, rate: f64) -> Result { + Ok(Parameters::new(shape, rate)?.into()) } /// Returns the shape (α) of the gamma distribution @@ -414,13 +475,7 @@ mod tests { #[test] fn test_create() { - let valid = [ - (1.0, 0.1), - (1.0, 1.0), - (10.0, 10.0), - (10.0, 1.0), - (10.0, f64::INFINITY), - ]; + let valid = [(1.0, 0.1), (1.0, 1.0), (10.0, 10.0), (10.0, 1.0)]; for (s, r) in valid { try_create(s, r); @@ -450,7 +505,6 @@ mod tests { ((1.0, 1.0), 1.0), ((10.0, 10.0), 1.0), ((10.0, 1.0), 10.0), - ((10.0, f64::INFINITY), 0.0), ]; for ((s, r), res) in test { test_case(s, r, res, f); @@ -465,7 +519,6 @@ mod tests { ((1.0, 1.0), 1.0), ((10.0, 10.0), 0.1), ((10.0, 1.0), 10.0), - ((10.0, f64::INFINITY), 0.0), ]; for ((s, r), res) in test { test_case(s, r, res, f); @@ -480,7 +533,6 @@ mod tests { ((1.0, 1.0), 1.0), ((10.0, 10.0), 0.2334690854869339583626209), ((10.0, 1.0), 2.53605417848097964238061239), - ((10.0, f64::INFINITY), f64::NEG_INFINITY), ]; for ((s, r), res) in test { test_case(s, r, res, f); @@ -495,7 +547,6 @@ mod tests { ((1.0, 1.0), 2.0), ((10.0, 10.0), 0.6324555320336758663997787), ((10.0, 1.0), 0.63245553203367586639977870), - ((10.0, f64::INFINITY), 0.6324555320336758), ]; for ((s, r), res) in test { test_case(s, r, res, f); @@ -509,11 +560,7 @@ mod tests { for &((s, r), res) in test.iter() { test_case_special(s, r, res, 10e-6, f); } - let test = [ - ((10.0, 10.0), 0.9), - ((10.0, 1.0), 9.0), - ((10.0, f64::INFINITY), 0.0), - ]; + let test = [((10.0, 10.0), 0.9), ((10.0, 1.0), 9.0)]; for ((s, r), res) in test { test_case(s, r, res, f); } @@ -527,7 +574,6 @@ mod tests { ((1.0, 1.0), 0.0), ((10.0, 10.0), 0.0), ((10.0, 1.0), 0.0), - ((10.0, f64::INFINITY), 0.0), ]; for ((s, r), res) in test { test_case(s, r, res, f); @@ -538,7 +584,6 @@ mod tests { ((1.0, 1.0), f64::INFINITY), ((10.0, 10.0), f64::INFINITY), ((10.0, 1.0), f64::INFINITY), - ((10.0, f64::INFINITY), f64::INFINITY), ]; for ((s, r), res) in test { test_case(s, r, res, f); @@ -585,7 +630,6 @@ mod tests { ((10.0, 10.0), 10.0, -69.0527107131946016148658), ((10.0, 1.0), 1.0, -13.8018274800814696112077), ((10.0, 1.0), 10.0, -2.07856164313505845504579), - ((10.0, f64::INFINITY), f64::INFINITY, f64::NEG_INFINITY), ]; for ((s, r), x, res) in test { test_case(s, r, res, f(x)); @@ -606,8 +650,6 @@ mod tests { ((10.0, 10.0), 10.0, 0.999999999999999999999999), ((10.0, 1.0), 1.0, 0.000000111425478338720677), ((10.0, 1.0), 10.0, 0.542070285528147791685835), - ((10.0, f64::INFINITY), 1.0, 0.0), - ((10.0, f64::INFINITY), 10.0, 1.0), ]; for ((s, r), x, res) in test { test_case(s, r, res, f(x)); @@ -657,8 +699,6 @@ mod tests { ((10.0, 10.0), 10.0, 1.1253473960842808e-31), ((10.0, 1.0), 1.0, 0.9999998885745217), ((10.0, 1.0), 10.0, 0.4579297144718528), - ((10.0, f64::INFINITY), 1.0, 1.0), - ((10.0, f64::INFINITY), 10.0, 0.0), ]; for ((s, r), x, res) in test { test_case(s, r, res, f(x)); diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 56deb09a..4bf0b02c 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -3,7 +3,8 @@ //! concrete implementations for a variety of distributions. use super::statistics::{Max, Min}; use ::num_traits::{Bounded, Float, Num}; -use num_traits::{NumAssign, NumAssignOps, NumAssignRef}; +use num_traits::{AsPrimitive, NumAssign, NumAssignOps, NumAssignRef}; +use thiserror::Error; pub use self::bernoulli::Bernoulli; pub use self::beta::Beta; @@ -19,7 +20,7 @@ pub use self::empirical::Empirical; pub use self::erlang::Erlang; pub use self::exponential::Exp; pub use self::fisher_snedecor::FisherSnedecor; -pub use self::gamma::Gamma; +pub use self::gamma::{Gamma, GammaError}; pub use self::geometric::Geometric; pub use self::hypergeometric::Hypergeometric; pub use self::inverse_gamma::InverseGamma; @@ -27,8 +28,8 @@ pub use self::laplace::Laplace; pub use self::log_normal::LogNormal; pub use self::multinomial::Multinomial; pub use self::multivariate_normal::MultivariateNormal; -pub use self::negative_binomial::NegativeBinomial; -pub use self::normal::Normal; +pub use self::negative_binomial::{NegativeBinomial, NegativeBinomialError}; +pub use self::normal::{Normal, NormalError}; pub use self::pareto::Pareto; pub use self::poisson::Poisson; pub use self::students_t::StudentsT; @@ -71,6 +72,35 @@ mod weibull; mod ziggurat; mod ziggurat_tables; +#[derive(Clone, PartialEq, Debug, Error)] +pub enum ParametrizationError { + #[error("expected positive, got {0}")] + ExpectedPositive(N), + #[error("expected non-negative, got {0}")] + ExpectedNotNegative(N), + #[error("expected finite, {0}")] + ExpectedFinite(N), + #[error("expected not-NAN")] + ExpectedNotNan, +} + +impl From> for ParametrizationError { + fn from(val: ParametrizationError) -> Self { + match val { + ParametrizationError::ExpectedPositive(x) => { + ParametrizationError::ExpectedPositive(x.as_()) + } + ParametrizationError::ExpectedNotNegative(x) => { + ParametrizationError::ExpectedNotNegative(x.as_()) + } + ParametrizationError::ExpectedFinite(x) => { + ParametrizationError::ExpectedFinite(x.as_()) + } + ParametrizationError::ExpectedNotNan => ParametrizationError::ExpectedNotNan, + } + } +} + use crate::Result; /// The `ContinuousCDF` trait is used to specify an interface for univariate diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index 065c2239..0e37be51 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -1,9 +1,56 @@ -use crate::distribution::{self, poisson, Discrete, DiscreteCDF}; +use crate::distribution::{ + self, poisson, Discrete, DiscreteCDF, ParametrizationError as ParamError, +}; use crate::function::{beta, gamma}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; +use thiserror::Error; + +#[derive(Clone, PartialEq, Debug, Error)] +pub enum NegativeBinomialError { + #[error("mean must be finite, positive, and not nan")] + InvalidMean(#[source] ParamError), + #[error("success_count, `r`, must be positive")] + InvalidSuccessCount(#[source] ParamError), + #[error("sample probability, `p`, must represent a probability")] + InvalidProbability(f64), + #[error("mean of {0} is degenerate")] + DegenerateMean(f64), + #[error("success_count, `r`, of 0 is degenerate")] + DegenerateSuccessCount, + #[error("probability of {0} is degenerate")] + DegenerateProbability(f64), +} + +/// holds a valid parametrization of the gamma distribution in shape and rate. +pub struct Parameters { + r: u64, + p: f64, +} + +impl Parameters { + pub fn new(p: f64, r: u64) -> Result { + if p.is_nan() || !((0.0..=1.0).contains(&p)) { + Err(NegativeBinomialError::InvalidProbability(p)) + } else if p == 0.0 || p == 1.0 { + Err(NegativeBinomialError::DegenerateProbability(p)) + } else if r == 0 { + Err(NegativeBinomialError::DegenerateSuccessCount) + } else { + Ok(Self { p, r }) + } + } +} + +impl From for NegativeBinomial { + fn from(value: Parameters) -> Self { + Self { + r: value.r as f64, + p: value.p, + } + } +} /// Implements the /// [negative binomial](http://en.wikipedia.org/wiki/Negative_binomial_distribution) @@ -30,7 +77,7 @@ use std::f64; /// use statrs::statistics::DiscreteDistribution; /// use statrs::prec::almost_eq; /// -/// let r = NegativeBinomial::new(4.0, 0.5).unwrap(); +/// let r = NegativeBinomial::new(4, 0.5).unwrap(); /// assert_eq!(r.mean().unwrap(), 4.0); /// assert!(almost_eq(r.pmf(0), 0.0625, 1e-8)); /// assert!(almost_eq(r.pmf(3), 0.15625, 1e-8)); @@ -58,17 +105,14 @@ impl NegativeBinomial { /// ``` /// use statrs::distribution::NegativeBinomial; /// - /// let mut result = NegativeBinomial::new(4.0, 0.5); + /// let mut result = NegativeBinomial::new(4, 0.5); /// assert!(result.is_ok()); - /// - /// result = NegativeBinomial::new(-0.5, 5.0); - /// assert!(result.is_err()); /// ``` - pub fn new(r: f64, p: f64) -> Result { - if p.is_nan() || !(0.0..=1.0).contains(&p) || r.is_nan() || r < 0.0 { - Err(StatsError::BadParams) + pub fn new(r: u64, p: f64) -> Result { + if p.is_nan() || !(0.0..=1.0).contains(&p) { + Err(NegativeBinomialError::InvalidProbability(p)) } else { - Ok(NegativeBinomial { r, p }) + Ok(Parameters::new(p, r)?.into()) } } @@ -81,7 +125,7 @@ impl NegativeBinomial { /// ``` /// use statrs::distribution::NegativeBinomial; /// - /// let r = NegativeBinomial::new(5.0, 0.5).unwrap(); + /// let r = NegativeBinomial::new(5, 0.5).unwrap(); /// assert_eq!(r.p(), 0.5); /// ``` pub fn p(&self) -> f64 { @@ -96,11 +140,11 @@ impl NegativeBinomial { /// ``` /// use statrs::distribution::NegativeBinomial; /// - /// let r = NegativeBinomial::new(5.0, 0.5).unwrap(); - /// assert_eq!(r.r(), 5.0); + /// let r = NegativeBinomial::new(5, 0.5).unwrap(); + /// assert_eq!(r.r(), 5); /// ``` - pub fn r(&self) -> f64 { - self.r + pub fn r(&self) -> u64 { + self.r as u64 } } @@ -296,24 +340,24 @@ mod tests { use crate::distribution::{DiscreteCDF, Discrete, NegativeBinomial}; use crate::distribution::internal::test; - fn try_create(r: f64, p: f64) -> NegativeBinomial { + fn try_create(r: u64, p: f64) -> NegativeBinomial { let r = NegativeBinomial::new(r, p); - assert!(r.is_ok()); + assert!(r.is_ok(), "r = {:?}", r); r.unwrap() } - fn create_case(r: f64, p: f64) { + fn create_case(r: u64, p: f64) { let dist = try_create(r, p); assert_eq!(p, dist.p()); assert_eq!(r, dist.r()); } - fn bad_create_case(r: f64, p: f64) { + fn bad_create_case(r: u64, p: f64) { let r = NegativeBinomial::new(r, p); - assert!(r.is_err()); + assert!(r.is_err(), "r = {:?}", r); } - fn get_value(r: f64, p: f64, eval: F) -> T + fn get_value(r: u64, p: f64, eval: F) -> T where T: PartialEq + Debug, F: Fn(NegativeBinomial) -> T { @@ -321,7 +365,7 @@ mod tests { eval(r) } - fn test_case(r: f64, p: f64, expected: T, eval: F) + fn test_case(r: u64, p: f64, expected: T, eval: F) where T: PartialEq + Debug, F: Fn(NegativeBinomial) -> T { @@ -330,7 +374,7 @@ mod tests { } - fn test_case_or_nan(r: f64, p: f64, expected: f64, eval: F) + fn test_case_or_nan(r: u64, p: f64, expected: f64, eval: F) where F: Fn(NegativeBinomial) -> f64 { let x = get_value(r, p, eval); @@ -341,7 +385,7 @@ mod tests { assert_eq!(expected, x); } } - fn test_almost(r: f64, p: f64, expected: f64, acc: f64, eval: F) + fn test_almost(r: u64, p: f64, expected: f64, acc: f64, eval: F) where F: Fn(NegativeBinomial) -> f64 { let x = get_value(r, p, eval); @@ -350,163 +394,164 @@ mod tests { #[test] fn test_create() { - create_case(0.0, 0.0); - create_case(0.3, 0.4); - create_case(1.0, 0.3); + // create_case(0.3, 0.4); // one should instead use Gamma directly + create_case(1, 0.3); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN, 1.0); - bad_create_case(0.0, f64::NAN); - bad_create_case(-1.0, 1.0); - bad_create_case(2.0, 2.0); + bad_create_case(0, f64::NAN); // reject NAN + bad_create_case(2, -1.0); // not a probability + bad_create_case(2, 2.0); + bad_create_case(1, 0.0); // degenerate + bad_create_case(1, 1.0); + bad_create_case(0, 0.5); } #[test] fn test_mean() { let mean = |x: NegativeBinomial| x.mean().unwrap(); - test_case(4.0, 0.0, f64::INFINITY, mean); - test_almost(3.0, 0.3, 7.0, 1e-15 , mean); - test_case(2.0, 1.0, 0.0, mean); + // test_case(4, 0.0, f64::INFINITY, mean); + test_almost(3, 0.3, 7.0, 1e-15 , mean); + // test_case(2, 1.0, 0.0, mean); } #[test] fn test_variance() { let variance = |x: NegativeBinomial| x.variance().unwrap(); - test_case(4.0, 0.0, f64::INFINITY, variance); - test_almost(3.0, 0.3, 23.333333333333, 1e-12, variance); - test_case(2.0, 1.0, 0.0, variance); + // test_case(4, 0.0, f64::INFINITY, variance); + test_almost(3, 0.3, 23.333333333333, 1e-12, variance); + // test_case(2, 1.0, 0.0, variance); } #[test] fn test_skewness() { let skewness = |x: NegativeBinomial| x.skewness().unwrap(); - test_case(0.0, 0.0, f64::INFINITY, skewness); - test_almost(0.1, 0.3, 6.425396041, 1e-09, skewness); - test_case(1.0, 1.0, f64::INFINITY, skewness); + // test_case(0, 0.0, f64::INFINITY, skewness); + test_almost(1, 0.3, 2.0318886359, 1e-09, skewness); + // test_case(1, 1.0, f64::INFINITY, skewness); } #[test] fn test_mode() { let mode = |x: NegativeBinomial| x.mode().unwrap(); - test_case(0.0, 0.0, 0.0, mode); - test_case(0.3, 0.0, 0.0, mode); - test_case(1.0, 1.0, 0.0, mode); - test_case(10.0, 0.01, 891.0, mode); + // test_case(0, 0.0, 0.0, mode); + // test_case(0, 0.0, 0.0, mode); + // test_case(1, 1.0, 0.0, mode); + // test_case(10, 0.01, 891.0, mode); } #[test] fn test_min_max() { let min = |x: NegativeBinomial| x.min(); let max = |x: NegativeBinomial| x.max(); - test_case(1.0, 0.5, 0, min); - test_case(1.0, 0.3, u64::MAX, max); + test_case(1, 0.5, 0, min); + test_case(1, 0.3, u64::MAX, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: NegativeBinomial| x.pmf(arg); - test_almost(4.0, 0.5, 0.0625, 1e-8, pmf(0)); - test_almost(4.0, 0.5, 0.15625, 1e-8, pmf(3)); - test_case(1.0, 0.0, 0.0, pmf(0)); - test_case(1.0, 0.0, 0.0, pmf(1)); - test_almost(3.0, 0.2, 0.008, 1e-15, pmf(0)); - test_almost(3.0, 0.2, 0.0192, 1e-15, pmf(1)); - test_almost(3.0, 0.2, 0.04096, 1e-15, pmf(3)); - test_almost(10.0, 0.2, 1.024e-07, 1e-07, pmf(0)); - test_almost(10.0, 0.2, 8.192e-07, 1e-07, pmf(1)); - test_almost(10.0, 0.2, 0.001015706852, 1e-07, pmf(10)); - test_almost(1.0, 0.3, 0.3, 1e-15, pmf(0)); - test_almost(1.0, 0.3, 0.21, 1e-15, pmf(1)); - test_almost(3.0, 0.3, 0.027, 1e-15, pmf(0)); - test_case(0.3, 1.0, 0.0, pmf(1)); - test_case(0.3, 1.0, 0.0, pmf(3)); - test_case_or_nan(0.3, 1.0, f64::NAN, pmf(0)); - test_case(0.3, 1.0, 0.0, pmf(1)); - test_case(0.3, 1.0, 0.0, pmf(10)); - test_case_or_nan(1.0, 1.0, f64::NAN, pmf(0)); - test_case(1.0, 1.0, 0.0, pmf(1)); - test_case_or_nan(3.0, 1.0, f64::NAN, pmf(0)); - test_case(3.0, 1.0, 0.0, pmf(1)); - test_case(3.0, 1.0, 0.0, pmf(3)); - test_case_or_nan(10.0, 1.0, f64::NAN, pmf(0)); - test_case(10.0, 1.0, 0.0, pmf(1)); - test_case(10.0, 1.0, 0.0, pmf(10)); + test_almost(4, 0.5, 0.0625, 1e-8, pmf(0)); + test_almost(4, 0.5, 0.15625, 1e-8, pmf(3)); + // test_case(1, 0.0, 0.0, pmf(0)); + // test_case(1, 0.0, 0.0, pmf(1)); + test_almost(3, 0.2, 0.008, 1e-15, pmf(0)); + test_almost(3, 0.2, 0.0192, 1e-15, pmf(1)); + test_almost(3, 0.2, 0.04096, 1e-15, pmf(3)); + test_almost(10, 0.2, 1.024e-07, 1e-07, pmf(0)); + test_almost(10, 0.2, 8.192e-07, 1e-07, pmf(1)); + test_almost(10, 0.2, 0.001015706852, 1e-07, pmf(10)); + test_almost(1, 0.3, 0.3, 1e-15, pmf(0)); + test_almost(1, 0.3, 0.21, 1e-15, pmf(1)); + test_almost(3, 0.3, 0.027, 1e-15, pmf(0)); + // test_case(0, 1.0, 0.0, pmf(1)); + // test_case(0, 1.0, 0.0, pmf(3)); + // test_case_or_nan(0, 1.0, f64::NAN, pmf(0)); + // test_case(0, 1.0, 0.0, pmf(1)); + // test_case(0, 1.0, 0.0, pmf(10)); + // test_case_or_nan(1, 1.0, f64::NAN, pmf(0)); + // test_case(1, 1.0, 0.0, pmf(1)); + // test_case_or_nan(3, 1.0, f64::NAN, pmf(0)); + // test_case(3, 1.0, 0.0, pmf(1)); + // test_case(3, 1.0, 0.0, pmf(3)); + // test_case_or_nan(10, 1.0, f64::NAN, pmf(0)); + // test_case(10, 1.0, 0.0, pmf(1)); + // test_case(10, 1.0, 0.0, pmf(10)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: NegativeBinomial| x.ln_pmf(arg); - test_case(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(0)); - test_case(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(1)); - test_almost(3.0, 0.2, -4.828313737, 1e-08, ln_pmf(0)); - test_almost(3.0, 0.2, -3.952845, 1e-08, ln_pmf(1)); - test_almost(3.0, 0.2, -3.195159298, 1e-08, ln_pmf(3)); - test_almost(10.0, 0.2, -16.09437912, 1e-08, ln_pmf(0)); - test_almost(10.0, 0.2, -14.01493758, 1e-08, ln_pmf(1)); - test_almost(10.0, 0.2, -6.892170503, 1e-08, ln_pmf(10)); - test_almost(1.0, 0.3, -1.203972804, 1e-08, ln_pmf(0)); - test_almost(1.0, 0.3, -1.560647748, 1e-08, ln_pmf(1)); - test_almost(3.0, 0.3, -3.611918413, 1e-08, ln_pmf(0)); - test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(3)); - test_case_or_nan(0.3, 1.0, f64::NAN, ln_pmf(0)); - test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(10)); - test_case_or_nan(1.0, 1.0, f64::NAN, ln_pmf(0)); - test_case(1.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case_or_nan(3.0, 1.0, f64::NAN, ln_pmf(0)); - test_case(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(3)); - test_case_or_nan(10.0, 1.0, f64::NAN, ln_pmf(0)); - test_case(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(10)); + // test_case(1, 0.0, f64::NEG_INFINITY, ln_pmf(0)); + // test_case(1, 0.0, f64::NEG_INFINITY, ln_pmf(1)); + test_almost(3, 0.2, -4.828313737, 1e-08, ln_pmf(0)); + test_almost(3, 0.2, -3.952845, 1e-08, ln_pmf(1)); + test_almost(3, 0.2, -3.195159298, 1e-08, ln_pmf(3)); + test_almost(10, 0.2, -16.09437912, 1e-08, ln_pmf(0)); + test_almost(10, 0.2, -14.01493758, 1e-08, ln_pmf(1)); + test_almost(10, 0.2, -6.892170503, 1e-08, ln_pmf(10)); + test_almost(1, 0.3, -1.203972804, 1e-08, ln_pmf(0)); + test_almost(1, 0.3, -1.560647748, 1e-08, ln_pmf(1)); + test_almost(3, 0.3, -3.611918413, 1e-08, ln_pmf(0)); + // test_case(0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + // test_case(0, 1.0, f64::NEG_INFINITY, ln_pmf(3)); + // test_case_or_nan(0, 1.0, f64::NAN, ln_pmf(0)); + // test_case(0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + // test_case(0, 1.0, f64::NEG_INFINITY, ln_pmf(10)); + // test_case_or_nan(1, 1.0, f64::NAN, ln_pmf(0)); + // test_case(1, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + // test_case_or_nan(3, 1.0, f64::NAN, ln_pmf(0)); + // test_case(3, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + // test_case(3, 1.0, f64::NEG_INFINITY, ln_pmf(3)); + // test_case_or_nan(10, 1.0, f64::NAN, ln_pmf(0)); + // test_case(10, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + // test_case(10, 1.0, f64::NEG_INFINITY, ln_pmf(10)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg); - test_almost(1.0, 0.3, 0.3, 1e-08, cdf(0)); - test_almost(1.0, 0.3, 0.51, 1e-08, cdf(1)); - test_almost(1.0, 0.3, 0.83193, 1e-08, cdf(4)); - test_almost(1.0, 0.3, 0.9802267326, 1e-08, cdf(10)); - test_case(1.0, 1.0, 1.0, cdf(0)); - test_case(1.0, 1.0, 1.0, cdf(1)); - test_almost(10.0, 0.75, 0.05631351471, 1e-08, cdf(0)); - test_almost(10.0, 0.75, 0.1970973015, 1e-08, cdf(1)); - test_almost(10.0, 0.75, 0.9960578583, 1e-08, cdf(10)); + test_almost(1, 0.3, 0.3, 1e-08, cdf(0)); + test_almost(1, 0.3, 0.51, 1e-08, cdf(1)); + test_almost(1, 0.3, 0.83193, 1e-08, cdf(4)); + test_almost(1, 0.3, 0.9802267326, 1e-08, cdf(10)); + // test_case(1, 1.0, 1.0, cdf(0)); + // test_case(1, 1.0, 1.0, cdf(1)); + test_almost(10, 0.75, 0.05631351471, 1e-08, cdf(0)); + test_almost(10, 0.75, 0.1970973015, 1e-08, cdf(1)); + test_almost(10, 0.75, 0.9960578583, 1e-08, cdf(10)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg); - test_almost(1.0, 0.3, 0.7, 1e-08, sf(0)); - test_almost(1.0, 0.3, 0.49, 1e-08, sf(1)); - test_almost(1.0, 0.3, 0.1680699999999986, 1e-08, sf(4)); - test_almost(1.0, 0.3, 0.019773267430000074, 1e-08, sf(10)); - test_case(1.0, 1.0, 0.0, sf(0)); - test_case(1.0, 1.0, 0.0, sf(1)); - test_almost(10.0, 0.75, 0.9436864852905275, 1e-08, sf(0)); - test_almost(10.0, 0.75, 0.8029026985168456, 1e-08, sf(1)); - test_almost(10.0, 0.75, 0.003942141664083465, 1e-08, sf(10)); + test_almost(1, 0.3, 0.7, 1e-08, sf(0)); + test_almost(1, 0.3, 0.49, 1e-08, sf(1)); + test_almost(1, 0.3, 0.1680699999999986, 1e-08, sf(4)); + test_almost(1, 0.3, 0.019773267430000074, 1e-08, sf(10)); + // test_case(1, 1.0, 0.0, sf(0)); + // test_case(1, 1.0, 0.0, sf(1)); + test_almost(10, 0.75, 0.9436864852905275, 1e-08, sf(0)); + test_almost(10, 0.75, 0.8029026985168456, 1e-08, sf(1)); + test_almost(10, 0.75, 0.003942141664083465, 1e-08, sf(10)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg); - test_case(3.0, 0.5, 1.0, cdf(100)); + test_case(3, 0.5, 1.0, cdf(100)); } #[test] fn test_discrete() { - test::check_discrete_distribution(&try_create(5.0, 0.3), 35); - test::check_discrete_distribution(&try_create(10.0, 0.7), 21); + test::check_discrete_distribution(&try_create(5, 0.3), 35); + test::check_discrete_distribution(&try_create(10, 0.7), 21); } #[test] fn test_sf_upper_bound() { let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg); - test_almost(3.0, 0.5, 5.282409836586059e-28, 1e-28, sf(100)); + test_almost(3, 0.5, 5.282409836586059e-28, 1e-28, sf(100)); } } diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 94e8c6b6..d0795853 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -1,9 +1,22 @@ -use crate::distribution::{ziggurat, Continuous, ContinuousCDF}; +use crate::consts; +use crate::distribution::{ + ziggurat, Continuous, ContinuousCDF, ParametrizationError as ParamError, +}; use crate::function::erf; use crate::statistics::*; -use crate::{consts, Result, StatsError}; use rand::Rng; use std::f64; +use thiserror::Error; + +#[derive(Clone, PartialEq, Debug, Error)] +pub enum NormalError { + #[error("variance must be positive and not nan")] + InvalidStandardDeviation(#[source] ParamError), + #[error("location must be finite and not nan")] + InvalidLocation(#[source] ParamError), + #[error("variance of {0} is degenerate")] + DegenerateVariance(f64), +} /// Implements the [Normal](https://en.wikipedia.org/wiki/Normal_distribution) /// distribution @@ -44,9 +57,23 @@ impl Normal { /// result = Normal::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(mean: f64, std_dev: f64) -> Result { - if mean.is_nan() || std_dev.is_nan() || std_dev <= 0.0 { - Err(StatsError::BadParams) + pub fn new(mean: f64, std_dev: f64) -> Result { + if mean.is_nan() { + Err(NormalError::InvalidLocation(ParamError::ExpectedNotNan)) + } else if std_dev.is_nan() { + Err(NormalError::InvalidStandardDeviation( + ParamError::ExpectedNotNan, + )) + } else if std_dev < 0.0 { + Err(NormalError::InvalidStandardDeviation( + ParamError::ExpectedPositive(std_dev), + )) + } else if std_dev == 0.0 || std_dev.is_infinite() { + Err(NormalError::DegenerateVariance(std_dev)) + } else if mean.is_infinite() { + Err(NormalError::InvalidLocation(ParamError::ExpectedFinite( + mean, + ))) } else { Ok(Normal { mean, std_dev }) } @@ -377,7 +404,6 @@ mod tests { create_case(-5.0, 1.0); create_case(0.0, 10.0); create_case(10.0, 100.0); - create_case(-5.0, f64::INFINITY); } #[test] @@ -395,7 +421,6 @@ mod tests { test_case(0.0, 0.1, 0.1 * 0.1, variance); test_case(0.0, 1.0, 1.0, variance); test_case(0.0, 10.0, 100.0, variance); - test_case(0.0, f64::INFINITY, f64::INFINITY, variance); } #[test] @@ -404,7 +429,6 @@ mod tests { test_almost(0.0, 0.1, -0.8836465597893729422377, 1e-15, entropy); test_case(0.0, 1.0, 1.41893853320467274178, entropy); test_case(0.0, 10.0, 3.721523626198718425798, entropy); - test_case(0.0, f64::INFINITY, f64::INFINITY, entropy); } #[test] @@ -413,7 +437,6 @@ mod tests { test_case(0.0, 0.1, 0.0, skewness); test_case(4.0, 1.0, 0.0, skewness); test_case(0.3, 10.0, 0.0, skewness); - test_case(0.0, f64::INFINITY, 0.0, skewness); } #[test] @@ -424,7 +447,6 @@ mod tests { test_case(0.1, 1.0, 0.1, mode); test_case(1.0, 1.0, 1.0, mode); test_case(-10.0, 1.0, -10.0, mode); - test_case(f64::INFINITY, 1.0, f64::INFINITY, mode); } #[test] @@ -435,7 +457,6 @@ mod tests { test_case(0.1, 1.0, 0.1, median); test_case(1.0, 1.0, 1.0, median); test_case(-0.0, 1.0, -0.0, median); - test_case(f64::INFINITY, 1.0, f64::INFINITY, median); } #[test] @@ -471,9 +492,6 @@ mod tests { test_case(10.0, 100.0, 0.003969525474770117655105, pdf(0.0)); test_almost(10.0, 100.0, 0.002660852498987548218204, 1e-18, pdf(100.0)); test_case(10.0, 100.0, 6.561581477467659126534E-4, pdf(200.0)); - test_case(-5.0, f64::INFINITY, 0.0, pdf(-5.0)); - test_case(-5.0, f64::INFINITY, 0.0, pdf(0.0)); - test_case(-5.0, f64::INFINITY, 0.0, pdf(100.0)); } #[test] @@ -499,9 +517,6 @@ mod tests { test_almost(10.0, 100.0, (0.003969525474770117655105f64).ln(),1e-15, ln_pdf(0.0)); test_almost(10.0, 100.0, (0.002660852498987548218204f64).ln(), 1e-15, ln_pdf(100.0)); test_almost(10.0, 100.0, (6.561581477467659126534E-4f64).ln(), 1e-15, ln_pdf(200.0)); - test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); - test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(100.0)); } #[test] diff --git a/src/error.rs b/src/error.rs index e6f7ca40..9cb9e43d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,7 @@ use std::error::Error; use std::fmt; +#[deprecated(since = "0.18.0", note = "dropping for less general error variants")] /// Enumeration of possible errors thrown within the `statrs` library #[derive(Clone, PartialEq, Debug)] pub enum StatsError {