Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hierarchy of errors #247

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ rand = "0.8"
nalgebra = { version = "0.32", features = ["rand"] }
approx = "0.5.0"
num-traits = "0.2.14"
thiserror = "1.0.63"

[dev-dependencies]
criterion = "0.3.3"
Expand Down
96 changes: 96 additions & 0 deletions examples/error_trials.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
extern crate statrs;

use anyhow::{anyhow, Context, Result as AnyhowResult};
use statrs::distribution::{
Continuous, Discrete, Gamma, GammaError, NegativeBinomial, Normal, ParametrizationError,
};

pub fn main() -> AnyhowResult<()> {
gamma_pdf(1.0, 1.0, 1.0).map(|x| println!("val = {}", x))?;
// val = 0.36787944117144233

normal_pdf(1.0, 1.0, 1.0).map(|x| println!("val = {}", x))?;
// val = 0.39894228040143265

gamma_pdf_with_negative_shape_correction(-0.5, 1.0, 1.0).map(|x| println!("val = {}", x))?;
// without shape correction would emit, the below
// Error: failed creating gamma(-0.5,1)
//
// Caused by:
// 0: shape must be finite, positive, and not nan
// 1: expected positive, got -0.5
// after re-attempt, output is
// Error: gamma provided invalid shape
// attempting to correct shape to 0.5
// val = 0.2075537487102974

nb_pmf(1, 1.0, 1).map(|x| println!("val = {}", x))?;
// Error: failed creating nb(1,1)
//
// Caused by:
// mean of 0 is degenerate

nb_pmf(1, 0., 1).map(|x| println!("val = {}", x))?;
// Error: failed creating nb(1,0)
//
// Caused by:
// mean of inf is degenerate

normal_pdf(1.0, f64::INFINITY, 1.0).map(|x| println!("val = {}", x))?;
// Error: failed creating normal(1, inf)
//
// Caused by:
// variance of inf is degenrate

normal_pdf(1.0, 0.0, 1.0).map(|x| println!("val = {}", x))?;
// Error: failed creating normal(1, 0)
//
// Caused by:
// variance of 0 is degenerate

Ok(())
}

pub fn gamma_pdf(shape: f64, rate: f64, x: f64) -> AnyhowResult<f64> {
Ok(Gamma::new(shape, rate)
.context(format!("failed creating gamma({},{})", shape, rate))?
.pdf(x))
}

pub fn gamma_pdf_with_negative_shape_correction(
shape: f64,
rate: f64,
x: f64,
) -> AnyhowResult<f64> {
match gamma_pdf(shape, rate, x) {
Ok(x) => Ok(x),
Err(ee) => {
if let GammaError::InvalidShape(e) = ee.downcast::<GammaError>()? {
eprintln!("Error: gamma provided invalid shape");
if let ParametrizationError::ExpectedPositive(shape) = e {
eprintln!("\tattempting to correct shape to {}", shape.abs());
// fails again for 0 and INF
gamma_pdf(shape.abs(), rate, x)
} else {
Err(anyhow!("cannot recover valid shape from this error"))
}
} else {
Err(anyhow!(
"cannot recover both valid shape and rate from this error"
))
}
}
}
}

pub fn nb_pmf(r: u64, p: f64, x: u64) -> AnyhowResult<f64> {
Ok(NegativeBinomial::new(r, p)
.context(format!("failed creating nb({},{})", r, p))?
.pmf(x))
}

pub fn normal_pdf(location: f64, scale: f64, x: f64) -> AnyhowResult<f64> {
Ok(Normal::new(location, scale)
.context(format!("failed creating normal({}, {})", location, scale))?
.pdf(x))
}
3 changes: 1 addition & 2 deletions src/distribution/chi_squared.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distribution::{Continuous, ContinuousCDF, Gamma};
use crate::statistics::*;
use crate::Result;
use rand::Rng;
use std::f64;

Expand Down Expand Up @@ -48,7 +47,7 @@ impl ChiSquared {
/// result = ChiSquared::new(0.0);
/// assert!(result.is_err());
/// ```
pub fn new(freedom: f64) -> Result<ChiSquared> {
pub fn new(freedom: f64) -> Result<ChiSquared, super::gamma::GammaError> {
Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared { freedom, g })
}

Expand Down
4 changes: 1 addition & 3 deletions src/distribution/erlang.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distribution::{Continuous, ContinuousCDF, Gamma};
use crate::statistics::*;
use crate::Result;
use rand::Rng;

/// Implements the [Erlang](https://en.wikipedia.org/wiki/Erlang_distribution)
Expand Down Expand Up @@ -45,7 +44,7 @@ impl Erlang {
/// result = Erlang::new(0, 0.0);
/// assert!(result.is_err());
/// ```
pub fn new(shape: u64, rate: f64) -> Result<Erlang> {
pub fn new(shape: u64, rate: f64) -> Result<Erlang, super::gamma::GammaError> {
Gamma::new(shape as f64, rate).map(|g| Erlang { g })
}

Expand Down Expand Up @@ -304,7 +303,6 @@ mod tests {
create_case(1, 1.0);
create_case(10, 10.0);
create_case(10, 1.0);
create_case(10, f64::INFINITY);
}

#[test]
Expand Down
110 changes: 75 additions & 35 deletions src/distribution/gamma.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,78 @@
use crate::distribution::{Continuous, ContinuousCDF};
use crate::distribution::{Continuous, ContinuousCDF, ParametrizationError as ParamError};
use crate::function::gamma;
use crate::prec;
use crate::statistics::*;
use crate::{Result, StatsError};
use rand::Rng;
use thiserror::Error;

#[derive(Clone, PartialEq, Debug, Error)]
pub enum GammaError {
#[error("shape must be finite, positive, and not nan")]
InvalidShape(#[source] ParamError<f64>),
#[error("shape must be finite, positive, and not nan")]
InvalidRate(#[source] ParamError<f64>),
#[error("rate of {0} is degenerate")]
DegenerateRate(f64),
#[error("shape of {0} is degenerate")]
DegenerateShape(f64),
}

impl From<super::negative_binomial::NegativeBinomialError> for GammaError {
fn from(value: super::negative_binomial::NegativeBinomialError) -> Self {
use super::negative_binomial::NegativeBinomialError::*;
match value {
InvalidMean(e) => Self::InvalidShape(e),
InvalidProbability(p) => {
if p.is_nan() {
Self::InvalidRate(ParamError::ExpectedNotNan)
} else {
Self::InvalidRate(ParamError::ExpectedPositive(p / (1.0 - p)))
}
}
InvalidSuccessCount(e) => Self::InvalidRate(e.into()),
DegenerateMean(m) => Self::DegenerateShape(m),
DegenerateProbability(p) => Self::DegenerateRate(p / (1.0 - p)),
DegenerateSuccessCount => Self::DegenerateRate(0.0),
}
}
}

/// holds a valid parametrization of the gamma distribution in shape and rate.
pub struct Parameters {
shape: f64,
rate: f64,
}

Comment on lines +41 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why you added this abstraction? You could also check the parameters inside Gamma::new and return a suitable error from there, right?

Copy link
Contributor Author

@YeungOnion YeungOnion Jul 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly just playing around with having input checking in another type to support the use of infallible new methods.

But more interestingly perhaps it could be a way to get alternative parametrizations #198 as well as new_from_moments to do something like,

let m = Gamma::new(shape, rate).into_moments();
let p = Poisson::new_from_moments(m);

But I've lots of thoughts floating around about that none motivated by something I've actually needed thus far, so happy to drop it until specific need arises. As a solution to #198 could be adding just one more method.

impl Parameters {
pub fn new(shape: f64, rate: f64) -> Result<Self, GammaError> {
if shape.is_nan() {
Err(GammaError::InvalidShape(ParamError::ExpectedNotNan))
} else if rate.is_nan() {
Err(GammaError::InvalidRate(ParamError::ExpectedNotNan))
} else if rate <= 0.0 {
Err(GammaError::InvalidRate(ParamError::ExpectedPositive(rate)))
} else if shape <= 0.0 {
Err(GammaError::InvalidShape(ParamError::ExpectedPositive(
shape,
)))
} else if rate.is_infinite() {
Err(GammaError::DegenerateRate(rate))
} else if shape.is_infinite() {
Err(GammaError::DegenerateShape(shape))
} else {
Ok(Self { shape, rate })
}
}
}

impl From<Parameters> for Gamma {
fn from(value: Parameters) -> Self {
Gamma {
shape: value.shape,
rate: value.rate,
}
}
}

/// Implements the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution)
/// distribution
Expand Down Expand Up @@ -45,16 +114,8 @@ impl Gamma {
/// result = Gamma::new(0.0, 0.0);
/// assert!(result.is_err());
/// ```
pub fn new(shape: f64, rate: f64) -> Result<Gamma> {
if shape.is_nan()
|| rate.is_nan()
|| shape.is_infinite() && rate.is_infinite()
|| shape <= 0.0
|| rate <= 0.0
{
return Err(StatsError::BadParams);
}
Ok(Gamma { shape, rate })
pub fn new(shape: f64, rate: f64) -> Result<Gamma, GammaError> {
Ok(Parameters::new(shape, rate)?.into())
}

/// Returns the shape (α) of the gamma distribution
Expand Down Expand Up @@ -414,13 +475,7 @@ mod tests {

#[test]
fn test_create() {
let valid = [
(1.0, 0.1),
(1.0, 1.0),
(10.0, 10.0),
(10.0, 1.0),
(10.0, f64::INFINITY),
];
let valid = [(1.0, 0.1), (1.0, 1.0), (10.0, 10.0), (10.0, 1.0)];

for (s, r) in valid {
try_create(s, r);
Expand Down Expand Up @@ -450,7 +505,6 @@ mod tests {
((1.0, 1.0), 1.0),
((10.0, 10.0), 1.0),
((10.0, 1.0), 10.0),
((10.0, f64::INFINITY), 0.0),
];
for ((s, r), res) in test {
test_case(s, r, res, f);
Expand All @@ -465,7 +519,6 @@ mod tests {
((1.0, 1.0), 1.0),
((10.0, 10.0), 0.1),
((10.0, 1.0), 10.0),
((10.0, f64::INFINITY), 0.0),
];
for ((s, r), res) in test {
test_case(s, r, res, f);
Expand All @@ -480,7 +533,6 @@ mod tests {
((1.0, 1.0), 1.0),
((10.0, 10.0), 0.2334690854869339583626209),
((10.0, 1.0), 2.53605417848097964238061239),
((10.0, f64::INFINITY), f64::NEG_INFINITY),
];
for ((s, r), res) in test {
test_case(s, r, res, f);
Expand All @@ -495,7 +547,6 @@ mod tests {
((1.0, 1.0), 2.0),
((10.0, 10.0), 0.6324555320336758663997787),
((10.0, 1.0), 0.63245553203367586639977870),
((10.0, f64::INFINITY), 0.6324555320336758),
];
for ((s, r), res) in test {
test_case(s, r, res, f);
Expand All @@ -509,11 +560,7 @@ mod tests {
for &((s, r), res) in test.iter() {
test_case_special(s, r, res, 10e-6, f);
}
let test = [
((10.0, 10.0), 0.9),
((10.0, 1.0), 9.0),
((10.0, f64::INFINITY), 0.0),
];
let test = [((10.0, 10.0), 0.9), ((10.0, 1.0), 9.0)];
for ((s, r), res) in test {
test_case(s, r, res, f);
}
Expand All @@ -527,7 +574,6 @@ mod tests {
((1.0, 1.0), 0.0),
((10.0, 10.0), 0.0),
((10.0, 1.0), 0.0),
((10.0, f64::INFINITY), 0.0),
];
for ((s, r), res) in test {
test_case(s, r, res, f);
Expand All @@ -538,7 +584,6 @@ mod tests {
((1.0, 1.0), f64::INFINITY),
((10.0, 10.0), f64::INFINITY),
((10.0, 1.0), f64::INFINITY),
((10.0, f64::INFINITY), f64::INFINITY),
];
for ((s, r), res) in test {
test_case(s, r, res, f);
Expand Down Expand Up @@ -585,7 +630,6 @@ mod tests {
((10.0, 10.0), 10.0, -69.0527107131946016148658),
((10.0, 1.0), 1.0, -13.8018274800814696112077),
((10.0, 1.0), 10.0, -2.07856164313505845504579),
((10.0, f64::INFINITY), f64::INFINITY, f64::NEG_INFINITY),
];
for ((s, r), x, res) in test {
test_case(s, r, res, f(x));
Expand All @@ -606,8 +650,6 @@ mod tests {
((10.0, 10.0), 10.0, 0.999999999999999999999999),
((10.0, 1.0), 1.0, 0.000000111425478338720677),
((10.0, 1.0), 10.0, 0.542070285528147791685835),
((10.0, f64::INFINITY), 1.0, 0.0),
((10.0, f64::INFINITY), 10.0, 1.0),
];
for ((s, r), x, res) in test {
test_case(s, r, res, f(x));
Expand Down Expand Up @@ -657,8 +699,6 @@ mod tests {
((10.0, 10.0), 10.0, 1.1253473960842808e-31),
((10.0, 1.0), 1.0, 0.9999998885745217),
((10.0, 1.0), 10.0, 0.4579297144718528),
((10.0, f64::INFINITY), 1.0, 1.0),
((10.0, f64::INFINITY), 10.0, 0.0),
];
for ((s, r), x, res) in test {
test_case(s, r, res, f(x));
Expand Down
Loading
Loading