diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 297fc9f4..efc3c5f3 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -31,7 +31,7 @@ pub use self::multinomial::Multinomial; #[cfg(feature = "nalgebra")] pub use self::multivariate_normal::{MultivariateNormal, MultivariateNormalError}; #[cfg(feature = "nalgebra")] -pub use self::multivariate_students_t::MultivariateStudent; +pub use self::multivariate_students_t::{MultivariateStudent, MultivariateStudentError}; pub use self::negative_binomial::{NegativeBinomial, NegativeBinomialError}; pub use self::normal::{Normal, NormalError}; pub use self::pareto::{Pareto, ParetoError}; diff --git a/src/distribution/multivariate_students_t.rs b/src/distribution/multivariate_students_t.rs index 3758d328..3a69d02d 100644 --- a/src/distribution/multivariate_students_t.rs +++ b/src/distribution/multivariate_students_t.rs @@ -2,7 +2,6 @@ use crate::distribution::Continuous; use crate::distribution::{ChiSquared, Normal}; use crate::function::gamma; use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; -use crate::{Result, StatsError}; use nalgebra::{Cholesky, Const, DMatrix, Dim, DimMin, Dyn, OMatrix, OVector}; use rand::Rng; use std::f64::consts::PI; @@ -39,6 +38,51 @@ where ln_pdf_const: f64, } +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +pub enum MultivariateStudentError { + /// The scale matrix is asymmetric or contains a NaN. + ScaleInvalid, + + /// The location vector contains a NaN. + LocationInvalid, + + /// The degrees of freedom are NaN, zero, or less than zero. + FreedomInvalid, + + /// The amount of rows in the location vector is not equal to the amount + /// of rows in the scale matrix. + DimensionMismatch, + + /// After all other validation, computing the Cholesky decomposition failed. + /// This means that the scale matrix is not definite-positive. + CholeskyFailed, +} + +impl std::fmt::Display for MultivariateStudentError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MultivariateStudentError::ScaleInvalid => { + write!(f, "Scale matrix is asymmetric or contains a NaN") + } + MultivariateStudentError::LocationInvalid => { + write!(f, "Location vector contains a NaN") + } + MultivariateStudentError::FreedomInvalid => { + write!(f, "Degrees of freedom are NaN, zero, or less than zero") + } + MultivariateStudentError::DimensionMismatch => write!( + f, + "Location vector and scale matrix do not have the same number of rows" + ), + MultivariateStudentError::CholeskyFailed => { + write!(f, "Computing the Cholesky decomposition failed") + } + } + } +} + +impl std::error::Error for MultivariateStudentError {} + impl MultivariateStudent { /// Constructs a new multivariate students t distribution with a location of `location`, /// scale matrix `scale` and `freedom` degrees of freedom. @@ -47,7 +91,11 @@ impl MultivariateStudent { /// /// Returns `StatsError::BadParams` if the scale matrix is not symmetric-positive /// definite and `StatsError::ArgMustBePositive` if freedom is non-positive. - pub fn new(location: Vec, scale: Vec, freedom: f64) -> Result { + pub fn new( + location: Vec, + scale: Vec, + freedom: f64, + ) -> Result { let dim = location.len(); Self::new_from_nalgebra(location.into(), DMatrix::from_vec(dim, dim, scale), freedom) } @@ -69,26 +117,26 @@ where location: OVector, scale: OMatrix, freedom: f64, - ) -> Result { + ) -> Result { let dim = location.len(); - // Check that the provided scale matrix is symmetric - if scale.lower_triangle() != scale.upper_triangle().transpose() - // Check that mean and covariance do not contain NaN - || location.iter().any(|f| f.is_nan()) + if location.iter().any(|f| f.is_nan()) { + return Err(MultivariateStudentError::LocationInvalid); + } + + if !scale.is_square() + || scale.lower_triangle() != scale.upper_triangle().transpose() || scale.iter().any(|f| f.is_nan()) - // Check that the dimensions match - || location.nrows() != scale.nrows() || scale.nrows() != scale.ncols() - // Check that the degrees of freedom is not NaN - || freedom.is_nan() { - return Err(StatsError::BadParams); + return Err(MultivariateStudentError::ScaleInvalid); } - // Check that degrees of freedom is positive - if freedom <= 0. { - return Err(StatsError::ArgMustBePositive( - "Degrees of freedom must be positive", - )); + + if freedom.is_nan() || freedom <= 0.0 { + return Err(MultivariateStudentError::FreedomInvalid); + } + + if location.nrows() != scale.nrows() { + return Err(MultivariateStudentError::DimensionMismatch); } let scale_det = scale.determinant(); @@ -98,7 +146,7 @@ where - 0.5 * scale_det.ln(); match Cholesky::new(scale.clone()) { - None => Err(StatsError::BadParams), // Scale matrix is not positive definite + None => Err(MultivariateStudentError::CholeskyFailed), Some(cholesky_decomp) => { let precision = cholesky_decomp.inverse(); Ok(MultivariateStudent {