Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement concrete Error types for each distribution's new function #265

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,8 @@ anyhow = "1.0"
version = "0.32"
default-features = false
features = ["macros"]

[lints.rust.unexpected_cfgs]
level = "warn"
# Set by cargo-llvm-cov when running on nightly
check-cfg = ['cfg(coverage_nightly)']
10 changes: 4 additions & 6 deletions src/distribution/bernoulli.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distribution::{Binomial, Discrete, DiscreteCDF};
use crate::distribution::{Binomial, BinomialError, Discrete, DiscreteCDF};
use crate::statistics::*;
use crate::Result;
use rand::Rng;

/// Implements the
Expand Down Expand Up @@ -45,7 +44,7 @@ impl Bernoulli {
/// result = Bernoulli::new(-0.5);
/// assert!(result.is_err());
/// ```
pub fn new(p: f64) -> Result<Bernoulli> {
pub fn new(p: f64) -> Result<Bernoulli, BinomialError> {
Binomial::new(p, 1).map(|b| Bernoulli { b })
}

Expand Down Expand Up @@ -265,11 +264,10 @@ impl Discrete<u64, f64> for Bernoulli {
#[rustfmt::skip]
#[cfg(test)]
mod testing {
use crate::distribution::DiscreteCDF;
use super::*;
use crate::testing_boiler;
use super::Bernoulli;

testing_boiler!(p: f64; Bernoulli);
testing_boiler!(p: f64; Bernoulli; BinomialError);

#[test]
fn test_create() {
Expand Down
52 changes: 41 additions & 11 deletions src/distribution/beta.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::distribution::{Continuous, ContinuousCDF};
use crate::function::{beta, gamma};
use crate::statistics::*;
use crate::{Result, StatsError};
use rand::Rng;

/// Implements the [Beta](https://en.wikipedia.org/wiki/Beta_distribution)
Expand All @@ -24,6 +23,33 @@ pub struct Beta {
shape_b: f64,
}

/// Represents the errors that can occur when creating a [`Beta`].
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
#[non_exhaustive]
pub enum BetaError {
/// Shape A is NaN, zero or negative.
ShapeAInvalid,

/// Shape B is NaN, zero or negative.
ShapeBInvalid,

/// Shape A and Shape B are infinite.
BothShapesInfinite,
}

impl std::fmt::Display for BetaError {
#[cfg_attr(coverage_nightly, coverage(off))]
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
BetaError::ShapeAInvalid => write!(f, "Shape A is NaN, zero or negative"),
BetaError::ShapeBInvalid => write!(f, "Shape B is NaN, zero or negative"),
BetaError::BothShapesInfinite => write!(f, "Shape A and shape B are infinite"),
}
}
}

impl std::error::Error for BetaError {}

impl Beta {
/// Constructs a new beta distribution with shapeA (α) of `shape_a`
/// and shapeB (β) of `shape_b`
Expand All @@ -44,15 +70,19 @@ impl Beta {
/// result = Beta::new(0.0, 0.0);
/// assert!(result.is_err());
/// ```
pub fn new(shape_a: f64, shape_b: f64) -> Result<Beta> {
if shape_a.is_nan()
|| shape_b.is_nan()
|| shape_a.is_infinite() && shape_b.is_infinite()
|| shape_a <= 0.0
|| shape_b <= 0.0
{
return Err(StatsError::BadParams);
};
pub fn new(shape_a: f64, shape_b: f64) -> Result<Beta, BetaError> {
if shape_a.is_nan() || shape_a <= 0.0 {
return Err(BetaError::ShapeAInvalid);
}

if shape_b.is_nan() || shape_b <= 0.0 {
return Err(BetaError::ShapeBInvalid);
}

if shape_a.is_infinite() && shape_b.is_infinite() {
return Err(BetaError::BothShapesInfinite);
}

Ok(Beta { shape_a, shape_b })
}

Expand Down Expand Up @@ -433,7 +463,7 @@ mod tests {
use super::super::internal::*;
use crate::testing_boiler;

testing_boiler!(a: f64, b: f64; Beta);
testing_boiler!(a: f64, b: f64; Beta; BetaError);

#[test]
fn test_create() {
Expand Down
29 changes: 23 additions & 6 deletions src/distribution/binomial.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::distribution::{Discrete, DiscreteCDF};
use crate::function::{beta, factorial};
use crate::statistics::*;
use crate::{Result, StatsError};
use rand::Rng;
use std::f64;

Expand All @@ -26,6 +25,25 @@ pub struct Binomial {
n: u64,
}

/// Represents the errors that can occur when creating a [`Binomial`].
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
#[non_exhaustive]
pub enum BinomialError {
/// The probability is NaN or not in `[0, 1]`.
ProbabilityInvalid,
}

impl std::fmt::Display for BinomialError {
#[cfg_attr(coverage_nightly, coverage(off))]
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
BinomialError::ProbabilityInvalid => write!(f, "Probability is NaN or not in [0, 1]"),
}
}
}

impl std::error::Error for BinomialError {}

impl Binomial {
/// Constructs a new binomial distribution
/// with a given `p` probability of success of `n`
Expand All @@ -47,9 +65,9 @@ impl Binomial {
/// result = Binomial::new(-0.5, 5);
/// assert!(result.is_err());
/// ```
pub fn new(p: f64, n: u64) -> Result<Binomial> {
pub fn new(p: f64, n: u64) -> Result<Binomial, BinomialError> {
if p.is_nan() || !(0.0..=1.0).contains(&p) {
Err(StatsError::BadParams)
Err(BinomialError::ProbabilityInvalid)
} else {
Ok(Binomial { p, n })
}
Expand Down Expand Up @@ -328,12 +346,11 @@ impl Discrete<u64, f64> for Binomial {
#[rustfmt::skip]
#[cfg(test)]
mod tests {
use crate::statistics::*;
use crate::distribution::{DiscreteCDF, Discrete, Binomial};
use super::*;
use crate::distribution::internal::*;
use crate::testing_boiler;

testing_boiler!(p: f64, n: u64; Binomial);
testing_boiler!(p: f64, n: u64; Binomial; BinomialError);

#[test]
fn test_create() {
Expand Down
92 changes: 70 additions & 22 deletions src/distribution/categorical.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distribution::{Discrete, DiscreteCDF};
use crate::statistics::*;
use crate::{Result, StatsError};
use rand::Rng;
use std::f64;

Expand All @@ -27,6 +26,36 @@ pub struct Categorical {
sf: Vec<f64>,
}

/// Represents the errors that can occur when creating a [`Categorical`].
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
#[non_exhaustive]
pub enum CategoricalError {
/// The probability mass is empty.
ProbMassEmpty,

/// The probabilities sums up to zero.
ProbMassSumZero,

/// The probability mass contains at least one element which is NaN or less than zero.
ProbMassHasInvalidElements,
}

impl std::fmt::Display for CategoricalError {
#[cfg_attr(coverage_nightly, coverage(off))]
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
CategoricalError::ProbMassEmpty => write!(f, "Probability mass is empty"),
CategoricalError::ProbMassSumZero => write!(f, "Probabilities sum up to zero"),
CategoricalError::ProbMassHasInvalidElements => write!(
f,
"Probability mass contains at least one element which is NaN or less than zero"
),
}
}
}

impl std::error::Error for CategoricalError {}

impl Categorical {
/// Constructs a new categorical distribution
/// with the probabilities masses defined by `prob_mass`
Expand All @@ -52,23 +81,36 @@ impl Categorical {
/// result = Categorical::new(&[0.0, -1.0, 2.0]);
/// assert!(result.is_err());
/// ```
pub fn new(prob_mass: &[f64]) -> Result<Categorical> {
if !super::internal::is_valid_multinomial(prob_mass, true) {
Err(StatsError::BadParams)
} else {
// extract un-normalized cdf
let cdf = prob_mass_to_cdf(prob_mass);
// extract un-normalized sf
let sf = cdf_to_sf(&cdf);
// extract normalized probability mass
let sum = cdf[cdf.len() - 1];
let mut norm_pmf = vec![0.0; prob_mass.len()];
norm_pmf
.iter_mut()
.zip(prob_mass.iter())
.for_each(|(np, pm)| *np = *pm / sum);
Ok(Categorical { norm_pmf, cdf, sf })
pub fn new(prob_mass: &[f64]) -> Result<Categorical, CategoricalError> {
if prob_mass.is_empty() {
return Err(CategoricalError::ProbMassEmpty);
}

let mut prob_sum = 0.0;
for &p in prob_mass {
if p.is_nan() || p < 0.0 {
return Err(CategoricalError::ProbMassHasInvalidElements);
}

prob_sum += p;
}

if prob_sum == 0.0 {
return Err(CategoricalError::ProbMassSumZero);
}

// extract un-normalized cdf
let cdf = prob_mass_to_cdf(prob_mass);
// extract un-normalized sf
let sf = cdf_to_sf(&cdf);
// extract normalized probability mass
let sum = cdf[cdf.len() - 1];
let mut norm_pmf = vec![0.0; prob_mass.len()];
norm_pmf
.iter_mut()
.zip(prob_mass.iter())
.for_each(|(np, pm)| *np = *pm / sum);
Ok(Categorical { norm_pmf, cdf, sf })
}

fn cdf_max(&self) -> f64 {
Expand Down Expand Up @@ -351,12 +393,11 @@ fn test_binary_index() {
#[rustfmt::skip]
#[cfg(test)]
mod tests {
use crate::statistics::*;
use crate::distribution::{Categorical, Discrete, DiscreteCDF};
use super::*;
use crate::distribution::internal::*;
use crate::testing_boiler;

testing_boiler!(prob_mass: &[f64]; Categorical);
testing_boiler!(prob_mass: &[f64]; Categorical; CategoricalError);

#[test]
fn test_create() {
Expand All @@ -365,8 +406,15 @@ mod tests {

#[test]
fn test_bad_create() {
create_err(&[-1.0, 1.0]);
create_err(&[0.0, 0.0]);
let invalid: &[(&[f64], CategoricalError)] = &[
(&[], CategoricalError::ProbMassEmpty),
(&[-1.0, 1.0], CategoricalError::ProbMassHasInvalidElements),
(&[0.0, 0.0, 0.0], CategoricalError::ProbMassSumZero),
];

for &(prob_mass, err) in invalid {
test_create_err(prob_mass, err);
}
}

#[test]
Expand Down
Loading