Skip to content

Commit

Permalink
feat!: Concrete errors for <Distribution>::new
Browse files Browse the repository at this point in the history
Includes notable changes:
- Add internal `test_create_err` function
- Use `Beta` in unit tests for testing_boiler
- Validate Dirichlet params inside `new` and
  remove `is_valid_alpha` function
- Use `Result<Empirical, ()>` for Empirical
  (infallible `::new` function)
- Validate Multinomial params inside `new`
  and remove `check_multinomial` function
- Add a concrete error type to fisher's exact
  test, too (it is dependent on Hypergeometric,
  which is why it's included in this change)
  • Loading branch information
FreezyLemon committed Sep 8, 2024
1 parent a0716b3 commit 7349dff
Show file tree
Hide file tree
Showing 34 changed files with 1,233 additions and 466 deletions.
8 changes: 3 additions & 5 deletions src/distribution/bernoulli.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distribution::{Binomial, Discrete, DiscreteCDF};
use crate::distribution::{Binomial, BinomialError, Discrete, DiscreteCDF};
use crate::statistics::*;
use crate::Result;
use rand::Rng;

/// Implements the
Expand Down Expand Up @@ -45,7 +44,7 @@ impl Bernoulli {
/// result = Bernoulli::new(-0.5);
/// assert!(result.is_err());
/// ```
pub fn new(p: f64) -> Result<Bernoulli> {
pub fn new(p: f64) -> Result<Bernoulli, BinomialError> {
Binomial::new(p, 1).map(|b| Bernoulli { b })
}

Expand Down Expand Up @@ -266,10 +265,9 @@ impl Discrete<u64, f64> for Bernoulli {
#[cfg(test)]
mod testing {
use super::*;
use crate::StatsError;
use crate::testing_boiler;

testing_boiler!(p: f64; Bernoulli; StatsError);
testing_boiler!(p: f64; Bernoulli; BinomialError);

#[test]
fn test_create() {
Expand Down
51 changes: 40 additions & 11 deletions src/distribution/beta.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::distribution::{Continuous, ContinuousCDF};
use crate::function::{beta, gamma};
use crate::statistics::*;
use crate::{Result, StatsError};
use rand::Rng;

/// Implements the [Beta](https://en.wikipedia.org/wiki/Beta_distribution)
Expand All @@ -24,6 +23,32 @@ pub struct Beta {
shape_b: f64,
}

/// Represents the errors that can occur when creating a [`Beta`].
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
#[non_exhaustive]
pub enum BetaError {
/// Shape A is NaN, zero or negative.
ShapeAInvalid,

/// Shape B is NaN, zero or negative.
ShapeBInvalid,

/// Shape A and Shape B are infinite.
BothShapesInfinite,
}

impl std::fmt::Display for BetaError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
BetaError::ShapeAInvalid => write!(f, "Shape A is NaN, zero or negative"),
BetaError::ShapeBInvalid => write!(f, "Shape B is NaN, zero or negative"),
BetaError::BothShapesInfinite => write!(f, "Shape A and shape B are infinite"),
}
}
}

impl std::error::Error for BetaError {}

impl Beta {
/// Constructs a new beta distribution with shapeA (α) of `shape_a`
/// and shapeB (β) of `shape_b`
Expand All @@ -44,15 +69,19 @@ impl Beta {
/// result = Beta::new(0.0, 0.0);
/// assert!(result.is_err());
/// ```
pub fn new(shape_a: f64, shape_b: f64) -> Result<Beta> {
if shape_a.is_nan()
|| shape_b.is_nan()
|| shape_a.is_infinite() && shape_b.is_infinite()
|| shape_a <= 0.0
|| shape_b <= 0.0
{
return Err(StatsError::BadParams);
};
pub fn new(shape_a: f64, shape_b: f64) -> Result<Beta, BetaError> {
if shape_a.is_nan() || shape_a <= 0.0 {
return Err(BetaError::ShapeAInvalid);
}

if shape_b.is_nan() || shape_b <= 0.0 {
return Err(BetaError::ShapeBInvalid);
}

if shape_a.is_infinite() && shape_b.is_infinite() {
return Err(BetaError::BothShapesInfinite);
}

Ok(Beta { shape_a, shape_b })
}

Expand Down Expand Up @@ -433,7 +462,7 @@ mod tests {
use super::super::internal::*;
use crate::testing_boiler;

testing_boiler!(a: f64, b: f64; Beta; StatsError);
testing_boiler!(a: f64, b: f64; Beta; BetaError);

#[test]
fn test_create() {
Expand Down
25 changes: 21 additions & 4 deletions src/distribution/binomial.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::distribution::{Discrete, DiscreteCDF};
use crate::function::{beta, factorial};
use crate::statistics::*;
use crate::{Result, StatsError};
use rand::Rng;
use std::f64;

Expand All @@ -26,6 +25,24 @@ pub struct Binomial {
n: u64,
}

/// Represents the errors that can occur when creating a [`Binomial`].
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
#[non_exhaustive]
pub enum BinomialError {
/// The probability is NaN or not in `[0, 1]`.
ProbabilityInvalid,
}

impl std::fmt::Display for BinomialError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
BinomialError::ProbabilityInvalid => write!(f, "Probability is NaN or not in [0, 1]"),
}
}
}

impl std::error::Error for BinomialError {}

impl Binomial {
/// Constructs a new binomial distribution
/// with a given `p` probability of success of `n`
Expand All @@ -47,9 +64,9 @@ impl Binomial {
/// result = Binomial::new(-0.5, 5);
/// assert!(result.is_err());
/// ```
pub fn new(p: f64, n: u64) -> Result<Binomial> {
pub fn new(p: f64, n: u64) -> Result<Binomial, BinomialError> {
if p.is_nan() || !(0.0..=1.0).contains(&p) {
Err(StatsError::BadParams)
Err(BinomialError::ProbabilityInvalid)
} else {
Ok(Binomial { p, n })
}
Expand Down Expand Up @@ -332,7 +349,7 @@ mod tests {
use crate::distribution::internal::*;
use crate::testing_boiler;

testing_boiler!(p: f64, n: u64; Binomial; StatsError);
testing_boiler!(p: f64, n: u64; Binomial; BinomialError);

#[test]
fn test_create() {
Expand Down
88 changes: 68 additions & 20 deletions src/distribution/categorical.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distribution::{Discrete, DiscreteCDF};
use crate::statistics::*;
use crate::{Result, StatsError};
use rand::Rng;
use std::f64;

Expand All @@ -27,6 +26,35 @@ pub struct Categorical {
sf: Vec<f64>,
}

/// Represents the errors that can occur when creating a [`Categorical`].
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
#[non_exhaustive]
pub enum CategoricalError {
/// The probability mass is empty.
ProbMassEmpty,

/// The probabilities sums up to zero.
ProbMassSumZero,

/// The probability mass contains at least one element which is NaN or less than zero.
ProbMassHasInvalidElements,
}

impl std::fmt::Display for CategoricalError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
CategoricalError::ProbMassEmpty => write!(f, "Probability mass is empty"),
CategoricalError::ProbMassSumZero => write!(f, "Probabilities sum up to zero"),
CategoricalError::ProbMassHasInvalidElements => write!(
f,
"Probability mass contains at least one element which is NaN or less than zero"
),
}
}
}

impl std::error::Error for CategoricalError {}

impl Categorical {
/// Constructs a new categorical distribution
/// with the probabilities masses defined by `prob_mass`
Expand All @@ -52,23 +80,36 @@ impl Categorical {
/// result = Categorical::new(&[0.0, -1.0, 2.0]);
/// assert!(result.is_err());
/// ```
pub fn new(prob_mass: &[f64]) -> Result<Categorical> {
if !super::internal::is_valid_multinomial(prob_mass, true) {
Err(StatsError::BadParams)
} else {
// extract un-normalized cdf
let cdf = prob_mass_to_cdf(prob_mass);
// extract un-normalized sf
let sf = cdf_to_sf(&cdf);
// extract normalized probability mass
let sum = cdf[cdf.len() - 1];
let mut norm_pmf = vec![0.0; prob_mass.len()];
norm_pmf
.iter_mut()
.zip(prob_mass.iter())
.for_each(|(np, pm)| *np = *pm / sum);
Ok(Categorical { norm_pmf, cdf, sf })
pub fn new(prob_mass: &[f64]) -> Result<Categorical, CategoricalError> {
if prob_mass.is_empty() {
return Err(CategoricalError::ProbMassEmpty);
}

let mut prob_sum = 0.0;
for &p in prob_mass {
if p.is_nan() || p < 0.0 {
return Err(CategoricalError::ProbMassHasInvalidElements);
}

prob_sum += p;
}

if prob_sum == 0.0 {
return Err(CategoricalError::ProbMassSumZero);
}

// extract un-normalized cdf
let cdf = prob_mass_to_cdf(prob_mass);
// extract un-normalized sf
let sf = cdf_to_sf(&cdf);
// extract normalized probability mass
let sum = cdf[cdf.len() - 1];
let mut norm_pmf = vec![0.0; prob_mass.len()];
norm_pmf
.iter_mut()
.zip(prob_mass.iter())
.for_each(|(np, pm)| *np = *pm / sum);
Ok(Categorical { norm_pmf, cdf, sf })
}

fn cdf_max(&self) -> f64 {
Expand Down Expand Up @@ -355,7 +396,7 @@ mod tests {
use crate::distribution::internal::*;
use crate::testing_boiler;

testing_boiler!(prob_mass: &[f64]; Categorical; StatsError);
testing_boiler!(prob_mass: &[f64]; Categorical; CategoricalError);

#[test]
fn test_create() {
Expand All @@ -364,8 +405,15 @@ mod tests {

#[test]
fn test_bad_create() {
create_err(&[-1.0, 1.0]);
create_err(&[0.0, 0.0]);
let invalid: &[(&[f64], CategoricalError)] = &[
(&[], CategoricalError::ProbMassEmpty),
(&[-1.0, 1.0], CategoricalError::ProbMassHasInvalidElements),
(&[0.0, 0.0, 0.0], CategoricalError::ProbMassSumZero),
];

for &(prob_mass, err) in invalid {
test_create_err(prob_mass, err);
}
}

#[test]
Expand Down
53 changes: 42 additions & 11 deletions src/distribution/cauchy.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distribution::{Continuous, ContinuousCDF};
use crate::statistics::*;
use crate::{Result, StatsError};
use rand::Rng;
use std::f64;

Expand All @@ -23,6 +22,28 @@ pub struct Cauchy {
scale: f64,
}

/// Represents the errors that can occur when creating a [`Cauchy`].
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
#[non_exhaustive]
pub enum CauchyError {
/// The location is NaN.
LocationInvalid,

/// The scale is NaN, zero or less than zero.
ScaleInvalid,
}

impl std::fmt::Display for CauchyError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
CauchyError::LocationInvalid => write!(f, "Location is NaN"),
CauchyError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"),
}
}
}

impl std::error::Error for CauchyError {}

impl Cauchy {
/// Constructs a new cauchy distribution with the given
/// location and scale.
Expand All @@ -42,12 +63,16 @@ impl Cauchy {
/// result = Cauchy::new(0.0, -1.0);
/// assert!(result.is_err());
/// ```
pub fn new(location: f64, scale: f64) -> Result<Cauchy> {
if location.is_nan() || scale.is_nan() || scale <= 0.0 {
Err(StatsError::BadParams)
} else {
Ok(Cauchy { location, scale })
pub fn new(location: f64, scale: f64) -> Result<Cauchy, CauchyError> {
if location.is_nan() {
return Err(CauchyError::LocationInvalid);
}

if scale.is_nan() || scale <= 0.0 {
return Err(CauchyError::ScaleInvalid);
}

Ok(Cauchy { location, scale })
}

/// Returns the location of the cauchy distribution
Expand Down Expand Up @@ -256,7 +281,7 @@ mod tests {
use crate::distribution::internal::*;
use crate::testing_boiler;

testing_boiler!(location: f64, scale: f64; Cauchy; StatsError);
testing_boiler!(location: f64, scale: f64; Cauchy; CauchyError);

#[test]
fn test_create() {
Expand All @@ -270,10 +295,16 @@ mod tests {

#[test]
fn test_bad_create() {
create_err(f64::NAN, 1.0);
create_err(1.0, f64::NAN);
create_err(f64::NAN, f64::NAN);
create_err(1.0, 0.0);
let invalid = [
(f64::NAN, 1.0, CauchyError::LocationInvalid),
(1.0, f64::NAN, CauchyError::ScaleInvalid),
(f64::NAN, f64::NAN, CauchyError::LocationInvalid),
(1.0, 0.0, CauchyError::ScaleInvalid),
];

for (location, scale, err) in invalid {
test_create_err(location, scale, err);
}
}

#[test]
Expand Down
Loading

0 comments on commit 7349dff

Please sign in to comment.