Skip to content

Commit

Permalink
feat!(errors): propose error hierarchy
Browse files Browse the repository at this point in the history
  • Loading branch information
YeungOnion committed Jun 23, 2024
1 parent 0d3c90c commit 2ee9a3a
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 5 deletions.
36 changes: 35 additions & 1 deletion src/distribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//! and provides
//! concrete implementations for a variety of distributions.
use super::statistics::{Max, Min};
use crate::StatsError;
use ::num_traits::{Bounded, Float, Num};
use num_traits::{NumAssign, NumAssignOps, NumAssignRef};

Expand Down Expand Up @@ -71,7 +72,40 @@ mod weibull;
mod ziggurat;
mod ziggurat_tables;

use crate::Result;
type Result<T> = std::result::Result<T, DistributionError>;

#[derive(Copy, Clone, PartialEq, Debug)]
pub enum DistributionError {
InvalidConstruction(StatsError),
DegenerateConstruction(f64),
ExpectedProbability(f64),
}

impl std::fmt::Display for DistributionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidConstruction(_) => {
write!(f, "provided value does not specify valid distribution")
}
Self::DegenerateConstruction(_) => write!(
f,
"provided value represents degenerate distribution, see statrs-dev/statrs#102"
),
Self::ExpectedProbability(p) => write!(f, "expected probability, got {p:.3e}"),
}
}
}

impl std::error::Error for DistributionError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use core::ops::Bound::Included;
match self {
Self::InvalidConstruction(e) => Some(e),
Self::DegenerateConstruction(_) => None,
Self::ExpectedProbability(_) => None,
}
}
}

/// The `ContinuousCDF` trait is used to specify an interface for univariate
/// distributions for which cdf float arguments are sensible.
Expand Down
74 changes: 71 additions & 3 deletions src/distribution/normal.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use super::DistributionError as DistrError;
use crate::distribution::{ziggurat, Continuous, ContinuousCDF};
use crate::function::erf;
use crate::statistics::*;
use crate::{consts, Result, StatsError};
use crate::{consts, StatsError};
use rand::Rng;
use std::f64;
use std::ops::Bound;

/// Implements the [Normal](https://en.wikipedia.org/wiki/Normal_distribution)
/// distribution
Expand All @@ -24,6 +26,36 @@ pub struct Normal {
std_dev: f64,
}

#[derive(Copy, Clone, PartialEq, Debug)]
pub enum Error {
InvalidMean(DistrError),
InvalidStdDev(DistrError),
}

impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidMean(_) => write!(f, "expected finite and not nan mean"),
Self::InvalidStdDev(_) => write!(
f,
"expected finite, positive, and not nan standard deviation"
),
}
}
}

impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(match self {
Self::InvalidMean(e) => e,
Self::InvalidStdDev(e) => e,
})
}
}

type Result<T> = std::result::Result<T, Error>;
const POSITIVE_RANGE: (Bound<f64>, Bound<f64>) = (Bound::Excluded(0.0), Bound::Unbounded);

impl Normal {
/// Constructs a new normal distribution with a mean of `mean`
/// and a standard deviation of `std_dev`
Expand All @@ -45,8 +77,22 @@ impl Normal {
/// assert!(result.is_err());
/// ```
pub fn new(mean: f64, std_dev: f64) -> Result<Normal> {
if mean.is_nan() || std_dev.is_nan() || std_dev <= 0.0 {
Err(StatsError::BadParams)
if mean.is_nan() {
Err(Error::InvalidMean(DistrError::InvalidConstruction(
StatsError::Finite(mean),
)))
} else if std_dev.is_nan() {
Err(Error::InvalidStdDev(DistrError::InvalidConstruction(
StatsError::NotNan,
)))
} else if std_dev == 0.0 {
Err(Error::InvalidStdDev(DistrError::DegenerateConstruction(
0.0,
)))
} else if std_dev < 0.0 {
Err(Error::InvalidStdDev(DistrError::InvalidConstruction(
StatsError::Bounded(POSITIVE_RANGE, std_dev),
)))
} else {
Ok(Normal { mean, std_dev })
}
Expand Down Expand Up @@ -334,6 +380,7 @@ impl std::default::Default for Normal {
#[rustfmt::skip]
#[cfg(test)]
mod tests {
use super::*;
use crate::statistics::*;
use crate::distribution::{ContinuousCDF, Continuous, Normal};
use crate::distribution::internal::*;
Expand Down Expand Up @@ -563,4 +610,25 @@ mod tests {
// Check that the standard deviation of the distribution is close to 1
assert_almost_eq!(n_std, 1.0, 1e-15);
}

#[test]
fn test_errors() {
let n = Normal::new(f64::NAN, f64::INFINITY);
assert!(matches!(n.err().unwrap(),
Error::InvalidMean(
DistrError::InvalidConstruction(
StatsError::NotNan | StatsError::Finite(_)
)
)
));

let n = Normal::new(0.0, 0.0);
assert!(matches!(n.err().unwrap(),
Error::InvalidStdDev(
DistrError::InvalidConstruction(
StatsError::FiniteNonNegative(_)
)
)
));
}
}
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt;
use std::ops::Bound;

/// Enumeration of possible errors thrown within the `statrs` library
#[derive(Clone, PartialEq, Debug)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub enum StatsError {
/// Generic bad input parameter error
BadParams,
Expand Down

0 comments on commit 2ee9a3a

Please sign in to comment.