diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 56deb09a..b3189276 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -2,6 +2,7 @@ //! and provides //! concrete implementations for a variety of distributions. use super::statistics::{Max, Min}; +use crate::StatsError; use ::num_traits::{Bounded, Float, Num}; use num_traits::{NumAssign, NumAssignOps, NumAssignRef}; @@ -71,7 +72,40 @@ mod weibull; mod ziggurat; mod ziggurat_tables; -use crate::Result; +type Result = std::result::Result; + +#[derive(Copy, Clone, PartialEq, Debug)] +pub enum DistributionError { + InvalidConstruction(StatsError), + DegenerateConstruction(f64), + ExpectedProbability(f64), +} + +impl std::fmt::Display for DistributionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidConstruction(_) => { + write!(f, "provided value does not specify valid distribution") + } + Self::DegenerateConstruction(_) => write!( + f, + "provided value represents degenerate distribution, see statrs-dev/statrs#102" + ), + Self::ExpectedProbability(p) => write!(f, "expected probability, got {p:.3e}"), + } + } +} + +impl std::error::Error for DistributionError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + use core::ops::Bound::Included; + match self { + Self::InvalidConstruction(e) => Some(e), + Self::DegenerateConstruction(_) => None, + Self::ExpectedProbability(_) => None, + } + } +} /// The `ContinuousCDF` trait is used to specify an interface for univariate /// distributions for which cdf float arguments are sensible. diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 94e8c6b6..44e3338f 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -1,9 +1,11 @@ +use super::DistributionError as DistrError; use crate::distribution::{ziggurat, Continuous, ContinuousCDF}; use crate::function::erf; use crate::statistics::*; -use crate::{consts, Result, StatsError}; +use crate::{consts, StatsError}; use rand::Rng; use std::f64; +use std::ops::Bound; /// Implements the [Normal](https://en.wikipedia.org/wiki/Normal_distribution) /// distribution @@ -24,6 +26,36 @@ pub struct Normal { std_dev: f64, } +#[derive(Copy, Clone, PartialEq, Debug)] +pub enum Error { + InvalidMean(DistrError), + InvalidStdDev(DistrError), +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidMean(_) => write!(f, "expected finite and not nan mean"), + Self::InvalidStdDev(_) => write!( + f, + "expected finite, positive, and not nan standard deviation" + ), + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(match self { + Self::InvalidMean(e) => e, + Self::InvalidStdDev(e) => e, + }) + } +} + +type Result = std::result::Result; +const POSITIVE_RANGE: (Bound, Bound) = (Bound::Excluded(0.0), Bound::Unbounded); + impl Normal { /// Constructs a new normal distribution with a mean of `mean` /// and a standard deviation of `std_dev` @@ -45,8 +77,22 @@ impl Normal { /// assert!(result.is_err()); /// ``` pub fn new(mean: f64, std_dev: f64) -> Result { - if mean.is_nan() || std_dev.is_nan() || std_dev <= 0.0 { - Err(StatsError::BadParams) + if mean.is_nan() { + Err(Error::InvalidMean(DistrError::InvalidConstruction( + StatsError::Finite(mean), + ))) + } else if std_dev.is_nan() { + Err(Error::InvalidStdDev(DistrError::InvalidConstruction( + StatsError::NotNan, + ))) + } else if std_dev == 0.0 { + Err(Error::InvalidStdDev(DistrError::DegenerateConstruction( + 0.0, + ))) + } else if std_dev < 0.0 { + Err(Error::InvalidStdDev(DistrError::InvalidConstruction( + StatsError::Bounded(POSITIVE_RANGE, std_dev), + ))) } else { Ok(Normal { mean, std_dev }) } @@ -334,6 +380,7 @@ impl std::default::Default for Normal { #[rustfmt::skip] #[cfg(test)] mod tests { + use super::*; use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Normal}; use crate::distribution::internal::*; @@ -563,4 +610,25 @@ mod tests { // Check that the standard deviation of the distribution is close to 1 assert_almost_eq!(n_std, 1.0, 1e-15); } + + #[test] + fn test_errors() { + let n = Normal::new(f64::NAN, f64::INFINITY); + assert!(matches!(n.err().unwrap(), + Error::InvalidMean( + DistrError::InvalidConstruction( + StatsError::NotNan | StatsError::Finite(_) + ) + ) + )); + + let n = Normal::new(0.0, 0.0); + assert!(matches!(n.err().unwrap(), + Error::InvalidStdDev( + DistrError::InvalidConstruction( + StatsError::FiniteNonNegative(_) + ) + ) + )); + } } diff --git a/src/error.rs b/src/error.rs index a18aebbd..08fda6c8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,7 +3,7 @@ use std::fmt; use std::ops::Bound; /// Enumeration of possible errors thrown within the `statrs` library -#[derive(Clone, PartialEq, Debug)] +#[derive(Copy, Clone, PartialEq, Debug)] pub enum StatsError { /// Generic bad input parameter error BadParams,