From d3dd51729b45ab31a0c31227b6a9db121c3fab28 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 2 Aug 2024 12:39:09 -0500 Subject: [PATCH 1/6] test: expand tests for Dirichlet distribution --- src/distribution/dirichlet.rs | 81 ++++++++++++++++++++++++----------- 1 file changed, 56 insertions(+), 25 deletions(-) diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index eb4d7958..f058b46d 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -345,10 +345,15 @@ fn is_valid_alpha(a: &[f64]) -> bool { #[rustfmt::skip] #[cfg(test)] mod tests { - use nalgebra::{dvector, vector, DimMin, OVector}; + use std::fmt::{Debug, Display}; - use super::*; - use crate::distribution::Continuous; + use nalgebra::{dmatrix, dvector, vector, DimMin, OVector}; + + use super::is_valid_alpha; + use crate::{ + distribution::{Continuous, Dirichlet}, + statistics::{MeanN, VarianceN}, + }; fn try_create(alpha: OVector) -> Dirichlet where @@ -369,15 +374,16 @@ mod tests { assert!(dd.is_err()); } - fn test_almost(alpha: OVector, expected: f64, acc: f64, eval: F) + fn test_almost(alpha: OVector, expected: T, acc: f64, eval: F) where - F: FnOnce(Dirichlet) -> f64, + T: Debug + Display + approx::RelativeEq, + F: FnOnce(Dirichlet) -> T, D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { let dd = try_create(alpha); let x = eval(dd); - assert_almost_eq!(expected, x, acc); + assert_relative_eq!(expected, x, epsilon = acc); } #[test] @@ -406,26 +412,51 @@ mod tests { bad_create_case(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate } - // #[test] - // fn test_mean() { - // let n = Dirichlet::new_with_param(0.3, 5).unwrap(); - // let res = n.mean(); - // for x in res { - // assert_eq!(x, 0.3 / 1.5); - // } - // } + #[test] + fn test_mean() { + let mean = |dd: Dirichlet<_>| dd.mean().unwrap(); - // #[test] - // fn test_variance() { - // let alpha = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; - // let sum = alpha.iter().fold(0.0, |acc, x| acc + x); - // let n = Dirichlet::new(&alpha).unwrap(); - // let res = n.variance(); - // for i in 1..11 { - // let f = i as f64; - // assert_almost_eq!(res[i-1], f * (sum - f) / (sum * sum * (sum + 1.0)), 1e-15); - // } - // } + test_almost(vec![0.5; 5].into(), vec![1.0 / 5.0; 5].into(), 1e-15, mean); + + test_almost( + dvector![0.1, 0.2, 0.3, 0.4], + dvector![0.1, 0.2, 0.3, 0.4], + 1e-15, + mean, + ); + + test_almost( + dvector![1.0, 2.0, 3.0, 4.0], + dvector![0.1, 0.2, 0.3, 0.4], + 1e-15, + mean, + ); + } + + #[test] + fn test_variance() { + let variance = |dd: Dirichlet<_>| dd.variance().unwrap(); + + test_almost( + dvector![1.0, 2.0], + dmatrix![0.055555555555555, -0.055555555555555; + -0.055555555555555, 0.055555555555555; + ], + 1e-15, + variance, + ); + + test_almost( + dvector![0.1, 0.2, 0.3, 0.4], + dmatrix![0.045, -0.010, -0.015, -0.020; + -0.010, 0.080, -0.030, -0.040; + -0.015, -0.030, 0.105, -0.060; + -0.020, -0.040, -0.060, 0.120; + ], + 1e-15, + variance, + ); + } // #[test] // fn test_std_dev() { From 444cdf39151d1c8c1c45178c9de36f81b8b0dcfa Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:45:39 -0500 Subject: [PATCH 2/6] feat(multivariate)!: migrate Multinomial to generic dimension API --- src/distribution/internal.rs | 37 +++++++++++++++++ src/distribution/multinomial.rs | 73 ++++++++++++++++++++++++--------- 2 files changed, 91 insertions(+), 19 deletions(-) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index b6e71136..b45b4bf7 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -1,5 +1,8 @@ +use nalgebra::{Dim, OVector}; use num_traits::Num; +use crate::StatsError; + /// Returns true if there are no elements in `x` in `arr` /// such that `x <= 0.0` or `x` is `f64::NAN` and `sum(arr) > 0.0`. /// IF `incl_zero` is true, it tests for `x < 0.0` instead of `x <= 0.0` @@ -14,6 +17,36 @@ pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool { sum != 0.0 } +pub fn check_multinomial(arr: &OVector, accept_zeroes: bool) -> crate::Result<()> +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + if arr.len() < 2 { + return Err(StatsError::BadParams); + } + let mut sum = 0.0; + for &x in arr.iter() { + if x.is_nan() { + return Err(StatsError::BadParams); + } else if x.is_infinite() { + return Err(StatsError::BadParams); + } else if x < 0.0 { + return Err(StatsError::BadParams); + } else if x == 0.0 && !accept_zeroes { + return Err(StatsError::BadParams); + } else { + sum += x; + } + } + + if sum != 0.0 { + Ok(()) + } else { + Err(StatsError::BadParams) + } +} + /// Implements univariate function bisection searching for criteria /// ```text /// smallest k such that f(k) >= z @@ -225,12 +258,16 @@ pub mod test { let invalid = [1.0, f64::NAN, 3.0]; assert!(!is_valid_multinomial(&invalid, true)); + assert!(check_multinomial(&invalid.to_vec().into(), true).is_err()); let invalid2 = [-2.0, 5.0, 1.0, 6.2]; assert!(!is_valid_multinomial(&invalid2, true)); + assert!(check_multinomial(&invalid2.to_vec().into(), true).is_err()); let invalid3 = [0.0, 0.0, 0.0]; assert!(!is_valid_multinomial(&invalid3, true)); + assert!(check_multinomial(&invalid3.to_vec().into(), true).is_err()); let valid = [5.2, 0.0, 1e-15, 1000000.12]; assert!(is_valid_multinomial(&valid, true)); + assert!(check_multinomial(&valid.to_vec().into(), true).is_ok()); } #[test] diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index dd17d2f0..b3e95cb2 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -1,8 +1,8 @@ use crate::distribution::Discrete; use crate::function::factorial; use crate::statistics::*; -use crate::{Result, StatsError}; -use ::nalgebra::{DMatrix, DVector}; +use crate::Result; +use nalgebra::{Const, DMatrix, DVector, Dim, Dyn, OVector}; use rand::Rng; /// Implements the @@ -22,12 +22,18 @@ use rand::Rng; /// assert_eq!(n.mean().unwrap(), DVector::from_vec(vec![1.5, 3.5])); /// ``` #[derive(Debug, Clone, PartialEq)] -pub struct Multinomial { - p: Vec, +pub struct Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + /// normalized probabilities for each species + p: OVector, + /// count of trials n: u64, } -impl Multinomial { +impl Multinomial { /// Constructs a new multinomial distribution with probabilities `p` /// and `n` number of trials. /// @@ -51,11 +57,20 @@ impl Multinomial { /// result = Multinomial::new(&[0.0, -1.0, 2.0], 3); /// assert!(result.is_err()); /// ``` - pub fn new(p: &[f64], n: u64) -> Result { - if !super::internal::is_valid_multinomial(p, true) { - Err(StatsError::BadParams) - } else { - Ok(Multinomial { p: p.to_vec(), n }) + pub fn new(p: &[f64], n: u64) -> Result { + Self::new_from_nalgebra(p.to_vec().into(), n) + } +} + +impl Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + pub fn new_from_nalgebra(mut p: OVector, n: u64) -> Result { + match super::internal::check_multinomial(&p, true) { + Err(e) => Err(e), + Ok(_) => Ok(Self { p, n }), } } @@ -70,7 +85,7 @@ impl Multinomial { /// let n = Multinomial::new(&[0.0, 1.0, 2.0], 3).unwrap(); /// assert_eq!(n.p(), [0.0, 1.0, 2.0]); /// ``` - pub fn p(&self) -> &[f64] { + pub fn p(&self) -> &OVector { &self.p } @@ -90,16 +105,24 @@ impl Multinomial { } } -impl std::fmt::Display for Multinomial { +impl std::fmt::Display for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Multinom({:#?},{})", self.p, self.n) } } -impl ::rand::distributions::Distribution> for Multinomial { - fn sample(&self, rng: &mut R) -> Vec { - let p_cdf = super::categorical::prob_mass_to_cdf(self.p()); - let mut res = vec![0.0; self.p.len()]; +impl ::rand::distributions::Distribution> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + fn sample(&self, rng: &mut R) -> OVector { + let p_cdf = super::categorical::prob_mass_to_cdf(self.p().as_slice()); + let mut res = OVector::zeros_generic(self.p.shape_generic().0, Const::<1>); for _ in 0..self.n { let i = super::categorical::sample_unchecked(rng, &p_cdf); let el = res.get_mut(i as usize).unwrap(); @@ -109,7 +132,11 @@ impl ::rand::distributions::Distribution> for Multinomial { } } -impl MeanN> for Multinomial { +impl MeanN> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ /// Returns the mean of the multinomial distribution /// /// # Formula @@ -127,7 +154,11 @@ impl MeanN> for Multinomial { } } -impl VarianceN> for Multinomial { +impl VarianceN> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ /// Returns the variance of the multinomial distribution /// /// # Formula @@ -169,7 +200,11 @@ impl VarianceN> for Multinomial { // } // } -impl<'a> Discrete<&'a [u64], f64> for Multinomial { +impl<'a, D> Discrete<&'a [u64], f64> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ /// Calculates the probability mass function for the multinomial /// distribution /// with the given `x`'s corresponding to the probabilities for this From b769979beaf29970ef96da8c22299b5820efc863 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:45:39 -0500 Subject: [PATCH 3/6] refactor: Multinomial stores normalized probability --- src/distribution/multinomial.rs | 61 +++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index b3e95cb2..610d746b 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -2,7 +2,7 @@ use crate::distribution::Discrete; use crate::function::factorial; use crate::statistics::*; use crate::Result; -use nalgebra::{Const, DMatrix, DVector, Dim, Dyn, OVector}; +use nalgebra::{Const, DVector, Dim, Dyn, OMatrix, OVector}; use rand::Rng; /// Implements the @@ -51,14 +51,14 @@ impl Multinomial { /// ``` /// use statrs::distribution::Multinomial; /// - /// let mut result = Multinomial::new(&[0.0, 1.0, 2.0], 3); + /// let mut result = Multinomial::new(vec![0.0, 1.0, 2.0], 3); /// assert!(result.is_ok()); /// - /// result = Multinomial::new(&[0.0, -1.0, 2.0], 3); + /// result = Multinomial::new(vec![0.0, -1.0, 2.0], 3); /// assert!(result.is_err()); /// ``` - pub fn new(p: &[f64], n: u64) -> Result { - Self::new_from_nalgebra(p.to_vec().into(), n) + pub fn new(p: Vec, n: u64) -> Result { + Self::new_from_nalgebra(p.into(), n) } } @@ -70,7 +70,10 @@ where pub fn new_from_nalgebra(mut p: OVector, n: u64) -> Result { match super::internal::check_multinomial(&p, true) { Err(e) => Err(e), - Ok(_) => Ok(Self { p, n }), + Ok(_) => { + p.unscale_mut(p.lp_norm(1)); + Ok(Self { p, n }) + } } } @@ -154,10 +157,11 @@ where } } -impl VarianceN> for Multinomial +impl VarianceN> for Multinomial where D: Dim, - nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the variance of the multinomial distribution /// @@ -169,13 +173,21 @@ where /// /// where `n` is the number of trials, `p_i` is the `i`th probability, /// and `k` is the total number of probabilities - fn variance(&self) -> Option> { - let cov: Vec<_> = self - .p - .iter() - .map(|x| x * self.n as f64 * (1.0 - x)) - .collect(); - Some(DMatrix::from_diagonal(&DVector::from_vec(cov))) + fn variance(&self) -> Option> { + let mut cov = OMatrix::from_diagonal(&self.p.map(|x| x * (1.0 - x))); + let mut offdiag = |x: usize, y: usize| { + let elt = -self.p[x] * self.p[y]; + // cov[(x, y)] = elt; + cov[(y, x)] = elt; + }; + + for i in 0..self.p.len() { + for j in 0..i { + offdiag(i, j); + } + } + cov.fill_lower_triangle_with_upper_triangle(); + Some(cov.scale(self.n as f64)) } } @@ -200,10 +212,11 @@ where // } // } -impl<'a, D> Discrete<&'a [u64], f64> for Multinomial +impl<'a, D> Discrete<&'a OVector, f64> for Multinomial where D: Dim, - nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Calculates the probability mass function for the multinomial /// distribution @@ -212,8 +225,7 @@ where /// /// # Panics /// - /// If the elements in `x` do not sum to `n` or if the length of `x` is not - /// equivalent to the length of `p` + /// If length of `x` is not equal to length of `p` /// /// # Formula /// @@ -224,14 +236,14 @@ where /// where `n` is the number of trials, `p_i` is the `i`th probability, /// `x_i` is the `i`th `x` value, and `k` is the total number of /// probabilities - fn pmf(&self, x: &[u64]) -> f64 { + fn pmf(&self, x: &OVector) -> f64 { if self.p.len() != x.len() { panic!("Expected x and p to have equal lengths."); } if x.iter().sum::() != self.n { return 0.0; } - let coeff = factorial::multinomial(self.n, x); + let coeff = factorial::multinomial(self.n, x.as_slice()); let val = coeff * self .p @@ -248,8 +260,7 @@ where /// /// # Panics /// - /// If the elements in `x` do not sum to `n` or if the length of `x` is not - /// equivalent to the length of `p` + /// If length of `x` is not equal to length of `p` /// /// # Formula /// @@ -260,14 +271,14 @@ where /// where `n` is the number of trials, `p_i` is the `i`th probability, /// `x_i` is the `i`th `x` value, and `k` is the total number of /// probabilities - fn ln_pmf(&self, x: &[u64]) -> f64 { + fn ln_pmf(&self, x: &OVector) -> f64 { if self.p.len() != x.len() { panic!("Expected x and p to have equal lengths."); } if x.iter().sum::() != self.n { return f64::NEG_INFINITY; } - let coeff = factorial::multinomial(self.n, x).ln(); + let coeff = factorial::multinomial(self.n, x.as_slice()).ln(); let val = coeff + self .p From 94ee7d13633f6b658a5bb6f2cc28e88b4099b4ec Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:48:21 -0500 Subject: [PATCH 4/6] test: reintroduce tests for Multinomial --- src/distribution/multinomial.rs | 316 +++++++++++++++++++------------- 1 file changed, 192 insertions(+), 124 deletions(-) diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index 610d746b..fbd8d594 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -290,142 +290,210 @@ where } } -// TODO: fix tests -// #[rustfmt::skip] -// #[cfg(test)] -// mod tests { -// use crate::statistics::*; -// use crate::distribution::{Discrete, Multinomial}; +#[rustfmt::skip] +#[cfg(test)] +mod tests { + use crate::{ + distribution::{Discrete, Multinomial}, + statistics::{MeanN, VarianceN}, + }; + use nalgebra::{dmatrix, dvector, vector, DimMin, Dyn, OVector}; + use std::fmt::{Debug, Display}; -// fn try_create(p: &[f64], n: u64) -> Multinomial { -// let dist = Multinomial::new(p, n); -// assert!(dist.is_ok()); -// dist.unwrap() -// } - -// fn create_case(p: &[f64], n: u64) { -// let dist = try_create(p, n); -// assert_eq!(dist.p(), p); -// assert_eq!(dist.n(), n); -// } - -// fn bad_create_case(p: &[f64], n: u64) { -// let dist = Multinomial::new(p, n); -// assert!(dist.is_err()); -// } + fn try_create(p: OVector, n: u64) -> Multinomial + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + { + let mvn = Multinomial::new_from_nalgebra(p, n); + assert!(mvn.is_ok()); + mvn.unwrap() + } -// fn test_case(p: &[f64], n: u64, expected: &[f64], eval: F) -// where F: Fn(Multinomial) -> Vec -// { -// let dist = try_create(p, n); -// let x = eval(dist); -// assert_eq!(*expected, *x); -// } + fn bad_create_case(p: OVector, n: u64) -> crate::StatsError + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + { + let dd = Multinomial::new_from_nalgebra(p, n); + assert!(dd.is_err()); + dd.unwrap_err() + } -// fn test_almost(p: &[f64], n: u64, expected: &[f64], acc: f64, eval: F) -// where F: Fn(Multinomial) -> Vec -// { -// let dist = try_create(p, n); -// let x = eval(dist); -// assert_eq!(expected.len(), x.len()); -// for i in 0..expected.len() { -// assert_almost_eq!(expected[i], x[i], acc); -// } -// } + fn test_almost(p: OVector, n: u64, expected: T, acc: f64, eval: F) + where + T: Debug + Display + approx::RelativeEq, + F: FnOnce(Multinomial) -> T, + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + { + let dd = try_create(p, n); + let x = eval(dd); + assert_relative_eq!(expected, x, epsilon = acc); + } -// fn test_almost_sr(p: &[f64], n: u64, expected: f64, acc:f64, eval: F) -// where F: Fn(Multinomial) -> f64 -// { -// let dist = try_create(p, n); -// let x = eval(dist); -// assert_almost_eq!(expected, x, acc); -// } + #[test] + fn test_create() { + assert_relative_eq!( + *try_create(vector![1.0, 1.0, 1.0], 4).p(), + vector![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0] + ); + try_create(dvector![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 4); + } -// #[test] -// fn test_create() { -// create_case(&[1.0, 1.0, 1.0], 4); -// create_case(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 4); -// } + #[test] + fn test_bad_create() { + assert_eq!( + bad_create_case(vector![-1.0, 2.0], 4), + crate::StatsError::BadParams + ); -// #[test] -// fn test_bad_create() { -// bad_create_case(&[-1.0, 1.0], 4); -// bad_create_case(&[0.0, 0.0], 4); -// } + assert_eq!( + bad_create_case(vector![0.0, 0.0], 4), + crate::StatsError::BadParams + ); + assert_eq!( + bad_create_case(vector![1.0, f64::NAN], 4), + crate::StatsError::BadParams + ); + } -// #[test] -// fn test_mean() { -// let mean = |x: Multinomial| x.mean().unwrap(); -// test_case(&[0.3, 0.7], 5, &[1.5, 3.5], mean); -// test_case(&[0.1, 0.3, 0.6], 10, &[1.0, 3.0, 6.0], mean); -// test_case(&[0.15, 0.35, 0.3, 0.2], 20, &[3.0, 7.0, 6.0, 4.0], mean); -// } + #[test] + fn test_mean() { + let mean = |x: Multinomial<_>| x.mean().unwrap(); + test_almost(dvector![0.3, 0.7], 5, dvector![1.5, 3.5], 1e-12, mean); + test_almost( + dvector![0.1, 0.3, 0.6], + 10, + dvector![1.0, 3.0, 6.0], + 1e-12, + mean, + ); + test_almost( + dvector![1.0, 3.0, 6.0], + 10, + dvector![1.0, 3.0, 6.0], + 1e-12, + mean, + ); + test_almost( + dvector![0.15, 0.35, 0.3, 0.2], + 20, + dvector![3.0, 7.0, 6.0, 4.0], + 1e-12, + mean, + ); + } -// #[test] -// fn test_variance() { -// let variance = |x: Multinomial| x.variance().unwrap(); -// test_almost(&[0.3, 0.7], 5, &[1.05, 1.05], 1e-15, variance); -// test_almost(&[0.1, 0.3, 0.6], 10, &[0.9, 2.1, 2.4], 1e-15, variance); -// test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[2.55, 4.55, 4.2, 3.2], 1e-15, variance); -// } + #[test] + fn test_variance() { + let variance = |x: Multinomial<_>| x.variance().unwrap(); + test_almost( + dvector![0.3, 0.7], + 5, + dmatrix![1.05, -1.05; + -1.05, 1.05], + 1e-15, + variance, + ); + test_almost( + dvector![0.1, 0.3, 0.6], + 10, + dmatrix![0.9, -0.3, -0.6; + -0.3, 2.1, -1.8; + -0.6, -1.8, 2.4; + ], + 1e-15, + variance, + ); + test_almost( + dvector![0.15, 0.35, 0.3, 0.2], + 20, + dmatrix![2.55, -1.05, -0.90, -0.60; + -1.05, 4.55, -2.10, -1.40; + -0.90, -2.10, 4.20, -1.20; + -0.60, -1.40, -1.20, 3.20; + ], + 1e-15, + variance, + ); + } -// // #[test] -// // fn test_skewness() { -// // let skewness = |x: Multinomial| x.skewness().unwrap(); -// // test_almost(&[0.3, 0.7], 5, &[0.390360029179413, -0.390360029179413], 1e-15, skewness); -// // test_almost(&[0.1, 0.3, 0.6], 10, &[0.843274042711568, 0.276026223736942, -0.12909944487358], 1e-15, skewness); -// // test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[0.438357003759605, 0.140642169281549, 0.195180014589707, 0.335410196624968], 1e-15, skewness); -// // } + // // #[test] + // // fn test_skewness() { + // // let skewness = |x: Multinomial| x.skewness().unwrap(); + // // test_almost(&[0.3, 0.7], 5, &[0.390360029179413, -0.390360029179413], 1e-15, skewness); + // // test_almost(&[0.1, 0.3, 0.6], 10, &[0.843274042711568, 0.276026223736942, -0.12909944487358], 1e-15, skewness); + // // test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[0.438357003759605, 0.140642169281549, 0.195180014589707, 0.335410196624968], 1e-15, skewness); + // // } -// #[test] -// fn test_pmf() { -// let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); -// test_almost_sr(&[0.3, 0.7], 10, 0.121060821, 1e-15, pmf(&[1, 9])); -// test_almost_sr(&[0.1, 0.3, 0.6], 10, 0.105815808, 1e-15, pmf(&[1, 3, 6])); -// test_almost_sr(&[0.15, 0.35, 0.3, 0.2], 10, 0.000145152, 1e-15, pmf(&[1, 1, 1, 7])); -// } + #[test] + fn test_pmf() { + let pmf = |arg: OVector| move |x: Multinomial<_>| x.pmf(&arg); + test_almost( + dvector![0.3, 0.7], + 10, + 0.121060821, + 1e-15, + pmf(dvector![1, 9]), + ); + test_almost( + dvector![0.1, 0.3, 0.6], + 10, + 0.105815808, + 1e-15, + pmf(dvector![1, 3, 6]), + ); + test_almost( + dvector![0.15, 0.35, 0.3, 0.2], + 10, + 0.000145152, + 1e-15, + pmf(dvector![1, 1, 1, 7]), + ); + } -// #[test] -// #[should_panic] -// fn test_pmf_x_wrong_length() { -// let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); -// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); -// n.pmf(&[1]); -// } + // #[test] + // #[should_panic] + // fn test_pmf_x_wrong_length() { + // let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); + // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); + // n.pmf(&[1]); + // } -// #[test] -// #[should_panic] -// fn test_pmf_x_wrong_sum() { -// let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); -// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); -// n.pmf(&[1, 3]); -// } + // #[test] + // #[should_panic] + // fn test_pmf_x_wrong_sum() { + // let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); + // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); + // n.pmf(&[1, 3]); + // } -// #[test] -// fn test_ln_pmf() { -// let large_p = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; -// let n = Multinomial::new(large_p, 45).unwrap(); -// let x = &[1, 2, 3, 4, 5, 6, 7, 8, 9]; -// assert_almost_eq!(n.pmf(x).ln(), n.ln_pmf(x), 1e-13); -// let n2 = Multinomial::new(large_p, 18).unwrap(); -// let x2 = &[1, 1, 1, 2, 2, 2, 3, 3, 3]; -// assert_almost_eq!(n2.pmf(x2).ln(), n2.ln_pmf(x2), 1e-13); -// let n3 = Multinomial::new(large_p, 51).unwrap(); -// let x3 = &[5, 6, 7, 8, 7, 6, 5, 4, 3]; -// assert_almost_eq!(n3.pmf(x3).ln(), n3.ln_pmf(x3), 1e-13); -// } + // #[test] + // fn test_ln_pmf() { + // let large_p = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + // let n = Multinomial::new(large_p, 45).unwrap(); + // let x = &[1, 2, 3, 4, 5, 6, 7, 8, 9]; + // assert_almost_eq!(n.pmf(x).ln(), n.ln_pmf(x), 1e-13); + // let n2 = Multinomial::new(large_p, 18).unwrap(); + // let x2 = &[1, 1, 1, 2, 2, 2, 3, 3, 3]; + // assert_almost_eq!(n2.pmf(x2).ln(), n2.ln_pmf(x2), 1e-13); + // let n3 = Multinomial::new(large_p, 51).unwrap(); + // let x3 = &[5, 6, 7, 8, 7, 6, 5, 4, 3]; + // assert_almost_eq!(n3.pmf(x3).ln(), n3.ln_pmf(x3), 1e-13); + // } -// #[test] -// #[should_panic] -// fn test_ln_pmf_x_wrong_length() { -// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); -// n.ln_pmf(&[1]); -// } + // #[test] + // #[should_panic] + // fn test_ln_pmf_x_wrong_length() { + // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); + // n.ln_pmf(&[1]); + // } -// #[test] -// #[should_panic] -// fn test_ln_pmf_x_wrong_sum() { -// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); -// n.ln_pmf(&[1, 3]); -// } -// } + // #[test] + // #[should_panic] + // fn test_ln_pmf_x_wrong_sum() { + // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); + // n.ln_pmf(&[1, 3]); + // } +} From 6ad3b1524d4c9146a1c3692c0443e7691f845fb0 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 2 Aug 2024 17:09:06 -0500 Subject: [PATCH 5/6] test(docs): update Multinomial doc tests --- src/distribution/multinomial.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index fbd8d594..dc402050 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -16,10 +16,10 @@ use rand::Rng; /// ``` /// use statrs::distribution::Multinomial; /// use statrs::statistics::MeanN; -/// use nalgebra::DVector; +/// use nalgebra::vector; /// -/// let n = Multinomial::new(&[0.3, 0.7], 5).unwrap(); -/// assert_eq!(n.mean().unwrap(), DVector::from_vec(vec![1.5, 3.5])); +/// let n = Multinomial::new_from_nalgebra(vector![0.3, 0.7], 5).unwrap(); +/// assert_eq!(n.mean().unwrap(), (vector![1.5, 3.5])); /// ``` #[derive(Debug, Clone, PartialEq)] pub struct Multinomial @@ -84,9 +84,10 @@ where /// /// ``` /// use statrs::distribution::Multinomial; + /// use nalgebra::dvector; /// - /// let n = Multinomial::new(&[0.0, 1.0, 2.0], 3).unwrap(); - /// assert_eq!(n.p(), [0.0, 1.0, 2.0]); + /// let n = Multinomial::new(vec![0.0, 1.0, 2.0], 3).unwrap(); + /// assert_eq!(*n.p(), dvector![0.0, 1.0/3.0, 2.0/3.0]); /// ``` pub fn p(&self) -> &OVector { &self.p @@ -100,7 +101,7 @@ where /// ``` /// use statrs::distribution::Multinomial; /// - /// let n = Multinomial::new(&[0.0, 1.0, 2.0], 3).unwrap(); + /// let n = Multinomial::new(vec![0.0, 1.0, 2.0], 3).unwrap(); /// assert_eq!(n.n(), 3); /// ``` pub fn n(&self) -> u64 { From 2ae5e79e81e512186188b10ef95225c311f4e623 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sat, 3 Aug 2024 09:44:12 -0500 Subject: [PATCH 6/6] chore: allow clippy lint error api will be more specific later --- src/distribution/internal.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index b45b4bf7..0837dc40 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -27,6 +27,7 @@ where } let mut sum = 0.0; for &x in arr.iter() { + #[allow(clippy::if_same_then_else)] if x.is_nan() { return Err(StatsError::BadParams); } else if x.is_infinite() {