diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index ab168020..75902a4b 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -202,6 +202,28 @@ impl<'a> Continuous<&'a DVector, f64> for MultivariateNormal { } } +impl Continuous, f64> for MultivariateNormal { + /// Calculates the probability density function for the multivariate + /// normal distribution at `x` + /// + /// # Formula + /// + /// ```ignore + /// (2 * π) ^ (-k / 2) * det(Σ) ^ (1 / 2) * e ^ ( -(1 / 2) * transpose(x - μ) * inv(Σ) * (x - μ)) + /// ``` + /// + /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant + /// of the covariance matrix, and `k` is the dimension of the distribution + fn pdf(&self, x: Vec) -> f64 { + self.pdf(&DVector::from(x)) + } + /// Calculates the log probability density function for the multivariate + /// normal distribution at `x`. Equivalent to pdf(x).ln(). + fn ln_pdf(&self, x: Vec) -> f64 { + self.pdf(&DVector::from(x)) + } +} + #[rustfmt::skip] #[cfg(all(test, feature = "nightly"))] mod tests {