diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 6aedede7..39181eee 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -26,7 +26,7 @@ pub use self::inverse_gamma::{InverseGamma, InverseGammaError}; pub use self::laplace::{Laplace, LaplaceError}; pub use self::log_normal::{LogNormal, LogNormalError}; pub use self::multinomial::Multinomial; -pub use self::multivariate_normal::MultivariateNormal; +pub use self::multivariate_normal::{MultivariateNormal, MultivariateNormalError}; pub use self::negative_binomial::{NegativeBinomial, NegativeBinomialError}; pub use self::normal::{Normal, NormalError}; pub use self::pareto::{Pareto, ParetoError}; diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 9949d676..9674feef 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -1,7 +1,6 @@ use crate::distribution::Continuous; use crate::distribution::Normal; use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; -use crate::{Result, StatsError}; use nalgebra::{Cholesky, Const, DMatrix, DVector, Dim, DimMin, Dyn, OMatrix, OVector}; use rand::Rng; use std::f64; @@ -36,15 +35,51 @@ where pdf_const: f64, } +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +pub enum MultivariateNormalError { + /// The covariance matrix is asymmetric or contains a NaN. + CovInvalid, + + /// The mean vector contains a NaN. + MeanInvalid, + + /// The amount of rows in the vector of means is not equal to the amount + /// of rows in the covariance matrix. + DimensionMismatch, + + /// After all other validation, computing the Cholesky decomposition failed. + CholeskyFailed, +} + +impl std::fmt::Display for MultivariateNormalError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MultivariateNormalError::CovInvalid => { + write!(f, "Covariance matrix is asymmetric or contains a NaN") + } + MultivariateNormalError::MeanInvalid => write!(f, "Mean vector contains a NaN"), + MultivariateNormalError::DimensionMismatch => write!( + f, + "Mean vector and covariance matrix do not have the same number of rows" + ), + MultivariateNormalError::CholeskyFailed => { + write!(f, "Computing the Cholesky decomposition failed") + } + } + } +} + +impl std::error::Error for MultivariateNormalError {} + impl MultivariateNormal { - /// Constructs a new multivariate normal distribution with a mean of `mean` + /// Constructs a new multivariate normal distribution with a mean of `mean` /// and covariance matrix `cov` /// /// # Errors /// /// Returns an error if the given covariance matrix is not /// symmetric or positive-definite - pub fn new(mean: Vec, cov: Vec) -> Result { + pub fn new(mean: Vec, cov: Vec) -> Result { let mean = DVector::from_vec(mean); let cov = DMatrix::from_vec(mean.len(), mean.len(), cov); MultivariateNormal::new_from_nalgebra(mean, cov) @@ -66,17 +101,25 @@ where /// /// Returns an error if the given covariance matrix is not /// symmetric or positive-definite - pub fn new_from_nalgebra(mean: OVector, cov: OMatrix) -> Result { - // Check that the provided covariance matrix is symmetric - if cov.lower_triangle() != cov.upper_triangle().transpose() - // Check that mean and covariance do not contain NaN - || mean.iter().any(|f| f.is_nan()) + pub fn new_from_nalgebra( + mean: OVector, + cov: OMatrix, + ) -> Result { + if mean.iter().any(|f| f.is_nan()) { + return Err(MultivariateNormalError::MeanInvalid); + } + + if cov.nrows() != cov.ncols() + || cov.lower_triangle() != cov.upper_triangle().transpose() || cov.iter().any(|f| f.is_nan()) - // Check that the dimensions match - || mean.nrows() != cov.nrows() || cov.nrows() != cov.ncols() { - return Err(StatsError::BadParams); + return Err(MultivariateNormalError::CovInvalid); + } + + if mean.nrows() != cov.nrows() { + return Err(MultivariateNormalError::DimensionMismatch); } + let cov_det = cov.determinant(); let pdf_const = ((2. * PI).powi(mean.nrows() as i32) * cov_det.abs()) .recip() @@ -84,7 +127,7 @@ where // Store the Cholesky decomposition of the covariance matrix // for sampling match Cholesky::new(cov.clone()) { - None => Err(StatsError::BadParams), + None => Err(MultivariateNormalError::CholeskyFailed), Some(cholesky_decomp) => { let precision = cholesky_decomp.inverse(); Ok(MultivariateNormal {