Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace all instances of StatsError with custom error types or Option<T> #284

Merged
merged 8 commits into from
Sep 12, 2024
47 changes: 20 additions & 27 deletions src/distribution/multivariate_normal.rs
Original file line number Diff line number Diff line change
@@ -1,42 +1,37 @@
use crate::distribution::Continuous;
use crate::statistics::{Max, MeanN, Min, Mode, VarianceN};
use crate::StatsError;
use nalgebra::{Cholesky, Const, DMatrix, DVector, Dim, DimMin, Dyn, OMatrix, OVector};
use std::f64;
use std::f64::consts::{E, PI};

/// computes both the normalization and exponential argument in the normal distribution
/// # Errors
/// will error on dimension mismatch
/// Computes both the normalization and exponential argument in the normal
/// distribution, returning `None` on dimension mismatch.
pub(super) fn density_normalization_and_exponential<D>(
mu: &OVector<f64, D>,
cov: &OMatrix<f64, D, D>,
precision: &OMatrix<f64, D, D>,
x: &OVector<f64, D>,
) -> std::result::Result<(f64, f64), StatsError>
) -> Option<(f64, f64)>
where
D: DimMin<D, Output = D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>
+ nalgebra::allocator::Allocator<f64, D, D>
+ nalgebra::allocator::Allocator<(usize, usize), D>,
{
Ok((
Some((
density_distribution_pdf_const(mu, cov)?,
density_distribution_exponential(mu, precision, x)?,
))
}

/// computes the argument of the exponential term in the normal distribution
/// ```text
/// ```
/// # Errors
/// will error on dimension mismatch
/// Computes the argument of the exponential term in the normal distribution,
/// returning `None` on dimension mismatch.
#[inline]
pub(super) fn density_distribution_exponential<D>(
fn density_distribution_exponential<D>(
mu: &OVector<f64, D>,
precision: &OMatrix<f64, D, D>,
x: &OVector<f64, D>,
) -> std::result::Result<f64, StatsError>
) -> Option<f64>
where
D: Dim,
nalgebra::DefaultAllocator:
Expand All @@ -46,35 +41,33 @@
|| x.shape_generic().0 != mu.shape_generic().0
|| !precision.is_square()
{
return Err(StatsError::ContainersMustBeSameLength);
return None;
}

let dv = x - mu;
let exp_term: f64 = -0.5 * (precision * &dv).dot(&dv);
Ok(exp_term)
// TODO update to dimension mismatch error
Some(exp_term)
}

/// computes the argument of the normalization term in the normal distribution
/// # Errors
/// will error on dimension mismatch
/// Computes the argument of the normalization term in the normal distribution,
/// returning `None` on dimension mismatch.
#[inline]
pub(super) fn density_distribution_pdf_const<D>(
mu: &OVector<f64, D>,
cov: &OMatrix<f64, D, D>,
) -> std::result::Result<f64, StatsError>
fn density_distribution_pdf_const<D>(mu: &OVector<f64, D>, cov: &OMatrix<f64, D, D>) -> Option<f64>
where
D: DimMin<D, Output = D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>
+ nalgebra::allocator::Allocator<f64, D, D>
+ nalgebra::allocator::Allocator<(usize, usize), D>,
{
if cov.shape_generic().0 != mu.shape_generic().0 || !cov.is_square() {
return Err(StatsError::ContainersMustBeSameLength);
return None;

Check warning on line 63 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L63

Added line #L63 was not covered by tests
}
let cov_det = cov.determinant();
Ok(((2. * PI).powi(mu.nrows() as i32) * cov_det.abs())
.recip()
.sqrt())
Some(
((2. * PI).powi(mu.nrows() as i32) * cov_det.abs())
.recip()
.sqrt(),
)
}

/// Implements the [Multivariate Normal](https://en.wikipedia.org/wiki/Multivariate_normal_distribution)
Expand Down
11 changes: 7 additions & 4 deletions src/distribution/multivariate_students_t.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,13 @@ where
/// [Γ(ν+p)/2] / [Γ(ν/2) ((ν * π)^p det(Σ))^(1 / 2)] * [1 + 1/ν (x - μ)ᵀ inv(Σ) (x - μ)]^(-(ν+p)/2)
/// ```
///
/// where `ν` is the degrees of freedom, `μ` is the mean, `Γ`
/// is the Gamma function, `inv(Σ)`
/// is the precision matrix, `det(Σ)` is the determinant
/// of the scale matrix, and `k` is the dimension of the distribution.
/// where
/// - `ν` is the degrees of freedom,
/// - `μ` is the mean,
/// - `Γ` is the Gamma function,
/// - `inv(Σ)` is the precision matrix,
/// - `det(Σ)` is the determinant of the scale matrix, and
/// - `k` is the dimension of the distribution.
fn pdf(&self, x: &'a OVector<f64, D>) -> f64 {
if self.freedom.is_infinite() {
use super::multivariate_normal::density_normalization_and_exponential;
Expand Down
199 changes: 118 additions & 81 deletions src/function/beta.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,37 @@
//! Provides the [beta](https://en.wikipedia.org/wiki/Beta_function) and related
//! function

use crate::error::StatsError;
use crate::function::gamma;
use crate::prec;
use crate::Result;
use std::f64;

/// Represents the errors that can occur when computing the natural logarithm
/// of the beta function or the regularized lower incomplete beta function.
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
#[non_exhaustive]
pub enum BetaFuncError {
/// `a` is zero or less than zero.
ANotGreaterThanZero,

/// `b` is zero or less than zero.
BNotGreaterThanZero,

/// `x` is not in `[0, 1]`.
XOutOfRange,
}

impl std::fmt::Display for BetaFuncError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
BetaFuncError::ANotGreaterThanZero => write!(f, "a is zero or less than zero"),
BetaFuncError::BNotGreaterThanZero => write!(f, "b is zero or less than zero"),
BetaFuncError::XOutOfRange => write!(f, "x is not in [0, 1]"),

Check warning on line 28 in src/function/beta.rs

View check run for this annotation

Codecov / codecov/patch

src/function/beta.rs#L24-L28

Added lines #L24 - L28 were not covered by tests
}
}

Check warning on line 30 in src/function/beta.rs

View check run for this annotation

Codecov / codecov/patch

src/function/beta.rs#L30

Added line #L30 was not covered by tests
}

impl std::error::Error for BetaFuncError {}

/// Computes the natural logarithm
/// of the beta function
/// where `a` is the first beta parameter
Expand All @@ -29,11 +54,11 @@
/// # Errors
///
/// if `a <= 0.0` or `b <= 0.0`
pub fn checked_ln_beta(a: f64, b: f64) -> Result<f64> {
pub fn checked_ln_beta(a: f64, b: f64) -> Result<f64, BetaFuncError> {
if a <= 0.0 {
Err(StatsError::ArgMustBePositive("a"))
Err(BetaFuncError::ANotGreaterThanZero)
} else if b <= 0.0 {
Err(StatsError::ArgMustBePositive("b"))
Err(BetaFuncError::BNotGreaterThanZero)
} else {
Ok(gamma::ln_gamma(a) + gamma::ln_gamma(b) - gamma::ln_gamma(a + b))
}
Expand All @@ -59,7 +84,7 @@
/// # Errors
///
/// if `a <= 0.0` or `b <= 0.0`
pub fn checked_beta(a: f64, b: f64) -> Result<f64> {
pub fn checked_beta(a: f64, b: f64) -> Result<f64, BetaFuncError> {
checked_ln_beta(a, b).map(|x| x.exp())
}

Expand All @@ -83,7 +108,7 @@
/// # Errors
///
/// If `a <= 0.0`, `b <= 0.0`, `x < 0.0`, or `x > 1.0`
pub fn checked_beta_inc(a: f64, b: f64, x: f64) -> Result<f64> {
pub fn checked_beta_inc(a: f64, b: f64, x: f64) -> Result<f64, BetaFuncError> {
checked_beta_reg(a, b, x).and_then(|x| checked_beta(a, b).map(|y| x * y))
}

Expand All @@ -109,96 +134,100 @@
/// # Errors
///
/// if `a <= 0.0`, `b <= 0.0`, `x < 0.0`, or `x > 1.0`
pub fn checked_beta_reg(a: f64, b: f64, x: f64) -> Result<f64> {
pub fn checked_beta_reg(a: f64, b: f64, x: f64) -> Result<f64, BetaFuncError> {
if a <= 0.0 {
Err(StatsError::ArgMustBePositive("a"))
} else if b <= 0.0 {
Err(StatsError::ArgMustBePositive("b"))
} else if !(0.0..=1.0).contains(&x) {
Err(StatsError::ArgIntervalIncl("x", 0.0, 1.0))
return Err(BetaFuncError::ANotGreaterThanZero);
}

if b <= 0.0 {
return Err(BetaFuncError::BNotGreaterThanZero);
}

if !(0.0..=1.0).contains(&x) {
return Err(BetaFuncError::XOutOfRange);
}

let bt = if x == 0.0 || ulps_eq!(x, 1.0) {
0.0
} else {
let bt = if x == 0.0 || ulps_eq!(x, 1.0) {
0.0
} else {
(gamma::ln_gamma(a + b) - gamma::ln_gamma(a) - gamma::ln_gamma(b)
+ a * x.ln()
+ b * (1.0 - x).ln())
.exp()
};
let symm_transform = x >= (a + 1.0) / (a + b + 2.0);
let eps = prec::F64_PREC;
let fpmin = f64::MIN_POSITIVE / eps;

let mut a = a;
let mut b = b;
let mut x = x;
if symm_transform {
let swap = a;
x = 1.0 - x;
a = b;
b = swap;
}
(gamma::ln_gamma(a + b) - gamma::ln_gamma(a) - gamma::ln_gamma(b)
+ a * x.ln()
+ b * (1.0 - x).ln())
.exp()
};
let symm_transform = x >= (a + 1.0) / (a + b + 2.0);
let eps = prec::F64_PREC;
let fpmin = f64::MIN_POSITIVE / eps;

let mut a = a;
let mut b = b;
let mut x = x;
if symm_transform {
let swap = a;
x = 1.0 - x;
a = b;
b = swap;
}

let qab = a + b;
let qap = a + 1.0;
let qam = a - 1.0;
let mut c = 1.0;
let mut d = 1.0 - qab * x / qap;

if d.abs() < fpmin {
d = fpmin;

Check warning on line 179 in src/function/beta.rs

View check run for this annotation

Codecov / codecov/patch

src/function/beta.rs#L179

Added line #L179 was not covered by tests
}
d = 1.0 / d;
let mut h = d;

let qab = a + b;
let qap = a + 1.0;
let qam = a - 1.0;
let mut c = 1.0;
let mut d = 1.0 - qab * x / qap;
for m in 1..141 {
let m = f64::from(m);
let m2 = m * 2.0;
let mut aa = m * (b - m) * x / ((qam + m2) * (a + m2));
d = 1.0 + aa * d;

if d.abs() < fpmin {
d = fpmin;
}
d = 1.0 / d;
let mut h = d;

for m in 1..141 {
let m = f64::from(m);
let m2 = m * 2.0;
let mut aa = m * (b - m) * x / ((qam + m2) * (a + m2));
d = 1.0 + aa * d;

if d.abs() < fpmin {
d = fpmin;
}

c = 1.0 + aa / c;
if c.abs() < fpmin {
c = fpmin;
}
c = 1.0 + aa / c;
if c.abs() < fpmin {
c = fpmin;

Check warning on line 196 in src/function/beta.rs

View check run for this annotation

Codecov / codecov/patch

src/function/beta.rs#L196

Added line #L196 was not covered by tests
}

d = 1.0 / d;
h = h * d * c;
aa = -(a + m) * (qab + m) * x / ((a + m2) * (qap + m2));
d = 1.0 + aa * d;
d = 1.0 / d;
h = h * d * c;
aa = -(a + m) * (qab + m) * x / ((a + m2) * (qap + m2));
d = 1.0 + aa * d;

if d.abs() < fpmin {
d = fpmin;
}
if d.abs() < fpmin {
d = fpmin;

Check warning on line 205 in src/function/beta.rs

View check run for this annotation

Codecov / codecov/patch

src/function/beta.rs#L205

Added line #L205 was not covered by tests
}

c = 1.0 + aa / c;
c = 1.0 + aa / c;

if c.abs() < fpmin {
c = fpmin;
}
if c.abs() < fpmin {
c = fpmin;

Check warning on line 211 in src/function/beta.rs

View check run for this annotation

Codecov / codecov/patch

src/function/beta.rs#L211

Added line #L211 was not covered by tests
}

d = 1.0 / d;
let del = d * c;
h *= del;
d = 1.0 / d;
let del = d * c;
h *= del;

if (del - 1.0).abs() <= eps {
return if symm_transform {
Ok(1.0 - bt * h / a)
} else {
Ok(bt * h / a)
};
}
if (del - 1.0).abs() <= eps {
return if symm_transform {
Ok(1.0 - bt * h / a)
} else {
Ok(bt * h / a)
};
}
}

if symm_transform {
Ok(1.0 - bt * h / a)
} else {
Ok(bt * h / a)
}
if symm_transform {
Ok(1.0 - bt * h / a)

Check warning on line 228 in src/function/beta.rs

View check run for this annotation

Codecov / codecov/patch

src/function/beta.rs#L227-L228

Added lines #L227 - L228 were not covered by tests
} else {
Ok(bt * h / a)

Check warning on line 230 in src/function/beta.rs

View check run for this annotation

Codecov / codecov/patch

src/function/beta.rs#L230

Added line #L230 was not covered by tests
}
}

Expand Down Expand Up @@ -396,6 +425,8 @@
#[rustfmt::skip]
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_ln_beta() {
assert_almost_eq!(super::ln_beta(0.5, 0.5), 1.144729885849400174144, 1e-15);
Expand Down Expand Up @@ -597,4 +628,10 @@
fn test_checked_beta_reg_x_gt_1() {
assert!(super::checked_beta_reg(1.0, 1.0, 2.0).is_err());
}

#[test]
fn test_error_is_sync_send() {
fn assert_sync_send<T: Sync + Send>() {}
assert_sync_send::<BetaFuncError>();
}
}
Loading
Loading