Skip to content

Commit

Permalink
Add MultivariateNormal
Browse files Browse the repository at this point in the history
  • Loading branch information
FreezyLemon committed Aug 16, 2024
1 parent 6349fa2 commit 1870451
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/distribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub use self::inverse_gamma::{InverseGamma, InverseGammaError};
pub use self::laplace::{Laplace, LaplaceError};
pub use self::log_normal::{LogNormal, LogNormalError};
pub use self::multinomial::Multinomial;
pub use self::multivariate_normal::MultivariateNormal;
pub use self::multivariate_normal::{MultivariateNormal, MultivariateNormalError};
pub use self::negative_binomial::{NegativeBinomial, NegativeBinomialError};
pub use self::normal::{Normal, NormalError};
pub use self::pareto::{Pareto, ParetoError};
Expand Down
67 changes: 55 additions & 12 deletions src/distribution/multivariate_normal.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::distribution::Continuous;
use crate::distribution::Normal;
use crate::statistics::{Max, MeanN, Min, Mode, VarianceN};
use crate::{Result, StatsError};
use nalgebra::{Cholesky, Const, DMatrix, DVector, Dim, DimMin, Dyn, OMatrix, OVector};
use rand::Rng;
use std::f64;
Expand Down Expand Up @@ -36,15 +35,51 @@ where
pdf_const: f64,
}

#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub enum MultivariateNormalError {
/// The covariance matrix is asymmetric or contains a NaN.
CovInvalid,

/// The mean vector contains a NaN.
MeanInvalid,

/// The amount of rows in the vector of means is not equal to the amount
/// of rows in the covariance matrix.
DimensionMismatch,

/// After all other validation, computing the Cholesky decomposition failed.
CholeskyFailed,
}

impl std::fmt::Display for MultivariateNormalError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
MultivariateNormalError::CovInvalid => {
write!(f, "Covariance matrix is asymmetric or contains a NaN")
}
MultivariateNormalError::MeanInvalid => write!(f, "Mean vector contains a NaN"),
MultivariateNormalError::DimensionMismatch => write!(
f,
"Mean vector and covariance matrix do not have the same number of rows"
),
MultivariateNormalError::CholeskyFailed => {
write!(f, "Computing the Cholesky decomposition failed")
}
}
}
}

impl std::error::Error for MultivariateNormalError {}

impl MultivariateNormal<Dyn> {
/// Constructs a new multivariate normal distribution with a mean of `mean`
/// Constructs a new multivariate normal distribution with a mean of `mean`
/// and covariance matrix `cov`
///
/// # Errors
///
/// Returns an error if the given covariance matrix is not
/// symmetric or positive-definite
pub fn new(mean: Vec<f64>, cov: Vec<f64>) -> Result<Self> {
pub fn new(mean: Vec<f64>, cov: Vec<f64>) -> Result<Self, MultivariateNormalError> {
let mean = DVector::from_vec(mean);
let cov = DMatrix::from_vec(mean.len(), mean.len(), cov);
MultivariateNormal::new_from_nalgebra(mean, cov)
Expand All @@ -66,25 +101,33 @@ where
///
/// Returns an error if the given covariance matrix is not
/// symmetric or positive-definite
pub fn new_from_nalgebra(mean: OVector<f64, D>, cov: OMatrix<f64, D, D>) -> Result<Self> {
// Check that the provided covariance matrix is symmetric
if cov.lower_triangle() != cov.upper_triangle().transpose()
// Check that mean and covariance do not contain NaN
|| mean.iter().any(|f| f.is_nan())
pub fn new_from_nalgebra(
mean: OVector<f64, D>,
cov: OMatrix<f64, D, D>,
) -> Result<Self, MultivariateNormalError> {
if mean.iter().any(|f| f.is_nan()) {
return Err(MultivariateNormalError::MeanInvalid);
}

if cov.nrows() != cov.ncols()
|| cov.lower_triangle() != cov.upper_triangle().transpose()
|| cov.iter().any(|f| f.is_nan())
// Check that the dimensions match
|| mean.nrows() != cov.nrows() || cov.nrows() != cov.ncols()
{
return Err(StatsError::BadParams);
return Err(MultivariateNormalError::CovInvalid);
}

if mean.nrows() != cov.nrows() {
return Err(MultivariateNormalError::DimensionMismatch);
}

let cov_det = cov.determinant();
let pdf_const = ((2. * PI).powi(mean.nrows() as i32) * cov_det.abs())
.recip()
.sqrt();
// Store the Cholesky decomposition of the covariance matrix
// for sampling
match Cholesky::new(cov.clone()) {
None => Err(StatsError::BadParams),
None => Err(MultivariateNormalError::CholeskyFailed),
Some(cholesky_decomp) => {
let precision = cholesky_decomp.inverse();
Ok(MultivariateNormal {
Expand Down

0 comments on commit 1870451

Please sign in to comment.