Skip to content

Commit

Permalink
feat: possibility for creating multivariate normal dist from nalgebra…
Browse files Browse the repository at this point in the history
… DVector/DMatrix
  • Loading branch information
henryjac authored and YeungOnion committed Apr 23, 2024
1 parent d29d020 commit ce40e3a
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/distribution/multivariate_normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub struct MultivariateNormal {
}

impl MultivariateNormal {
/// Constructs a new multivariate normal distribution with a mean of `mean`
/// Constructs a new multivariate normal distribution with a mean of `mean`
/// and covariance matrix `cov`
///
/// # Errors
Expand All @@ -47,6 +47,18 @@ impl MultivariateNormal {
pub fn new(mean: Vec<f64>, cov: Vec<f64>) -> Result<Self> {
let mean = DVector::from_vec(mean);
let cov = DMatrix::from_vec(mean.len(), mean.len(), cov);
return MultivariateNormal::new_from_nalgebra(mean, cov)
}

/// Constructs a new multivariate normal distribution with a mean of `mean`
/// and covariance matrix `cov`, but with explicitly using nalgebras
/// DVector and DMatrix instead of Vec<f64>
///
/// # Errors
///
/// Returns an error if the given covariance matrix is not
/// symmetric or positive-definite
pub fn new_from_nalgebra(mean: DVector<f64>, cov: DMatrix<f64>) -> Result<Self> {
let dim = mean.len();
// Check that the provided covariance matrix is symmetric
if cov.lower_triangle() != cov.upper_triangle().transpose()
Expand Down Expand Up @@ -79,6 +91,7 @@ impl MultivariateNormal {
}
}
}

/// Returns the entropy of the multivariate normal distribution
///
/// # Formula
Expand Down

0 comments on commit ce40e3a

Please sign in to comment.