Skip to content

Commit

Permalink
fix: use multivariate normal pdf when freedom = inf for multivariate …
Browse files Browse the repository at this point in the history
…student
  • Loading branch information
henryjac committed Dec 25, 2022
1 parent 9ab1a61 commit 280fa7e
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions src/distribution/multivariate_students_t.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl ::rand::distributions::Distribution<DVector<f64>> for MultivariateStudent {
///
/// W * L * Z + μ
///
/// where `W` has √(ν/Sν) distribution, Sν has Chi-squared
/// where `W` has √(ν/Sν) distribution, Sν has Chi-squared
/// distribution with ν degrees of freedom,
/// `L` is the Cholesky decomposition of the scale matrix,
/// `Z` is a vector of normally distributed random variables, and
Expand Down Expand Up @@ -167,8 +167,8 @@ impl MeanN<DVector<f64>> for MultivariateStudent {
///
/// # Remarks
///
/// This is the same mean used to construct the distribution if
/// the degrees of freedom is larger than 1.
/// This is the same mean used to construct the distribution if
/// the degrees of freedom is larger than 1
fn mean(&self) -> Option<DVector<f64>> {
if self.freedom > 1. {
let mut vec = vec![];
Expand All @@ -184,6 +184,14 @@ impl MeanN<DVector<f64>> for MultivariateStudent {

impl VarianceN<DMatrix<f64>> for MultivariateStudent {
/// Returns the covariance matrix of the multivariate student distribution
///
/// # Formula
/// ```ignore
/// Σ ⋅ ν / (ν - 2)
/// ```
///
/// where `Σ` is the scale matrix and `ν` is the degrees of freedom.
/// Only defined if freedom is larger than 2
fn variance(&self) -> Option<DMatrix<f64>> {
if self.freedom > 2. {
Some(self.scale.clone() * self.freedom / (self.freedom - 2.))
Expand Down Expand Up @@ -219,19 +227,14 @@ impl<'a> Continuous<&'a DVector<f64>, f64> for MultivariateStudent {
/// ```
///
/// where `ν` is the degrees of freedom, `μ` is the mean, `Gamma`
/// is the Gamma function, `inv(Σ)`
/// is the Gamma function, `inv(Σ)`
/// is the precision matrix, `det(Σ)` is the determinant
/// of the scale matrix, and `k` is the dimension of the distribution
///
/// TODO: Make this converge for large degrees of freedom
/// Current commented code beneath fails since `MultivariateNormal::new` accepts Vec<f64> and
/// not DVector or DMatrix. Should implement that instead of changing back to Vec<f64>, or
/// even have a constructor `MultivariateNormal::from_student`.
fn pdf(&self, x: &'a DVector<f64>) -> f64 {
// if self.freedom == f64::INFINITY {
// let mvn = MultivariateNormal::new(self.location, self.scale).unwrap();
// return mvn.pdf(x);
// }
if self.freedom == f64::INFINITY {
let mvn = MultivariateNormal::from_students(self.clone()).unwrap();
return mvn.pdf(x);
}
let dv = x - &self.location;
let base_term = 1.
+ 1. / self.freedom
Expand All @@ -244,6 +247,10 @@ impl<'a> Continuous<&'a DVector<f64>, f64> for MultivariateStudent {
/// Calculates the log probability density function for the multivariate
/// student distribution at `x`. Equivalent to pdf(x).ln().
fn ln_pdf(&self, x: &'a DVector<f64>) -> f64 {
if self.freedom == f64::INFINITY {
let mvn = MultivariateNormal::from_students(self.clone()).unwrap();
return mvn.ln_pdf(x);
}
let dv = x - &self.location;
let base_term = 1.
+ 1. / self.freedom
Expand Down Expand Up @@ -427,12 +434,19 @@ mod tests {
test_case(vec![-1., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 10., f64::NEG_INFINITY, pdf(dvec![10., 10.]));
}

// TODO: These tests fail because inf degrees of freedom give NaN
#[test]
fn test_pdf_freedom_large() {
let pdf_mvs = |mv: MultivariateStudent, arg: DVector<f64>| mv.pdf(&arg);
let pdf_mvn = |mv: MultivariateNormal, arg: DVector<f64>| mv.pdf(&arg);
test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-6, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 1e-7, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-15, dvec![1., 1.], pdf_mvs, pdf_mvn);
}
#[test]
fn test_ln_pdf_freedom_large() {
let pdf_mvs = |mv: MultivariateStudent, arg: DVector<f64>| mv.ln_pdf(&arg);
let pdf_mvn = |mv: MultivariateNormal, arg: DVector<f64>| mv.ln_pdf(&arg);
test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 1e-5, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-50, dvec![1., 1.], pdf_mvs, pdf_mvn);
}
}

0 comments on commit 280fa7e

Please sign in to comment.