diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index cc96e102..eb4d7958 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -2,8 +2,7 @@ use crate::distribution::Continuous; use crate::function::gamma; use crate::statistics::*; use crate::{prec, Result, StatsError}; -use nalgebra::DMatrix; -use nalgebra::DVector; +use nalgebra::{Const, Dim, Dyn, OMatrix, OVector}; use rand::Rng; use std::f64; @@ -24,10 +23,15 @@ use std::f64; /// assert_eq!(n.pdf(&DVector::from_vec(vec![0.33333, 0.33333, 0.33333])), 2.222155556222205); /// ``` #[derive(Clone, PartialEq, Debug)] -pub struct Dirichlet { - alpha: DVector, +pub struct Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + alpha: OVector, } -impl Dirichlet { + +impl Dirichlet { /// Constructs a new dirichlet distribution with the given /// concentration parameters (alpha) /// @@ -51,15 +55,8 @@ impl Dirichlet { /// result = Dirichlet::new(alpha_err); /// assert!(result.is_err()); /// ``` - pub fn new(alpha: Vec) -> Result { - if !is_valid_alpha(&alpha) { - Err(StatsError::BadParams) - } else { - // let vec = alpha.to_vec(); - Ok(Dirichlet { - alpha: DVector::from_vec(alpha.to_vec()), - }) - } + pub fn new(alpha: Vec) -> Result { + Self::new_from_nalgebra(alpha.into()) } /// Constructs a new dirichlet distribution with the given @@ -81,9 +78,30 @@ impl Dirichlet { /// result = Dirichlet::new_with_param(0.0, 1); /// assert!(result.is_err()); /// ``` - pub fn new_with_param(alpha: f64, n: usize) -> Result { + pub fn new_with_param(alpha: f64, n: usize) -> Result { Self::new(vec![alpha; n]) } +} + +impl Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + /// Constructs a new distribution with the given vector for `alpha` + /// Does not clone the vector it takes ownership of + /// + /// # Error + /// + /// Returns an error if vector has length less than 2 or if any element + /// of alpha is NOT finite positive + pub fn new_from_nalgebra(alpha: OVector) -> Result { + if !is_valid_alpha(alpha.as_slice()) { + Err(StatsError::BadParams) + } else { + Ok(Self { alpha }) + } + } /// Returns the concentration parameters of /// the dirichlet distribution as a slice @@ -97,12 +115,12 @@ impl Dirichlet { /// let n = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap(); /// assert_eq!(n.alpha(), &DVector::from_vec(vec![1.0, 2.0, 3.0])); /// ``` - pub fn alpha(&self) -> &DVector { + pub fn alpha(&self) -> &nalgebra::OVector { &self.alpha } fn alpha_sum(&self) -> f64 { - self.alpha.fold(0.0, |acc, x| acc + x) + self.alpha.sum() } /// Returns the entropy of the dirichlet distribution @@ -134,30 +152,40 @@ impl Dirichlet { } } -impl std::fmt::Display for Dirichlet { +impl std::fmt::Display for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Dir({}, {})", self.alpha.len(), &self.alpha) } } -impl ::rand::distributions::Distribution> for Dirichlet { - fn sample(&self, rng: &mut R) -> DVector { +impl ::rand::distributions::Distribution> for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + fn sample(&self, rng: &mut R) -> OVector { let mut sum = 0.0; - let mut samples: Vec<_> = self - .alpha - .iter() - .map(|&a| { + OVector::from_iterator_generic( + self.alpha.shape_generic().0, + Const::<1>, + self.alpha.iter().map(|&a| { let sample = super::gamma::sample_unchecked(rng, a, 1.0); sum += sample; sample - }) - .collect(); - for _ in samples.iter_mut().map(|x| *x /= sum) {} - DVector::from_vec(samples) + }), + ) } } -impl MeanN> for Dirichlet { +impl MeanN> for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ /// Returns the means of the dirichlet distribution /// /// # Formula @@ -168,13 +196,18 @@ impl MeanN> for Dirichlet { /// /// for the `i`th element where `α_i` is the `i`th concentration parameter /// and `α_0` is the sum of all concentration parameters - fn mean(&self) -> Option> { + fn mean(&self) -> Option> { let sum = self.alpha_sum(); Some(self.alpha.map(|x| x / sum)) } } -impl VarianceN> for Dirichlet { +impl VarianceN> for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the variances of the dirichlet distribution /// /// # Formula @@ -185,10 +218,10 @@ impl VarianceN> for Dirichlet { /// /// for the `i`th element where `α_i` is the `i`th concentration parameter /// and `α_0` is the sum of all concentration parameters - fn variance(&self) -> Option> { + fn variance(&self) -> Option> { let sum = self.alpha_sum(); let normalizing = sum * sum * (sum + 1.0); - let mut cov = DMatrix::from_diagonal(&self.alpha.map(|x| x * (sum - x) / normalizing)); + let mut cov = OMatrix::from_diagonal(&self.alpha.map(|x| x * (sum - x) / normalizing)); let mut offdiag = |x: usize, y: usize| { let elt = -self.alpha[x] * self.alpha[y] / normalizing; cov[(x, y)] = elt; @@ -203,7 +236,13 @@ impl VarianceN> for Dirichlet { } } -impl<'a> Continuous<&'a DVector, f64> for Dirichlet { +impl<'a, D> Continuous<&'a OVector, f64> for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator, D>, +{ /// Calculates the probabiliy density function for the dirichlet /// distribution /// with given `x`'s corresponding to the concentration parameters for this @@ -234,7 +273,7 @@ impl<'a> Continuous<&'a DVector, f64> for Dirichlet { /// the `i`th concentration parameter, `Γ` is the gamma function, /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`, /// and `K` is the number of concentration parameters - fn pdf(&self, x: &DVector) -> f64 { + fn pdf(&self, x: &OVector) -> f64 { self.ln_pdf(x).exp() } @@ -268,7 +307,7 @@ impl<'a> Continuous<&'a DVector, f64> for Dirichlet { /// the `i`th concentration parameter, `Γ` is the gamma function, /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`, /// and `K` is the number of concentration parameters - fn ln_pdf(&self, x: &DVector) -> f64 { + fn ln_pdf(&self, x: &OVector) -> f64 { // TODO: would it be clearer here to just do a for loop instead // of using iterators? if self.alpha.len() != x.len() { @@ -300,55 +339,71 @@ impl<'a> Continuous<&'a DVector, f64> for Dirichlet { // determines if `a` is a valid alpha array // for the Dirichlet distribution fn is_valid_alpha(a: &[f64]) -> bool { - a.len() >= 2 && super::internal::is_valid_multinomial(a, false) + a.len() >= 2 && a.iter().all(|&a_i| a_i.is_finite() && a_i > 0.0) } #[rustfmt::skip] #[cfg(test)] mod tests { + use nalgebra::{dvector, vector, DimMin, OVector}; + use super::*; - use crate::distribution::{Continuous, Dirichlet}; + use crate::distribution::Continuous; - #[test] - fn test_is_valid_alpha() { - let invalid = [1.0]; - assert!(!is_valid_alpha(&invalid)); + fn try_create(alpha: OVector) -> Dirichlet + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + { + let mvn = Dirichlet::new_from_nalgebra(alpha); + assert!(mvn.is_ok()); + mvn.unwrap() } - fn try_create(alpha: &[f64]) -> Dirichlet + fn bad_create_case(alpha: OVector) + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { - let n = Dirichlet::new(alpha.to_vec()); - assert!(n.is_ok()); - n.unwrap() + let dd = Dirichlet::new_from_nalgebra(alpha); + assert!(dd.is_err()); } - fn create_case(alpha: &[f64]) + fn test_almost(alpha: OVector, expected: f64, acc: f64, eval: F) + where + F: FnOnce(Dirichlet) -> f64, + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { - let n = try_create(alpha); - let a2 = n.alpha(); - for i in 0..alpha.len() { - assert_eq!(alpha[i], a2[i]); - } + let dd = try_create(alpha); + let x = eval(dd); + assert_almost_eq!(expected, x, acc); } - fn bad_create_case(alpha: &[f64]) - { - let n = Dirichlet::new(alpha.to_vec()); - assert!(n.is_err()); + #[test] + fn test_is_valid_alpha() { + assert!(!is_valid_alpha(&[1.0])); + assert!(!is_valid_alpha(&[1.0, f64::NAN])); + assert!(is_valid_alpha(&[1.0, 2.0])); + assert!(!is_valid_alpha(&[1.0, 0.0])); + assert!(!is_valid_alpha(&[1.0, f64::INFINITY])); + assert!(!is_valid_alpha(&[-1.0, 2.0])); } #[test] fn test_create() { - create_case(&[1.0, 2.0, 3.0, 4.0, 5.0]); - create_case(&[0.001, f64::INFINITY, 3756.0]); + try_create(vector![1.0, 2.0, 3.0, 4.0, 5.0]); + assert!(Dirichlet::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]).is_ok()); + // try_create(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate } #[test] fn test_bad_create() { - bad_create_case(&[1.0]); - bad_create_case(&[1.0, 2.0, 0.0, 4.0, 5.0]); - bad_create_case(&[1.0, f64::NAN, 3.0, 4.0, 5.0]); - bad_create_case(&[0.0, 0.0, 0.0]); + bad_create_case(vector![1.0]); + bad_create_case(vector![1.0, 2.0, 0.0, 4.0, 5.0]); + bad_create_case(vector![1.0, f64::NAN, 3.0, 4.0, 5.0]); + bad_create_case(vector![0.0, 0.0, 0.0]); + bad_create_case(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate } // #[test] @@ -386,70 +441,94 @@ mod tests { #[test] fn test_entropy() { - let mut n = try_create(&[0.1, 0.3, 0.5, 0.8]); - assert_eq!(n.entropy().unwrap(), -17.46469081094079); - - n = try_create(&[0.1, 0.2, 0.3, 0.4]); - assert_eq!(n.entropy().unwrap(), -21.53881433791513); - } - - macro_rules! dvec { - ($($x:expr),*) => (DVector::from_vec(vec![$($x),*])); + let entropy = |x: Dirichlet<_>| x.entropy().unwrap(); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + -17.46469081094079, + 1e-30, + entropy, + ); + test_almost( + vector![0.1, 0.2, 0.3, 0.4], + -21.53881433791513, + 1e-30, + entropy, + ); } #[test] fn test_pdf() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - assert_almost_eq!(n.pdf(&dvec![0.01, 0.03, 0.5, 0.46]), 18.77225681167061, 1e-12); - assert_almost_eq!(n.pdf(&dvec![0.1,0.2,0.3,0.4]), 0.8314656481199253, 1e-14); + let pdf = |arg| move |x: Dirichlet<_>| x.pdf(&arg); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + 18.77225681167061, + 1e-12, + pdf([0.01, 0.03, 0.5, 0.46].into()), + ); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + 0.8314656481199253, + 1e-14, + pdf([0.1, 0.2, 0.3, 0.4].into()), + ); } #[test] fn test_ln_pdf() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - assert_almost_eq!(n.ln_pdf(&dvec![0.01, 0.03, 0.5, 0.46]), 18.77225681167061f64.ln(), 1e-12); - assert_almost_eq!(n.ln_pdf(&dvec![0.1,0.2,0.3,0.4]), 0.8314656481199253f64.ln(), 1e-14); + let ln_pdf = |arg| move |x: Dirichlet<_>| x.ln_pdf(&arg); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + 18.77225681167061_f64.ln(), + 1e-12, + ln_pdf([0.01, 0.03, 0.5, 0.46].into()), + ); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + 0.8314656481199253_f64.ln(), + 1e-14, + ln_pdf([0.1, 0.2, 0.3, 0.4].into()), + ); } #[test] #[should_panic] fn test_pdf_bad_input_length() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.pdf(&dvec![0.5]); + let n = try_create(dvector![0.1, 0.3, 0.5, 0.8]); + n.pdf(&dvector![0.5]); } #[test] #[should_panic] fn test_pdf_bad_input_range() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.pdf(&dvec![1.5, 0.0, 0.0, 0.0]); + let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); + n.pdf(&vector![1.5, 0.0, 0.0, 0.0]); } #[test] #[should_panic] fn test_pdf_bad_input_sum() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.pdf(&dvec![0.5, 0.25, 0.8, 0.9]); + let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); + n.pdf(&vector![0.5, 0.25, 0.8, 0.9]); } #[test] #[should_panic] fn test_ln_pdf_bad_input_length() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.ln_pdf(&dvec![0.5]); + let n = try_create(dvector![0.1, 0.3, 0.5, 0.8]); + n.ln_pdf(&dvector![0.5]); } #[test] #[should_panic] fn test_ln_pdf_bad_input_range() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.ln_pdf(&dvec![1.5, 0.0, 0.0, 0.0]); + let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); + n.ln_pdf(&vector![1.5, 0.0, 0.0, 0.0]); } #[test] #[should_panic] fn test_ln_pdf_bad_input_sum() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.ln_pdf(&dvec![0.5, 0.25, 0.8, 0.9]); + let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); + n.ln_pdf(&vector![0.5, 0.25, 0.8, 0.9]); } }