Skip to content

Commit

Permalink
Add MultinomialError
Browse files Browse the repository at this point in the history
Move check from separate fn into `Multinomial::new`
  • Loading branch information
FreezyLemon committed Sep 5, 2024
1 parent a5e7b08 commit 1e62fa7
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 66 deletions.
52 changes: 0 additions & 52 deletions src/distribution/internal.rs
Original file line number Diff line number Diff line change
@@ -1,42 +1,5 @@
use num_traits::Num;

#[cfg(feature = "nalgebra")]
use nalgebra::{Dim, OVector};

#[cfg(feature = "nalgebra")]
pub fn check_multinomial<D>(arr: &OVector<f64, D>, accept_zeroes: bool) -> crate::Result<()>
where
D: Dim,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
{
use crate::StatsError;

if arr.len() < 2 {
return Err(StatsError::BadParams);
}
let mut sum = 0.0;
for &x in arr.iter() {
#[allow(clippy::if_same_then_else)]
if x.is_nan() {
return Err(StatsError::BadParams);
} else if x.is_infinite() {
return Err(StatsError::BadParams);
} else if x < 0.0 {
return Err(StatsError::BadParams);
} else if x == 0.0 && !accept_zeroes {
return Err(StatsError::BadParams);
} else {
sum += x;
}
}

if sum != 0.0 {
Ok(())
} else {
Err(StatsError::BadParams)
}
}

/// Implements univariate function bisection searching for criteria
/// ```text
/// smallest k such that f(k) >= z
Expand Down Expand Up @@ -485,21 +448,6 @@ pub mod test {
check_sum_pmf_is_cdf(dist, x_max);
}

#[cfg(feature = "nalgebra")]
#[test]
fn test_is_valid_multinomial() {
use std::f64;

let invalid = [1.0, f64::NAN, 3.0];
assert!(check_multinomial(&invalid.to_vec().into(), true).is_err());
let invalid2 = [-2.0, 5.0, 1.0, 6.2];
assert!(check_multinomial(&invalid2.to_vec().into(), true).is_err());
let invalid3 = [0.0, 0.0, 0.0];
assert!(check_multinomial(&invalid3.to_vec().into(), true).is_err());
let valid = [5.2, 0.0, 1e-15, 1000000.12];
assert!(check_multinomial(&valid.to_vec().into(), true).is_ok());
}

#[test]
fn test_integer_bisection() {
fn search(z: usize, data: &[usize]) -> Option<usize> {
Expand Down
2 changes: 1 addition & 1 deletion src/distribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub use self::inverse_gamma::{InverseGamma, InverseGammaError};
pub use self::laplace::{Laplace, LaplaceError};
pub use self::log_normal::{LogNormal, LogNormalError};
#[cfg(feature = "nalgebra")]
pub use self::multinomial::Multinomial;
pub use self::multinomial::{Multinomial, MultinomialError};
#[cfg(feature = "nalgebra")]
pub use self::multivariate_normal::{MultivariateNormal, MultivariateNormalError};
#[cfg(feature = "nalgebra")]
Expand Down
69 changes: 56 additions & 13 deletions src/distribution/multinomial.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::distribution::Discrete;
use crate::function::factorial;
use crate::statistics::*;
use crate::Result;
use nalgebra::{Const, DVector, Dim, Dyn, OMatrix, OVector};
use rand::Rng;

Expand Down Expand Up @@ -33,6 +32,33 @@ where
n: u64,
}

#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub enum MultinomialError {
/// Fewer than two probabilities.
NotEnoughProbabilities,

/// The sum of all probabilities is zero.
ProbabilitySumZero,

/// At least one probability is NaN, infinite, or less than zero.
ProbabilityInvalid,
}

impl std::fmt::Display for MultinomialError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
MultinomialError::NotEnoughProbabilities => write!(f, "Fewer than two probabilities"),
MultinomialError::ProbabilitySumZero => write!(f, "The probabilities sum up to zero"),
MultinomialError::ProbabilityInvalid => write!(
f,
"The probabilities contain at least one NaN, infinity, or value less than zero"
),

Check warning on line 55 in src/distribution/multinomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multinomial.rs#L48-L55

Added lines #L48 - L55 were not covered by tests
}
}

Check warning on line 57 in src/distribution/multinomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multinomial.rs#L57

Added line #L57 was not covered by tests
}

impl std::error::Error for MultinomialError {}

impl Multinomial<Dyn> {
/// Constructs a new multinomial distribution with probabilities `p`
/// and `n` number of trials.
Expand All @@ -57,7 +83,7 @@ impl Multinomial<Dyn> {
/// result = Multinomial::new(vec![0.0, -1.0, 2.0], 3);
/// assert!(result.is_err());
/// ```
pub fn new(p: Vec<f64>, n: u64) -> Result<Self> {
pub fn new(p: Vec<f64>, n: u64) -> Result<Self, MultinomialError> {
Self::new_from_nalgebra(p.into(), n)
}
}
Expand All @@ -67,14 +93,26 @@ where
D: Dim,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
{
pub fn new_from_nalgebra(mut p: OVector<f64, D>, n: u64) -> Result<Self> {
match super::internal::check_multinomial(&p, true) {
Err(e) => Err(e),
Ok(_) => {
p.unscale_mut(p.lp_norm(1));
Ok(Self { p, n })
pub fn new_from_nalgebra(mut p: OVector<f64, D>, n: u64) -> Result<Self, MultinomialError> {
if p.len() < 2 {
return Err(MultinomialError::NotEnoughProbabilities);
}

let mut sum = 0.0;
for &val in &p {
if val.is_nan() || val < 0.0 {
return Err(MultinomialError::ProbabilityInvalid);
}

sum += val;
}

if sum == 0.0 {
return Err(MultinomialError::ProbabilitySumZero);
}

p.unscale_mut(p.lp_norm(1));
Ok(Self { p, n })
}

/// Returns the probabilities of the multinomial
Expand Down Expand Up @@ -295,7 +333,7 @@ where
#[cfg(test)]
mod tests {
use crate::{
distribution::{Discrete, Multinomial},
distribution::{Discrete, Multinomial, MultinomialError},
statistics::{MeanN, VarianceN},
};
use nalgebra::{dmatrix, dvector, vector, DimMin, Dyn, OVector};
Expand All @@ -311,7 +349,7 @@ mod tests {
mvn.unwrap()
}

fn bad_create_case<D>(p: OVector<f64, D>, n: u64) -> crate::StatsError
fn bad_create_case<D>(p: OVector<f64, D>, n: u64) -> MultinomialError
where
D: DimMin<D, Output = D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
Expand Down Expand Up @@ -344,18 +382,23 @@ mod tests {

#[test]
fn test_bad_create() {
assert_eq!(
bad_create_case(vector![0.5], 4),
MultinomialError::NotEnoughProbabilities,
);

assert_eq!(
bad_create_case(vector![-1.0, 2.0], 4),
crate::StatsError::BadParams
MultinomialError::ProbabilityInvalid,
);

assert_eq!(
bad_create_case(vector![0.0, 0.0], 4),
crate::StatsError::BadParams
MultinomialError::ProbabilitySumZero,
);
assert_eq!(
bad_create_case(vector![1.0, f64::NAN], 4),
crate::StatsError::BadParams
MultinomialError::ProbabilityInvalid,
);
}

Expand Down

0 comments on commit 1e62fa7

Please sign in to comment.