Skip to content

Commit

Permalink
refactor(errors): use thiserror instead of manual impls
Browse files Browse the repository at this point in the history
  • Loading branch information
YeungOnion committed Jun 23, 2024
1 parent df9c2f9 commit 1bc312a
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 100 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ rand = "0.8"
nalgebra = { version = "0.32", features = ["rand"] }
approx = "0.5.0"
num-traits = "0.2.14"
thiserror = "1.0.61"

[dev-dependencies]
criterion = "0.3.3"
33 changes: 5 additions & 28 deletions src/distribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,39 +72,16 @@ mod weibull;
mod ziggurat;
mod ziggurat_tables;

#[derive(Copy, Clone, PartialEq, Debug)]
#[derive(Copy, Clone, PartialEq, Debug, thiserror::Error)]
pub enum DistributionError {
InvalidConstruction(StatsError),
#[error("provided value does not specify valid distribution")]
InvalidConstruction(#[source] StatsError),
#[error("provided value represents degenerate distribution, see statrs-dev/statrs#102")]
DegenerateConstruction(f64),
#[error("expected probability, got {:.3e}", .0)]
ExpectedProbability(f64),
}

impl std::fmt::Display for DistributionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidConstruction(_) => {
write!(f, "provided value does not specify valid distribution")
}
Self::DegenerateConstruction(_) => write!(
f,
"provided value represents degenerate distribution, see statrs-dev/statrs#102"
),
Self::ExpectedProbability(p) => write!(f, "expected probability, got {p:.3e}"),
}
}
}

impl std::error::Error for DistributionError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use core::ops::Bound::Included;
match self {
Self::InvalidConstruction(e) => Some(e),
Self::DegenerateConstruction(_) => None,
Self::ExpectedProbability(_) => None,
}
}
}

/// The `ContinuousCDF` trait is used to specify an interface for univariate
/// distributions for which cdf float arguments are sensible.
pub trait ContinuousCDF<K: Float, T: Float>: Min<K> + Max<K> {
Expand Down
29 changes: 5 additions & 24 deletions src/distribution/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,12 @@ pub struct Normal {
std_dev: f64,
}

#[derive(Copy, Clone, PartialEq, Debug)]
#[derive(Copy, Clone, PartialEq, Debug, thiserror::Error)]
pub enum Error {
InvalidMean(DistrError),
InvalidStdDev(DistrError),
}

impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidMean(_) => write!(f, "expected finite and not nan mean"),
Self::InvalidStdDev(_) => write!(
f,
"expected finite, positive, and not nan standard deviation"
),
}
}
}

impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(match self {
Self::InvalidMean(e) => e,
Self::InvalidStdDev(e) => e,
})
}
#[error("expected finite and not nan mean")]
InvalidMean(#[source] DistrError),
#[error("expected finite, positive, and not nan standard deviation")]
InvalidStdDev(#[source] DistrError),
}

type Result<T> = std::result::Result<T, Error>;
Expand Down
61 changes: 13 additions & 48 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,34 @@ use std::fmt;
use std::ops::Bound;

/// Enumeration of possible errors thrown within the `statrs` library
#[derive(Copy, Clone, PartialEq, Debug)]
#[derive(Copy, Clone, PartialEq, Debug, thiserror::Error)]
pub enum StatsError {
/// Generic bad input parameter error
#[error("Bad parameters, unspecified")]
BadParams,
/// value must not be NAN
#[error("value must not be NAN")]
NotNan,
/// value must be finite and must not be NAN
#[error("given `{}`, but must be finite and not NAN", .0)]
Finite(f64),
/// value must be finite, non negative and must not be NAN
#[error("given `{}`, but must be finite, non-negative and not NAN", .0)]
FiniteNonNegative(f64),
/// value must be within specified bounds
#[error("given `{}`, but must be on interval {:?}", .1, .0)]
Bounded((Bound<f64>, Bound<f64>), f64),
/// first value must be within bounds defined by second value
#[error("given `{}`, but another value {} requires it be on {:?}", .1, .2, .0)]
ParametrizedBounded((Bound<f64>, Bound<f64>), f64, f64),
/// Expected one iterator to not exhaust before another
#[error("Iterator exhausted earlier than expected")]
IteratorExhaustedEarly,
/// Containers of the same length were expected
#[error("Expected containers of same length, found one len=`{}`", .0)]
ContainersMustBeSameLength(usize),
/// Computation failed to converge,
#[error("Computation failed to converge, last iteration reached `{}` but stepped relative prec `{}`", .0, .1)]
FailedConvergence(f64, f64),
/// Elements in a container were expected to sum to a value but didn't
#[error("sum found to be {}, expected {}", .0, .1)]
ContainerExpectedSum(f64, f64),
/// Elements in a container were expected to sum to a variable but didn't
#[error("sum found to be {}, but other value specifies should be {}", .0, .1)]
ContainerExpectedSumVar(f64, f64),
/// Special case exception
#[error("{}", .0)]
SpecialCase(&'static str),
}

impl Error for StatsError {}

impl fmt::Display for StatsError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
StatsError::BadParams => write!(f, "Bad parameters, unspecified"),
StatsError::NotNan => write!(f, "value must not be NAN"),
StatsError::Finite(x) => write!(f, "given `{}`, but must be finite and not NAN", x),
StatsError::FiniteNonNegative(x) => write!(f, "given `{}`, but must be finite, non-negative and not NAN", x),
StatsError::Bounded(bound, x) => {
write!(f, "given `{}`, but must be on interval {:?}", x, bound)
}
StatsError::ParametrizedBounded(bound, x, y) => write!(
f,
"given `{}`, but another value {} requires it be on {:?}",
x, y, bound
),
StatsError::ContainersMustBeSameLength(size) => write!(
f,
"Expected containers of same length, found only one of size `{}`",
size
),
StatsError::FailedConvergence(x,prec) => write!(f, "Computation failed to converge, last iteration reached `{}` but stepped relative prec `{}`", x, prec),
StatsError::IteratorExhaustedEarly => write!(f, "Iterator exhausted earlier than expected"),
StatsError::ContainerExpectedSum(s, sum) => {
write!(f, "sum found to be {}, expected {}", s, sum)
}
StatsError::ContainerExpectedSumVar(s, sum) => {
write!(f, "sum found to be {}, but other value specifies should be {}", s, sum)
}
StatsError::SpecialCase(s) => write!(f, "{}", s),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit 1bc312a

Please sign in to comment.