Skip to content

Commit

Permalink
Add MultivariateStudentError
Browse files Browse the repository at this point in the history
  • Loading branch information
FreezyLemon committed Sep 5, 2024
1 parent f32b5a1 commit a5e7b08
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/distribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub use self::multinomial::Multinomial;
#[cfg(feature = "nalgebra")]
pub use self::multivariate_normal::{MultivariateNormal, MultivariateNormalError};
#[cfg(feature = "nalgebra")]
pub use self::multivariate_students_t::MultivariateStudent;
pub use self::multivariate_students_t::{MultivariateStudent, MultivariateStudentError};
pub use self::negative_binomial::{NegativeBinomial, NegativeBinomialError};
pub use self::normal::{Normal, NormalError};
pub use self::pareto::{Pareto, ParetoError};
Expand Down
84 changes: 66 additions & 18 deletions src/distribution/multivariate_students_t.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::distribution::Continuous;
use crate::distribution::{ChiSquared, Normal};
use crate::function::gamma;
use crate::statistics::{Max, MeanN, Min, Mode, VarianceN};
use crate::{Result, StatsError};
use nalgebra::{Cholesky, Const, DMatrix, Dim, DimMin, Dyn, OMatrix, OVector};
use rand::Rng;
use std::f64::consts::PI;
Expand Down Expand Up @@ -39,6 +38,51 @@ where
ln_pdf_const: f64,
}

#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub enum MultivariateStudentError {
/// The scale matrix is asymmetric or contains a NaN.
ScaleInvalid,

/// The location vector contains a NaN.
LocationInvalid,

/// The degrees of freedom are NaN, zero, or less than zero.
FreedomInvalid,

/// The amount of rows in the location vector is not equal to the amount
/// of rows in the scale matrix.
DimensionMismatch,

/// After all other validation, computing the Cholesky decomposition failed.
/// This means that the scale matrix is not definite-positive.
CholeskyFailed,
}

impl std::fmt::Display for MultivariateStudentError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
MultivariateStudentError::ScaleInvalid => {
write!(f, "Scale matrix is asymmetric or contains a NaN")
}
MultivariateStudentError::LocationInvalid => {
write!(f, "Location vector contains a NaN")
}
MultivariateStudentError::FreedomInvalid => {
write!(f, "Degrees of freedom are NaN, zero, or less than zero")
}
MultivariateStudentError::DimensionMismatch => write!(
f,
"Location vector and scale matrix do not have the same number of rows"
),
MultivariateStudentError::CholeskyFailed => {
write!(f, "Computing the Cholesky decomposition failed")
}
}
}
}

impl std::error::Error for MultivariateStudentError {}

impl MultivariateStudent<Dyn> {
/// Constructs a new multivariate students t distribution with a location of `location`,
/// scale matrix `scale` and `freedom` degrees of freedom.
Expand All @@ -47,7 +91,11 @@ impl MultivariateStudent<Dyn> {
///
/// Returns `StatsError::BadParams` if the scale matrix is not symmetric-positive
/// definite and `StatsError::ArgMustBePositive` if freedom is non-positive.
pub fn new(location: Vec<f64>, scale: Vec<f64>, freedom: f64) -> Result<Self> {
pub fn new(
location: Vec<f64>,
scale: Vec<f64>,
freedom: f64,
) -> Result<Self, MultivariateStudentError> {
let dim = location.len();
Self::new_from_nalgebra(location.into(), DMatrix::from_vec(dim, dim, scale), freedom)
}
Expand All @@ -69,26 +117,26 @@ where
location: OVector<f64, D>,
scale: OMatrix<f64, D, D>,
freedom: f64,
) -> Result<Self> {
) -> Result<Self, MultivariateStudentError> {
let dim = location.len();

// Check that the provided scale matrix is symmetric
if scale.lower_triangle() != scale.upper_triangle().transpose()
// Check that mean and covariance do not contain NaN
|| location.iter().any(|f| f.is_nan())
if location.iter().any(|f| f.is_nan()) {
return Err(MultivariateStudentError::LocationInvalid);
}

if !scale.is_square()
|| scale.lower_triangle() != scale.upper_triangle().transpose()
|| scale.iter().any(|f| f.is_nan())
// Check that the dimensions match
|| location.nrows() != scale.nrows() || scale.nrows() != scale.ncols()
// Check that the degrees of freedom is not NaN
|| freedom.is_nan()
{
return Err(StatsError::BadParams);
return Err(MultivariateStudentError::ScaleInvalid);
}
// Check that degrees of freedom is positive
if freedom <= 0. {
return Err(StatsError::ArgMustBePositive(
"Degrees of freedom must be positive",
));

if freedom.is_nan() || freedom <= 0.0 {
return Err(MultivariateStudentError::FreedomInvalid);
}

if location.nrows() != scale.nrows() {
return Err(MultivariateStudentError::DimensionMismatch);
}

let scale_det = scale.determinant();
Expand All @@ -98,7 +146,7 @@ where
- 0.5 * scale_det.ln();

match Cholesky::new(scale.clone()) {
None => Err(StatsError::BadParams), // Scale matrix is not positive definite
None => Err(MultivariateStudentError::CholeskyFailed),
Some(cholesky_decomp) => {
let precision = cholesky_decomp.inverse();
Ok(MultivariateStudent {
Expand Down

0 comments on commit a5e7b08

Please sign in to comment.