From 80c1fd45057eb9decfb780177bae24ab885113d6 Mon Sep 17 00:00:00 2001 From: Iago-lito Date: Mon, 7 Aug 2023 13:54:18 +0200 Subject: [PATCH 001/185] Fix invalid `k` letter in Poisson formula docs. For some reason, these formulae used a mixture of `k` and `x` for the same parameter. I haven't checked whether this problem is spread more widely within the docs though. --- src/distribution/poisson.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 0c1b2379..9879b6a8 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -229,7 +229,7 @@ impl Discrete for Poisson { /// # Formula /// /// ```ignore - /// (λ^k * e^(-λ)) / x! + /// (λ^x * e^(-λ)) / x! /// ``` /// /// where `λ` is the rate @@ -244,7 +244,7 @@ impl Discrete for Poisson { /// # Formula /// /// ```ignore - /// ln((λ^k * e^(-λ)) / x!) + /// ln((λ^x * e^(-λ)) / x!) /// ``` /// /// where `λ` is the rate From 73865d89b56986f7c77c47cf0ad127218e2ad232 Mon Sep 17 00:00:00 2001 From: Iago-lito Date: Sun, 10 Mar 2024 22:53:09 +0100 Subject: [PATCH 002/185] Fix invalid k letter in hypergeometric formula docs. Done according to [this comment](https://github.com/statrs-dev/statrs/pull/192#issuecomment-1986526127) I.. think? --- src/distribution/hypergeometric.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 0ac8e750..2aefe1da 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -140,8 +140,8 @@ impl DiscreteCDF for Hypergeometric { /// # Formula /// /// ```ignore - /// 1 - ((n choose k+1) * (N-n choose K-k-1)) / (N choose K) * 3_F_2(1, - /// k+1-K, k+1-n; k+2, N+k+2-K-n; 1) + /// 1 - ((n choose x+1) * (N-n choose K-x-1)) / (N choose K) * 3_F_2(1, + /// x+1-K, x+1-n; k+2, N+x+2-K-n; 1) /// ``` /// /// where `N` is population, `K` is successes, `n` is draws, @@ -150,7 +150,7 @@ impl DiscreteCDF for Hypergeometric { /// org/wiki/Generalized_hypergeometric_function) /// /// Calculated as a discrete integral over the probability mass - /// function evaluated from 0..k+1 + /// function evaluated from 0..x+1 fn cdf(&self, x: u64) -> f64 { if x < self.min() { 0.0 @@ -174,8 +174,8 @@ impl DiscreteCDF for Hypergeometric { /// # Formula /// /// ```ignore - /// 1 - ((n choose k+1) * (N-n choose K-k-1)) / (N choose K) * 3_F_2(1, - /// k+1-K, k+1-n; k+2, N+k+2-K-n; 1) + /// 1 - ((n choose x+1) * (N-n choose K-x-1)) / (N choose K) * 3_F_2(1, + /// x+1-K, x+1-n; x+2, N+x+2-K-n; 1) /// ``` /// /// where `N` is population, `K` is successes, `n` is draws, @@ -184,7 +184,7 @@ impl DiscreteCDF for Hypergeometric { /// org/wiki/Generalized_hypergeometric_function) /// /// Calculated as a discrete integral over the probability mass - /// function evaluated from (k+1)..max + /// function evaluated from (x+1)..max fn sf(&self, x: u64) -> f64 { if x < self.min() { 1.0 From f5e238f76a9cad2e1c71736a0fbc9f0049caa68e Mon Sep 17 00:00:00 2001 From: Raimundo Saona <37874270+saona-raimundo@users.noreply.github.com> Date: Mon, 29 Jan 2024 22:00:51 -0300 Subject: [PATCH 003/185] feature: Implement Continuous> for MultivariateNormal --- src/distribution/multivariate_normal.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index ab168020..75902a4b 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -202,6 +202,28 @@ impl<'a> Continuous<&'a DVector, f64> for MultivariateNormal { } } +impl Continuous, f64> for MultivariateNormal { + /// Calculates the probability density function for the multivariate + /// normal distribution at `x` + /// + /// # Formula + /// + /// ```ignore + /// (2 * π) ^ (-k / 2) * det(Σ) ^ (1 / 2) * e ^ ( -(1 / 2) * transpose(x - μ) * inv(Σ) * (x - μ)) + /// ``` + /// + /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant + /// of the covariance matrix, and `k` is the dimension of the distribution + fn pdf(&self, x: Vec) -> f64 { + self.pdf(&DVector::from(x)) + } + /// Calculates the log probability density function for the multivariate + /// normal distribution at `x`. Equivalent to pdf(x).ln(). + fn ln_pdf(&self, x: Vec) -> f64 { + self.pdf(&DVector::from(x)) + } +} + #[rustfmt::skip] #[cfg(all(test, feature = "nightly"))] mod tests { From f25de9a348a5785466ceb4955926147cd4f3e5a9 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 17 Apr 2024 00:12:24 +0200 Subject: [PATCH 004/185] Use `include` to reduce amount of files in crates.io package This also removes the non-free NIST dataset from the package --- Cargo.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index d8e6eadf..32ae5552 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,8 @@ homepage = "https://github.com/boxtown/statrs" repository = "https://github.com/boxtown/statrs" edition = "2018" +include = ["CHANGELOG.md", "LICENSE.md", "src/"] + [lib] name = "statrs" path = "src/lib.rs" From 018bb70e68a56cb1d14336a7cc177534c36ca1c7 Mon Sep 17 00:00:00 2001 From: Henry Jacobson Date: Sat, 31 Dec 2022 11:58:38 +0100 Subject: [PATCH 005/185] feat: default implementation of surival function with generics --- src/distribution/mod.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 23ae7ec5..a9718ee6 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -101,7 +101,9 @@ pub trait ContinuousCDF: Min + Max { /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(0.5, n.sf(0.5)); /// ``` - fn sf(&self, x: K) -> T; + fn sf(&self, x: K) -> T { + T::one() - self.cdf(x) + } /// Due to issues with rounding and floating-point accuracy the default /// implementation may be ill-behaved. @@ -167,7 +169,9 @@ pub trait DiscreteCDF: Min + Max { /// let n = DiscreteUniform::new(1, 10).unwrap(); /// assert_eq!(0.4, n.sf(6)); /// ``` - fn sf(&self, x: K) -> T; + fn sf(&self, x: K) -> T { + T::one() - self.cdf(x) + } /// Due to issues with rounding and floating-point accuracy the default implementation may be ill-behaved /// Specialized inverse cdfs should be used whenever possible. From f97a4c8f29c4aa9a6a941a9baad1d5065d4ac76d Mon Sep 17 00:00:00 2001 From: Kalev Lember Date: Tue, 28 Mar 2023 11:07:57 +0200 Subject: [PATCH 006/185] Update nalgebra to 0.32 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 32ae5552..ac7f09ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ nightly = [] [dependencies] rand = "0.8" -nalgebra = { version = "0.29", features = ["rand"] } +nalgebra = { version = "0.32", features = ["rand"] } approx = "0.5.0" num-traits = "0.2.14" lazy_static = "1.4.0" From 7dc37e459440b76efc30f8a1576e5ac9aec7eb47 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 11 Apr 2024 15:50:47 -0500 Subject: [PATCH 007/185] fix: correct repository link --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ac7f09ca..c7ad536d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,8 +8,8 @@ license = "MIT" keywords = ["probability", "statistics", "stats", "distribution", "math"] categories = ["science"] documentation = "https://docs.rs/statrs/0.15.0/statrs/" -homepage = "https://github.com/boxtown/statrs" -repository = "https://github.com/boxtown/statrs" +homepage = "https://github.com/statrs-dev/statrs" +repository = "https://github.com/statrs-dev/statrs" edition = "2018" include = ["CHANGELOG.md", "LICENSE.md", "src/"] From 90b4109ce4401b61eb798352495d43c10fc60077 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 24 Mar 2024 16:37:34 -0500 Subject: [PATCH 008/185] doc: fix 0.16 changelog to include version dependency change rely on docs.rs to have correct versino in url for docs --- CHANGELOG.md | 4 ++++ Cargo.toml | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a65a3c75..9b8b7ffc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,13 @@ +Unreleased + + v0.16.0 - Adds an `sf` method to the `ContinuousCDF` and `DiscreteCDF` traits - Calculates the survival function (CDF complement) for the distribution. - Survival function implemented for all distributions implementing `ContinuousCDF` and `DiscreteCDF` - See [PR description](https://github.com/statrs-dev/statrs/pull/172) for in-depth changes +- update `nalgebra` to `0.29` v0.15.0 diff --git a/Cargo.toml b/Cargo.toml index c7ad536d..34549647 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,6 @@ description = "Statistical computing library for Rust" license = "MIT" keywords = ["probability", "statistics", "stats", "distribution", "math"] categories = ["science"] -documentation = "https://docs.rs/statrs/0.15.0/statrs/" homepage = "https://github.com/statrs-dev/statrs" repository = "https://github.com/statrs-dev/statrs" edition = "2018" From b8a0ec248d1ed25b268411897464befe46f5e385 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 14 Mar 2024 18:54:54 -0500 Subject: [PATCH 009/185] doc: docstrings with math now use text instead of ignore --- src/distribution/bernoulli.rs | 24 ++++++++++----------- src/distribution/beta.rs | 22 +++++++++---------- src/distribution/binomial.rs | 24 ++++++++++----------- src/distribution/categorical.rs | 20 +++++++++--------- src/distribution/cauchy.rs | 18 ++++++++-------- src/distribution/chi.rs | 22 +++++++++---------- src/distribution/chi_squared.rs | 24 ++++++++++----------- src/distribution/dirac.rs | 14 ++++++------- src/distribution/dirichlet.rs | 16 +++++++------- src/distribution/discrete_uniform.rs | 18 ++++++++-------- src/distribution/erlang.rs | 22 +++++++++---------- src/distribution/exponential.rs | 24 ++++++++++----------- src/distribution/fisher_snedecor.rs | 20 +++++++++--------- src/distribution/gamma.rs | 22 +++++++++---------- src/distribution/geometric.rs | 24 ++++++++++----------- src/distribution/hypergeometric.rs | 22 +++++++++---------- src/distribution/inverse_gamma.rs | 22 +++++++++---------- src/distribution/laplace.rs | 28 ++++++++++++------------- src/distribution/log_normal.rs | 28 ++++++++++++------------- src/distribution/multinomial.rs | 10 ++++----- src/distribution/multivariate_normal.rs | 10 +++++---- src/distribution/negative_binomial.rs | 24 ++++++++++----------- src/distribution/normal.rs | 28 ++++++++++++------------- src/distribution/pareto.rs | 24 ++++++++++----------- src/distribution/poisson.rs | 24 ++++++++++----------- src/distribution/students_t.rs | 24 ++++++++++----------- src/distribution/triangular.rs | 20 +++++++++--------- src/distribution/uniform.rs | 20 +++++++++--------- src/distribution/weibull.rs | 24 ++++++++++----------- 29 files changed, 312 insertions(+), 310 deletions(-) diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index e31f9c5f..46648a7a 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -92,7 +92,7 @@ impl DiscreteCDF for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < 0 { 0 } /// else if x >= 1 { 1 } /// else { 1 - p } @@ -106,7 +106,7 @@ impl DiscreteCDF for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < 0 { 1 } /// else if x >= 1 { 0 } /// else { p } @@ -123,7 +123,7 @@ impl Min for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> u64 { @@ -138,7 +138,7 @@ impl Max for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 /// ``` fn max(&self) -> u64 { @@ -152,7 +152,7 @@ impl Distribution for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// p /// ``` fn mean(&self) -> Option { @@ -163,7 +163,7 @@ impl Distribution for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// p * (1 - p) /// ``` fn variance(&self) -> Option { @@ -174,7 +174,7 @@ impl Distribution for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// q = (1 - p) /// -q * ln(q) - p * ln(p) /// ``` @@ -186,7 +186,7 @@ impl Distribution for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// q = (1 - p) /// (1 - 2p) / sqrt(p * q) /// ``` @@ -201,7 +201,7 @@ impl Median for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// if p < 0.5 { 0 } /// else if p > 0.5 { 1 } /// else { 0.5 } @@ -216,7 +216,7 @@ impl Mode> for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// if p < 0.5 { 0 } /// else { 1 } /// ``` @@ -231,7 +231,7 @@ impl Discrete for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// if x == 0 { 1 - p } /// else { p } /// ``` @@ -244,7 +244,7 @@ impl Discrete for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// else if x == 0 { ln(1 - p) } /// else { ln(p) } /// ``` diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 6dd5adc3..ba071dfa 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -103,7 +103,7 @@ impl ContinuousCDF for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// I_x(α, β) /// ``` /// @@ -134,7 +134,7 @@ impl ContinuousCDF for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// I_(1-x)(β, α) /// ``` /// @@ -168,7 +168,7 @@ impl Min for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -183,7 +183,7 @@ impl Max for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 /// ``` fn max(&self) -> f64 { @@ -196,7 +196,7 @@ impl Distribution for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// α / (α + β) /// ``` /// @@ -215,7 +215,7 @@ impl Distribution for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// (α * β) / ((α + β)^2 * (α + β + 1)) /// ``` /// @@ -235,7 +235,7 @@ impl Distribution for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(B(α, β)) - (α - 1)ψ(α) - (β - 1)ψ(β) + (α + β - 2)ψ(α + β) /// ``` /// @@ -256,7 +256,7 @@ impl Distribution for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// 2(β - α) * sqrt(α + β + 1) / ((α + β + 2) * sqrt(αβ)) /// ``` /// @@ -290,7 +290,7 @@ impl Mode> for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// (α - 1) / (α + β - 2) /// ``` /// @@ -314,7 +314,7 @@ impl Continuous for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// let B(α, β) = Γ(α)Γ(β)/Γ(α + β) /// /// x^(α - 1) * (1 - x)^(β - 1) / B(α, β) @@ -352,7 +352,7 @@ impl Continuous for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// let B(α, β) = Γ(α)Γ(β)/Γ(α + β) /// /// ln(x^(α - 1) * (1 - x)^(β - 1) / B(α, β)) diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 85ddecaa..7c5c9622 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -106,7 +106,7 @@ impl DiscreteCDF for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// I_(1 - p)(n - x, 1 + x) /// ``` /// @@ -125,7 +125,7 @@ impl DiscreteCDF for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// I_(p)(x + 1, n - x) /// ``` /// @@ -147,7 +147,7 @@ impl Min for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> u64 { @@ -162,7 +162,7 @@ impl Max for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// n /// ``` fn max(&self) -> u64 { @@ -175,7 +175,7 @@ impl Distribution for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// p * n /// ``` fn mean(&self) -> Option { @@ -185,7 +185,7 @@ impl Distribution for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// n * p * (1 - p) /// ``` fn variance(&self) -> Option { @@ -195,7 +195,7 @@ impl Distribution for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * ln (2 * π * e * n * p * (1 - p)) /// ``` fn entropy(&self) -> Option { @@ -213,7 +213,7 @@ impl Distribution for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 - 2p) / sqrt(n * p * (1 - p))) /// ``` fn skewness(&self) -> Option { @@ -226,7 +226,7 @@ impl Median for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// floor(n * p) /// ``` fn median(&self) -> f64 { @@ -239,7 +239,7 @@ impl Mode> for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// floor((n + 1) * p) /// ``` fn mode(&self) -> Option { @@ -260,7 +260,7 @@ impl Discrete for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// (n choose k) * p^k * (1 - p)^(n - k) /// ``` fn pmf(&self, x: u64) -> f64 { @@ -291,7 +291,7 @@ impl Discrete for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((n choose k) * p^k * (1 - p)^(n - k)) /// ``` fn ln_pmf(&self, x: u64) -> f64 { diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index f489d653..ba3c39de 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -89,7 +89,7 @@ impl DiscreteCDF for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// sum(p_j) from 0..x /// ``` /// @@ -107,7 +107,7 @@ impl DiscreteCDF for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// [ sum(p_j) from x..end ] /// ``` fn sf(&self, x: u64) -> f64 { @@ -128,7 +128,7 @@ impl DiscreteCDF for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// i /// ``` /// @@ -151,7 +151,7 @@ impl Min for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> u64 { @@ -166,7 +166,7 @@ impl Max for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// n /// ``` fn max(&self) -> u64 { @@ -179,7 +179,7 @@ impl Distribution for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// Σ(j * p_j) /// ``` /// @@ -198,7 +198,7 @@ impl Distribution for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// Σ(p_j * (j - μ)^2) /// ``` /// @@ -221,7 +221,7 @@ impl Distribution for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// -Σ(p_j * ln(p_j)) /// ``` /// @@ -243,7 +243,7 @@ impl Median for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// CDF^-1(0.5) /// ``` fn median(&self) -> f64 { @@ -257,7 +257,7 @@ impl Discrete for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// p_x /// ``` fn pmf(&self, x: u64) -> f64 { diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index e42919ea..dcd81af5 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -91,7 +91,7 @@ impl ContinuousCDF for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / π) * arctan((x - x_0) / γ) + 0.5 /// ``` /// @@ -105,7 +105,7 @@ impl ContinuousCDF for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / π) * arctan(-(x - x_0) / γ) + 0.5 /// ``` /// @@ -123,7 +123,7 @@ impl Min for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// NEG_INF /// ``` fn min(&self) -> f64 { @@ -137,7 +137,7 @@ impl Max for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -150,7 +150,7 @@ impl Distribution for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(γ) + ln(4π) /// ``` /// @@ -165,7 +165,7 @@ impl Median for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// x_0 /// ``` /// @@ -180,7 +180,7 @@ impl Mode> for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// x_0 /// ``` /// @@ -196,7 +196,7 @@ impl Continuous for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / (πγ * (1 + ((x - x_0) / γ)^2)) /// ``` /// @@ -212,7 +212,7 @@ impl Continuous for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(1 / (πγ * (1 + ((x - x_0) / γ)^2))) /// ``` /// diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index 7fcccf5e..0a65a8e6 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -84,7 +84,7 @@ impl ContinuousCDF for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// P(k / 2, x^2 / 2) /// ``` /// @@ -105,7 +105,7 @@ impl ContinuousCDF for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// P(k / 2, x^2 / 2) /// ``` /// @@ -128,7 +128,7 @@ impl Min for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -142,7 +142,7 @@ impl Max for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -159,7 +159,7 @@ impl Distribution for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// sqrt2 * Γ((k + 1) / 2) / Γ(k / 2) /// ``` /// @@ -193,7 +193,7 @@ impl Distribution for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// k - μ^2 /// ``` /// @@ -211,7 +211,7 @@ impl Distribution for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(Γ(k / 2)) + 0.5 * (k - ln2 - (k - 1) * ψ(k / 2)) /// ``` /// @@ -236,7 +236,7 @@ impl Distribution for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// (μ / σ^3) * (1 - 2σ^2) /// ``` /// where `μ` is the mean and `σ` the standard deviation @@ -257,7 +257,7 @@ impl Mode> for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// sqrt(k - 1) /// ``` /// @@ -276,7 +276,7 @@ impl Continuous for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// (2^(1 - (k / 2)) * x^(k - 1) * e^(-x^2 / 2)) / Γ(k / 2) /// ``` /// @@ -299,7 +299,7 @@ impl Continuous for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((2^(1 - (k / 2)) * x^(k - 1) * e^(-x^2 / 2)) / Γ(k / 2)) /// ``` fn ln_pdf(&self, x: f64) -> f64 { diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index 05ad63ba..5551b55f 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -108,7 +108,7 @@ impl ContinuousCDF for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / Γ(k / 2)) * γ(k / 2, x / 2) /// ``` /// @@ -123,7 +123,7 @@ impl ContinuousCDF for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / Γ(k / 2)) * γ(k / 2, x / 2) /// ``` /// @@ -141,7 +141,7 @@ impl Min for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -156,7 +156,7 @@ impl Max for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -169,7 +169,7 @@ impl Distribution for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// k /// ``` /// @@ -181,7 +181,7 @@ impl Distribution for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// 2k /// ``` /// @@ -193,7 +193,7 @@ impl Distribution for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// (k / 2) + ln(2 * Γ(k / 2)) + (1 - (k / 2)) * ψ(k / 2) /// ``` /// @@ -206,7 +206,7 @@ impl Distribution for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// sqrt(8 / k) /// ``` /// @@ -221,7 +221,7 @@ impl Median for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// k * (1 - (2 / 9k))^3 /// ``` fn median(&self) -> f64 { @@ -241,7 +241,7 @@ impl Mode> for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// k - 2 /// ``` /// @@ -257,7 +257,7 @@ impl Continuous for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / (2^(k / 2) * Γ(k / 2)) * x^((k / 2) - 1) * e^(-x / 2) /// ``` /// @@ -271,7 +271,7 @@ impl Continuous for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(1 / (2^(k / 2) * Γ(k / 2)) * x^((k / 2) - 1) * e^(-x / 2)) /// ``` fn ln_pdf(&self, x: f64) -> f64 { diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index b58b676b..daa081a2 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -85,7 +85,7 @@ impl Min for Dirac { /// /// # Formula /// - /// ```ignore + /// ```text /// v /// ``` fn min(&self) -> f64 { @@ -99,7 +99,7 @@ impl Max for Dirac { /// /// # Formula /// - /// ```ignore + /// ```text /// v /// ``` fn max(&self) -> f64 { @@ -121,7 +121,7 @@ impl Distribution for Dirac { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` /// @@ -133,7 +133,7 @@ impl Distribution for Dirac { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` /// @@ -145,7 +145,7 @@ impl Distribution for Dirac { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn skewness(&self) -> Option { @@ -158,7 +158,7 @@ impl Median for Dirac { /// /// # Formula /// - /// ```ignore + /// ```text /// v /// ``` /// @@ -173,7 +173,7 @@ impl Mode> for Dirac { /// /// # Formula /// - /// ```ignore + /// ```text /// v /// ``` /// diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index 104a5981..a08b3175 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -111,13 +111,13 @@ impl Dirichlet { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(B(α)) - (K - α_0)ψ(α_0) - Σ((α_i - 1)ψ(α_i)) /// ``` /// /// where /// - /// ```ignore + /// ```text /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i)) /// ``` /// @@ -158,7 +158,7 @@ impl MeanN> for Dirichlet { /// /// # Formula /// - /// ```ignore + /// ```text /// α_i / α_0 /// ``` /// @@ -175,7 +175,7 @@ impl VarianceN> for Dirichlet { /// /// # Formula /// - /// ```ignore + /// ```text /// (α_i * (α_0 - α_i)) / (α_0^2 * (α_0 + 1)) /// ``` /// @@ -215,13 +215,13 @@ impl<'a> Continuous<&'a DVector, f64> for Dirichlet { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / B(α)) * Π(x_i^(α_i - 1)) /// ``` /// /// where /// - /// ```ignore + /// ```text /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i)) /// ``` /// @@ -249,13 +249,13 @@ impl<'a> Continuous<&'a DVector, f64> for Dirichlet { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((1 / B(α)) * Π(x_i^(α_i - 1))) /// ``` /// /// where /// - /// ```ignore + /// ```text /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i)) /// ``` /// diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index c151318f..926b1cf3 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -63,7 +63,7 @@ impl DiscreteCDF for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (floor(x) - min + 1) / (max - min + 1) /// ``` fn cdf(&self, x: i64) -> f64 { @@ -131,7 +131,7 @@ impl Distribution for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (min + max) / 2 /// ``` fn mean(&self) -> Option { @@ -141,7 +141,7 @@ impl Distribution for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// ((max - min + 1)^2 - 1) / 12 /// ``` fn variance(&self) -> Option { @@ -152,7 +152,7 @@ impl Distribution for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(max - min + 1) /// ``` fn entropy(&self) -> Option { @@ -163,7 +163,7 @@ impl Distribution for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn skewness(&self) -> Option { @@ -176,7 +176,7 @@ impl Median for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (max + min) / 2 /// ``` fn median(&self) -> f64 { @@ -194,7 +194,7 @@ impl Mode> for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// N/A // (max + min) / 2 for the middle element /// ``` fn mode(&self) -> Option { @@ -212,7 +212,7 @@ impl Discrete for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / (max - min + 1) /// ``` fn pmf(&self, x: i64) -> f64 { @@ -232,7 +232,7 @@ impl Discrete for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(1 / (max - min + 1)) /// ``` fn ln_pmf(&self, x: i64) -> f64 { diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index 619ba698..e07dff6b 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -91,7 +91,7 @@ impl ContinuousCDF for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// γ(k, λx) (k - 1)! /// ``` /// @@ -107,7 +107,7 @@ impl ContinuousCDF for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// γ(k, λx) (k - 1)! /// ``` /// @@ -125,7 +125,7 @@ impl Min for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -140,7 +140,7 @@ impl Max for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -158,7 +158,7 @@ impl Distribution for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// k / λ /// ``` /// @@ -170,7 +170,7 @@ impl Distribution for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// k / λ^2 /// ``` /// @@ -182,7 +182,7 @@ impl Distribution for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// k - ln(λ) + ln(Γ(k)) + (1 - k) * ψ(k) /// ``` /// @@ -195,7 +195,7 @@ impl Distribution for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// 2 / sqrt(k) /// ``` /// @@ -215,7 +215,7 @@ impl Mode> for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// (k - 1) / λ /// ``` /// @@ -236,7 +236,7 @@ impl Continuous for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// (λ^k / Γ(k)) * x^(k - 1) * e^(-λ * x) /// ``` /// @@ -256,7 +256,7 @@ impl Continuous for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((λ^k / Γ(k)) * x^(k - 1) * e ^(-λ * x)) /// ``` /// diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index 890592d8..f374b9b2 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -79,7 +79,7 @@ impl ContinuousCDF for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 - e^(-λ * x) /// ``` /// @@ -97,7 +97,7 @@ impl ContinuousCDF for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// e^(-λ * x) /// ``` /// @@ -117,7 +117,7 @@ impl Min for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -131,7 +131,7 @@ impl Max for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -144,7 +144,7 @@ impl Distribution for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / λ /// ``` /// @@ -156,7 +156,7 @@ impl Distribution for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / λ^2 /// ``` /// @@ -168,7 +168,7 @@ impl Distribution for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 - ln(λ) /// ``` /// @@ -180,7 +180,7 @@ impl Distribution for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// 2 /// ``` fn skewness(&self) -> Option { @@ -193,7 +193,7 @@ impl Median for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / λ) * ln2 /// ``` /// @@ -208,7 +208,7 @@ impl Mode> for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn mode(&self) -> Option { @@ -222,7 +222,7 @@ impl Continuous for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// λ * e^(-λ * x) /// ``` /// @@ -240,7 +240,7 @@ impl Continuous for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(λ * e^(-λ * x)) /// ``` /// diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index 4b8782be..ee70c30e 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -103,7 +103,7 @@ impl ContinuousCDF for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// I_((d1 * x) / (d1 * x + d2))(d1 / 2, d2 / 2) /// ``` /// @@ -129,7 +129,7 @@ impl ContinuousCDF for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// I_(1 - ((d1 * x) / (d1 * x + d2))(d2 / 2, d1 / 2) /// ``` /// @@ -158,7 +158,7 @@ impl Min for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -173,7 +173,7 @@ impl Max for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -194,7 +194,7 @@ impl Distribution for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// d2 / (d2 - 2) /// ``` /// @@ -218,7 +218,7 @@ impl Distribution for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// (2 * d2^2 * (d1 + d2 - 2)) / (d1 * (d2 - 2)^2 * (d2 - 4)) /// ``` /// @@ -249,7 +249,7 @@ impl Distribution for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// ((2d1 + d2 - 2) * sqrt(8 * (d2 - 4))) / ((d2 - 6) * sqrt(d1 * (d1 + d2 /// - 2))) /// ``` @@ -282,7 +282,7 @@ impl Mode> for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// ((d1 - 2) / d1) * (d2 / (d2 + 2)) /// ``` /// @@ -311,7 +311,7 @@ impl Continuous for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// sqrt(((d1 * x) ^ d1 * d2 ^ d2) / (d1 * x + d2) ^ (d1 + d2)) / (x * β(d1 /// / 2, d2 / 2)) /// ``` @@ -340,7 +340,7 @@ impl Continuous for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(sqrt(((d1 * x) ^ d1 * d2 ^ d2) / (d1 * x + d2) ^ (d1 + d2)) / (x * /// β(d1 / 2, d2 / 2))) /// ``` diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 7a36a30f..fd993ba5 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -99,7 +99,7 @@ impl ContinuousCDF for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / Γ(α)) * γ(α, β * x) /// ``` /// @@ -124,7 +124,7 @@ impl ContinuousCDF for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / Γ(α)) * γ(α, β * x) /// ``` /// @@ -156,7 +156,7 @@ impl Min for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -171,7 +171,7 @@ impl Max for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -184,7 +184,7 @@ impl Distribution for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// α / β /// ``` /// @@ -196,7 +196,7 @@ impl Distribution for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// α / β^2 /// ``` /// @@ -208,7 +208,7 @@ impl Distribution for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// α - ln(β) + ln(Γ(α)) + (1 - α) * ψ(α) /// ``` /// @@ -224,7 +224,7 @@ impl Distribution for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// 2 / sqrt(α) /// ``` /// @@ -239,7 +239,7 @@ impl Mode> for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// (α - 1) / β /// ``` /// @@ -260,7 +260,7 @@ impl Continuous for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// (β^α / Γ(α)) * x^(α - 1) * e^(-β * x) /// ``` /// @@ -291,7 +291,7 @@ impl Continuous for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((β^α / Γ(α)) * x^(α - 1) * e ^(-β * x)) /// ``` /// diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index a6e390d7..c7e801c1 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -85,7 +85,7 @@ impl DiscreteCDF for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 - (1 - p) ^ x /// ``` fn cdf(&self, x: u64) -> f64 { @@ -104,7 +104,7 @@ impl DiscreteCDF for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 - p) ^ x /// ``` fn sf(&self, x: u64) -> f64 { @@ -125,7 +125,7 @@ impl Min for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 /// ``` fn min(&self) -> u64 { @@ -140,7 +140,7 @@ impl Max for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// 2^63 - 1 /// ``` fn max(&self) -> u64 { @@ -153,7 +153,7 @@ impl Distribution for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / p /// ``` fn mean(&self) -> Option { @@ -163,7 +163,7 @@ impl Distribution for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 - p) / p^2 /// ``` fn variance(&self) -> Option { @@ -173,7 +173,7 @@ impl Distribution for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// (-(1 - p) * log_2(1 - p) - p * log_2(p)) / p /// ``` fn entropy(&self) -> Option { @@ -184,7 +184,7 @@ impl Distribution for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// (2 - p) / sqrt(1 - p) /// ``` fn skewness(&self) -> Option { @@ -200,7 +200,7 @@ impl Mode> for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 /// ``` fn mode(&self) -> Option { @@ -215,7 +215,7 @@ impl Median for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// ceil(-1 / log_2(1 - p)) /// ``` fn median(&self) -> f64 { @@ -229,7 +229,7 @@ impl Discrete for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 - p)^(x - 1) * p /// ``` fn pmf(&self, x: u64) -> f64 { @@ -245,7 +245,7 @@ impl Discrete for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((1 - p)^(x - 1) * p) /// ``` fn ln_pmf(&self, x: u64) -> f64 { diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 2aefe1da..95f44d18 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -139,7 +139,7 @@ impl DiscreteCDF for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 - ((n choose x+1) * (N-n choose K-x-1)) / (N choose K) * 3_F_2(1, /// x+1-K, x+1-n; k+2, N+x+2-K-n; 1) /// ``` @@ -173,7 +173,7 @@ impl DiscreteCDF for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 - ((n choose x+1) * (N-n choose K-x-1)) / (N choose K) * 3_F_2(1, /// x+1-K, x+1-n; x+2, N+x+2-K-n; 1) /// ``` @@ -193,7 +193,7 @@ impl DiscreteCDF for Hypergeometric { } else { let k = x; let ln_denom = factorial::ln_binomial(self.population, self.draws); - (k + 1 .. self.max() + 1).fold(0.0, |acc, i| { + (k + 1..self.max() + 1).fold(0.0, |acc, i| { acc + (factorial::ln_binomial(self.successes, i) + factorial::ln_binomial(self.population - self.successes, self.draws - i) - ln_denom) @@ -210,7 +210,7 @@ impl Min for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// max(0, n + K - N) /// ``` /// @@ -227,7 +227,7 @@ impl Max for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// min(K, n) /// ``` /// @@ -246,7 +246,7 @@ impl Distribution for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// K * n / N /// ``` /// @@ -266,7 +266,7 @@ impl Distribution for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// n * (K / N) * ((N - K) / N) * ((N - n) / (N - 1)) /// ``` /// @@ -289,7 +289,7 @@ impl Distribution for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// ((N - 2K) * (N - 1)^(1 / 2) * (N - 2n)) / ([n * K * (N - K) * (N - /// n)]^(1 / 2) * (N - 2)) /// ``` @@ -315,7 +315,7 @@ impl Mode> for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// floor((n + 1) * (k + 1) / (N + 2)) /// ``` /// @@ -331,7 +331,7 @@ impl Discrete for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// (K choose x) * (N-K choose n-x) / (N choose n) /// ``` /// @@ -351,7 +351,7 @@ impl Discrete for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((K choose x) * (N-K choose n-x) / (N choose n)) /// ``` /// diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index 1cf69fa9..e439be45 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -99,7 +99,7 @@ impl ContinuousCDF for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// Γ(α, β / x) / Γ(α) /// ``` /// @@ -121,7 +121,7 @@ impl ContinuousCDF for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// Γ(α, β / x) / Γ(α) /// ``` /// @@ -146,7 +146,7 @@ impl Min for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -161,7 +161,7 @@ impl Max for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -178,7 +178,7 @@ impl Distribution for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// β / (α - 1) /// ``` /// @@ -198,7 +198,7 @@ impl Distribution for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// β^2 / ((α - 1)^2 * (α - 2)) /// ``` /// @@ -216,7 +216,7 @@ impl Distribution for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// α + ln(β * Γ(α)) - (1 + α) * ψ(α) /// ``` /// @@ -235,7 +235,7 @@ impl Distribution for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// 4 * sqrt(α - 2) / (α - 3) /// ``` /// @@ -254,7 +254,7 @@ impl Mode> for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// β / (α + 1) /// ``` /// @@ -270,7 +270,7 @@ impl Continuous for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// (β^α / Γ(α)) * x^(-α - 1) * e^(-β / x) /// ``` /// @@ -291,7 +291,7 @@ impl Continuous for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((β^α / Γ(α)) * x^(-α - 1) * e^(-β / x)) /// ``` /// diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index 2d3d5590..4f73b278 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -92,7 +92,7 @@ impl ContinuousCDF for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * (1 + signum(x - μ)) - signum(x - μ) * exp(-|x - μ| / b) /// ``` /// @@ -111,7 +111,7 @@ impl ContinuousCDF for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 - [(1 / 2) * (1 + signum(x - μ)) - signum(x - μ) * exp(-|x - μ| / b)] /// ``` /// @@ -131,11 +131,11 @@ impl ContinuousCDF for Laplace { /// # Formula /// /// if p <= 1/2 - /// ```ignore + /// ```text /// μ + b * ln(2p) /// ``` /// if p >= 1/2 - /// ```ignore + /// ```text /// μ - b * ln(2 - 2p) /// ``` /// @@ -158,7 +158,7 @@ impl Min for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// NEG_INF /// ``` fn min(&self) -> f64 { @@ -172,7 +172,7 @@ impl Max for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -185,7 +185,7 @@ impl Distribution for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -197,7 +197,7 @@ impl Distribution for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// 2*b^2 /// ``` /// @@ -209,7 +209,7 @@ impl Distribution for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(2be) /// ``` /// @@ -221,7 +221,7 @@ impl Distribution for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn skewness(&self) -> Option { @@ -234,7 +234,7 @@ impl Median for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -249,7 +249,7 @@ impl Mode> for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -265,7 +265,7 @@ impl Continuous for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2b) * exp(-|x - μ| / b) /// ``` /// where `μ` is the location and `b` is the scale @@ -278,7 +278,7 @@ impl Continuous for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((1 / 2b) * exp(-|x - μ| / b)) /// ``` /// diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 13854f6f..6698dd2a 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -68,7 +68,7 @@ impl ContinuousCDF for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) + (1 / 2) * erf((ln(x) - μ) / sqrt(2) * σ) /// ``` /// @@ -89,7 +89,7 @@ impl ContinuousCDF for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) + (1 / 2) * erf(-(ln(x) - μ) / sqrt(2) * σ) /// ``` /// @@ -100,9 +100,9 @@ impl ContinuousCDF for LogNormal { /// the sign of the argument error function with respect to the cdf. /// /// the normal cdf Φ (and internal error function) as the following property: - /// ```ignore + /// ```text /// Φ(-x) + Φ(x) = 1 - /// Φ(-x) = 1 - Φ(x) + /// Φ(-x) = 1 - Φ(x) /// ``` fn sf(&self, x: f64) -> f64 { if x <= 0.0 { @@ -121,7 +121,7 @@ impl Min for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -135,7 +135,7 @@ impl Max for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -148,7 +148,7 @@ impl Distribution for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// e^(μ + σ^2 / 2) /// ``` /// @@ -160,7 +160,7 @@ impl Distribution for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// (e^(σ^2) - 1) * e^(2μ + σ^2) /// ``` /// @@ -173,7 +173,7 @@ impl Distribution for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(σe^(μ + 1 / 2) * sqrt(2π)) /// ``` /// @@ -185,7 +185,7 @@ impl Distribution for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// (e^(σ^2) + 2) * sqrt(e^(σ^2) - 1) /// ``` /// @@ -201,7 +201,7 @@ impl Median for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// e^μ /// ``` /// @@ -216,7 +216,7 @@ impl Mode> for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// e^(μ - σ^2) /// ``` /// @@ -232,7 +232,7 @@ impl Continuous for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / xσ * sqrt(2π)) * e^(-((ln(x) - μ)^2) / 2σ^2) /// ``` /// @@ -251,7 +251,7 @@ impl Continuous for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((1 / xσ * sqrt(2π)) * e^(-((ln(x) - μ)^2) / 2σ^2)) /// ``` /// diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index a4f0524d..31de9e57 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -108,7 +108,7 @@ impl MeanN> for Multinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// n * p_i for i in 1...k /// ``` /// @@ -126,7 +126,7 @@ impl VarianceN> for Multinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// n * p_i * (1 - p_i) for i in 1...k /// ``` /// @@ -147,7 +147,7 @@ impl VarianceN> for Multinomial { // /// // /// # Formula // /// -// /// ```ignore +// /// ```text // /// (1 - 2 * p_i) / (n * p_i * (1 - p_i)) for i in 1...k // /// ``` // /// @@ -176,7 +176,7 @@ impl<'a> Discrete<&'a [u64], f64> for Multinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// (n! / x_1!...x_k!) * p_i^x_i for i in 1...k /// ``` /// @@ -212,7 +212,7 @@ impl<'a> Discrete<&'a [u64], f64> for Multinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((n! / x_1!...x_k!) * p_i^x_i) for i in 1...k /// ``` /// diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 75902a4b..b41a09e1 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -83,7 +83,7 @@ impl MultivariateNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * ln(det(2 * π * e * Σ)) /// ``` /// @@ -104,7 +104,9 @@ impl ::rand::distributions::Distribution> for MultivariateNormal { /// Samples from the multivariate normal distribution /// /// # Formula + /// ```text /// L * Z + μ + /// ``` /// /// where `L` is the Cholesky decomposition of the covariance matrix, /// `Z` is a vector of normally distributed random variables, and @@ -160,7 +162,7 @@ impl Mode> for MultivariateNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -176,7 +178,7 @@ impl<'a> Continuous<&'a DVector, f64> for MultivariateNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// (2 * π) ^ (-k / 2) * det(Σ) ^ (1 / 2) * e ^ ( -(1 / 2) * transpose(x - μ) * inv(Σ) * (x - μ)) /// ``` /// @@ -208,7 +210,7 @@ impl Continuous, f64> for MultivariateNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// (2 * π) ^ (-k / 2) * det(Σ) ^ (1 / 2) * e ^ ( -(1 / 2) * transpose(x - μ) * inv(Σ) * (x - μ)) /// ``` /// diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index a9ed077a..4c23079c 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -117,7 +117,7 @@ impl DiscreteCDF for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// I_(p)(r, x+1) /// ``` /// @@ -137,7 +137,7 @@ impl DiscreteCDF for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// I_(1-p)(x+1, r) /// ``` /// @@ -154,7 +154,7 @@ impl Min for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> u64 { @@ -169,7 +169,7 @@ impl Max for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// u64::MAX /// ``` fn max(&self) -> u64 { @@ -182,7 +182,7 @@ impl DiscreteDistribution for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// r * (1-p) / p /// ``` fn mean(&self) -> Option { @@ -192,7 +192,7 @@ impl DiscreteDistribution for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// r * (1-p) / p^2 /// ``` fn variance(&self) -> Option { @@ -202,7 +202,7 @@ impl DiscreteDistribution for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// (2-p) / sqrt(r * (1-p)) /// ``` fn skewness(&self) -> Option { @@ -215,7 +215,7 @@ impl Mode> for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// if r > 1 then /// floor((r - 1) * (1-p / p)) /// else @@ -239,13 +239,13 @@ impl Discrete for NegativeBinomial { /// /// When `r` is an integer, the formula is: /// - /// ```ignore + /// ```text /// (x + r - 1 choose x) * (1 - p)^x * p^r /// ``` /// /// The general formula for real `r` is: /// - /// ```ignore + /// ```text /// Γ(r + x)/(Γ(r) * Γ(x + 1)) * (1 - p)^x * p^r /// ``` /// @@ -261,13 +261,13 @@ impl Discrete for NegativeBinomial { /// /// When `r` is an integer, the formula is: /// - /// ```ignore + /// ```text /// ln((x + r - 1 choose x) * (1 - p)^x * p^r) /// ``` /// /// The general formula for real `r` is: /// - /// ```ignore + /// ```text /// ln(Γ(r + x)/(Γ(r) * Γ(x + 1)) * (1 - p)^x * p^r) /// ``` /// diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index dabb7915..d1cc91f2 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -65,7 +65,7 @@ impl ContinuousCDF for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * (1 + erf((x - μ) / (σ * sqrt(2)))) /// ``` /// @@ -80,7 +80,7 @@ impl ContinuousCDF for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * (1 + erf(-(x - μ) / (σ * sqrt(2)))) /// ``` /// @@ -91,9 +91,9 @@ impl ContinuousCDF for Normal { /// the sign of the argument error function with respect to the cdf. /// /// the normal cdf Φ (and internal error function) as the following property: - /// ```ignore + /// ```text /// Φ(-x) + Φ(x) = 1 - /// Φ(-x) = 1 - Φ(x) + /// Φ(-x) = 1 - Φ(x) /// ``` fn sf(&self, x: f64) -> f64 { sf_unchecked(x, self.mean, self.std_dev) @@ -108,7 +108,7 @@ impl ContinuousCDF for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// μ - sqrt(2) * σ * erfc_inv(2x) /// ``` /// @@ -129,7 +129,7 @@ impl Min for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// -INF /// ``` fn min(&self) -> f64 { @@ -143,7 +143,7 @@ impl Max for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -164,7 +164,7 @@ impl Distribution for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// σ^2 /// ``` /// @@ -176,7 +176,7 @@ impl Distribution for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * ln(2σ^2 * π * e) /// ``` /// @@ -188,7 +188,7 @@ impl Distribution for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn skewness(&self) -> Option { @@ -201,7 +201,7 @@ impl Median for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -216,7 +216,7 @@ impl Mode> for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -232,7 +232,7 @@ impl Continuous for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / sqrt(2σ^2 * π)) * e^(-(x - μ)^2 / 2σ^2) /// ``` /// @@ -247,7 +247,7 @@ impl Continuous for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((1 / sqrt(2σ^2 * π)) * e^(-(x - μ)^2 / 2σ^2)) /// ``` /// diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index e59595a2..eca17f82 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -97,7 +97,7 @@ impl ContinuousCDF for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < x_m { /// 0 /// } else { @@ -119,7 +119,7 @@ impl ContinuousCDF for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < x_m { /// 1 /// } else { @@ -143,7 +143,7 @@ impl Min for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// x_m /// ``` /// @@ -159,7 +159,7 @@ impl Max for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -172,7 +172,7 @@ impl Distribution for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// if α <= 1 { /// INF /// } else { @@ -192,7 +192,7 @@ impl Distribution for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// if α <= 2 { /// INF /// } else { @@ -213,7 +213,7 @@ impl Distribution for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(α/x_m) - 1/α - 1 /// ``` /// @@ -231,7 +231,7 @@ impl Distribution for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// (2*(α + 1)/(α - 3))*sqrt((α - 2)/α) /// ``` /// @@ -253,7 +253,7 @@ impl Median for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// x_m*2^(1/α) /// ``` /// @@ -268,7 +268,7 @@ impl Mode> for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// x_m /// ``` /// @@ -284,7 +284,7 @@ impl Continuous for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < x_m { /// 0 /// } else { @@ -306,7 +306,7 @@ impl Continuous for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < x_m { /// -INF /// } else { diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 9879b6a8..47d614ce 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -84,7 +84,7 @@ impl DiscreteCDF for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// P(x + 1, λ) /// ``` /// @@ -98,7 +98,7 @@ impl DiscreteCDF for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// P(x + 1, λ) /// ``` /// @@ -114,7 +114,7 @@ impl Min for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> u64 { @@ -128,7 +128,7 @@ impl Max for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// 2^63 - 1 /// ``` fn max(&self) -> u64 { @@ -141,7 +141,7 @@ impl Distribution for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// λ /// ``` /// @@ -153,7 +153,7 @@ impl Distribution for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// λ /// ``` /// @@ -165,7 +165,7 @@ impl Distribution for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * ln(2πeλ) - 1 / (12λ) - 1 / (24λ^2) - 19 / (360λ^3) /// ``` /// @@ -182,7 +182,7 @@ impl Distribution for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// λ^(-1/2) /// ``` /// @@ -197,7 +197,7 @@ impl Median for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// floor(λ + 1 / 3 - 0.02 / λ) /// ``` /// @@ -212,7 +212,7 @@ impl Mode> for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// floor(λ) /// ``` /// @@ -228,7 +228,7 @@ impl Discrete for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// (λ^x * e^(-λ)) / x! /// ``` /// @@ -243,7 +243,7 @@ impl Discrete for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((λ^x * e^(-λ)) / x!) /// ``` /// diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index 277647bf..21981513 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -124,7 +124,7 @@ impl ContinuousCDF for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < μ { /// (1 / 2) * I(t, v / 2, 1 / 2) /// } else { @@ -156,7 +156,7 @@ impl ContinuousCDF for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < μ { /// 1 - (1 / 2) * I(t, v / 2, 1 / 2) /// } else { @@ -209,7 +209,7 @@ impl Min for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// -INF /// ``` fn min(&self) -> f64 { @@ -223,7 +223,7 @@ impl Max for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -240,7 +240,7 @@ impl Distribution for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -260,7 +260,7 @@ impl Distribution for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// if v == INF { /// Some(σ^2) /// } else if freedom > 2.0 { @@ -284,7 +284,7 @@ impl Distribution for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// - ln(σ) + (v + 1) / 2 * (ψ((v + 1) / 2) - ψ(v / 2)) + ln(sqrt(v) * B(v / 2, 1 / /// 2)) /// ``` @@ -309,7 +309,7 @@ impl Distribution for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn skewness(&self) -> Option { @@ -326,7 +326,7 @@ impl Median for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -341,7 +341,7 @@ impl Mode> for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -358,7 +358,7 @@ impl Continuous for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// Γ((v + 1) / 2) / (sqrt(vπ) * Γ(v / 2) * σ) * (1 + k^2 / v)^(-1 / 2 * (v /// + 1)) /// ``` @@ -387,7 +387,7 @@ impl Continuous for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(Γ((v + 1) / 2) / (sqrt(vπ) * Γ(v / 2) * σ) * (1 + k^2 / v)^(-1 / 2 * /// (v + 1))) /// ``` diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index 068f6c9a..a9cb98a9 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -72,7 +72,7 @@ impl ContinuousCDF for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// if x == min { /// 0 /// } if min < x <= mode { @@ -103,7 +103,7 @@ impl ContinuousCDF for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// if x == min { /// 1 /// } if min < x <= mode { @@ -159,7 +159,7 @@ impl Distribution for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// (min + max + mode) / 3 /// ``` fn mean(&self) -> Option { @@ -169,7 +169,7 @@ impl Distribution for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// (min^2 + max^2 + mode^2 - min * max - min * mode - max * mode) / 18 /// ``` fn variance(&self) -> Option { @@ -182,7 +182,7 @@ impl Distribution for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / 2 + ln((max - min) / 2) /// ``` fn entropy(&self) -> Option { @@ -192,7 +192,7 @@ impl Distribution for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// (sqrt(2) * (min + max - 2 * mode) * (2 * min - max - mode) * (min - 2 * /// max + mode)) / /// ( 5 * (min^2 + max^2 + mode^2 - min * max - min * mode - max * mode)^(3 @@ -213,7 +213,7 @@ impl Median for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// if mode >= (min + max) / 2 { /// min + sqrt((max - min) * (mode - min) / 2) /// } else { @@ -237,7 +237,7 @@ impl Mode> for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// mode /// ``` fn mode(&self) -> Option { @@ -252,7 +252,7 @@ impl Continuous for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < min { /// 0 /// } else if min <= x <= mode { @@ -282,7 +282,7 @@ impl Continuous for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// ln( if x < min { /// 0 /// } else if min <= x <= mode { diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index bca6590b..c5f7d776 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -68,7 +68,7 @@ impl ContinuousCDF for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (x - min) / (max - min) /// ``` fn cdf(&self, x: f64) -> f64 { @@ -86,7 +86,7 @@ impl ContinuousCDF for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (max - x) / (max - min) /// ``` fn sf(&self, x: f64) -> f64 { @@ -121,7 +121,7 @@ impl Distribution for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (min + max) / 2 /// ``` fn mean(&self) -> Option { @@ -131,7 +131,7 @@ impl Distribution for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (max - min)^2 / 12 /// ``` fn variance(&self) -> Option { @@ -141,7 +141,7 @@ impl Distribution for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(max - min) /// ``` fn entropy(&self) -> Option { @@ -151,7 +151,7 @@ impl Distribution for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn skewness(&self) -> Option { @@ -164,7 +164,7 @@ impl Median for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (min + max) / 2 /// ``` fn median(&self) -> f64 { @@ -182,7 +182,7 @@ impl Mode> for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// N/A // (max + min) / 2 for the middle element /// ``` fn mode(&self) -> Option { @@ -200,7 +200,7 @@ impl Continuous for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / (max - min) /// ``` fn pdf(&self, x: f64) -> f64 { @@ -221,7 +221,7 @@ impl Continuous for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(1 / (max - min)) /// ``` fn ln_pdf(&self, x: f64) -> f64 { diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index eab7d942..c414c20c 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -103,7 +103,7 @@ impl ContinuousCDF for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 - e^-((x/λ)^k) /// ``` /// @@ -121,7 +121,7 @@ impl ContinuousCDF for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// e^-((x/λ)^k) /// ``` /// @@ -141,7 +141,7 @@ impl Min for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -155,7 +155,7 @@ impl Max for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// INF /// ``` fn max(&self) -> f64 { @@ -168,7 +168,7 @@ impl Distribution for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// λΓ(1 + 1 / k) /// ``` /// @@ -181,7 +181,7 @@ impl Distribution for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// λ^2 * (Γ(1 + 2 / k) - Γ(1 + 1 / k)^2) /// ``` /// @@ -195,7 +195,7 @@ impl Distribution for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// γ(1 - 1 / k) + ln(λ / k) + 1 /// ``` /// @@ -211,7 +211,7 @@ impl Distribution for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// (Γ(1 + 3 / k) * λ^3 - 3μσ^2 - μ^3) / σ^3 /// ``` /// @@ -236,7 +236,7 @@ impl Median for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// λ(ln(2))^(1 / k) /// ``` /// @@ -251,7 +251,7 @@ impl Mode> for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// if k == 1 { /// 0 /// } else { @@ -276,7 +276,7 @@ impl Continuous for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// (k / λ) * (x / λ)^(k - 1) * e^(-(x / λ)^k) /// ``` /// @@ -301,7 +301,7 @@ impl Continuous for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((k / λ) * (x / λ)^(k - 1) * e^(-(x / λ)^k)) /// ``` /// From eb0745e46d74a077fce9345ee772598aa6183671 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Mon, 18 Mar 2024 19:32:42 -0500 Subject: [PATCH 010/185] doc: alias `inverse_cdf` as "quantile function" in docs --- src/distribution/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index a9718ee6..562eb2bc 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -111,6 +111,8 @@ pub trait ContinuousCDF: Min + Max { /// Performs a binary search on the domain of `cdf` to obtain an approximation /// of `F^-1(p) := inf { x | F(x) >= p }`. Needless to say, performance may /// may be lacking. + #[doc(alias = "quantile function")] + #[doc(alias = "quantile")] fn inverse_cdf(&self, p: T) -> K { if p == T::zero() { return self.min(); From feb156cf6fa25ccb109ed776558e31d27d920c34 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 14 Mar 2024 10:52:23 -0500 Subject: [PATCH 011/185] chore: update rustfmt to make changes --- rustfmt.toml | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/rustfmt.toml b/rustfmt.toml index b42e764f..d1c82741 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,4 +1,33 @@ -# Run using `rustfmt --write-mode overwrite *.rs` in the -# root of the src directory. You may still get some -# formatting errors (whitespace etc) which should be -# fixed manually before committing. +# This rustfmt file is added for configuration, but in practice much of our +# code is hand-formatted, frequently with more readable results. +# taken from rust-random/rand + +# Comments: +normalize_comments = true +wrap_comments = false +comment_width = 90 # small excess is okay but prefer 80 + +# Arguments: +use_small_heuristics = "Default" +# TODO: single line functions only where short, please? +# https://github.com/rust-lang/rustfmt/issues/3358 +fn_single_line = false +fn_params_layout = "Compressed" +overflow_delimited_expr = true +where_single_line = true + +# enum_discrim_align_threshold = 20 +# struct_field_align_threshold = 20 + +# Compatibility: +edition = "2021" + +# Misc: +inline_attribute_width = 80 +blank_lines_upper_bound = 2 +reorder_impl_items = true +# report_todo = "Unnumbered" +# report_fixme = "Unnumbered" + +# Ignored files: +ignore = [] From 6a28743942a52f77dc088f8107b4fb2b7caf6ffa Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Wed, 3 Apr 2024 10:32:52 -0500 Subject: [PATCH 012/185] chore: format and add contributing content to README format markdown sentences to entire lines. add test status badge add usage of `cargo fmt` to contributing section --- README.md | 49 ++++++++++++++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 4521bb9b..c6e59dbb 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # statrs -[![Build Status](https://travis-ci.org/boxtown/statrs.svg?branch=master)](https://travis-ci.org/boxtown/statrs) +![tests](https://github.com/statrs-dev/statrs/actions/workflows/test.yml/badge.svg) [![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE.md) [![Crates.io](https://img.shields.io/crates/v/statrs.svg)](https://crates.io/crates/statrs) @@ -13,18 +13,15 @@ Should work for both nightly and stable Rust. ## Description Statrs provides a host of statistical utilities for Rust scientific computing. -Included are a number of common distributions that can be sampled (i.e. Normal, Exponential, -Student's T, Gamma, Uniform, etc.) plus common statistical functions like the gamma function, -beta function, and error function. +Included are a number of common distributions that can be sampled (i.e. Normal, Exponential, Student's T, Gamma, Uniform, etc.) plus common statistical functions like the gamma function, beta function, and error function. -This library is a work-in-progress port of the statistical capabilities -in the C# Math.NET library. All unit tests in the library borrowed from Math.NET when possible -and filled-in when not. +This library is a work-in-progress port of the statistical capabilities in the C# Math.NET library. +All unit tests in the library borrowed from Math.NET when possible and filled-in when not. -This library is a work-in-progress and not complete. Planned for future releases are continued implementations -of distributions as well as porting over more statistical utilities +This library is a work-in-progress and not complete. +Planned for future releases are continued implementations of distributions as well as porting over more statistical utilities. -Please check out the documentation [here](https://docs.rs/statrs/*/statrs/) +Please check out the documentation [here](https://docs.rs/statrs/*/statrs/). ## Usage @@ -38,7 +35,7 @@ statrs = "0.16" ## Examples Statrs comes with a number of commonly used distributions including Normal, Gamma, Student's T, Exponential, Weibull, etc. -The common use case is to set up the distributions and sample from them which depends on the `Rand` crate for random number generation +The common use case is to set up the distributions and sample from them which depends on the `Rand` crate for random number generation. ```Rust use statrs::distribution::Exp; @@ -49,7 +46,7 @@ let n = Exp::new(0.5).unwrap(); print!("{}", n.sample(&mut r)); ``` -Statrs also comes with a number of useful utility traits for more detailed introspection of distributions +Statrs also comes with a number of useful utility traits for more detailed introspection of distributions. ```Rust use statrs::distribution::{Exp, Continuous, ContinuousCDF}; @@ -76,7 +73,8 @@ assert!(n.variance().is_none()); ## Contributing -Want to contribute? Check out some of the issues marked [help wanted](https://github.com/statrs-dev/statrs/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) +Want to contribute? +Check out some of the issues marked [help wanted](https://github.com/statrs-dev/statrs/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) ### How to contribute @@ -92,24 +90,28 @@ Create a feature branch: git checkout -b master ``` -After commiting your code: +Write your code and docs, then ensure it is formatted: + +The below sample modify in-place, use `--check` flag to view diff without making file changes. +Not using `fmt` from +nightly may result in some warnings and different formatting. +Our CI will `fmt`, but less chores in commit history are appreciated. ``` -git push -u origin +cargo +nightly fmt ``` -Then submit a PR, preferably referencing the relevant issue. +After commiting your code: -### Style +``` +git push -u origin +``` -This repo makes use of `rustfmt` with the configuration specified in `rustfmt.toml`. -See https://github.com/rust-lang-nursery/rustfmt for instructions on installation -and usage and run the formatter using `rustfmt --write-mode overwrite *.rs` in -the `src` directory before committing. +Then submit a PR, preferably referencing the relevant issue, if it exists. ### Commit messages Please be explicit and and purposeful with commit messages. +[Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/#summary) encouraged. #### Bad @@ -122,3 +124,8 @@ Modify test code ``` test: Update statrs::distribution::Normal test_cdf ``` + +### Communication Expectations + +Please allow at least one week before pinging issues/pr's. + From f3683365fb1b7115cb6dca366bc4d3b2bfe0f473 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 12 Apr 2024 17:43:59 -0500 Subject: [PATCH 013/185] chore: update badges and move examples to docs --- README.md | 45 +++++---------------------------------------- src/lib.rs | 46 +++++++++++++++++++++++++++++++++------------- 2 files changed, 38 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index c6e59dbb..fa43b7bb 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,10 @@ # statrs ![tests](https://github.com/statrs-dev/statrs/actions/workflows/test.yml/badge.svg) -[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE.md) -[![Crates.io](https://img.shields.io/crates/v/statrs.svg)](https://crates.io/crates/statrs) +[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE.md)] +[![Crate](https://img.shields.io/crates/v/statrs.svg)](https://crates.io/crates/statrs) +![docs.rs](https://img.shields.io/docsrs/statrs?style=for-the-badge). +[![codecov](https://codecov.io/gh/statrs-dev/statrs/graph/badge.svg?token=XtMSMYXvIf)](https://codecov.io/gh/statrs-dev/statrs) ## Current Version: v0.16.0 @@ -32,44 +34,7 @@ Add the most recent release to your `Cargo.toml` statrs = "0.16" ``` -## Examples - -Statrs comes with a number of commonly used distributions including Normal, Gamma, Student's T, Exponential, Weibull, etc. -The common use case is to set up the distributions and sample from them which depends on the `Rand` crate for random number generation. - -```Rust -use statrs::distribution::Exp; -use rand::distributions::Distribution; - -let mut r = rand::rngs::OsRng; -let n = Exp::new(0.5).unwrap(); -print!("{}", n.sample(&mut r)); -``` - -Statrs also comes with a number of useful utility traits for more detailed introspection of distributions. - -```Rust -use statrs::distribution::{Exp, Continuous, ContinuousCDF}; -use statrs::statistics::Distribution; - -let n = Exp::new(1.0).unwrap(); -assert_eq!(n.mean(), Some(1.0)); -assert_eq!(n.variance(), Some(1.0)); -assert_eq!(n.entropy(), Some(1.0)); -assert_eq!(n.skewness(), Some(2.0)); -assert_eq!(n.cdf(1.0), 0.6321205588285576784045); -assert_eq!(n.pdf(1.0), 0.3678794411714423215955); -``` - -as well as utility functions including `erf`, `gamma`, `ln_gamma`, `beta`, etc. - -```Rust -use statrs::statistics::Distribution; -use statrs::distribution::FisherSnedecor; - -let n = FisherSnedecor::new(1.0, 1.0).unwrap(); -assert!(n.variance().is_none()); -``` +For examples, view the docs hosted on ![docs.rs](https://img.shields.io/docsrs/statrs?style=for-the-badge). ## Contributing diff --git a/src/lib.rs b/src/lib.rs index 9a9d0a70..939400fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,23 +5,43 @@ //! Math.NET in so far as they are used in the computation of distribution //! values. This crate depends on the `rand` crate to provide RNG. //! -//! # Example -//! The following example samples from a standard normal distribution -//! +//! # Sampling +//! The common use case is to set up the distributions and sample from them which depends on the `Rand` crate for random number generation. //! ``` -//! # extern crate rand; -//! # extern crate statrs; +//! use statrs::distribution::Exp; //! use rand::distributions::Distribution; -//! use statrs::distribution::Normal; +//! let mut r = rand::rngs::OsRng; +//! let n = Exp::new(0.5).unwrap(); +//! print!("{}", n.sample(&mut r)); +//! ``` +//! +//! # Introspecting distributions +//! Statrs also comes with a number of useful utility traits for more detailed introspection of distributions. +//! ``` +//! use statrs::distribution::{Exp, Continuous, ContinuousCDF}; // `cdf` and `pdf` +//! use statrs::statistics::Distribution; // statistical moments and entropy +//! +//! let n = Exp::new(1.0).unwrap(); +//! assert_eq!(n.mean(), Some(1.0)); +//! assert_eq!(n.variance(), Some(1.0)); +//! assert_eq!(n.entropy(), Some(1.0)); +//! assert_eq!(n.skewness(), Some(2.0)); +//! assert_eq!(n.cdf(1.0), 0.6321205588285576784045); +//! assert_eq!(n.pdf(1.0), 0.3678794411714423215955); +//! ``` +//! +//! # Utility functions +//! as well as utility functions including `erf`, `gamma`, `ln_gamma`, `beta`, etc. +//! +//! ``` +//! use statrs::distribution::FisherSnedecor; +//! use statrs::statistics::Distribution; //! -//! # fn main() { -//! let mut r = rand::thread_rng(); -//! let n = Normal::new(0.0, 1.0).unwrap(); -//! for _ in 0..10 { -//! print!("{}", n.sample(&mut r)); -//! } -//! # } +//! let n = FisherSnedecor::new(1.0, 1.0).unwrap(); +//! assert!(n.variance().is_none()); //! ``` +//! ## Distributions implemented +//! Statrs comes with a number of commonly used distributions including Normal, Gamma, Student's T, Exponential, Weibull, etc. view all implemented in `distributions` module. #![crate_type = "lib"] #![crate_name = "statrs"] From dc82d3c739ba0ca82f594d8d7f011a95cf25103b Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 21 Apr 2024 10:28:01 -0500 Subject: [PATCH 014/185] chore: consistent linefeed --- README.md | 190 +++++++++++++++++++++++++++--------------------------- 1 file changed, 95 insertions(+), 95 deletions(-) diff --git a/README.md b/README.md index fa43b7bb..90362c54 100644 --- a/README.md +++ b/README.md @@ -1,96 +1,96 @@ -# statrs - -![tests](https://github.com/statrs-dev/statrs/actions/workflows/test.yml/badge.svg) -[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE.md)] -[![Crate](https://img.shields.io/crates/v/statrs.svg)](https://crates.io/crates/statrs) -![docs.rs](https://img.shields.io/docsrs/statrs?style=for-the-badge). -[![codecov](https://codecov.io/gh/statrs-dev/statrs/graph/badge.svg?token=XtMSMYXvIf)](https://codecov.io/gh/statrs-dev/statrs) - -## Current Version: v0.16.0 - -Should work for both nightly and stable Rust. - -**NOTE:** While I will try to maintain backwards compatibility as much as possible, since this is still a 0.x.x project the API is not considered stable and thus subject to possible breaking changes up until v1.0.0 - -## Description - -Statrs provides a host of statistical utilities for Rust scientific computing. -Included are a number of common distributions that can be sampled (i.e. Normal, Exponential, Student's T, Gamma, Uniform, etc.) plus common statistical functions like the gamma function, beta function, and error function. - -This library is a work-in-progress port of the statistical capabilities in the C# Math.NET library. -All unit tests in the library borrowed from Math.NET when possible and filled-in when not. - -This library is a work-in-progress and not complete. -Planned for future releases are continued implementations of distributions as well as porting over more statistical utilities. - -Please check out the documentation [here](https://docs.rs/statrs/*/statrs/). - -## Usage - -Add the most recent release to your `Cargo.toml` - -```Rust -[dependencies] -statrs = "0.16" -``` - +# statrs + +![tests](https://github.com/statrs-dev/statrs/actions/workflows/test.yml/badge.svg) +[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE.md)] +[![Crate](https://img.shields.io/crates/v/statrs.svg)](https://crates.io/crates/statrs) +![docs.rs](https://img.shields.io/docsrs/statrs?style=for-the-badge). +[![codecov](https://codecov.io/gh/statrs-dev/statrs/graph/badge.svg?token=XtMSMYXvIf)](https://codecov.io/gh/statrs-dev/statrs) + +## Current Version: v0.16.0 + +Should work for both nightly and stable Rust. + +**NOTE:** While I will try to maintain backwards compatibility as much as possible, since this is still a 0.x.x project the API is not considered stable and thus subject to possible breaking changes up until v1.0.0 + +## Description + +Statrs provides a host of statistical utilities for Rust scientific computing. +Included are a number of common distributions that can be sampled (i.e. Normal, Exponential, Student's T, Gamma, Uniform, etc.) plus common statistical functions like the gamma function, beta function, and error function. + +This library is a work-in-progress port of the statistical capabilities in the C# Math.NET library. +All unit tests in the library borrowed from Math.NET when possible and filled-in when not. + +This library is a work-in-progress and not complete. +Planned for future releases are continued implementations of distributions as well as porting over more statistical utilities. + +Please check out the documentation [here](https://docs.rs/statrs/*/statrs/). + +## Usage + +Add the most recent release to your `Cargo.toml` + +```Rust +[dependencies] +statrs = "0.16" +``` + For examples, view the docs hosted on ![docs.rs](https://img.shields.io/docsrs/statrs?style=for-the-badge). - -## Contributing - -Want to contribute? -Check out some of the issues marked [help wanted](https://github.com/statrs-dev/statrs/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) - -### How to contribute - -Clone the repo: - -``` -git clone https://github.com/statrs-dev/statrs -``` - -Create a feature branch: - -``` -git checkout -b master -``` - -Write your code and docs, then ensure it is formatted: - -The below sample modify in-place, use `--check` flag to view diff without making file changes. -Not using `fmt` from +nightly may result in some warnings and different formatting. -Our CI will `fmt`, but less chores in commit history are appreciated. - -``` -cargo +nightly fmt -``` - -After commiting your code: - -``` -git push -u origin -``` - -Then submit a PR, preferably referencing the relevant issue, if it exists. - -### Commit messages - -Please be explicit and and purposeful with commit messages. -[Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/#summary) encouraged. - -#### Bad - -``` -Modify test code -``` - -#### Good - -``` -test: Update statrs::distribution::Normal test_cdf -``` - -### Communication Expectations - -Please allow at least one week before pinging issues/pr's. - + +## Contributing + +Want to contribute? +Check out some of the issues marked [help wanted](https://github.com/statrs-dev/statrs/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) + +### How to contribute + +Clone the repo: + +``` +git clone https://github.com/statrs-dev/statrs +``` + +Create a feature branch: + +``` +git checkout -b master +``` + +Write your code and docs, then ensure it is formatted: + +The below sample modify in-place, use `--check` flag to view diff without making file changes. +Not using `fmt` from +nightly may result in some warnings and different formatting. +Our CI will `fmt`, but less chores in commit history are appreciated. + +``` +cargo +nightly fmt +``` + +After commiting your code: + +``` +git push -u origin +``` + +Then submit a PR, preferably referencing the relevant issue, if it exists. + +### Commit messages + +Please be explicit and and purposeful with commit messages. +[Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/#summary) encouraged. + +#### Bad + +``` +Modify test code +``` + +#### Good + +``` +test: Update statrs::distribution::Normal test_cdf +``` + +### Communication Expectations + +Please allow at least one week before pinging issues/pr's. + From b149b06d94f67492aae84af8e4cc1f678593b505 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 17 Apr 2024 13:04:18 +0200 Subject: [PATCH 015/185] Remove unnecessary type casts --- src/distribution/binomial.rs | 4 ++-- src/distribution/poisson.rs | 4 ++-- src/generate.rs | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 7c5c9622..72de6b84 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -279,7 +279,7 @@ impl Discrete for Binomial { 0.0 } } else { - (factorial::ln_binomial(self.n as u64, x as u64) + (factorial::ln_binomial(self.n, x) + x as f64 * self.p.ln() + (self.n - x) as f64 * (1.0 - self.p).ln()) .exp() @@ -310,7 +310,7 @@ impl Discrete for Binomial { f64::NEG_INFINITY } } else { - factorial::ln_binomial(self.n as u64, x as u64) + factorial::ln_binomial(self.n, x) + x as f64 * self.p.ln() + (self.n - x) as f64 * (1.0 - self.p).ln() } diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 47d614ce..e295780d 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -234,7 +234,7 @@ impl Discrete for Poisson { /// /// where `λ` is the rate fn pmf(&self, x: u64) -> f64 { - (-self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x as u64)).exp() + (-self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x)).exp() } /// Calculates the log probability mass function for the poisson @@ -249,7 +249,7 @@ impl Discrete for Poisson { /// /// where `λ` is the rate fn ln_pmf(&self, x: u64) -> f64 { - -self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x as u64) + -self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x) } } /// Generates one sample from the Poisson distribution either by diff --git a/src/generate.rs b/src/generate.rs index 1f6102d2..59783598 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -323,7 +323,7 @@ impl InfiniteSawtooth { 0.0, delay, ), - low_value: low_value as f64, + low_value, } } } From 1f8c97509ed7b6e27e69f8a43b674d25190631c7 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 17 Apr 2024 13:06:20 +0200 Subject: [PATCH 016/185] Use Range::contains instead of manual comparisons --- src/distribution/binomial.rs | 2 +- src/distribution/negative_binomial.rs | 2 +- src/function/beta.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 72de6b84..f3bfc76d 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -49,7 +49,7 @@ impl Binomial { /// assert!(result.is_err()); /// ``` pub fn new(p: f64, n: u64) -> Result { - if p.is_nan() || p < 0.0 || p > 1.0 { + if p.is_nan() || !(0.0..=1.0).contains(&p) { Err(StatsError::BadParams) } else { Ok(Binomial { p, n }) diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index 4c23079c..e455aa47 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -65,7 +65,7 @@ impl NegativeBinomial { /// assert!(result.is_err()); /// ``` pub fn new(r: f64, p: f64) -> Result { - if p.is_nan() || p < 0.0 || p > 1.0 || r.is_nan() || r < 0.0 { + if p.is_nan() || !(0.0..=1.0).contains(&p) || r.is_nan() || r < 0.0 { Err(StatsError::BadParams) } else { Ok(NegativeBinomial { r, p }) diff --git a/src/function/beta.rs b/src/function/beta.rs index 23fb3430..f217d8d6 100644 --- a/src/function/beta.rs +++ b/src/function/beta.rs @@ -365,7 +365,7 @@ pub fn inv_beta_reg(mut a: f64, mut b: f64, mut x: f64) -> f64 { if sq < prev { pnext = p - adj; - if 0.0 <= pnext && pnext <= 1.0 { + if (0.0..=1.0).contains(&pnext) { break; } } From af35d44127b8e6f5a759038c15e219805972c718 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 17 Apr 2024 13:09:37 +0200 Subject: [PATCH 017/185] Use BTreeMap::keys instead of iter --- src/distribution/empirical.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index 5804f7c1..21e70522 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -158,14 +158,14 @@ impl ::rand::distributions::Distribution for Empirical { /// Panics if number of samples is zero impl Max for Empirical { fn max(&self) -> f64 { - self.data.iter().rev().map(|(key, _)| key.0).next().unwrap() + self.data.keys().rev().map(|key| key.0) .next().unwrap() } } /// Panics if number of samples is zero impl Min for Empirical { fn min(&self) -> f64 { - self.data.iter().map(|(key, _)| key.0).next().unwrap() + self.data.keys().map(|key| key.0).next().unwrap() } } From 2a0424efca0889782b2510aeddace82d06df8cc3 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sun, 21 Apr 2024 22:57:46 +0200 Subject: [PATCH 018/185] Remove deprecated f64 associated constants --- src/distribution/beta.rs | 105 ++++++++++++++-------------- src/distribution/cauchy.rs | 2 +- src/distribution/chi.rs | 2 +- src/distribution/chi_squared.rs | 2 +- src/distribution/erlang.rs | 2 +- src/distribution/exponential.rs | 2 +- src/distribution/fisher_snedecor.rs | 2 +- src/distribution/gamma.rs | 53 +++++++------- src/distribution/inverse_gamma.rs | 2 +- src/distribution/laplace.rs | 83 +++++++++++----------- src/distribution/log_normal.rs | 2 +- src/distribution/normal.rs | 4 +- src/distribution/pareto.rs | 8 +-- src/distribution/students_t.rs | 6 +- src/distribution/weibull.rs | 2 +- src/function/factorial.rs | 7 +- 16 files changed, 140 insertions(+), 144 deletions(-) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index ba071dfa..4682948c 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -3,7 +3,6 @@ use crate::function::{beta, gamma}; use crate::is_zero; use crate::statistics::*; use crate::{Result, StatsError}; -use core::f64::INFINITY as INF; use rand::Rng; /// Implements the [Beta](https://en.wikipedia.org/wiki/Beta_distribution) @@ -326,13 +325,13 @@ impl Continuous for Beta { 0.0 } else if self.shape_a.is_infinite() { if ulps_eq!(x, 1.0) { - INF + f64::INFINITY } else { 0.0 } } else if self.shape_b.is_infinite() { if is_zero(x) { - INF + f64::INFINITY } else { 0.0 } @@ -361,18 +360,18 @@ impl Continuous for Beta { /// where `α` is shapeA, `β` is shapeB, and `Γ` is the gamma function fn ln_pdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { - -INF + f64::NEG_INFINITY } else if self.shape_a.is_infinite() { if ulps_eq!(x, 1.0) { - INF + f64::INFINITY } else { - -INF + f64::NEG_INFINITY } } else if self.shape_b.is_infinite() { if is_zero(x) { - INF + f64::INFINITY } else { - -INF + f64::NEG_INFINITY } } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) { 0.0 @@ -383,14 +382,14 @@ impl Continuous for Beta { let bb = if ulps_eq!(self.shape_a, 1.0) && is_zero(x) { 0.0 } else if is_zero(x) { - -INF + f64::NEG_INFINITY } else { (self.shape_a - 1.0) * x.ln() }; let cc = if ulps_eq!(self.shape_b, 1.0) && ulps_eq!(x, 1.0) { 0.0 } else if ulps_eq!(x, 1.0) { - -INF + f64::NEG_INFINITY } else { (self.shape_b - 1.0) * (1.0 - x).ln() }; @@ -412,7 +411,7 @@ mod tests { #[test] fn test_create() { - let valid = [(1.0, 1.0), (9.0, 1.0), (5.0, 100.0), (1.0, INF), (INF, 1.0)]; + let valid = [(1.0, 1.0), (9.0, 1.0), (5.0, 100.0), (1.0, f64::INFINITY), (f64::INFINITY, 1.0)]; for &arg in valid.iter() { try_create(arg); } @@ -424,15 +423,15 @@ mod tests { (0.0, 0.0), (0.0, 0.1), (1.0, 0.0), - (0.0, INF), - (INF, 0.0), + (0.0, f64::INFINITY), + (f64::INFINITY, 0.0), (f64::NAN, 1.0), (1.0, f64::NAN), (f64::NAN, f64::NAN), (1.0, -1.0), (-1.0, 1.0), (-1.0, -1.0), - (INF, INF), + (f64::INFINITY, f64::INFINITY), ]; for &arg in invalid.iter() { bad_create_case(arg); @@ -446,8 +445,8 @@ mod tests { ((1.0, 1.0), 0.5), ((9.0, 1.0), 0.9), ((5.0, 100.0), 0.047619047619047619047616), - ((1.0, INF), 0.0), - ((INF, 1.0), 1.0), + ((1.0, f64::INFINITY), 0.0), + ((f64::INFINITY, 1.0), 1.0), ]; for &(arg, res) in test.iter() { test_case(arg, res, f); @@ -461,8 +460,8 @@ mod tests { ((1.0, 1.0), 1.0 / 12.0), ((9.0, 1.0), 9.0 / 1100.0), ((5.0, 100.0), 500.0 / 1168650.0), - ((1.0, INF), 0.0), - ((INF, 1.0), 0.0), + ((1.0, f64::INFINITY), 0.0), + ((f64::INFINITY, 1.0), 0.0), ]; for &(arg, res) in test.iter() { test_case(arg, res, f); @@ -481,8 +480,8 @@ mod tests { } test_case_special((1.0, 1.0), 0.0, 1e-14, f); let entropy = |x: Beta| x.entropy(); - test_none((1.0, INF), entropy); - test_none((INF, 1.0), entropy); + test_none((1.0, f64::INFINITY), entropy); + test_none((f64::INFINITY, 1.0), entropy); } #[test] @@ -491,16 +490,16 @@ mod tests { test_case((1.0, 1.0), 0.0, skewness); test_case((9.0, 1.0), -1.4740554623801777107177478829, skewness); test_case((5.0, 100.0), 0.817594109275534303545831591, skewness); - test_case((1.0, INF), 2.0, skewness); - test_case((INF, 1.0), -2.0, skewness); + test_case((1.0, f64::INFINITY), 2.0, skewness); + test_case((f64::INFINITY, 1.0), -2.0, skewness); } #[test] fn test_mode() { let mode = |x: Beta| x.mode().unwrap(); test_case((5.0, 100.0), 0.038834951456310676243255386, mode); - test_case((92.0, INF), 0.0, mode); - test_case((INF, 2.0), 1.0, mode); + test_case((92.0, f64::INFINITY), 0.0, mode); + test_case((f64::INFINITY, 2.0), 1.0, mode); } #[test] @@ -539,12 +538,12 @@ mod tests { ((5.0, 100.0), 0.5, 4.534102298350337661e-23), ((5.0, 100.0), 1.0, 0.0), ((5.0, 100.0), 1.0, 0.0), - ((1.0, INF), 0.0, INF), - ((1.0, INF), 0.5, 0.0), - ((1.0, INF), 1.0, 0.0), - ((INF, 1.0), 0.0, 0.0), - ((INF, 1.0), 0.5, 0.0), - ((INF, 1.0), 1.0, INF), + ((1.0, f64::INFINITY), 0.0, f64::INFINITY), + ((1.0, f64::INFINITY), 0.5, 0.0), + ((1.0, f64::INFINITY), 1.0, 0.0), + ((f64::INFINITY, 1.0), 0.0, 0.0), + ((f64::INFINITY, 1.0), 0.5, 0.0), + ((f64::INFINITY, 1.0), 1.0, f64::INFINITY), ]; for &(arg, x, expect) in test.iter() { test_case(arg, expect, f(x)); @@ -570,18 +569,18 @@ mod tests { ((1.0, 1.0), 0.0, 0.0), ((1.0, 1.0), 0.5, 0.0), ((1.0, 1.0), 1.0, 0.0), - ((9.0, 1.0), 0.0, -INF), + ((9.0, 1.0), 0.0, f64::NEG_INFINITY), ((9.0, 1.0), 0.5, -3.347952867143343092547366497), ((9.0, 1.0), 1.0, 2.1972245773362193827904904738), - ((5.0, 100.0), 0.0, -INF), + ((5.0, 100.0), 0.0, f64::NEG_INFINITY), ((5.0, 100.0), 0.5, -51.447830024537682154565870), - ((5.0, 100.0), 1.0, -INF), - ((1.0, INF), 0.0, INF), - ((1.0, INF), 0.5, -INF), - ((1.0, INF), 1.0, -INF), - ((INF, 1.0), 0.0, -INF), - ((INF, 1.0), 0.5, -INF), - ((INF, 1.0), 1.0, INF), + ((5.0, 100.0), 1.0, f64::NEG_INFINITY), + ((1.0, f64::INFINITY), 0.0, f64::INFINITY), + ((1.0, f64::INFINITY), 0.5, f64::NEG_INFINITY), + ((1.0, f64::INFINITY), 1.0, f64::NEG_INFINITY), + ((f64::INFINITY, 1.0), 0.0, f64::NEG_INFINITY), + ((f64::INFINITY, 1.0), 0.5, f64::NEG_INFINITY), + ((f64::INFINITY, 1.0), 1.0, f64::INFINITY), ]; for &(arg, x, expect) in test.iter() { test_case(arg, expect, f(x)); @@ -591,13 +590,13 @@ mod tests { #[test] fn test_ln_pdf_input_lt_0() { let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg); - test_case((1.0, 1.0), -INF, ln_pdf(-1.0)); + test_case((1.0, 1.0), f64::NEG_INFINITY, ln_pdf(-1.0)); } #[test] fn test_ln_pdf_input_gt_1() { let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg); - test_case((1.0, 1.0), -INF, ln_pdf(2.0)); + test_case((1.0, 1.0), f64::NEG_INFINITY, ln_pdf(2.0)); } #[test] @@ -613,12 +612,12 @@ mod tests { ((5.0, 100.0), 0.0, 0.0), ((5.0, 100.0), 0.5, 1.0), ((5.0, 100.0), 1.0, 1.0), - ((1.0, INF), 0.0, 1.0), - ((1.0, INF), 0.5, 1.0), - ((1.0, INF), 1.0, 1.0), - ((INF, 1.0), 0.0, 0.0), - ((INF, 1.0), 0.5, 0.0), - ((INF, 1.0), 1.0, 1.0), + ((1.0, f64::INFINITY), 0.0, 1.0), + ((1.0, f64::INFINITY), 0.5, 1.0), + ((1.0, f64::INFINITY), 1.0, 1.0), + ((f64::INFINITY, 1.0), 0.0, 0.0), + ((f64::INFINITY, 1.0), 0.5, 0.0), + ((f64::INFINITY, 1.0), 1.0, 1.0), ]; for &(arg, x, expect) in test.iter() { test_case(arg, expect, cdf(x)); @@ -638,12 +637,12 @@ mod tests { ((5.0, 100.0), 0.0, 1.0), ((5.0, 100.0), 0.5, 0.0), ((5.0, 100.0), 1.0, 0.0), - ((1.0, INF), 0.0, 0.0), - ((1.0, INF), 0.5, 0.0), - ((1.0, INF), 1.0, 0.0), - ((INF, 1.0), 0.0, 1.0), - ((INF, 1.0), 0.5, 1.0), - ((INF, 1.0), 1.0, 0.0), + ((1.0, f64::INFINITY), 0.0, 0.0), + ((1.0, f64::INFINITY), 0.5, 0.0), + ((1.0, f64::INFINITY), 1.0, 0.0), + ((f64::INFINITY, 1.0), 0.0, 1.0), + ((f64::INFINITY, 1.0), 0.5, 1.0), + ((f64::INFINITY, 1.0), 1.0, 0.0), ]; for &(arg, x, expect) in test.iter() { test_case(arg, expect, sf(x)); diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index dcd81af5..d66a9747 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -138,7 +138,7 @@ impl Max for Cauchy { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index 0a65a8e6..f72376bb 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -143,7 +143,7 @@ impl Max for Chi { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index 5551b55f..ab8dc398 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -157,7 +157,7 @@ impl Max for ChiSquared { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index e07dff6b..e0721f24 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -141,7 +141,7 @@ impl Max for Erlang { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { self.g.max() diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index f374b9b2..0856989c 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -132,7 +132,7 @@ impl Max for Exp { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index ee70c30e..da8d9570 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -174,7 +174,7 @@ impl Max for FisherSnedecor { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index fd993ba5..02b55190 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -2,7 +2,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; use crate::{Result, StatsError}; -use core::f64::INFINITY as INF; use rand::Rng; /// Implements the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) @@ -172,10 +171,10 @@ impl Max for Gamma { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { - INF + f64::INFINITY } } @@ -255,8 +254,8 @@ impl Continuous for Gamma { /// /// # Remarks /// - /// Returns `NAN` if any of `shape` or `rate` are `INF` - /// or if `x` is `INF` + /// Returns `NAN` if any of `shape` or `rate` are `f64::INFINITY` + /// or if `x` is `f64::INFINITY` /// /// # Formula /// @@ -286,8 +285,8 @@ impl Continuous for Gamma { /// /// # Remarks /// - /// Returns `NAN` if any of `shape` or `rate` are `INF` - /// or if `x` is `INF` + /// Returns `NAN` if any of `shape` or `rate` are `f64::INFINITY` + /// or if `x` is `f64::INFINITY` /// /// # Formula /// @@ -367,7 +366,7 @@ mod tests { (1.0, 1.0), (10.0, 10.0), (10.0, 1.0), - (10.0, INF), + (10.0, f64::INFINITY), ]; for &arg in valid.iter() { @@ -398,7 +397,7 @@ mod tests { ((1.0, 1.0), 1.0), ((10.0, 10.0), 1.0), ((10.0, 1.0), 10.0), - ((10.0, INF), 0.0), + ((10.0, f64::INFINITY), 0.0), ]; for &(arg, res) in test.iter() { test_case(arg, res, f); @@ -413,7 +412,7 @@ mod tests { ((1.0, 1.0), 1.0), ((10.0, 10.0), 0.1), ((10.0, 1.0), 10.0), - ((10.0, INF), 0.0), + ((10.0, f64::INFINITY), 0.0), ]; for &(arg, res) in test.iter() { test_case(arg, res, f); @@ -428,7 +427,7 @@ mod tests { ((1.0, 1.0), 1.0), ((10.0, 10.0), 0.2334690854869339583626209), ((10.0, 1.0), 2.53605417848097964238061239), - ((10.0, INF), f64::NEG_INFINITY), + ((10.0, f64::INFINITY), f64::NEG_INFINITY), ]; for &(arg, res) in test.iter() { test_case(arg, res, f); @@ -443,7 +442,7 @@ mod tests { ((1.0, 1.0), 2.0), ((10.0, 10.0), 0.6324555320336758663997787), ((10.0, 1.0), 0.63245553203367586639977870), - ((10.0, INF), 0.6324555320336758), + ((10.0, f64::INFINITY), 0.6324555320336758), ]; for &(arg, res) in test.iter() { test_case(arg, res, f); @@ -457,7 +456,7 @@ mod tests { for &(arg, res) in test.iter() { test_case_special(arg, res, 10e-6, f); } - let test = [((10.0, 10.0), 0.9), ((10.0, 1.0), 9.0), ((10.0, INF), 0.0)]; + let test = [((10.0, 10.0), 0.9), ((10.0, 1.0), 9.0), ((10.0, f64::INFINITY), 0.0)]; for &(arg, res) in test.iter() { test_case(arg, res, f); } @@ -471,18 +470,18 @@ mod tests { ((1.0, 1.0), 0.0), ((10.0, 10.0), 0.0), ((10.0, 1.0), 0.0), - ((10.0, INF), 0.0), + ((10.0, f64::INFINITY), 0.0), ]; for &(arg, res) in test.iter() { test_case(arg, res, f); } let f = |x: Gamma| x.max(); let test = [ - ((1.0, 0.1), INF), - ((1.0, 1.0), INF), - ((10.0, 10.0), INF), - ((10.0, 1.0), INF), - ((10.0, INF), INF), + ((1.0, 0.1), f64::INFINITY), + ((1.0, 1.0), f64::INFINITY), + ((10.0, 10.0), f64::INFINITY), + ((10.0, 1.0), f64::INFINITY), + ((10.0, f64::INFINITY), f64::INFINITY), ]; for &(arg, res) in test.iter() { test_case(arg, res, f); @@ -506,9 +505,9 @@ mod tests { test_case(arg, res, f(x)); } //TODO: test special - // test_is_nan((10.0, INF), pdf(1.0)); // is this really the behavior we want? + // test_is_nan((10.0, f64::INFINITY), pdf(1.0)); // is this really the behavior we want? //TODO: test special - // (10.0, INF, INF, 0.0, pdf(INF)),]; + // (10.0, f64::INFINITY, f64::INFINITY, 0.0, pdf(f64::INFINITY)),]; } #[test] @@ -529,13 +528,13 @@ mod tests { ((10.0, 10.0), 10.0, -69.0527107131946016148658), ((10.0, 1.0), 1.0, -13.8018274800814696112077), ((10.0, 1.0), 10.0, -2.07856164313505845504579), - ((10.0, INF), INF, f64::NEG_INFINITY), + ((10.0, f64::INFINITY), f64::INFINITY, f64::NEG_INFINITY), ]; for &(arg, x, res) in test.iter() { test_case(arg, res, f(x)); } // TODO: test special - // test_is_nan((10.0, INF), f(1.0)); // is this really the behavior we want? + // test_is_nan((10.0, f64::INFINITY), f(1.0)); // is this really the behavior we want? } #[test] @@ -550,8 +549,8 @@ mod tests { ((10.0, 10.0), 10.0, 0.999999999999999999999999), ((10.0, 1.0), 1.0, 0.000000111425478338720677), ((10.0, 1.0), 10.0, 0.542070285528147791685835), - ((10.0, INF), 1.0, 0.0), - ((10.0, INF), 10.0, 1.0), + ((10.0, f64::INFINITY), 1.0, 0.0), + ((10.0, f64::INFINITY), 10.0, 1.0), ]; for &(arg, x, res) in test.iter() { test_case(arg, res, f(x)); @@ -575,8 +574,8 @@ mod tests { ((10.0, 10.0), 10.0, 1.1253473960842808e-31), ((10.0, 1.0), 1.0, 0.9999998885745217), ((10.0, 1.0), 10.0, 0.4579297144718528), - ((10.0, INF), 1.0, 1.0), - ((10.0, INF), 10.0, 0.0), + ((10.0, f64::INFINITY), 1.0, 1.0), + ((10.0, f64::INFINITY), 10.0, 0.0), ]; for &(arg, x, res) in test.iter() { test_case(arg, res, f(x)); diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index e439be45..31b1d4f6 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -162,7 +162,7 @@ impl Max for InverseGamma { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index 4f73b278..66893b46 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -173,7 +173,7 @@ impl Max for Laplace { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -291,7 +291,6 @@ impl Continuous for Laplace { #[cfg(all(test, feature = "nightly"))] mod tests { use super::*; - use core::f64::INFINITY as INF; use rand::thread_rng; fn try_create(location: f64, scale: f64) -> Laplace { @@ -349,12 +348,12 @@ mod tests { #[test] fn test_create() { try_create(1.0, 2.0); - try_create(-INF, 0.1); + try_create(f64::NEG_INFINITY, 0.1); try_create(-5.0 - 1.0, 1.0); try_create(0.0, 5.0); try_create(1.0, 7.0); try_create(5.0, 10.0); - try_create(INF, INF); + try_create(f64::INFINITY, f64::INFINITY); } #[test] @@ -367,71 +366,71 @@ mod tests { #[test] fn test_mean() { let mean = |x: Laplace| x.mean().unwrap(); - test_case(-INF, 0.1, -INF, mean); + test_case(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, mean); test_case(-5.0 - 1.0, 1.0, -6.0, mean); test_case(0.0, 5.0, 0.0, mean); test_case(1.0, 10.0, 1.0, mean); - test_case(INF, INF, INF, mean); + test_case(f64::INFINITY, f64::INFINITY, f64::INFINITY, mean); } #[test] fn test_variance() { let variance = |x: Laplace| x.variance().unwrap(); - test_almost(-INF, 0.1, 0.02, 1E-12, variance); + test_almost(f64::NEG_INFINITY, 0.1, 0.02, 1E-12, variance); test_almost(-5.0 - 1.0, 1.0, 2.0, 1E-12, variance); test_almost(0.0, 5.0, 50.0, 1E-12, variance); test_almost(1.0, 7.0, 98.0, 1E-12, variance); test_almost(5.0, 10.0, 200.0, 1E-12, variance); - test_almost(INF, INF, INF, 1E-12, variance); + test_almost(f64::INFINITY, f64::INFINITY, f64::INFINITY, 1E-12, variance); } #[test] fn test_entropy() { let entropy = |x: Laplace| x.entropy().unwrap(); - test_almost(-INF, 0.1, (2.0 * f64::consts::E * 0.1).ln(), 1E-12, entropy); + test_almost(f64::NEG_INFINITY, 0.1, (2.0 * f64::consts::E * 0.1).ln(), 1E-12, entropy); test_almost(-6.0, 1.0, (2.0 * f64::consts::E).ln(), 1E-12, entropy); test_almost(1.0, 7.0, (2.0 * f64::consts::E * 7.0).ln(), 1E-12, entropy); test_almost(5., 10., (2. * f64::consts::E * 10.).ln(), 1E-12, entropy); - test_almost(INF, INF, INF, 1E-12, entropy); + test_almost(f64::INFINITY, f64::INFINITY, f64::INFINITY, 1E-12, entropy); } #[test] fn test_skewness() { let skewness = |x: Laplace| x.skewness().unwrap(); - test_case(-INF, 0.1, 0.0, skewness); + test_case(f64::NEG_INFINITY, 0.1, 0.0, skewness); test_case(-6.0, 1.0, 0.0, skewness); test_case(1.0, 7.0, 0.0, skewness); test_case(5.0, 10.0, 0.0, skewness); - test_case(INF, INF, 0.0, skewness); + test_case(f64::INFINITY, f64::INFINITY, 0.0, skewness); } #[test] fn test_mode() { let mode = |x: Laplace| x.mode().unwrap(); - test_case(-INF, 0.1, -INF, mode); + test_case(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, mode); test_case(-6.0, 1.0, -6.0, mode); test_case(1.0, 7.0, 1.0, mode); test_case(5.0, 10.0, 5.0, mode); - test_case(INF, INF, INF, mode); + test_case(f64::INFINITY, f64::INFINITY, f64::INFINITY, mode); } #[test] fn test_median() { let median = |x: Laplace| x.median(); - test_case(-INF, 0.1, -INF, median); + test_case(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, median); test_case(-6.0, 1.0, -6.0, median); test_case(1.0, 7.0, 1.0, median); test_case(5.0, 10.0, 5.0, median); - test_case(INF, INF, INF, median); + test_case(f64::INFINITY, f64::INFINITY, f64::INFINITY, median); } #[test] fn test_min() { - test_case(0.0, 1.0, -INF, |l| l.min()); + test_case(0.0, 1.0, f64::NEG_INFINITY, |l| l.min()); } #[test] fn test_max() { - test_case(0.0, 1.0, INF, |l| l.max()); + test_case(0.0, 1.0, f64::INFINITY, |l| l.max()); } #[test] @@ -442,22 +441,22 @@ mod tests { test_almost(-1.0, 0.1, 3.8905661205668983e-19, 1E-12, pdf(-5.4)); test_almost(5.0, 0.1, 5.056107463052243e-43, 1E-12, pdf(-4.9)); test_almost(-5.0, 0.1, 1.9877248679543235e-30, 1E-12, pdf(2.0)); - test_almost(INF, 0.1, 0.0, 1E-12, pdf(5.5)); - test_almost(-INF, 0.1, 0.0, 1E-12, pdf(-0.0)); - test_almost(0.0, 1.0, 0.0, 1E-12, pdf(INF)); + test_almost(f64::INFINITY, 0.1, 0.0, 1E-12, pdf(5.5)); + test_almost(f64::NEG_INFINITY, 0.1, 0.0, 1E-12, pdf(-0.0)); + test_almost(0.0, 1.0, 0.0, 1E-12, pdf(f64::INFINITY)); test_almost(1.0, 1.0, 0.00915781944436709, 1E-12, pdf(5.0)); test_almost(-1.0, 1.0, 0.5, 1E-12, pdf(-1.0)); test_almost(5.0, 1.0, 0.0012393760883331792, 1E-12, pdf(-1.0)); test_almost(-5.0, 1.0, 0.0002765421850739168, 1E-12, pdf(2.5)); - test_almost(INF, 0.1, 0.0, 1E-12, pdf(2.0)); - test_almost(-INF, 0.1, 0.0, 1E-12, pdf(15.0)); - test_almost(0.0, INF, 0.0, 1E-12, pdf(89.3)); - test_almost(1.0, INF, 0.0, 1E-12, pdf(-0.1)); - test_almost(-1.0, INF, 0.0, 1E-12, pdf(0.1)); - test_almost(5.0, INF, 0.0, 1E-12, pdf(-6.1)); - test_almost(-5.0, INF, 0.0, 1E-12, pdf(-10.0)); - test_is_nan(INF, INF, pdf(2.0)); - test_is_nan(-INF, INF, pdf(-5.1)); + test_almost(f64::INFINITY, 0.1, 0.0, 1E-12, pdf(2.0)); + test_almost(f64::NEG_INFINITY, 0.1, 0.0, 1E-12, pdf(15.0)); + test_almost(0.0, f64::INFINITY, 0.0, 1E-12, pdf(89.3)); + test_almost(1.0, f64::INFINITY, 0.0, 1E-12, pdf(-0.1)); + test_almost(-1.0, f64::INFINITY, 0.0, 1E-12, pdf(0.1)); + test_almost(5.0, f64::INFINITY, 0.0, 1E-12, pdf(-6.1)); + test_almost(-5.0, f64::INFINITY, 0.0, 1E-12, pdf(-10.0)); + test_is_nan(f64::INFINITY, f64::INFINITY, pdf(2.0)); + test_is_nan(f64::NEG_INFINITY, f64::INFINITY, pdf(-5.1)); } #[test] @@ -468,22 +467,22 @@ mod tests { test_almost(-1.0, 0.1, -42.39056208756591, 1E-12, ln_pdf(-5.4)); test_almost(5.0, 0.1, -97.3905620875659, 1E-12, ln_pdf(-4.9)); test_almost(-5.0, 0.1, -68.3905620875659, 1E-12, ln_pdf(2.0)); - test_case(INF, 0.1, -INF, ln_pdf(5.5)); - test_case(-INF, 0.1, -INF, ln_pdf(-0.0)); - test_case(0.0, 1.0, -INF, ln_pdf(INF)); + test_case(f64::INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(5.5)); + test_case(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(-0.0)); + test_case(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); test_almost(1.0, 1.0, -4.693147180559945, 1E-12, ln_pdf(5.0)); test_almost(-1.0, 1.0, -f64::consts::LN_2, 1E-12, ln_pdf(-1.0)); test_almost(5.0, 1.0, -6.693147180559945, 1E-12, ln_pdf(-1.0)); test_almost(-5.0, 1.0, -8.193147180559945, 1E-12, ln_pdf(2.5)); - test_case(INF, 0.1, -INF, ln_pdf(2.0)); - test_case(-INF, 0.1, -INF, ln_pdf(15.0)); - test_case(0.0, INF, -INF, ln_pdf(89.3)); - test_case(1.0, INF, -INF, ln_pdf(-0.1)); - test_case(-1.0, INF, -INF, ln_pdf(0.1)); - test_case(5.0, INF, -INF, ln_pdf(-6.1)); - test_case(-5.0, INF, -INF, ln_pdf(-10.0)); - test_is_nan(INF, INF, ln_pdf(2.0)); - test_is_nan(-INF, INF, ln_pdf(-5.1)); + test_case(f64::INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(2.0)); + test_case(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(15.0)); + test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(89.3)); + test_case(1.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-0.1)); + test_case(-1.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.1)); + test_case(5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-6.1)); + test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-10.0)); + test_is_nan(f64::INFINITY, f64::INFINITY, ln_pdf(2.0)); + test_is_nan(f64::NEG_INFINITY, f64::INFINITY, ln_pdf(-5.1)); } #[test] diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 6698dd2a..028869e0 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -136,7 +136,7 @@ impl Max for LogNormal { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index d1cc91f2..bcd01b81 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -130,7 +130,7 @@ impl Min for Normal { /// # Formula /// /// ```text - /// -INF + /// f64::NEG_INFINITY /// ``` fn min(&self) -> f64 { f64::NEG_INFINITY @@ -144,7 +144,7 @@ impl Max for Normal { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index eca17f82..031205eb 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -160,7 +160,7 @@ impl Max for Pareto { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -174,7 +174,7 @@ impl Distribution for Pareto { /// /// ```text /// if α <= 1 { - /// INF + /// f64::INFINITY /// } else { /// (α * x_m)/(α - 1) /// } @@ -194,7 +194,7 @@ impl Distribution for Pareto { /// /// ```text /// if α <= 2 { - /// INF + /// f64::INFINITY /// } else { /// (x_m/(α - 1))^2 * (α/(α - 2)) /// } @@ -308,7 +308,7 @@ impl Continuous for Pareto { /// /// ```text /// if x < x_m { - /// -INF + /// f64::NEG_INFINITY /// } else { /// ln(α) + α*ln(x_m) - (α + 1)*ln(x) /// } diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index 21981513..4f84c489 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -210,7 +210,7 @@ impl Min for StudentsT { /// # Formula /// /// ```text - /// -INF + /// f64::NEG_INFINITY /// ``` fn min(&self) -> f64 { f64::NEG_INFINITY @@ -224,7 +224,7 @@ impl Max for StudentsT { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -261,7 +261,7 @@ impl Distribution for StudentsT { /// # Formula /// /// ```text - /// if v == INF { + /// if v == f64::INFINITY { /// Some(σ^2) /// } else if freedom > 2.0 { /// Some(v * σ^2 / (v - 2)) diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index c414c20c..4f04403d 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -156,7 +156,7 @@ impl Max for Weibull { /// # Formula /// /// ```text - /// INF + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY diff --git a/src/function/factorial.rs b/src/function/factorial.rs index eac59bcf..e06e9bcf 100644 --- a/src/function/factorial.rs +++ b/src/function/factorial.rs @@ -4,7 +4,6 @@ use crate::error::StatsError; use crate::function::gamma; use crate::Result; -use core::f64::INFINITY as INF; /// The maximum factorial representable /// by a 64-bit floating point without @@ -20,7 +19,7 @@ pub const MAX_FACTORIAL: usize = 170; /// Returns `f64::INFINITY` if `x > 170` pub fn factorial(x: u64) -> f64 { let x = x as usize; - FCACHE.get(x).map_or(INF, |&fac| fac) + FCACHE.get(x).map_or(f64::INFINITY, |&fac| fac) } /// Computes the logarithmic factorial function `x -> ln(x!)` @@ -124,8 +123,8 @@ mod tests { #[test] fn test_factorial_overflow() { - assert_eq!(factorial(172), INF); - assert_eq!(factorial(u64::MAX), INF); + assert_eq!(factorial(172), f64::INFINITY); + assert_eq!(factorial(u64::MAX), f64::INFINITY); } #[test] From d8705cf3e879032188aeb935fed4ae73c830786f Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sun, 21 Apr 2024 23:00:41 +0200 Subject: [PATCH 019/185] Remove imports of std::u64 This indirectly removes references to deprecated constants --- src/distribution/geometric.rs | 2 +- src/distribution/negative_binomial.rs | 2 +- src/distribution/poisson.rs | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index c7e801c1..f87f5ee0 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -3,7 +3,7 @@ use crate::statistics::*; use crate::{Result, StatsError}; use rand::distributions::OpenClosed01; use rand::Rng; -use std::{f64, u64}; +use std::f64; /// Implements the /// [Geometric](https://en.wikipedia.org/wiki/Geometric_distribution) diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index e455aa47..4c69a869 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -173,7 +173,7 @@ impl Max for NegativeBinomial { /// u64::MAX /// ``` fn max(&self) -> u64 { - std::u64::MAX + u64::MAX } } diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index e295780d..e8f98bef 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -4,7 +4,6 @@ use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; use std::f64; -use std::u64; /// Implements the [Poisson](https://en.wikipedia.org/wiki/Poisson_distribution) /// distribution From 27e28ecaeb9e3be5286a46400e976c580add56fc Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sun, 21 Apr 2024 23:02:58 +0200 Subject: [PATCH 020/185] Replace manual f64::clamp implementation --- src/function/beta.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/function/beta.rs b/src/function/beta.rs index f217d8d6..fec184f2 100644 --- a/src/function/beta.rs +++ b/src/function/beta.rs @@ -327,11 +327,7 @@ pub fn inv_beta_reg(mut a: f64, mut b: f64, mut x: f64) -> f64 { } } - if p < 0.0001 { - p = 0.0001; - } else if 0.9999 < p { - p = 0.9999; - } + p = p.clamp(0.0001, 0.9999); // Remark AS R83 // http://www.jstor.org/stable/2347779 From 594bced4855aa2fa67e2a8a8e4b3487c610d44dd Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sun, 21 Apr 2024 23:05:16 +0200 Subject: [PATCH 021/185] Rewrite some `*=` for clarity clippy complained about this and thought it was unintended --- src/distribution/gamma.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 02b55190..b8e7df44 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -341,8 +341,8 @@ pub fn sample_unchecked(rng: &mut R, shape: f64, rate: f64) -> }; } - v *= v * v; - x *= x; + v = v * v * v; + x = x * x; let u: f64 = rng.gen(); if u < 1.0 - 0.0331 * x * x || u.ln() < 0.5 * x + d * (1.0 - v + v.ln()) { return afix * d * v / rate; From d29d020f896444e1dea84c44f89df5d48b37d528 Mon Sep 17 00:00:00 2001 From: Jonas Marcello Date: Mon, 8 Apr 2024 21:52:00 +0200 Subject: [PATCH 022/185] Remove lazy-static and make FCACHE a proper const --- Cargo.toml | 1 - src/function/factorial.rs | 38 +++++++++++++++++++++++--------------- src/lib.rs | 3 --- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 34549647..a8b6e4ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,6 @@ rand = "0.8" nalgebra = { version = "0.32", features = ["rand"] } approx = "0.5.0" num-traits = "0.2.14" -lazy_static = "1.4.0" [dev-dependencies] criterion = "0.3.3" diff --git a/src/function/factorial.rs b/src/function/factorial.rs index e06e9bcf..77bcaa5b 100644 --- a/src/function/factorial.rs +++ b/src/function/factorial.rs @@ -90,26 +90,34 @@ pub fn checked_multinomial(n: u64, ni: &[u64]) -> Result { // Initialization for pre-computed cache of 171 factorial // values 0!...170! -lazy_static! { - static ref FCACHE: [f64; MAX_FACTORIAL + 1] = { - let mut fcache = [1.0; MAX_FACTORIAL + 1]; - fcache - .iter_mut() - .enumerate() - .skip(1) - .fold(1.0, |acc, (i, elt)| { - let fac = acc * i as f64; - *elt = fac; - fac - }); - fcache - }; -} +const FCACHE: [f64; MAX_FACTORIAL + 1] = { + let mut fcache = [1.0; MAX_FACTORIAL + 1]; + + // `const` only allow while loops + let mut i = 1; + while i < MAX_FACTORIAL + 1 { + fcache[i] = fcache[i - 1] * i as f64; + i += 1; + } + + fcache +}; #[cfg(test)] mod tests { use super::*; + #[test] + fn test_fcache() { + assert!((FCACHE[0] - 1.0).abs() < f64::EPSILON); + assert!((FCACHE[1] - 1.0).abs() < f64::EPSILON); + assert!((FCACHE[2] - 2.0).abs() < f64::EPSILON); + assert!((FCACHE[3] - 6.0).abs() < f64::EPSILON); + assert!((FCACHE[4] - 24.0).abs() < f64::EPSILON); + assert!((FCACHE[70] - 1197857166996989e85).abs() < f64::EPSILON); + assert!((FCACHE[170] - 7257415615307994e291).abs() < f64::EPSILON); + } + #[test] fn test_factorial_and_ln_factorial() { let mut fac = 1.0; diff --git a/src/lib.rs b/src/lib.rs index 939400fb..ad234627 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,9 +55,6 @@ #[macro_use] extern crate approx; -#[macro_use] -extern crate lazy_static; - #[macro_export] macro_rules! assert_almost_eq { ($a:expr, $b:expr, $prec:expr) => { From ce40e3a62dd3e680797be5853822dc13e7483d68 Mon Sep 17 00:00:00 2001 From: Henry Jacobson Date: Sun, 25 Dec 2022 14:16:31 +0100 Subject: [PATCH 023/185] feat: possibility for creating multivariate normal dist from nalgebra DVector/DMatrix --- src/distribution/multivariate_normal.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index b41a09e1..d6f400ec 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -37,7 +37,7 @@ pub struct MultivariateNormal { } impl MultivariateNormal { - /// Constructs a new multivariate normal distribution with a mean of `mean` + /// Constructs a new multivariate normal distribution with a mean of `mean` /// and covariance matrix `cov` /// /// # Errors @@ -47,6 +47,18 @@ impl MultivariateNormal { pub fn new(mean: Vec, cov: Vec) -> Result { let mean = DVector::from_vec(mean); let cov = DMatrix::from_vec(mean.len(), mean.len(), cov); + return MultivariateNormal::new_from_nalgebra(mean, cov) + } + + /// Constructs a new multivariate normal distribution with a mean of `mean` + /// and covariance matrix `cov`, but with explicitly using nalgebras + /// DVector and DMatrix instead of Vec + /// + /// # Errors + /// + /// Returns an error if the given covariance matrix is not + /// symmetric or positive-definite + pub fn new_from_nalgebra(mean: DVector, cov: DMatrix) -> Result { let dim = mean.len(); // Check that the provided covariance matrix is symmetric if cov.lower_triangle() != cov.upper_triangle().transpose() @@ -79,6 +91,7 @@ impl MultivariateNormal { } } } + /// Returns the entropy of the multivariate normal distribution /// /// # Formula From 9e63ee4ceff609059cfed27efce3ec8756aba574 Mon Sep 17 00:00:00 2001 From: Tony Rippy Date: Mon, 2 May 2022 10:05:04 -0400 Subject: [PATCH 024/185] Adds an `inverse_cdf()` specialization for Uniform --- src/distribution/uniform.rs | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index c5f7d776..c4abc985 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -102,6 +102,19 @@ impl ContinuousCDF for Uniform { (self.max - x) / (self.max - self.min) } } + + /// Finds the value of `x` where `F(p) = x` + fn inverse_cdf(&self, p: f64) -> f64 { + if !(0.0..=1.0).contains(&p) { + panic!("p must be in [0, 1], was {}", p); + } else if p == 0.0 { + self.min + } else if p == 1.0 { + self.max + } else { + (self.max - self.min) * p + self.min + } + } } impl Min for Uniform { @@ -417,6 +430,21 @@ mod tests { test_case(0.0, f64::INFINITY, 1.0, cdf(f64::INFINITY)); } + #[test] + fn test_inverse_cdf() { + let inverse_cdf = |arg: f64| move |x: Uniform| x.inverse_cdf(arg); + test_case(0.0, 0.0, 0.0, inverse_cdf(0.0)); + test_case(0.0, 0.0, 0.0, inverse_cdf(1.0)); + test_case(0.0, 0.1, 0.05, inverse_cdf(0.5)); + test_case(0.0, 10.0, 5.0, inverse_cdf(0.5)); + test_case(1.0, 10.0, 1.0, inverse_cdf(0.0)); + test_case(1.0, 10.0, 4.0, inverse_cdf(1.0 / 3.0)); + test_case(1.0, 10.0, 10.0, inverse_cdf(1.0)); + test_case(f64::NEG_INFINITY, f64::INFINITY, f64::NEG_INFINITY, inverse_cdf(0.0)); + test_case(0.0, f64::INFINITY, 0.0, inverse_cdf(0.0)); + test_case(0.0, f64::INFINITY, f64::INFINITY, inverse_cdf(1.0)); + } + #[test] fn test_cdf_lower_bound() { let cdf = |arg: f64| move |x: Uniform| x.cdf(arg); From cb9c2d67be8371e852dc862738656550bca47c45 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sun, 28 Apr 2024 18:26:59 +0200 Subject: [PATCH 025/185] Make PartialOrd impl canonical for `Empirical` --- src/distribution/empirical.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index 21e70522..24c327c9 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -13,13 +13,13 @@ impl Eq for NonNAN {} impl PartialOrd for NonNAN { fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) + Some(self.cmp(other)) } } impl Ord for NonNAN { fn cmp(&self, other: &Self) -> Ordering { - self.partial_cmp(other).unwrap() + self.0.partial_cmp(&other.0).unwrap() } } From ccf8ec8c48ddd879c4c85dab7a2ca6883478a936 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Mon, 29 Apr 2024 10:21:03 +0200 Subject: [PATCH 026/185] Rename NonNAN to NonNan --- src/distribution/empirical.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index 24c327c9..43588819 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -7,17 +7,17 @@ use rand::Rng; use std::collections::BTreeMap; #[derive(Clone, Debug, PartialEq)] -struct NonNAN(T); +struct NonNan(T); -impl Eq for NonNAN {} +impl Eq for NonNan {} -impl PartialOrd for NonNAN { +impl PartialOrd for NonNan { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for NonNAN { +impl Ord for NonNan { fn cmp(&self, other: &Self) -> Ordering { self.0.partial_cmp(&other.0).unwrap() } @@ -42,7 +42,7 @@ pub struct Empirical { sum: f64, mean_and_var: Option<(f64, f64)>, // keys are data points, values are number of data points with equal value - data: BTreeMap, u64>, + data: BTreeMap, u64>, } impl Empirical { @@ -86,13 +86,13 @@ impl Empirical { self.mean_and_var = Some((data_point, 0.)); } } - *self.data.entry(NonNAN(data_point)).or_insert(0) += 1; + *self.data.entry(NonNan(data_point)).or_insert(0) += 1; } } pub fn remove(&mut self, data_point: f64) { if !data_point.is_nan() { if let (Some(val), Some((mean, var))) = - (self.data.remove(&NonNAN(data_point)), self.mean_and_var) + (self.data.remove(&NonNan(data_point)), self.mean_and_var) { if val == 1 && self.data.is_empty() { self.mean_and_var = None; @@ -105,7 +105,7 @@ impl Empirical { var - (self.sum - 1.) * (data_point - mean) * (data_point - mean) / self.sum; self.sum -= 1.; if val != 1 { - self.data.insert(NonNAN(data_point), val - 1); + self.data.insert(NonNan(data_point), val - 1); }; self.mean_and_var = Some((mean, var)); } From 4d9451398987a28aed09ea4303dd5380fcfbbac6 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Mon, 29 Apr 2024 12:33:31 +0200 Subject: [PATCH 027/185] Add test asserting that `StatsError` is Sync & Send --- src/error.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/error.rs b/src/error.rs index f13ecfc6..c76d8b32 100644 --- a/src/error.rs +++ b/src/error.rs @@ -104,3 +104,18 @@ impl fmt::Display for StatsError { } } } + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_sync() {} + fn assert_send() {} + + #[test] + fn test_sync_send() { + // Error types should implement Sync and Send + let _ = assert_sync::(); + let _ = assert_send::(); + } +} From c8f219e948a73d3dbc2168a19074ad95ce8d7f99 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sun, 28 Apr 2024 18:40:42 +0200 Subject: [PATCH 028/185] Ignore warning about nested module with same name --- src/statistics/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/statistics/mod.rs b/src/statistics/mod.rs index b156c6ec..6a2d3342 100644 --- a/src/statistics/mod.rs +++ b/src/statistics/mod.rs @@ -10,5 +10,6 @@ mod iter_statistics; mod order_statistics; // TODO: fix later mod slice_statistics; +#[allow(clippy::module_inception)] mod statistics; mod traits; From cfdf90597d19f18e59af1c2be90c394ef14dd5f1 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sun, 28 Apr 2024 18:41:37 +0200 Subject: [PATCH 029/185] Remove needless `return` statement --- src/distribution/multivariate_normal.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index d6f400ec..ff4ec6bc 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -47,7 +47,7 @@ impl MultivariateNormal { pub fn new(mean: Vec, cov: Vec) -> Result { let mean = DVector::from_vec(mean); let cov = DMatrix::from_vec(mean.len(), mean.len(), cov); - return MultivariateNormal::new_from_nalgebra(mean, cov) + MultivariateNormal::new_from_nalgebra(mean, cov) } /// Constructs a new multivariate normal distribution with a mean of `mean` From 3e42a6662b41471de25b851cafdcb0c8990c21ea Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+YeungOnion@users.noreply.github.com> Date: Wed, 15 May 2024 17:34:24 -0500 Subject: [PATCH 030/185] chore: update README with formatting and adding to "contributing" - cleanup typos in badges and make reference-style links - remove text near to header to align with common practices - remove version number in README to reduce chance of mis-versioning in docs - edit "contributing" verbiage to be more welcoming and specific in direction --- README.md | 50 ++++++++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 90362c54..8b35a01e 100644 --- a/README.md +++ b/README.md @@ -1,45 +1,51 @@ # statrs -![tests](https://github.com/statrs-dev/statrs/actions/workflows/test.yml/badge.svg) -[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE.md)] -[![Crate](https://img.shields.io/crates/v/statrs.svg)](https://crates.io/crates/statrs) -![docs.rs](https://img.shields.io/docsrs/statrs?style=for-the-badge). -[![codecov](https://codecov.io/gh/statrs-dev/statrs/graph/badge.svg?token=XtMSMYXvIf)](https://codecov.io/gh/statrs-dev/statrs) - -## Current Version: v0.16.0 - -Should work for both nightly and stable Rust. - -**NOTE:** While I will try to maintain backwards compatibility as much as possible, since this is still a 0.x.x project the API is not considered stable and thus subject to possible breaking changes up until v1.0.0 - -## Description +![tests][actions-test-badge] +[![MIT licensed][license-badge]](./LICENSE.md) +[![Crate][crates-badge]][crates-url] +[![docs.rs](https://img.shields.io/docsrs/statrs)][docs-url] +[![codecov][codecov-badge]][codecov-url] + +[actions-test-badge]: https://github.com/statrs-dev/statrs/actions/workflows/test.yml/badge.svg +[crates-badge]: https://img.shields.io/crates/v/statrs.svg +[crates-url]: https://crates.io/crates/statrs +[license-badge]: https://img.shields.io/badge/license-MIT-blue.svg +[docsrs-badge]: https://img.shields.io/docsrs/statrs +[docs-url]: https://docs.rs/statrs/*/statrs +[codecov-badge]: https://codecov.io/gh/statrs-dev/statrs/graph/badge.svg?token=XtMSMYXvIf +[codecov-url]: https://codecov.io/gh/statrs-dev/statrs Statrs provides a host of statistical utilities for Rust scientific computing. + Included are a number of common distributions that can be sampled (i.e. Normal, Exponential, Student's T, Gamma, Uniform, etc.) plus common statistical functions like the gamma function, beta function, and error function. -This library is a work-in-progress port of the statistical capabilities in the C# Math.NET library. +This library began as port of the statistical capabilities in the C# Math.NET library. All unit tests in the library borrowed from Math.NET when possible and filled-in when not. - -This library is a work-in-progress and not complete. Planned for future releases are continued implementations of distributions as well as porting over more statistical utilities. -Please check out the documentation [here](https://docs.rs/statrs/*/statrs/). +Please check out the documentation [here][docs-url]. ## Usage Add the most recent release to your `Cargo.toml` -```Rust +```toml [dependencies] -statrs = "0.16" +statrs = "*" # replace * by the latest version of the crate. ``` -For examples, view the docs hosted on ![docs.rs](https://img.shields.io/docsrs/statrs?style=for-the-badge). +For examples, view [the docs](https://docs.rs/statrs/*/statrs/). ## Contributing -Want to contribute? -Check out some of the issues marked [help wanted](https://github.com/statrs-dev/statrs/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) +Thanks for your help to improve the project! +**No contribution is too small and all contributions are valued.** + +Suggestions if you don't know where to start, +- documentation is a great place to start, as you'll be able to identify the value of existing documentation better than its authors. +- tests are valuable in demonstrating correct behavior, you can review test coverage on the [CodeCov Report][codecov-url]*, not live until [#229](https://github.com/statrs-dev/statrs/pull/229) merged. +- check out some of the issues marked [help wanted](https://github.com/statrs-dev/statrs/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). +- look at what's not included from Math.NET's [Distributions](https://github.com/mathnet/mathnet-numerics/tree/master/src/Numerics/Distributions), [Statistics](https://github.com/mathnet/mathnet-numerics/tree/master/src/Numerics/Statistics), or related. ### How to contribute From 9f2aa4f2c7db6d58ddecbfcbc1487ed323182b93 Mon Sep 17 00:00:00 2001 From: tessob Date: Wed, 7 Jun 2023 09:16:14 +0200 Subject: [PATCH 031/185] Intellij IDEA & macOS compatibility --- .gitignore | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index dea682e4..3bb686bd 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,8 @@ #editor specific /.vscode/ -.idea/ \ No newline at end of file +.idea/ +*.iml + +# macOS +.DS_Store From 2db6bdb1264a495ea130c99486b5ab64e7fb2d4c Mon Sep 17 00:00:00 2001 From: tessob Date: Wed, 7 Jun 2023 09:17:29 +0200 Subject: [PATCH 032/185] Implementation & Tests --- src/distribution/exponential.rs | 34 +++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index 0856989c..c335b895 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -109,6 +109,19 @@ impl ContinuousCDF for Exp { (-self.rate * x).exp() } } + + /// Calculates the inverse cumulative distribution function. + /// + /// # Formula + /// + /// ```ignore + /// -ln(1 - p) / λ + /// ``` + /// + /// where `p` is the probability and `λ` is the rate + fn inverse_cdf(&self, p: f64) -> f64 { + -(1.0 - p).ln() / self.rate + } } impl Min for Exp { @@ -457,6 +470,27 @@ mod tests { test_case(f64::INFINITY, 1.0, cdf(f64::INFINITY)); } + #[test] + fn test_inverse_cdf() { + let distribution = Exp::new(0.42).unwrap(); + assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); + + let distribution = Exp::new(0.042).unwrap(); + assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); + + let distribution = Exp::new(0.0042).unwrap(); + assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); + + let distribution = Exp::new(0.33).unwrap(); + assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); + + let distribution = Exp::new(0.033).unwrap(); + assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); + + let distribution = Exp::new(0.0033).unwrap(); + assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); + } + #[test] fn test_sf() { let sf = |arg: f64| move |x: Exp| x.sf(arg); From fc2981407fed308efeb465414af47e8a099542a3 Mon Sep 17 00:00:00 2001 From: tessob Date: Thu, 8 Jun 2023 09:33:36 +0200 Subject: [PATCH 033/185] Update src/distribution/exponential.rs Co-authored-by: Warren Weckesser --- src/distribution/exponential.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index c335b895..b52e1899 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -120,7 +120,7 @@ impl ContinuousCDF for Exp { /// /// where `p` is the probability and `λ` is the rate fn inverse_cdf(&self, p: f64) -> f64 { - -(1.0 - p).ln() / self.rate + -(-p).ln_1p() / self.rate } } From 4f9c9c037ca6033b04841adb84894c59dac19e20 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 26 Apr 2024 21:21:38 -0500 Subject: [PATCH 034/185] fix: default inverse_cdf method incorrectly implements generic integer bisection additional tests --- src/distribution/binomial.rs | 25 +++++++++++++++++++++++++ src/distribution/mod.rs | 7 +++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index f3bfc76d..d9c1bf3c 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -565,6 +565,31 @@ mod tests { test_case(0.5, 3, 0.0, sf(5)); } + #[test] + fn test_inverse_cdf() { + let invcdf = |arg: f64| move |x: Binomial| x.inverse_cdf(arg); + test_case(0.4, 5, 2, invcdf(0.3456)); + test_case(0.5, 6, 4, invcdf(0.75)); + } + + #[test] + fn test_inverse_cdf_is_infimum() { + let invcdf = |arg: f64| move |x: Binomial| x.inverse_cdf(arg); + let a = 0.2592; + let b = 0.3456; + test_case(0.4, 5, 1, invcdf(a)); + test_case(0.4, 5, 1, invcdf((a+b)/2.0)); + test_case(0.4, 5, 2, invcdf(b)); + } + + #[test] + fn test_cdf_inverse_cdf() { + let cdf_invcdf = |arg: u64| move |x: Binomial| x.inverse_cdf(x.cdf(arg)); + test_case(0.3, 10, 3, cdf_invcdf(3)); + test_case(0.3, 10, 4, cdf_invcdf(4)); + test_case(0.5, 6, 4, cdf_invcdf(4)); + } + #[test] fn test_discrete() { test::check_discrete_distribution(&try_create(0.3, 5), 5); diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 562eb2bc..93806e5a 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -187,11 +187,14 @@ pub trait DiscreteCDF: Min + Max { }; let two = K::one() + K::one(); let mut high = two.clone(); - let mut low = K::min_value(); + let mut low = self.min(); while self.cdf(high.clone()) < p { high = high.clone() + high.clone(); } - while high != low { + while self.cdf(low.clone()) > p { + low = low.clone() / two.clone(); + } + while high != low.clone() + K::one() { let mid = (high.clone() + low.clone()) / two.clone(); if self.cdf(mid.clone()) >= p { high = mid; From 3acfb35066046f4f0c1c6ad91bf0393e60b6c93e Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 28 Apr 2024 16:31:39 -0500 Subject: [PATCH 035/185] feat: tighten trait bounds for DiscreteCDF --- src/distribution/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 93806e5a..4351393e 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -2,7 +2,7 @@ //! and provides //! concrete implementations for a variety of distributions. use super::statistics::{Max, Min}; -use ::num_traits::{float::Float, Bounded, Num}; +use ::num_traits::{Float, PrimInt}; pub use self::bernoulli::Bernoulli; pub use self::beta::Beta; @@ -145,7 +145,7 @@ pub trait ContinuousCDF: Min + Max { /// The `DiscreteCDF` trait is used to specify an interface for univariate /// discrete distributions. -pub trait DiscreteCDF: Min + Max { +pub trait DiscreteCDF: Min + Max { /// Returns the cumulative distribution function calculated /// at `x` for a given distribution. May panic depending /// on the implementor. From 6c6edfe1a0fcf525736c00b9a0a97c42135e7624 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 28 Apr 2024 16:35:25 -0500 Subject: [PATCH 036/185] feat: move integral function bisection out of ::inverse_cdf seemed generic enough to put elsewhere --- src/distribution/internal.rs | 35 +++++++++++++++++++++++++++++++++++ src/distribution/mod.rs | 34 +++++++++++++++------------------- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 95c93872..8c0fe87e 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -1,3 +1,5 @@ +use num_traits::{Float, PrimInt}; + /// Returns true if there are no elements in `x` in `arr` /// such that `x <= 0.0` or `x` is `f64::NAN` and `sum(arr) > 0.0`. /// IF `incl_zero` is true, it tests for `x < 0.0` instead of `x <= 0.0` @@ -12,6 +14,39 @@ pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool { sum != 0.0 } +/// implements univariate function bisection search with infimum +/// if `None`, either the function was found not semi-monotone on the interval +/// or the provided bounds did not map to a range containing `z` +/// if `Some(k)`, then the condition below is met +/// ```text +/// smallest k such that f(k) >= z +/// ``` +pub fn integral_bisection_search( + f: impl Fn(K) -> T, z: T, lb: K, ub: K, +) -> Option { + if lb > ub || !(f(lb)..=f(ub)).contains(&z) { + return None; + } + let two = K::one() + K::one(); + let mut lb = lb; + let mut ub = ub; + loop { + let mid = (lb + ub) / two; + if !(f(lb)..=f(ub)).contains(&f(mid)) { + // if f found to not be monotone on the interval + return None; + } else if (lb..=lb + K::one()).contains(&ub) { + // if ub \in [lb, lb+1] + return Some(ub); + } else if f(mid) >= z { + // implies mid >= z + ub = mid; + } else { + lb = mid; + } + } +} + #[macro_use] #[cfg(all(test, feature = "nightly"))] pub mod test { diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 4351393e..384dea64 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -177,32 +177,28 @@ pub trait DiscreteCDF: Min + Max { /// Due to issues with rounding and floating-point accuracy the default implementation may be ill-behaved /// Specialized inverse cdfs should be used whenever possible. + /// + /// # Panics, for the default impl + /// + /// If `x <= 0.0` or `x >= 1.0` fn inverse_cdf(&self, p: T) -> K { // TODO: fix integer implementation if p == T::zero() { return self.min(); - }; - if p == T::one() { + } else if p == T::one() { return self.max(); - }; - let two = K::one() + K::one(); - let mut high = two.clone(); - let mut low = self.min(); - while self.cdf(high.clone()) < p { - high = high.clone() + high.clone(); + } else if !(T::zero()..=T::one()).contains(&p) { + panic!("p must be in [0, 1]") } - while self.cdf(low.clone()) > p { - low = low.clone() / two.clone(); - } - while high != low.clone() + K::one() { - let mid = (high.clone() + low.clone()) / two.clone(); - if self.cdf(mid.clone()) >= p { - high = mid; - } else { - low = mid; - } + + let two = K::one() + K::one(); + let mut high = two; + let low = self.min(); + while self.cdf(high) < p { + high = two * high; } - high + + internal::integral_bisection_search(|p| self.cdf(p), p, low, high).unwrap() } } From 24eb78542f50b03645851adf36f6fcc098ffb051 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 23 May 2024 17:55:28 -0500 Subject: [PATCH 037/185] fix: bisection did not handle lower bound landing exactly on needle also changed trait bounds to be less stringent --- src/distribution/internal.rs | 34 +++++++++++++++++++--------------- src/distribution/mod.rs | 25 +++++++++++++------------ 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 8c0fe87e..a045c9f9 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -1,4 +1,4 @@ -use num_traits::{Float, PrimInt}; +use num_traits::{Bounded, Float, Num}; /// Returns true if there are no elements in `x` in `arr` /// such that `x <= 0.0` or `x` is `f64::NAN` and `sum(arr) > 0.0`. @@ -14,32 +14,36 @@ pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool { sum != 0.0 } -/// implements univariate function bisection search with infimum -/// if `None`, either the function was found not semi-monotone on the interval -/// or the provided bounds did not map to a range containing `z` -/// if `Some(k)`, then the condition below is met +/// Implements univariate function bisection searching for criteria /// ```text /// smallest k such that f(k) >= z /// ``` -pub fn integral_bisection_search( - f: impl Fn(K) -> T, z: T, lb: K, ub: K, +/// Evaluates to `None` if +/// - provided interval has lower bound greater than upper bound +/// - function found not semi-monotone on the provided interval containing `z` +/// Evaluates to `Some(k)`, where `k` satisfies the search criteria +pub fn integral_bisection_search( + f: impl Fn(&K) -> T, z: T, lb: K, ub: K, ) -> Option { - if lb > ub || !(f(lb)..=f(ub)).contains(&z) { + if !(f(&lb)..=f(&ub)).contains(&z) { return None; } let two = K::one() + K::one(); let mut lb = lb; let mut ub = ub; loop { - let mid = (lb + ub) / two; - if !(f(lb)..=f(ub)).contains(&f(mid)) { - // if f found to not be monotone on the interval + let mid = (lb.clone() + ub.clone()) / two.clone(); + if !(f(&lb)..=f(&ub)).contains(&f(&mid)) { + // if f found not monotone on the interval return None; - } else if (lb..=lb + K::one()).contains(&ub) { - // if ub \in [lb, lb+1] + } else if f(&lb) == z { + return Some(lb); + } else if f(&ub) == z { return Some(ub); - } else if f(mid) >= z { - // implies mid >= z + } else if (lb.clone() + K::one()) == ub { + // no more elements to search + return Some(ub); + } else if f(&mid) >= z { ub = mid; } else { lb = mid; diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 384dea64..56deb09a 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -2,7 +2,8 @@ //! and provides //! concrete implementations for a variety of distributions. use super::statistics::{Max, Min}; -use ::num_traits::{Float, PrimInt}; +use ::num_traits::{Bounded, Float, Num}; +use num_traits::{NumAssign, NumAssignOps, NumAssignRef}; pub use self::bernoulli::Bernoulli; pub use self::beta::Beta; @@ -145,7 +146,9 @@ pub trait ContinuousCDF: Min + Max { /// The `DiscreteCDF` trait is used to specify an interface for univariate /// discrete distributions. -pub trait DiscreteCDF: Min + Max { +pub trait DiscreteCDF: + Min + Max +{ /// Returns the cumulative distribution function calculated /// at `x` for a given distribution. May panic depending /// on the implementor. @@ -178,27 +181,25 @@ pub trait DiscreteCDF: Min + Max { /// Due to issues with rounding and floating-point accuracy the default implementation may be ill-behaved /// Specialized inverse cdfs should be used whenever possible. /// - /// # Panics, for the default impl - /// - /// If `x <= 0.0` or `x >= 1.0` + /// # Panics + /// this default impl panics if provided `p` not on interval [0.0, 1.0] fn inverse_cdf(&self, p: T) -> K { - // TODO: fix integer implementation if p == T::zero() { return self.min(); } else if p == T::one() { return self.max(); } else if !(T::zero()..=T::one()).contains(&p) { - panic!("p must be in [0, 1]") + panic!("p must be on [0, 1]") } let two = K::one() + K::one(); - let mut high = two; - let low = self.min(); - while self.cdf(high) < p { - high = two * high; + let mut ub = two.clone(); + let lb = self.min(); + while self.cdf(ub.clone()) < p { + ub *= two.clone(); } - internal::integral_bisection_search(|p| self.cdf(p), p, low, high).unwrap() + internal::integral_bisection_search(|p| self.cdf(p.clone()), p, lb, ub).unwrap() } } From 30b60c36f4428df57391ec913c7a67912fd5837a Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 23 May 2024 18:01:06 -0500 Subject: [PATCH 038/185] test: integer_bisection_search is tested generally instead of within binomial --- src/distribution/binomial.rs | 10 ---------- src/distribution/internal.rs | 24 +++++++++++++++++++++++- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index d9c1bf3c..38499617 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -572,16 +572,6 @@ mod tests { test_case(0.5, 6, 4, invcdf(0.75)); } - #[test] - fn test_inverse_cdf_is_infimum() { - let invcdf = |arg: f64| move |x: Binomial| x.inverse_cdf(arg); - let a = 0.2592; - let b = 0.3456; - test_case(0.4, 5, 1, invcdf(a)); - test_case(0.4, 5, 1, invcdf((a+b)/2.0)); - test_case(0.4, 5, 2, invcdf(b)); - } - #[test] fn test_cdf_inverse_cdf() { let cdf_invcdf = |arg: u64| move |x: Binomial| x.inverse_cdf(x.cdf(arg)); diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index a045c9f9..f88c5f62 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -54,7 +54,7 @@ pub fn integral_bisection_search( #[macro_use] #[cfg(all(test, feature = "nightly"))] pub mod test { - use super::is_valid_multinomial; + use super::*; use crate::consts::ACC; use crate::distribution::{Continuous, ContinuousCDF, Discrete, DiscreteCDF}; @@ -235,4 +235,26 @@ pub mod test { let invalid = [5.2, 0.0, 1e-15, 1000000.12]; assert!(!is_valid_multinomial(&invalid, false)); } + + #[test] + fn test_integer_bisection() { + fn search(z: usize, data: &Vec) -> Option { + integral_bisection_search(|idx: &usize| data[*idx], z, 0, data.len() - 1) + } + + let needle = 3; + let data = (0..5) + .map(|n| if n >= needle { n + 1 } else { n }) + .collect::>(); + + for i in 0..(data.len()) { + assert_eq!(search(data[i], &data), Some(i),) + } + { + let infimum = search(needle, &data); + let found_element = search(needle + 1, &data); // 4 > needle && member of range + assert_eq!(found_element, Some(needle)); + assert_eq!(infimum, found_element) + } + } } From 6a25feac74ad5049d3137a427c5f8d241b7f06ec Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 23 May 2024 18:12:02 -0500 Subject: [PATCH 039/185] test: test cases in issue #185 for Binomial inverse_cdf --- src/distribution/binomial.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 38499617..2b4a3ea3 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -569,6 +569,9 @@ mod tests { fn test_inverse_cdf() { let invcdf = |arg: f64| move |x: Binomial| x.inverse_cdf(arg); test_case(0.4, 5, 2, invcdf(0.3456)); + + // cases in issue #185 + test_case(0.018, 465, 1, invcdf(3.472e-4)); test_case(0.5, 6, 4, invcdf(0.75)); } From f845bb2de9aba6e6f677e9d2426c6447a0b7f0a8 Mon Sep 17 00:00:00 2001 From: Joey Sacchini Date: Fri, 22 Jan 2021 13:56:04 -0500 Subject: [PATCH 040/185] implement and unit-test specialized inverse_cdf for LogNormal distribution --- src/distribution/log_normal.rs | 142 ++++++++++++++++++++++++++------- 1 file changed, 114 insertions(+), 28 deletions(-) diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 028869e0..f2a3de9a 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -113,6 +113,32 @@ impl ContinuousCDF for LogNormal { 0.5 * erf::erfc((x.ln() - self.location) / (self.scale * f64::consts::SQRT_2)) } } + /// Calculates the inverse cumulative distribution function for the + /// log-normal distribution at `p` + /// + /// # Panics + /// + /// If `p < 0.0` or `p > 1.0` + /// + /// # Formula + /// + /// ```ignore + /// μ - σ * sqrt(2) * erfc_inv(2p) + /// ``` + /// + /// where `μ` is the location, `σ` is the scale and `erfc_inv` is + /// the inverse of the complementary error function + fn inverse_cdf(&self, p: f64) -> f64 { + if p == 0.0 { + 0.0 + } else if p < 1.0 { + (self.location - (self.scale * f64::consts::SQRT_2 * erf::erfc_inv(2.0 * p))).exp() + } else if p == 1.0 { + f64::INFINITY + } else { + panic!("p must be within [0.0, 1.0]"); + } + } } impl Min for LogNormal { @@ -590,34 +616,94 @@ mod tests { #[test] fn test_cdf() { - let cdf = |arg: f64| move |x: LogNormal| x.cdf(arg); - test_almost(-0.1, 0.1, 0.0, 1e-107, cdf(0.1)); - test_almost(-0.1, 0.1, 0.0000000015011556178148777579869633555518882664666520593658, 1e-19, cdf(0.5)); - test_almost(-0.1, 0.1, 0.10908001076375810900224507908874442583171381706127, 1e-11, cdf(0.8)); - test_almost(-0.1, 1.5, 0.070999149762464508991968731574953594549291668468349, 1e-11, cdf(0.1)); - test_case(-0.1, 1.5, 0.34626224992888089297789445771047690175505847991946, cdf(0.5)); - test_case(-0.1, 1.5, 0.46728530589487698517090261668589508746353129242404, cdf(0.8)); - test_almost(-0.1, 2.5, 0.18914969879695093477606645992572208111152994999076, 1e-10, cdf(0.1)); - test_case(-0.1, 2.5, 0.40622798321378106125020505907901206714868922279347, cdf(0.5)); - test_case(-0.1, 2.5, 0.48035707589956665425068652807400957345208517749893, cdf(0.8)); - test_almost(1.5, 0.1, 0.0, 1e-315, cdf(0.1)); - test_almost(1.5, 0.1, 0.0, 1e-106, cdf(0.5)); - test_almost(1.5, 0.1, 0.0, 1e-66, cdf(0.8)); - test_almost(1.5, 1.5, 0.005621455876973168709588070988239748831823850202953, 1e-12, cdf(0.1)); - test_almost(1.5, 1.5, 0.07185716187918271235246980951571040808235628115265, 1e-11, cdf(0.5)); - test_almost(1.5, 1.5, 0.12532699044614938400496547188720940854423187977236, 1e-11, cdf(0.8)); - test_almost(1.5, 2.5, 0.064125647996943514411570834861724406903677144126117, 1e-11, cdf(0.1)); - test_almost(1.5, 2.5, 0.19017302281590810871719754032332631806011441356498, 1e-10, cdf(0.5)); - test_almost(1.5, 2.5, 0.24533064397555500690927047163085419096928289095201, 1e-16, cdf(0.8)); - test_case(2.5, 0.1, 0.0, cdf(0.1)); - test_almost(2.5, 0.1, 0.0, 1e-223, cdf(0.5)); - test_almost(2.5, 0.1, 0.0, 1e-162, cdf(0.8)); - test_almost(2.5, 1.5, 0.00068304052220788502001572635016579586444611070077399, 1e-13, cdf(0.1)); - test_almost(2.5, 1.5, 0.016636862816580533038130583128179878924863968664206, 1e-12, cdf(0.5)); - test_almost(2.5, 1.5, 0.034729001282904174941366974418836262996834852343018, 1e-11, cdf(0.8)); - test_almost(2.5, 2.5, 0.027363708266690978870139978537188410215717307180775, 1e-11, cdf(0.1)); - test_almost(2.5, 2.5, 0.10075543423327634536450625420610429181921642201567, 1e-11, cdf(0.5)); - test_almost(2.5, 2.5, 0.13802019192453118732001307556787218421918336849121, 1e-11, cdf(0.8)); + cdf_tests(false); + } + + #[test] + fn test_inverse_cdf() { + cdf_tests(true) + } + + // we can reuse the (input, output) pairs from the CDF unit test + // and verify that passing an 'output' to .inverse_cdf gives 'input', + // except in cases where output would be 0.0 (the inverse_cdf is defined to + // always give 0.0 in this case). + fn cdf_tests(inverse: bool) { + let f = |arg: f64| move |x: LogNormal| if inverse { + x.inverse_cdf(arg) + } else { + x.cdf(arg) + }; + + // given some cdf_input and cdf_output, returns a tuple (input, output) where + // input is what we will provide to cdf/inverse_cdf, and output is expected return + // value + let arrange_input_output = |cdf_input: f64, cdf_output: f64| { + if inverse { + (cdf_output, cdf_input) + } else { + (cdf_input, cdf_output) + } + }; + + // calls test_almost after re-arranging the input/output arguments and calling f with input + let almost = |mean: f64, std_dev: f64, cdf_input: f64, cdf_output: f64, acc: f64| { + let (input, output) = arrange_input_output(cdf_input, cdf_output); + test_almost(mean, std_dev, output, acc, f(input)); + }; + + // calls test_case after re-arranging the input/output arguments and calling f with input + let case = |mean: f64, std_dev: f64, cdf_input: f64, cdf_output: f64| { + let (input, output) = arrange_input_output(cdf_input, cdf_output); + test_case(mean, std_dev, output, f(input)); + }; + + // we skip cases where the CDF outputs 0.0 when testing the inverse CDF because + // there are multiple inputs to the CDF which give an answer of 0.0, therefore testing whether + // inputting 0.0 to the inverse cdf will give the same answer is not a valid test + // the inverse cdf for log-normal is defined to give answer 0.0 for input 0.0 + if inverse { + case(-0.1, 0.1, 0.0, 0.0); + } + + if !inverse { + almost(-0.1, 0.1, 0.1, 0.0, 1e-107); + } + almost(-0.1, 0.1, 0.5, 0.0000000015011556178148777579869633555518882664666520593658, 1e-16); + almost(-0.1, 0.1, 0.8, 0.10908001076375810900224507908874442583171381706127, 1e-11); + almost(-0.1, 1.5, 0.1, 0.070999149762464508991968731574953594549291668468349, 1e-11); + case(-0.1, 1.5, 0.5, 0.34626224992888089297789445771047690175505847991946); + case(-0.1, 1.5, 0.8, 0.46728530589487698517090261668589508746353129242404); + almost(-0.1, 2.5, 0.1, 0.18914969879695093477606645992572208111152994999076, 1e-10); + case(-0.1, 2.5, 0.5, 0.40622798321378106125020505907901206714868922279347); + case(-0.1, 2.5, 0.8, 0.48035707589956665425068652807400957345208517749893); + + // input to inverse would be 0.0 + if !inverse { + almost(1.5, 0.1, 0.1, 0.0, 1e-315); + almost(1.5, 0.1, 0.5, 0.0, 1e-106); + almost(1.5, 0.1, 0.8, 0.0, 1e-66); + } + + almost(1.5, 1.5, 0.1, 0.005621455876973168709588070988239748831823850202953, 1e-12); + almost(1.5, 1.5, 0.8, 0.12532699044614938400496547188720940854423187977236, 1e-11); + almost(1.5, 2.5, 0.1, 0.064125647996943514411570834861724406903677144126117, 1e-11); + almost(1.5, 2.5, 0.5, 0.19017302281590810871719754032332631806011441356498, 1e-10); + almost(1.5, 2.5, 0.8, 0.24533064397555500690927047163085419096928289095201, 1e-16); + + // input to inverse would be 0.0 + if !inverse { + case(2.5, 0.1, 0.1, 0.0); + almost(2.5, 0.1, 0.5, 0.0, 1e-223); + almost(2.5, 0.1, 0.8, 0.0, 1e-162); + } + + almost(2.5, 1.5, 0.1, 0.00068304052220788502001572635016579586444611070077399, 1e-13); + almost(2.5, 1.5, 0.5, 0.016636862816580533038130583128179878924863968664206, 1e-12); + almost(2.5, 1.5, 0.8, 0.034729001282904174941366974418836262996834852343018, 1e-11); + almost(2.5, 2.5, 0.1, 0.027363708266690978870139978537188410215717307180775, 1e-11); + almost(2.5, 2.5, 0.5, 0.10075543423327634536450625420610429181921642201567, 1e-11); + almost(2.5, 2.5, 0.8, 0.13802019192453118732001307556787218421918336849121, 1e-11); } #[test] From 7fa7d7ea8d95d62328c6cb58472134b2dcd8c122 Mon Sep 17 00:00:00 2001 From: Rik Huijzer Date: Mon, 24 Oct 2022 12:05:47 +0200 Subject: [PATCH 041/185] Mention quantile function in docstring --- src/distribution/normal.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index bcd01b81..ba7a408c 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -100,7 +100,8 @@ impl ContinuousCDF for Normal { } /// Calculates the inverse cumulative distribution function for the - /// normal distribution at `x` + /// normal distribution at `x`. + /// In other languages, such as R, this is known as the the quantile function. /// /// # Panics /// From a8d29e743d945194600f5dc382fdc2164ec321d5 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 2 May 2024 17:57:15 -0500 Subject: [PATCH 042/185] feat: implement newton raphson search for ::inverse_cdf uses bisection prior to NR search --- src/distribution/gamma.rs | 51 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index b8e7df44..cd9a8afa 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -146,6 +146,57 @@ impl ContinuousCDF for Gamma { gamma::gamma_ur(self.shape, x * self.rate) } } + + fn inverse_cdf(&self, p: f64) -> f64 { + fn convergence(x: &mut f64, x_new: f64) -> bool { + let out = approx::relative_eq!(*x, x_new, max_relative = crate::consts::ACC); + *x = x_new; + out + } + + const MAX_ITERS: (u16, u16) = (8, 4); + if !(0.0..=1.0).contains(&p) { + panic!("default inverse_cdf implementation should be provided probability on [0,1]") + } + if p == 0.0 { + return self.min(); + }; + if p == 1.0 { + return self.max(); + }; + + // Bisection search for MAX_ITERS.0 iterations + let mut high = 2.0; + let mut low = 1.0; + while self.cdf(low) > p { + low /= 2.0; + } + while self.cdf(high) < p { + high *= 2.0; + } + let mut x_0 = (high + low) / 2.0; + + for _ in 0..MAX_ITERS.0 { + if self.cdf(x_0) >= p { + high = x_0; + } else { + low = x_0; + } + if convergence(&mut x_0, (high + low) / 2.0) { + break; + } + } + + // NR method, guarantee at least one step + for _ in 0..MAX_ITERS.1 { + let x_next = x_0 - (self.cdf(x_0) - p) / self.pdf(x_0); + if convergence(&mut x_0, x_next) { + break; + } + } + + x_0 + } } impl Min for Gamma { From 0a7b6a1f4f314523b0b7f18f061e60d0a41d2859 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sat, 4 May 2024 10:43:11 -0500 Subject: [PATCH 043/185] chore: move helper for convergence into `prec` --- src/distribution/gamma.rs | 11 +++-------- src/prec.rs | 9 +++++++++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index cd9a8afa..863d41cf 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -1,5 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; +use crate::prec; use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; @@ -148,12 +149,6 @@ impl ContinuousCDF for Gamma { } fn inverse_cdf(&self, p: f64) -> f64 { - fn convergence(x: &mut f64, x_new: f64) -> bool { - let out = approx::relative_eq!(*x, x_new, max_relative = crate::consts::ACC); - *x = x_new; - out - } - const MAX_ITERS: (u16, u16) = (8, 4); if !(0.0..=1.0).contains(&p) { panic!("default inverse_cdf implementation should be provided probability on [0,1]") @@ -182,7 +177,7 @@ impl ContinuousCDF for Gamma { } else { low = x_0; } - if convergence(&mut x_0, (high + low) / 2.0) { + if prec::convergence(&mut x_0, (high + low) / 2.0) { break; } } @@ -190,7 +185,7 @@ impl ContinuousCDF for Gamma { // NR method, guarantee at least one step for _ in 0..MAX_ITERS.1 { let x_next = x_0 - (self.cdf(x_0) - p) / self.pdf(x_0); - if convergence(&mut x_0, x_next) { + if prec::convergence(&mut x_0, x_next) { break; } } diff --git a/src/prec.rs b/src/prec.rs index 59ad3714..042c8b22 100644 --- a/src/prec.rs +++ b/src/prec.rs @@ -25,3 +25,12 @@ pub fn almost_eq(a: f64, b: f64, acc: f64) -> bool { (a - b).abs() < acc } + +/// Compares if two floats are close via `approx::relative_eq!` +/// and `crate::consts::ACC` relative precision. +/// Updates first argument to value of second argument +pub fn convergence(x: &mut f64, x_new: f64) -> bool { + let res = approx::relative_eq!(*x, x_new, max_relative = crate::consts::ACC); + *x = x_new; + res +} From daeea6c1891b3f85965b197b03b2166729bb8c99 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sat, 4 May 2024 10:44:46 -0500 Subject: [PATCH 044/185] test: cdf(inverse_cdf(p)) ~ p for `::inverse_cdf` --- src/distribution/gamma.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 863d41cf..19eb4780 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -608,6 +608,25 @@ mod tests { test_case((1.0, 0.1), 0.0, |x| x.cdf(0.0)); } + #[test] + fn test_cdf_inverse_identity() { + let f = |p: f64| move |g: Gamma| g.cdf(g.inverse_cdf(p)); + let params = [ + (1.0, 0.1), + (1.0, 1.0), + (10.0, 10.0), + (10.0, 1.0), + (100.0, 200.0), + ]; + + for param in params { + for n in -5..0 { + let p = 10.0f64.powi(n); + test_case(param, p, f(p)); + } + } + } + #[test] fn test_sf() { let f = |arg: f64| move |x: Gamma| x.sf(arg); From e46d123f7f3de3bf70196c7868793c5a03108e10 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 23 May 2024 19:48:03 -0500 Subject: [PATCH 045/185] test: add test case from #200 --- src/distribution/gamma.rs | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 19eb4780..d1abf0c2 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -133,17 +133,13 @@ impl ContinuousCDF for Gamma { fn sf(&self, x: f64) -> f64 { if x <= 0.0 { 1.0 - } - else if ulps_eq!(x, self.shape) && self.rate.is_infinite() { + } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() { 0.0 - } - else if self.rate.is_infinite() { + } else if self.rate.is_infinite() { 1.0 - } - else if x.is_infinite() { + } else if x.is_infinite() { 0.0 - } - else { + } else { gamma::gamma_ur(self.shape, x * self.rate) } } @@ -502,7 +498,11 @@ mod tests { for &(arg, res) in test.iter() { test_case_special(arg, res, 10e-6, f); } - let test = [((10.0, 10.0), 0.9), ((10.0, 1.0), 9.0), ((10.0, f64::INFINITY), 0.0)]; + let test = [ + ((10.0, 10.0), 0.9), + ((10.0, 1.0), 9.0), + ((10.0, f64::INFINITY), 0.0), + ]; for &(arg, res) in test.iter() { test_case(arg, res, f); } @@ -625,6 +625,13 @@ mod tests { test_case(param, p, f(p)); } } + + // test case from issue #200 + { + let x = 20.5567; + let f = |x: f64| move |g: Gamma| g.inverse_cdf(g.cdf(x)); + test_case((3.0, 0.5), x, f(x)) + } } #[test] From 6699215aac56a5ce2530ca7520fc0a2bce854448 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 23 May 2024 20:03:12 -0500 Subject: [PATCH 046/185] style: use literals with comments instead of const --- src/distribution/gamma.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index d1abf0c2..3fa926b0 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -145,7 +145,6 @@ impl ContinuousCDF for Gamma { } fn inverse_cdf(&self, p: f64) -> f64 { - const MAX_ITERS: (u16, u16) = (8, 4); if !(0.0..=1.0).contains(&p) { panic!("default inverse_cdf implementation should be provided probability on [0,1]") } @@ -167,7 +166,7 @@ impl ContinuousCDF for Gamma { } let mut x_0 = (high + low) / 2.0; - for _ in 0..MAX_ITERS.0 { + for _ in 0..8 { if self.cdf(x_0) >= p { high = x_0; } else { @@ -178,8 +177,8 @@ impl ContinuousCDF for Gamma { } } - // NR method, guarantee at least one step - for _ in 0..MAX_ITERS.1 { + // Newton Raphson, for at least one step + for _ in 0..4 { let x_next = x_0 - (self.cdf(x_0) - p) / self.pdf(x_0); if prec::convergence(&mut x_0, x_next) { break; From 5281589331f041a658180fcd73875512fbd569c8 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Mon, 22 Apr 2024 20:28:19 +0200 Subject: [PATCH 047/185] Apply whitespace fixes --- src/distribution/bernoulli.rs | 5 ++++- src/distribution/beta.rs | 5 ++++- src/distribution/binomial.rs | 3 +++ src/distribution/categorical.rs | 8 +++++--- src/distribution/chi.rs | 3 +++ src/distribution/chi_squared.rs | 3 +++ src/distribution/dirac.rs | 5 +++-- src/distribution/dirichlet.rs | 1 + src/distribution/discrete_uniform.rs | 5 ++++- src/distribution/empirical.rs | 12 ++++++++---- src/distribution/erlang.rs | 3 +++ src/distribution/exponential.rs | 3 +++ src/distribution/fisher_snedecor.rs | 6 ++++-- src/distribution/gamma.rs | 7 +++++-- src/distribution/geometric.rs | 3 +++ src/distribution/hypergeometric.rs | 2 ++ src/distribution/inverse_gamma.rs | 3 +++ src/distribution/laplace.rs | 3 +++ src/distribution/log_normal.rs | 3 +++ src/distribution/multivariate_normal.rs | 2 ++ src/distribution/negative_binomial.rs | 2 ++ src/distribution/normal.rs | 3 +++ src/distribution/pareto.rs | 3 +++ src/distribution/poisson.rs | 3 +++ src/distribution/students_t.rs | 4 +++- src/distribution/triangular.rs | 3 +++ src/distribution/uniform.rs | 3 +++ src/distribution/weibull.rs | 3 +++ src/function/beta.rs | 1 - src/statistics/slice_statistics.rs | 7 +++++++ 30 files changed, 99 insertions(+), 18 deletions(-) diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index 46648a7a..0af2f101 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -101,7 +101,7 @@ impl DiscreteCDF for Bernoulli { self.b.cdf(x) } - /// Calculates the survival function for the + /// Calculates the survival function for the /// bernoulli distribution at `x`. /// /// # Formula @@ -158,6 +158,7 @@ impl Distribution for Bernoulli { fn mean(&self) -> Option { self.b.mean() } + /// Returns the variance of the bernoulli /// distribution /// @@ -169,6 +170,7 @@ impl Distribution for Bernoulli { fn variance(&self) -> Option { self.b.variance() } + /// Returns the entropy of the bernoulli /// distribution /// @@ -181,6 +183,7 @@ impl Distribution for Bernoulli { fn entropy(&self) -> Option { self.b.entropy() } + /// Returns the skewness of the bernoulli /// distribution /// diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 4682948c..39d03598 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -155,7 +155,7 @@ impl ContinuousCDF for Beta { } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) { 1. - x } else { - beta::beta_reg(self.shape_b, self.shape_a, 1.0 - x) + beta::beta_reg(self.shape_b, self.shape_a, 1.0 - x) } } } @@ -208,6 +208,7 @@ impl Distribution for Beta { }; Some(mean) } + /// Returns the variance of the beta distribution /// /// # Remarks @@ -230,6 +231,7 @@ impl Distribution for Beta { }; Some(var) } + /// Returns the entropy of the beta distribution /// /// # Formula @@ -251,6 +253,7 @@ impl Distribution for Beta { }; Some(entr) } + /// Returns the skewness of the Beta distribution /// /// # Formula diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 2b4a3ea3..f65bf246 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -181,6 +181,7 @@ impl Distribution for Binomial { fn mean(&self) -> Option { Some(self.p * self.n as f64) } + /// Returns the variance of the binomial distribution /// /// # Formula @@ -191,6 +192,7 @@ impl Distribution for Binomial { fn variance(&self) -> Option { Some(self.p * (1.0 - self.p) * self.n as f64) } + /// Returns the entropy of the binomial distribution /// /// # Formula @@ -209,6 +211,7 @@ impl Distribution for Binomial { }; Some(entr) } + /// Returns the skewness of the binomial distribution /// /// # Formula diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index ba3c39de..dfd8a4ca 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -12,7 +12,7 @@ use std::f64; /// # Examples /// /// ``` -/// +/// /// use statrs::distribution::{Categorical, Discrete}; /// use statrs::statistics::Distribution; /// use statrs::prec; @@ -25,7 +25,7 @@ use std::f64; pub struct Categorical { norm_pmf: Vec, cdf: Vec, - sf: Vec + sf: Vec, } impl Categorical { @@ -194,6 +194,7 @@ impl Distribution for Categorical { .fold(0.0, |acc, (idx, &val)| acc + idx as f64 * val), ) } + /// Returns the variance of the categorical distribution /// /// # Formula @@ -217,6 +218,7 @@ impl Distribution for Categorical { }); Some(var) } + /// Returns the entropy of the categorical distribution /// /// # Formula @@ -294,7 +296,7 @@ pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec { cdf } -/// Computes the sf from the given cumulative densities. +/// Computes the sf from the given cumulative densities. /// Performs no parameter or bounds checking. pub fn cdf_to_sf(cdf: &[f64]) -> Vec { let max = *cdf.last().unwrap(); diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index f72376bb..2ca32518 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -185,6 +185,7 @@ impl Distribution for Chi { Some(mean) } } + /// Returns the variance of the chi distribution /// /// # Remarks @@ -203,6 +204,7 @@ impl Distribution for Chi { let mean = self.mean()?; Some(self.freedom - mean * mean) } + /// Returns the entropy of the chi distribution /// /// # Remarks @@ -228,6 +230,7 @@ impl Distribution for Chi { / 2.0; Some(entr) } + /// Returns the skewness of the chi distribution /// /// # Remarks diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index ab8dc398..cf07f4cf 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -177,6 +177,7 @@ impl Distribution for ChiSquared { fn mean(&self) -> Option { self.g.mean() } + /// Returns the variance of the chi-squared distribution /// /// # Formula @@ -189,6 +190,7 @@ impl Distribution for ChiSquared { fn variance(&self) -> Option { self.g.variance() } + /// Returns the entropy of the chi-squared distribution /// /// # Formula @@ -202,6 +204,7 @@ impl Distribution for ChiSquared { fn entropy(&self) -> Option { self.g.entropy() } + /// Returns the skewness of the chi-squared distribution /// /// # Formula diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index daa081a2..9a66e5e0 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -56,7 +56,6 @@ impl ContinuousCDF for Dirac { /// dirac distribution at `x` /// /// Where the value is 1 if x > `v`, 0 otherwise. - /// fn cdf(&self, x: f64) -> f64 { if x < self.0 { 0.0 @@ -69,7 +68,6 @@ impl ContinuousCDF for Dirac { /// dirac distribution at `x` /// /// Where the value is 0 if x > `v`, 1 otherwise. - /// fn sf(&self, x: f64) -> f64 { if x < self.0 { 1.0 @@ -117,6 +115,7 @@ impl Distribution for Dirac { fn mean(&self) -> Option { Some(self.0) } + /// Returns the variance of the dirac distribution /// /// # Formula @@ -129,6 +128,7 @@ impl Distribution for Dirac { fn variance(&self) -> Option { Some(0.0) } + /// Returns the entropy of the dirac distribution /// /// # Formula @@ -141,6 +141,7 @@ impl Distribution for Dirac { fn entropy(&self) -> Option { Some(0.0) } + /// Returns the skewness of the dirac distribution /// /// # Formula diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index a08b3175..0f703056 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -107,6 +107,7 @@ impl Dirichlet { fn alpha_sum(&self) -> f64 { self.alpha.fold(0.0, |acc, x| acc + x) } + /// Returns the entropy of the dirichlet distribution /// /// # Formula diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index 926b1cf3..59b0da4e 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -84,7 +84,7 @@ impl DiscreteCDF for DiscreteUniform { } fn sf(&self, x: i64) -> f64 { - //1. - self.cdf(x) + // 1. - self.cdf(x) if x < self.min { 1.0 } else if x >= self.max { @@ -137,6 +137,7 @@ impl Distribution for DiscreteUniform { fn mean(&self) -> Option { Some((self.min + self.max) as f64 / 2.0) } + /// Returns the variance of the discrete uniform distribution /// /// # Formula @@ -148,6 +149,7 @@ impl Distribution for DiscreteUniform { let diff = (self.max - self.min) as f64; Some(((diff + 1.0) * (diff + 1.0) - 1.0) / 12.0) } + /// Returns the entropy of the discrete uniform distribution /// /// # Formula @@ -159,6 +161,7 @@ impl Distribution for DiscreteUniform { let diff = (self.max - self.min) as f64; Some((diff + 1.0).ln()) } + /// Returns the skewness of the discrete uniform distribution /// /// # Formula diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index 43588819..0e8c964a 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -56,7 +56,6 @@ impl Empirical { /// /// let mut result = Empirical::new(); /// assert!(result.is_ok()); - /// /// ``` pub fn new() -> Result { Ok(Empirical { @@ -65,6 +64,7 @@ impl Empirical { data: BTreeMap::new(), }) } + pub fn from_vec(src: Vec) -> Empirical { let mut empirical = Empirical::new().unwrap(); for elt in src.into_iter() { @@ -72,6 +72,7 @@ impl Empirical { } empirical } + pub fn add(&mut self, data_point: f64) { if !data_point.is_nan() { self.sum += 1.; @@ -89,6 +90,7 @@ impl Empirical { *self.data.entry(NonNan(data_point)).or_insert(0) += 1; } } + pub fn remove(&mut self, data_point: f64) { if !data_point.is_nan() { if let (Some(val), Some((mean, var))) = @@ -111,6 +113,7 @@ impl Empirical { } } } + // Due to issues with rounding and floating-point accuracy the default // implementation may be ill-behaved. // Specialized inverse cdfs should be used whenever possible. @@ -158,7 +161,7 @@ impl ::rand::distributions::Distribution for Empirical { /// Panics if number of samples is zero impl Max for Empirical { fn max(&self) -> f64 { - self.data.keys().rev().map(|key| key.0) .next().unwrap() + self.data.keys().rev().map(|key| key.0).next().unwrap() } } @@ -173,6 +176,7 @@ impl Distribution for Empirical { fn mean(&self) -> Option { self.mean_and_var.map(|(mean, _)| mean) } + fn variance(&self) -> Option { self.mean_and_var.map(|(_, var)| var / (self.sum - 1.)) } @@ -256,8 +260,8 @@ mod tests { let unchanged = empirical.clone(); empirical.add(2.0); empirical.remove(2.0); - //because of rounding errors, this doesn't hold in general - //due to the mean and variance being calculated in a streaming way + // because of rounding errors, this doesn't hold in general + // due to the mean and variance being calculated in a streaming way assert_eq!(unchanged, empirical); } } diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index e0721f24..c959e122 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -166,6 +166,7 @@ impl Distribution for Erlang { fn mean(&self) -> Option { self.g.mean() } + /// Returns the variance of the erlang distribution /// /// # Formula @@ -178,6 +179,7 @@ impl Distribution for Erlang { fn variance(&self) -> Option { self.g.variance() } + /// Returns the entropy of the erlang distribution /// /// # Formula @@ -191,6 +193,7 @@ impl Distribution for Erlang { fn entropy(&self) -> Option { self.g.entropy() } + /// Returns the skewness of the erlang distribution /// /// # Formula diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index b52e1899..e0fe74dc 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -165,6 +165,7 @@ impl Distribution for Exp { fn mean(&self) -> Option { Some(1.0 / self.rate) } + /// Returns the variance of the exponential distribution /// /// # Formula @@ -177,6 +178,7 @@ impl Distribution for Exp { fn variance(&self) -> Option { Some(1.0 / (self.rate * self.rate)) } + /// Returns the entropy of the exponential distribution /// /// # Formula @@ -189,6 +191,7 @@ impl Distribution for Exp { fn entropy(&self) -> Option { Some(1.0 - self.rate.ln()) } + /// Returns the skewness of the exponential distribution /// /// # Formula diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index da8d9570..c5e2463a 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -144,8 +144,8 @@ impl ContinuousCDF for FisherSnedecor { } else { beta::beta_reg( self.freedom_2 / 2.0, - self.freedom_1 / 2.0, - 1. - ((self.freedom_1 * x) / (self.freedom_1 * x + self.freedom_2)) + self.freedom_1 / 2.0, + 1. - ((self.freedom_1 * x) / (self.freedom_1 * x + self.freedom_2)), ) } } @@ -206,6 +206,7 @@ impl Distribution for FisherSnedecor { Some(self.freedom_2 / (self.freedom_2 - 2.0)) } } + /// Returns the variance of the fisher-snedecor distribution /// /// # Panics @@ -237,6 +238,7 @@ impl Distribution for FisherSnedecor { Some(val) } } + /// Returns the skewness of the fisher-snedecor distribution /// /// # Panics diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 3fa926b0..ee592a49 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -232,6 +232,7 @@ impl Distribution for Gamma { fn mean(&self) -> Option { Some(self.shape / self.rate) } + /// Returns the variance of the gamma distribution /// /// # Formula @@ -244,6 +245,7 @@ impl Distribution for Gamma { fn variance(&self) -> Option { Some(self.shape / (self.rate * self.rate)) } + /// Returns the entropy of the gamma distribution /// /// # Formula @@ -260,6 +262,7 @@ impl Distribution for Gamma { + (1.0 - self.shape) * gamma::digamma(self.shape); Some(entr) } + /// Returns the skewness of the gamma distribution /// /// # Formula @@ -549,9 +552,9 @@ mod tests { for &(arg, x, res) in test.iter() { test_case(arg, res, f(x)); } - //TODO: test special + // TODO: test special // test_is_nan((10.0, f64::INFINITY), pdf(1.0)); // is this really the behavior we want? - //TODO: test special + // TODO: test special // (10.0, f64::INFINITY, f64::INFINITY, 0.0, pdf(f64::INFINITY)),]; } diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index f87f5ee0..d61b7fc8 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -159,6 +159,7 @@ impl Distribution for Geometric { fn mean(&self) -> Option { Some(1.0 / self.p) } + /// Returns the standard deviation of the geometric distribution /// /// # Formula @@ -169,6 +170,7 @@ impl Distribution for Geometric { fn variance(&self) -> Option { Some((1.0 - self.p) / (self.p * self.p)) } + /// Returns the entropy of the geometric distribution /// /// # Formula @@ -180,6 +182,7 @@ impl Distribution for Geometric { let inv = 1.0 / self.p; Some(-inv * (1. - self.p).log(2.0) + (inv - 1.).log(2.0)) } + /// Returns the skewness of the geometric distribution /// /// # Formula diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 95f44d18..1116ac7a 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -258,6 +258,7 @@ impl Distribution for Hypergeometric { Some(self.successes as f64 * self.draws as f64 / self.population as f64) } } + /// Returns the variance of the hypergeometric distribution /// /// # None @@ -281,6 +282,7 @@ impl Distribution for Hypergeometric { Some(val) } } + /// Returns the skewness of the hypergeometric distribution /// /// # None diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index 31b1d4f6..b55afd64 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -190,6 +190,7 @@ impl Distribution for InverseGamma { Some(self.rate / (self.shape - 1.0)) } } + /// Returns the variance of the inverse gamma distribution /// /// # None @@ -212,6 +213,7 @@ impl Distribution for InverseGamma { Some(val) } } + /// Returns the entropy of the inverse gamma distribution /// /// # Formula @@ -227,6 +229,7 @@ impl Distribution for InverseGamma { - (1.0 + self.shape) * gamma::digamma(self.shape); Some(entr) } + /// Returns the skewness of the inverse gamma distribution /// /// # None diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index 66893b46..d1ccc6da 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -193,6 +193,7 @@ impl Distribution for Laplace { fn mean(&self) -> Option { Some(self.location) } + /// Returns the variance of the laplace distribution /// /// # Formula @@ -205,6 +206,7 @@ impl Distribution for Laplace { fn variance(&self) -> Option { Some(2. * self.scale * self.scale) } + /// Returns the entropy of the laplace distribution /// /// # Formula @@ -217,6 +219,7 @@ impl Distribution for Laplace { fn entropy(&self) -> Option { Some((2. * self.scale).ln() + 1.) } + /// Returns the skewness of the laplace distribution /// /// # Formula diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index f2a3de9a..d991c020 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -182,6 +182,7 @@ impl Distribution for LogNormal { fn mean(&self) -> Option { Some((self.location + self.scale * self.scale / 2.0).exp()) } + /// Returns the variance of the log-normal distribution /// /// # Formula @@ -195,6 +196,7 @@ impl Distribution for LogNormal { let sigma2 = self.scale * self.scale; Some((sigma2.exp() - 1.0) * (self.location + self.location + sigma2).exp()) } + /// Returns the entropy of the log-normal distribution /// /// # Formula @@ -207,6 +209,7 @@ impl Distribution for LogNormal { fn entropy(&self) -> Option { Some(0.5 + self.scale.ln() + self.location + consts::LN_SQRT_2PI) } + /// Returns the skewness of the log-normal distribution /// /// # Formula diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index ff4ec6bc..368da0ec 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -205,6 +205,7 @@ impl<'a> Continuous<&'a DVector, f64> for MultivariateNormal { .unwrap(); self.pdf_const * exp_term.exp() } + /// Calculates the log probability density function for the multivariate /// normal distribution at `x`. Equivalent to pdf(x).ln(). fn ln_pdf(&self, x: &'a DVector) -> f64 { @@ -232,6 +233,7 @@ impl Continuous, f64> for MultivariateNormal { fn pdf(&self, x: Vec) -> f64 { self.pdf(&DVector::from(x)) } + /// Calculates the log probability density function for the multivariate /// normal distribution at `x`. Equivalent to pdf(x).ln(). fn ln_pdf(&self, x: Vec) -> f64 { diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index 4c69a869..a924ee8d 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -188,6 +188,7 @@ impl DiscreteDistribution for NegativeBinomial { fn mean(&self) -> Option { Some(self.r * (1.0 - self.p) / self.p) } + /// Returns the variance of the negative binomial distribution. /// /// # Formula @@ -198,6 +199,7 @@ impl DiscreteDistribution for NegativeBinomial { fn variance(&self) -> Option { Some(self.r * (1.0 - self.p) / (self.p * self.p)) } + /// Returns the skewness of the negative binomial distribution. /// /// # Formula diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index ba7a408c..5624a5c1 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -161,6 +161,7 @@ impl Distribution for Normal { fn mean(&self) -> Option { Some(self.mean) } + /// Returns the variance of the normal distribution /// /// # Formula @@ -173,6 +174,7 @@ impl Distribution for Normal { fn variance(&self) -> Option { Some(self.std_dev * self.std_dev) } + /// Returns the entropy of the normal distribution /// /// # Formula @@ -185,6 +187,7 @@ impl Distribution for Normal { fn entropy(&self) -> Option { Some(self.std_dev.ln() + consts::LN_SQRT_2PIE) } + /// Returns the skewness of the normal distribution /// /// # Formula diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 031205eb..5a18c03f 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -188,6 +188,7 @@ impl Distribution for Pareto { Some((self.shape * self.scale) / (self.shape - 1.0)) } } + /// Returns the variance of the Pareto distribution /// /// # Formula @@ -209,6 +210,7 @@ impl Distribution for Pareto { Some(a * a * self.shape / (self.shape - 2.0)) } } + /// Returns the entropy for the Pareto distribution /// /// # Formula @@ -221,6 +223,7 @@ impl Distribution for Pareto { fn entropy(&self) -> Option { Some(self.shape.ln() - self.scale.ln() - (1.0 / self.shape) - 1.0) } + /// Returns the skewness of the Pareto distribution /// /// # Panics diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index e8f98bef..ce07ce96 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -148,6 +148,7 @@ impl Distribution for Poisson { fn mean(&self) -> Option { Some(self.lambda) } + /// Returns the variance of the poisson distribution /// /// # Formula @@ -160,6 +161,7 @@ impl Distribution for Poisson { fn variance(&self) -> Option { Some(self.lambda) } + /// Returns the entropy of the poisson distribution /// /// # Formula @@ -177,6 +179,7 @@ impl Distribution for Poisson { - 19.0 / (360.0 * self.lambda * self.lambda * self.lambda), ) } + /// Returns the skewness of the poisson distribution /// /// # Formula diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index 4f84c489..02dd092b 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -252,6 +252,7 @@ impl Distribution for StudentsT { Some(self.location) } } + /// Returns the variance of the student's t-distribution /// /// # None @@ -280,6 +281,7 @@ impl Distribution for StudentsT { None } } + /// Returns the entropy for the student's t-distribution /// /// # Formula @@ -301,6 +303,7 @@ impl Distribution for StudentsT { + (self.freedom.sqrt() * beta::beta(self.freedom / 2.0, 0.5)).ln(); Some(result + shift) } + /// Returns the skewness of the student's t-distribution /// /// # None @@ -598,7 +601,6 @@ mod tests { test_case((0.0, 1.0, f64::INFINITY), 0.977249868051821, cdf(2.0)); } - #[test] fn test_sf() { let sf = |arg: f64| move |x: StudentsT| x.sf(arg); diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index a9cb98a9..a94bb0bb 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -165,6 +165,7 @@ impl Distribution for Triangular { fn mean(&self) -> Option { Some((self.min + self.max + self.mode) / 3.0) } + /// Returns the variance of the triangular distribution /// /// # Formula @@ -178,6 +179,7 @@ impl Distribution for Triangular { let c = self.mode; Some((a * a + b * b + c * c - a * b - a * c - b * c) / 18.0) } + /// Returns the entropy of the triangular distribution /// /// # Formula @@ -188,6 +190,7 @@ impl Distribution for Triangular { fn entropy(&self) -> Option { Some(0.5 + ((self.max - self.min) / 2.0).ln()) } + /// Returns the skewness of the triangular distribution /// /// # Formula diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index c4abc985..9414222a 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -140,6 +140,7 @@ impl Distribution for Uniform { fn mean(&self) -> Option { Some((self.min + self.max) / 2.0) } + /// Returns the variance for the continuous uniform distribution /// /// # Formula @@ -150,6 +151,7 @@ impl Distribution for Uniform { fn variance(&self) -> Option { Some((self.max - self.min) * (self.max - self.min) / 12.0) } + /// Returns the entropy for the continuous uniform distribution /// /// # Formula @@ -160,6 +162,7 @@ impl Distribution for Uniform { fn entropy(&self) -> Option { Some((self.max - self.min).ln()) } + /// Returns the skewness for the continuous uniform distribution /// /// # Formula diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 4f04403d..4b928aa7 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -177,6 +177,7 @@ impl Distribution for Weibull { fn mean(&self) -> Option { Some(self.scale * gamma::gamma(1.0 + 1.0 / self.shape)) } + /// Returns the variance of the weibull distribution /// /// # Formula @@ -191,6 +192,7 @@ impl Distribution for Weibull { let mean = self.mean()?; Some(self.scale * self.scale * gamma::gamma(1.0 + 2.0 / self.shape) - mean * mean) } + /// Returns the entropy of the weibull distribution /// /// # Formula @@ -207,6 +209,7 @@ impl Distribution for Weibull { + 1.0; Some(entr) } + /// Returns the skewness of the weibull distribution /// /// # Formula diff --git a/src/function/beta.rs b/src/function/beta.rs index fec184f2..128406c7 100644 --- a/src/function/beta.rs +++ b/src/function/beta.rs @@ -204,7 +204,6 @@ pub fn checked_beta_reg(a: f64, b: f64, x: f64) -> Result { } /// Computes the inverse of the regularized incomplete beta function -// // This code is based on the implementation in the ["special"][1] crate, // which in turn is based on a [C implementation][2] by John Burkardt. The // original algorithm was published in Applied Statistics and is known as diff --git a/src/statistics/slice_statistics.rs b/src/statistics/slice_statistics.rs index 1d1b79cc..a9cbfdde 100644 --- a/src/statistics/slice_statistics.rs +++ b/src/statistics/slice_statistics.rs @@ -7,6 +7,7 @@ pub struct Data(D); impl> Index for Data { type Output = f64; + fn index(&self, i: usize) -> &f64 { &self.0.as_ref()[i] } @@ -22,18 +23,23 @@ impl + AsRef<[f64]>> Data { pub fn new(data: D) -> Self { Data(data) } + pub fn swap(&mut self, i: usize, j: usize) { self.0.as_mut().swap(i, j) } + pub fn len(&self) -> usize { self.0.as_ref().len() } + pub fn is_empty(&self) -> bool { self.0.as_ref().len() == 0 } + pub fn iter(&self) -> core::slice::Iter<'_, f64> { self.0.as_ref().iter() } + // Selection algorithm from Numerical Recipes // See: https://en.wikipedia.org/wiki/Selection_algorithm fn select_inplace(&mut self, rank: usize) -> f64 { @@ -299,6 +305,7 @@ impl + AsRef<[f64]>> Distribution for Data { fn mean(&self) -> Option { Some(Statistics::mean(self.iter())) } + /// Estimates the unbiased population variance from the provided samples /// /// # Remarks From 142db1f17dc3d9ab39a0a70ae4114c865e5d901f Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Mon, 22 Apr 2024 20:36:12 +0200 Subject: [PATCH 048/185] Skip formatting on hand-formatted data --- src/distribution/students_t.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index 02dd092b..e8c183f3 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -1098,7 +1098,7 @@ mod tests { // for p in ps: // q = t.invcdf(p, df) // print(f"({p:5.3f}, {df:5.1f}, {float(q)}),") - // + #[rustfmt::skip] let invcdf_data = [ // p df inverse_cdf(p, df) (0.001, 1.0, -318.30883898555044), From 6f784813be68d97389e2eb3b68a2c62801783847 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Mon, 22 Apr 2024 20:38:47 +0200 Subject: [PATCH 049/185] Remove some unusual formatting rules --- rustfmt.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/rustfmt.toml b/rustfmt.toml index d1c82741..2f399e9c 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -12,9 +12,6 @@ use_small_heuristics = "Default" # TODO: single line functions only where short, please? # https://github.com/rust-lang/rustfmt/issues/3358 fn_single_line = false -fn_params_layout = "Compressed" -overflow_delimited_expr = true -where_single_line = true # enum_discrim_align_threshold = 20 # struct_field_align_threshold = 20 @@ -23,7 +20,6 @@ where_single_line = true edition = "2021" # Misc: -inline_attribute_width = 80 blank_lines_upper_bound = 2 reorder_impl_items = true # report_todo = "Unnumbered" From 37bd70db014d929a1e24766472d46997dbacbc9f Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Mon, 22 Apr 2024 20:39:09 +0200 Subject: [PATCH 050/185] Break long lines --- src/distribution/laplace.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index d1ccc6da..f04306f1 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -389,7 +389,13 @@ mod tests { #[test] fn test_entropy() { let entropy = |x: Laplace| x.entropy().unwrap(); - test_almost(f64::NEG_INFINITY, 0.1, (2.0 * f64::consts::E * 0.1).ln(), 1E-12, entropy); + test_almost( + f64::NEG_INFINITY, + 0.1, + (2.0 * f64::consts::E * 0.1).ln(), + 1E-12, + entropy, + ); test_almost(-6.0, 1.0, (2.0 * f64::consts::E).ln(), 1E-12, entropy); test_almost(1.0, 7.0, (2.0 * f64::consts::E * 7.0).ln(), 1E-12, entropy); test_almost(5., 10., (2. * f64::consts::E * 10.).ln(), 1E-12, entropy); From 33d1699fb74d6c499140da1320b6aa509efdcb07 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 21 Apr 2024 10:08:46 -0500 Subject: [PATCH 051/185] feat: extend StatsError for finiteness --- src/error.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/error.rs b/src/error.rs index c76d8b32..ce0bb1a4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,6 +6,8 @@ use std::fmt; pub enum StatsError { /// Generic bad input parameter error BadParams, + /// An argument must be finite + ArgFinite(&'static str), /// An argument should have been positive and was not ArgMustBePositive(&'static str), /// An argument should have been non-negative and was not @@ -58,6 +60,7 @@ impl fmt::Display for StatsError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { StatsError::BadParams => write!(f, "Bad distribution parameters"), + StatsError::ArgFinite(s) => write!(f, "Argument {} must be finite", s), StatsError::ArgMustBePositive(s) => write!(f, "Argument {} must be positive", s), StatsError::ArgNotNegative(s) => write!(f, "Argument {} must be non-negative", s), StatsError::ArgIntervalIncl(s, min, max) => { From 73590e63b77d3785541a3546f45ec7ba3a833f04 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 21 Apr 2024 10:09:05 -0500 Subject: [PATCH 052/185] feat: reject constructing Uniform of infinite support additionally removes logic handling infinite support --- src/distribution/uniform.rs | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index 9414222a..1b672f03 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -31,7 +31,7 @@ impl Uniform { /// /// # Errors /// - /// Returns an error if `min` or `max` are `NaN` + /// Returns an error if `min` or `max` are `NaN` or unbounded /// /// # Examples /// @@ -44,12 +44,21 @@ impl Uniform { /// /// result = Uniform::new(f64::NAN, f64::NAN); /// assert!(result.is_err()); + /// + /// result = Uniform::new(f64::NEG_INFINITY, 1.0); + /// assert!(result.is_err()); /// ``` pub fn new(min: f64, max: f64) -> Result { - if min > max || min.is_nan() || max.is_nan() { - Err(StatsError::BadParams) - } else { - Ok(Uniform { min, max }) + if min.is_nan() || max.is_nan() { + return Err(StatsError::BadParams); + } + + match (min.is_finite(), max.is_finite(), min < max) { + (false, false, _) => Err(StatsError::ArgFinite("min and max")), + (false, true, _) => Err(StatsError::ArgFinite("min")), + (true, false, _) => Err(StatsError::ArgFinite("max")), + (true, true, false) => Err(StatsError::ArgLteArg("min", "max")), + (true, true, true) => Ok(Uniform { min, max }), } } } @@ -94,10 +103,6 @@ impl ContinuousCDF for Uniform { 1.0 } else if x >= self.max { 0.0 - } else if x.is_infinite() && self.max.is_infinite() { - 0.0 - } else if self.max.is_infinite() { - 1.0 } else { (self.max - x) / (self.max - self.min) } From 36b2145ff7db300a2d270b9684d339bf646ad664 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 21 Apr 2024 17:46:17 -0500 Subject: [PATCH 053/185] test: ensure test suite matches behavior for restricting bounded Uniform --- src/distribution/uniform.rs | 34 +++------------------------------- 1 file changed, 3 insertions(+), 31 deletions(-) diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index 1b672f03..30b42bc3 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -264,7 +264,7 @@ mod tests { fn try_create(min: f64, max: f64) -> Uniform { let n = Uniform::new(min, max); - assert!(n.is_ok()); + assert!(n.is_ok(), "failed create over interval [{}, {}]", min, max); n.unwrap() } @@ -304,19 +304,19 @@ mod tests { #[test] fn test_create() { - create_case(0.0, 0.0); create_case(0.0, 0.1); create_case(0.0, 1.0); - create_case(10.0, 10.0); create_case(-5.0, 11.0); create_case(-5.0, 100.0); } #[test] fn test_bad_create() { + bad_create_case(0.0, 0.0); bad_create_case(f64::NAN, 1.0); bad_create_case(1.0, f64::NAN); bad_create_case(f64::NAN, f64::NAN); + bad_create_case(0.0, f64::INFINITY); bad_create_case(1.0, 0.0); } @@ -327,7 +327,6 @@ mod tests { test_case(0.0, 2.0, 1.0 / 3.0, variance); test_almost(0.1, 4.0, 1.2675, 1e-15, variance); test_case(10.0, 11.0, 1.0 / 12.0, variance); - test_case(0.0, f64::INFINITY, f64::INFINITY, variance); } #[test] @@ -338,7 +337,6 @@ mod tests { test_almost(0.1, 4.0, 1.360976553135600743431, 1e-15, entropy); test_case(1.0, 10.0, 2.19722457733621938279, entropy); test_case(10.0, 11.0, 0.0, entropy); - test_case(0.0, f64::INFINITY, f64::INFINITY, entropy); } #[test] @@ -349,7 +347,6 @@ mod tests { test_case(0.1, 4.0, 0.0, skewness); test_case(1.0, 10.0, 0.0, skewness); test_case(10.0, 11.0, 0.0, skewness); - test_case(0.0, f64::INFINITY, 0.0, skewness); } #[test] @@ -360,7 +357,6 @@ mod tests { test_case(0.1, 4.0, 2.05, mode); test_case(1.0, 10.0, 5.5, mode); test_case(10.0, 11.0, 10.5, mode); - test_case(0.0, f64::INFINITY, f64::INFINITY, mode); } #[test] @@ -371,15 +367,11 @@ mod tests { test_case(0.1, 4.0, 2.05, median); test_case(1.0, 10.0, 5.5, median); test_case(10.0, 11.0, 10.5, median); - test_case(0.0, f64::INFINITY, f64::INFINITY, median); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Uniform| x.pdf(arg); - test_case(0.0, 0.0, 0.0, pdf(-5.0)); - test_case(0.0, 0.0, f64::INFINITY, pdf(0.0)); - test_case(0.0, 0.0, 0.0, pdf(5.0)); test_case(0.0, 0.1, 0.0, pdf(-5.0)); test_case(0.0, 0.1, 10.0, pdf(0.05)); test_case(0.0, 0.1, 0.0, pdf(5.0)); @@ -394,17 +386,11 @@ mod tests { test_case(-5.0, 100.0, 0.009523809523809523809524, pdf(-5.0)); test_case(-5.0, 100.0, 0.009523809523809523809524, pdf(0.0)); test_case(-5.0, 100.0, 0.0, pdf(101.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(-5.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(10.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(f64::INFINITY)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Uniform| x.ln_pdf(arg); - test_case(0.0, 0.0, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(0.0, 0.0, f64::INFINITY, ln_pdf(0.0)); - test_case(0.0, 0.0, f64::NEG_INFINITY, ln_pdf(5.0)); test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(-5.0)); test_almost(0.0, 0.1, 2.302585092994045684018, 1e-15, ln_pdf(0.05)); test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0)); @@ -419,38 +405,27 @@ mod tests { test_case(-5.0, 100.0, -4.653960350157523371101, ln_pdf(-5.0)); test_case(-5.0, 100.0, -4.653960350157523371101, ln_pdf(0.0)); test_case(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(101.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(10.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Uniform| x.cdf(arg); - test_case(0.0, 0.0, 0.0, cdf(0.0)); test_case(0.0, 0.1, 0.5, cdf(0.05)); test_case(0.0, 1.0, 0.5, cdf(0.5)); test_case(0.0, 10.0, 0.1, cdf(1.0)); test_case(0.0, 10.0, 0.5, cdf(5.0)); test_case(-5.0, 100.0, 0.0, cdf(-5.0)); test_case(-5.0, 100.0, 0.04761904761904761904762, cdf(0.0)); - test_case(0.0, f64::INFINITY, 0.0, cdf(10.0)); - test_case(0.0, f64::INFINITY, 1.0, cdf(f64::INFINITY)); } #[test] fn test_inverse_cdf() { let inverse_cdf = |arg: f64| move |x: Uniform| x.inverse_cdf(arg); - test_case(0.0, 0.0, 0.0, inverse_cdf(0.0)); - test_case(0.0, 0.0, 0.0, inverse_cdf(1.0)); test_case(0.0, 0.1, 0.05, inverse_cdf(0.5)); test_case(0.0, 10.0, 5.0, inverse_cdf(0.5)); test_case(1.0, 10.0, 1.0, inverse_cdf(0.0)); test_case(1.0, 10.0, 4.0, inverse_cdf(1.0 / 3.0)); test_case(1.0, 10.0, 10.0, inverse_cdf(1.0)); - test_case(f64::NEG_INFINITY, f64::INFINITY, f64::NEG_INFINITY, inverse_cdf(0.0)); - test_case(0.0, f64::INFINITY, 0.0, inverse_cdf(0.0)); - test_case(0.0, f64::INFINITY, f64::INFINITY, inverse_cdf(1.0)); } #[test] @@ -469,15 +444,12 @@ mod tests { #[test] fn test_sf() { let sf = |arg: f64| move |x: Uniform| x.sf(arg); - test_case(0.0, 0.0, 1.0, sf(0.0)); test_case(0.0, 0.1, 0.5, sf(0.05)); test_case(0.0, 1.0, 0.5, sf(0.5)); test_case(0.0, 10.0, 0.9, sf(1.0)); test_case(0.0, 10.0, 0.5, sf(5.0)); test_case(-5.0, 100.0, 1.0, sf(-5.0)); test_case(-5.0, 100.0, 0.9523809523809523, sf(0.0)); - test_case(0.0, f64::INFINITY, 1.0, sf(10.0)); - test_case(0.0, f64::INFINITY, 0.0, sf(f64::INFINITY)); } #[test] From 8091e613c9a39f77313e0d6115b0022b0553957a Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 24 May 2024 19:11:27 -0500 Subject: [PATCH 054/185] chore: run fmt before introducing fmt into CI --- src/distribution/internal.rs | 5 ++++- src/distribution/log_normal.rs | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index f88c5f62..d9d75546 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -23,7 +23,10 @@ pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool { /// - function found not semi-monotone on the provided interval containing `z` /// Evaluates to `Some(k)`, where `k` satisfies the search criteria pub fn integral_bisection_search( - f: impl Fn(&K) -> T, z: T, lb: K, ub: K, + f: impl Fn(&K) -> T, + z: T, + lb: K, + ub: K, ) -> Option { if !(f(&lb)..=f(&ub)).contains(&z) { return None; diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index d991c020..b74986f9 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -113,6 +113,7 @@ impl ContinuousCDF for LogNormal { 0.5 * erf::erfc((x.ln() - self.location) / (self.scale * f64::consts::SQRT_2)) } } + /// Calculates the inverse cumulative distribution function for the /// log-normal distribution at `p` /// From fc6190eaa008c5fff6bd7006f1c7030085ec5479 Mon Sep 17 00:00:00 2001 From: Ashwin Narayan Date: Mon, 6 May 2024 13:42:55 +0800 Subject: [PATCH 055/185] Add a a way to get a standard normal distribution easily. --- src/distribution/normal.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 5624a5c1..b7aed9d8 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -51,6 +51,26 @@ impl Normal { Ok(Normal { mean, std_dev }) } } + + /// Constructs a new standard normal distribution with a mean of 0 + /// and a standard deviation of 1. + /// + /// + /// # Examples + /// + /// ``` + /// use statrs::distribution::Normal; + /// + /// let mut result = Normal::standard(); + /// ``` + pub fn new() -> Normal { + let mean: f64 = 0.0; + let std_dev: f64 = 1.0; + Normal { + mean, + std_dev + } + } } impl ::rand::distributions::Distribution for Normal { From c54aa190d637ba35fda1666114fa8d17d588f7ba Mon Sep 17 00:00:00 2001 From: Ashwin Narayan Date: Mon, 6 May 2024 13:43:43 +0800 Subject: [PATCH 056/185] Update normal.rs --- src/distribution/normal.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index b7aed9d8..5ac00b94 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -63,7 +63,7 @@ impl Normal { /// /// let mut result = Normal::standard(); /// ``` - pub fn new() -> Normal { + pub fn standard() -> Normal { let mean: f64 = 0.0; let std_dev: f64 = 1.0; Normal { From 0080262d5acd7c1e04ee8ecff4b0d22fd246d594 Mon Sep 17 00:00:00 2001 From: Ashwin Narayan Date: Thu, 23 May 2024 16:35:23 +0800 Subject: [PATCH 057/185] Implement standard normal using Default trait. --- src/distribution/normal.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 5ac00b94..62d9ccd6 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -52,7 +52,7 @@ impl Normal { } } - /// Constructs a new standard normal distribution with a mean of 0 + /// Constructs a new standard normal distribution with a mean of 0 /// and a standard deviation of 1. /// /// @@ -63,12 +63,10 @@ impl Normal { /// /// let mut result = Normal::standard(); /// ``` - pub fn standard() -> Normal { - let mean: f64 = 0.0; - let std_dev: f64 = 1.0; + fn standard() -> Normal { Normal { - mean, - std_dev + mean: 0.0, + std_dev: 1.0, } } } @@ -312,6 +310,15 @@ pub fn sample_unchecked(rng: &mut R, mean: f64, std_dev: f64) - mean + std_dev * ziggurat::sample_std_normal(rng) } + +impl std::default::Default for Normal { + /// Returns the standard normal distribution with a mean of 0 + /// and a standard deviation of 1. + fn default() -> Self { + Self::standard() + } +} + #[rustfmt::skip] #[cfg(all(test, feature = "nightly"))] mod tests { From 20298f0c607e968d53de3aae2afa4597c56f7bda Mon Sep 17 00:00:00 2001 From: Ashwin Narayan Date: Thu, 23 May 2024 16:42:27 +0800 Subject: [PATCH 058/185] Add new test case for standard normal. --- src/distribution/normal.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 62d9ccd6..c38a5731 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -63,7 +63,7 @@ impl Normal { /// /// let mut result = Normal::standard(); /// ``` - fn standard() -> Normal { + pub fn standard() -> Normal { Normal { mean: 0.0, std_dev: 1.0, @@ -539,4 +539,17 @@ mod tests { test_almost(5.0, 2.0, 10.0, 1e-14, inverse_cdf(0.9937903346742238648330218954258077788721022530769078)); test_case(5.0, 2.0, f64::INFINITY, inverse_cdf(1.0)); } + + #[test] + fn test_default() { + let n = Normal::default(); + + n_mean = n.mean().unwrap(); + n_std = n.std_dev().unwrap(); + + // Check that the mean of the distribution is close to 0 + assert_almost_eq!(n_mean, 0.0, 1e-15); + // Check that the standard deviation of the distribution is close to 1 + assert_almost_eq!(n_std, 1.0, 1e-15); + } } From 358aafb58bd25038a2379aefa3249ecb62037228 Mon Sep 17 00:00:00 2001 From: Ashwin Narayan Date: Fri, 24 May 2024 12:44:28 +0800 Subject: [PATCH 059/185] Ooops fix let --- src/distribution/normal.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index c38a5731..4d540eee 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -544,8 +544,8 @@ mod tests { fn test_default() { let n = Normal::default(); - n_mean = n.mean().unwrap(); - n_std = n.std_dev().unwrap(); + let n_mean = n.mean().unwrap(); + let n_std = n.std_dev().unwrap(); // Check that the mean of the distribution is close to 0 assert_almost_eq!(n_mean, 0.0, 1e-15); From cd4f2d9f5a16fc8489b1ae52cb16ea7f39e1e765 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sat, 25 May 2024 13:00:44 +0200 Subject: [PATCH 060/185] Remove `feature = "nightly"` gate where unneeded --- src/distribution/bernoulli.rs | 2 +- src/distribution/dirac.rs | 2 +- src/distribution/dirichlet.rs | 2 +- src/distribution/discrete_uniform.rs | 2 +- src/distribution/empirical.rs | 2 +- src/distribution/laplace.rs | 2 +- src/distribution/multivariate_normal.rs | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index 0af2f101..d0b6e219 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -257,7 +257,7 @@ impl Discrete for Bernoulli { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod testing { use std::fmt::Debug; use crate::distribution::DiscreteCDF; diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index 9a66e5e0..c781fa72 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -185,7 +185,7 @@ impl Mode> for Dirac { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Dirac}; diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index 0f703056..f47a0cfd 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -301,7 +301,7 @@ fn is_valid_alpha(a: &[f64]) -> bool { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use super::*; use nalgebra::{DVector}; diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index 59b0da4e..a128d9d3 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -248,7 +248,7 @@ impl Discrete for DiscreteUniform { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use std::fmt::Debug; use crate::statistics::*; diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index 0e8c964a..b22b78be 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -206,7 +206,7 @@ impl ContinuousCDF for Empirical { } } -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use super::*; #[test] diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index f04306f1..bcaaae08 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -291,7 +291,7 @@ impl Continuous for Laplace { } } -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use super::*; use rand::thread_rng; diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 368da0ec..e50ac3ac 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -242,7 +242,7 @@ impl Continuous, f64> for MultivariateNormal { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use crate::distribution::{Continuous, MultivariateNormal}; use crate::statistics::*; From db5d28ef5d0a5f97569c18d6ac1fe477fd968fcf Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 17 Apr 2024 10:40:03 +0200 Subject: [PATCH 061/185] Update CI workflow actions --- .github/workflows/test.yml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4f9fd9cd..e9ebd8cc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,14 +21,11 @@ jobs: target: x86_64-unknown-linux-gnu toolchain: nightly steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Install toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@master with: - profile: minimal - target: ${{ matrix.target }} toolchain: ${{ matrix.toolchain }} - override: true - name: Test nightly feature (if possible) if: ${{ matrix.toolchain == 'nightly' }} run: | @@ -36,4 +33,5 @@ jobs: cargo test --target ${{ matrix.target }} --benches --features=nightly - name: Test default features run: | - cargo test --target ${{ matrix.target }} \ No newline at end of file + cargo test --target ${{ matrix.target }} + From 7ac914854bc54d53d5d605037bbd615f4088b392 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 17 Apr 2024 10:47:08 +0200 Subject: [PATCH 062/185] Add clippy job to CI - Disable incremental compilation (useless on CI) - Run clippy first before test jobs - Add `-Dwarnings` to fail on compiler warnings --- .github/workflows/test.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e9ebd8cc..3b4bba8b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,9 +6,26 @@ on: pull_request: branches: [ master ] +env: + CARGO_INCREMENTAL: 0 + RUSTFLAGS: "-Dwarnings" + jobs: + clippy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install Rust stable with clippy + uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Run cargo clippy + run: cargo clippy --all-targets + test: name: Test + needs: clippy runs-on: ${{ matrix.os }} strategy: fail-fast: false From 8ebf43953f2d26a0a8d2b2bc17e8a139e9a3a5a5 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 17 Apr 2024 10:53:28 +0200 Subject: [PATCH 063/185] Expand CI test job - Run on macos, linux and windows - Remove explicit build target (host arch is fine) --- .github/workflows/test.yml | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3b4bba8b..f82eedc6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,31 +24,25 @@ jobs: run: cargo clippy --all-targets test: - name: Test needs: clippy runs-on: ${{ matrix.os }} strategy: - fail-fast: false matrix: - include: - - os: ubuntu-latest - target: x86_64-unknown-linux-gnu - toolchain: stable - - os: ubuntu-latest - target: x86_64-unknown-linux-gnu - toolchain: nightly + os: [ubuntu-latest, macos-latest, windows-latest] + toolchain: [stable, nightly] + steps: - uses: actions/checkout@v4 - - name: Install toolchain + - name: Install Rust ${{ matrix.toolchain }} uses: dtolnay/rust-toolchain@master with: toolchain: ${{ matrix.toolchain }} - - name: Test nightly feature (if possible) - if: ${{ matrix.toolchain == 'nightly' }} - run: | - cargo test --target ${{ matrix.target }} --features=nightly - cargo test --target ${{ matrix.target }} --benches --features=nightly + + - name: Test nightly feature + if: matrix.toolchain == 'nightly' + run: cargo test --all-targets --features=nightly + - name: Test default features - run: | - cargo test --target ${{ matrix.target }} + if: matrix.toolchain != 'nightly' + run: cargo test --all-targets From 7656a57cb6ba819be637c377d64d91bd9cd3c7e1 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Mon, 22 Apr 2024 17:35:18 +0200 Subject: [PATCH 064/185] Check formatting in CI via rustfmt --- .github/workflows/test.yml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f82eedc6..42099b38 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,8 +23,20 @@ jobs: - name: Run cargo clippy run: cargo clippy --all-targets + fmt: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install Rust nightly with rustfmt + uses: dtolnay/rust-toolchain@nightly + with: + components: rustfmt + + - name: Run rustfmt --check + run: cargo fmt -- --check + test: - needs: clippy + needs: [clippy, fmt] runs-on: ${{ matrix.os }} strategy: matrix: From e82f685690ed4ae23ed3d237f7c2660c47b331c4 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 1 May 2024 22:26:42 +0200 Subject: [PATCH 065/185] Fix clippy error introduced in #226 --- src/error.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/error.rs b/src/error.rs index ce0bb1a4..faa74fad 100644 --- a/src/error.rs +++ b/src/error.rs @@ -118,7 +118,7 @@ mod tests { #[test] fn test_sync_send() { // Error types should implement Sync and Send - let _ = assert_sync::(); - let _ = assert_send::(); + assert_sync::(); + assert_send::(); } } From 2e3c453ad71b611acff9c50b3fb52106c2e7e6c7 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sat, 25 May 2024 11:56:40 +0200 Subject: [PATCH 066/185] Allow some imprecision in specific test case This case caused a test failure on macOS --- src/distribution/cauchy.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index d66a9747..3af0ca63 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -432,7 +432,7 @@ mod tests { test_almost(0.0, 0.1, 0.9936346508990272, 1e-16, sf(-5.0)); test_almost(0.0, 0.1, 0.9682744825694465, 1e-16, sf(-1.0)); test_case(0.0, 0.1, 0.5, sf(0.0)); - test_case(0.0, 0.1, 0.03172551743055352, sf(1.0)); + test_almost(0.0, 0.1, 0.03172551743055352, 1e-16, sf(1.0)); test_case(0.0, 0.1, 0.006365349100972806, sf(5.0)); test_almost(0.0, 1.0, 0.9371670418109989, 1e-16, sf(-5.0)); test_case(0.0, 1.0, 0.75, sf(-1.0)); From d1fd362d92e245764f7475d978932cbcb3dd8879 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 26 May 2024 22:02:48 -0500 Subject: [PATCH 067/185] fix: define no mode for gamma with shape<1 --- src/distribution/gamma.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index ee592a49..b7574fd1 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -283,12 +283,16 @@ impl Mode> for Gamma { /// # Formula /// /// ```text - /// (α - 1) / β + /// (α - 1) / β, where α≥1 /// ``` /// /// where `α` is the shape and `β` is the rate fn mode(&self) -> Option { - Some((self.shape - 1.0) / self.rate) + if self.shape < 1.0 { + None + } else { + Some((self.shape - 1.0) / self.rate) + } } } From c0b85c3db80aef1fe4efe7ea5cfac2dac03b3eda Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sun, 26 May 2024 11:09:45 +0200 Subject: [PATCH 068/185] Update testing_boiler to not need nightly features Also update dependent testing modules --- src/distribution/beta.rs | 88 ++++++------- src/distribution/gamma.rs | 76 +++++------ src/distribution/internal.rs | 26 ++-- src/distribution/students_t.rs | 230 ++++++++++++++++----------------- 4 files changed, 210 insertions(+), 210 deletions(-) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 39d03598..307f1010 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -410,13 +410,13 @@ mod tests { use crate::statistics::*; use crate::testing_boiler; - testing_boiler!((f64, f64), Beta); + testing_boiler!(a: f64, b: f64; Beta); #[test] fn test_create() { let valid = [(1.0, 1.0), (9.0, 1.0), (5.0, 100.0), (1.0, f64::INFINITY), (f64::INFINITY, 1.0)]; - for &arg in valid.iter() { - try_create(arg); + for (a, b) in valid { + try_create(a, b); } } @@ -436,8 +436,8 @@ mod tests { (-1.0, -1.0), (f64::INFINITY, f64::INFINITY), ]; - for &arg in invalid.iter() { - bad_create_case(arg); + for (a, b) in invalid { + bad_create_case(a, b); } } @@ -451,8 +451,8 @@ mod tests { ((1.0, f64::INFINITY), 0.0), ((f64::INFINITY, 1.0), 1.0), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((a, b), res) in test { + test_case(a, b, res, f); } } @@ -466,8 +466,8 @@ mod tests { ((1.0, f64::INFINITY), 0.0), ((f64::INFINITY, 1.0), 0.0), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((a, b), res) in test { + test_case(a, b, res, f); } } @@ -478,53 +478,53 @@ mod tests { ((9.0, 1.0), -1.3083356884473304939016015), ((5.0, 100.0), -2.52016231876027436794592), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((a, b), res) in test { + test_case(a, b, res, f); } - test_case_special((1.0, 1.0), 0.0, 1e-14, f); + test_case_special(1.0, 1.0, 0.0, 1e-14, f); let entropy = |x: Beta| x.entropy(); - test_none((1.0, f64::INFINITY), entropy); - test_none((f64::INFINITY, 1.0), entropy); + test_none(1.0, f64::INFINITY, entropy); + test_none(f64::INFINITY, 1.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Beta| x.skewness().unwrap(); - test_case((1.0, 1.0), 0.0, skewness); - test_case((9.0, 1.0), -1.4740554623801777107177478829, skewness); - test_case((5.0, 100.0), 0.817594109275534303545831591, skewness); - test_case((1.0, f64::INFINITY), 2.0, skewness); - test_case((f64::INFINITY, 1.0), -2.0, skewness); + test_case(1.0, 1.0, 0.0, skewness); + test_case(9.0, 1.0, -1.4740554623801777107177478829, skewness); + test_case(5.0, 100.0, 0.817594109275534303545831591, skewness); + test_case(1.0, f64::INFINITY, 2.0, skewness); + test_case(f64::INFINITY, 1.0, -2.0, skewness); } #[test] fn test_mode() { let mode = |x: Beta| x.mode().unwrap(); - test_case((5.0, 100.0), 0.038834951456310676243255386, mode); - test_case((92.0, f64::INFINITY), 0.0, mode); - test_case((f64::INFINITY, 2.0), 1.0, mode); + test_case(5.0, 100.0, 0.038834951456310676243255386, mode); + test_case(92.0, f64::INFINITY, 0.0, mode); + test_case(f64::INFINITY, 2.0, 1.0, mode); } #[test] #[should_panic] fn test_mode_shape_a_lte_1() { let mode = |x: Beta| x.mode().unwrap(); - get_value((1.0, 5.0), mode); + get_value(1.0, 5.0, mode); } #[test] #[should_panic] fn test_mode_shape_b_lte_1() { let mode = |x: Beta| x.mode().unwrap(); - get_value((5.0, 1.0), mode); + get_value(5.0, 1.0, mode); } #[test] fn test_min_max() { let min = |x: Beta| x.min(); let max = |x: Beta| x.max(); - test_case((1.0, 1.0), 0.0, min); - test_case((1.0, 1.0), 1.0, max); + test_case(1.0, 1.0, 0.0, min); + test_case(1.0, 1.0, 1.0, max); } #[test] @@ -548,21 +548,21 @@ mod tests { ((f64::INFINITY, 1.0), 0.5, 0.0), ((f64::INFINITY, 1.0), 1.0, f64::INFINITY), ]; - for &(arg, x, expect) in test.iter() { - test_case(arg, expect, f(x)); + for ((a, b), x, expect) in test { + test_case(a, b, expect, f(x)); } } #[test] fn test_pdf_input_lt_0() { let pdf = |arg: f64| move |x: Beta| x.pdf(arg); - test_case((1.0, 1.0), 0.0, pdf(-1.0)); + test_case(1.0, 1.0, 0.0, pdf(-1.0)); } #[test] fn test_pdf_input_gt_0() { let pdf = |arg: f64| move |x: Beta| x.pdf(arg); - test_case((1.0, 1.0), 0.0, pdf(2.0)); + test_case(1.0, 1.0, 0.0, pdf(2.0)); } #[test] @@ -585,21 +585,21 @@ mod tests { ((f64::INFINITY, 1.0), 0.5, f64::NEG_INFINITY), ((f64::INFINITY, 1.0), 1.0, f64::INFINITY), ]; - for &(arg, x, expect) in test.iter() { - test_case(arg, expect, f(x)); + for ((a, b), x, expect) in test { + test_case(a, b, expect, f(x)); } } #[test] fn test_ln_pdf_input_lt_0() { let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg); - test_case((1.0, 1.0), f64::NEG_INFINITY, ln_pdf(-1.0)); + test_case(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); } #[test] fn test_ln_pdf_input_gt_1() { let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg); - test_case((1.0, 1.0), f64::NEG_INFINITY, ln_pdf(2.0)); + test_case(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(2.0)); } #[test] @@ -622,8 +622,8 @@ mod tests { ((f64::INFINITY, 1.0), 0.5, 0.0), ((f64::INFINITY, 1.0), 1.0, 1.0), ]; - for &(arg, x, expect) in test.iter() { - test_case(arg, expect, cdf(x)); + for ((a, b), x, expect) in test { + test_case(a, b, expect, cdf(x)); } } @@ -647,38 +647,38 @@ mod tests { ((f64::INFINITY, 1.0), 0.5, 1.0), ((f64::INFINITY, 1.0), 1.0, 0.0), ]; - for &(arg, x, expect) in test.iter() { - test_case(arg, expect, sf(x)); + for ((a, b), x, expect) in test { + test_case(a, b, expect, sf(x)); } } #[test] fn test_cdf_input_lt_0() { let cdf = |arg: f64| move |x: Beta| x.cdf(arg); - test_case((1.0, 1.0), 0.0, cdf(-1.0)); + test_case(1.0, 1.0, 0.0, cdf(-1.0)); } #[test] fn test_cdf_input_gt_1() { let cdf = |arg: f64| move |x: Beta| x.cdf(arg); - test_case((1.0, 1.0), 1.0, cdf(2.0)); + test_case(1.0, 1.0, 1.0, cdf(2.0)); } #[test] fn test_sf_input_lt_0() { let sf = |arg: f64| move |x: Beta| x.sf(arg); - test_case((1.0, 1.0), 1.0, sf(-1.0)); + test_case(1.0, 1.0, 1.0, sf(-1.0)); } #[test] fn test_sf_input_gt_1() { let sf = |arg: f64| move |x: Beta| x.sf(arg); - test_case((1.0, 1.0), 0.0, sf(2.0)); + test_case(1.0, 1.0, 0.0, sf(2.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create((1.2, 3.4)), 0.0, 1.0); - test::check_continuous_distribution(&try_create((4.5, 6.7)), 0.0, 1.0); + test::check_continuous_distribution(&try_create(1.2, 3.4), 0.0, 1.0); + test::check_continuous_distribution(&try_create(4.5, 6.7), 0.0, 1.0); } } diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index b7574fd1..6e8def6c 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -405,7 +405,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!((f64, f64), Gamma); + testing_boiler!(shape: f64, rate: f64; Gamma); #[test] fn test_create() { @@ -417,8 +417,8 @@ mod tests { (10.0, f64::INFINITY), ]; - for &arg in valid.iter() { - try_create(arg); + for (s, r) in valid { + try_create(s, r); } } @@ -432,8 +432,8 @@ mod tests { (-1.0, -1.0), (-1.0, f64::NAN), ]; - for &arg in invalid.iter() { - bad_create_case(arg); + for (s, r) in invalid { + bad_create_case(s, r); } } @@ -447,8 +447,8 @@ mod tests { ((10.0, 1.0), 10.0), ((10.0, f64::INFINITY), 0.0), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((s, r), res) in test { + test_case(s, r, res, f); } } @@ -462,8 +462,8 @@ mod tests { ((10.0, 1.0), 10.0), ((10.0, f64::INFINITY), 0.0), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((s, r), res) in test { + test_case(s, r, res, f); } } @@ -477,8 +477,8 @@ mod tests { ((10.0, 1.0), 2.53605417848097964238061239), ((10.0, f64::INFINITY), f64::NEG_INFINITY), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((s, r), res) in test { + test_case(s, r, res, f); } } @@ -492,8 +492,8 @@ mod tests { ((10.0, 1.0), 0.63245553203367586639977870), ((10.0, f64::INFINITY), 0.6324555320336758), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((s, r), res) in test { + test_case(s, r, res, f); } } @@ -501,16 +501,16 @@ mod tests { fn test_mode() { let f = |x: Gamma| x.mode().unwrap(); let test = [((1.0, 0.1), 0.0), ((1.0, 1.0), 0.0)]; - for &(arg, res) in test.iter() { - test_case_special(arg, res, 10e-6, f); + for &((s, r), res) in test.iter() { + test_case_special(s, r, res, 10e-6, f); } let test = [ ((10.0, 10.0), 0.9), ((10.0, 1.0), 9.0), ((10.0, f64::INFINITY), 0.0), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((s, r), res) in test { + test_case(s, r, res, f); } } @@ -524,8 +524,8 @@ mod tests { ((10.0, 1.0), 0.0), ((10.0, f64::INFINITY), 0.0), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((s, r), res) in test { + test_case(s, r, res, f); } let f = |x: Gamma| x.max(); let test = [ @@ -535,8 +535,8 @@ mod tests { ((10.0, 1.0), f64::INFINITY), ((10.0, f64::INFINITY), f64::INFINITY), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((s, r), res) in test { + test_case(s, r, res, f); } } @@ -553,8 +553,8 @@ mod tests { ((10.0, 1.0), 1.0, 0.000001013777119630297402), ((10.0, 1.0), 10.0, 0.125110035721133298984764), ]; - for &(arg, x, res) in test.iter() { - test_case(arg, res, f(x)); + for ((s, r), x, res) in test { + test_case(s, r, res, f(x)); } // TODO: test special // test_is_nan((10.0, f64::INFINITY), pdf(1.0)); // is this really the behavior we want? @@ -564,8 +564,8 @@ mod tests { #[test] fn test_pdf_at_zero() { - test_case((1.0, 0.1), 0.1, |x| x.pdf(0.0)); - test_case((1.0, 0.1), 0.1f64.ln(), |x| x.ln_pdf(0.0)); + test_case(1.0, 0.1, 0.1, |x| x.pdf(0.0)); + test_case(1.0, 0.1, 0.1f64.ln(), |x| x.ln_pdf(0.0)); } #[test] @@ -582,8 +582,8 @@ mod tests { ((10.0, 1.0), 10.0, -2.07856164313505845504579), ((10.0, f64::INFINITY), f64::INFINITY, f64::NEG_INFINITY), ]; - for &(arg, x, res) in test.iter() { - test_case(arg, res, f(x)); + for ((s, r), x, res) in test { + test_case(s, r, res, f(x)); } // TODO: test special // test_is_nan((10.0, f64::INFINITY), f(1.0)); // is this really the behavior we want? @@ -604,14 +604,14 @@ mod tests { ((10.0, f64::INFINITY), 1.0, 0.0), ((10.0, f64::INFINITY), 10.0, 1.0), ]; - for &(arg, x, res) in test.iter() { - test_case(arg, res, f(x)); + for ((s, r), x, res) in test { + test_case(s, r, res, f(x)); } } #[test] fn test_cdf_at_zero() { - test_case((1.0, 0.1), 0.0, |x| x.cdf(0.0)); + test_case(1.0, 0.1, 0.0, |x| x.cdf(0.0)); } #[test] @@ -625,10 +625,10 @@ mod tests { (100.0, 200.0), ]; - for param in params { + for (s, r) in params { for n in -5..0 { let p = 10.0f64.powi(n); - test_case(param, p, f(p)); + test_case(s, r, p, f(p)); } } @@ -636,7 +636,7 @@ mod tests { { let x = 20.5567; let f = |x: f64| move |g: Gamma| g.inverse_cdf(g.cdf(x)); - test_case((3.0, 0.5), x, f(x)) + test_case(3.0, 0.5, x, f(x)) } } @@ -655,19 +655,19 @@ mod tests { ((10.0, f64::INFINITY), 1.0, 1.0), ((10.0, f64::INFINITY), 10.0, 0.0), ]; - for &(arg, x, res) in test.iter() { - test_case(arg, res, f(x)); + for ((s, r), x, res) in test { + test_case(s, r, res, f(x)); } } #[test] fn test_sf_at_zero() { - test_case((1.0, 0.1), 1.0, |x| x.sf(0.0)); + test_case(1.0, 0.1, 1.0, |x| x.sf(0.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create((1.0, 0.5)), 0.0, 20.0); - test::check_continuous_distribution(&try_create((9.0, 2.0)), 0.0, 20.0); + test::check_continuous_distribution(&try_create(1.0, 0.5), 0.0, 20.0); + test::check_continuous_distribution(&try_create(9.0, 2.0), 0.0, 20.0); } } diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index d9d75546..07b04df9 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -63,52 +63,52 @@ pub mod test { #[macro_export] macro_rules! testing_boiler { - ($arg:ty, $dist:ty) => { - fn try_create(arg: $arg) -> $dist { - let n = <$dist>::new.call_once(arg); + ($($arg_name:ident: $arg_ty:ty),+; $dist:ty) => { + fn try_create($($arg_name: $arg_ty),+) -> $dist { + let n = <$dist>::new($($arg_name),+); assert!(n.is_ok()); n.unwrap() } - fn bad_create_case(arg: $arg) { - let n = <$dist>::new.call(arg); + fn bad_create_case($($arg_name: $arg_ty),+) { + let n = <$dist>::new($($arg_name),+); assert!(n.is_err()); } - fn get_value(arg: $arg, eval: F) -> T + fn get_value($($arg_name: $arg_ty),+, eval: F) -> T where F: Fn($dist) -> T, { - let n = try_create(arg); + let n = try_create($($arg_name),+); eval(n) } - fn test_case(arg: $arg, expected: T, eval: F) + fn test_case($($arg_name: $arg_ty),+, expected: T, eval: F) where F: Fn($dist) -> T, T: ::core::fmt::Debug + ::approx::RelativeEq, { - let x = get_value(arg, eval); + let x = get_value($($arg_name),+, eval); assert_relative_eq!(expected, x, max_relative = ACC); } #[allow(dead_code)] // This is not used by all distributions. - fn test_case_special(arg: $arg, expected: T, acc: f64, eval: F) + fn test_case_special($($arg_name: $arg_ty),+, expected: T, acc: f64, eval: F) where F: Fn($dist) -> T, T: ::core::fmt::Debug + ::approx::AbsDiffEq, { - let x = get_value(arg, eval); + let x = get_value($($arg_name),+, eval); assert_abs_diff_eq!(expected, x, epsilon = acc); } #[allow(dead_code)] // This is not used by all distributions. - fn test_none(arg: $arg, eval: F) + fn test_none($($arg_name: $arg_ty),+, eval: F) where F: Fn($dist) -> Option, T: ::core::cmp::PartialEq + ::core::fmt::Debug, { - let x = get_value(arg, eval); + let x = get_value($($arg_name),+, eval); assert_eq!(None, x); } }; diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index e8c183f3..f7448380 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -423,14 +423,14 @@ mod tests { use crate::testing_boiler; use std::panic; - testing_boiler!((f64, f64, f64), StudentsT); + testing_boiler!(location: f64, scale: f64, freedom: f64; StudentsT); #[test] fn test_create() { - try_create((0.0, 0.1, 1.0)); - try_create((0.0, 1.0, 1.0)); - try_create((-5.0, 1.0, 3.0)); - try_create((10.0, 10.0, f64::INFINITY)); + try_create(0.0, 0.1, 1.0); + try_create(0.0, 1.0, 1.0); + try_create(-5.0, 1.0, 3.0); + try_create(10.0, 10.0, f64::INFINITY); } // #[test] @@ -441,56 +441,56 @@ mod tests { #[test] fn test_bad_create() { - bad_create_case((f64::NAN, 1.0, 1.0)); - bad_create_case((0.0, f64::NAN, 1.0)); - bad_create_case((0.0, 1.0, f64::NAN)); - bad_create_case((0.0, -10.0, 1.0)); - bad_create_case((0.0, 10.0, -1.0)); + bad_create_case(f64::NAN, 1.0, 1.0); + bad_create_case(0.0, f64::NAN, 1.0); + bad_create_case(0.0, 1.0, f64::NAN); + bad_create_case(0.0, -10.0, 1.0); + bad_create_case(0.0, 10.0, -1.0); } #[test] fn test_mean() { let mean = |x: StudentsT| x.mean().unwrap(); - test_case((0.0, 1.0, 3.0), 0.0, mean); - test_case((0.0, 10.0, 2.0), 0.0, mean); - test_case((0.0, 10.0, f64::INFINITY), 0.0, mean); - test_case((-5.0, 100.0, 1.5), -5.0, mean); + test_case(0.0, 1.0, 3.0, 0.0, mean); + test_case(0.0, 10.0, 2.0, 0.0, mean); + test_case(0.0, 10.0, f64::INFINITY, 0.0, mean); + test_case(-5.0, 100.0, 1.5, -5.0, mean); let mean = |x: StudentsT| x.mean(); - test_none((0.0, 1.0, 1.0), mean); - test_none((0.0, 0.1, 1.0), mean); - test_none((0.0, 10.0, 1.0), mean); - test_none((10.0, 1.0, 1.0), mean); - test_none((0.0, f64::INFINITY, 1.0), mean); + test_none(0.0, 1.0, 1.0, mean); + test_none(0.0, 0.1, 1.0, mean); + test_none(0.0, 10.0, 1.0, mean); + test_none(10.0, 1.0, 1.0, mean); + test_none(0.0, f64::INFINITY, 1.0, mean); } #[test] #[should_panic] fn test_mean_freedom_lte_1() { let mean = |x: StudentsT| x.mean().unwrap(); - get_value((1.0, 1.0, 0.5), mean); + get_value(1.0, 1.0, 0.5, mean); } #[test] fn test_variance() { let variance = |x: StudentsT| x.variance().unwrap(); - test_case((0.0, 1.0, 3.0), 3.0, variance); - test_case((0.0, 10.0, 2.5), 500.0, variance); - test_case((10.0, 1.0, 2.5), 5.0, variance); + test_case(0.0, 1.0, 3.0, 3.0, variance); + test_case(0.0, 10.0, 2.5, 500.0, variance); + test_case(10.0, 1.0, 2.5, 5.0, variance); let variance = |x: StudentsT| x.variance(); - test_none((0.0, 10.0, 2.0), variance); - test_none((0.0, 1.0, 1.0), variance); - test_none((0.0, 0.1, 1.0), variance); - test_none((0.0, 10.0, 1.0), variance); - test_none((10.0, 1.0, 1.0), variance); - test_none((-5.0, 100.0, 1.5), variance); - test_none((0.0, f64::INFINITY, 1.0), variance); + test_none(0.0, 10.0, 2.0, variance); + test_none(0.0, 1.0, 1.0, variance); + test_none(0.0, 0.1, 1.0, variance); + test_none(0.0, 10.0, 1.0, variance); + test_none(10.0, 1.0, 1.0, variance); + test_none(-5.0, 100.0, 1.5, variance); + test_none(0.0, f64::INFINITY, 1.0, variance); } #[test] #[should_panic] fn test_variance_freedom_lte1() { let variance = |x: StudentsT| x.variance().unwrap(); - get_value((1.0, 1.0, 0.5), variance); + get_value(1.0, 1.0, 0.5, variance); } // TODO: valid skewness tests @@ -498,134 +498,134 @@ mod tests { #[should_panic] fn test_skewness_freedom_lte_3() { let skewness = |x: StudentsT| x.skewness().unwrap(); - get_value((1.0, 1.0, 1.0), skewness); + get_value(1.0, 1.0, 1.0, skewness); } #[test] fn test_mode() { let mode = |x: StudentsT| x.mode().unwrap(); - test_case((0.0, 1.0, 1.0), 0.0, mode); - test_case((0.0, 0.1, 1.0), 0.0, mode); - test_case((0.0, 1.0, 3.0), 0.0, mode); - test_case((0.0, 10.0, 1.0), 0.0, mode); - test_case((0.0, 10.0, 2.0), 0.0, mode); - test_case((0.0, 10.0, 2.5), 0.0, mode); - test_case((0.0, 10.0, f64::INFINITY), 0.0, mode); - test_case((10.0, 1.0, 1.0), 10.0, mode); - test_case((10.0, 1.0, 2.5), 10.0, mode); - test_case((-5.0, 100.0, 1.5), -5.0, mode); - test_case((0.0, f64::INFINITY, 1.0), 0.0, mode); + test_case(0.0, 1.0, 1.0, 0.0, mode); + test_case(0.0, 0.1, 1.0, 0.0, mode); + test_case(0.0, 1.0, 3.0, 0.0, mode); + test_case(0.0, 10.0, 1.0, 0.0, mode); + test_case(0.0, 10.0, 2.0, 0.0, mode); + test_case(0.0, 10.0, 2.5, 0.0, mode); + test_case(0.0, 10.0, f64::INFINITY, 0.0, mode); + test_case(10.0, 1.0, 1.0, 10.0, mode); + test_case(10.0, 1.0, 2.5, 10.0, mode); + test_case(-5.0, 100.0, 1.5, -5.0, mode); + test_case(0.0, f64::INFINITY, 1.0, 0.0, mode); } #[test] fn test_median() { let median = |x: StudentsT| x.median(); - test_case((0.0, 1.0, 1.0), 0.0, median); - test_case((0.0, 0.1, 1.0), 0.0, median); - test_case((0.0, 1.0, 3.0), 0.0, median); - test_case((0.0, 10.0, 1.0), 0.0, median); - test_case((0.0, 10.0, 2.0), 0.0, median); - test_case((0.0, 10.0, 2.5), 0.0, median); - test_case((0.0, 10.0, f64::INFINITY), 0.0, median); - test_case((10.0, 1.0, 1.0), 10.0, median); - test_case((10.0, 1.0, 2.5), 10.0, median); - test_case((-5.0, 100.0, 1.5), -5.0, median); - test_case((0.0, f64::INFINITY, 1.0), 0.0, median); + test_case(0.0, 1.0, 1.0, 0.0, median); + test_case(0.0, 0.1, 1.0, 0.0, median); + test_case(0.0, 1.0, 3.0, 0.0, median); + test_case(0.0, 10.0, 1.0, 0.0, median); + test_case(0.0, 10.0, 2.0, 0.0, median); + test_case(0.0, 10.0, 2.5, 0.0, median); + test_case(0.0, 10.0, f64::INFINITY, 0.0, median); + test_case(10.0, 1.0, 1.0, 10.0, median); + test_case(10.0, 1.0, 2.5, 10.0, median); + test_case(-5.0, 100.0, 1.5, -5.0, median); + test_case(0.0, f64::INFINITY, 1.0, 0.0, median); } #[test] fn test_min_max() { let min = |x: StudentsT| x.min(); let max = |x: StudentsT| x.max(); - test_case((0.0, 1.0, 1.0), f64::NEG_INFINITY, min); - test_case((2.5, 100.0, 1.5), f64::NEG_INFINITY, min); - test_case((10.0, f64::INFINITY, 3.5), f64::NEG_INFINITY, min); - test_case((0.0, 1.0, 1.0), f64::INFINITY, max); - test_case((2.5, 100.0, 1.5), f64::INFINITY, max); - test_case((10.0, f64::INFINITY, 5.5), f64::INFINITY, max); + test_case(0.0, 1.0, 1.0, f64::NEG_INFINITY, min); + test_case(2.5, 100.0, 1.5, f64::NEG_INFINITY, min); + test_case(10.0, f64::INFINITY, 3.5, f64::NEG_INFINITY, min); + test_case(0.0, 1.0, 1.0, f64::INFINITY, max); + test_case(2.5, 100.0, 1.5, f64::INFINITY, max); + test_case(10.0, f64::INFINITY, 5.5, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: StudentsT| x.pdf(arg); - test_case((0.0, 1.0, 1.0), 0.318309886183791, pdf(0.0)); - test_case((0.0, 1.0, 1.0), 0.159154943091895, pdf(1.0)); - test_case((0.0, 1.0, 1.0), 0.159154943091895, pdf(-1.0)); - test_case((0.0, 1.0, 1.0), 0.063661977236758, pdf(2.0)); - test_case((0.0, 1.0, 1.0), 0.063661977236758, pdf(-2.0)); - test_case((0.0, 1.0, 2.0), 0.353553390593274, pdf(0.0)); - test_case((0.0, 1.0, 2.0), 0.192450089729875, pdf(1.0)); - test_case((0.0, 1.0, 2.0), 0.192450089729875, pdf(-1.0)); - test_case((0.0, 1.0, 2.0), 0.068041381743977, pdf(2.0)); - test_case((0.0, 1.0, 2.0), 0.068041381743977, pdf(-2.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.398942280401433, pdf(0.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.241970724519143, pdf(1.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.053990966513188, pdf(2.0)); + test_case(0.0, 1.0, 1.0, 0.318309886183791, pdf(0.0)); + test_case(0.0, 1.0, 1.0, 0.159154943091895, pdf(1.0)); + test_case(0.0, 1.0, 1.0, 0.159154943091895, pdf(-1.0)); + test_case(0.0, 1.0, 1.0, 0.063661977236758, pdf(2.0)); + test_case(0.0, 1.0, 1.0, 0.063661977236758, pdf(-2.0)); + test_case(0.0, 1.0, 2.0, 0.353553390593274, pdf(0.0)); + test_case(0.0, 1.0, 2.0, 0.192450089729875, pdf(1.0)); + test_case(0.0, 1.0, 2.0, 0.192450089729875, pdf(-1.0)); + test_case(0.0, 1.0, 2.0, 0.068041381743977, pdf(2.0)); + test_case(0.0, 1.0, 2.0, 0.068041381743977, pdf(-2.0)); + test_case(0.0, 1.0, f64::INFINITY, 0.398942280401433, pdf(0.0)); + test_case(0.0, 1.0, f64::INFINITY, 0.241970724519143, pdf(1.0)); + test_case(0.0, 1.0, f64::INFINITY, 0.053990966513188, pdf(2.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: StudentsT| x.ln_pdf(arg); - test_case((0.0, 1.0, 1.0), -1.144729885849399, ln_pdf(0.0)); - test_case((0.0, 1.0, 1.0), -1.837877066409348, ln_pdf(1.0)); - test_case((0.0, 1.0, 1.0), -1.837877066409348, ln_pdf(-1.0)); - test_case((0.0, 1.0, 1.0), -2.754167798283503, ln_pdf(2.0)); - test_case((0.0, 1.0, 1.0), -2.754167798283503, ln_pdf(-2.0)); - test_case((0.0, 1.0, 2.0), -1.039720770839917, ln_pdf(0.0)); - test_case((0.0, 1.0, 2.0), -1.647918433002166, ln_pdf(1.0)); - test_case((0.0, 1.0, 2.0), -1.647918433002166, ln_pdf(-1.0)); - test_case((0.0, 1.0, 2.0), -2.687639203842085, ln_pdf(2.0)); - test_case((0.0, 1.0, 2.0), -2.687639203842085, ln_pdf(-2.0)); - test_case((0.0, 1.0, f64::INFINITY), -0.918938533204672, ln_pdf(0.0)); - test_case((0.0, 1.0, f64::INFINITY), -1.418938533204674, ln_pdf(1.0)); - test_case((0.0, 1.0, f64::INFINITY), -2.918938533204674, ln_pdf(2.0)); + test_case(0.0, 1.0, 1.0, -1.144729885849399, ln_pdf(0.0)); + test_case(0.0, 1.0, 1.0, -1.837877066409348, ln_pdf(1.0)); + test_case(0.0, 1.0, 1.0, -1.837877066409348, ln_pdf(-1.0)); + test_case(0.0, 1.0, 1.0, -2.754167798283503, ln_pdf(2.0)); + test_case(0.0, 1.0, 1.0, -2.754167798283503, ln_pdf(-2.0)); + test_case(0.0, 1.0, 2.0, -1.039720770839917, ln_pdf(0.0)); + test_case(0.0, 1.0, 2.0, -1.647918433002166, ln_pdf(1.0)); + test_case(0.0, 1.0, 2.0, -1.647918433002166, ln_pdf(-1.0)); + test_case(0.0, 1.0, 2.0, -2.687639203842085, ln_pdf(2.0)); + test_case(0.0, 1.0, 2.0, -2.687639203842085, ln_pdf(-2.0)); + test_case(0.0, 1.0, f64::INFINITY, -0.918938533204672, ln_pdf(0.0)); + test_case(0.0, 1.0, f64::INFINITY, -1.418938533204674, ln_pdf(1.0)); + test_case(0.0, 1.0, f64::INFINITY, -2.918938533204674, ln_pdf(2.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: StudentsT| x.cdf(arg); - test_case((0.0, 1.0, 1.0), 0.5, cdf(0.0)); - test_case((0.0, 1.0, 1.0), 0.75, cdf(1.0)); - test_case((0.0, 1.0, 1.0), 0.25, cdf(-1.0)); - test_case((0.0, 1.0, 1.0), 0.852416382349567, cdf(2.0)); - test_case((0.0, 1.0, 1.0), 0.147583617650433, cdf(-2.0)); - test_case((0.0, 1.0, 2.0), 0.5, cdf(0.0)); - test_case((0.0, 1.0, 2.0), 0.788675134594813, cdf(1.0)); - test_case((0.0, 1.0, 2.0), 0.211324865405187, cdf(-1.0)); - test_case((0.0, 1.0, 2.0), 0.908248290463863, cdf(2.0)); - test_case((0.0, 1.0, 2.0), 0.091751709536137, cdf(-2.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.5, cdf(0.0)); + test_case(0.0, 1.0, 1.0, 0.5, cdf(0.0)); + test_case(0.0, 1.0, 1.0, 0.75, cdf(1.0)); + test_case(0.0, 1.0, 1.0, 0.25, cdf(-1.0)); + test_case(0.0, 1.0, 1.0, 0.852416382349567, cdf(2.0)); + test_case(0.0, 1.0, 1.0, 0.147583617650433, cdf(-2.0)); + test_case(0.0, 1.0, 2.0, 0.5, cdf(0.0)); + test_case(0.0, 1.0, 2.0, 0.788675134594813, cdf(1.0)); + test_case(0.0, 1.0, 2.0, 0.211324865405187, cdf(-1.0)); + test_case(0.0, 1.0, 2.0, 0.908248290463863, cdf(2.0)); + test_case(0.0, 1.0, 2.0, 0.091751709536137, cdf(-2.0)); + test_case(0.0, 1.0, f64::INFINITY, 0.5, cdf(0.0)); // TODO: these are curiously low accuracy and should be re-examined - test_case((0.0, 1.0, f64::INFINITY), 0.841344746068543, cdf(1.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.977249868051821, cdf(2.0)); + test_case(0.0, 1.0, f64::INFINITY, 0.841344746068543, cdf(1.0)); + test_case(0.0, 1.0, f64::INFINITY, 0.977249868051821, cdf(2.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: StudentsT| x.sf(arg); - test_case((0.0, 1.0, 1.0), 0.5, sf(0.0)); - test_case((0.0, 1.0, 1.0), 0.25, sf(1.0)); - test_case((0.0, 1.0, 1.0), 0.75, sf(-1.0)); - test_case((0.0, 1.0, 1.0), 0.147583617650433, sf(2.0)); - test_case((0.0, 1.0, 1.0), 0.852416382349566, sf(-2.0)); - test_case((0.0, 1.0, 2.0), 0.5, sf(0.0)); - test_case((0.0, 1.0, 2.0), 0.211324865405186, sf(1.0)); - test_case((0.0, 1.0, 2.0), 0.788675134594813, sf(-1.0)); - test_case((0.0, 1.0, 2.0), 0.091751709536137, sf(2.0)); - test_case((0.0, 1.0, 2.0), 0.908248290463862, sf(-2.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.5, sf(0.0)); + test_case(0.0, 1.0, 1.0, 0.5, sf(0.0)); + test_case(0.0, 1.0, 1.0, 0.25, sf(1.0)); + test_case(0.0, 1.0, 1.0, 0.75, sf(-1.0)); + test_case(0.0, 1.0, 1.0, 0.147583617650433, sf(2.0)); + test_case(0.0, 1.0, 1.0, 0.852416382349566, sf(-2.0)); + test_case(0.0, 1.0, 2.0, 0.5, sf(0.0)); + test_case(0.0, 1.0, 2.0, 0.211324865405186, sf(1.0)); + test_case(0.0, 1.0, 2.0, 0.788675134594813, sf(-1.0)); + test_case(0.0, 1.0, 2.0, 0.091751709536137, sf(2.0)); + test_case(0.0, 1.0, 2.0, 0.908248290463862, sf(-2.0)); + test_case(0.0, 1.0, f64::INFINITY, 0.5, sf(0.0)); // TODO: these are curiously low accuracy and should be re-examined - test_case((0.0, 1.0, f64::INFINITY), 0.158655253945057, sf(1.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.022750131947162, sf(2.0)); + test_case(0.0, 1.0, f64::INFINITY, 0.158655253945057, sf(1.0)); + test_case(0.0, 1.0, f64::INFINITY, 0.022750131947162, sf(2.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create((0.0, 1.0, 3.0)), -30.0, 30.0); - test::check_continuous_distribution(&try_create((0.0, 1.0, 10.0)), -10.0, 10.0); - test::check_continuous_distribution(&try_create((20.0, 0.5, 10.0)), 10.0, 30.0); + test::check_continuous_distribution(&try_create(0.0, 1.0, 3.0), -30.0, 30.0); + test::check_continuous_distribution(&try_create(0.0, 1.0, 10.0), -10.0, 10.0); + test::check_continuous_distribution(&try_create(20.0, 0.5, 10.0), 10.0, 30.0); } #[test] From a3d7d3b78681d288b1587b81e695b6d70f2d4ab7 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sun, 26 May 2024 11:23:41 +0200 Subject: [PATCH 069/185] Remove "nightly" feature --- Cargo.toml | 3 --- src/distribution/beta.rs | 2 +- src/distribution/binomial.rs | 2 +- src/distribution/categorical.rs | 2 +- src/distribution/cauchy.rs | 2 +- src/distribution/chi.rs | 2 +- src/distribution/chi_squared.rs | 2 +- src/distribution/erlang.rs | 2 +- src/distribution/exponential.rs | 2 +- src/distribution/fisher_snedecor.rs | 2 +- src/distribution/gamma.rs | 2 +- src/distribution/geometric.rs | 2 +- src/distribution/hypergeometric.rs | 2 +- src/distribution/internal.rs | 2 +- src/distribution/inverse_gamma.rs | 2 +- src/distribution/log_normal.rs | 2 +- src/distribution/negative_binomial.rs | 2 +- src/distribution/normal.rs | 2 +- src/distribution/pareto.rs | 2 +- src/distribution/poisson.rs | 2 +- src/distribution/students_t.rs | 2 +- src/distribution/triangular.rs | 2 +- src/distribution/uniform.rs | 2 +- src/distribution/weibull.rs | 2 +- src/lib.rs | 2 -- 25 files changed, 23 insertions(+), 28 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a8b6e4ca..7bd5c687 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,9 +17,6 @@ include = ["CHANGELOG.md", "LICENSE.md", "src/"] name = "statrs" path = "src/lib.rs" -[features] -nightly = [] - [dependencies] rand = "0.8" nalgebra = { version = "0.32", features = ["rand"] } diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 307f1010..53edb161 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -402,7 +402,7 @@ impl Continuous for Beta { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use super::*; use crate::consts::ACC; diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index f65bf246..8aeab96d 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -321,7 +321,7 @@ impl Discrete for Binomial { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use std::fmt::Debug; use crate::statistics::*; diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index dfd8a4ca..7b2d489e 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -344,7 +344,7 @@ fn test_binary_index() { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use std::fmt::Debug; use crate::statistics::*; diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index 3af0ca63..6da74f14 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -226,7 +226,7 @@ impl Continuous for Cauchy { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Cauchy}; diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index 2ca32518..1d520567 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -317,7 +317,7 @@ impl Continuous for Chi { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use std::f64; use crate::distribution::internal::*; diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index cf07f4cf..65346e8e 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -283,7 +283,7 @@ impl Continuous for ChiSquared { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use crate::statistics::Median; use crate::distribution::ChiSquared; diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index c959e122..8affd534 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -270,7 +270,7 @@ impl Continuous for Erlang { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use crate::distribution::Erlang; use crate::distribution::internal::*; diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index e0fe74dc..86e87b0e 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -271,7 +271,7 @@ impl Continuous for Exp { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use std::f64; use crate::statistics::*; diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index c5e2463a..367df2f9 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -355,7 +355,7 @@ impl Continuous for FisherSnedecor { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, FisherSnedecor}; diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 6e8def6c..a9cba88e 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -398,7 +398,7 @@ pub fn sample_unchecked(rng: &mut R, shape: f64, rate: f64) -> } } -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use super::*; use crate::consts::ACC; diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index d61b7fc8..955534c4 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -265,7 +265,7 @@ impl Discrete for Geometric { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use std::fmt::Debug; use crate::statistics::*; diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 1116ac7a..722629fd 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -366,7 +366,7 @@ impl Discrete for Hypergeometric { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use std::fmt::Debug; use crate::statistics::*; diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 07b04df9..2bda5291 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -55,7 +55,7 @@ pub fn integral_bisection_search( } #[macro_use] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] pub mod test { use super::*; use crate::consts::ACC; diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index b55afd64..edf928f7 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -305,7 +305,7 @@ impl Continuous for InverseGamma { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, InverseGamma}; diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index b74986f9..12d08e81 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -297,7 +297,7 @@ impl Continuous for LogNormal { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, LogNormal}; diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index a924ee8d..327bafd7 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -283,7 +283,7 @@ impl Discrete for NegativeBinomial { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use std::fmt::Debug; use crate::statistics::*; diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 4d540eee..a3bbdbe2 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -320,7 +320,7 @@ impl std::default::Default for Normal { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Normal}; diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 5a18c03f..768c6944 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -328,7 +328,7 @@ impl Continuous for Pareto { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Pareto}; diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index ce07ce96..417dd285 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -296,7 +296,7 @@ pub fn sample_unchecked(rng: &mut R, lambda: f64) -> f64 { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use std::fmt::Debug; use crate::statistics::*; diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index f7448380..70c90c05 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -414,7 +414,7 @@ impl Continuous for StudentsT { } } -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use crate::consts::ACC; use crate::distribution::internal::*; diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index a94bb0bb..22450b3e 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -311,7 +311,7 @@ fn sample_unchecked(rng: &mut R, min: f64, max: f64, mode: f64) } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use std::fmt::Debug; use crate::statistics::*; diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index 30b42bc3..6e36a16f 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -255,7 +255,7 @@ impl Continuous for Uniform { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Uniform}; diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 4b928aa7..10113613 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -325,7 +325,7 @@ impl Continuous for Weibull { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Weibull}; diff --git a/src/lib.rs b/src/lib.rs index ad234627..e8b0e7b7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,8 +49,6 @@ #![allow(clippy::many_single_char_names)] #![allow(unused_imports)] #![forbid(unsafe_code)] -#![cfg_attr(all(test, feature = "nightly"), feature(unboxed_closures))] -#![cfg_attr(all(test, feature = "nightly"), feature(fn_traits))] #[macro_use] extern crate approx; From 2b0e5e56bc783435c85565f6894f100a6939cebb Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sun, 26 May 2024 18:03:56 +0200 Subject: [PATCH 070/185] Make `consts::ACC` import not required for testing_boiler --- src/distribution/beta.rs | 1 - src/distribution/binomial.rs | 1 - src/distribution/categorical.rs | 1 - src/distribution/cauchy.rs | 1 - src/distribution/chi.rs | 1 - src/distribution/chi_squared.rs | 1 - src/distribution/dirac.rs | 1 - src/distribution/dirichlet.rs | 1 - src/distribution/discrete_uniform.rs | 1 - src/distribution/erlang.rs | 1 - src/distribution/exponential.rs | 1 - src/distribution/fisher_snedecor.rs | 1 - src/distribution/gamma.rs | 1 - src/distribution/geometric.rs | 1 - src/distribution/hypergeometric.rs | 1 - src/distribution/internal.rs | 3 +-- src/distribution/inverse_gamma.rs | 1 - src/distribution/log_normal.rs | 1 - src/distribution/multinomial.rs | 1 - src/distribution/multivariate_normal.rs | 1 - src/distribution/negative_binomial.rs | 1 - src/distribution/normal.rs | 1 - src/distribution/pareto.rs | 1 - src/distribution/poisson.rs | 1 - src/distribution/triangular.rs | 1 - src/distribution/uniform.rs | 1 - src/distribution/weibull.rs | 1 - 27 files changed, 1 insertion(+), 28 deletions(-) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 53edb161..35be0ee6 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -405,7 +405,6 @@ impl Continuous for Beta { #[cfg(test)] mod tests { use super::*; - use crate::consts::ACC; use super::super::internal::*; use crate::statistics::*; use crate::testing_boiler; diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 8aeab96d..bd61c8fe 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -327,7 +327,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{DiscreteCDF, Discrete, Binomial}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(p: f64, n: u64) -> Binomial { let n = Binomial::new(p, n); diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 7b2d489e..606fdd13 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -350,7 +350,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{Categorical, Discrete, DiscreteCDF}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(prob_mass: &[f64]) -> Categorical { let n = Categorical::new(prob_mass); diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index 6da74f14..186b7efc 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -231,7 +231,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Cauchy}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(location: f64, scale: f64) -> Cauchy { let n = Cauchy::new(location, scale); diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index 1d520567..272cb954 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -323,7 +323,6 @@ mod tests { use crate::distribution::internal::*; use crate::distribution::{Chi, Continuous, ContinuousCDF}; use crate::statistics::*; - use crate::consts::ACC; fn try_create(freedom: f64) -> Chi { let n = Chi::new(freedom); diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index 65346e8e..ddc082ad 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -288,7 +288,6 @@ mod tests { use crate::statistics::Median; use crate::distribution::ChiSquared; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(freedom: f64) -> ChiSquared { let n = ChiSquared::new(freedom); diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index c781fa72..4e9b99de 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -189,7 +189,6 @@ impl Mode> for Dirac { mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Dirac}; - use crate::consts::ACC; fn try_create(v: f64) -> Dirac { let d = Dirac::new(v); diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index f47a0cfd..d71d168b 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -308,7 +308,6 @@ mod tests { use crate::function::gamma; use crate::statistics::*; use crate::distribution::{Continuous, Dirichlet}; - use crate::consts::ACC; #[test] fn test_is_valid_alpha() { diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index a128d9d3..a36929a4 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -253,7 +253,6 @@ mod tests { use std::fmt::Debug; use crate::statistics::*; use crate::distribution::{DiscreteCDF, Discrete, DiscreteUniform}; - use crate::consts::ACC; fn try_create(min: i64, max: i64) -> DiscreteUniform { let n = DiscreteUniform::new(min, max); diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index 8affd534..4e88c969 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -274,7 +274,6 @@ impl Continuous for Erlang { mod tests { use crate::distribution::Erlang; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(shape: u64, rate: f64) -> Erlang { let n = Erlang::new(shape, rate); diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index 86e87b0e..b509e058 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -277,7 +277,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Exp}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(rate: f64) -> Exp { let n = Exp::new(rate); diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index 367df2f9..74815824 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -360,7 +360,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, FisherSnedecor}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(freedom_1: f64, freedom_2: f64) -> FisherSnedecor { let n = FisherSnedecor::new(freedom_1, freedom_2); diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index a9cba88e..e2500174 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -401,7 +401,6 @@ pub fn sample_unchecked(rng: &mut R, shape: f64, rate: f64) -> #[cfg(test)] mod tests { use super::*; - use crate::consts::ACC; use crate::distribution::internal::*; use crate::testing_boiler; diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index 955534c4..0dc7fad8 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -271,7 +271,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{DiscreteCDF, Discrete, Geometric}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(p: f64) -> Geometric { let n = Geometric::new(p); diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 722629fd..21f94326 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -372,7 +372,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{DiscreteCDF, Discrete, Hypergeometric}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(population: u64, successes: u64, draws: u64) -> Hypergeometric { let n = Hypergeometric::new(population, successes, draws); diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 2bda5291..391b7f35 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -58,7 +58,6 @@ pub fn integral_bisection_search( #[cfg(test)] pub mod test { use super::*; - use crate::consts::ACC; use crate::distribution::{Continuous, ContinuousCDF, Discrete, DiscreteCDF}; #[macro_export] @@ -89,7 +88,7 @@ pub mod test { T: ::core::fmt::Debug + ::approx::RelativeEq, { let x = get_value($($arg_name),+, eval); - assert_relative_eq!(expected, x, max_relative = ACC); + assert_relative_eq!(expected, x, max_relative = $crate::consts::ACC); } #[allow(dead_code)] // This is not used by all distributions. diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index edf928f7..5dccd8ce 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -310,7 +310,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, InverseGamma}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(shape: f64, rate: f64) -> InverseGamma { let n = InverseGamma::new(shape, rate); diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 12d08e81..46cd1c8a 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -302,7 +302,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, LogNormal}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(mean: f64, std_dev: f64) -> LogNormal { let n = LogNormal::new(mean, std_dev); diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index 31de9e57..f1ac0a8e 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -244,7 +244,6 @@ impl<'a> Discrete<&'a [u64], f64> for Multinomial { // mod tests { // use crate::statistics::*; // use crate::distribution::{Discrete, Multinomial}; -// use crate::consts::ACC; // fn try_create(p: &[f64], n: u64) -> Multinomial { // let dist = Multinomial::new(p, n); diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index e50ac3ac..8e238003 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -246,7 +246,6 @@ impl Continuous, f64> for MultivariateNormal { mod tests { use crate::distribution::{Continuous, MultivariateNormal}; use crate::statistics::*; - use crate::consts::ACC; use core::fmt::Debug; use nalgebra::base::allocator::Allocator; use nalgebra::{ diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index 327bafd7..cca59889 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -289,7 +289,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{DiscreteCDF, Discrete, NegativeBinomial}; use crate::distribution::internal::test; - use crate::consts::ACC; fn try_create(r: f64, p: f64) -> NegativeBinomial { let r = NegativeBinomial::new(r, p); diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index a3bbdbe2..8e7b1e1a 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -325,7 +325,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Normal}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(mean: f64, std_dev: f64) -> Normal { let n = Normal::new(mean, std_dev); diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 768c6944..7d6cd1e0 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -333,7 +333,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Pareto}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(scale: f64, shape: f64) -> Pareto { let p = Pareto::new(scale, shape); diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 417dd285..6ec943eb 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -302,7 +302,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{DiscreteCDF, Discrete, Poisson}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(lambda: f64) -> Poisson { let n = Poisson::new(lambda); diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index 22450b3e..848c2dbd 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -317,7 +317,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Triangular}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(min: f64, max: f64, mode: f64) -> Triangular { let n = Triangular::new(min, max, mode); diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index 6e36a16f..0fcc90ea 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -260,7 +260,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Uniform}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(min: f64, max: f64) -> Uniform { let n = Uniform::new(min, max); diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 10113613..3507a062 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -330,7 +330,6 @@ mod tests { use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Weibull}; use crate::distribution::internal::*; - use crate::consts::ACC; fn try_create(shape: f64, scale: f64) -> Weibull { let n = Weibull::new(shape, scale); From 0192393dd700962b612f868490c46a20ce334c9e Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sun, 26 May 2024 11:35:20 +0200 Subject: [PATCH 071/185] Fix some clippy warnings & errors - slice instead of `&Vec<_>` - remove legacy constants - allow number that is coincidentally close to E - use `consts::FRAC_1_PI` instead of literal value --- src/distribution/internal.rs | 2 +- src/distribution/negative_binomial.rs | 2 +- src/distribution/students_t.rs | 8 +++++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 391b7f35..301d4a9e 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -240,7 +240,7 @@ pub mod test { #[test] fn test_integer_bisection() { - fn search(z: usize, data: &Vec) -> Option { + fn search(z: usize, data: &[usize]) -> Option { integral_bisection_search(|idx: &usize| data[*idx], z, 0, data.len() - 1) } diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index cca59889..c8c1e725 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -395,7 +395,7 @@ mod tests { let min = |x: NegativeBinomial| x.min(); let max = |x: NegativeBinomial| x.max(); test_case(1.0, 0.5, 0, min); - test_case(1.0, 0.3, std::u64::MAX, max); + test_case(1.0, 0.3, u64::MAX, max); } #[test] diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index 70c90c05..5bbb628b 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -548,7 +548,7 @@ mod tests { #[test] fn test_pdf() { let pdf = |arg: f64| move |x: StudentsT| x.pdf(arg); - test_case(0.0, 1.0, 1.0, 0.318309886183791, pdf(0.0)); + test_case(0.0, 1.0, 1.0, std::f64::consts::FRAC_1_PI, pdf(0.0)); test_case(0.0, 1.0, 1.0, 0.159154943091895, pdf(1.0)); test_case(0.0, 1.0, 1.0, 0.159154943091895, pdf(-1.0)); test_case(0.0, 1.0, 1.0, 0.063661977236758, pdf(2.0)); @@ -768,6 +768,8 @@ mod tests { test(0.9, 011.0, 1.363); test(0.95, 011.0, 1.796); test(0.975, 011.0, 2.201); + // 2.718 is roughly equal to E + #[allow(clippy::approx_constant)] test(0.99, 011.0, 2.718); test(0.995, 011.0, 3.106); test(0.9975, 011.0, 3.497); @@ -1152,12 +1154,12 @@ mod tests { #[test] fn test_inv_cdf_p0() { let d = StudentsT::new(0.0, 1.0, 12.0).unwrap(); - assert_eq!(d.inverse_cdf(0.0), std::f64::NEG_INFINITY); + assert_eq!(d.inverse_cdf(0.0), f64::NEG_INFINITY); } #[test] fn test_inv_cdf_p1() { let d = StudentsT::new(0.0, 1.0, 12.0).unwrap(); - assert_eq!(d.inverse_cdf(1.0), std::f64::INFINITY); + assert_eq!(d.inverse_cdf(1.0), f64::INFINITY); } } From 12f7bf46bf67450c5a408cf59a8667537c3a49ea Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Mon, 27 May 2024 10:09:18 +0200 Subject: [PATCH 072/185] Remove nightly tests from CI --- .github/workflows/test.yml | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 42099b38..5c031691 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -41,20 +41,12 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - toolchain: [stable, nightly] steps: - uses: actions/checkout@v4 - - name: Install Rust ${{ matrix.toolchain }} - uses: dtolnay/rust-toolchain@master - with: - toolchain: ${{ matrix.toolchain }} - - - name: Test nightly feature - if: matrix.toolchain == 'nightly' - run: cargo test --all-targets --features=nightly + - name: Install Rust stable + uses: dtolnay/rust-toolchain@stable - name: Test default features - if: matrix.toolchain != 'nightly' run: cargo test --all-targets From 3e7970e2a2118c5e52276fbbe732ac00d867805f Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+YeungOnion@users.noreply.github.com> Date: Mon, 3 Jun 2024 17:11:41 -0500 Subject: [PATCH 073/185] Create dependabot.yml --- .github/dependabot.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..70f86105 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,16 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "cargo" + directory: "/" + open-pull-requests-limit: 10 + schedule: + interval: "monthly" + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" From 7fa83365b83b034a26e1970345be4e14dd89fa37 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 5 Jun 2024 11:17:48 +0200 Subject: [PATCH 074/185] Make `prec::almost_eq` a wrapper for `abs_diff_eq` --- src/prec.rs | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/prec.rs b/src/prec.rs index 042c8b22..9de39707 100644 --- a/src/prec.rs +++ b/src/prec.rs @@ -1,5 +1,7 @@ //! Provides utility functions for working with floating point precision +use approx::AbsDiffEq; + /// Standard epsilon, maximum relative precision of IEEE 754 double-precision /// floating point numbers (64 bit) e.g. `2^-53` pub const F64_PREC: f64 = 0.00000000000000011102230246251565; @@ -7,23 +9,10 @@ pub const F64_PREC: f64 = 0.00000000000000011102230246251565; /// Default accuracy for `f64`, equivalent to `0.0 * F64_PREC` pub const DEFAULT_F64_ACC: f64 = 0.0000000000000011102230246251565; -/// Returns true if `a` and `b `are within `acc` of each other. -/// If `a` or `b` are infinite, returns `true` only if both are -/// infinite and similarly signed. Always returns `false` if -/// either number is a `NAN`. +/// Compares if two floats are close via `approx::abs_diff_eq` +/// using a maximum absolute difference (epsilon) of `acc`. pub fn almost_eq(a: f64, b: f64, acc: f64) -> bool { - // only true if a and b are infinite with same - // sign - if a.is_infinite() || b.is_infinite() { - return a == b; - } - - // NANs are never equal - if a.is_nan() && b.is_nan() { - return false; - } - - (a - b).abs() < acc + a.abs_diff_eq(&b, acc) } /// Compares if two floats are close via `approx::relative_eq!` From e8e9c61b860241c70f9c71f2e07fbd6dde2cf44f Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 6 Jun 2024 14:32:05 -0500 Subject: [PATCH 075/185] fix: handle assertions for inf in `prec::almost_eq` --- src/prec.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/prec.rs b/src/prec.rs index 9de39707..a6dd7861 100644 --- a/src/prec.rs +++ b/src/prec.rs @@ -12,6 +12,9 @@ pub const DEFAULT_F64_ACC: f64 = 0.0000000000000011102230246251565; /// Compares if two floats are close via `approx::abs_diff_eq` /// using a maximum absolute difference (epsilon) of `acc`. pub fn almost_eq(a: f64, b: f64, acc: f64) -> bool { + if a.is_infinite() && b.is_infinite() { + return true; + } a.abs_diff_eq(&b, acc) } From af8e2382e6c358333f463ac3f7c0339dce9b49fb Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 24 May 2024 19:46:40 -0500 Subject: [PATCH 076/185] chore: docstring math should `text` instead of `ignore` --- src/distribution/exponential.rs | 2 +- src/distribution/log_normal.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index b509e058..105cd83e 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -114,7 +114,7 @@ impl ContinuousCDF for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// -ln(1 - p) / λ /// ``` /// diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 46cd1c8a..356a6ab3 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -123,7 +123,7 @@ impl ContinuousCDF for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// μ - σ * sqrt(2) * erfc_inv(2p) /// ``` /// From 673e3cce8a6c8e827d95a3fec52affbbdda5f8b8 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Wed, 29 May 2024 08:51:15 -0500 Subject: [PATCH 077/185] doc: update changelog for merged PRs --- CHANGELOG.md | 84 +++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 63 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b8b7ffc..2992f26f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,49 @@ -Unreleased - - -v0.16.0 +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.17.0](https://github.com/statrs-dev/statrs/compare/v0.16.0...v0.17.0) - 2024-05-30 + +### Added +- specializes `inverse_cdf()` for Uniform (#166) +- Add way to get standard normal distribution easily. (#228) +- reject constructing Uniform of infinite support (#218) +- extend `StatsError` for finiteness (#218) +- default implementation of survival function with generics (#179) +- update `MultivariateNormal` API + - construct from nalgebra with `MultivariateNormal::new_from_nalgebra` (#177) + - support `std::vec` vector input in addition to `nalgebra` vectors (#199) + +### Fixed +- Update nalgebra to 0.32 (#187) +- for Gamma with shape<1 there is no mode, returns `None` instead of some negative number (#212) +- fix precision of ::inverse_cdf with some newton raphson steps (#227) + - adds test case from #200 +- fix integer bisection for default implementation of `::inverse_cdf` (#220) + - also add tests from (#185) + +### Other +- Remove "nightly" feature and drop testing requirement for `nightly` (#234) +- Allow some imprecision in specific test case (#215) +- Update CI (#215) + - Check formatting in CI via rustfmt + - Expand CI test job + - Add clippy job to CI +- update README with formatting and adding to "Contributing" (#213) +- Add test asserting that `StatsError` is Sync & Send (#226) +- Rename private struct NonNAN to NonNan (#222) +- Remove `lazy-static` dependency and make FCACHE a proper const (#211) +- crate examples shall be in docstrings instead of README (#213) +- alias `inverse_cdf` as "quantile function" in docs (#213) +- docstrings with math shall be `text` instead of `ignore` (#213) + + + +## [0.16.0] - Adds an `sf` method to the `ContinuousCDF` and `DiscreteCDF` traits - Calculates the survival function (CDF complement) for the distribution. @@ -9,11 +51,11 @@ v0.16.0 - See [PR description](https://github.com/statrs-dev/statrs/pull/172) for in-depth changes - update `nalgebra` to `0.29` -v0.15.0 +## [v0.15.0](https://www.github.com/statrs-dev/statrs/compare/v0.15.0...v0.16.0) - upgrade `nalgebra` to `0.27.1` to avoid RUSTSEC-2021-0070 -v0.14.0 +## [v0.14.0](https://www.github.com/statrs-dev/statrs/compare/v0.14.0...v0.15.0) - upgrade `rand` dependency to `0.8` - fix inaccurate sampling of `Gamma` @@ -27,28 +69,28 @@ v0.14.0 - Moved to dynamic vectors in the MultivariateNormal distribution - Reduced a number of distribution-specific traits into the Distribution and DiscreteDistribution traits -v0.13.0 +## [v0.13.0](https://www.github.com/statrs-dev/statrs/compare/v0.12.0...v0.13.0) - Implemented `MultivariateNormal` distribution (depends on `nalgebra 0.19`) - Implemented `Dirac` distribution - Implemented `Negative Binomial` distribution -v0.12.0 +## [v0.12.0](https://www.github.com/statrs-dev/statrs/compare/v0.11.0...v0.12.0) - upgrade `rand` dependency to `0.7` -v0.11.0 +## [v0.11.0](https://www.github.com/statrs-dev/statrs/compare/v0.10.0...v0.11.0) - upgrade `rand` dependency to `0.6` - Implement `CheckedInverseCDF` and `InverseCDF` for `Normal` distribution -v0.10.0 +## [v0.10.0](https://www.github.com/statrs-dev/statrs/compare/v0.9.0...v0.10.0) - upgrade `rand` dependency to `0.5` - Removes the `Distribution` trait in favor of the `rand::distributions::Distribution` trait - Removed functions deprecated in `0.8.0` (`periodic`, `periodic_custom`, `sinusoidal`, `sinusoidal_custom`) -v0.9.0 +## [v0.9.0](https://www.github.com/statrs-dev/statrs/compare/v0.16.0...v0.17.0) - implemented infinite sequence generator for periodic sequence - implemented infinite sequence generator for sinusoidal sequence @@ -60,7 +102,7 @@ v0.9.0 - Implemented `Entropy` trait for the `Categorical` distribution - Add a `checked_` interface to all distribution methods and functions that may panic -v0.8.0 +## [v0.8.0](https://www.github.com/statrs-dev/statrs/compare/v0.16.0...v0.17.0) - `cdf(x)`, `pdf(x)` and `pmf(x)` now return the correct value instead of panicking when `x` is outside the range of values that the distribution can attain. - Fixed a bug in the `Uniform` distribution implementation where samples were drawn from range `[min, max + 1)` instead of `[min, max]`. The samples are now drawn correctly from the range `[min, max]`. @@ -97,14 +139,14 @@ assert!(x.min().is_nan()); Since the regression affects a very slim edge-case and the fix is very simple, no breaking changes to the `Statistics` API was deemed necessary -v0.7.0 +## [v0.7.0](https://www.github.com/statrs-dev/statrs/compare/v0.6.0...v0.7.0) - Implemented `Categorical` distribution - Implemented `Erlang` distribution - Implemented `Multinomial` distribution - New `InverseCDF` trait for distributions that implement the inverse cdf function -v0.6.0 +## [v0.6.0](https://www.github.com/statrs-dev/statrs/compare/v0.16.0...v0.17.0) - `gamma::gamma_ur`, `gamma::gamma_ui`, `gamma::gamma_lr`, and `gamma::gamma_li` now follow strict gamma function domain, panicking if `a` or `x` are not in `(0, +inf)` - `beta::beta_reg` no longer allows `0.0` for `a` or `b` arguments @@ -135,11 +177,11 @@ v0.6.0 - `Hypergeometric` now implements `Discrete` rather than `Discrete` - `Poisson` now implements `Discrete` rather than `Discrete` -v0.5.1 +## [v0.5.1](https://www.github.com/statrs-dev/statrs/compare/v0.5.0...v0.5.1) - Fixed critical bug in `normal::sample_unchecked` where it was returning `NaN` -v0.5.0 +## [v0.5.0](https://www.github.com/statrs-dev/statrs/compare/v0.4.0...v0.5.0) - Implemented the `logistic::logistic` special function - Implemented the `logistic::logit` special function @@ -154,22 +196,22 @@ v0.5.0 - `Binomial::pdf` and `Binomial::ln_pdf` now panic if `x > n` or `x < 0` - `Bernoulli::pdf` and `Bernoulli::ln_pdf` now panic if `x > 1` or `x < 0` -v0.4.0 +## [v0.4.0] - Implemented the `exponential::integral` special function - Implemented the `Cauchy` (otherwise known as the `Lorenz`) distribution - Implemented the `Dirichlet` distribution - `Continuous` and `Discrete` traits no longer dependent on `Distribution` trait -v0.3.2 +## [v0.3.2] - Implemented the `FisherSnedecor` (F) distribution -v0.3.1 +## [v0.3.1] - Removed print statements from `ln_pdf` method in `Beta` distribution -v0.3.0 +## [v0.3.0] - Moved methods `min` and `max` out of trait `Univariate` into their own respective traits `Min` and `Max` - Traits `Min`, `Max`, `Mean`, `Variance`, `Entropy`, `Skewness`, `Median`, and `Mode` moved from `distribution` module to `statistics` module @@ -184,7 +226,7 @@ v0.3.0 - `InplaceStatistics` renamed to `OrderStatistics`, all methods in `InplaceStatistics` have `_inplace` trimmed from method name. - Inverse DiGamma function implemented with signature `gamma::inv_digamma(x: f64) -> f64` -v0.2.0 +## [v0.2.0] - Created `statistics` module and `Statistics` trait - `Statistics` trait implementation for `[f64]` From 09939ce21b1d40b7d3d1224020ef5dfe33f6c60c Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Wed, 29 May 2024 09:13:23 -0500 Subject: [PATCH 078/185] release: 0.17.0 doc: remove codecov badge before release to crates.io badge is pending its correctness, see PR #229 for progress --- Cargo.toml | 2 +- README.md | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7bd5c687..4c99ae1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "statrs" -version = "0.16.0" +version = "0.17.0" authors = ["Michael Ma"] description = "Statistical computing library for Rust" license = "MIT" diff --git a/README.md b/README.md index 8b35a01e..4cd36567 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,6 @@ [![MIT licensed][license-badge]](./LICENSE.md) [![Crate][crates-badge]][crates-url] [![docs.rs](https://img.shields.io/docsrs/statrs)][docs-url] -[![codecov][codecov-badge]][codecov-url] [actions-test-badge]: https://github.com/statrs-dev/statrs/actions/workflows/test.yml/badge.svg [crates-badge]: https://img.shields.io/crates/v/statrs.svg From 885a0986bad9d09c7429d9b461f0f666327a82fd Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 7 Jun 2024 18:41:39 -0500 Subject: [PATCH 079/185] fix: handle signed infinty in `prec::almost_eq` --- src/prec.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/prec.rs b/src/prec.rs index a6dd7861..fc9a5836 100644 --- a/src/prec.rs +++ b/src/prec.rs @@ -13,7 +13,7 @@ pub const DEFAULT_F64_ACC: f64 = 0.0000000000000011102230246251565; /// using a maximum absolute difference (epsilon) of `acc`. pub fn almost_eq(a: f64, b: f64, acc: f64) -> bool { if a.is_infinite() && b.is_infinite() { - return true; + return a == b; } a.abs_diff_eq(&b, acc) } From 4523efc20720da4f7b13c051372970da5c69f588 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 7 Jun 2024 23:49:50 -0500 Subject: [PATCH 080/185] release(bench): removing benches from package requires remove benches from manifest If we find a consistent MSRV approach for dev vs lib depends, then we will certainly distribute the benches with the crate. --- Cargo.toml | 7 ------- 1 file changed, 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4c99ae1e..16e13b1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,10 +22,3 @@ rand = "0.8" nalgebra = { version = "0.32", features = ["rand"] } approx = "0.5.0" num-traits = "0.2.14" - -[dev-dependencies] -criterion = "0.3.3" - -[[bench]] -name = "order_statistics" -harness = false From 74edc70a20f42d5b5c540d9cfadfa235d9474a89 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sat, 8 Jun 2024 00:00:06 -0500 Subject: [PATCH 081/185] fix(bench): code in benches still needs criterion --- Cargo.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 16e13b1b..94277e09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,3 +22,6 @@ rand = "0.8" nalgebra = { version = "0.32", features = ["rand"] } approx = "0.5.0" num-traits = "0.2.14" + +[dev-dependencies] +criterion = "0.3.3" From 41555a1201db895483b18796b110b8087b29f2e0 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sat, 8 Jun 2024 00:03:11 -0500 Subject: [PATCH 082/185] chore: Release statrs version 0.17.1 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 94277e09..9823299f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "statrs" -version = "0.17.0" +version = "0.17.1" authors = ["Michael Ma"] description = "Statistical computing library for Rust" license = "MIT" From c0b7644defc622f57757ab93a1ac6bccee893a89 Mon Sep 17 00:00:00 2001 From: alimf17 Date: Sun, 9 Jun 2024 12:31:48 -0400 Subject: [PATCH 083/185] feat: Added in the std_dev method from the Distribution trait explicitly It simply wastes computation to square and then square root a value which is simply inherently part of the distribution definition. --- src/distribution/normal.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 8e7b1e1a..48436955 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -193,6 +193,13 @@ impl Distribution for Normal { Some(self.std_dev * self.std_dev) } + /// Returns the standard deviation of the normal distribution + /// # Remarks + /// This is the same standard deviation used to construct the distribution + fn std_dev(&self) -> Option { + Some(self.std_dev) + } + /// Returns the entropy of the normal distribution /// /// # Formula From 5693f70eae262c9b57f15b040a415943300ded96 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Wed, 29 May 2024 12:54:40 -0500 Subject: [PATCH 084/185] feat(c-common-traits): derive common traits for public structs feat: implement `Display` for `distribution` mod feat: implement `Display` for generated sequences feat: implement `Display` for `statistics::Data` --- src/distribution/bernoulli.rs | 8 +++- src/distribution/beta.rs | 8 +++- src/distribution/binomial.rs | 8 +++- src/distribution/categorical.rs | 8 +++- src/distribution/cauchy.rs | 8 +++- src/distribution/chi.rs | 8 +++- src/distribution/chi_squared.rs | 8 +++- src/distribution/dirac.rs | 6 +++ src/distribution/dirichlet.rs | 8 +++- src/distribution/discrete_uniform.rs | 8 +++- src/distribution/empirical.rs | 59 ++++++++++++++++++++++++- src/distribution/erlang.rs | 8 +++- src/distribution/exponential.rs | 8 +++- src/distribution/fisher_snedecor.rs | 6 +++ src/distribution/gamma.rs | 8 +++- src/distribution/geometric.rs | 8 +++- src/distribution/hypergeometric.rs | 14 +++++- src/distribution/inverse_gamma.rs | 8 +++- src/distribution/laplace.rs | 8 +++- src/distribution/log_normal.rs | 8 +++- src/distribution/multinomial.rs | 6 +++ src/distribution/multivariate_normal.rs | 8 +++- src/distribution/negative_binomial.rs | 8 +++- src/distribution/normal.rs | 9 +++- src/distribution/pareto.rs | 8 +++- src/distribution/poisson.rs | 8 +++- src/distribution/students_t.rs | 8 +++- src/distribution/triangular.rs | 8 +++- src/distribution/uniform.rs | 6 +++ src/distribution/weibull.rs | 8 +++- src/generate.rs | 35 +++++++++++++++ src/statistics/slice_statistics.rs | 27 ++++++++++- src/statistics/statistics.rs | 2 +- 33 files changed, 323 insertions(+), 31 deletions(-) diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index d0b6e219..61499ebd 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -20,7 +20,7 @@ use rand::Rng; /// assert_eq!(n.pmf(0), 0.5); /// assert_eq!(n.pmf(1), 0.5); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Bernoulli { b: Binomial, } @@ -80,6 +80,12 @@ impl Bernoulli { } } +impl std::fmt::Display for Bernoulli { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Bernoulli({})", self.p()) + } +} + impl ::rand::distributions::Distribution for Bernoulli { fn sample(&self, rng: &mut R) -> f64 { rng.gen_bool(self.p()) as u8 as f64 diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 35be0ee6..b3cfdcf4 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -19,7 +19,7 @@ use rand::Rng; /// assert_eq!(n.mean().unwrap(), 0.5); /// assert!(prec::almost_eq(n.pdf(0.5), 1.5, 1e-14)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Beta { shape_a: f64, shape_b: f64, @@ -86,6 +86,12 @@ impl Beta { } } +impl std::fmt::Display for Beta { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Beta(a={}, b={})", self.shape_a, self.shape_b) + } +} + impl ::rand::distributions::Distribution for Beta { fn sample(&self, rng: &mut R) -> f64 { // Generated by sampling two gamma distributions and normalizing. diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index bd61c8fe..2b56e6fc 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -21,7 +21,7 @@ use std::f64; /// assert_eq!(n.pmf(0), 0.03125); /// assert_eq!(n.pmf(3), 0.3125); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Binomial { p: f64, n: u64, @@ -87,6 +87,12 @@ impl Binomial { } } +impl std::fmt::Display for Binomial { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Bin({},{})", self.p, self.n) + } +} + impl ::rand::distributions::Distribution for Binomial { fn sample(&self, rng: &mut R) -> f64 { (0..self.n).fold(0.0, |acc, _| { diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 606fdd13..31bccf8b 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -21,7 +21,7 @@ use std::f64; /// assert!(prec::almost_eq(n.mean().unwrap(), 5.0 / 3.0, 1e-15)); /// assert_eq!(n.pmf(1), 1.0 / 3.0); /// ``` -#[derive(Debug, Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct Categorical { norm_pmf: Vec, cdf: Vec, @@ -77,6 +77,12 @@ impl Categorical { } } +impl std::fmt::Display for Categorical { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Cat({:#?})", self.norm_pmf) + } +} + impl ::rand::distributions::Distribution for Categorical { fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, &self.cdf) diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index 186b7efc..b3f40bda 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -17,7 +17,7 @@ use std::f64; /// assert_eq!(n.mode().unwrap(), 0.0); /// assert_eq!(n.pdf(1.0), 0.1591549430918953357689); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Cauchy { location: f64, scale: f64, @@ -79,6 +79,12 @@ impl Cauchy { } } +impl std::fmt::Display for Cauchy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Cauchy({}, {})", self.location, self.scale) + } +} + impl ::rand::distributions::Distribution for Cauchy { fn sample(&self, r: &mut R) -> f64 { self.location + self.scale * (f64::consts::PI * (r.gen::() - 0.5)).tan() diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index 272cb954..205cca11 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -19,7 +19,7 @@ use std::f64; /// assert!(prec::almost_eq(n.mean().unwrap(), 1.25331413731550025121, 1e-14)); /// assert!(prec::almost_eq(n.pdf(1.0), 0.60653065971263342360, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Chi { freedom: f64, } @@ -68,6 +68,12 @@ impl Chi { } } +impl std::fmt::Display for Chi { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "χ_{}", self.freedom) + } +} + impl ::rand::distributions::Distribution for Chi { fn sample(&self, rng: &mut R) -> f64 { (0..self.freedom as i64) diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index ddc082ad..1c6b42b0 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -21,7 +21,7 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), 3.0); /// assert!(prec::almost_eq(n.pdf(4.0), 0.107981933026376103901, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct ChiSquared { freedom: f64, g: Gamma, @@ -96,6 +96,12 @@ impl ChiSquared { } } +impl std::fmt::Display for ChiSquared { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "χ^2_{}", self.freedom) + } +} + impl ::rand::distributions::Distribution for ChiSquared { fn sample(&self, r: &mut R) -> f64 { ::rand::distributions::Distribution::sample(&self.g, r) diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index 4e9b99de..4fa6a390 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -45,6 +45,12 @@ impl Dirac { } } +impl std::fmt::Display for Dirac { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "δ_{}", self.0) + } +} + impl ::rand::distributions::Distribution for Dirac { fn sample(&self, _: &mut R) -> f64 { self.0 diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index d71d168b..55795a7a 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -26,7 +26,7 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), DVector::from_vec(vec![1.0 / 6.0, 1.0 / 3.0, 0.5])); /// assert_eq!(n.pdf(&DVector::from_vec(vec![0.33333, 0.33333, 0.33333])), 2.222155556222205); /// ``` -#[derive(Debug, Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct Dirichlet { alpha: DVector, } @@ -137,6 +137,12 @@ impl Dirichlet { } } +impl std::fmt::Display for Dirichlet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Dir({}, {})", self.alpha.len(), &self.alpha) + } +} + impl ::rand::distributions::Distribution> for Dirichlet { fn sample(&self, rng: &mut R) -> DVector { let mut sum = 0.0; diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index a36929a4..361cadd8 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -17,7 +17,7 @@ use rand::Rng; /// assert_eq!(n.mean().unwrap(), 2.5); /// assert_eq!(n.pmf(3), 1.0 / 6.0); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct DiscreteUniform { min: i64, max: i64, @@ -51,6 +51,12 @@ impl DiscreteUniform { } } +impl std::fmt::Display for DiscreteUniform { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Uni([{}, {}])", self.min, self.max) + } +} + impl ::rand::distributions::Distribution for DiscreteUniform { fn sample(&self, rng: &mut R) -> f64 { rng.gen_range(self.min..=self.max) as f64 diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index b22b78be..9afd7022 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -6,7 +6,7 @@ use core::cmp::Ordering; use rand::Rng; use std::collections::BTreeMap; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, PartialEq, Debug)] struct NonNan(T); impl Eq for NonNan {} @@ -37,7 +37,7 @@ impl Ord for NonNan { /// let empirical = Empirical::from_vec(samples); /// assert_eq!(empirical.mean().unwrap(), 5.0); /// ``` -#[derive(Debug, Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct Empirical { sum: f64, mean_and_var: Option<(f64, f64)>, @@ -151,6 +151,30 @@ impl Empirical { } } +impl std::fmt::Display for Empirical { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some((&NonNan(x), _)) = self.data.first_key_value() { + write!(f, "Empirical([{:.3e}", x)?; + } else { + return write!(f, "Empirical(∅)"); + } + + let mut enumerated_values = self + .data + .iter() + .flat_map(|(&NonNan(x), &count)| std::iter::repeat(x).take(count as usize)) + .skip(1); + + for val in enumerated_values.by_ref().take(4) { + write!(f, ", {:.3e}", val)?; + } + if enumerated_values.next().is_some() { + write!(f, ", ...")?; + } + write!(f, "])") + } +} + impl ::rand::distributions::Distribution for Empirical { fn sample(&self, rng: &mut R) -> f64 { let uniform = Uniform::new(0.0, 1.0).unwrap(); @@ -204,6 +228,10 @@ impl ContinuousCDF for Empirical { } sum as f64 / self.sum } + + fn inverse_cdf(&self, p: f64) -> f64 { + self.__inverse_cdf(p) + } } #[cfg(test)] @@ -264,4 +292,31 @@ mod tests { // due to the mean and variance being calculated in a streaming way assert_eq!(unchanged, empirical); } + + #[test] + fn test_display() { + let mut e = Empirical::new().unwrap(); + assert_eq!(e.to_string(), "Empirical(∅)"); + e.add(1.0); + assert_eq!(e.to_string(), "Empirical([1.000e0])"); + e.add(1.0); + assert_eq!(e.to_string(), "Empirical([1.000e0, 1.000e0])"); + e.add(2.0); + assert_eq!(e.to_string(), "Empirical([1.000e0, 1.000e0, 2.000e0])"); + e.add(2.0); + assert_eq!( + e.to_string(), + "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0])" + ); + e.add(5.0); + assert_eq!( + e.to_string(), + "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0, 5.000e0])" + ); + e.add(5.0); + assert_eq!( + e.to_string(), + "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0, 5.000e0, ...])" + ); + } } diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index 4e88c969..1213baef 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -20,7 +20,7 @@ use rand::Rng; /// assert_eq!(n.mean().unwrap(), 3.0); /// assert!(prec::almost_eq(n.pdf(2.0), 0.270670566473225383788, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Erlang { g: Gamma, } @@ -78,6 +78,12 @@ impl Erlang { } } +impl std::fmt::Display for Erlang { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "E({}, {})", self.rate(), self.shape()) + } +} + impl ::rand::distributions::Distribution for Erlang { fn sample(&self, rng: &mut R) -> f64 { ::rand::distributions::Distribution::sample(&self.g, rng) diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index 105cd83e..978ae638 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -20,7 +20,7 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), 1.0); /// assert_eq!(n.pdf(1.0), 0.3678794411714423215955); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Exp { rate: f64, } @@ -67,6 +67,12 @@ impl Exp { } } +impl std::fmt::Display for Exp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Exp({})", self.rate) + } +} + impl ::rand::distributions::Distribution for Exp { fn sample(&self, r: &mut R) -> f64 { ziggurat::sample_exp_1(r) / self.rate diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index 74815824..d54a1bef 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -89,6 +89,12 @@ impl FisherSnedecor { } } +impl std::fmt::Display for FisherSnedecor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "F({},{})", self.freedom_1, self.freedom_2) + } +} + impl ::rand::distributions::Distribution for FisherSnedecor { fn sample(&self, rng: &mut R) -> f64 { (super::gamma::sample_unchecked(rng, self.freedom_1 / 2.0, 0.5) * self.freedom_2) diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index e2500174..166ebb72 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -19,7 +19,7 @@ use rand::Rng; /// assert_eq!(n.mean().unwrap(), 3.0); /// assert!(prec::almost_eq(n.pdf(2.0), 0.270670566473225383788, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Gamma { shape: f64, rate: f64, @@ -86,6 +86,12 @@ impl Gamma { } } +impl std::fmt::Display for Gamma { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Γ({}, {})", self.shape, self.rate) + } +} + impl ::rand::distributions::Distribution for Gamma { fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.shape, self.rate) diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index 0dc7fad8..4df623ed 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -20,7 +20,7 @@ use std::f64; /// assert_eq!(n.pmf(1), 0.3); /// assert_eq!(n.pmf(2), 0.21); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Geometric { p: f64, } @@ -68,6 +68,12 @@ impl Geometric { } } +impl std::fmt::Display for Geometric { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Geom({})", self.p) + } +} + impl ::rand::distributions::Distribution for Geometric { fn sample(&self, r: &mut R) -> f64 { if ulps_eq!(self.p, 1.0) { diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 21f94326..8b6d8500 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -14,7 +14,7 @@ use std::f64; /// /// ``` /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Eq, Debug)] pub struct Hypergeometric { population: u64, successes: u64, @@ -110,6 +110,16 @@ impl Hypergeometric { } } +impl std::fmt::Display for Hypergeometric { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Hypergeometric({},{},{})", + self.population, self.successes, self.draws + ) + } +} + impl ::rand::distributions::Distribution for Hypergeometric { fn sample(&self, rng: &mut R) -> f64 { let mut population = self.population as f64; @@ -193,7 +203,7 @@ impl DiscreteCDF for Hypergeometric { } else { let k = x; let ln_denom = factorial::ln_binomial(self.population, self.draws); - (k + 1..self.max() + 1).fold(0.0, |acc, i| { + (k + 1..=self.max()).fold(0.0, |acc, i| { acc + (factorial::ln_binomial(self.successes, i) + factorial::ln_binomial(self.population - self.successes, self.draws - i) - ln_denom) diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index 5dccd8ce..d22d2239 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -20,7 +20,7 @@ use std::f64; /// assert!(prec::almost_eq(n.mean().unwrap(), 1.0, 1e-14)); /// assert_eq!(n.pdf(1.0), 0.07554920138253064); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct InverseGamma { shape: f64, rate: f64, @@ -87,6 +87,12 @@ impl InverseGamma { } } +impl std::fmt::Display for InverseGamma { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Inv-Gamma({}, {})", self.shape, self.rate) + } +} + impl ::rand::distributions::Distribution for InverseGamma { fn sample(&self, r: &mut R) -> f64 { 1.0 / super::gamma::sample_unchecked(r, self.shape, self.rate) diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index bcaaae08..1ed74132 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -17,7 +17,7 @@ use std::f64; /// assert_eq!(n.mode().unwrap(), 0.0); /// assert_eq!(n.pdf(1.0), 0.18393972058572117); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Laplace { location: f64, scale: f64, @@ -79,6 +79,12 @@ impl Laplace { } } +impl std::fmt::Display for Laplace { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Laplace({}, {})", self.location, self.scale) + } +} + impl ::rand::distributions::Distribution for Laplace { fn sample(&self, rng: &mut R) -> f64 { let x: f64 = rng.gen_range(-0.5..0.5); diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 356a6ab3..b6dbff6f 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -20,7 +20,7 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), (0.5f64).exp()); /// assert!(prec::almost_eq(n.pdf(1.0), 0.3989422804014326779399, 1e-16)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct LogNormal { location: f64, scale: f64, @@ -55,6 +55,12 @@ impl LogNormal { } } +impl std::fmt::Display for LogNormal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "LogNormal({}, {}^2)", self.location, self.scale) + } +} + impl ::rand::distributions::Distribution for LogNormal { fn sample(&self, rng: &mut R) -> f64 { super::normal::sample_unchecked(rng, self.location, self.scale).exp() diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index f1ac0a8e..dd17d2f0 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -90,6 +90,12 @@ impl Multinomial { } } +impl std::fmt::Display for Multinomial { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Multinom({:#?},{})", self.p, self.n) + } +} + impl ::rand::distributions::Distribution> for Multinomial { fn sample(&self, rng: &mut R) -> Vec { let p_cdf = super::categorical::prob_mass_to_cdf(self.p()); diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 8e238003..0f16b639 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -26,7 +26,7 @@ use std::f64::consts::{E, PI}; /// assert_eq!(mvn.variance().unwrap(), DMatrix::from_vec(2, 2, vec![1., 0., 0., 1.])); /// assert_eq!(mvn.pdf(&DVector::from_vec(vec![1., 1.])), 0.05854983152431917); /// ``` -#[derive(Debug, Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct MultivariateNormal { dim: usize, cov_chol_decomp: DMatrix, @@ -113,6 +113,12 @@ impl MultivariateNormal { } } +impl std::fmt::Display for MultivariateNormal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "N({}, {})", &self.mu, &self.cov) + } +} + impl ::rand::distributions::Distribution> for MultivariateNormal { /// Samples from the multivariate normal distribution /// diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index c8c1e725..065c2239 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -35,7 +35,7 @@ use std::f64; /// assert!(almost_eq(r.pmf(0), 0.0625, 1e-8)); /// assert!(almost_eq(r.pmf(3), 0.15625, 1e-8)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct NegativeBinomial { r: f64, p: f64, @@ -104,6 +104,12 @@ impl NegativeBinomial { } } +impl std::fmt::Display for NegativeBinomial { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "NB({},{})", self.r, self.p) + } +} + impl ::rand::distributions::Distribution for NegativeBinomial { fn sample(&self, r: &mut R) -> u64 { let lambda = distribution::gamma::sample_unchecked(r, self.r, (1.0 - self.p) / self.p); diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 48436955..94e8c6b6 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -18,7 +18,7 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), 0.0); /// assert_eq!(n.pdf(1.0), 0.2419707245191433497978); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Normal { mean: f64, std_dev: f64, @@ -71,6 +71,12 @@ impl Normal { } } +impl std::fmt::Display for Normal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "N({},{})", self.mean, self.std_dev) + } +} + impl ::rand::distributions::Distribution for Normal { fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.mean, self.std_dev) @@ -317,7 +323,6 @@ pub fn sample_unchecked(rng: &mut R, mean: f64, std_dev: f64) - mean + std_dev * ziggurat::sample_std_normal(rng) } - impl std::default::Default for Normal { /// Returns the standard normal distribution with a mean of 0 /// and a standard deviation of 1. diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 7d6cd1e0..55df13a5 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -19,7 +19,7 @@ use std::f64; /// assert_eq!(p.mean().unwrap(), 2.0); /// assert!(prec::almost_eq(p.pdf(2.0), 0.25, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Pareto { scale: f64, shape: f64, @@ -83,6 +83,12 @@ impl Pareto { } } +impl std::fmt::Display for Pareto { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Pareto({},{})", self.scale, self.shape) + } +} + impl ::rand::distributions::Distribution for Pareto { fn sample(&self, rng: &mut R) -> f64 { // Inverse transform sampling diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 6ec943eb..7653ed20 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -19,7 +19,7 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), 1.0); /// assert!(prec::almost_eq(n.pmf(1), 0.367879441171442, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Poisson { lambda: f64, } @@ -66,6 +66,12 @@ impl Poisson { } } +impl std::fmt::Display for Poisson { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Pois({})", self.lambda) + } +} + impl ::rand::distributions::Distribution for Poisson { /// Generates one sample from the Poisson distribution either by /// Knuth's method if lambda < 30.0 or Rejection method PA by diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index 5bbb628b..af7d7356 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -20,7 +20,7 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), 0.0); /// assert!(prec::almost_eq(n.pdf(0.0), 0.353553390593274, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct StudentsT { location: f64, scale: f64, @@ -104,6 +104,12 @@ impl StudentsT { } } +impl std::fmt::Display for StudentsT { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "t_{}({},{})", self.freedom, self.location, self.scale) + } +} + impl ::rand::distributions::Distribution for StudentsT { fn sample(&self, r: &mut R) -> f64 { // based on method 2, section 5 in chapter 9 of L. Devroye's diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index 848c2dbd..5fa89b70 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -18,7 +18,7 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), 7.5 / 3.0); /// assert_eq!(n.pdf(2.5), 5.0 / 12.5); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Triangular { min: f64, max: f64, @@ -59,6 +59,12 @@ impl Triangular { } } +impl std::fmt::Display for Triangular { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Triangular([{},{}], {})", self.min, self.max, self.mode) + } +} + impl ::rand::distributions::Distribution for Triangular { fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.min, self.max, self.mode) diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index 0fcc90ea..9a3478bb 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -63,6 +63,12 @@ impl Uniform { } } +impl std::fmt::Display for Uniform { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Uni([{},{}])", self.min, self.max) + } +} + impl ::rand::distributions::Distribution for Uniform { fn sample(&self, rng: &mut R) -> f64 { let d = RandUniform::new_inclusive(self.min, self.max); diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 3507a062..49dbc4d8 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -21,7 +21,7 @@ use std::f64; /// 0.95135076986687318362924871772654021925505786260884, 1e-15)); /// assert_eq!(n.pdf(1.0), 3.6787944117144232159552377016146086744581113103177); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Weibull { shape: f64, scale: f64, @@ -90,6 +90,12 @@ impl Weibull { } } +impl std::fmt::Display for Weibull { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Weibull({},{})", self.scale, self.shape) + } +} + impl ::rand::distributions::Distribution for Weibull { fn sample(&self, rng: &mut R) -> f64 { let x: f64 = rng.gen(); diff --git a/src/generate.rs b/src/generate.rs index 59783598..e834c6c5 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -29,6 +29,7 @@ pub fn log_spaced(length: usize, start_exp: f64, stop_exp: f64) -> Vec { } /// Infinite iterator returning floats that form a periodic wave +#[derive(Clone, Copy, PartialEq, Debug)] pub struct InfinitePeriodic { amplitude: f64, step: f64, @@ -80,6 +81,12 @@ impl InfinitePeriodic { } } +impl std::fmt::Display for InfinitePeriodic { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", self) + } +} + impl Iterator for InfinitePeriodic { type Item = f64; @@ -96,6 +103,7 @@ impl Iterator for InfinitePeriodic { } /// Infinite iterator returning floats that form a sinusoidal wave +#[derive(Debug, Clone, Copy, PartialEq)] pub struct InfiniteSinusoidal { amplitude: f64, mean: f64, @@ -159,6 +167,12 @@ impl InfiniteSinusoidal { } } +impl std::fmt::Display for InfiniteSinusoidal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", &self) + } +} + impl Iterator for InfiniteSinusoidal { type Item = f64; @@ -175,6 +189,7 @@ impl Iterator for InfiniteSinusoidal { /// Infinite iterator returning floats forming a square wave starting /// with the high phase +#[derive(Debug, Clone, Copy, PartialEq)] pub struct InfiniteSquare { periodic: InfinitePeriodic, high_duration: f64, @@ -212,6 +227,12 @@ impl InfiniteSquare { } } +impl std::fmt::Display for InfiniteSquare { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", &self) + } +} + impl Iterator for InfiniteSquare { type Item = f64; @@ -228,6 +249,7 @@ impl Iterator for InfiniteSquare { /// Infinite iterator returning floats forming a triangle wave starting with /// the raise phase from the lowest sample +#[derive(Debug, Clone, Copy, PartialEq)] pub struct InfiniteTriangle { periodic: InfinitePeriodic, raise_duration: f64, @@ -278,6 +300,12 @@ impl InfiniteTriangle { } } +impl std::fmt::Display for InfiniteTriangle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", &self) + } +} + impl Iterator for InfiniteTriangle { type Item = f64; @@ -294,6 +322,7 @@ impl Iterator for InfiniteTriangle { /// Infinite iterator returning floats forming a sawtooth wave /// starting with the lowest sample +#[derive(Debug, Clone, Copy, PartialEq)] pub struct InfiniteSawtooth { periodic: InfinitePeriodic, low_value: f64, @@ -328,6 +357,12 @@ impl InfiniteSawtooth { } } +impl std::fmt::Display for InfiniteSawtooth { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", &self) + } +} + impl Iterator for InfiniteSawtooth { type Item = f64; diff --git a/src/statistics/slice_statistics.rs b/src/statistics/slice_statistics.rs index a9cbfdde..ea2f3096 100644 --- a/src/statistics/slice_statistics.rs +++ b/src/statistics/slice_statistics.rs @@ -2,9 +2,34 @@ use crate::statistics::*; use core::ops::{Index, IndexMut}; use rand::prelude::SliceRandom; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] pub struct Data(D); +impl std::fmt::Display for Data +where + D: Clone + IntoIterator, + I: Clone + std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut tee = self.0.clone().into_iter(); + write!(f, "Data([")?; + + if let Some(v) = tee.next() { + write!(f, "{}", v)?; + } + for _ in 1..5 { + if let Some(v) = tee.next() { + write!(f, ", {}", v)?; + } + } + if tee.next().is_some() { + write!(f, "...")?; + } + + write!(f, "])") + } +} + impl> Index for Data { type Output = f64; diff --git a/src/statistics/statistics.rs b/src/statistics/statistics.rs index 40b91c72..3081791b 100644 --- a/src/statistics/statistics.rs +++ b/src/statistics/statistics.rs @@ -1,6 +1,6 @@ /// Enumeration of possible tie-breaking strategies /// when computing ranks -#[derive(Debug, Copy, Clone)] +#[derive(Copy, Clone, Debug)] pub enum RankTieBreaker { /// Replaces ties with their mean Average, From 1be04dd3703eb4b46408c98afa97fbb54bcc362c Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 16 Jun 2024 11:46:23 -0500 Subject: [PATCH 085/185] chore(errors): do not implement `description` for Error --- src/error.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/error.rs b/src/error.rs index faa74fad..e6f7ca40 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,7 +2,7 @@ use std::error::Error; use std::fmt; /// Enumeration of possible errors thrown within the `statrs` library -#[derive(Debug)] +#[derive(Clone, PartialEq, Debug)] pub enum StatsError { /// Generic bad input parameter error BadParams, @@ -50,11 +50,7 @@ pub enum StatsError { SpecialCase(&'static str), } -impl Error for StatsError { - fn description(&self) -> &str { - "Error performing statistical calculation" - } -} +impl Error for StatsError {} impl fmt::Display for StatsError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { From a33fca3816fad3a0fdeb5584f88362c209e5da5f Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 24 May 2024 17:41:59 -0500 Subject: [PATCH 086/185] test: treat tests on NIST data as integration tests also modify assert_almost_equal macro to support option single trailing comma --- src/lib.rs | 2 +- src/statistics/iter_statistics.rs | 120 ----------------------- tests/nist_tests.rs | 153 ++++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+), 121 deletions(-) create mode 100644 tests/nist_tests.rs diff --git a/src/lib.rs b/src/lib.rs index e8b0e7b7..bdb5fc51 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,7 +55,7 @@ extern crate approx; #[macro_export] macro_rules! assert_almost_eq { - ($a:expr, $b:expr, $prec:expr) => { + ($a:expr, $b:expr, $prec:expr $(,)?) => { if !$crate::prec::almost_eq($a, $b, $prec) { panic!( "assertion failed: `abs(left - right) < {:e}`, (left: `{}`, right: `{}`)", diff --git a/src/statistics/iter_statistics.rs b/src/statistics/iter_statistics.rs index 3a53d835..06941ef6 100644 --- a/src/statistics/iter_statistics.rs +++ b/src/statistics/iter_statistics.rs @@ -252,126 +252,6 @@ mod tests { use crate::generate::{InfinitePeriodic, InfiniteSinusoidal}; use crate::testing; - #[test] - fn test_mean() { - let mut data = testing::load_data("nist/lottery.txt"); - assert_almost_eq!((&data).mean(), 518.958715596330, 1e-12); - - data = testing::load_data("nist/lew.txt"); - assert_almost_eq!((&data).mean(), -177.435000000000, 1e-13); - - data = testing::load_data("nist/mavro.txt"); - assert_almost_eq!((&data).mean(), 2.00185600000000, 1e-15); - - data = testing::load_data("nist/michaelso.txt"); - assert_almost_eq!((&data).mean(), 299.852400000000, 1e-13); - - data = testing::load_data("nist/numacc1.txt"); - assert_eq!((&data).mean(), 10000002.0); - - data = testing::load_data("nist/numacc2.txt"); - assert_almost_eq!((&data).mean(), 1.2, 1e-15); - - data = testing::load_data("nist/numacc3.txt"); - assert_eq!((&data).mean(), 1000000.2); - - data = testing::load_data("nist/numacc4.txt"); - assert_almost_eq!((&data).mean(), 10000000.2, 1e-8); - } - - #[test] - fn test_std_dev() { - let mut data = testing::load_data("nist/lottery.txt"); - assert_almost_eq!((&data).std_dev(), 291.699727470969, 1e-13); - - data = testing::load_data("nist/lew.txt"); - assert_almost_eq!((&data).std_dev(), 277.332168044316, 1e-12); - - data = testing::load_data("nist/mavro.txt"); - assert_almost_eq!((&data).std_dev(), 0.000429123454003053, 1e-15); - - data = testing::load_data("nist/michaelso.txt"); - assert_almost_eq!((&data).std_dev(), 0.0790105478190518, 1e-13); - - data = testing::load_data("nist/numacc1.txt"); - assert_eq!((&data).std_dev(), 1.0); - - data = testing::load_data("nist/numacc2.txt"); - assert_almost_eq!((&data).std_dev(), 0.1, 1e-16); - - data = testing::load_data("nist/numacc3.txt"); - assert_almost_eq!((&data).std_dev(), 0.1, 1e-10); - - data = testing::load_data("nist/numacc4.txt"); - assert_almost_eq!((&data).std_dev(), 0.1, 1e-9); - } - - #[test] - fn test_min_max_short() { - let data = [-1.0, 5.0, 0.0, -3.0, 10.0, -0.5, 4.0]; - assert_eq!(data.min(), -3.0); - assert_eq!(data.max(), 10.0); - } - - #[test] - fn test_mean_variance_stability() { - let seed = [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 - ]; - let mut rng: StdRng = SeedableRng::from_seed(seed); - let normal = Normal::new(1e9, 2.0).unwrap(); - let samples = (0..10000).map(|_| normal.sample::(&mut rng)).collect::>(); - assert_almost_eq!((&samples).mean(), 1e9, 10.0); - assert_almost_eq!((&samples).variance(), 4.0, 0.1); - assert_almost_eq!((&samples).std_dev(), 2.0, 0.01); - assert_almost_eq!((&samples).quadratic_mean(), 1e9, 10.0); - } - - #[test] - fn test_covariance_consistent_with_variance() { - let mut data = testing::load_data("nist/lottery.txt"); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); - - data = testing::load_data("nist/lew.txt"); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); - - data = testing::load_data("nist/mavro.txt"); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); - - data = testing::load_data("nist/michaelso.txt"); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); - - data = testing::load_data("nist/numacc1.txt"); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); - } - - #[test] - fn test_pop_covar_consistent_with_pop_var() { - let mut data = testing::load_data("nist/lottery.txt"); - assert_almost_eq!((&data).population_variance(), (&data).population_covariance(&data), 1e-10); - - data = testing::load_data("nist/lew.txt"); - assert_almost_eq!((&data).population_variance(), (&data).population_covariance(&data), 1e-10); - - data = testing::load_data("nist/mavro.txt"); - assert_almost_eq!((&data).population_variance(), (&data).population_covariance(&data), 1e-10); - - data = testing::load_data("nist/michaelso.txt"); - assert_almost_eq!((&data).population_variance(), (&data).population_covariance(&data), 1e-10); - - data = testing::load_data("nist/numacc1.txt"); - assert_almost_eq!((&data).population_variance(), (&data).population_covariance(&data), 1e-10); - } - - #[test] - fn test_covariance_is_symmetric() { - let data_a = &testing::load_data("nist/lottery.txt")[0..200]; - let data_b = &testing::load_data("nist/lew.txt")[0..200]; - assert_almost_eq!(data_a.covariance(data_b), data_b.covariance(data_a), 1e-10); - assert_almost_eq!(data_a.population_covariance(data_b), data_b.population_covariance(data_a), 1e-11); - } - #[test] fn test_empty_data_returns_nan() { let data = [0.0; 0]; diff --git a/tests/nist_tests.rs b/tests/nist_tests.rs new file mode 100644 index 00000000..d14f4cde --- /dev/null +++ b/tests/nist_tests.rs @@ -0,0 +1,153 @@ +// #![cfg(test)] +use statrs::assert_almost_eq; +use statrs::statistics::Statistics; +use std::io::{BufRead, BufReader}; +use std::{env, fs}; + +#[cfg(test)] +const NIST_DATA_DIR_ENV: &str = "STATRS_NIST_DATA_DIR"; + +fn load_data(pathname: String) -> Vec { + let f = fs::File::open(pathname).unwrap(); + let mut reader = BufReader::new(f); + + let mut buf = String::new(); + let mut data: Vec = vec![]; + while reader.read_line(&mut buf).unwrap() > 0 { + data.push(buf.trim().parse::().unwrap()); + buf.clear(); + } + data +} + +#[test] +#[ignore = "NIST tests should not run from typical `cargo test` calls"] +fn nist_test_mean() { + let path_dir = env::var(NIST_DATA_DIR_ENV).unwrap(); + let mut data = load_data(dbg!(path_dir.clone() + "lottery.txt")); + assert_almost_eq!((&data).mean(), 518.958715596330, 1e-12); + + data = load_data(dbg!(path_dir.clone() + "lew.txt")); + assert_almost_eq!((&data).mean(), -177.435000000000, 1e-13); + + data = load_data(dbg!(path_dir.clone() + "mavro.txt")); + assert_almost_eq!((&data).mean(), 2.00185600000000, 1e-15); + + data = load_data(dbg!(path_dir.clone() + "michaelso.txt")); + assert_almost_eq!((&data).mean(), 299.852400000000, 1e-13); + + data = load_data(dbg!(path_dir.clone() + "numacc1.txt")); + assert_eq!((&data).mean(), 10000002.0); + + data = load_data(dbg!(path_dir.clone() + "numacc2.txt")); + assert_almost_eq!((&data).mean(), 1.2, 1e-15); + + data = load_data(dbg!(path_dir.clone() + "numacc3.txt")); + assert_eq!((&data).mean(), 1000000.2); + + data = load_data(dbg!(path_dir.clone() + "numacc4.txt")); + assert_almost_eq!((&data).mean(), 10000000.2, 1e-8); +} + +#[test] +#[ignore = "NIST tests should not run from typical `cargo test` calls"] +fn nist_test_std_dev() { + let path_dir = env::var(NIST_DATA_DIR_ENV).unwrap(); + let mut data = load_data(dbg!(path_dir.clone() + "lottery.txt")); + assert_almost_eq!((&data).std_dev(), 291.699727470969, 1e-13); + + data = load_data(dbg!(path_dir.clone() + "lew.txt")); + assert_almost_eq!((&data).std_dev(), 277.332168044316, 1e-12); + + data = load_data(dbg!(path_dir.clone() + "mavro.txt")); + assert_almost_eq!((&data).std_dev(), 0.000429123454003053, 1e-15); + + data = load_data(dbg!(path_dir.clone() + "michaelso.txt")); + assert_almost_eq!((&data).std_dev(), 0.0790105478190518, 1e-13); + + data = load_data(dbg!(path_dir.clone() + "numacc1.txt")); + assert_eq!((&data).std_dev(), 1.0); + + data = load_data(dbg!(path_dir.clone() + "numacc2.txt")); + assert_almost_eq!((&data).std_dev(), 0.1, 1e-16); + + data = load_data(dbg!(path_dir.clone() + "numacc3.txt")); + assert_almost_eq!((&data).std_dev(), 0.1, 1e-10); + + data = load_data(dbg!(path_dir.clone() + "numacc4.txt")); + assert_almost_eq!((&data).std_dev(), 0.1, 1e-9); +} + +#[test] +#[ignore = "NIST tests should not run from typical `cargo test` calls"] +fn nist_test_covariance_consistent_with_variance() { + let path_dir = env::var(NIST_DATA_DIR_ENV).unwrap(); + let mut data = load_data(dbg!(path_dir.clone() + "lottery.txt")); + assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); + + data = load_data(dbg!(path_dir.clone() + "lew.txt")); + assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); + + data = load_data(dbg!(path_dir.clone() + "mavro.txt")); + assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); + + data = load_data(dbg!(path_dir.clone() + "michaelso.txt")); + assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); + + data = load_data(dbg!(path_dir.clone() + "numacc1.txt")); + assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); +} + +#[test] +#[ignore = "NIST tests should not run from typical `cargo test` calls"] +fn nist_test_pop_covar_consistent_with_pop_var() { + let path_dir = env::var(NIST_DATA_DIR_ENV).unwrap(); + let mut data = load_data(dbg!(path_dir.clone() + "lottery.txt")); + assert_almost_eq!( + (&data).population_variance(), + (&data).population_covariance(&data), + 1e-10, + ); + + data = load_data(dbg!(path_dir.clone() + "lew.txt")); + assert_almost_eq!( + (&data).population_variance(), + (&data).population_covariance(&data), + 1e-10, + ); + + data = load_data(dbg!(path_dir.clone() + "mavro.txt")); + assert_almost_eq!( + (&data).population_variance(), + (&data).population_covariance(&data), + 1e-10, + ); + + data = load_data(dbg!(path_dir.clone() + "michaelso.txt")); + assert_almost_eq!( + (&data).population_variance(), + (&data).population_covariance(&data), + 1e-10, + ); + + data = load_data(dbg!(path_dir.clone() + "numacc1.txt")); + assert_almost_eq!( + (&data).population_variance(), + (&data).population_covariance(&data), + 1e-10, + ); +} + +#[test] +#[ignore = "NIST tests should not run from typical `cargo test` calls"] +fn nist_test_covariance_is_symmetric() { + let path_dir = env::var(NIST_DATA_DIR_ENV).unwrap(); + let data_a = &load_data(dbg!(path_dir.clone() + "lottery.txt"))[0..200]; + let data_b = &load_data(dbg!(path_dir.clone() + "lew.txt"))[0..200]; + assert_almost_eq!(data_a.covariance(data_b), data_b.covariance(data_a), 1e-10); + assert_almost_eq!( + data_a.population_covariance(data_b), + data_b.population_covariance(data_a), + 1e-11, + ); +} From cca1c834a6a5fd0ba20724998cdacb5280672991 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 24 May 2024 17:58:27 -0500 Subject: [PATCH 087/185] test: remove preformatted NIST data TODO modify parsing to work with test data from NIST source --- data/nist/lew.txt | 200 -------- data/nist/lottery.txt | 218 --------- data/nist/mavro.txt | 50 -- data/nist/michaelso.txt | 100 ---- data/nist/numacc1.txt | 3 - data/nist/numacc2.txt | 1001 --------------------------------------- data/nist/numacc3.txt | 1001 --------------------------------------- data/nist/numacc4.txt | 1001 --------------------------------------- 8 files changed, 3574 deletions(-) delete mode 100644 data/nist/lew.txt delete mode 100644 data/nist/lottery.txt delete mode 100644 data/nist/mavro.txt delete mode 100644 data/nist/michaelso.txt delete mode 100644 data/nist/numacc1.txt delete mode 100644 data/nist/numacc2.txt delete mode 100644 data/nist/numacc3.txt delete mode 100644 data/nist/numacc4.txt diff --git a/data/nist/lew.txt b/data/nist/lew.txt deleted file mode 100644 index 9e38a720..00000000 --- a/data/nist/lew.txt +++ /dev/null @@ -1,200 +0,0 @@ --213 --564 --35 --15 -141 -115 --420 --360 -203 --338 --431 -194 --220 --513 -154 --125 --559 -92 --21 --579 --52 -99 --543 --175 -162 --457 --346 -204 --300 --474 -164 --107 --572 --8 -83 --541 --224 -180 --420 --374 -201 --236 --531 -83 -27 --564 --112 -131 --507 --254 -199 --311 --495 -143 --46 --579 --90 -136 --472 --338 -202 --287 --477 -169 --124 --568 -17 -48 --568 --135 -162 --430 --422 -172 --74 --577 --13 -92 --534 --243 -194 --355 --465 -156 --81 --578 --64 -139 --449 --384 -193 --198 --538 -110 --44 --577 --6 -66 --552 --164 -161 --460 --344 -205 --281 --504 -134 --28 --576 --118 -156 --437 --381 -200 --220 --540 -83 -11 --568 --160 -172 --414 --408 -188 --125 --572 --32 -139 --492 --321 -205 --262 --504 -142 --83 --574 -0 -48 --571 --106 -137 --501 --266 -190 --391 --406 -194 --186 --553 -83 --13 --577 --49 -103 --515 --280 -201 -300 --506 -131 --45 --578 --80 -138 --462 --361 -201 --211 --554 -32 -74 --533 --235 -187 --372 --442 -182 --147 --566 -25 -68 --535 --244 -194 --351 --463 -174 --125 --570 -15 -72 --550 --190 -172 --424 --385 -198 --218 --536 -96 \ No newline at end of file diff --git a/data/nist/lottery.txt b/data/nist/lottery.txt deleted file mode 100644 index a1880747..00000000 --- a/data/nist/lottery.txt +++ /dev/null @@ -1,218 +0,0 @@ -162 -671 -933 -414 -788 -730 -817 -33 -536 -875 -670 -236 -473 -167 -877 -980 -316 -950 -456 -92 -517 -557 -956 -954 -104 -178 -794 -278 -147 -773 -437 -435 -502 -610 -582 -780 -689 -562 -964 -791 -28 -97 -848 -281 -858 -538 -660 -972 -671 -613 -867 -448 -738 -966 -139 -636 -847 -659 -754 -243 -122 -455 -195 -968 -793 -59 -730 -361 -574 -522 -97 -762 -431 -158 -429 -414 -22 -629 -788 -999 -187 -215 -810 -782 -47 -34 -108 -986 -25 -644 -829 -630 -315 -567 -919 -331 -207 -412 -242 -607 -668 -944 -749 -168 -864 -442 -533 -805 -372 -63 -458 -777 -416 -340 -436 -140 -919 -350 -510 -572 -905 -900 -85 -389 -473 -758 -444 -169 -625 -692 -140 -897 -672 -288 -312 -860 -724 -226 -884 -508 -976 -741 -476 -417 -831 -15 -318 -432 -241 -114 -799 -955 -833 -358 -935 -146 -630 -830 -440 -642 -356 -373 -271 -715 -367 -393 -190 -669 -8 -861 -108 -795 -269 -590 -326 -866 -64 -523 -862 -840 -219 -382 -998 -4 -628 -305 -747 -247 -34 -747 -729 -645 -856 -974 -24 -568 -24 -694 -608 -480 -410 -729 -947 -293 -53 -930 -223 -203 -677 -227 -62 -455 -387 -318 -562 -242 -428 -968 \ No newline at end of file diff --git a/data/nist/mavro.txt b/data/nist/mavro.txt deleted file mode 100644 index b904e6aa..00000000 --- a/data/nist/mavro.txt +++ /dev/null @@ -1,50 +0,0 @@ -2.00180 -2.00170 -2.00180 -2.00190 -2.00180 -2.00170 -2.00150 -2.00140 -2.00150 -2.00150 -2.00170 -2.00180 -2.00180 -2.00190 -2.00190 -2.00210 -2.00200 -2.00160 -2.00140 -2.00130 -2.00130 -2.00150 -2.00150 -2.00160 -2.00150 -2.00140 -2.00130 -2.00140 -2.00150 -2.00140 -2.00150 -2.00160 -2.00150 -2.00160 -2.00190 -2.00200 -2.00200 -2.00210 -2.00220 -2.00230 -2.00240 -2.00250 -2.00270 -2.00260 -2.00260 -2.00260 -2.00270 -2.00260 -2.00250 -2.00240 \ No newline at end of file diff --git a/data/nist/michaelso.txt b/data/nist/michaelso.txt deleted file mode 100644 index 2e436816..00000000 --- a/data/nist/michaelso.txt +++ /dev/null @@ -1,100 +0,0 @@ -299.85 -299.74 -299.90 -300.07 -299.93 -299.85 -299.95 -299.98 -299.98 -299.88 -300.00 -299.98 -299.93 -299.65 -299.76 -299.81 -300.00 -300.00 -299.96 -299.96 -299.96 -299.94 -299.96 -299.94 -299.88 -299.80 -299.85 -299.88 -299.90 -299.84 -299.83 -299.79 -299.81 -299.88 -299.88 -299.83 -299.80 -299.79 -299.76 -299.80 -299.88 -299.88 -299.88 -299.86 -299.72 -299.72 -299.62 -299.86 -299.97 -299.95 -299.88 -299.91 -299.85 -299.87 -299.84 -299.84 -299.85 -299.84 -299.84 -299.84 -299.89 -299.81 -299.81 -299.82 -299.80 -299.77 -299.76 -299.74 -299.75 -299.76 -299.91 -299.92 -299.89 -299.86 -299.88 -299.72 -299.84 -299.85 -299.85 -299.78 -299.89 -299.84 -299.78 -299.81 -299.76 -299.81 -299.79 -299.81 -299.82 -299.85 -299.87 -299.87 -299.81 -299.74 -299.81 -299.94 -299.95 -299.80 -299.81 -299.87 \ No newline at end of file diff --git a/data/nist/numacc1.txt b/data/nist/numacc1.txt deleted file mode 100644 index 79dec5da..00000000 --- a/data/nist/numacc1.txt +++ /dev/null @@ -1,3 +0,0 @@ -10000001 -10000003 -10000002 \ No newline at end of file diff --git a/data/nist/numacc2.txt b/data/nist/numacc2.txt deleted file mode 100644 index 8a345dad..00000000 --- a/data/nist/numacc2.txt +++ /dev/null @@ -1,1001 +0,0 @@ -1.2 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 \ No newline at end of file diff --git a/data/nist/numacc3.txt b/data/nist/numacc3.txt deleted file mode 100644 index c7313205..00000000 --- a/data/nist/numacc3.txt +++ /dev/null @@ -1,1001 +0,0 @@ -1000000.2 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 \ No newline at end of file diff --git a/data/nist/numacc4.txt b/data/nist/numacc4.txt deleted file mode 100644 index 63647051..00000000 --- a/data/nist/numacc4.txt +++ /dev/null @@ -1,1001 +0,0 @@ -10000000.2 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 \ No newline at end of file From 49035ec461d4990b83a7b58ea1f8eebfbc86d4d2 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 26 May 2024 21:32:59 -0500 Subject: [PATCH 088/185] test: drop unused testing method --- src/testing/mod.rs | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/src/testing/mod.rs b/src/testing/mod.rs index 45eb3cab..338bdd2e 100644 --- a/src/testing/mod.rs +++ b/src/testing/mod.rs @@ -1,32 +1 @@ //! Provides testing helpers and utilities - -use std::fs::File; -use std::io::{BufRead, BufReader}; -use std::str; - -/// Loads a test data file into a vector of `f64`'s. -/// Path is relative to /data. -/// -/// # Panics -/// -/// Panics if the file does not exist or could not be opened, or -/// there was an error reading the file. -#[cfg(test)] -pub fn load_data(path: &str) -> Vec { - // note: the copious use of unwrap is because this is a test helper and - // if reading the data file fails, we want to panic immediately - - let path_prefix = "./data/".to_string(); - let true_path = path_prefix + path.trim().trim_start_matches('/'); - - let f = File::open(true_path).unwrap(); - let mut reader = BufReader::new(f); - - let mut buf = String::new(); - let mut data: Vec = vec![]; - while reader.read_line(&mut buf).unwrap() > 0 { - data.push(buf.trim().parse::().unwrap()); - buf.clear(); - } - data -} From 0a963f09538fd338d7d437a291e31ae9d6b068ff Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sat, 22 Jun 2024 18:20:30 -0500 Subject: [PATCH 089/185] test: first pass NIST strd univariate test shell script needs curl, grep, and sed --- Cargo.toml | 1 + prep_data.sh | 33 +++++++ tests/nist_tests.rs | 222 +++++++++++++++++++------------------------- 3 files changed, 129 insertions(+), 127 deletions(-) create mode 100755 prep_data.sh diff --git a/Cargo.toml b/Cargo.toml index 9823299f..4901c1f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,4 @@ num-traits = "0.2.14" [dev-dependencies] criterion = "0.3.3" +anyhow = "1.0" diff --git a/prep_data.sh b/prep_data.sh new file mode 100755 index 00000000..93ba2842 --- /dev/null +++ b/prep_data.sh @@ -0,0 +1,33 @@ +#! /bin/bash + +process_file() { + # Define input and output file names + SOURCE=$1 + FILENAME=$2 + curl -fsSL ${SOURCE}/$FILENAME > $FILENAME + + # Extract line numbers for Certified Values and Data from the header + INFO=$(grep "Certified Values:" $FILENAME) + CERTIFIED_VALUES_START=$(echo $INFO | awk '{print $4}') + CERTIFIED_VALUES_END=$(echo $INFO | awk '{print $6}') + + INFO=$(grep "Data :" $FILENAME) + DATA_START=$(echo $INFO | awk '{print $4}') + DATA_END=$(echo $INFO | awk '{print $6}') + + # Extract and reformat sections + # Certified values + sed -n -i \ + -e "${CERTIFIED_VALUES_START},${CERTIFIED_VALUES_END}p" \ + -e "${DATA_START},${DATA_END}p" \ + $FILENAME + # sed -n -i -e "${CERTIFIED_VALUES_START},${CERTIFIED_VALUES_END}s/\(exact\)//p" $FILENAME + +} + +URL='https://www.itl.nist.gov/div898/strd/univ/data' +for file in Lottery.dat Lew.dat Mavro.dat Michelso.dat NumAcc1.dat NumAcc2.dat NumAcc3.dat +do + process_file $URL $file +done + diff --git a/tests/nist_tests.rs b/tests/nist_tests.rs index d14f4cde..674b657e 100644 --- a/tests/nist_tests.rs +++ b/tests/nist_tests.rs @@ -1,153 +1,121 @@ // #![cfg(test)] -use statrs::assert_almost_eq; +use anyhow::Result; +use approx::assert_relative_eq; use statrs::statistics::Statistics; + use std::io::{BufRead, BufReader}; +use std::path::PathBuf; use std::{env, fs}; -#[cfg(test)] -const NIST_DATA_DIR_ENV: &str = "STATRS_NIST_DATA_DIR"; - -fn load_data(pathname: String) -> Vec { - let f = fs::File::open(pathname).unwrap(); - let mut reader = BufReader::new(f); +struct TestCase { + certified: CertifiedValues, + values: Vec, +} - let mut buf = String::new(); - let mut data: Vec = vec![]; - while reader.read_line(&mut buf).unwrap() > 0 { - data.push(buf.trim().parse::().unwrap()); - buf.clear(); +impl std::fmt::Debug for TestCase { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TestCase({:?}, [...]", self.certified) } - data } -#[test] -#[ignore = "NIST tests should not run from typical `cargo test` calls"] -fn nist_test_mean() { - let path_dir = env::var(NIST_DATA_DIR_ENV).unwrap(); - let mut data = load_data(dbg!(path_dir.clone() + "lottery.txt")); - assert_almost_eq!((&data).mean(), 518.958715596330, 1e-12); - - data = load_data(dbg!(path_dir.clone() + "lew.txt")); - assert_almost_eq!((&data).mean(), -177.435000000000, 1e-13); - - data = load_data(dbg!(path_dir.clone() + "mavro.txt")); - assert_almost_eq!((&data).mean(), 2.00185600000000, 1e-15); - - data = load_data(dbg!(path_dir.clone() + "michaelso.txt")); - assert_almost_eq!((&data).mean(), 299.852400000000, 1e-13); - - data = load_data(dbg!(path_dir.clone() + "numacc1.txt")); - assert_eq!((&data).mean(), 10000002.0); - - data = load_data(dbg!(path_dir.clone() + "numacc2.txt")); - assert_almost_eq!((&data).mean(), 1.2, 1e-15); - - data = load_data(dbg!(path_dir.clone() + "numacc3.txt")); - assert_eq!((&data).mean(), 1000000.2); +#[derive(Debug)] +struct CertifiedValues { + mean: f64, + std_dev: f64, + corr: f64, +} - data = load_data(dbg!(path_dir.clone() + "numacc4.txt")); - assert_almost_eq!((&data).mean(), 10000000.2, 1e-8); +impl std::fmt::Display for CertifiedValues { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "μ={:.3e}, σ={:.3e}, r={:.3e}", + self.mean, self.std_dev, self.corr + ) + } } +#[cfg(test)] +const NIST_DATA_DIR_ENV: &str = "STATRS_NIST_DATA_DIR"; +#[cfg(test)] +const FILENAMES: [&str; 7] = [ + "Lottery.dat", + "Lew.dat", + "Mavro.dat", + "Michelso.dat", + "NumAcc1.dat", + "NumAcc2.dat", + "NumAcc3.dat", +]; + #[test] #[ignore = "NIST tests should not run from typical `cargo test` calls"] -fn nist_test_std_dev() { - let path_dir = env::var(NIST_DATA_DIR_ENV).unwrap(); - let mut data = load_data(dbg!(path_dir.clone() + "lottery.txt")); - assert_almost_eq!((&data).std_dev(), 291.699727470969, 1e-13); - - data = load_data(dbg!(path_dir.clone() + "lew.txt")); - assert_almost_eq!((&data).std_dev(), 277.332168044316, 1e-12); - - data = load_data(dbg!(path_dir.clone() + "mavro.txt")); - assert_almost_eq!((&data).std_dev(), 0.000429123454003053, 1e-15); - - data = load_data(dbg!(path_dir.clone() + "michaelso.txt")); - assert_almost_eq!((&data).std_dev(), 0.0790105478190518, 1e-13); - - data = load_data(dbg!(path_dir.clone() + "numacc1.txt")); - assert_eq!((&data).std_dev(), 1.0); - - data = load_data(dbg!(path_dir.clone() + "numacc2.txt")); - assert_almost_eq!((&data).std_dev(), 0.1, 1e-16); - - data = load_data(dbg!(path_dir.clone() + "numacc3.txt")); - assert_almost_eq!((&data).std_dev(), 0.1, 1e-10); - - data = load_data(dbg!(path_dir.clone() + "numacc4.txt")); - assert_almost_eq!((&data).std_dev(), 0.1, 1e-9); +fn nist_strd_univariate_mean() { + let path_prefix = env::var(NIST_DATA_DIR_ENV).unwrap_or_else(|e| panic!("{}", e)); + + for fname in FILENAMES { + let case = parse_file([&path_prefix, fname].iter().collect::()) + .unwrap_or_else(|e| panic!("failed parsing file {} with {:?}", fname, e)); + assert_relative_eq!( + case.values.iter().mean(), + case.certified.mean, + epsilon = 1e-12 + ); + } } #[test] -#[ignore = "NIST tests should not run from typical `cargo test` calls"] -fn nist_test_covariance_consistent_with_variance() { - let path_dir = env::var(NIST_DATA_DIR_ENV).unwrap(); - let mut data = load_data(dbg!(path_dir.clone() + "lottery.txt")); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); - - data = load_data(dbg!(path_dir.clone() + "lew.txt")); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); - - data = load_data(dbg!(path_dir.clone() + "mavro.txt")); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); +#[ignore] +fn nist_strd_univariate_std_dev() { + let path_prefix = env::var(NIST_DATA_DIR_ENV).unwrap_or_else(|e| panic!("{}", e)); + + for fname in FILENAMES { + let case = parse_file([&path_prefix, fname].iter().collect::()) + .unwrap_or_else(|e| panic!("failed parsing file {} with {:?}", fname, e)); + assert_relative_eq!( + case.values.iter().std_dev(), + case.certified.std_dev, + epsilon = 1e-10 + ); + } +} - data = load_data(dbg!(path_dir.clone() + "michaelso.txt")); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); +fn parse_certified_value(line: String) -> Result { + line.chars() + .skip_while(|&c| c != ':') + .skip(1) // skip through ':' delimiter + .skip_while(|&c| c.is_whitespace()) // effectively `String` trim + .take_while(|&c| matches!(c, '0'..='9' | '-' | '.')) + .collect::() + .parse::() + .map_err(|e| e.into()) +} - data = load_data(dbg!(path_dir.clone() + "numacc1.txt")); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); +fn parse_file(path: impl AsRef) -> anyhow::Result { + let f = fs::File::open(path)?; + let reader = BufReader::new(f); + let mut lines = reader.lines(); + + let mean = parse_certified_value(lines.next().expect("file should not be exhausted")?)?; + let std_dev = parse_certified_value(lines.next().expect("file should not be exhausted")?)?; + let corr = parse_certified_value(lines.next().expect("file should not be exhausted")?)?; + + Ok(TestCase { + certified: CertifiedValues { + mean, + std_dev, + corr, + }, + values: lines + .map_while(|line| line.ok()?.trim().parse().ok()) + .collect(), + }) } #[test] #[ignore = "NIST tests should not run from typical `cargo test` calls"] -fn nist_test_pop_covar_consistent_with_pop_var() { - let path_dir = env::var(NIST_DATA_DIR_ENV).unwrap(); - let mut data = load_data(dbg!(path_dir.clone() + "lottery.txt")); - assert_almost_eq!( - (&data).population_variance(), - (&data).population_covariance(&data), - 1e-10, - ); - - data = load_data(dbg!(path_dir.clone() + "lew.txt")); - assert_almost_eq!( - (&data).population_variance(), - (&data).population_covariance(&data), - 1e-10, - ); - - data = load_data(dbg!(path_dir.clone() + "mavro.txt")); - assert_almost_eq!( - (&data).population_variance(), - (&data).population_covariance(&data), - 1e-10, - ); - - data = load_data(dbg!(path_dir.clone() + "michaelso.txt")); - assert_almost_eq!( - (&data).population_variance(), - (&data).population_covariance(&data), - 1e-10, - ); - - data = load_data(dbg!(path_dir.clone() + "numacc1.txt")); - assert_almost_eq!( - (&data).population_variance(), - (&data).population_covariance(&data), - 1e-10, - ); -} +fn nist_test_covariance_consistent_with_variance() {} #[test] #[ignore = "NIST tests should not run from typical `cargo test` calls"] -fn nist_test_covariance_is_symmetric() { - let path_dir = env::var(NIST_DATA_DIR_ENV).unwrap(); - let data_a = &load_data(dbg!(path_dir.clone() + "lottery.txt"))[0..200]; - let data_b = &load_data(dbg!(path_dir.clone() + "lew.txt"))[0..200]; - assert_almost_eq!(data_a.covariance(data_b), data_b.covariance(data_a), 1e-10); - assert_almost_eq!( - data_a.population_covariance(data_b), - data_b.population_covariance(data_a), - 1e-11, - ); -} +fn nist_test_covariance_is_symmetric() {} From f6baeb9863731e64331bf529fa8bcaf96e968fa9 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sat, 22 Jun 2024 19:36:51 -0500 Subject: [PATCH 090/185] test: mv shell script to test also drop usage of cfg(test) in tests/ directory --- prep_data.sh => tests/gather_nist_data.sh | 1 - tests/nist_tests.rs | 3 --- 2 files changed, 4 deletions(-) rename prep_data.sh => tests/gather_nist_data.sh (90%) diff --git a/prep_data.sh b/tests/gather_nist_data.sh similarity index 90% rename from prep_data.sh rename to tests/gather_nist_data.sh index 93ba2842..9381ff08 100755 --- a/prep_data.sh +++ b/tests/gather_nist_data.sh @@ -21,7 +21,6 @@ process_file() { -e "${CERTIFIED_VALUES_START},${CERTIFIED_VALUES_END}p" \ -e "${DATA_START},${DATA_END}p" \ $FILENAME - # sed -n -i -e "${CERTIFIED_VALUES_START},${CERTIFIED_VALUES_END}s/\(exact\)//p" $FILENAME } diff --git a/tests/nist_tests.rs b/tests/nist_tests.rs index 674b657e..9efc94bf 100644 --- a/tests/nist_tests.rs +++ b/tests/nist_tests.rs @@ -1,4 +1,3 @@ -// #![cfg(test)] use anyhow::Result; use approx::assert_relative_eq; use statrs::statistics::Statistics; @@ -35,9 +34,7 @@ impl std::fmt::Display for CertifiedValues { } } -#[cfg(test)] const NIST_DATA_DIR_ENV: &str = "STATRS_NIST_DATA_DIR"; -#[cfg(test)] const FILENAMES: [&str; 7] = [ "Lottery.dat", "Lew.dat", From 475292fdbbf302af2758748890402987ed93f5e9 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 23 Jun 2024 12:00:25 -0500 Subject: [PATCH 091/185] chore: distribute nist test and wrapper as part of package --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 4901c1f2..2b133a61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ homepage = "https://github.com/statrs-dev/statrs" repository = "https://github.com/statrs-dev/statrs" edition = "2018" -include = ["CHANGELOG.md", "LICENSE.md", "src/"] +include = ["CHANGELOG.md", "LICENSE.md", "src/", "tests/"] [lib] name = "statrs" From e29b69c10c7ce96e90dddabc063ea220731a574d Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 23 Jun 2024 12:19:16 -0500 Subject: [PATCH 092/185] chore: drop empty module `testing` --- src/lib.rs | 3 --- src/statistics/iter_statistics.rs | 3 +-- src/testing/mod.rs | 1 - 3 files changed, 1 insertion(+), 6 deletions(-) delete mode 100644 src/testing/mod.rs diff --git a/src/lib.rs b/src/lib.rs index bdb5fc51..568294b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -82,9 +82,6 @@ pub(crate) fn is_zero(x: f64) -> bool { ulps_eq!(x, 0.0, max_ulps = 0) } -// #[cfg(test)] -mod testing; - pub use crate::error::StatsError; /// Result type for the statrs library package that returns diff --git a/src/statistics/iter_statistics.rs b/src/statistics/iter_statistics.rs index 06941ef6..e568e531 100644 --- a/src/statistics/iter_statistics.rs +++ b/src/statistics/iter_statistics.rs @@ -245,12 +245,11 @@ where mod tests { use std::f64::consts; use rand::rngs::StdRng; - use rand::{SeedableRng}; + use rand::SeedableRng; use rand::distributions::Distribution; use crate::distribution::Normal; use crate::statistics::Statistics; use crate::generate::{InfinitePeriodic, InfiniteSinusoidal}; - use crate::testing; #[test] fn test_empty_data_returns_nan() { diff --git a/src/testing/mod.rs b/src/testing/mod.rs deleted file mode 100644 index 338bdd2e..00000000 --- a/src/testing/mod.rs +++ /dev/null @@ -1 +0,0 @@ -//! Provides testing helpers and utilities From 222539a86471c10eebfe753af4b43a1479f9f0ad Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Mon, 24 Jun 2024 19:27:51 -0500 Subject: [PATCH 093/185] test: improve ergonomics for test --- .gitignore | 3 +++ tests/gather_nist_data.sh | 16 ++++++++++------ tests/nist_tests.rs | 30 ++++++++++++++++-------------- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/.gitignore b/.gitignore index 3bb686bd..fe9e65ec 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,9 @@ # Executables *.exe +# Test data for integration tests +tests/*.dat + # Generated by Cargo /target/ *.lock diff --git a/tests/gather_nist_data.sh b/tests/gather_nist_data.sh index 9381ff08..2b663734 100755 --- a/tests/gather_nist_data.sh +++ b/tests/gather_nist_data.sh @@ -1,27 +1,31 @@ #! /bin/bash +# this script is to download and preprocess datafiles for the nist_tests.rs +# integration test for statrs downloads data to directory specified by env +# var STATRS_NIST_DATA_DIR process_file() { # Define input and output file names SOURCE=$1 FILENAME=$2 - curl -fsSL ${SOURCE}/$FILENAME > $FILENAME + TARGET=${STATRS_NIST_DATA_DIR-tests}/${FILENAME} + echo -e ${FILENAME} '\n\tDownloading...' + curl -fsSL ${SOURCE}/$FILENAME > ${TARGET} # Extract line numbers for Certified Values and Data from the header - INFO=$(grep "Certified Values:" $FILENAME) + INFO=$(grep "Certified Values:" $TARGET) CERTIFIED_VALUES_START=$(echo $INFO | awk '{print $4}') CERTIFIED_VALUES_END=$(echo $INFO | awk '{print $6}') - INFO=$(grep "Data :" $FILENAME) + INFO=$(grep "Data :" $TARGET) DATA_START=$(echo $INFO | awk '{print $4}') DATA_END=$(echo $INFO | awk '{print $6}') + echo -e '\tFormatting...' # Extract and reformat sections - # Certified values sed -n -i \ -e "${CERTIFIED_VALUES_START},${CERTIFIED_VALUES_END}p" \ -e "${DATA_START},${DATA_END}p" \ - $FILENAME - + $TARGET } URL='https://www.itl.nist.gov/div898/strd/univ/data' diff --git a/tests/nist_tests.rs b/tests/nist_tests.rs index 9efc94bf..a985820f 100644 --- a/tests/nist_tests.rs +++ b/tests/nist_tests.rs @@ -45,32 +45,34 @@ const FILENAMES: [&str; 7] = [ "NumAcc3.dat", ]; +fn get_path(fname: &str, prefix: Option<&str>) -> PathBuf { + if let Some(prefix) = prefix { + [prefix, fname].iter().collect() + } else { + ["tests", fname].iter().collect() + } +} + #[test] #[ignore = "NIST tests should not run from typical `cargo test` calls"] fn nist_strd_univariate_mean() { - let path_prefix = env::var(NIST_DATA_DIR_ENV).unwrap_or_else(|e| panic!("{}", e)); - for fname in FILENAMES { - let case = parse_file([&path_prefix, fname].iter().collect::()) - .unwrap_or_else(|e| panic!("failed parsing file {} with {:?}", fname, e)); - assert_relative_eq!( - case.values.iter().mean(), - case.certified.mean, - epsilon = 1e-12 - ); + let filepath = get_path(fname, env::var(NIST_DATA_DIR_ENV).ok().as_deref()); + let case = parse_file(filepath) + .unwrap_or_else(|e| panic!("failed parsing file {} with `{:?}`", fname, e)); + assert_relative_eq!(case.values.mean(), case.certified.mean, epsilon = 1e-12); } } #[test] #[ignore] fn nist_strd_univariate_std_dev() { - let path_prefix = env::var(NIST_DATA_DIR_ENV).unwrap_or_else(|e| panic!("{}", e)); - for fname in FILENAMES { - let case = parse_file([&path_prefix, fname].iter().collect::()) - .unwrap_or_else(|e| panic!("failed parsing file {} with {:?}", fname, e)); + let filepath = get_path(fname, env::var(NIST_DATA_DIR_ENV).ok().as_deref()); + let case = parse_file(filepath) + .unwrap_or_else(|e| panic!("failed parsing file {} with `{:?}`", fname, e)); assert_relative_eq!( - case.values.iter().std_dev(), + case.values.std_dev(), case.certified.std_dev, epsilon = 1e-10 ); From a57db2fda0fb379c2667f8e8a110c9bafc8591a2 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Mon, 24 Jun 2024 19:38:41 -0500 Subject: [PATCH 094/185] docs: describe how to run tests including NIST --- README.md | 13 +++++++++++++ tests/nist_tests.rs | 17 +++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/README.md b/README.md index 4cd36567..740e5dfe 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,19 @@ statrs = "*" # replace * by the latest version of the crate. For examples, view [the docs](https://docs.rs/statrs/*/statrs/). +### Running tests + +If you'd like to run all suggested tests, you'll need to download some data from +NIST, we have a script for this and formatting the data in the `tests/` folder. + +```sh +cargo test +./tests/gather_nist_data.sh && cargo test -- --include-ignored nist_ +``` + +If you'd like to modify where the data is downloaded, you can use the environment variable, +`STATRS_NIST_DATA_DIR` for running the script and the tests. + ## Contributing Thanks for your help to improve the project! diff --git a/tests/nist_tests.rs b/tests/nist_tests.rs index a985820f..0f731067 100644 --- a/tests/nist_tests.rs +++ b/tests/nist_tests.rs @@ -1,3 +1,20 @@ +//! This test relies on data that is reusable but not distributable by statrs as +//! such, the data will need to be downloaded from the relevant NIST StRD dataset +//! the parsing for testing assumes data to be of form, +//! ```text +//! sample mean : +//! sample std_dev : +//! sample correlation: +//! [zero or more blank lines] +//! data0 +//! data1 +//! data2 +//! ... +//! ``` +//! This test can be run on it's own from the shell from this folder as +//! ```sh +//! ./gather_nist_data.sh && cargo test -- --ignored nist_ +//! ``` use anyhow::Result; use approx::assert_relative_eq; use statrs::statistics::Statistics; From ddff53d265b8a9267bb2699e3ee17c9c32b778a3 Mon Sep 17 00:00:00 2001 From: Maxime Jacques Date: Fri, 6 May 2022 09:03:42 -0400 Subject: [PATCH 095/185] Add fisher statistical tests --- src/lib.rs | 1 + src/stats_tests/fisher.rs | 341 ++++++++++++++++++++++++++++++++++++++ src/stats_tests/mod.rs | 1 + 3 files changed, 343 insertions(+) create mode 100644 src/stats_tests/fisher.rs create mode 100644 src/stats_tests/mod.rs diff --git a/src/lib.rs b/src/lib.rs index 568294b5..a87f8ef9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -73,6 +73,7 @@ pub mod function; pub mod generate; pub mod prec; pub mod statistics; +pub mod stats_tests; mod error; diff --git a/src/stats_tests/fisher.rs b/src/stats_tests/fisher.rs new file mode 100644 index 00000000..ca2bad73 --- /dev/null +++ b/src/stats_tests/fisher.rs @@ -0,0 +1,341 @@ +// Perform a Fisher exact test on a 2x2 contingency table. +// Based on scipy's fisher test: https://github.com/scipy/scipy/blob/v1.7.0/scipy/stats/stats.py#L40757 + +use crate::distribution::{Discrete, DiscreteCDF, Hypergeometric}; +use crate::StatsError; + +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +pub enum Alternative { + TwoSided, + Less, + Greater, +} + +const EPSILON: f64 = 1.0 - 1e-4; + +fn binary_search( + n: u64, + n1: u64, + n2: u64, + mode: u64, + p_exact: f64, + epsilon: f64, + upper: bool, +) -> u64 { + // Binary search in two-sided test with starting bound as argument + let (mut min_val, mut max_val) = { + if upper { + (mode, n) + } else { + (0, mode) + } + }; + + let population = n1 + n2; + let successes = n1; + let draws = n; + let dist = Hypergeometric::new(population, successes, draws).unwrap(); + + let mut guess = 0; + loop { + if max_val - min_val <= 1 { + break; + } + guess = { + if max_val == min_val + 1 && guess == min_val { + max_val + } else { + (max_val + min_val) / 2 + } + }; + + let ng = { + if upper { + guess - 1 + } else { + guess + 1 + } + }; + + let pmf_comp = dist.pmf(ng); + let p_guess = dist.pmf(guess); + if p_guess <= p_exact && p_exact < pmf_comp { + break; + } + if p_guess < p_exact { + max_val = guess + } else { + min_val = guess + } + } + + if guess == 0 { + guess = min_val + } + if upper { + loop { + if guess > 0 && dist.pmf(guess) < p_exact * epsilon { + guess -= 1; + } else { + break; + } + } + loop { + if dist.pmf(guess) > p_exact / epsilon { + guess += 1; + } else { + break; + } + } + } else { + loop { + if dist.pmf(guess) < p_exact * epsilon { + guess += 1; + } else { + break; + } + } + loop { + if guess > 0 && dist.pmf(guess) > p_exact / epsilon { + guess -= 1; + } else { + break; + } + } + } + guess +} + +pub fn fishers_exact_with_odds_ratio( + table: &[u64; 4], + alternative: Alternative, +) -> Result<(f64, f64), StatsError> { + // Calculate fisher's exact test with the odds ratio + if (table[0] == 0 && table[2] == 0) || (table[1] == 0 && table[3] == 0) { + // If both values in a row or column are zero, p-value is 1 and odds ratio is NaN. + return Ok((f64::NAN, 1.0)); + } + + let odds_ratio = { + if table[1] > 0 && table[2] > 0 { + (table[0] * table[3]) as f64 / (table[1] * table[2]) as f64 + } else { + f64::INFINITY + } + }; + + let p_value = fishers_exact(table, alternative)?; + Ok((odds_ratio, p_value)) +} + +pub fn fishers_exact(table: &[u64; 4], alternative: Alternative) -> Result { + // Rewrite of the scipy's Fisher exact test + + // If both values in a row or column are zero, the p-value is 1 and + // the odds ratio is NaN. + if (table[0] == 0 && table[2] == 0) || (table[1] == 0 && table[3] == 0) { + return Ok(1.0); + } + + let n1 = table[0] + table[1]; + let n2 = table[2] + table[3]; + let n = table[0] + table[2]; + + let p_value = { + let population = n1 + n2; + let successes = n1; + + match alternative { + Alternative::Less => { + let draws = n; + let dist = Hypergeometric::new(population, successes, draws)?; + dist.cdf(table[0]) + } + Alternative::Greater => { + let draws = table[1] + table[3]; + let dist = Hypergeometric::new(population, successes, draws)?; + dist.cdf(table[1]) + } + Alternative::TwoSided => { + let draws = n; + let dist = Hypergeometric::new(population, successes, draws)?; + + let p_exact = dist.pmf(table[0]); + let mode = ((n + 1) * (n1 + 1)) / (n1 + n2 + 2) as u64; // todo: check floor? + let p_mode = dist.pmf(mode); + + if (p_exact - p_mode).abs() / p_exact.max(p_mode) <= 1.0 - EPSILON { + return Ok(1.0); + } + + if table[0] < mode { + let p_lower = dist.cdf(table[0]); + if dist.pmf(n) > p_exact / EPSILON { + return Ok(p_lower); + } + let guess = binary_search(n, n1, n2, mode, p_exact, EPSILON, true); + return Ok(p_lower + 1.0 - dist.cdf(guess - 1)); + } + + let p_upper = 1.0 - dist.cdf(table[0] - 1); + if dist.pmf(0) > p_exact / EPSILON { + return Ok(p_upper); + } + + let guess = binary_search(n, n1, n2, mode, p_exact, EPSILON, false); + p_upper + dist.cdf(guess) + } + } + }; + + Ok(p_value.min(1.0)) +} + +#[cfg(test)] +mod tests { + use super::fishers_exact; + use crate::prec; + use crate::stats_tests::fisher::{fishers_exact_with_odds_ratio, Alternative}; + + #[test] + fn test_fishers_exact() { + let cases = [ + ( + [3, 5, 4, 50], + 0.9963034765672599, + 0.03970749246529277, + 0.03970749246529276, + ), + ( + [61, 118, 2, 1], + 0.27535061623455315, + 0.9598172545684959, + 0.27535061623455315, + ), + ( + [172, 46, 90, 127], + 1.0, + 6.662405187351769e-16, + 9.041009036528785e-16, + ), + ( + [127, 38, 112, 43], + 0.8637599357870167, + 0.20040942958644145, + 0.3687862842650179, + ), + ( + [186, 177, 111, 154], + 0.9918518696328176, + 0.012550663906725129, + 0.023439141644624434, + ), + ( + [137, 49, 135, 183], + 0.999999999998533, + 5.6517533666400615e-12, + 8.870999836202932e-12, + ), + ( + [37, 115, 37, 152], + 0.8834621182590621, + 0.17638403366123565, + 0.29400927608021704, + ), + ( + [124, 117, 119, 175], + 0.9956704915461392, + 0.007134712391455461, + 0.011588218284387445, + ), + ( + [70, 114, 41, 118], + 0.9945558498544903, + 0.010384865876586255, + 0.020438291037108678, + ), + ( + [173, 21, 89, 7], + 0.2303739114068352, + 0.8808002774812677, + 0.4027047267306024, + ), + ( + [18, 147, 123, 58], + 4.077820702304103e-29, + 0.9999999999999817, + 0.0, + ), + ( + [116, 20, 92, 186], + 0.9999999999998267, + 6.598118571034892e-25, + 8.164831402188242e-25, + ), + ( + [9, 22, 44, 38], + 0.01584272038710196, + 0.9951463496539362, + 0.021581786662999272, + ), + ( + [9, 101, 135, 7], + 3.3336213533847776e-50, + 1.0, + 3.3336213533847776e-50, + ), + ( + [153, 27, 191, 144], + 0.9999999999950817, + 2.473736787266208e-11, + 3.185816623300107e-11, + ), + ( + [111, 195, 189, 69], + 6.665245982898848e-19, + 0.9999999999994574, + 1.0735744915712542e-18, + ), + ( + [125, 21, 31, 131], + 0.99999999999974, + 9.720661317939016e-34, + 1.0352129312860277e-33, + ), + ( + [201, 192, 69, 179], + 0.9999999988714893, + 3.1477232259550017e-09, + 4.761075937088169e-09, + ), + ( + [124, 138, 159, 160], + 0.30153826772785475, + 0.7538974235759873, + 0.5601766196310243, + ), + ]; + + for (table, less_expected, greater_expected, two_sided_expected) in cases.iter() { + for (alternative, expected) in [ + Alternative::Less, + Alternative::Greater, + Alternative::TwoSided, + ] + .iter() + .zip(vec![less_expected, greater_expected, two_sided_expected]) + { + let p_value = fishers_exact(&table, *alternative).unwrap(); + assert!(prec::almost_eq(p_value, *expected, 1e-12)); + } + } + } + #[test] + fn test_fishers_exact_with_odds() { + let table = [3, 5, 4, 50]; + let (odds_ratio, p_value) = + fishers_exact_with_odds_ratio(&table, Alternative::Less).unwrap(); + assert!(prec::almost_eq(p_value, 0.9963034765672599, 1e-12)); + assert!(prec::almost_eq(odds_ratio, 7.5, 1e-1)); + } +} diff --git a/src/stats_tests/mod.rs b/src/stats_tests/mod.rs new file mode 100644 index 00000000..322094ed --- /dev/null +++ b/src/stats_tests/mod.rs @@ -0,0 +1 @@ +mod fisher; From 737d5dbee2222b417efd55d16b05288d3aaa4deb Mon Sep 17 00:00:00 2001 From: Maxime Jacques Date: Fri, 6 May 2022 10:22:51 -0400 Subject: [PATCH 096/185] Add docstrings --- src/stats_tests/fisher.rs | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/src/stats_tests/fisher.rs b/src/stats_tests/fisher.rs index ca2bad73..7b3446fd 100644 --- a/src/stats_tests/fisher.rs +++ b/src/stats_tests/fisher.rs @@ -1,10 +1,7 @@ -// Perform a Fisher exact test on a 2x2 contingency table. -// Based on scipy's fisher test: https://github.com/scipy/scipy/blob/v1.7.0/scipy/stats/stats.py#L40757 - use crate::distribution::{Discrete, DiscreteCDF, Hypergeometric}; use crate::StatsError; -#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +#[derive(Debug, Copy, Clone)] pub enum Alternative { TwoSided, Less, @@ -13,6 +10,7 @@ pub enum Alternative { const EPSILON: f64 = 1.0 - 1e-4; +/// Binary search in two-sided test with starting bound as argument fn binary_search( n: u64, n1: u64, @@ -22,7 +20,6 @@ fn binary_search( epsilon: f64, upper: bool, ) -> u64 { - // Binary search in two-sided test with starting bound as argument let (mut min_val, mut max_val) = { if upper { (mode, n) @@ -106,6 +103,17 @@ fn binary_search( guess } +/// Perform a Fisher exact test on a 2x2 contingency table. +/// Based on scipy's fisher test: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.fisher_exact.html#scipy-stats-fisher-exact +/// Returns the odds ratio and p_value +/// # Examples +/// +/// ``` +/// use statrs::statis_tests::fishers_exact; +/// use statrs::statis_tests::Alternative; +/// let table = [3, 5, 4, 50]; +/// let (odds_ratio, p_value) = fishers_exact_with_odds_ratio(&table, Alternative::Less).unwrap(); +/// ``` pub fn fishers_exact_with_odds_ratio( table: &[u64; 4], alternative: Alternative, @@ -128,11 +136,19 @@ pub fn fishers_exact_with_odds_ratio( Ok((odds_ratio, p_value)) } +/// Perform a Fisher exact test on a 2x2 contingency table. +/// Based on scipy's fisher test: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.fisher_exact.html#scipy-stats-fisher-exact +/// Returns only the p_value +/// # Examples +/// +/// ``` +/// use statrs::statis_tests::fishers_exact; +/// use statrs::statis_tests::Alternative; +/// let table = [3, 5, 4, 50]; +/// let p_value = fishers_exact(&table, Alternative::Less).unwrap(); +/// ``` pub fn fishers_exact(table: &[u64; 4], alternative: Alternative) -> Result { - // Rewrite of the scipy's Fisher exact test - - // If both values in a row or column are zero, the p-value is 1 and - // the odds ratio is NaN. + // If both values in a row or column are zero, the p-value is 1 and the odds ratio is NaN. if (table[0] == 0 && table[2] == 0) || (table[1] == 0 && table[3] == 0) { return Ok(1.0); } @@ -161,7 +177,7 @@ pub fn fishers_exact(table: &[u64; 4], alternative: Alternative) -> Result Date: Sun, 28 Apr 2024 11:00:44 -0500 Subject: [PATCH 097/185] make `stats_test::fisher` public module --- src/stats_tests/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stats_tests/mod.rs b/src/stats_tests/mod.rs index 322094ed..7d84620d 100644 --- a/src/stats_tests/mod.rs +++ b/src/stats_tests/mod.rs @@ -1 +1 @@ -mod fisher; +pub mod fisher; From 5e3133171d9ed717847c28cd9db092b24d34eeba Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 23 Jun 2024 13:58:54 -0500 Subject: [PATCH 098/185] chore: run clippy lint --- src/stats_tests/fisher.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stats_tests/fisher.rs b/src/stats_tests/fisher.rs index 7b3446fd..7b3075f5 100644 --- a/src/stats_tests/fisher.rs +++ b/src/stats_tests/fisher.rs @@ -177,7 +177,7 @@ pub fn fishers_exact(table: &[u64; 4], alternative: Alternative) -> Result Date: Thu, 18 Jul 2024 11:13:37 -0500 Subject: [PATCH 099/185] fix: typo in `stats_tests::fisher` doctests --- src/stats_tests/fisher.rs | 11 +++++------ src/stats_tests/mod.rs | 3 +++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/stats_tests/fisher.rs b/src/stats_tests/fisher.rs index 7b3075f5..8456196d 100644 --- a/src/stats_tests/fisher.rs +++ b/src/stats_tests/fisher.rs @@ -109,8 +109,8 @@ fn binary_search( /// # Examples /// /// ``` -/// use statrs::statis_tests::fishers_exact; -/// use statrs::statis_tests::Alternative; +/// use statrs::stats_tests::fishers_exact_with_odds_ratio; +/// use statrs::stats_tests::Alternative; /// let table = [3, 5, 4, 50]; /// let (odds_ratio, p_value) = fishers_exact_with_odds_ratio(&table, Alternative::Less).unwrap(); /// ``` @@ -142,8 +142,8 @@ pub fn fishers_exact_with_odds_ratio( /// # Examples /// /// ``` -/// use statrs::statis_tests::fishers_exact; -/// use statrs::statis_tests::Alternative; +/// use statrs::stats_tests::fishers_exact; +/// use statrs::stats_tests::Alternative; /// let table = [3, 5, 4, 50]; /// let p_value = fishers_exact(&table, Alternative::Less).unwrap(); /// ``` @@ -209,9 +209,8 @@ pub fn fishers_exact(table: &[u64; 4], alternative: Alternative) -> Result Date: Thu, 18 Jul 2024 11:15:19 -0500 Subject: [PATCH 100/185] refactor: re-exports for `stats_tests` --- src/stats_tests/fisher.rs | 8 +------- src/stats_tests/mod.rs | 8 +++++++- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/stats_tests/fisher.rs b/src/stats_tests/fisher.rs index 8456196d..94826d00 100644 --- a/src/stats_tests/fisher.rs +++ b/src/stats_tests/fisher.rs @@ -1,13 +1,7 @@ +use super::Alternative; use crate::distribution::{Discrete, DiscreteCDF, Hypergeometric}; use crate::StatsError; -#[derive(Debug, Copy, Clone)] -pub enum Alternative { - TwoSided, - Less, - Greater, -} - const EPSILON: f64 = 1.0 - 1e-4; /// Binary search in two-sided test with starting bound as argument diff --git a/src/stats_tests/mod.rs b/src/stats_tests/mod.rs index 1e292518..4435fb3a 100644 --- a/src/stats_tests/mod.rs +++ b/src/stats_tests/mod.rs @@ -1,4 +1,10 @@ pub mod fisher; -pub use fisher::Alternative; +#[derive(Debug, Copy, Clone)] +pub enum Alternative { + TwoSided, + Less, + Greater, +} + pub use fisher::{fishers_exact, fishers_exact_with_odds_ratio}; From f57aab529ee0143d62e66063c2439257db8f635b Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 18 Jul 2024 12:03:59 -0500 Subject: [PATCH 101/185] doc(fisher test): basic docs --- src/stats_tests/fisher.rs | 4 +++- src/stats_tests/mod.rs | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/stats_tests/fisher.rs b/src/stats_tests/fisher.rs index 94826d00..0af5fa11 100644 --- a/src/stats_tests/fisher.rs +++ b/src/stats_tests/fisher.rs @@ -99,7 +99,8 @@ fn binary_search( /// Perform a Fisher exact test on a 2x2 contingency table. /// Based on scipy's fisher test: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.fisher_exact.html#scipy-stats-fisher-exact -/// Returns the odds ratio and p_value +/// Expects a table in row-major order +/// Returns the [odds ratio](https://en.wikipedia.org/wiki/Odds_ratio) and p_value /// # Examples /// /// ``` @@ -132,6 +133,7 @@ pub fn fishers_exact_with_odds_ratio( /// Perform a Fisher exact test on a 2x2 contingency table. /// Based on scipy's fisher test: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.fisher_exact.html#scipy-stats-fisher-exact +/// Expects a table in row-major order /// Returns only the p_value /// # Examples /// diff --git a/src/stats_tests/mod.rs b/src/stats_tests/mod.rs index 4435fb3a..84a01fc7 100644 --- a/src/stats_tests/mod.rs +++ b/src/stats_tests/mod.rs @@ -1,9 +1,16 @@ pub mod fisher; +/// Specifies an [alternative hypothesis](https://en.wikipedia.org/wiki/Alternative_hypothesis) #[derive(Debug, Copy, Clone)] pub enum Alternative { + #[doc(alias = "two-tailed")] + #[doc(alias = "two tailed")] TwoSided, + #[doc(alias = "one-tailed")] + #[doc(alias = "one tailed")] Less, + #[doc(alias = "one-tailed")] + #[doc(alias = "one tailed")] Greater, } From 5b123da2655b2806aa5db9a4cbed6059216d8ab8 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 18 Jul 2024 12:05:44 -0500 Subject: [PATCH 102/185] test(fisher test): test short circuit evaluation of all 0 row/columns --- src/stats_tests/fisher.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/stats_tests/fisher.rs b/src/stats_tests/fisher.rs index 0af5fa11..96b6ac9d 100644 --- a/src/stats_tests/fisher.rs +++ b/src/stats_tests/fisher.rs @@ -342,6 +342,16 @@ mod tests { } } } + + #[test] + fn test_fishers_exact_for_trivial() { + let cases = [[0, 0, 1, 2], [1, 2, 0, 0], [1, 0, 2, 0], [0, 1, 0, 2]]; + + for table in cases.iter() { + assert_eq!(fishers_exact(table, Alternative::Less).unwrap(), 1.0) + } + } + #[test] fn test_fishers_exact_with_odds() { let table = [3, 5, 4, 50]; From d436f7c21cbdad12e6690781cab76c263e437d58 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 18 Jul 2024 11:45:09 -0500 Subject: [PATCH 103/185] fix(fisher test): check rows for all 0's --- src/stats_tests/fisher.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/stats_tests/fisher.rs b/src/stats_tests/fisher.rs index 96b6ac9d..a18c821d 100644 --- a/src/stats_tests/fisher.rs +++ b/src/stats_tests/fisher.rs @@ -113,10 +113,11 @@ pub fn fishers_exact_with_odds_ratio( table: &[u64; 4], alternative: Alternative, ) -> Result<(f64, f64), StatsError> { - // Calculate fisher's exact test with the odds ratio - if (table[0] == 0 && table[2] == 0) || (table[1] == 0 && table[3] == 0) { - // If both values in a row or column are zero, p-value is 1 and odds ratio is NaN. - return Ok((f64::NAN, 1.0)); + // If both values in a row or column are zero, p-value is 1 and odds ratio is NaN. + match table { + [0, _, 0, _] | [_, 0, _, 0] => return Ok((f64::NAN, 1.0)), // both 0 in a row + [0, 0, _, _] | [_, _, 0, 0] => return Ok((f64::NAN, 1.0)), // both 0 in a column + _ => (), // continue } let odds_ratio = { @@ -145,8 +146,10 @@ pub fn fishers_exact_with_odds_ratio( /// ``` pub fn fishers_exact(table: &[u64; 4], alternative: Alternative) -> Result { // If both values in a row or column are zero, the p-value is 1 and the odds ratio is NaN. - if (table[0] == 0 && table[2] == 0) || (table[1] == 0 && table[3] == 0) { - return Ok(1.0); + match table { + [0, _, 0, _] | [_, 0, _, 0] => return Ok(1.0), // both 0 in a row + [0, 0, _, _] | [_, _, 0, 0] => return Ok(1.0), // both 0 in a column + _ => (), // continue } let n1 = table[0] + table[1]; From 3f140e468791b0df4c278d633bdbaa79a3efecf9 Mon Sep 17 00:00:00 2001 From: riverbl <94326797+riverbl@users.noreply.github.com> Date: Thu, 14 Mar 2024 20:59:01 +0000 Subject: [PATCH 104/185] Make `MultivariateNormal` generic over dimension Allow `MultivariateNormal` to have either a runtime known dynamic dimension (as per current) or a compile time known constant dimension. `MultivariateNormal::new_from_nalgebra` has been changed to take the mean and covariance as `OVector` and `OMatrix` rather than `DVector` and `DMatrix`. --- src/distribution/multivariate_normal.rs | 471 +++++++++++++++++------- 1 file changed, 343 insertions(+), 128 deletions(-) diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 0f16b639..993bb9ba 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -3,10 +3,9 @@ use crate::distribution::Normal; use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; use crate::{Result, StatsError}; use nalgebra::{ - base::allocator::Allocator, base::dimension::DimName, Cholesky, DefaultAllocator, Dim, DimMin, - LU, U1, + base::allocator::Allocator, Cholesky, Const, DMatrix, DVector, DefaultAllocator, Dim, DimMin, + Dyn, OMatrix, OVector, }; -use nalgebra::{DMatrix, DVector}; use rand::Rng; use std::f64; use std::f64::consts::{E, PI}; @@ -18,26 +17,30 @@ use std::f64::consts::{E, PI}; /// /// ``` /// use statrs::distribution::{MultivariateNormal, Continuous}; -/// use nalgebra::{DVector, DMatrix}; +/// use nalgebra::{matrix, vector}; /// use statrs::statistics::{MeanN, VarianceN}; /// -/// let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.]).unwrap(); -/// assert_eq!(mvn.mean().unwrap(), DVector::from_vec(vec![0., 0.])); -/// assert_eq!(mvn.variance().unwrap(), DMatrix::from_vec(2, 2, vec![1., 0., 0., 1.])); -/// assert_eq!(mvn.pdf(&DVector::from_vec(vec![1., 1.])), 0.05854983152431917); +/// let mvn = MultivariateNormal::new_from_nalgebra(vector![0., 0.], matrix![1., 0.; 0., 1.]).unwrap(); +/// assert_eq!(mvn.mean().unwrap(), vector![0., 0.]); +/// assert_eq!(mvn.variance().unwrap(), matrix![1., 0.; 0., 1.]); +/// assert_eq!(mvn.pdf(&vector![1., 1.]), 0.05854983152431917); /// ``` #[derive(Clone, PartialEq, Debug)] -pub struct MultivariateNormal { - dim: usize, - cov_chol_decomp: DMatrix, - mu: DVector, - cov: DMatrix, - precision: DMatrix, +pub struct MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + cov_chol_decomp: OMatrix, + mu: OVector, + cov: OMatrix, + precision: OMatrix, pdf_const: f64, } -impl MultivariateNormal { - /// Constructs a new multivariate normal distribution with a mean of `mean` +impl MultivariateNormal { + /// Constructs a new multivariate normal distribution with a mean of `mean` /// and covariance matrix `cov` /// /// # Errors @@ -49,17 +52,24 @@ impl MultivariateNormal { let cov = DMatrix::from_vec(mean.len(), mean.len(), cov); MultivariateNormal::new_from_nalgebra(mean, cov) } +} +impl MultivariateNormal +where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, +{ /// Constructs a new multivariate normal distribution with a mean of `mean` - /// and covariance matrix `cov`, but with explicitly using nalgebras - /// DVector and DMatrix instead of Vec + /// and covariance matrix `cov` using `nalgebra` `OVector` and `OMatrix` + /// instead of `Vec` /// /// # Errors /// /// Returns an error if the given covariance matrix is not /// symmetric or positive-definite - pub fn new_from_nalgebra(mean: DVector, cov: DMatrix) -> Result { - let dim = mean.len(); + pub fn new_from_nalgebra(mean: OVector, cov: OMatrix) -> Result { // Check that the provided covariance matrix is symmetric if cov.lower_triangle() != cov.upper_triangle().transpose() // Check that mean and covariance do not contain NaN @@ -81,7 +91,6 @@ impl MultivariateNormal { Some(cholesky_decomp) => { let precision = cholesky_decomp.inverse(); Ok(MultivariateNormal { - dim, cov_chol_decomp: cholesky_decomp.unpack(), mu: mean, cov, @@ -113,13 +122,23 @@ impl MultivariateNormal { } } -impl std::fmt::Display for MultivariateNormal { +impl std::fmt::Display for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "N({}, {})", &self.mu, &self.cov) } } -impl ::rand::distributions::Distribution> for MultivariateNormal { +impl ::rand::distributions::Distribution> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Samples from the multivariate normal distribution /// /// # Formula @@ -131,52 +150,73 @@ impl ::rand::distributions::Distribution> for MultivariateNormal { /// `Z` is a vector of normally distributed random variables, and /// `μ` is the mean vector - fn sample(&self, rng: &mut R) -> DVector { + fn sample(&self, rng: &mut R) -> OVector { let d = Normal::new(0., 1.).unwrap(); - let z = DVector::::from_distribution(self.dim, &d, rng); + let z = OVector::from_distribution_generic(self.mu.shape_generic().0, Const::<1>, &d, rng); (&self.cov_chol_decomp * z) + &self.mu } } -impl Min> for MultivariateNormal { +impl Min> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the minimum value in the domain of the /// multivariate normal distribution represented by a real vector - fn min(&self) -> DVector { - DVector::from_vec(vec![f64::NEG_INFINITY; self.dim]) + fn min(&self) -> OVector { + OMatrix::repeat_generic(self.mu.shape_generic().0, Const::<1>, f64::NEG_INFINITY) } } -impl Max> for MultivariateNormal { +impl Max> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the maximum value in the domain of the /// multivariate normal distribution represented by a real vector - fn max(&self) -> DVector { - DVector::from_vec(vec![f64::INFINITY; self.dim]) + fn max(&self) -> OVector { + OMatrix::repeat_generic(self.mu.shape_generic().0, Const::<1>, f64::INFINITY) } } -impl MeanN> for MultivariateNormal { +impl MeanN> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the mean of the normal distribution /// /// # Remarks /// /// This is the same mean used to construct the distribution - fn mean(&self) -> Option> { - let mut vec = vec![]; - for elt in self.mu.clone().into_iter() { - vec.push(*elt); - } - Some(DVector::from_vec(vec)) + fn mean(&self) -> Option> { + Some(self.mu.clone()) } } -impl VarianceN> for MultivariateNormal { +impl VarianceN> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the covariance matrix of the multivariate normal distribution - fn variance(&self) -> Option> { + fn variance(&self) -> Option> { Some(self.cov.clone()) } } -impl Mode> for MultivariateNormal { +impl Mode> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the mode of the multivariate normal distribution /// /// # Formula @@ -186,12 +226,18 @@ impl Mode> for MultivariateNormal { /// ``` /// /// where `μ` is the mean - fn mode(&self) -> DVector { + fn mode(&self) -> OVector { self.mu.clone() } } -impl<'a> Continuous<&'a DVector, f64> for MultivariateNormal { +impl<'a, D> Continuous<&'a OVector, f64> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator, D>, +{ /// Calculates the probability density function for the multivariate /// normal distribution at `x` /// @@ -203,7 +249,7 @@ impl<'a> Continuous<&'a DVector, f64> for MultivariateNormal { /// /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant /// of the covariance matrix, and `k` is the dimension of the distribution - fn pdf(&self, x: &'a DVector) -> f64 { + fn pdf(&self, x: &'a OVector) -> f64 { let dv = x - &self.mu; let exp_term = -0.5 * *(&dv.transpose() * &self.precision * &dv) @@ -214,7 +260,7 @@ impl<'a> Continuous<&'a DVector, f64> for MultivariateNormal { /// Calculates the log probability density function for the multivariate /// normal distribution at `x`. Equivalent to pdf(x).ln(). - fn ln_pdf(&self, x: &'a DVector) -> f64 { + fn ln_pdf(&self, x: &'a OVector) -> f64 { let dv = x - &self.mu; let exp_term = -0.5 * *(&dv.transpose() * &self.precision * &dv) @@ -224,7 +270,7 @@ impl<'a> Continuous<&'a DVector, f64> for MultivariateNormal { } } -impl Continuous, f64> for MultivariateNormal { +impl Continuous, f64> for MultivariateNormal { /// Calculates the probability density function for the multivariate /// normal distribution at `x` /// @@ -250,151 +296,320 @@ impl Continuous, f64> for MultivariateNormal { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{Continuous, MultivariateNormal}; - use crate::statistics::*; use core::fmt::Debug; - use nalgebra::base::allocator::Allocator; - use nalgebra::{ - DefaultAllocator, Dim, DimMin, DimName, Matrix2, Matrix3, Vector2, Vector3, - U1, U2, + + use nalgebra::{dmatrix, dvector, matrix, vector, DimMin, OMatrix, OVector}; + + use crate::{ + distribution::{Continuous, MultivariateNormal}, + statistics::{Max, MeanN, Min, Mode, VarianceN}, }; - fn try_create(mean: Vec, covariance: Vec) -> MultivariateNormal + fn try_create(mean: OVector, covariance: OMatrix) -> MultivariateNormal + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, { - let mvn = MultivariateNormal::new(mean, covariance); + let mvn = MultivariateNormal::new_from_nalgebra(mean, covariance); assert!(mvn.is_ok()); mvn.unwrap() } - fn create_case(mean: Vec, covariance: Vec) + fn create_case(mean: OVector, covariance: OMatrix) + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, { let mvn = try_create(mean.clone(), covariance.clone()); - assert_eq!(DVector::from_vec(mean.clone()), mvn.mean().unwrap()); - assert_eq!(DMatrix::from_vec(mean.len(), mean.len(), covariance), mvn.variance().unwrap()); + assert_eq!(mean, mvn.mean().unwrap()); + assert_eq!(covariance, mvn.variance().unwrap()); } - fn bad_create_case(mean: Vec, covariance: Vec) + fn bad_create_case(mean: OVector, covariance: OMatrix) + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, { - let mvn = MultivariateNormal::new(mean, covariance); + let mvn = MultivariateNormal::new_from_nalgebra(mean, covariance); assert!(mvn.is_err()); } - fn test_case(mean: Vec, covariance: Vec, expected: T, eval: F) - where + fn test_case( + mean: OVector, covariance: OMatrix, expected: T, eval: F, + ) where T: Debug + PartialEq, - F: FnOnce(MultivariateNormal) -> T, + F: FnOnce(MultivariateNormal) -> T, + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, { let mvn = try_create(mean, covariance); let x = eval(mvn); assert_eq!(expected, x); } - fn test_almost( - mean: Vec, - covariance: Vec, - expected: f64, - acc: f64, - eval: F, + fn test_almost( + mean: OVector, covariance: OMatrix, expected: f64, acc: f64, eval: F, ) where - F: FnOnce(MultivariateNormal) -> f64, + F: FnOnce(MultivariateNormal) -> f64, + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, { let mvn = try_create(mean, covariance); let x = eval(mvn); assert_almost_eq!(expected, x, acc); } - use super::*; - - macro_rules! dvec { - ($($x:expr),*) => (DVector::from_vec(vec![$($x),*])); - } - - macro_rules! mat2 { - ($x11:expr, $x12:expr, $x21:expr, $x22:expr) => (DMatrix::from_vec(2,2,vec![$x11, $x12, $x21, $x22])); - } - - // macro_rules! mat3 { - // ($x11:expr, $x12:expr, $x13:expr, $x21:expr, $x22:expr, $x23:expr, $x31:expr, $x32:expr, $x33:expr) => (DMatrix::from_vec(3,3,vec![$x11, $x12, $x13, $x21, $x22, $x23, $x31, $x32, $x33])); - // } - #[test] fn test_create() { - create_case(vec![0., 0.], vec![1., 0., 0., 1.]); - create_case(vec![10., 5.], vec![2., 1., 1., 2.]); - create_case(vec![4., 5., 6.], vec![2., 1., 0., 1., 2., 1., 0., 1., 2.]); - create_case(vec![0., f64::INFINITY], vec![1., 0., 0., 1.]); - create_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY]); + create_case(vector![0., 0.], matrix![1., 0.; 0., 1.]); + create_case(vector![10., 5.], matrix![2., 1.; 1., 2.]); + create_case( + vector![4., 5., 6.], + matrix![2., 1., 0.; 1., 2., 1.; 0., 1., 2.], + ); + create_case(dvector![0., f64::INFINITY], dmatrix![1., 0.; 0., 1.]); + create_case( + dvector![0., 0.], + dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY], + ); } #[test] fn test_bad_create() { // Covariance not symmetric - bad_create_case(vec![0., 0.], vec![1., 1., 0., 1.]); + bad_create_case(vector![0., 0.], matrix![1., 1.; 0., 1.]); // Covariance not positive-definite - bad_create_case(vec![0., 0.], vec![1., 2., 2., 1.]); + bad_create_case(vector![0., 0.], matrix![1., 2.; 2., 1.]); // NaN in mean - bad_create_case(vec![0., f64::NAN], vec![1., 0., 0., 1.]); + bad_create_case(dvector![0., f64::NAN], dmatrix![1., 0.; 0., 1.]); // NaN in Covariance Matrix - bad_create_case(vec![0., 0.], vec![1., 0., 0., f64::NAN]); + bad_create_case(dvector![0., 0.], dmatrix![1., 0.; 0., f64::NAN]); } #[test] fn test_variance() { - let variance = |x: MultivariateNormal| x.variance().unwrap(); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], mat2![1., 0., 0., 1.], variance); - test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], mat2![f64::INFINITY, 0., 0., f64::INFINITY], variance); + let variance = |x: MultivariateNormal<_>| x.variance().unwrap(); + test_case( + vector![0., 0.], + matrix![1., 0.; 0., 1.], + matrix![1., 0.; 0., 1.], + variance, + ); + test_case( + vector![0., 0.], + matrix![f64::INFINITY, 0.; 0., f64::INFINITY], + matrix![f64::INFINITY, 0.; 0., f64::INFINITY], + variance, + ); } #[test] fn test_entropy() { - let entropy = |x: MultivariateNormal| x.entropy().unwrap(); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], 2.8378770664093453, entropy); - test_case(vec![0., 0.], vec![1., 0.5, 0.5, 1.], 2.694036030183455, entropy); - test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], f64::INFINITY, entropy); + let entropy = |x: MultivariateNormal<_>| x.entropy().unwrap(); + test_case( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + 2.8378770664093453, + entropy, + ); + test_case( + dvector![0., 0.], + dmatrix![1., 0.5; 0.5, 1.], + 2.694036030183455, + entropy, + ); + test_case( + dvector![0., 0.], + dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY], + f64::INFINITY, + entropy, + ); } #[test] fn test_mode() { - let mode = |x: MultivariateNormal| x.mode(); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], dvec![0., 0.], mode); - test_case(vec![f64::INFINITY, f64::INFINITY], vec![1., 0., 0., 1.], dvec![f64::INFINITY, f64::INFINITY], mode); + let mode = |x: MultivariateNormal<_>| x.mode(); + test_case( + vector![0., 0.], + matrix![1., 0.; 0., 1.], + vector![0., 0.], + mode, + ); + test_case( + vector![f64::INFINITY, f64::INFINITY], + matrix![1., 0.; 0., 1.], + vector![f64::INFINITY, f64::INFINITY], + mode, + ); } #[test] fn test_min_max() { - let min = |x: MultivariateNormal| x.min(); - let max = |x: MultivariateNormal| x.max(); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], dvec![f64::INFINITY, f64::INFINITY], max); - test_case(vec![10., 1.], vec![1., 0., 0., 1.], dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min); - test_case(vec![-3., 5.], vec![1., 0., 0., 1.], dvec![f64::INFINITY, f64::INFINITY], max); + let min = |x: MultivariateNormal<_>| x.min(); + let max = |x: MultivariateNormal<_>| x.max(); + test_case( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + dvector![f64::NEG_INFINITY, f64::NEG_INFINITY], + min, + ); + test_case( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + dvector![f64::INFINITY, f64::INFINITY], + max, + ); + test_case( + dvector![10., 1.], + dmatrix![1., 0.; 0., 1.], + dvector![f64::NEG_INFINITY, f64::NEG_INFINITY], + min, + ); + test_case( + dvector![-3., 5.], + dmatrix![1., 0.; 0., 1.], + dvector![f64::INFINITY, f64::INFINITY], + max, + ); } #[test] fn test_pdf() { - let pdf = |arg: DVector| move |x: MultivariateNormal| x.pdf(&arg); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], 0.05854983152431917, pdf(dvec![1., 1.])); - test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 0.013064233284684921, 1e-15, pdf(dvec![1., 2.])); - test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 1.8618676045881531e-23, 1e-35, pdf(dvec![1., 10.])); - test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 5.920684802611216e-45, 1e-58, pdf(dvec![10., 10.])); - test_almost(vec![0., 0.], vec![1., 0.9, 0.9, 1.], 1.6576716577547003e-05, 1e-18, pdf(dvec![1., -1.])); - test_almost(vec![0., 0.], vec![1., 0.99, 0.99, 1.], 4.1970621773477824e-44, 1e-54, pdf(dvec![1., -1.])); - test_almost(vec![0.5, -0.2], vec![2.0, 0.3, 0.3, 0.5], 0.0013075203140666656, 1e-15, pdf(dvec![2., 2.])); - test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 0.0, pdf(dvec![10., 10.])); - test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 0.0, pdf(dvec![100., 100.])); + let pdf = |arg| move |x: MultivariateNormal<_>| x.pdf(&arg); + test_case( + vector![0., 0.], + matrix![1., 0.; 0., 1.], + 0.05854983152431917, + pdf(vector![1., 1.]), + ); + test_almost( + vector![0., 0.], + matrix![1., 0.; 0., 1.], + 0.013064233284684921, + 1e-15, + pdf(vector![1., 2.]), + ); + test_almost( + vector![0., 0.], + matrix![1., 0.; 0., 1.], + 1.8618676045881531e-23, + 1e-35, + pdf(vector![1., 10.]), + ); + test_almost( + vector![0., 0.], + matrix![1., 0.; 0., 1.], + 5.920684802611216e-45, + 1e-58, + pdf(vector![10., 10.]), + ); + test_almost( + vector![0., 0.], + matrix![1., 0.9; 0.9, 1.], + 1.6576716577547003e-05, + 1e-18, + pdf(vector![1., -1.]), + ); + test_almost( + vector![0., 0.], + matrix![1., 0.99; 0.99, 1.], + 4.1970621773477824e-44, + 1e-54, + pdf(vector![1., -1.]), + ); + test_almost( + vector![0.5, -0.2], + matrix![2.0, 0.3; 0.3, 0.5], + 0.0013075203140666656, + 1e-15, + pdf(vector![2., 2.]), + ); + test_case( + vector![0., 0.], + matrix![f64::INFINITY, 0.; 0., f64::INFINITY], + 0.0, + pdf(vector![10., 10.]), + ); + test_case( + vector![0., 0.], + matrix![f64::INFINITY, 0.; 0., f64::INFINITY], + 0.0, + pdf(vector![100., 100.]), + ); } #[test] fn test_ln_pdf() { - let ln_pdf = |arg: DVector<_>| move |x: MultivariateNormal| x.ln_pdf(&arg); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], (0.05854983152431917f64).ln(), ln_pdf(dvec![1., 1.])); - test_almost(vec![0., 0.], vec![1., 0., 0., 1.], (0.013064233284684921f64).ln(), 1e-15, ln_pdf(dvec![1., 2.])); - test_almost(vec![0., 0.], vec![1., 0., 0., 1.], (1.8618676045881531e-23f64).ln(), 1e-15, ln_pdf(dvec![1., 10.])); - test_almost(vec![0., 0.], vec![1., 0., 0., 1.], (5.920684802611216e-45f64).ln(), 1e-15, ln_pdf(dvec![10., 10.])); - test_almost(vec![0., 0.], vec![1., 0.9, 0.9, 1.], (1.6576716577547003e-05f64).ln(), 1e-14, ln_pdf(dvec![1., -1.])); - test_almost(vec![0., 0.], vec![1., 0.99, 0.99, 1.], (4.1970621773477824e-44f64).ln(), 1e-12, ln_pdf(dvec![1., -1.])); - test_almost(vec![0.5, -0.2], vec![2.0, 0.3, 0.3, 0.5], (0.0013075203140666656f64).ln(), 1e-15, ln_pdf(dvec![2., 2.])); - test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], f64::NEG_INFINITY, ln_pdf(dvec![10., 10.])); - test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], f64::NEG_INFINITY, ln_pdf(dvec![100., 100.])); + let ln_pdf = |arg| move |x: MultivariateNormal<_>| x.ln_pdf(&arg); + test_case( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + (0.05854983152431917f64).ln(), + ln_pdf(dvector![1., 1.]), + ); + test_almost( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + (0.013064233284684921f64).ln(), + 1e-15, + ln_pdf(dvector![1., 2.]), + ); + test_almost( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + (1.8618676045881531e-23f64).ln(), + 1e-15, + ln_pdf(dvector![1., 10.]), + ); + test_almost( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + (5.920684802611216e-45f64).ln(), + 1e-15, + ln_pdf(dvector![10., 10.]), + ); + test_almost( + dvector![0., 0.], + dmatrix![1., 0.9; 0.9, 1.], + (1.6576716577547003e-05f64).ln(), + 1e-14, + ln_pdf(dvector![1., -1.]), + ); + test_almost( + dvector![0., 0.], + dmatrix![1., 0.99; 0.99, 1.], + (4.1970621773477824e-44f64).ln(), + 1e-12, + ln_pdf(dvector![1., -1.]), + ); + test_almost( + dvector![0.5, -0.2], + dmatrix![2.0, 0.3; 0.3, 0.5], + (0.0013075203140666656f64).ln(), + 1e-15, + ln_pdf(dvector![2., 2.]), + ); + test_case( + dvector![0., 0.], + dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY], + f64::NEG_INFINITY, + ln_pdf(dvector![10., 10.]), + ); + test_case( + dvector![0., 0.], + dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY], + f64::NEG_INFINITY, + ln_pdf(dvector![100., 100.]), + ); } } From 31791409773ec849858e9c5e77af75fe52c3fc30 Mon Sep 17 00:00:00 2001 From: riverbl <94326797+riverbl@users.noreply.github.com> Date: Sun, 28 Apr 2024 15:49:51 +0100 Subject: [PATCH 105/185] Remove dependency on `nalgebra` `macros` feature Remove `nalgebra` `macros` feature from dependencies and add it to dev dependencies. Bump edition to 2021. --- Cargo.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2b133a61..fb2cc4c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ keywords = ["probability", "statistics", "stats", "distribution", "math"] categories = ["science"] homepage = "https://github.com/statrs-dev/statrs" repository = "https://github.com/statrs-dev/statrs" -edition = "2018" +edition = "2021" include = ["CHANGELOG.md", "LICENSE.md", "src/", "tests/"] @@ -19,10 +19,11 @@ path = "src/lib.rs" [dependencies] rand = "0.8" -nalgebra = { version = "0.32", features = ["rand"] } +nalgebra = { version = "0.32", default_features = false, features = ["rand", "std"] } approx = "0.5.0" num-traits = "0.2.14" [dev-dependencies] criterion = "0.3.3" anyhow = "1.0" +nalgebra = { version = "0.32", default_features = false, features = ["macros"] } From b37b081c43322e47b8725a435cffe61ac22a730a Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Wed, 31 Jul 2024 12:02:54 -0500 Subject: [PATCH 106/185] chore: update nalgebra feature flag name also handle a clippy error in docstring --- Cargo.toml | 4 ++-- src/distribution/internal.rs | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fb2cc4c6..ab92cb19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,11 +19,11 @@ path = "src/lib.rs" [dependencies] rand = "0.8" -nalgebra = { version = "0.32", default_features = false, features = ["rand", "std"] } +nalgebra = { version = "0.32", default-features = false, features = ["rand", "std"] } approx = "0.5.0" num-traits = "0.2.14" [dev-dependencies] criterion = "0.3.3" anyhow = "1.0" -nalgebra = { version = "0.32", default_features = false, features = ["macros"] } +nalgebra = { version = "0.32", default-features = false, features = ["macros"] } diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 301d4a9e..15c24c10 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -21,6 +21,7 @@ pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool { /// Evaluates to `None` if /// - provided interval has lower bound greater than upper bound /// - function found not semi-monotone on the provided interval containing `z` +/// /// Evaluates to `Some(k)`, where `k` satisfies the search criteria pub fn integral_bisection_search( f: impl Fn(&K) -> T, From ff63f93b06797879c16bda43325f24ee6f157c0c Mon Sep 17 00:00:00 2001 From: Amando <63152017+avhz@users.noreply.github.com> Date: Sun, 21 Jul 2024 00:08:05 +0200 Subject: [PATCH 107/185] Update uniform.rs Add a `standard()` method to Uniform Distribution. --- src/distribution/uniform.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index 9a3478bb..bcc6acce 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -61,6 +61,23 @@ impl Uniform { (true, true, true) => Ok(Uniform { min, max }), } } + + /// Constructs a new standard uniform distribution with + /// a lower bound 0 and an upper bound of 1. + /// + /// # Examples + /// + /// ``` + /// use statrs::distribution::Uniform; + /// + /// let uniform = Uniform::standard(); + /// ``` + pub fn standard() -> Self { + Self { + min: 0.0, + max: 1.0, + } + } } impl std::fmt::Display for Uniform { From 53fc9d973f42888a6063556550e8c3a85a2fbd66 Mon Sep 17 00:00:00 2001 From: avhz Date: Sun, 4 Aug 2024 22:50:44 +0200 Subject: [PATCH 108/185] add `Default` for `Uniform`, plus unit test --- src/distribution/uniform.rs | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index bcc6acce..04578a58 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -4,6 +4,7 @@ use crate::{Result, StatsError}; use rand::distributions::Uniform as RandUniform; use rand::Rng; use std::f64; +use std::fmt::Debug; /// Implements the [Continuous /// Uniform](https://en.wikipedia.org/wiki/Uniform_distribution_(continuous)) @@ -62,7 +63,7 @@ impl Uniform { } } - /// Constructs a new standard uniform distribution with + /// Constructs a new standard uniform distribution with /// a lower bound 0 and an upper bound of 1. /// /// # Examples @@ -73,10 +74,13 @@ impl Uniform { /// let uniform = Uniform::standard(); /// ``` pub fn standard() -> Self { - Self { - min: 0.0, - max: 1.0, - } + Self { min: 0.0, max: 1.0 } + } +} + +impl Default for Uniform { + fn default() -> Self { + Self::standard() } } @@ -514,4 +518,17 @@ mod tests { .all(|v| (min <= v) && (v < max)) ); } + + #[test] + fn test_default() { + let n = Uniform::default(); + + let n_mean = n.mean().unwrap(); + let n_std = n.std_dev().unwrap(); + + // Check that the mean of the distribution is close to 1 / 2 + assert_almost_eq!(n_mean, 0.5, 1e-15); + // Check that the standard deviation of the distribution is close to 1 / sqrt(12) + assert_almost_eq!(n_std, 0.288_675_134_594_812_9, 1e-15); + } } From 44ccdb97bbc58dbb5bc6fac89ac8a9f88d86e74e Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Tue, 28 May 2024 16:08:00 -0500 Subject: [PATCH 109/185] coverage: introduce coverage with llvm-cov --- .github/workflows/coverage.yml | 38 ++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 .github/workflows/coverage.yml diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml new file mode 100644 index 00000000..cb069b7e --- /dev/null +++ b/.github/workflows/coverage.yml @@ -0,0 +1,38 @@ +name: Coverage + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] +jobs: + coverage: + name: Coverage + runs-on: ubuntu-latest + env: + RUSTFLAGS: -D warnings + CARGO_TERM_COLOR: always + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: llvm-tools-preview + + - uses: taiki-e/install-action@v2 + with: + tool: nextest + - uses: taiki-e/install-action@v2 + with: + tool: cargo-llvm-cov + + - name: Collect coverage + run: | + cargo llvm-cov --no-report nextest + cargo llvm-cov report --lcov --output-path lcov.info + + - name: Upload to codecov.io + uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{secrets.CODECOV_TOKEN}} + fail_ci_if_error: false From f1b27fb6daeaf416e0f1246a33455f498265512d Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 9 Aug 2024 04:13:47 -0500 Subject: [PATCH 110/185] FEAT: add some more inverse CDFs --- src/distribution/beta.rs | 20 ++++++++++++++++++++ src/distribution/cauchy.rs | 14 ++++++++++++++ src/distribution/chi_squared.rs | 4 ++++ src/distribution/erlang.rs | 4 ++++ src/distribution/fisher_snedecor.rs | 13 +++++++++++++ src/distribution/pareto.rs | 8 ++++++++ src/distribution/weibull.rs | 8 ++++++++ 7 files changed, 71 insertions(+) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index b3cfdcf4..e17ac6f3 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -164,6 +164,14 @@ impl ContinuousCDF for Beta { beta::beta_reg(self.shape_b, self.shape_a, 1.0 - x) } } + + fn inverse_cdf(&self, x: f64) -> f64 { + if !(0.0..=1.0).contains(&x) { + panic!("x must be in [0, 1]"); + } else { + beta::inv_beta_reg(self.shape_a, self.shape_b, x) + } + } } impl Min for Beta { @@ -657,6 +665,18 @@ mod tests { } } + #[test] + fn test_inverse_cdf() { + // let inverse_cdf = |arg: f64| move |x: Beta| x.inverse_cdf(arg); + let cdf = |arg: f64| move |x: Beta| x.inverse_cdf(x.cdf(arg)); + [1.0, 2.0, 1.0, 0.6].iter() + .zip([1.0, 1.0, 5.0, 0.9].iter()) + .zip([0.0, 0.1, 0.9, 1.0].iter()) + .for_each(|((&a, &b), &val)| { + test_case(a, b, val, cdf(val)); + }); + } + #[test] fn test_cdf_input_lt_0() { let cdf = |arg: f64| move |x: Beta| x.cdf(arg); diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index b3f40bda..3cde692a 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -121,6 +121,14 @@ impl ContinuousCDF for Cauchy { fn sf(&self, x: f64) -> f64 { (1.0 / f64::consts::PI) * ((self.location - x) / self.scale).atan() + 0.5 } + + fn inverse_cdf(&self, x: f64) -> f64 { + if !(0.0..=1.0).contains(&x) { + panic!("x must be in [0, 1]"); + } else { + self.location + self.scale * (f64::consts::PI * (x - 0.5)).tan() + } + } } impl Min for Cauchy { @@ -466,6 +474,12 @@ mod tests { test_case(f64::INFINITY, 1.0, 1.0, sf(5.0)); } + #[test] + fn test_inverse_cdf() { + let icdf = |arg: f64| move |x: Cauchy| x.inverse_cdf(arg); + test_case(0.0, 1.0, -3.077683537175253, icdf(0.1)); + } + #[test] fn test_continuous() { test::check_continuous_distribution(&try_create(-1.2, 3.4), -1500.0, 1500.0); diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index 1c6b42b0..dfef0918 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -138,6 +138,10 @@ impl ContinuousCDF for ChiSquared { fn sf(&self, x: f64) -> f64 { self.g.sf(x) } + + fn inverse_cdf(&self, p: f64) -> f64 { + self.g.inverse_cdf(p) + } } impl Min for ChiSquared { diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index 1213baef..a3b330db 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -122,6 +122,10 @@ impl ContinuousCDF for Erlang { fn sf(&self, x: f64) -> f64 { self.g.sf(x) } + + fn inverse_cdf(&self, p: f64) -> f64 { + self.g.inverse_cdf(p) + } } impl Min for Erlang { diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index d54a1bef..5d9c8a11 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -155,6 +155,19 @@ impl ContinuousCDF for FisherSnedecor { ) } } + + fn inverse_cdf(&self, x: f64) -> f64 { + if !(0.0..=1.0).contains(&x) { + panic!("x must be in [0, 1]"); + } else { + let z = beta::inv_beta_reg( + self.freedom_1 / 2.0, + self.freedom_2 / 2.0, + x, + ); + self.freedom_2 / (self.freedom_1 * (1.0 / z - 1.0)) + } + } } impl Min for FisherSnedecor { diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 55df13a5..9a638b54 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -141,6 +141,14 @@ impl ContinuousCDF for Pareto { (self.scale / x).powf(self.shape) } } + + fn inverse_cdf(&self, p: f64) -> f64 { + if !(0.0..=1.0).contains(&p) { + panic!("x must be in [0, 1]"); + } else { + self.scale / (1.0 - p).powf(1.0 / self.shape) + } + } } impl Min for Pareto { diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 49dbc4d8..036fc393 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -139,6 +139,14 @@ impl ContinuousCDF for Weibull { (-x.powf(self.shape) * self.scale_pow_shape_inv).exp() } } + + fn inverse_cdf(&self, p: f64) -> f64 { + if !(0.0..=1.0).contains(&p) { + panic!("x must be in [0, 1]"); + } else { + ((-p).ln_1p() / self.scale_pow_shape_inv).powf(1.0 / self.shape) + } + } } impl Min for Weibull { From 1bb2140cc8d34ca93e784a9b579bb175cce740f4 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 9 Aug 2024 04:16:22 -0500 Subject: [PATCH 111/185] FEAT: add FullContinuous and FullDiscrete convenience options --- src/distribution/bernoulli.rs | 4 ++++ src/distribution/beta.rs | 4 ++++ src/distribution/binomial.rs | 4 ++++ src/distribution/categorical.rs | 4 ++++ src/distribution/cauchy.rs | 4 ++++ src/distribution/chi.rs | 4 ++++ src/distribution/chi_squared.rs | 4 ++++ src/distribution/dirac.rs | 2 ++ src/distribution/discrete_uniform.rs | 4 ++++ src/distribution/empirical.rs | 2 ++ src/distribution/erlang.rs | 4 ++++ src/distribution/exponential.rs | 4 ++++ src/distribution/fisher_snedecor.rs | 4 ++++ src/distribution/gamma.rs | 5 +++++ src/distribution/geometric.rs | 4 ++++ src/distribution/hypergeometric.rs | 4 ++++ src/distribution/inverse_gamma.rs | 4 ++++ src/distribution/laplace.rs | 4 ++++ src/distribution/log_normal.rs | 4 ++++ src/distribution/mod.rs | 4 ++++ src/distribution/negative_binomial.rs | 4 ++++ src/distribution/normal.rs | 4 ++++ src/distribution/pareto.rs | 4 ++++ src/distribution/poisson.rs | 5 +++++ src/distribution/students_t.rs | 4 ++++ src/distribution/triangular.rs | 4 ++++ src/distribution/uniform.rs | 4 ++++ src/distribution/weibull.rs | 4 ++++ 28 files changed, 110 insertions(+) diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index 61499ebd..c0f84de0 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -3,6 +3,8 @@ use crate::statistics::*; use crate::Result; use rand::Rng; +use super::FullDiscrete; + /// Implements the /// [Bernoulli](https://en.wikipedia.org/wiki/Bernoulli_distribution) /// distribution which is a special case of the @@ -262,6 +264,8 @@ impl Discrete for Bernoulli { } } +impl FullDiscrete for Bernoulli {} + #[rustfmt::skip] #[cfg(test)] mod testing { diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index e17ac6f3..f103a795 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -5,6 +5,8 @@ use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; +use super::FullContinuous; + /// Implements the [Beta](https://en.wikipedia.org/wiki/Beta_distribution) /// distribution /// @@ -415,6 +417,8 @@ impl Continuous for Beta { } } +impl FullContinuous for Beta {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 2b56e6fc..d60cb58f 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -6,6 +6,8 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; +use super::FullDiscrete; + /// Implements the /// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution) /// distribution @@ -326,6 +328,8 @@ impl Discrete for Binomial { } } +impl FullDiscrete for Binomial {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 31bccf8b..77f0378b 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -4,6 +4,8 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; +use super::FullDiscrete; + /// Implements the /// [Categorical](https://en.wikipedia.org/wiki/Categorical_distribution) /// distribution, also known as the generalized Bernoulli or discrete @@ -333,6 +335,8 @@ fn binary_index(search: &[f64], val: f64) -> usize { cmp::min(search.len(), cmp::max(low, 0) as usize) } +impl FullDiscrete for Categorical {} + #[test] fn test_prob_mass_to_cdf() { let arr = [0.0, 0.5, 0.5, 3.0, 1.1]; diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index 3cde692a..e5db33de 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -4,6 +4,8 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; +use super::FullContinuous; + /// Implements the [Cauchy](https://en.wikipedia.org/wiki/Cauchy_distribution) /// distribution, also known as the Lorentz distribution. /// @@ -239,6 +241,8 @@ impl Continuous for Cauchy { } } +impl FullContinuous for Cauchy {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index 205cca11..4222787b 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -5,6 +5,8 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; +use super::FullContinuous; + /// Implements the [Chi](https://en.wikipedia.org/wiki/Chi_distribution) /// distribution /// @@ -322,6 +324,8 @@ impl Continuous for Chi { } } +impl FullContinuous for Chi {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index dfef0918..862043cf 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -4,6 +4,8 @@ use crate::Result; use rand::Rng; use std::f64; +use super::FullContinuous; + /// Implements the /// [Chi-squared](https://en.wikipedia.org/wiki/Chi-squared_distribution) /// distribution which is a special case of the @@ -292,6 +294,8 @@ impl Continuous for ChiSquared { } } +impl FullContinuous for ChiSquared {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index 4fa6a390..70f86f82 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -3,6 +3,8 @@ use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; +use super::FullContinuous; + /// Implements the [Dirac Delta](https://en.wikipedia.org/wiki/Dirac_delta_function#As_a_distribution) /// distribution /// diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index 361cadd8..2f3832b1 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -3,6 +3,8 @@ use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; +use super::FullDiscrete; + /// Implements the [Discrete /// Uniform](https://en.wikipedia.org/wiki/Discrete_uniform_distribution) /// distribution @@ -253,6 +255,8 @@ impl Discrete for DiscreteUniform { } } +impl FullDiscrete for DiscreteUniform {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index 9afd7022..b566681c 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -6,6 +6,8 @@ use core::cmp::Ordering; use rand::Rng; use std::collections::BTreeMap; +use super::FullContinuous; + #[derive(Clone, PartialEq, Debug)] struct NonNan(T); diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index a3b330db..f8025507 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -3,6 +3,8 @@ use crate::statistics::*; use crate::Result; use rand::Rng; +use super::FullContinuous; + /// Implements the [Erlang](https://en.wikipedia.org/wiki/Erlang_distribution) /// distribution /// which is a special case of the @@ -279,6 +281,8 @@ impl Continuous for Erlang { } } +impl FullContinuous for Erlang {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index 978ae638..1c2c729b 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -4,6 +4,8 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; +use super::FullContinuous; + /// Implements the /// [Exp](https://en.wikipedia.org/wiki/Exp_distribution) /// distribution and is a special case of the @@ -276,6 +278,8 @@ impl Continuous for Exp { } } +impl FullContinuous for Exp {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index 5d9c8a11..785f1f81 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -5,6 +5,8 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; +use super::FullContinuous; + /// Implements the /// [Fisher-Snedecor](https://en.wikipedia.org/wiki/F-distribution) distribution /// also commonly known as the F-distribution @@ -373,6 +375,8 @@ impl Continuous for FisherSnedecor { } } +impl FullContinuous for FisherSnedecor {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 166ebb72..4634b3a5 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -5,6 +5,8 @@ use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; +use super::FullContinuous; + /// Implements the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) /// distribution /// @@ -363,6 +365,9 @@ impl Continuous for Gamma { } } } + +impl FullContinuous for Gamma {} + /// Samples from a gamma distribution with a shape of `shape` and a /// rate of `rate` using `rng` as the source of randomness. Implementation from: ///
diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index 4df623ed..adadba02 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -5,6 +5,8 @@ use rand::distributions::OpenClosed01; use rand::Rng; use std::f64; +use super::FullDiscrete; + /// Implements the /// [Geometric](https://en.wikipedia.org/wiki/Geometric_distribution) /// distribution @@ -270,6 +272,8 @@ impl Discrete for Geometric { } } +impl FullDiscrete for Geometric {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 8b6d8500..ded21304 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -6,6 +6,8 @@ use rand::Rng; use std::cmp; use std::f64; +use super::FullDiscrete; + /// Implements the /// [Hypergeometric](http://en.wikipedia.org/wiki/Hypergeometric_distribution) /// distribution @@ -375,6 +377,8 @@ impl Discrete for Hypergeometric { } } +impl FullDiscrete for Hypergeometric {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index d22d2239..f9622469 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -5,6 +5,8 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; +use super::FullContinuous; + /// Implements the [Inverse /// Gamma](https://en.wikipedia.org/wiki/Inverse-gamma_distribution) /// distribution @@ -310,6 +312,8 @@ impl Continuous for InverseGamma { } } +impl FullContinuous for InverseGamma {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index 1ed74132..370c3a87 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -4,6 +4,8 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; +use super::FullContinuous; + /// Implements the [Laplace](https://en.wikipedia.org/wiki/Laplace_distribution) /// distribution. /// @@ -297,6 +299,8 @@ impl Continuous for Laplace { } } +impl FullContinuous for Laplace {} + #[cfg(test)] mod tests { use super::*; diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index b6dbff6f..903bc24c 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -5,6 +5,8 @@ use crate::{consts, Result, StatsError}; use rand::Rng; use std::f64; +use super::FullContinuous; + /// Implements the /// [Log-normal](https://en.wikipedia.org/wiki/Log-normal_distribution) /// distribution @@ -302,6 +304,8 @@ impl Continuous for LogNormal { } } +impl FullContinuous for LogNormal {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 56deb09a..16e3b64f 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -280,3 +280,7 @@ pub trait Discrete { /// ``` fn ln_pmf(&self, x: K) -> T; } + +pub trait FullContinuous: ContinuousCDF + Continuous {} + +pub trait FullDiscrete: DiscreteCDF + Discrete {} \ No newline at end of file diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index 065c2239..e8bff754 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -5,6 +5,8 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; +use super::FullDiscrete; + /// Implements the /// [negative binomial](http://en.wikipedia.org/wiki/Negative_binomial_distribution) /// distribution. @@ -288,6 +290,8 @@ impl Discrete for NegativeBinomial { } } +impl FullDiscrete for NegativeBinomial {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 94e8c6b6..d1fabd6b 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -5,6 +5,8 @@ use crate::{consts, Result, StatsError}; use rand::Rng; use std::f64; +use super::FullContinuous; + /// Implements the [Normal](https://en.wikipedia.org/wiki/Normal_distribution) /// distribution /// @@ -292,6 +294,8 @@ impl Continuous for Normal { } } +impl FullContinuous for Normal {} + /// performs an unchecked cdf calculation for a normal distribution /// with the given mean and standard deviation at x pub fn cdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 { diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 9a638b54..8f4edc9e 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -5,6 +5,8 @@ use rand::distributions::OpenClosed01; use rand::Rng; use std::f64; +use super::FullContinuous; + /// Implements the [Pareto](https://en.wikipedia.org/wiki/Pareto_distribution) /// distribution /// @@ -341,6 +343,8 @@ impl Continuous for Pareto { } } +impl FullContinuous for Pareto {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 7653ed20..29e0b23e 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -5,6 +5,8 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; +use super::FullDiscrete; + /// Implements the [Poisson](https://en.wikipedia.org/wiki/Poisson_distribution) /// distribution /// @@ -260,6 +262,9 @@ impl Discrete for Poisson { -self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x) } } + +impl FullDiscrete for Poisson {} + /// Generates one sample from the Poisson distribution either by /// Knuth's method if lambda < 30.0 or Rejection method PA by /// A. C. Atkinson from the Journal of the Royal Statistical Society diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index af7d7356..14701b3b 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -6,6 +6,8 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; +use super::FullContinuous; + /// Implements the [Student's /// T](https://en.wikipedia.org/wiki/Student%27s_t-distribution) distribution /// @@ -420,6 +422,8 @@ impl Continuous for StudentsT { } } +impl FullContinuous for StudentsT {} + #[cfg(test)] mod tests { use crate::consts::ACC; diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index 5fa89b70..ecf1a704 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -4,6 +4,8 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; +use super::FullContinuous; + /// Implements the /// [Triangular](https://en.wikipedia.org/wiki/Triangular_distribution) /// distribution @@ -307,6 +309,8 @@ impl Continuous for Triangular { } } +impl FullContinuous for Triangular {} + fn sample_unchecked(rng: &mut R, min: f64, max: f64, mode: f64) -> f64 { let f: f64 = rng.gen(); if f < (mode - min) / (max - min) { diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index 04578a58..e78c2469 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -6,6 +6,8 @@ use rand::Rng; use std::f64; use std::fmt::Debug; +use super::FullContinuous; + /// Implements the [Continuous /// Uniform](https://en.wikipedia.org/wiki/Uniform_distribution_(continuous)) /// distribution @@ -281,6 +283,8 @@ impl Continuous for Uniform { } } +impl FullContinuous for Uniform {} + #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 036fc393..c5c9ec91 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -6,6 +6,8 @@ use crate::{consts, Result, StatsError}; use rand::Rng; use std::f64; +use super::FullContinuous; + /// Implements the [Weibull](https://en.wikipedia.org/wiki/Weibull_distribution) /// distribution /// @@ -338,6 +340,8 @@ impl Continuous for Weibull { } } +impl FullContinuous for Weibull {} + #[rustfmt::skip] #[cfg(test)] mod tests { From ee54da3d1d8d4b58296f012f1ac5619b96ff4a3b Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sat, 10 Aug 2024 07:30:38 -0500 Subject: [PATCH 112/185] Revert "FEAT: add FullContinuous and FullDiscrete convenience options" This reverts commit 7e85082a314304fe5e82598e04afad72bee7be28. --- src/distribution/bernoulli.rs | 4 ---- src/distribution/beta.rs | 4 ---- src/distribution/binomial.rs | 4 ---- src/distribution/categorical.rs | 4 ---- src/distribution/cauchy.rs | 4 ---- src/distribution/chi.rs | 4 ---- src/distribution/chi_squared.rs | 4 ---- src/distribution/dirac.rs | 2 -- src/distribution/discrete_uniform.rs | 4 ---- src/distribution/empirical.rs | 2 -- src/distribution/erlang.rs | 4 ---- src/distribution/exponential.rs | 4 ---- src/distribution/fisher_snedecor.rs | 4 ---- src/distribution/gamma.rs | 5 ----- src/distribution/geometric.rs | 4 ---- src/distribution/hypergeometric.rs | 4 ---- src/distribution/inverse_gamma.rs | 4 ---- src/distribution/laplace.rs | 4 ---- src/distribution/log_normal.rs | 4 ---- src/distribution/mod.rs | 4 ---- src/distribution/negative_binomial.rs | 4 ---- src/distribution/normal.rs | 4 ---- src/distribution/pareto.rs | 4 ---- src/distribution/poisson.rs | 5 ----- src/distribution/students_t.rs | 4 ---- src/distribution/triangular.rs | 4 ---- src/distribution/uniform.rs | 4 ---- src/distribution/weibull.rs | 4 ---- 28 files changed, 110 deletions(-) diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index c0f84de0..61499ebd 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -3,8 +3,6 @@ use crate::statistics::*; use crate::Result; use rand::Rng; -use super::FullDiscrete; - /// Implements the /// [Bernoulli](https://en.wikipedia.org/wiki/Bernoulli_distribution) /// distribution which is a special case of the @@ -264,8 +262,6 @@ impl Discrete for Bernoulli { } } -impl FullDiscrete for Bernoulli {} - #[rustfmt::skip] #[cfg(test)] mod testing { diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index f103a795..e17ac6f3 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -5,8 +5,6 @@ use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; -use super::FullContinuous; - /// Implements the [Beta](https://en.wikipedia.org/wiki/Beta_distribution) /// distribution /// @@ -417,8 +415,6 @@ impl Continuous for Beta { } } -impl FullContinuous for Beta {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index d60cb58f..2b56e6fc 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -6,8 +6,6 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; -use super::FullDiscrete; - /// Implements the /// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution) /// distribution @@ -328,8 +326,6 @@ impl Discrete for Binomial { } } -impl FullDiscrete for Binomial {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 77f0378b..31bccf8b 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -4,8 +4,6 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; -use super::FullDiscrete; - /// Implements the /// [Categorical](https://en.wikipedia.org/wiki/Categorical_distribution) /// distribution, also known as the generalized Bernoulli or discrete @@ -335,8 +333,6 @@ fn binary_index(search: &[f64], val: f64) -> usize { cmp::min(search.len(), cmp::max(low, 0) as usize) } -impl FullDiscrete for Categorical {} - #[test] fn test_prob_mass_to_cdf() { let arr = [0.0, 0.5, 0.5, 3.0, 1.1]; diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index e5db33de..3cde692a 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -4,8 +4,6 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; -use super::FullContinuous; - /// Implements the [Cauchy](https://en.wikipedia.org/wiki/Cauchy_distribution) /// distribution, also known as the Lorentz distribution. /// @@ -241,8 +239,6 @@ impl Continuous for Cauchy { } } -impl FullContinuous for Cauchy {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index 4222787b..205cca11 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -5,8 +5,6 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; -use super::FullContinuous; - /// Implements the [Chi](https://en.wikipedia.org/wiki/Chi_distribution) /// distribution /// @@ -324,8 +322,6 @@ impl Continuous for Chi { } } -impl FullContinuous for Chi {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index 862043cf..dfef0918 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -4,8 +4,6 @@ use crate::Result; use rand::Rng; use std::f64; -use super::FullContinuous; - /// Implements the /// [Chi-squared](https://en.wikipedia.org/wiki/Chi-squared_distribution) /// distribution which is a special case of the @@ -294,8 +292,6 @@ impl Continuous for ChiSquared { } } -impl FullContinuous for ChiSquared {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index 70f86f82..4fa6a390 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -3,8 +3,6 @@ use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; -use super::FullContinuous; - /// Implements the [Dirac Delta](https://en.wikipedia.org/wiki/Dirac_delta_function#As_a_distribution) /// distribution /// diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index 2f3832b1..361cadd8 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -3,8 +3,6 @@ use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; -use super::FullDiscrete; - /// Implements the [Discrete /// Uniform](https://en.wikipedia.org/wiki/Discrete_uniform_distribution) /// distribution @@ -255,8 +253,6 @@ impl Discrete for DiscreteUniform { } } -impl FullDiscrete for DiscreteUniform {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index b566681c..9afd7022 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -6,8 +6,6 @@ use core::cmp::Ordering; use rand::Rng; use std::collections::BTreeMap; -use super::FullContinuous; - #[derive(Clone, PartialEq, Debug)] struct NonNan(T); diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index f8025507..a3b330db 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -3,8 +3,6 @@ use crate::statistics::*; use crate::Result; use rand::Rng; -use super::FullContinuous; - /// Implements the [Erlang](https://en.wikipedia.org/wiki/Erlang_distribution) /// distribution /// which is a special case of the @@ -281,8 +279,6 @@ impl Continuous for Erlang { } } -impl FullContinuous for Erlang {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index 1c2c729b..978ae638 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -4,8 +4,6 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; -use super::FullContinuous; - /// Implements the /// [Exp](https://en.wikipedia.org/wiki/Exp_distribution) /// distribution and is a special case of the @@ -278,8 +276,6 @@ impl Continuous for Exp { } } -impl FullContinuous for Exp {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index 785f1f81..5d9c8a11 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -5,8 +5,6 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; -use super::FullContinuous; - /// Implements the /// [Fisher-Snedecor](https://en.wikipedia.org/wiki/F-distribution) distribution /// also commonly known as the F-distribution @@ -375,8 +373,6 @@ impl Continuous for FisherSnedecor { } } -impl FullContinuous for FisherSnedecor {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 4634b3a5..166ebb72 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -5,8 +5,6 @@ use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; -use super::FullContinuous; - /// Implements the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) /// distribution /// @@ -365,9 +363,6 @@ impl Continuous for Gamma { } } } - -impl FullContinuous for Gamma {} - /// Samples from a gamma distribution with a shape of `shape` and a /// rate of `rate` using `rng` as the source of randomness. Implementation from: ///
diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index adadba02..4df623ed 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -5,8 +5,6 @@ use rand::distributions::OpenClosed01; use rand::Rng; use std::f64; -use super::FullDiscrete; - /// Implements the /// [Geometric](https://en.wikipedia.org/wiki/Geometric_distribution) /// distribution @@ -272,8 +270,6 @@ impl Discrete for Geometric { } } -impl FullDiscrete for Geometric {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index ded21304..8b6d8500 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -6,8 +6,6 @@ use rand::Rng; use std::cmp; use std::f64; -use super::FullDiscrete; - /// Implements the /// [Hypergeometric](http://en.wikipedia.org/wiki/Hypergeometric_distribution) /// distribution @@ -377,8 +375,6 @@ impl Discrete for Hypergeometric { } } -impl FullDiscrete for Hypergeometric {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index f9622469..d22d2239 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -5,8 +5,6 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; -use super::FullContinuous; - /// Implements the [Inverse /// Gamma](https://en.wikipedia.org/wiki/Inverse-gamma_distribution) /// distribution @@ -312,8 +310,6 @@ impl Continuous for InverseGamma { } } -impl FullContinuous for InverseGamma {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index 370c3a87..1ed74132 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -4,8 +4,6 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; -use super::FullContinuous; - /// Implements the [Laplace](https://en.wikipedia.org/wiki/Laplace_distribution) /// distribution. /// @@ -299,8 +297,6 @@ impl Continuous for Laplace { } } -impl FullContinuous for Laplace {} - #[cfg(test)] mod tests { use super::*; diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 903bc24c..b6dbff6f 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -5,8 +5,6 @@ use crate::{consts, Result, StatsError}; use rand::Rng; use std::f64; -use super::FullContinuous; - /// Implements the /// [Log-normal](https://en.wikipedia.org/wiki/Log-normal_distribution) /// distribution @@ -304,8 +302,6 @@ impl Continuous for LogNormal { } } -impl FullContinuous for LogNormal {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 16e3b64f..56deb09a 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -280,7 +280,3 @@ pub trait Discrete { /// ``` fn ln_pmf(&self, x: K) -> T; } - -pub trait FullContinuous: ContinuousCDF + Continuous {} - -pub trait FullDiscrete: DiscreteCDF + Discrete {} \ No newline at end of file diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index e8bff754..065c2239 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -5,8 +5,6 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; -use super::FullDiscrete; - /// Implements the /// [negative binomial](http://en.wikipedia.org/wiki/Negative_binomial_distribution) /// distribution. @@ -290,8 +288,6 @@ impl Discrete for NegativeBinomial { } } -impl FullDiscrete for NegativeBinomial {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index d1fabd6b..94e8c6b6 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -5,8 +5,6 @@ use crate::{consts, Result, StatsError}; use rand::Rng; use std::f64; -use super::FullContinuous; - /// Implements the [Normal](https://en.wikipedia.org/wiki/Normal_distribution) /// distribution /// @@ -294,8 +292,6 @@ impl Continuous for Normal { } } -impl FullContinuous for Normal {} - /// performs an unchecked cdf calculation for a normal distribution /// with the given mean and standard deviation at x pub fn cdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 { diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 8f4edc9e..9a638b54 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -5,8 +5,6 @@ use rand::distributions::OpenClosed01; use rand::Rng; use std::f64; -use super::FullContinuous; - /// Implements the [Pareto](https://en.wikipedia.org/wiki/Pareto_distribution) /// distribution /// @@ -343,8 +341,6 @@ impl Continuous for Pareto { } } -impl FullContinuous for Pareto {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 29e0b23e..7653ed20 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -5,8 +5,6 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; -use super::FullDiscrete; - /// Implements the [Poisson](https://en.wikipedia.org/wiki/Poisson_distribution) /// distribution /// @@ -262,9 +260,6 @@ impl Discrete for Poisson { -self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x) } } - -impl FullDiscrete for Poisson {} - /// Generates one sample from the Poisson distribution either by /// Knuth's method if lambda < 30.0 or Rejection method PA by /// A. C. Atkinson from the Journal of the Royal Statistical Society diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index 14701b3b..af7d7356 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -6,8 +6,6 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; -use super::FullContinuous; - /// Implements the [Student's /// T](https://en.wikipedia.org/wiki/Student%27s_t-distribution) distribution /// @@ -422,8 +420,6 @@ impl Continuous for StudentsT { } } -impl FullContinuous for StudentsT {} - #[cfg(test)] mod tests { use crate::consts::ACC; diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index ecf1a704..5fa89b70 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -4,8 +4,6 @@ use crate::{Result, StatsError}; use rand::Rng; use std::f64; -use super::FullContinuous; - /// Implements the /// [Triangular](https://en.wikipedia.org/wiki/Triangular_distribution) /// distribution @@ -309,8 +307,6 @@ impl Continuous for Triangular { } } -impl FullContinuous for Triangular {} - fn sample_unchecked(rng: &mut R, min: f64, max: f64, mode: f64) -> f64 { let f: f64 = rng.gen(); if f < (mode - min) / (max - min) { diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index e78c2469..04578a58 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -6,8 +6,6 @@ use rand::Rng; use std::f64; use std::fmt::Debug; -use super::FullContinuous; - /// Implements the [Continuous /// Uniform](https://en.wikipedia.org/wiki/Uniform_distribution_(continuous)) /// distribution @@ -283,8 +281,6 @@ impl Continuous for Uniform { } } -impl FullContinuous for Uniform {} - #[rustfmt::skip] #[cfg(test)] mod tests { diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index c5c9ec91..036fc393 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -6,8 +6,6 @@ use crate::{consts, Result, StatsError}; use rand::Rng; use std::f64; -use super::FullContinuous; - /// Implements the [Weibull](https://en.wikipedia.org/wiki/Weibull_distribution) /// distribution /// @@ -340,8 +338,6 @@ impl Continuous for Weibull { } } -impl FullContinuous for Weibull {} - #[rustfmt::skip] #[cfg(test)] mod tests { From 25f476a6d75bc9aea3b4ecc77939066134b46fc0 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Mon, 12 Aug 2024 15:09:35 -0500 Subject: [PATCH 113/185] chore: fmt on #260 --- src/distribution/fisher_snedecor.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index 5d9c8a11..f543f4e5 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -160,11 +160,7 @@ impl ContinuousCDF for FisherSnedecor { if !(0.0..=1.0).contains(&x) { panic!("x must be in [0, 1]"); } else { - let z = beta::inv_beta_reg( - self.freedom_1 / 2.0, - self.freedom_2 / 2.0, - x, - ); + let z = beta::inv_beta_reg(self.freedom_1 / 2.0, self.freedom_2 / 2.0, x); self.freedom_2 / (self.freedom_1 * (1.0 / z - 1.0)) } } From e8200b8815d5312484e9ea3fd7e025799647015b Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 12 Aug 2024 21:51:53 +0000 Subject: [PATCH 114/185] TEST: add a bunch of iCDF tests --- src/distribution/beta.rs | 35 ++++++++++++++++++----- src/distribution/categorical.rs | 2 +- src/distribution/cauchy.rs | 33 ++++++++++++++++++++-- src/distribution/chi_squared.rs | 11 ++++++++ src/distribution/fisher_snedecor.rs | 29 +++++++++++++++++++ src/distribution/pareto.rs | 25 ++++++++++++++-- src/distribution/triangular.rs | 44 ++++++++++++++++++++++++++++- src/distribution/weibull.rs | 29 +++++++++++++++++-- 8 files changed, 193 insertions(+), 15 deletions(-) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index e17ac6f3..6f204601 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -165,6 +165,18 @@ impl ContinuousCDF for Beta { } } + /// Calculates the inverse cumulative distribution function for the beta + /// distribution + /// at `x` + /// + /// # Formula + /// + /// ```text + /// I^{-1}_x(α, β) + /// ``` + /// + /// where `α` is shapeA, `β` is shapeB, and `I_x` is the inverse of the + /// regularized lower incomplete beta function fn inverse_cdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { panic!("x must be in [0, 1]"); @@ -668,13 +680,22 @@ mod tests { #[test] fn test_inverse_cdf() { // let inverse_cdf = |arg: f64| move |x: Beta| x.inverse_cdf(arg); - let cdf = |arg: f64| move |x: Beta| x.inverse_cdf(x.cdf(arg)); - [1.0, 2.0, 1.0, 0.6].iter() - .zip([1.0, 1.0, 5.0, 0.9].iter()) - .zip([0.0, 0.1, 0.9, 1.0].iter()) - .for_each(|((&a, &b), &val)| { - test_case(a, b, val, cdf(val)); - }); + let func = |arg: f64| move |x: Beta| x.inverse_cdf(x.cdf(arg)); + let test = [ + ((1.0, 1.0), 0.0, 0.0), + ((1.0, 1.0), 0.5, 0.5), + ((1.0, 1.0), 1.0, 1.0), + ((9.0, 1.0), 0.0, 0.0), + ((9.0, 1.0), 0.001953125, 0.001953125), + ((9.0, 1.0), 0.5, 0.5), + ((9.0, 1.0), 1.0, 1.0), + ((5.0, 100.0), 0.0, 0.0), + ((5.0, 100.0), 0.01, 0.01), + ((5.0, 100.0), 1.0, 1.0), + ]; + for ((a, b), x, expect) in test { + test_case(a, b, expect, func(x)); + }; } #[test] diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 31bccf8b..2bf5dc47 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -12,7 +12,7 @@ use std::f64; /// # Examples /// /// ``` -/// +/// /// use statrs::distribution::{Categorical, Discrete}; /// use statrs::statistics::Distribution; /// use statrs::prec; diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index 3cde692a..078f7635 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -122,6 +122,16 @@ impl ContinuousCDF for Cauchy { (1.0 / f64::consts::PI) * ((self.location - x) / self.scale).atan() + 0.5 } + /// Calculates the inverse cumulative distribution function for the + /// cauchy distribution at `x` + /// + /// # Formula + /// + /// ```text + /// x_0 + γ tan((x - 0.5) π) + /// ``` + /// + /// where `x_0` is the location and `γ` is the scale fn inverse_cdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { panic!("x must be in [0, 1]"); @@ -476,8 +486,27 @@ mod tests { #[test] fn test_inverse_cdf() { - let icdf = |arg: f64| move |x: Cauchy| x.inverse_cdf(arg); - test_case(0.0, 1.0, -3.077683537175253, icdf(0.1)); + let func = |arg: f64| move |x: Cauchy| x.inverse_cdf(x.cdf(arg)); + test_almost(0.0, 0.1, -5.0, 1e-10, func(-5.0)); + test_almost(0.0, 0.1, -1.0, 1e-14, func(-1.0)); + test_case(0.0, 0.1, 0.0, func(0.0)); + test_almost(0.0, 0.1, 1.0, 1e-14, func(1.0)); + test_almost(0.0, 0.1, 5.0, 1e-10, func(5.0)); + test_almost(0.0, 1.0, -5.0, 1e-14, func(-5.0)); + test_almost(0.0, 1.0, -1.0, 1e-15, func(-1.0)); + test_case(0.0, 1.0, 0.0, func(0.0)); + test_almost(0.0, 1.0, 1.0, 1e-15, func(1.0)); + test_almost(0.0, 1.0, 5.0, 1e-14, func(5.0)); + test_almost(0.0, 10.0, -5.0, 1e-14, func(-5.0)); + test_almost(0.0, 10.0, -1.0, 1e-14, func(-1.0)); + test_case(0.0, 10.0, 0.0, func(0.0)); + test_almost(0.0, 10.0, 1.0, 1e-14, func(1.0)); + test_almost(0.0, 10.0, 5.0, 1e-14, func(5.0)); + test_case(-5.0, 100.0, -5.0, func(-5.0)); + test_almost(-5.0, 100.0, -1.0, 1e-10, func(-1.0)); + test_almost(-5.0, 100.0, 0.0, 1e-14, func(0.0)); + test_almost(-5.0, 100.0, 1.0, 1e-14, func(1.0)); + test_almost(-5.0, 100.0, 5.0, 1e-10, func(5.0)); } #[test] diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index dfef0918..911ae8c8 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -139,6 +139,17 @@ impl ContinuousCDF for ChiSquared { self.g.sf(x) } + /// Calculates the inverse cumulative distribution function for the + /// chi-squared distribution at `x` + /// + /// # Formula + /// + /// ```text + /// (1 / Γ(k / 2)) * γ(k / 2, x / 2) + /// ``` + /// + /// where `k` is the degrees of freedom, `Γ` is the gamma function, + /// and `γ` is the lower incomplete gamma function fn inverse_cdf(&self, p: f64) -> f64 { self.g.inverse_cdf(p) } diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index f543f4e5..cb38ed19 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -156,6 +156,18 @@ impl ContinuousCDF for FisherSnedecor { } } + /// Calculates the inverse cumulative distribution function for the + /// fisher-snedecor distribution at `x` + /// + /// # Formula + /// + /// ```text + /// I_((d1 * x) / (d1 * x + d2))(d1 / 2, d2 / 2) + /// ``` + /// + /// where `d1` is the first degree of freedom, `d2` is + /// the second degree of freedom, and `I` is the regularized incomplete + /// beta function fn inverse_cdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { panic!("x must be in [0, 1]"); @@ -595,6 +607,23 @@ mod tests { test_almost(10.0, 1.0, 0.65910686769794, 1e-12, sf(1.0)); } + #[test] + fn test_inverse_cdf() { + let func = |arg: f64| move |x: FisherSnedecor| x.inverse_cdf(x.cdf(arg)); + test_almost(0.1, 0.1, 0.1, 1e-12, func(0.1)); + test_almost(1.0, 0.1, 0.1, 1e-12, func(0.1)); + test_almost(10.0, 0.1, 0.1, 1e-12, func(0.1)); + test_almost(0.1, 1.0, 0.1, 1e-12, func(0.1)); + test_almost(1.0, 1.0, 0.1, 1e-12, func(0.1)); + test_almost(10.0, 1.0, 0.1, 1e-12, func(0.1)); + test_almost(0.1, 0.1, 1.0, 1e-13, func(1.0)); + test_almost(1.0, 0.1, 1.0, 1e-12, func(1.0)); + test_almost(10.0, 0.1, 1.0, 1e-12, func(1.0)); + test_almost(0.1, 1.0, 1.0, 1e-12, func(1.0)); + test_almost(1.0, 1.0, 1.0, 1e-12, func(1.0)); + test_almost(10.0, 1.0, 1.0, 1e-12, func(1.0)); + } + #[test] fn test_sf_lower_bound() { let sf = |arg: f64| move |x: FisherSnedecor| x.sf(arg); diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 9a638b54..5a22e63e 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -142,11 +142,21 @@ impl ContinuousCDF for Pareto { } } + /// Calculates the inverse cumulative distribution function for the Pareto + /// distribution at `x` + /// + /// # Formula + /// + /// ```text + /// x_m / (1 - x)^(1 / α) + /// ``` + /// + /// where `x_m` is the scale and `α` is the shape fn inverse_cdf(&self, p: f64) -> f64 { if !(0.0..=1.0).contains(&p) { panic!("x must be in [0, 1]"); } else { - self.scale / (1.0 - p).powf(1.0 / self.shape) + self.scale * (1.0 - p).powf(-1.0 / self.shape) } } } @@ -532,12 +542,23 @@ mod tests { test_case(1.0, 1.0, 1.0, sf(1.0)); test_case(5.0, 5.0, 1.0, sf(2.0)); test_almost(7.0, 7.0, 0.08235429999999999, 1e-14, sf(10.0)); - test_almost(10.0, 10.0, 0.16150558288984573, 1e14, sf(12.0)); + test_almost(10.0, 10.0, 0.16150558288984573, 1e-14, sf(12.0)); test_case(5.0, 1.0, 0.5, sf(10.0)); test_almost(3.0, 10.0, 0.0009765625, 1e-14, sf(6.0)); test_case(1.0, 1.0, 0.0, sf(f64::INFINITY)); } + #[test] + fn test_inverse_cdf() { + let func = |arg: f64| move |x: Pareto| x.inverse_cdf(x.cdf(arg)); + test_case(0.1, 0.1, 0.1, func(0.1)); + test_case(1.0, 1.0, 1.0, func(1.0)); + test_case(7.0, 7.0, 10.0, func(10.0)); + test_case(10.0, 10.0, 12.0, func(12.0)); + test_case(5.0, 1.0, 10.0, func(10.0)); + test_case(3.0, 10.0, 6.0, func(6.0)); + } + #[test] fn test_continuous() { test::check_continuous_distribution(&try_create(1.0, 10.0), 1.0, 10.0); diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index 5fa89b70..6b917a7a 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -84,7 +84,7 @@ impl ContinuousCDF for Triangular { /// } if min < x <= mode { /// (x - min)^2 / ((max - min) * (mode - min)) /// } else if mode < x < max { - /// 1 - (max - min)^2 / ((max - min) * (max - mode)) + /// 1 - (max - x)^2 / ((max - min) * (max - mode)) /// } else { /// 1 /// } @@ -134,6 +134,34 @@ impl ContinuousCDF for Triangular { 0.0 } } + + /// Calculates the inverse cumulative distribution function for the triangular + /// distribution + /// at `x` + /// + /// # Formula + /// + /// ```text + /// if x < (mode - min) / (max - min) { + /// min + ((max - min) * (mode - min) * x)^(1 / 2) + /// } else { + /// max - (1 - (max - min) * (max - mode) * x)^(1 / 2) + /// } + /// ``` + fn inverse_cdf(&self, p: f64) -> f64 { + let a = self.min; + let b = self.max; + let c = self.mode; + if !(0.0..=1.0).contains(&p) { + panic!("x must be in [0, 1]"); + } + + if p < (c - a) / (b - a) { + a + ((c - a) * (b - a) * p).powf(0.5) + } else { + b - ((b - a) * (b - c) * (1.0 - p)).powf(0.5) + } + } } impl Min for Triangular { @@ -538,6 +566,20 @@ mod tests { test_case(0.0, 3.0, 1.5, 0.0, sf(5.0)); } + #[test] + fn test_inverse_cdf() { + let func = |arg: f64| move |x: Triangular| x.inverse_cdf(x.cdf(arg)); + test_almost(0.0, 1.0, 0.5, 0.25, 1e-15, func(0.25)); + test_almost(0.0, 1.0, 0.5, 0.5, 1e-15, func(0.5)); + test_almost(0.0, 1.0, 0.5, 0.75, 1e-15, func(0.75)); + test_almost(-5.0, 8.0, -3.5, -4.0, 1e-15, func(-4.0)); + test_almost(-5.0, 8.0, -3.5, -3.5, 1e-15, func(-3.5)); + test_almost(-5.0, 8.0, -3.5, 4.0, 1e-15, func(4.0)); + test_almost(-5.0, -3.0, -4.0, -4.5, 1e-15, func(-4.5)); + test_almost(-5.0, -3.0, -4.0, -4.0, 1e-15, func(-4.0)); + test_almost(-5.0, -3.0, -4.0, -3.5, 1e-15, func(-3.5)); + } + #[test] fn test_continuous() { test::check_continuous_distribution(&try_create(-5.0, 5.0, 0.0), -5.0, 5.0); diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 036fc393..8d6369fc 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -140,12 +140,22 @@ impl ContinuousCDF for Weibull { } } + /// Calculates the inverse cumulative distribution function for the weibull + /// distribution at `x` + /// + /// # Formula + /// + /// ```text + /// λ (-ln(1 - x))^(1 / k) + /// ``` + /// + /// where `k` is the shape and `λ` is the scale fn inverse_cdf(&self, p: f64) -> f64 { if !(0.0..=1.0).contains(&p) { panic!("x must be in [0, 1]"); - } else { - ((-p).ln_1p() / self.scale_pow_shape_inv).powf(1.0 / self.shape) } + + (-((-p).ln_1p() / self.scale_pow_shape_inv)).powf(1.0 / self.shape) } } @@ -534,6 +544,21 @@ mod tests { test_case(10.0, 1.0, 0.0, sf(10.0)); } + #[test] + fn test_inverse_cdf() { + let func = |arg: f64| move |x: Weibull| x.inverse_cdf(x.cdf(arg)); + test_case(1.0, 0.1, 0.0, func(0.0)); + test_almost(1.0, 0.1, 1.0, 1e-13, func(1.0)); + test_case(1.0, 1.0, 0.0, func(0.0)); + test_case(1.0, 1.0, 1.0, func(1.0)); + test_almost(1.0, 1.0, 10.0, 1e-10, func(10.0)); + test_case(10.0, 10.0, 0.0, func(0.0)); + test_almost(10.0, 10.0, 1.0, 1e-5, func(1.0)); + test_almost(10.0, 10.0, 10.0, 1e-10, func(10.0)); + test_case(10.0, 1.0, 0.0, func(0.0)); + test_case(10.0, 1.0, 1.0, func(1.0)); + } + #[test] fn test_continuous() { test::check_continuous_distribution(&try_create(1.0, 0.2), 0.0, 10.0); From 438a2a2f13913abcf4d5cbfb6a7f7f945a3e7647 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 13 Aug 2024 15:23:06 +0000 Subject: [PATCH 115/185] DOC: update inverse CDF docstrings --- src/distribution/chi_squared.rs | 2 +- src/distribution/erlang.rs | 11 +++++++++++ src/distribution/fisher_snedecor.rs | 3 ++- src/distribution/triangular.rs | 6 +++--- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index 911ae8c8..8ead2314 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -145,7 +145,7 @@ impl ContinuousCDF for ChiSquared { /// # Formula /// /// ```text - /// (1 / Γ(k / 2)) * γ(k / 2, x / 2) + /// γ^{-1}(k / 2, x * Γ(k / 2) / 2) /// ``` /// /// where `k` is the degrees of freedom, `Γ` is the gamma function, diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index a3b330db..6fb6a098 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -123,6 +123,17 @@ impl ContinuousCDF for Erlang { self.g.sf(x) } + /// Calculates the inverse cumulative distribution function for the erlang + /// distribution at `x` + /// + /// # Formula + /// + /// ```text + /// γ^{-1}(k, (k - 1)! x) / λ + /// ``` + /// + /// where `k` is the shape, `λ` is the rate, and `γ` is the upper + /// incomplete gamma function fn inverse_cdf(&self, p: f64) -> f64 { self.g.inverse_cdf(p) } diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index cb38ed19..59ca595a 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -162,7 +162,8 @@ impl ContinuousCDF for FisherSnedecor { /// # Formula /// /// ```text - /// I_((d1 * x) / (d1 * x + d2))(d1 / 2, d2 / 2) + /// z = I^{-1}_(x)(d1 / 2, d2 / 2) + /// d2 / (d1 (1 / z - 1)) /// ``` /// /// where `d1` is the first degree of freedom, `d2` is diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index 6b917a7a..8cb48dfa 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -145,7 +145,7 @@ impl ContinuousCDF for Triangular { /// if x < (mode - min) / (max - min) { /// min + ((max - min) * (mode - min) * x)^(1 / 2) /// } else { - /// max - (1 - (max - min) * (max - mode) * x)^(1 / 2) + /// max - ((max - min) * (max - mode) * (1 - x))^(1 / 2) /// } /// ``` fn inverse_cdf(&self, p: f64) -> f64 { @@ -157,9 +157,9 @@ impl ContinuousCDF for Triangular { } if p < (c - a) / (b - a) { - a + ((c - a) * (b - a) * p).powf(0.5) + a + ((c - a) * (b - a) * p).sqrt() } else { - b - ((b - a) * (b - c) * (1.0 - p)).powf(0.5) + b - ((b - a) * (b - c) * (1.0 - p)).sqrt() } } } From 039c54509c35444b7a785d7760ebf852240231c4 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 13 Aug 2024 15:42:46 +0000 Subject: [PATCH 116/185] FMT: fix formatting issue --- src/distribution/categorical.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 2bf5dc47..31bccf8b 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -12,7 +12,7 @@ use std::f64; /// # Examples /// /// ``` -/// +/// /// use statrs::distribution::{Categorical, Discrete}; /// use statrs::statistics::Distribution; /// use statrs::prec; From 258a3f914a1af63ba098685747e73864c68c7e04 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 7 Aug 2024 12:06:54 +0200 Subject: [PATCH 117/185] Replace `is_zero(x)` with `x == 0.0` The function itself only seems to exist as a workaround to a clippy warn/err, but with max_ulps = 0 it is almost a normal equality operation --- src/distribution/beta.rs | 9 ++++----- src/distribution/binomial.rs | 9 ++++----- src/distribution/students_t.rs | 1 - src/distribution/weibull.rs | 5 ++--- src/function/beta.rs | 3 +-- src/function/erf.rs | 3 +-- src/function/gamma.rs | 3 +-- src/lib.rs | 6 ------ 8 files changed, 13 insertions(+), 26 deletions(-) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 6f204601..cfe1bb0f 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::{beta, gamma}; -use crate::is_zero; use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; @@ -359,7 +358,7 @@ impl Continuous for Beta { 0.0 } } else if self.shape_b.is_infinite() { - if is_zero(x) { + if x == 0.0 { f64::INFINITY } else { 0.0 @@ -397,7 +396,7 @@ impl Continuous for Beta { f64::NEG_INFINITY } } else if self.shape_b.is_infinite() { - if is_zero(x) { + if x == 0.0 { f64::INFINITY } else { f64::NEG_INFINITY @@ -408,9 +407,9 @@ impl Continuous for Beta { let aa = gamma::ln_gamma(self.shape_a + self.shape_b) - gamma::ln_gamma(self.shape_a) - gamma::ln_gamma(self.shape_b); - let bb = if ulps_eq!(self.shape_a, 1.0) && is_zero(x) { + let bb = if ulps_eq!(self.shape_a, 1.0) && x == 0.0 { 0.0 - } else if is_zero(x) { + } else if x == 0.0 { f64::NEG_INFINITY } else { (self.shape_a - 1.0) * x.ln() diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 2b56e6fc..07060bfc 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -1,6 +1,5 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{beta, factorial}; -use crate::is_zero; use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; @@ -207,7 +206,7 @@ impl Distribution for Binomial { /// (1 / 2) * ln (2 * π * e * n * p * (1 - p)) /// ``` fn entropy(&self) -> Option { - let entr = if is_zero(self.p) || ulps_eq!(self.p, 1.0) { + let entr = if self.p == 0.0 || ulps_eq!(self.p, 1.0) { 0.0 } else { (0..self.n + 1).fold(0.0, |acc, x| { @@ -252,7 +251,7 @@ impl Mode> for Binomial { /// floor((n + 1) * p) /// ``` fn mode(&self) -> Option { - let mode = if is_zero(self.p) { + let mode = if self.p == 0.0 { 0 } else if ulps_eq!(self.p, 1.0) { self.n @@ -275,7 +274,7 @@ impl Discrete for Binomial { fn pmf(&self, x: u64) -> f64 { if x > self.n { 0.0 - } else if is_zero(self.p) { + } else if self.p == 0.0 { if x == 0 { 1.0 } else { @@ -306,7 +305,7 @@ impl Discrete for Binomial { fn ln_pmf(&self, x: u64) -> f64 { if x > self.n { f64::NEG_INFINITY - } else if is_zero(self.p) { + } else if self.p == 0.0 { if x == 0 { 0.0 } else { diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index af7d7356..08d4afe9 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::{beta, gamma}; -use crate::is_zero; use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 8d6369fc..1382998c 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; -use crate::is_zero; use crate::statistics::*; use crate::{consts, Result, StatsError}; use rand::Rng; @@ -311,7 +310,7 @@ impl Continuous for Weibull { fn pdf(&self, x: f64) -> f64 { if x < 0.0 { 0.0 - } else if is_zero(x) && ulps_eq!(self.shape, 1.0) { + } else if x == 0.0 && ulps_eq!(self.shape, 1.0) { 1.0 / self.scale } else if x.is_infinite() { 0.0 @@ -336,7 +335,7 @@ impl Continuous for Weibull { fn ln_pdf(&self, x: f64) -> f64 { if x < 0.0 { f64::NEG_INFINITY - } else if is_zero(x) && ulps_eq!(self.shape, 1.0) { + } else if x == 0.0 && ulps_eq!(self.shape, 1.0) { 0.0 - self.scale.ln() } else if x.is_infinite() { f64::NEG_INFINITY diff --git a/src/function/beta.rs b/src/function/beta.rs index 128406c7..794a8504 100644 --- a/src/function/beta.rs +++ b/src/function/beta.rs @@ -3,7 +3,6 @@ use crate::error::StatsError; use crate::function::gamma; -use crate::is_zero; use crate::prec; use crate::Result; use std::f64; @@ -118,7 +117,7 @@ pub fn checked_beta_reg(a: f64, b: f64, x: f64) -> Result { } else if !(0.0..=1.0).contains(&x) { Err(StatsError::ArgIntervalIncl("x", 0.0, 1.0)) } else { - let bt = if is_zero(x) || ulps_eq!(x, 1.0) { + let bt = if x == 0.0 || ulps_eq!(x, 1.0) { 0.0 } else { (gamma::ln_gamma(a + b) - gamma::ln_gamma(a) - gamma::ln_gamma(b) diff --git a/src/function/erf.rs b/src/function/erf.rs index e42fd584..92b68858 100644 --- a/src/function/erf.rs +++ b/src/function/erf.rs @@ -2,7 +2,6 @@ //! related functions use crate::function::evaluate; -use crate::is_zero; use std::f64; /// `erf` calculates the error function at `x`. @@ -13,7 +12,7 @@ pub fn erf(x: f64) -> f64 { 1.0 } else if x <= 0.0 && x.is_infinite() { -1.0 - } else if is_zero(x) { + } else if x == 0.0 { 0.0 } else { erf_impl(x, false) diff --git a/src/function/gamma.rs b/src/function/gamma.rs index 9d5124f9..ced63871 100644 --- a/src/function/gamma.rs +++ b/src/function/gamma.rs @@ -3,7 +3,6 @@ use crate::consts; use crate::error::StatsError; -use crate::is_zero; use crate::prec; use crate::Result; use std::f64; @@ -216,7 +215,7 @@ pub fn checked_gamma_ur(a: f64, x: f64) -> Result { qkm1 *= big_inv; } - if !is_zero(qk) { + if qk != 0.0 { let r = pk / qk; let t = ((ans - r) / r).abs(); ans = r; diff --git a/src/lib.rs b/src/lib.rs index a87f8ef9..df65187c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,12 +77,6 @@ pub mod stats_tests; mod error; -// function to silence clippy on the special case when comparing to zero. -#[inline(always)] -pub(crate) fn is_zero(x: f64) -> bool { - ulps_eq!(x, 0.0, max_ulps = 0) -} - pub use crate::error::StatsError; /// Result type for the statrs library package that returns From 4f362c63710fe6fced8fc441608527970c0ba2d7 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 7 Aug 2024 12:31:07 +0200 Subject: [PATCH 118/185] Remove unused imports --- src/distribution/beta.rs | 1 - src/distribution/dirac.rs | 4 ++-- src/distribution/dirichlet.rs | 6 ------ src/distribution/empirical.rs | 5 ++--- src/distribution/internal.rs | 2 +- src/distribution/mod.rs | 6 ++---- src/distribution/multivariate_normal.rs | 5 +---- src/lib.rs | 1 - src/statistics/iter_statistics.rs | 4 ---- src/statistics/mod.rs | 1 - src/statistics/traits.rs | 3 --- 11 files changed, 8 insertions(+), 30 deletions(-) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index cfe1bb0f..c6df005b 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -431,7 +431,6 @@ impl Continuous for Beta { mod tests { use super::*; use super::super::internal::*; - use crate::statistics::*; use crate::testing_boiler; testing_boiler!(a: f64, b: f64; Beta); diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index 4fa6a390..c664d8db 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -1,4 +1,4 @@ -use crate::distribution::{Continuous, ContinuousCDF}; +use crate::distribution::ContinuousCDF; use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; @@ -194,7 +194,7 @@ impl Mode> for Dirac { #[cfg(test)] mod tests { use crate::statistics::*; - use crate::distribution::{ContinuousCDF, Continuous, Dirac}; + use crate::distribution::{ContinuousCDF, Dirac}; fn try_create(v: f64) -> Dirac { let d = Dirac::new(v); diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index 55795a7a..cc96e102 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -4,9 +4,6 @@ use crate::statistics::*; use crate::{prec, Result, StatsError}; use nalgebra::DMatrix; use nalgebra::DVector; -use nalgebra::{ - base::allocator::Allocator, base::dimension::DimName, DefaultAllocator, Dim, DimMin, U1, -}; use rand::Rng; use std::f64; @@ -310,9 +307,6 @@ fn is_valid_alpha(a: &[f64]) -> bool { #[cfg(test)] mod tests { use super::*; - use nalgebra::{DVector}; - use crate::function::gamma; - use crate::statistics::*; use crate::distribution::{Continuous, Dirichlet}; #[test] diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index 9afd7022..104169aa 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -1,7 +1,6 @@ -use crate::distribution::{Continuous, ContinuousCDF, Uniform}; +use crate::distribution::{ContinuousCDF, Uniform}; use crate::statistics::*; -use crate::{Result, StatsError}; -use ::num_traits::float::Float; +use crate::Result; use core::cmp::Ordering; use rand::Rng; use std::collections::BTreeMap; diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 15c24c10..b6e71136 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -1,4 +1,4 @@ -use num_traits::{Bounded, Float, Num}; +use num_traits::Num; /// Returns true if there are no elements in `x` in `arr` /// such that `x <= 0.0` or `x` is `f64::NAN` and `sum(arr) > 0.0`. diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 56deb09a..dea1f9af 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -2,8 +2,8 @@ //! and provides //! concrete implementations for a variety of distributions. use super::statistics::{Max, Min}; -use ::num_traits::{Bounded, Float, Num}; -use num_traits::{NumAssign, NumAssignOps, NumAssignRef}; +use ::num_traits::{Float, Num}; +use num_traits::NumAssignOps; pub use self::bernoulli::Bernoulli; pub use self::beta::Beta; @@ -71,8 +71,6 @@ mod weibull; mod ziggurat; mod ziggurat_tables; -use crate::Result; - /// The `ContinuousCDF` trait is used to specify an interface for univariate /// distributions for which cdf float arguments are sensible. pub trait ContinuousCDF: Min + Max { diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 993bb9ba..9949d676 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -2,10 +2,7 @@ use crate::distribution::Continuous; use crate::distribution::Normal; use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; use crate::{Result, StatsError}; -use nalgebra::{ - base::allocator::Allocator, Cholesky, Const, DMatrix, DVector, DefaultAllocator, Dim, DimMin, - Dyn, OMatrix, OVector, -}; +use nalgebra::{Cholesky, Const, DMatrix, DVector, Dim, DimMin, Dyn, OMatrix, OVector}; use rand::Rng; use std::f64; use std::f64::consts::{E, PI}; diff --git a/src/lib.rs b/src/lib.rs index df65187c..56f9b162 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,7 +47,6 @@ #![crate_name = "statrs"] #![allow(clippy::excessive_precision)] #![allow(clippy::many_single_char_names)] -#![allow(unused_imports)] #![forbid(unsafe_code)] #[macro_use] diff --git a/src/statistics/iter_statistics.rs b/src/statistics/iter_statistics.rs index e568e531..97b7fd17 100644 --- a/src/statistics/iter_statistics.rs +++ b/src/statistics/iter_statistics.rs @@ -244,10 +244,6 @@ where #[cfg(test)] mod tests { use std::f64::consts; - use rand::rngs::StdRng; - use rand::SeedableRng; - use rand::distributions::Distribution; - use crate::distribution::Normal; use crate::statistics::Statistics; use crate::generate::{InfinitePeriodic, InfiniteSinusoidal}; diff --git a/src/statistics/mod.rs b/src/statistics/mod.rs index 6a2d3342..272091a4 100644 --- a/src/statistics/mod.rs +++ b/src/statistics/mod.rs @@ -1,6 +1,5 @@ //! Provides traits for statistical computation -pub use self::iter_statistics::*; pub use self::order_statistics::*; pub use self::slice_statistics::*; pub use self::statistics::*; diff --git a/src/statistics/traits.rs b/src/statistics/traits.rs index d264d719..e5211075 100644 --- a/src/statistics/traits.rs +++ b/src/statistics/traits.rs @@ -1,6 +1,3 @@ -use ::nalgebra::{ - base::allocator::Allocator, base::dimension::DimName, DefaultAllocator, Dim, DimMin, U1, -}; use ::num_traits::float::Float; const STEPS: usize = 1_000; From a239ebe25006c985115bd1d127804a42413111d4 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 2 Aug 2024 00:53:46 -0500 Subject: [PATCH 119/185] feat(multivariate)!: migrate Dirichlet API to generic dimensions --- src/distribution/dirichlet.rs | 257 ++++++++++++++++++++++------------ 1 file changed, 168 insertions(+), 89 deletions(-) diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index cc96e102..eb4d7958 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -2,8 +2,7 @@ use crate::distribution::Continuous; use crate::function::gamma; use crate::statistics::*; use crate::{prec, Result, StatsError}; -use nalgebra::DMatrix; -use nalgebra::DVector; +use nalgebra::{Const, Dim, Dyn, OMatrix, OVector}; use rand::Rng; use std::f64; @@ -24,10 +23,15 @@ use std::f64; /// assert_eq!(n.pdf(&DVector::from_vec(vec![0.33333, 0.33333, 0.33333])), 2.222155556222205); /// ``` #[derive(Clone, PartialEq, Debug)] -pub struct Dirichlet { - alpha: DVector, +pub struct Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + alpha: OVector, } -impl Dirichlet { + +impl Dirichlet { /// Constructs a new dirichlet distribution with the given /// concentration parameters (alpha) /// @@ -51,15 +55,8 @@ impl Dirichlet { /// result = Dirichlet::new(alpha_err); /// assert!(result.is_err()); /// ``` - pub fn new(alpha: Vec) -> Result { - if !is_valid_alpha(&alpha) { - Err(StatsError::BadParams) - } else { - // let vec = alpha.to_vec(); - Ok(Dirichlet { - alpha: DVector::from_vec(alpha.to_vec()), - }) - } + pub fn new(alpha: Vec) -> Result { + Self::new_from_nalgebra(alpha.into()) } /// Constructs a new dirichlet distribution with the given @@ -81,9 +78,30 @@ impl Dirichlet { /// result = Dirichlet::new_with_param(0.0, 1); /// assert!(result.is_err()); /// ``` - pub fn new_with_param(alpha: f64, n: usize) -> Result { + pub fn new_with_param(alpha: f64, n: usize) -> Result { Self::new(vec![alpha; n]) } +} + +impl Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + /// Constructs a new distribution with the given vector for `alpha` + /// Does not clone the vector it takes ownership of + /// + /// # Error + /// + /// Returns an error if vector has length less than 2 or if any element + /// of alpha is NOT finite positive + pub fn new_from_nalgebra(alpha: OVector) -> Result { + if !is_valid_alpha(alpha.as_slice()) { + Err(StatsError::BadParams) + } else { + Ok(Self { alpha }) + } + } /// Returns the concentration parameters of /// the dirichlet distribution as a slice @@ -97,12 +115,12 @@ impl Dirichlet { /// let n = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap(); /// assert_eq!(n.alpha(), &DVector::from_vec(vec![1.0, 2.0, 3.0])); /// ``` - pub fn alpha(&self) -> &DVector { + pub fn alpha(&self) -> &nalgebra::OVector { &self.alpha } fn alpha_sum(&self) -> f64 { - self.alpha.fold(0.0, |acc, x| acc + x) + self.alpha.sum() } /// Returns the entropy of the dirichlet distribution @@ -134,30 +152,40 @@ impl Dirichlet { } } -impl std::fmt::Display for Dirichlet { +impl std::fmt::Display for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Dir({}, {})", self.alpha.len(), &self.alpha) } } -impl ::rand::distributions::Distribution> for Dirichlet { - fn sample(&self, rng: &mut R) -> DVector { +impl ::rand::distributions::Distribution> for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + fn sample(&self, rng: &mut R) -> OVector { let mut sum = 0.0; - let mut samples: Vec<_> = self - .alpha - .iter() - .map(|&a| { + OVector::from_iterator_generic( + self.alpha.shape_generic().0, + Const::<1>, + self.alpha.iter().map(|&a| { let sample = super::gamma::sample_unchecked(rng, a, 1.0); sum += sample; sample - }) - .collect(); - for _ in samples.iter_mut().map(|x| *x /= sum) {} - DVector::from_vec(samples) + }), + ) } } -impl MeanN> for Dirichlet { +impl MeanN> for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ /// Returns the means of the dirichlet distribution /// /// # Formula @@ -168,13 +196,18 @@ impl MeanN> for Dirichlet { /// /// for the `i`th element where `α_i` is the `i`th concentration parameter /// and `α_0` is the sum of all concentration parameters - fn mean(&self) -> Option> { + fn mean(&self) -> Option> { let sum = self.alpha_sum(); Some(self.alpha.map(|x| x / sum)) } } -impl VarianceN> for Dirichlet { +impl VarianceN> for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the variances of the dirichlet distribution /// /// # Formula @@ -185,10 +218,10 @@ impl VarianceN> for Dirichlet { /// /// for the `i`th element where `α_i` is the `i`th concentration parameter /// and `α_0` is the sum of all concentration parameters - fn variance(&self) -> Option> { + fn variance(&self) -> Option> { let sum = self.alpha_sum(); let normalizing = sum * sum * (sum + 1.0); - let mut cov = DMatrix::from_diagonal(&self.alpha.map(|x| x * (sum - x) / normalizing)); + let mut cov = OMatrix::from_diagonal(&self.alpha.map(|x| x * (sum - x) / normalizing)); let mut offdiag = |x: usize, y: usize| { let elt = -self.alpha[x] * self.alpha[y] / normalizing; cov[(x, y)] = elt; @@ -203,7 +236,13 @@ impl VarianceN> for Dirichlet { } } -impl<'a> Continuous<&'a DVector, f64> for Dirichlet { +impl<'a, D> Continuous<&'a OVector, f64> for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator, D>, +{ /// Calculates the probabiliy density function for the dirichlet /// distribution /// with given `x`'s corresponding to the concentration parameters for this @@ -234,7 +273,7 @@ impl<'a> Continuous<&'a DVector, f64> for Dirichlet { /// the `i`th concentration parameter, `Γ` is the gamma function, /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`, /// and `K` is the number of concentration parameters - fn pdf(&self, x: &DVector) -> f64 { + fn pdf(&self, x: &OVector) -> f64 { self.ln_pdf(x).exp() } @@ -268,7 +307,7 @@ impl<'a> Continuous<&'a DVector, f64> for Dirichlet { /// the `i`th concentration parameter, `Γ` is the gamma function, /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`, /// and `K` is the number of concentration parameters - fn ln_pdf(&self, x: &DVector) -> f64 { + fn ln_pdf(&self, x: &OVector) -> f64 { // TODO: would it be clearer here to just do a for loop instead // of using iterators? if self.alpha.len() != x.len() { @@ -300,55 +339,71 @@ impl<'a> Continuous<&'a DVector, f64> for Dirichlet { // determines if `a` is a valid alpha array // for the Dirichlet distribution fn is_valid_alpha(a: &[f64]) -> bool { - a.len() >= 2 && super::internal::is_valid_multinomial(a, false) + a.len() >= 2 && a.iter().all(|&a_i| a_i.is_finite() && a_i > 0.0) } #[rustfmt::skip] #[cfg(test)] mod tests { + use nalgebra::{dvector, vector, DimMin, OVector}; + use super::*; - use crate::distribution::{Continuous, Dirichlet}; + use crate::distribution::Continuous; - #[test] - fn test_is_valid_alpha() { - let invalid = [1.0]; - assert!(!is_valid_alpha(&invalid)); + fn try_create(alpha: OVector) -> Dirichlet + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + { + let mvn = Dirichlet::new_from_nalgebra(alpha); + assert!(mvn.is_ok()); + mvn.unwrap() } - fn try_create(alpha: &[f64]) -> Dirichlet + fn bad_create_case(alpha: OVector) + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { - let n = Dirichlet::new(alpha.to_vec()); - assert!(n.is_ok()); - n.unwrap() + let dd = Dirichlet::new_from_nalgebra(alpha); + assert!(dd.is_err()); } - fn create_case(alpha: &[f64]) + fn test_almost(alpha: OVector, expected: f64, acc: f64, eval: F) + where + F: FnOnce(Dirichlet) -> f64, + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { - let n = try_create(alpha); - let a2 = n.alpha(); - for i in 0..alpha.len() { - assert_eq!(alpha[i], a2[i]); - } + let dd = try_create(alpha); + let x = eval(dd); + assert_almost_eq!(expected, x, acc); } - fn bad_create_case(alpha: &[f64]) - { - let n = Dirichlet::new(alpha.to_vec()); - assert!(n.is_err()); + #[test] + fn test_is_valid_alpha() { + assert!(!is_valid_alpha(&[1.0])); + assert!(!is_valid_alpha(&[1.0, f64::NAN])); + assert!(is_valid_alpha(&[1.0, 2.0])); + assert!(!is_valid_alpha(&[1.0, 0.0])); + assert!(!is_valid_alpha(&[1.0, f64::INFINITY])); + assert!(!is_valid_alpha(&[-1.0, 2.0])); } #[test] fn test_create() { - create_case(&[1.0, 2.0, 3.0, 4.0, 5.0]); - create_case(&[0.001, f64::INFINITY, 3756.0]); + try_create(vector![1.0, 2.0, 3.0, 4.0, 5.0]); + assert!(Dirichlet::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]).is_ok()); + // try_create(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate } #[test] fn test_bad_create() { - bad_create_case(&[1.0]); - bad_create_case(&[1.0, 2.0, 0.0, 4.0, 5.0]); - bad_create_case(&[1.0, f64::NAN, 3.0, 4.0, 5.0]); - bad_create_case(&[0.0, 0.0, 0.0]); + bad_create_case(vector![1.0]); + bad_create_case(vector![1.0, 2.0, 0.0, 4.0, 5.0]); + bad_create_case(vector![1.0, f64::NAN, 3.0, 4.0, 5.0]); + bad_create_case(vector![0.0, 0.0, 0.0]); + bad_create_case(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate } // #[test] @@ -386,70 +441,94 @@ mod tests { #[test] fn test_entropy() { - let mut n = try_create(&[0.1, 0.3, 0.5, 0.8]); - assert_eq!(n.entropy().unwrap(), -17.46469081094079); - - n = try_create(&[0.1, 0.2, 0.3, 0.4]); - assert_eq!(n.entropy().unwrap(), -21.53881433791513); - } - - macro_rules! dvec { - ($($x:expr),*) => (DVector::from_vec(vec![$($x),*])); + let entropy = |x: Dirichlet<_>| x.entropy().unwrap(); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + -17.46469081094079, + 1e-30, + entropy, + ); + test_almost( + vector![0.1, 0.2, 0.3, 0.4], + -21.53881433791513, + 1e-30, + entropy, + ); } #[test] fn test_pdf() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - assert_almost_eq!(n.pdf(&dvec![0.01, 0.03, 0.5, 0.46]), 18.77225681167061, 1e-12); - assert_almost_eq!(n.pdf(&dvec![0.1,0.2,0.3,0.4]), 0.8314656481199253, 1e-14); + let pdf = |arg| move |x: Dirichlet<_>| x.pdf(&arg); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + 18.77225681167061, + 1e-12, + pdf([0.01, 0.03, 0.5, 0.46].into()), + ); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + 0.8314656481199253, + 1e-14, + pdf([0.1, 0.2, 0.3, 0.4].into()), + ); } #[test] fn test_ln_pdf() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - assert_almost_eq!(n.ln_pdf(&dvec![0.01, 0.03, 0.5, 0.46]), 18.77225681167061f64.ln(), 1e-12); - assert_almost_eq!(n.ln_pdf(&dvec![0.1,0.2,0.3,0.4]), 0.8314656481199253f64.ln(), 1e-14); + let ln_pdf = |arg| move |x: Dirichlet<_>| x.ln_pdf(&arg); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + 18.77225681167061_f64.ln(), + 1e-12, + ln_pdf([0.01, 0.03, 0.5, 0.46].into()), + ); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + 0.8314656481199253_f64.ln(), + 1e-14, + ln_pdf([0.1, 0.2, 0.3, 0.4].into()), + ); } #[test] #[should_panic] fn test_pdf_bad_input_length() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.pdf(&dvec![0.5]); + let n = try_create(dvector![0.1, 0.3, 0.5, 0.8]); + n.pdf(&dvector![0.5]); } #[test] #[should_panic] fn test_pdf_bad_input_range() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.pdf(&dvec![1.5, 0.0, 0.0, 0.0]); + let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); + n.pdf(&vector![1.5, 0.0, 0.0, 0.0]); } #[test] #[should_panic] fn test_pdf_bad_input_sum() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.pdf(&dvec![0.5, 0.25, 0.8, 0.9]); + let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); + n.pdf(&vector![0.5, 0.25, 0.8, 0.9]); } #[test] #[should_panic] fn test_ln_pdf_bad_input_length() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.ln_pdf(&dvec![0.5]); + let n = try_create(dvector![0.1, 0.3, 0.5, 0.8]); + n.ln_pdf(&dvector![0.5]); } #[test] #[should_panic] fn test_ln_pdf_bad_input_range() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.ln_pdf(&dvec![1.5, 0.0, 0.0, 0.0]); + let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); + n.ln_pdf(&vector![1.5, 0.0, 0.0, 0.0]); } #[test] #[should_panic] fn test_ln_pdf_bad_input_sum() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.ln_pdf(&dvec![0.5, 0.25, 0.8, 0.9]); + let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); + n.ln_pdf(&vector![0.5, 0.25, 0.8, 0.9]); } } From d4c109772056963f4f1e357d19a0fb2ccdd43007 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 2 Aug 2024 12:39:09 -0500 Subject: [PATCH 120/185] test: expand tests for Dirichlet distribution --- src/distribution/dirichlet.rs | 81 ++++++++++++++++++++++++----------- 1 file changed, 56 insertions(+), 25 deletions(-) diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index eb4d7958..f058b46d 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -345,10 +345,15 @@ fn is_valid_alpha(a: &[f64]) -> bool { #[rustfmt::skip] #[cfg(test)] mod tests { - use nalgebra::{dvector, vector, DimMin, OVector}; + use std::fmt::{Debug, Display}; - use super::*; - use crate::distribution::Continuous; + use nalgebra::{dmatrix, dvector, vector, DimMin, OVector}; + + use super::is_valid_alpha; + use crate::{ + distribution::{Continuous, Dirichlet}, + statistics::{MeanN, VarianceN}, + }; fn try_create(alpha: OVector) -> Dirichlet where @@ -369,15 +374,16 @@ mod tests { assert!(dd.is_err()); } - fn test_almost(alpha: OVector, expected: f64, acc: f64, eval: F) + fn test_almost(alpha: OVector, expected: T, acc: f64, eval: F) where - F: FnOnce(Dirichlet) -> f64, + T: Debug + Display + approx::RelativeEq, + F: FnOnce(Dirichlet) -> T, D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { let dd = try_create(alpha); let x = eval(dd); - assert_almost_eq!(expected, x, acc); + assert_relative_eq!(expected, x, epsilon = acc); } #[test] @@ -406,26 +412,51 @@ mod tests { bad_create_case(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate } - // #[test] - // fn test_mean() { - // let n = Dirichlet::new_with_param(0.3, 5).unwrap(); - // let res = n.mean(); - // for x in res { - // assert_eq!(x, 0.3 / 1.5); - // } - // } + #[test] + fn test_mean() { + let mean = |dd: Dirichlet<_>| dd.mean().unwrap(); - // #[test] - // fn test_variance() { - // let alpha = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; - // let sum = alpha.iter().fold(0.0, |acc, x| acc + x); - // let n = Dirichlet::new(&alpha).unwrap(); - // let res = n.variance(); - // for i in 1..11 { - // let f = i as f64; - // assert_almost_eq!(res[i-1], f * (sum - f) / (sum * sum * (sum + 1.0)), 1e-15); - // } - // } + test_almost(vec![0.5; 5].into(), vec![1.0 / 5.0; 5].into(), 1e-15, mean); + + test_almost( + dvector![0.1, 0.2, 0.3, 0.4], + dvector![0.1, 0.2, 0.3, 0.4], + 1e-15, + mean, + ); + + test_almost( + dvector![1.0, 2.0, 3.0, 4.0], + dvector![0.1, 0.2, 0.3, 0.4], + 1e-15, + mean, + ); + } + + #[test] + fn test_variance() { + let variance = |dd: Dirichlet<_>| dd.variance().unwrap(); + + test_almost( + dvector![1.0, 2.0], + dmatrix![0.055555555555555, -0.055555555555555; + -0.055555555555555, 0.055555555555555; + ], + 1e-15, + variance, + ); + + test_almost( + dvector![0.1, 0.2, 0.3, 0.4], + dmatrix![0.045, -0.010, -0.015, -0.020; + -0.010, 0.080, -0.030, -0.040; + -0.015, -0.030, 0.105, -0.060; + -0.020, -0.040, -0.060, 0.120; + ], + 1e-15, + variance, + ); + } // #[test] // fn test_std_dev() { From d0cc77330d2fc0fab2d08abcd09f386843e99218 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:45:39 -0500 Subject: [PATCH 121/185] feat(multivariate)!: migrate Multinomial to generic dimension API --- src/distribution/internal.rs | 37 +++++++++++++++++ src/distribution/multinomial.rs | 73 ++++++++++++++++++++++++--------- 2 files changed, 91 insertions(+), 19 deletions(-) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index b6e71136..b45b4bf7 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -1,5 +1,8 @@ +use nalgebra::{Dim, OVector}; use num_traits::Num; +use crate::StatsError; + /// Returns true if there are no elements in `x` in `arr` /// such that `x <= 0.0` or `x` is `f64::NAN` and `sum(arr) > 0.0`. /// IF `incl_zero` is true, it tests for `x < 0.0` instead of `x <= 0.0` @@ -14,6 +17,36 @@ pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool { sum != 0.0 } +pub fn check_multinomial(arr: &OVector, accept_zeroes: bool) -> crate::Result<()> +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + if arr.len() < 2 { + return Err(StatsError::BadParams); + } + let mut sum = 0.0; + for &x in arr.iter() { + if x.is_nan() { + return Err(StatsError::BadParams); + } else if x.is_infinite() { + return Err(StatsError::BadParams); + } else if x < 0.0 { + return Err(StatsError::BadParams); + } else if x == 0.0 && !accept_zeroes { + return Err(StatsError::BadParams); + } else { + sum += x; + } + } + + if sum != 0.0 { + Ok(()) + } else { + Err(StatsError::BadParams) + } +} + /// Implements univariate function bisection searching for criteria /// ```text /// smallest k such that f(k) >= z @@ -225,12 +258,16 @@ pub mod test { let invalid = [1.0, f64::NAN, 3.0]; assert!(!is_valid_multinomial(&invalid, true)); + assert!(check_multinomial(&invalid.to_vec().into(), true).is_err()); let invalid2 = [-2.0, 5.0, 1.0, 6.2]; assert!(!is_valid_multinomial(&invalid2, true)); + assert!(check_multinomial(&invalid2.to_vec().into(), true).is_err()); let invalid3 = [0.0, 0.0, 0.0]; assert!(!is_valid_multinomial(&invalid3, true)); + assert!(check_multinomial(&invalid3.to_vec().into(), true).is_err()); let valid = [5.2, 0.0, 1e-15, 1000000.12]; assert!(is_valid_multinomial(&valid, true)); + assert!(check_multinomial(&valid.to_vec().into(), true).is_ok()); } #[test] diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index dd17d2f0..b3e95cb2 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -1,8 +1,8 @@ use crate::distribution::Discrete; use crate::function::factorial; use crate::statistics::*; -use crate::{Result, StatsError}; -use ::nalgebra::{DMatrix, DVector}; +use crate::Result; +use nalgebra::{Const, DMatrix, DVector, Dim, Dyn, OVector}; use rand::Rng; /// Implements the @@ -22,12 +22,18 @@ use rand::Rng; /// assert_eq!(n.mean().unwrap(), DVector::from_vec(vec![1.5, 3.5])); /// ``` #[derive(Debug, Clone, PartialEq)] -pub struct Multinomial { - p: Vec, +pub struct Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + /// normalized probabilities for each species + p: OVector, + /// count of trials n: u64, } -impl Multinomial { +impl Multinomial { /// Constructs a new multinomial distribution with probabilities `p` /// and `n` number of trials. /// @@ -51,11 +57,20 @@ impl Multinomial { /// result = Multinomial::new(&[0.0, -1.0, 2.0], 3); /// assert!(result.is_err()); /// ``` - pub fn new(p: &[f64], n: u64) -> Result { - if !super::internal::is_valid_multinomial(p, true) { - Err(StatsError::BadParams) - } else { - Ok(Multinomial { p: p.to_vec(), n }) + pub fn new(p: &[f64], n: u64) -> Result { + Self::new_from_nalgebra(p.to_vec().into(), n) + } +} + +impl Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + pub fn new_from_nalgebra(mut p: OVector, n: u64) -> Result { + match super::internal::check_multinomial(&p, true) { + Err(e) => Err(e), + Ok(_) => Ok(Self { p, n }), } } @@ -70,7 +85,7 @@ impl Multinomial { /// let n = Multinomial::new(&[0.0, 1.0, 2.0], 3).unwrap(); /// assert_eq!(n.p(), [0.0, 1.0, 2.0]); /// ``` - pub fn p(&self) -> &[f64] { + pub fn p(&self) -> &OVector { &self.p } @@ -90,16 +105,24 @@ impl Multinomial { } } -impl std::fmt::Display for Multinomial { +impl std::fmt::Display for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Multinom({:#?},{})", self.p, self.n) } } -impl ::rand::distributions::Distribution> for Multinomial { - fn sample(&self, rng: &mut R) -> Vec { - let p_cdf = super::categorical::prob_mass_to_cdf(self.p()); - let mut res = vec![0.0; self.p.len()]; +impl ::rand::distributions::Distribution> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + fn sample(&self, rng: &mut R) -> OVector { + let p_cdf = super::categorical::prob_mass_to_cdf(self.p().as_slice()); + let mut res = OVector::zeros_generic(self.p.shape_generic().0, Const::<1>); for _ in 0..self.n { let i = super::categorical::sample_unchecked(rng, &p_cdf); let el = res.get_mut(i as usize).unwrap(); @@ -109,7 +132,11 @@ impl ::rand::distributions::Distribution> for Multinomial { } } -impl MeanN> for Multinomial { +impl MeanN> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ /// Returns the mean of the multinomial distribution /// /// # Formula @@ -127,7 +154,11 @@ impl MeanN> for Multinomial { } } -impl VarianceN> for Multinomial { +impl VarianceN> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ /// Returns the variance of the multinomial distribution /// /// # Formula @@ -169,7 +200,11 @@ impl VarianceN> for Multinomial { // } // } -impl<'a> Discrete<&'a [u64], f64> for Multinomial { +impl<'a, D> Discrete<&'a [u64], f64> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ /// Calculates the probability mass function for the multinomial /// distribution /// with the given `x`'s corresponding to the probabilities for this From b6d866f5706cf59ac88ceb750b865b4a0eb469a9 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:45:39 -0500 Subject: [PATCH 122/185] refactor: Multinomial stores normalized probability --- src/distribution/multinomial.rs | 61 +++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index b3e95cb2..610d746b 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -2,7 +2,7 @@ use crate::distribution::Discrete; use crate::function::factorial; use crate::statistics::*; use crate::Result; -use nalgebra::{Const, DMatrix, DVector, Dim, Dyn, OVector}; +use nalgebra::{Const, DVector, Dim, Dyn, OMatrix, OVector}; use rand::Rng; /// Implements the @@ -51,14 +51,14 @@ impl Multinomial { /// ``` /// use statrs::distribution::Multinomial; /// - /// let mut result = Multinomial::new(&[0.0, 1.0, 2.0], 3); + /// let mut result = Multinomial::new(vec![0.0, 1.0, 2.0], 3); /// assert!(result.is_ok()); /// - /// result = Multinomial::new(&[0.0, -1.0, 2.0], 3); + /// result = Multinomial::new(vec![0.0, -1.0, 2.0], 3); /// assert!(result.is_err()); /// ``` - pub fn new(p: &[f64], n: u64) -> Result { - Self::new_from_nalgebra(p.to_vec().into(), n) + pub fn new(p: Vec, n: u64) -> Result { + Self::new_from_nalgebra(p.into(), n) } } @@ -70,7 +70,10 @@ where pub fn new_from_nalgebra(mut p: OVector, n: u64) -> Result { match super::internal::check_multinomial(&p, true) { Err(e) => Err(e), - Ok(_) => Ok(Self { p, n }), + Ok(_) => { + p.unscale_mut(p.lp_norm(1)); + Ok(Self { p, n }) + } } } @@ -154,10 +157,11 @@ where } } -impl VarianceN> for Multinomial +impl VarianceN> for Multinomial where D: Dim, - nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the variance of the multinomial distribution /// @@ -169,13 +173,21 @@ where /// /// where `n` is the number of trials, `p_i` is the `i`th probability, /// and `k` is the total number of probabilities - fn variance(&self) -> Option> { - let cov: Vec<_> = self - .p - .iter() - .map(|x| x * self.n as f64 * (1.0 - x)) - .collect(); - Some(DMatrix::from_diagonal(&DVector::from_vec(cov))) + fn variance(&self) -> Option> { + let mut cov = OMatrix::from_diagonal(&self.p.map(|x| x * (1.0 - x))); + let mut offdiag = |x: usize, y: usize| { + let elt = -self.p[x] * self.p[y]; + // cov[(x, y)] = elt; + cov[(y, x)] = elt; + }; + + for i in 0..self.p.len() { + for j in 0..i { + offdiag(i, j); + } + } + cov.fill_lower_triangle_with_upper_triangle(); + Some(cov.scale(self.n as f64)) } } @@ -200,10 +212,11 @@ where // } // } -impl<'a, D> Discrete<&'a [u64], f64> for Multinomial +impl<'a, D> Discrete<&'a OVector, f64> for Multinomial where D: Dim, - nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Calculates the probability mass function for the multinomial /// distribution @@ -212,8 +225,7 @@ where /// /// # Panics /// - /// If the elements in `x` do not sum to `n` or if the length of `x` is not - /// equivalent to the length of `p` + /// If length of `x` is not equal to length of `p` /// /// # Formula /// @@ -224,14 +236,14 @@ where /// where `n` is the number of trials, `p_i` is the `i`th probability, /// `x_i` is the `i`th `x` value, and `k` is the total number of /// probabilities - fn pmf(&self, x: &[u64]) -> f64 { + fn pmf(&self, x: &OVector) -> f64 { if self.p.len() != x.len() { panic!("Expected x and p to have equal lengths."); } if x.iter().sum::() != self.n { return 0.0; } - let coeff = factorial::multinomial(self.n, x); + let coeff = factorial::multinomial(self.n, x.as_slice()); let val = coeff * self .p @@ -248,8 +260,7 @@ where /// /// # Panics /// - /// If the elements in `x` do not sum to `n` or if the length of `x` is not - /// equivalent to the length of `p` + /// If length of `x` is not equal to length of `p` /// /// # Formula /// @@ -260,14 +271,14 @@ where /// where `n` is the number of trials, `p_i` is the `i`th probability, /// `x_i` is the `i`th `x` value, and `k` is the total number of /// probabilities - fn ln_pmf(&self, x: &[u64]) -> f64 { + fn ln_pmf(&self, x: &OVector) -> f64 { if self.p.len() != x.len() { panic!("Expected x and p to have equal lengths."); } if x.iter().sum::() != self.n { return f64::NEG_INFINITY; } - let coeff = factorial::multinomial(self.n, x).ln(); + let coeff = factorial::multinomial(self.n, x.as_slice()).ln(); let val = coeff + self .p From 9ba0d0213f49e7fc9bc0c538c9c6eab838b08ff0 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:48:21 -0500 Subject: [PATCH 123/185] test: reintroduce tests for Multinomial --- src/distribution/multinomial.rs | 316 +++++++++++++++++++------------- 1 file changed, 192 insertions(+), 124 deletions(-) diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index 610d746b..fbd8d594 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -290,142 +290,210 @@ where } } -// TODO: fix tests -// #[rustfmt::skip] -// #[cfg(test)] -// mod tests { -// use crate::statistics::*; -// use crate::distribution::{Discrete, Multinomial}; +#[rustfmt::skip] +#[cfg(test)] +mod tests { + use crate::{ + distribution::{Discrete, Multinomial}, + statistics::{MeanN, VarianceN}, + }; + use nalgebra::{dmatrix, dvector, vector, DimMin, Dyn, OVector}; + use std::fmt::{Debug, Display}; -// fn try_create(p: &[f64], n: u64) -> Multinomial { -// let dist = Multinomial::new(p, n); -// assert!(dist.is_ok()); -// dist.unwrap() -// } - -// fn create_case(p: &[f64], n: u64) { -// let dist = try_create(p, n); -// assert_eq!(dist.p(), p); -// assert_eq!(dist.n(), n); -// } - -// fn bad_create_case(p: &[f64], n: u64) { -// let dist = Multinomial::new(p, n); -// assert!(dist.is_err()); -// } + fn try_create(p: OVector, n: u64) -> Multinomial + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + { + let mvn = Multinomial::new_from_nalgebra(p, n); + assert!(mvn.is_ok()); + mvn.unwrap() + } -// fn test_case(p: &[f64], n: u64, expected: &[f64], eval: F) -// where F: Fn(Multinomial) -> Vec -// { -// let dist = try_create(p, n); -// let x = eval(dist); -// assert_eq!(*expected, *x); -// } + fn bad_create_case(p: OVector, n: u64) -> crate::StatsError + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + { + let dd = Multinomial::new_from_nalgebra(p, n); + assert!(dd.is_err()); + dd.unwrap_err() + } -// fn test_almost(p: &[f64], n: u64, expected: &[f64], acc: f64, eval: F) -// where F: Fn(Multinomial) -> Vec -// { -// let dist = try_create(p, n); -// let x = eval(dist); -// assert_eq!(expected.len(), x.len()); -// for i in 0..expected.len() { -// assert_almost_eq!(expected[i], x[i], acc); -// } -// } + fn test_almost(p: OVector, n: u64, expected: T, acc: f64, eval: F) + where + T: Debug + Display + approx::RelativeEq, + F: FnOnce(Multinomial) -> T, + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + { + let dd = try_create(p, n); + let x = eval(dd); + assert_relative_eq!(expected, x, epsilon = acc); + } -// fn test_almost_sr(p: &[f64], n: u64, expected: f64, acc:f64, eval: F) -// where F: Fn(Multinomial) -> f64 -// { -// let dist = try_create(p, n); -// let x = eval(dist); -// assert_almost_eq!(expected, x, acc); -// } + #[test] + fn test_create() { + assert_relative_eq!( + *try_create(vector![1.0, 1.0, 1.0], 4).p(), + vector![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0] + ); + try_create(dvector![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 4); + } -// #[test] -// fn test_create() { -// create_case(&[1.0, 1.0, 1.0], 4); -// create_case(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 4); -// } + #[test] + fn test_bad_create() { + assert_eq!( + bad_create_case(vector![-1.0, 2.0], 4), + crate::StatsError::BadParams + ); -// #[test] -// fn test_bad_create() { -// bad_create_case(&[-1.0, 1.0], 4); -// bad_create_case(&[0.0, 0.0], 4); -// } + assert_eq!( + bad_create_case(vector![0.0, 0.0], 4), + crate::StatsError::BadParams + ); + assert_eq!( + bad_create_case(vector![1.0, f64::NAN], 4), + crate::StatsError::BadParams + ); + } -// #[test] -// fn test_mean() { -// let mean = |x: Multinomial| x.mean().unwrap(); -// test_case(&[0.3, 0.7], 5, &[1.5, 3.5], mean); -// test_case(&[0.1, 0.3, 0.6], 10, &[1.0, 3.0, 6.0], mean); -// test_case(&[0.15, 0.35, 0.3, 0.2], 20, &[3.0, 7.0, 6.0, 4.0], mean); -// } + #[test] + fn test_mean() { + let mean = |x: Multinomial<_>| x.mean().unwrap(); + test_almost(dvector![0.3, 0.7], 5, dvector![1.5, 3.5], 1e-12, mean); + test_almost( + dvector![0.1, 0.3, 0.6], + 10, + dvector![1.0, 3.0, 6.0], + 1e-12, + mean, + ); + test_almost( + dvector![1.0, 3.0, 6.0], + 10, + dvector![1.0, 3.0, 6.0], + 1e-12, + mean, + ); + test_almost( + dvector![0.15, 0.35, 0.3, 0.2], + 20, + dvector![3.0, 7.0, 6.0, 4.0], + 1e-12, + mean, + ); + } -// #[test] -// fn test_variance() { -// let variance = |x: Multinomial| x.variance().unwrap(); -// test_almost(&[0.3, 0.7], 5, &[1.05, 1.05], 1e-15, variance); -// test_almost(&[0.1, 0.3, 0.6], 10, &[0.9, 2.1, 2.4], 1e-15, variance); -// test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[2.55, 4.55, 4.2, 3.2], 1e-15, variance); -// } + #[test] + fn test_variance() { + let variance = |x: Multinomial<_>| x.variance().unwrap(); + test_almost( + dvector![0.3, 0.7], + 5, + dmatrix![1.05, -1.05; + -1.05, 1.05], + 1e-15, + variance, + ); + test_almost( + dvector![0.1, 0.3, 0.6], + 10, + dmatrix![0.9, -0.3, -0.6; + -0.3, 2.1, -1.8; + -0.6, -1.8, 2.4; + ], + 1e-15, + variance, + ); + test_almost( + dvector![0.15, 0.35, 0.3, 0.2], + 20, + dmatrix![2.55, -1.05, -0.90, -0.60; + -1.05, 4.55, -2.10, -1.40; + -0.90, -2.10, 4.20, -1.20; + -0.60, -1.40, -1.20, 3.20; + ], + 1e-15, + variance, + ); + } -// // #[test] -// // fn test_skewness() { -// // let skewness = |x: Multinomial| x.skewness().unwrap(); -// // test_almost(&[0.3, 0.7], 5, &[0.390360029179413, -0.390360029179413], 1e-15, skewness); -// // test_almost(&[0.1, 0.3, 0.6], 10, &[0.843274042711568, 0.276026223736942, -0.12909944487358], 1e-15, skewness); -// // test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[0.438357003759605, 0.140642169281549, 0.195180014589707, 0.335410196624968], 1e-15, skewness); -// // } + // // #[test] + // // fn test_skewness() { + // // let skewness = |x: Multinomial| x.skewness().unwrap(); + // // test_almost(&[0.3, 0.7], 5, &[0.390360029179413, -0.390360029179413], 1e-15, skewness); + // // test_almost(&[0.1, 0.3, 0.6], 10, &[0.843274042711568, 0.276026223736942, -0.12909944487358], 1e-15, skewness); + // // test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[0.438357003759605, 0.140642169281549, 0.195180014589707, 0.335410196624968], 1e-15, skewness); + // // } -// #[test] -// fn test_pmf() { -// let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); -// test_almost_sr(&[0.3, 0.7], 10, 0.121060821, 1e-15, pmf(&[1, 9])); -// test_almost_sr(&[0.1, 0.3, 0.6], 10, 0.105815808, 1e-15, pmf(&[1, 3, 6])); -// test_almost_sr(&[0.15, 0.35, 0.3, 0.2], 10, 0.000145152, 1e-15, pmf(&[1, 1, 1, 7])); -// } + #[test] + fn test_pmf() { + let pmf = |arg: OVector| move |x: Multinomial<_>| x.pmf(&arg); + test_almost( + dvector![0.3, 0.7], + 10, + 0.121060821, + 1e-15, + pmf(dvector![1, 9]), + ); + test_almost( + dvector![0.1, 0.3, 0.6], + 10, + 0.105815808, + 1e-15, + pmf(dvector![1, 3, 6]), + ); + test_almost( + dvector![0.15, 0.35, 0.3, 0.2], + 10, + 0.000145152, + 1e-15, + pmf(dvector![1, 1, 1, 7]), + ); + } -// #[test] -// #[should_panic] -// fn test_pmf_x_wrong_length() { -// let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); -// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); -// n.pmf(&[1]); -// } + // #[test] + // #[should_panic] + // fn test_pmf_x_wrong_length() { + // let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); + // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); + // n.pmf(&[1]); + // } -// #[test] -// #[should_panic] -// fn test_pmf_x_wrong_sum() { -// let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); -// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); -// n.pmf(&[1, 3]); -// } + // #[test] + // #[should_panic] + // fn test_pmf_x_wrong_sum() { + // let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); + // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); + // n.pmf(&[1, 3]); + // } -// #[test] -// fn test_ln_pmf() { -// let large_p = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; -// let n = Multinomial::new(large_p, 45).unwrap(); -// let x = &[1, 2, 3, 4, 5, 6, 7, 8, 9]; -// assert_almost_eq!(n.pmf(x).ln(), n.ln_pmf(x), 1e-13); -// let n2 = Multinomial::new(large_p, 18).unwrap(); -// let x2 = &[1, 1, 1, 2, 2, 2, 3, 3, 3]; -// assert_almost_eq!(n2.pmf(x2).ln(), n2.ln_pmf(x2), 1e-13); -// let n3 = Multinomial::new(large_p, 51).unwrap(); -// let x3 = &[5, 6, 7, 8, 7, 6, 5, 4, 3]; -// assert_almost_eq!(n3.pmf(x3).ln(), n3.ln_pmf(x3), 1e-13); -// } + // #[test] + // fn test_ln_pmf() { + // let large_p = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + // let n = Multinomial::new(large_p, 45).unwrap(); + // let x = &[1, 2, 3, 4, 5, 6, 7, 8, 9]; + // assert_almost_eq!(n.pmf(x).ln(), n.ln_pmf(x), 1e-13); + // let n2 = Multinomial::new(large_p, 18).unwrap(); + // let x2 = &[1, 1, 1, 2, 2, 2, 3, 3, 3]; + // assert_almost_eq!(n2.pmf(x2).ln(), n2.ln_pmf(x2), 1e-13); + // let n3 = Multinomial::new(large_p, 51).unwrap(); + // let x3 = &[5, 6, 7, 8, 7, 6, 5, 4, 3]; + // assert_almost_eq!(n3.pmf(x3).ln(), n3.ln_pmf(x3), 1e-13); + // } -// #[test] -// #[should_panic] -// fn test_ln_pmf_x_wrong_length() { -// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); -// n.ln_pmf(&[1]); -// } + // #[test] + // #[should_panic] + // fn test_ln_pmf_x_wrong_length() { + // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); + // n.ln_pmf(&[1]); + // } -// #[test] -// #[should_panic] -// fn test_ln_pmf_x_wrong_sum() { -// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); -// n.ln_pmf(&[1, 3]); -// } -// } + // #[test] + // #[should_panic] + // fn test_ln_pmf_x_wrong_sum() { + // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); + // n.ln_pmf(&[1, 3]); + // } +} From 31c5b6070d8cea0118b7736bfc9c2326868d1e4f Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 2 Aug 2024 17:09:06 -0500 Subject: [PATCH 124/185] test(docs): update Multinomial doc tests --- src/distribution/multinomial.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index fbd8d594..dc402050 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -16,10 +16,10 @@ use rand::Rng; /// ``` /// use statrs::distribution::Multinomial; /// use statrs::statistics::MeanN; -/// use nalgebra::DVector; +/// use nalgebra::vector; /// -/// let n = Multinomial::new(&[0.3, 0.7], 5).unwrap(); -/// assert_eq!(n.mean().unwrap(), DVector::from_vec(vec![1.5, 3.5])); +/// let n = Multinomial::new_from_nalgebra(vector![0.3, 0.7], 5).unwrap(); +/// assert_eq!(n.mean().unwrap(), (vector![1.5, 3.5])); /// ``` #[derive(Debug, Clone, PartialEq)] pub struct Multinomial @@ -84,9 +84,10 @@ where /// /// ``` /// use statrs::distribution::Multinomial; + /// use nalgebra::dvector; /// - /// let n = Multinomial::new(&[0.0, 1.0, 2.0], 3).unwrap(); - /// assert_eq!(n.p(), [0.0, 1.0, 2.0]); + /// let n = Multinomial::new(vec![0.0, 1.0, 2.0], 3).unwrap(); + /// assert_eq!(*n.p(), dvector![0.0, 1.0/3.0, 2.0/3.0]); /// ``` pub fn p(&self) -> &OVector { &self.p @@ -100,7 +101,7 @@ where /// ``` /// use statrs::distribution::Multinomial; /// - /// let n = Multinomial::new(&[0.0, 1.0, 2.0], 3).unwrap(); + /// let n = Multinomial::new(vec![0.0, 1.0, 2.0], 3).unwrap(); /// assert_eq!(n.n(), 3); /// ``` pub fn n(&self) -> u64 { From 5c2258c224d68be4c7dacfce66aba324d0a9f14f Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sat, 3 Aug 2024 09:44:12 -0500 Subject: [PATCH 125/185] chore: allow clippy lint error api will be more specific later --- src/distribution/internal.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index b45b4bf7..0837dc40 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -27,6 +27,7 @@ where } let mut sum = 0.0; for &x in arr.iter() { + #[allow(clippy::if_same_then_else)] if x.is_nan() { return Err(StatsError::BadParams); } else if x.is_infinite() { From 3398bd4ac6bdd9ea81fc47035a3f3bf91fd634e5 Mon Sep 17 00:00:00 2001 From: Henry Jacobson Date: Sat, 24 Dec 2022 00:31:15 +0100 Subject: [PATCH 126/185] feat: implement multivariate students t distribution fix: Use ln_pdf_const instead of pdf_const feat: creation of multivariate normal distribution from same variables as multivariate students (for when freedom = inf) fix: use multivariate normal pdf when freedom = inf for multivariate student test: panic test for invalid pdf argument fix: tests in documentation fix: clearer function name in test fix: add documentation tests: 3d matrices test cases for pdf. Also improves documentation for multivariate t minorly. test: modify test case in multivariate_t the float chosen happens to approximate f64::LOG10_2 this leads to a linting error instead of supressing, just choosing a different value and testing against scipy also run rustfmt feat: update multivariate student to dimension generic API refactor: adds exposed Normal density for reuse in infinite DOF student distribution chore: run linting --- src/distribution/mod.rs | 2 + src/distribution/multivariate_normal.rs | 144 ++++-- src/distribution/multivariate_students_t.rs | 543 ++++++++++++++++++++ 3 files changed, 641 insertions(+), 48 deletions(-) create mode 100644 src/distribution/multivariate_students_t.rs diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index dea1f9af..b146f9d4 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -27,6 +27,7 @@ pub use self::laplace::Laplace; pub use self::log_normal::LogNormal; pub use self::multinomial::Multinomial; pub use self::multivariate_normal::MultivariateNormal; +pub use self::multivariate_students_t::MultivariateStudent; pub use self::negative_binomial::NegativeBinomial; pub use self::normal::Normal; pub use self::pareto::Pareto; @@ -60,6 +61,7 @@ mod laplace; mod log_normal; mod multinomial; mod multivariate_normal; +mod multivariate_students_t; mod negative_binomial; mod normal; mod pareto; diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 9949d676..61ccfa57 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -1,12 +1,84 @@ use crate::distribution::Continuous; use crate::distribution::Normal; use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; -use crate::{Result, StatsError}; +use crate::StatsError; use nalgebra::{Cholesky, Const, DMatrix, DVector, Dim, DimMin, Dyn, OMatrix, OVector}; use rand::Rng; use std::f64; use std::f64::consts::{E, PI}; +/// computes both the normalization and exponential argument in the normal distribution +/// # Errors +/// will error on dimension mismatch +pub(super) fn density_normalization_and_exponential( + mu: &OVector, + cov: &OMatrix, + precision: &OMatrix, + x: &OVector, +) -> std::result::Result<(f64, f64), StatsError> +where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, +{ + Ok(( + density_distribution_pdf_const(mu, cov)?, + density_distribution_exponential(mu, precision, x)?, + )) +} + +/// computes the argument of the exponential term in the normal distribution +/// ```text +/// ``` +/// # Errors +/// will error on dimension mismatch +#[inline] +pub(super) fn density_distribution_exponential( + mu: &OVector, + precision: &OMatrix, + x: &OVector, +) -> std::result::Result +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + if x.shape_generic().0 != precision.shape_generic().0 + || x.shape_generic().0 != mu.shape_generic().0 + || !precision.is_square() + { + return Err(StatsError::ContainersMustBeSameLength); + } + let dv = x - mu; + let exp_term: f64 = -0.5 * (precision * &dv).dot(&dv); + Ok(exp_term) + // TODO update to dimension mismatch error +} + +/// computes the argument of the normalization term in the normal distribution +/// # Errors +/// will error on dimension mismatch +#[inline] +pub(super) fn density_distribution_pdf_const( + mu: &OVector, + cov: &OMatrix, +) -> std::result::Result +where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, +{ + if cov.shape_generic().0 != mu.shape_generic().0 || !cov.is_square() { + return Err(StatsError::ContainersMustBeSameLength); + } + let cov_det = cov.determinant(); + Ok(((2. * PI).powi(mu.nrows() as i32) * cov_det.abs()) + .recip() + .sqrt()) +} + /// Implements the [Multivariate Normal](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) /// distribution using the "nalgebra" crate for matrix operations /// @@ -44,7 +116,7 @@ impl MultivariateNormal { /// /// Returns an error if the given covariance matrix is not /// symmetric or positive-definite - pub fn new(mean: Vec, cov: Vec) -> Result { + pub fn new(mean: Vec, cov: Vec) -> Result { let mean = DVector::from_vec(mean); let cov = DMatrix::from_vec(mean.len(), mean.len(), cov); MultivariateNormal::new_from_nalgebra(mean, cov) @@ -66,7 +138,10 @@ where /// /// Returns an error if the given covariance matrix is not /// symmetric or positive-definite - pub fn new_from_nalgebra(mean: OVector, cov: OMatrix) -> Result { + pub fn new_from_nalgebra( + mean: OVector, + cov: OMatrix, + ) -> Result { // Check that the provided covariance matrix is symmetric if cov.lower_triangle() != cov.upper_triangle().transpose() // Check that mean and covariance do not contain NaN @@ -77,10 +152,6 @@ where { return Err(StatsError::BadParams); } - let cov_det = cov.determinant(); - let pdf_const = ((2. * PI).powi(mean.nrows() as i32) * cov_det.abs()) - .recip() - .sqrt(); // Store the Cholesky decomposition of the covariance matrix // for sampling match Cholesky::new(cov.clone()) { @@ -88,11 +159,11 @@ where Some(cholesky_decomp) => { let precision = cholesky_decomp.inverse(); Ok(MultivariateNormal { + pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(), cov_chol_decomp: cholesky_decomp.unpack(), mu: mean, cov, precision, - pdf_const, }) } } @@ -231,9 +302,8 @@ where impl<'a, D> Continuous<&'a OVector, f64> for MultivariateNormal where D: Dim, - nalgebra::DefaultAllocator: nalgebra::allocator::Allocator - + nalgebra::allocator::Allocator - + nalgebra::allocator::Allocator, D>, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Calculates the probability density function for the multivariate /// normal distribution at `x` @@ -246,47 +316,18 @@ where /// /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant /// of the covariance matrix, and `k` is the dimension of the distribution - fn pdf(&self, x: &'a OVector) -> f64 { - let dv = x - &self.mu; - let exp_term = -0.5 - * *(&dv.transpose() * &self.precision * &dv) - .get((0, 0)) - .unwrap(); - self.pdf_const * exp_term.exp() - } - - /// Calculates the log probability density function for the multivariate - /// normal distribution at `x`. Equivalent to pdf(x).ln(). - fn ln_pdf(&self, x: &'a OVector) -> f64 { - let dv = x - &self.mu; - let exp_term = -0.5 - * *(&dv.transpose() * &self.precision * &dv) - .get((0, 0)) - .unwrap(); - self.pdf_const.ln() + exp_term - } -} - -impl Continuous, f64> for MultivariateNormal { - /// Calculates the probability density function for the multivariate - /// normal distribution at `x` - /// - /// # Formula - /// - /// ```text - /// (2 * π) ^ (-k / 2) * det(Σ) ^ (1 / 2) * e ^ ( -(1 / 2) * transpose(x - μ) * inv(Σ) * (x - μ)) - /// ``` - /// - /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant - /// of the covariance matrix, and `k` is the dimension of the distribution - fn pdf(&self, x: Vec) -> f64 { - self.pdf(&DVector::from(x)) + fn pdf(&self, x: &OVector) -> f64 { + self.pdf_const + * density_distribution_exponential(&self.mu, &self.precision, x) + .unwrap() + .exp() } /// Calculates the log probability density function for the multivariate /// normal distribution at `x`. Equivalent to pdf(x).ln(). - fn ln_pdf(&self, x: Vec) -> f64 { - self.pdf(&DVector::from(x)) + fn ln_pdf(&self, x: &OVector) -> f64 { + self.pdf_const.ln() + + density_distribution_exponential(&self.mu, &self.precision, x).unwrap() } } @@ -609,4 +650,11 @@ mod tests { ln_pdf(dvector![100., 100.]), ); } + + #[test] + #[should_panic] + fn test_pdf_mismatched_arg_size() { + let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.,]).unwrap(); + mvn.pdf(&vec![1.].into()); // x.size != mu.size + } } diff --git a/src/distribution/multivariate_students_t.rs b/src/distribution/multivariate_students_t.rs new file mode 100644 index 00000000..4a3cd38f --- /dev/null +++ b/src/distribution/multivariate_students_t.rs @@ -0,0 +1,543 @@ +use crate::distribution::Continuous; +use crate::distribution::{ChiSquared, Normal}; +use crate::function::gamma; +use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; +use crate::{Result, StatsError}; +use nalgebra::{Cholesky, Const, DMatrix, Dim, DimMin, Dyn, OMatrix, OVector}; +use rand::Rng; +use std::f64::consts::PI; + +/// Implements the [Multivariate Student's t-distribution](https://en.wikipedia.org/wiki/Multivariate_t-distribution) +/// distribution using the "nalgebra" crate for matrix operations. +/// +/// Assumes all the marginal distributions have the same degree of freedom, ν. +/// +/// # Examples +/// +/// ``` +/// use statrs::distribution::{MultivariateStudent, Continuous}; +/// use nalgebra::{DVector, DMatrix}; +/// use statrs::statistics::{MeanN, VarianceN}; +/// +/// let mvs = MultivariateStudent::new(vec![0., 0.], vec![1., 0., 0., 1.], 4.).unwrap(); +/// assert_eq!(mvs.mean().unwrap(), DVector::from_vec(vec![0., 0.])); +/// assert_eq!(mvs.variance().unwrap(), DMatrix::from_vec(2, 2, vec![2., 0., 0., 2.])); +/// assert_eq!(mvs.pdf(&DVector::from_vec(vec![1., 1.])), 0.04715702017537655); +/// ``` +#[derive(Debug, Clone, PartialEq)] +pub struct MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + scale_chol_decomp: OMatrix, + location: OVector, + scale: OMatrix, + freedom: f64, + precision: OMatrix, + ln_pdf_const: f64, +} + +impl MultivariateStudent { + /// Constructs a new multivariate students t distribution with a location of `location`, + /// scale matrix `scale` and `freedom` degrees of freedom. + /// + /// # Errors + /// + /// Returns `StatsError::BadParams` if the scale matrix is not symmetric-positive + /// definite and `StatsError::ArgMustBePositive` if freedom is non-positive. + pub fn new(location: Vec, scale: Vec, freedom: f64) -> Result { + let dim = location.len(); + Self::new_from_nalgebra(location.into(), DMatrix::from_vec(dim, dim, scale), freedom) + } + + /// Returns the dimension of the distribution. + pub fn dim(&self) -> usize { + self.location.len() + } +} + +impl MultivariateStudent +where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, +{ + pub fn new_from_nalgebra( + location: OVector, + scale: OMatrix, + freedom: f64, + ) -> Result { + let dim = location.len(); + + // Check that the provided scale matrix is symmetric + if scale.lower_triangle() != scale.upper_triangle().transpose() + // Check that mean and covariance do not contain NaN + || location.iter().any(|f| f.is_nan()) + || scale.iter().any(|f| f.is_nan()) + // Check that the dimensions match + || location.nrows() != scale.nrows() || scale.nrows() != scale.ncols() + // Check that the degrees of freedom is not NaN + || freedom.is_nan() + { + return Err(StatsError::BadParams); + } + // Check that degrees of freedom is positive + if freedom <= 0. { + return Err(StatsError::ArgMustBePositive( + "Degrees of freedom must be positive", + )); + } + + let scale_det = scale.determinant(); + let ln_pdf_const = gamma::ln_gamma(0.5 * (freedom + dim as f64)) + - gamma::ln_gamma(0.5 * freedom) + - 0.5 * (dim as f64) * (freedom * PI).ln() + - 0.5 * scale_det.ln(); + + match Cholesky::new(scale.clone()) { + None => Err(StatsError::BadParams), // Scale matrix is not positive definite + Some(cholesky_decomp) => { + let precision = cholesky_decomp.inverse(); + Ok(MultivariateStudent { + scale_chol_decomp: cholesky_decomp.unpack(), + location, + scale, + freedom, + precision, + ln_pdf_const, + }) + } + } + } + + /// Returns the cholesky decomposiiton matrix of the scale matrix. + /// + /// Returns A where Σ = AAᵀ. + pub fn scale_chol_decomp(&self) -> &OMatrix { + &self.scale_chol_decomp + } + + /// Returns the location of the distribution. + pub fn location(&self) -> &OVector { + &self.location + } + + /// Returns the scale matrix of the distribution. + pub fn scale(&self) -> &OMatrix { + &self.scale + } + + /// Returns the degrees of freedom of the distribution. + pub fn freedom(&self) -> f64 { + self.freedom + } + + /// Returns the inverse of the cholesky decomposition matrix. + pub fn precision(&self) -> &OMatrix { + &self.precision + } + + /// Returns the logarithmed constant part of the probability + /// distribution function. + pub fn ln_pdf_const(&self) -> f64 { + self.ln_pdf_const + } +} + +impl ::rand::distributions::Distribution> for MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + /// Samples from the multivariate student distribution + /// + /// # Formula + /// + /// ```math + /// W ⋅ L ⋅ Z + μ + /// ``` + /// + /// where `W` has √(ν/Sν) distribution, Sν has Chi-squared + /// distribution with ν degrees of freedom, + /// `L` is the Cholesky decomposition of the scale matrix, + /// `Z` is a vector of normally distributed random variables, and + /// `μ` is the location vector + fn sample(&self, rng: &mut R) -> OVector { + let d = Normal::new(0., 1.).unwrap(); + let s = ChiSquared::new(self.freedom).unwrap(); + let w = (self.freedom / s.sample(rng)).sqrt(); + let (r, c) = self.location.shape_generic(); + let z = OVector::::from_distribution_generic(r, c, &d, rng); + (w * &self.scale_chol_decomp * z) + &self.location + } +} + +impl Min> for MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + /// Returns the minimum value in the domain of the + /// multivariate normal distribution represented by a real vector + fn min(&self) -> OVector { + OMatrix::repeat_generic( + self.location.shape_generic().0, + Const::<1>, + f64::NEG_INFINITY, + ) + } +} + +impl Max> for MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + /// Returns the minimum value in the domain of the + /// multivariate normal distribution represented by a real vector + fn max(&self) -> OVector { + OMatrix::repeat_generic(self.location.shape_generic().0, Const::<1>, f64::INFINITY) + } +} + +impl MeanN> for MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + /// Returns the mean of the student distribution. + /// + /// # Remarks + /// + /// This is the same mean used to construct the distribution if + /// the degrees of freedom is larger than 1. + fn mean(&self) -> Option> { + if self.freedom > 1. { + Some(self.location.clone()) + } else { + None + } + } +} + +impl VarianceN> for MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + /// Returns the covariance matrix of the multivariate student distribution. + /// + /// # Formula + /// + /// ```math + /// Σ ⋅ ν / (ν - 2) + /// ``` + /// + /// where `Σ` is the scale matrix and `ν` is the degrees of freedom. + /// Only defined if freedom is larger than 2. + fn variance(&self) -> Option> { + if self.freedom > 2. { + Some(self.scale.clone() * self.freedom / (self.freedom - 2.)) + } else { + None + } + } +} + +impl Mode> for MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + /// Returns the mode of the multivariate student distribution. + /// + /// # Formula + /// + /// ```math + /// μ + /// ``` + /// + /// where `μ` is the location. + fn mode(&self) -> OVector { + self.location.clone() + } +} + +impl<'a, D> Continuous<&'a OVector, f64> for MultivariateStudent +where + D: Dim + DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, +{ + /// Calculates the probability density function for the multivariate. + /// student distribution at `x`. + /// + /// # Formula + /// + /// ```math + /// [Γ(ν+p)/2] / [Γ(ν/2) ((ν * π)^p det(Σ))^(1 / 2)] * [1 + 1/ν (x - μ)ᵀ inv(Σ) (x - μ)]^(-(ν+p)/2) + /// ``` + /// + /// where `ν` is the degrees of freedom, `μ` is the mean, `Γ` + /// is the Gamma function, `inv(Σ)` + /// is the precision matrix, `det(Σ)` is the determinant + /// of the scale matrix, and `k` is the dimension of the distribution. + fn pdf(&self, x: &'a OVector) -> f64 { + if self.freedom.is_infinite() { + use super::multivariate_normal::density_normalization_and_exponential; + let (pdf_const, exp_arg) = density_normalization_and_exponential( + &self.location, + &self.scale, + &self.precision, + x, + ) + .unwrap(); + return pdf_const * exp_arg.exp(); + } + + let dv = x - &self.location; + let exp_arg: f64 = (&self.precision * &dv).dot(&dv); + let base_term = 1. + exp_arg / self.freedom; + self.ln_pdf_const.exp() * base_term.powf(-(self.freedom + self.location.len() as f64) / 2.) + } + + /// Calculates the log probability density function for the multivariate + /// student distribution at `x`. Equivalent to pdf(x).ln(). + fn ln_pdf(&self, x: &'a OVector) -> f64 { + if self.freedom.is_infinite() { + use super::multivariate_normal::density_normalization_and_exponential; + let (pdf_const, exp_arg) = density_normalization_and_exponential( + &self.location, + &self.scale, + &self.precision, + x, + ) + .unwrap(); + return pdf_const.ln() + exp_arg; + } + + let dv = x - &self.location; + let exp_arg: f64 = (&self.precision * &dv).dot(&dv); + let base_term = 1. + exp_arg / self.freedom; + self.ln_pdf_const - (self.freedom + self.location.len() as f64) / 2. * base_term.ln() + } +} + +#[rustfmt::skip] +#[cfg(test)] +mod tests { + use core::fmt::Debug; + + use approx::RelativeEq; + use nalgebra::{Dyn, DMatrix, DVector}; + + use crate::{ + distribution::{Continuous, MultivariateStudent, MultivariateNormal}, + statistics::{Max, MeanN, Min, Mode, VarianceN}, + }; + + fn try_create(location: Vec, scale: Vec, freedom: f64) -> MultivariateStudent + { + let mvs = MultivariateStudent::new(location, scale, freedom); + assert!(mvs.is_ok()); + mvs.unwrap() + } + + fn create_case(location: Vec, scale: Vec, freedom: f64) + { + let mvs = try_create(location.clone(), scale.clone(), freedom); + assert_eq!(DMatrix::from_vec(location.len(), location.len(), scale), mvs.scale); + assert_eq!(DVector::from_vec(location), mvs.location); + } + + fn bad_create_case(location: Vec, scale: Vec, freedom: f64) + { + let mvs = MultivariateStudent::new(location, scale, freedom); + assert!(mvs.is_err()); + } + + fn test_case(location: Vec, scale: Vec, freedom: f64, expected: T, eval: F) + where + T: Debug + PartialEq, + F: FnOnce(MultivariateStudent) -> T, + { + let mvs = try_create(location, scale, freedom); + let x = eval(mvs); + assert_eq!(expected, x); + } + + fn test_almost( + location: Vec, + scale: Vec, + freedom: f64, + expected: f64, + acc: f64, + eval: F, + ) where + F: FnOnce(MultivariateStudent) -> f64, + { + let mvs = try_create(location, scale, freedom); + let x = eval(mvs); + assert_almost_eq!(expected, x, acc); + } + + fn test_almost_multivariate_normal( + location: Vec, + scale: Vec, + freedom: f64, + acc: f64, + x: DVector, + eval_mvs: F1, + eval_mvn: F2, + ) where + F1: FnOnce(MultivariateStudent, DVector) -> f64, + F2: FnOnce(MultivariateNormal, DVector) -> f64, + { + let mvs = try_create(location.clone(), scale.clone(), freedom); + let mvn0 = MultivariateNormal::new(location, scale); + assert!(mvn0.is_ok()); + let mvn = mvn0.unwrap(); + let mvs_x = eval_mvs(mvs, x.clone()); + let mvn_x = eval_mvn(mvn, x.clone()); + assert!(mvs_x.relative_eq(&mvn_x, acc, acc), "mvn: {mvn_x} =/=\nmvs: {mvs_x}"); + // assert_relative_eq!(mvs_x, mvn_x, acc); + } + + + macro_rules! dvec { + ($($x:expr),*) => (DVector::from_vec(vec![$($x),*])); + } + + macro_rules! mat2 { + ($x11:expr, $x12:expr, $x21:expr, $x22:expr) => (DMatrix::from_vec(2,2,vec![$x11, $x12, $x21, $x22])); + } + + // macro_rules! mat3 { + // ($x11:expr, $x12:expr, $x13:expr, $x21:expr, $x22:expr, $x23:expr, $x31:expr, $x32:expr, $x33:expr) => (DMatrix::from_vec(3,3,vec![$x11, $x12, $x13, $x21, $x22, $x23, $x31, $x32, $x33])); + // } + + #[test] + fn test_create() { + create_case(vec![0., 0.], vec![1., 0., 0., 1.], 1.); + create_case(vec![10., 5.], vec![2., 1., 1., 2.], 3.); + create_case(vec![4., 5., 6.], vec![2., 1., 0., 1., 2., 1., 0., 1., 2.], 14.); + create_case(vec![0., f64::INFINITY], vec![1., 0., 0., 1.], f64::INFINITY); + create_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 0.1); + } + + #[test] + fn test_bad_create() { + // scale not symmetric. + bad_create_case(vec![0., 0.], vec![1., 1., 0., 1.], 1.); + // scale not positive-definite. + bad_create_case(vec![0., 0.], vec![1., 2., 2., 1.], 1.); + // NaN in location. + bad_create_case(vec![0., f64::NAN], vec![1., 0., 0., 1.], 1.); + // NaN in scale Matrix. + bad_create_case(vec![0., 0.], vec![1., 0., 0., f64::NAN], 1.); + // NaN in freedom. + bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], f64::NAN); + // Non-positive freedom. + bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], 0.); + bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], -1.); + } + + #[test] + fn test_variance() { + let variance = |x: MultivariateStudent| x.variance().unwrap(); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 3., 3. * mat2![1., 0., 0., 1.], variance); + test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 3., mat2![f64::INFINITY, 0., 0., f64::INFINITY], variance); + } + + // Variance is only defined for freedom > 2. + #[test] + fn test_bad_variance() { + let variance = |x: MultivariateStudent| x.variance(); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 2., None, variance); + } + + #[test] + fn test_mode() { + let mode = |x: MultivariateStudent| x.mode(); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., dvec![0., 0.], mode); + test_case(vec![f64::INFINITY, f64::INFINITY], vec![1., 0., 0., 1.], 1., dvec![f64::INFINITY, f64::INFINITY], mode); + } + + #[test] + fn test_mean() { + let mean = |x: MultivariateStudent| x.mean().unwrap(); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 2., dvec![0., 0.], mean); + test_case(vec![-1., 1., 3.], vec![1., 0., 0.5, 0., 2.0, 0., 0.5, 0., 3.0], 2., dvec![-1., 1., 3.], mean); + } + + // Mean is only defined if freedom > 1. + #[test] + fn test_bad_mean() { + let mean = |x: MultivariateStudent| x.mean(); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., None, mean); + } + + #[test] + fn test_min_max() { + let min = |x: MultivariateStudent| x.min(); + let max = |x: MultivariateStudent| x.max(); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., dvec![f64::INFINITY, f64::INFINITY], max); + test_case(vec![10., 1.], vec![1., 0., 0., 1.], 1., dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min); + test_case(vec![-3., 5.], vec![1., 0., 0., 1.], 1., dvec![f64::INFINITY, f64::INFINITY], max); + } + + #[test] + fn test_pdf() { + let pdf = |arg: DVector| move |x: MultivariateStudent| x.pdf(&arg); + test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.047157020175376416, 1e-15, pdf(dvec![1., 1.])); + test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.013972450422333741737457302178882, 1e-15, pdf(dvec![1., 2.])); + test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 2., 0.012992240252399619, 1e-17, pdf(dvec![1., 2.])); + test_almost(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, 2.639780816598878e-5, 1e-19, pdf(dvec![1., 10.])); + test_almost(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, 6.438051574348526e-5, 1e-19, pdf(dvec![10., 10.])); + // These three are crossed checked against both python's scipy.multivariate_t.pdf and octave's mvtpdf. + test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 6.960998836915657e-16, 1e-30, pdf(dvec![0.9718, 0.1298, 0.8134])); + test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 7.369987979187023e-16, 1e-30, pdf(dvec![0.4922, 0.5522, 0.7185])); + test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8.,6.951631724511314e-16, 1e-30, pdf(dvec![0.3020, 0.1491, 0.5008])); + test_case(vec![-1., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 10., 0., pdf(dvec![10., 10.])); + } + + #[test] + fn test_ln_pdf() { + let ln_pdf = |arg: DVector| move |x: MultivariateStudent| x.ln_pdf(&arg); + test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., -3.0542723907338383, 1e-14, ln_pdf(dvec![1., 1.])); + test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 2., -4.3434030034000815, 1e-14, ln_pdf(dvec![1., 2.])); + test_almost(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, -10.542229575274265, 1e-14, ln_pdf(dvec![1., 10.])); + test_almost(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, -9.650699521198622, 1e-14, ln_pdf(dvec![10., 10.])); + // test_case(vec![-1., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 10., f64::NEG_INFINITY, ln_pdf(dvec![10., 10.])); + } + + #[test] + fn test_pdf_freedom_large() { + let pdf_mvs = |mv: MultivariateStudent, arg: DVector| mv.pdf(&arg); + let pdf_mvn = |mv: MultivariateNormal, arg: DVector| mv.pdf(&arg); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-6, dvec![1., 1.], pdf_mvs, pdf_mvn); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 1e-7, dvec![1., 1.], pdf_mvs, pdf_mvn); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn); + test_almost_multivariate_normal(vec![5., -1.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![5., 1.], pdf_mvs, pdf_mvn); + } + #[test] + fn test_ln_pdf_freedom_large() { + let pdf_mvs = |mv: MultivariateStudent, arg: DVector| mv.ln_pdf(&arg); + let pdf_mvn = |mv: MultivariateNormal, arg: DVector| mv.ln_pdf(&arg); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-5, dvec![1., 1.], pdf_mvs, pdf_mvn); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 5e-6, dvec![1., 1.], pdf_mvs, pdf_mvn); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn); + } +} From aa276c8ef00b39937d56aba0f6e7ca950befe316 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Thu, 22 Aug 2024 11:17:50 -0500 Subject: [PATCH 127/185] test: ensure field immutable access is as expected --- src/distribution/multivariate_students_t.rs | 26 ++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/distribution/multivariate_students_t.rs b/src/distribution/multivariate_students_t.rs index 4a3cd38f..3758d328 100644 --- a/src/distribution/multivariate_students_t.rs +++ b/src/distribution/multivariate_students_t.rs @@ -339,7 +339,7 @@ mod tests { use core::fmt::Debug; use approx::RelativeEq; - use nalgebra::{Dyn, DMatrix, DVector}; + use nalgebra::{DMatrix, DVector, Dyn, OMatrix, OVector, U1, U2}; use crate::{ distribution::{Continuous, MultivariateStudent, MultivariateNormal}, @@ -540,4 +540,28 @@ mod tests { test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn); test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn); } + + #[test] + fn test_immut_field_access() { + // init as Dyn + let mvs = MultivariateStudent::new(vec![1., 1.], vec![1., 0., 0., 1.], 2.) + .expect("hard coded valid construction"); + assert_eq!(mvs.freedom(), 2.); + assert_relative_eq!(mvs.ln_pdf_const(), std::f64::consts::TAU.recip().ln(), epsilon = 1e-15); + + // compare to static + assert_eq!(mvs.dim(), 2); + assert!(mvs.location().eq(&OVector::::new(1., 1.))); + assert!(mvs.scale().eq(&OMatrix::::identity())); + assert!(mvs.precision().eq(&OMatrix::::identity())); + assert!(mvs.scale_chol_decomp().eq(&OMatrix::::identity())); + + // compare to Dyn + assert_eq!(mvs.location(),&OVector::::from_element_generic(Dyn(2), U1, 1.)); + assert_eq!(mvs.scale(), &OMatrix::::identity(2, 2)); + assert_eq!(mvs.precision(), &OMatrix::::identity(2, 2)); + assert_eq!(mvs.scale_chol_decomp(), &OMatrix::::identity(2, 2)); + } + + } From 31490210579fb5f4599a78c8dc5b2537e84ce3a5 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 7 Aug 2024 17:05:02 +0200 Subject: [PATCH 128/185] Collect coverage from doctests too Signed-off-by: FreezyLemon --- .github/workflows/coverage.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index cb069b7e..1485889c 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -14,7 +14,7 @@ jobs: CARGO_TERM_COLOR: always steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: dtolnay/rust-toolchain@nightly with: components: llvm-tools-preview @@ -28,7 +28,8 @@ jobs: - name: Collect coverage run: | cargo llvm-cov --no-report nextest - cargo llvm-cov report --lcov --output-path lcov.info + cargo llvm-cov --no-report --doc + cargo llvm-cov report --doctests --lcov --output-path lcov.info - name: Upload to codecov.io uses: codecov/codecov-action@v4 From cbf1495d75d45ae7b9cfd013c017a590c0546760 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 22:59:17 +0200 Subject: [PATCH 129/185] build: Add "nalgebra" feature gate The feature is enabled by default to avoid downstream breakage. Users can disable the default features to remove the nalgebra dependency at the cost of some distributions which rely on it. The distributions requiring nalgebra are (currently): - Dirichlet - Multinomial - MultivariateNormal - MultivariateStudent --- Cargo.toml | 17 +++++++++++++++-- src/distribution/internal.rs | 10 +++++++--- src/distribution/mod.rs | 8 ++++++++ 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ab92cb19..b6277e90 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,13 +17,26 @@ include = ["CHANGELOG.md", "LICENSE.md", "src/", "tests/"] name = "statrs" path = "src/lib.rs" +[features] +default = ["nalgebra"] +nalgebra = ["dep:nalgebra"] + [dependencies] rand = "0.8" -nalgebra = { version = "0.32", default-features = false, features = ["rand", "std"] } approx = "0.5.0" num-traits = "0.2.14" +[dependencies.nalgebra] +version = "0.32" +optional = true +default-features = false +features = ["rand", "std"] + [dev-dependencies] criterion = "0.3.3" anyhow = "1.0" -nalgebra = { version = "0.32", default-features = false, features = ["macros"] } + +[dev-dependencies.nalgebra] +version = "0.32" +default-features = false +features = ["macros"] diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 0837dc40..72c86e28 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -1,8 +1,5 @@ -use nalgebra::{Dim, OVector}; use num_traits::Num; -use crate::StatsError; - /// Returns true if there are no elements in `x` in `arr` /// such that `x <= 0.0` or `x` is `f64::NAN` and `sum(arr) > 0.0`. /// IF `incl_zero` is true, it tests for `x < 0.0` instead of `x <= 0.0` @@ -17,11 +14,17 @@ pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool { sum != 0.0 } +#[cfg(feature = "nalgebra")] +use nalgebra::{Dim, OVector}; + +#[cfg(feature = "nalgebra")] pub fn check_multinomial(arr: &OVector, accept_zeroes: bool) -> crate::Result<()> where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { + use crate::StatsError; + if arr.len() < 2 { return Err(StatsError::BadParams); } @@ -253,6 +256,7 @@ pub mod test { check_sum_pmf_is_cdf(dist, x_max); } + #[cfg(feature = "nalgebra")] #[test] fn test_is_valid_multinomial() { use std::f64; diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index b146f9d4..6e43db8e 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -13,6 +13,7 @@ pub use self::cauchy::Cauchy; pub use self::chi::Chi; pub use self::chi_squared::ChiSquared; pub use self::dirac::Dirac; +#[cfg(feature = "nalgebra")] pub use self::dirichlet::Dirichlet; pub use self::discrete_uniform::DiscreteUniform; pub use self::empirical::Empirical; @@ -25,8 +26,11 @@ pub use self::hypergeometric::Hypergeometric; pub use self::inverse_gamma::InverseGamma; pub use self::laplace::Laplace; pub use self::log_normal::LogNormal; +#[cfg(feature = "nalgebra")] pub use self::multinomial::Multinomial; +#[cfg(feature = "nalgebra")] pub use self::multivariate_normal::MultivariateNormal; +#[cfg(feature = "nalgebra")] pub use self::multivariate_students_t::MultivariateStudent; pub use self::negative_binomial::NegativeBinomial; pub use self::normal::Normal; @@ -45,6 +49,7 @@ mod cauchy; mod chi; mod chi_squared; mod dirac; +#[cfg(feature = "nalgebra")] mod dirichlet; mod discrete_uniform; mod empirical; @@ -59,8 +64,11 @@ mod internal; mod inverse_gamma; mod laplace; mod log_normal; +#[cfg(feature = "nalgebra")] mod multinomial; +#[cfg(feature = "nalgebra")] mod multivariate_normal; +#[cfg(feature = "nalgebra")] mod multivariate_students_t; mod negative_binomial; mod normal; From f7f2960d2c32861e531ef1d35906ef43e0376162 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 23:25:08 +0200 Subject: [PATCH 130/185] ci: Run clippy with and without default features --- .github/workflows/test.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5c031691..2b39f41e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,6 +23,9 @@ jobs: - name: Run cargo clippy run: cargo clippy --all-targets + - name: Run cargo clippy without default features + run: cargo clippy --no-default-features --all-targets + fmt: runs-on: ubuntu-latest steps: @@ -48,5 +51,5 @@ jobs: uses: dtolnay/rust-toolchain@stable - name: Test default features - run: cargo test --all-targets + run: cargo test From c29f4f8831a195311e1afb4f600b190eb5fc4157 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Thu, 5 Sep 2024 12:47:48 +0200 Subject: [PATCH 131/185] chore: update criterion to 0.5 --- Cargo.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index b6277e90..15b06e4e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,10 @@ include = ["CHANGELOG.md", "LICENSE.md", "src/", "tests/"] name = "statrs" path = "src/lib.rs" +[[bench]] +name = "order_statistics" +harness = false + [features] default = ["nalgebra"] nalgebra = ["dep:nalgebra"] @@ -33,7 +37,7 @@ default-features = false features = ["rand", "std"] [dev-dependencies] -criterion = "0.3.3" +criterion = "0.5" anyhow = "1.0" [dev-dependencies.nalgebra] From 8935ee3f1e2f552ea7e7a1452d737c9936cc899c Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 16:02:37 +0200 Subject: [PATCH 132/185] Replace try_create with create_ok Example message: Gamma::new was expected to succeed, but failed for shape=10.0, rate=NaN with error: 'Bad distribution parameters' --- src/distribution/beta.rs | 6 +++--- src/distribution/gamma.rs | 6 +++--- src/distribution/internal.rs | 38 +++++++++++++++++++++++++++++----- src/distribution/students_t.rs | 14 ++++++------- 4 files changed, 46 insertions(+), 18 deletions(-) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index c6df005b..8555ed29 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -439,7 +439,7 @@ mod tests { fn test_create() { let valid = [(1.0, 1.0), (9.0, 1.0), (5.0, 100.0), (1.0, f64::INFINITY), (f64::INFINITY, 1.0)]; for (a, b) in valid { - try_create(a, b); + create_ok(a, b); } } @@ -722,7 +722,7 @@ mod tests { #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(1.2, 3.4), 0.0, 1.0); - test::check_continuous_distribution(&try_create(4.5, 6.7), 0.0, 1.0); + test::check_continuous_distribution(&create_ok(1.2, 3.4), 0.0, 1.0); + test::check_continuous_distribution(&create_ok(4.5, 6.7), 0.0, 1.0); } } diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 166ebb72..ef320529 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -423,7 +423,7 @@ mod tests { ]; for (s, r) in valid { - try_create(s, r); + create_ok(s, r); } } @@ -672,7 +672,7 @@ mod tests { #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(1.0, 0.5), 0.0, 20.0); - test::check_continuous_distribution(&try_create(9.0, 2.0), 0.0, 20.0); + test::check_continuous_distribution(&create_ok(1.0, 0.5), 0.0, 20.0); + test::check_continuous_distribution(&create_ok(9.0, 2.0), 0.0, 20.0); } } diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 72c86e28..1b6d3791 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -101,10 +101,38 @@ pub mod test { #[macro_export] macro_rules! testing_boiler { ($($arg_name:ident: $arg_ty:ty),+; $dist:ty) => { - fn try_create($($arg_name: $arg_ty),+) -> $dist { - let n = <$dist>::new($($arg_name),+); - assert!(n.is_ok()); - n.unwrap() + fn make_param_text($($arg_name: $arg_ty),+) -> String { + // "" + let mut param_text = String::new(); + + // "shape=10.0, rate=NaN, " + $( + param_text.push_str( + &format!( + "{}={:?}, ", + stringify!($arg_name), + $arg_name, + ) + ); + )+ + + // "shape=10.0, rate=NaN" (removes trailing comma and whitespace) + param_text.pop(); + param_text.pop(); + + param_text + } + + fn create_ok($($arg_name: $arg_ty),+) -> $dist { + match <$dist>::new($($arg_name),+) { + Ok(d) => d, + Err(e) => panic!( + "{}::new was expected to succeed, but failed for {} with error: '{}'", + stringify!($dist), + make_param_text($($arg_name),+), + e + ) + } } fn bad_create_case($($arg_name: $arg_ty),+) { @@ -116,7 +144,7 @@ pub mod test { where F: Fn($dist) -> T, { - let n = try_create($($arg_name),+); + let n = create_ok($($arg_name),+); eval(n) } diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index 08d4afe9..f21e7b62 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -432,10 +432,10 @@ mod tests { #[test] fn test_create() { - try_create(0.0, 0.1, 1.0); - try_create(0.0, 1.0, 1.0); - try_create(-5.0, 1.0, 3.0); - try_create(10.0, 10.0, f64::INFINITY); + create_ok(0.0, 0.1, 1.0); + create_ok(0.0, 1.0, 1.0); + create_ok(-5.0, 1.0, 3.0); + create_ok(10.0, 10.0, f64::INFINITY); } // #[test] @@ -628,9 +628,9 @@ mod tests { #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(0.0, 1.0, 3.0), -30.0, 30.0); - test::check_continuous_distribution(&try_create(0.0, 1.0, 10.0), -10.0, 10.0); - test::check_continuous_distribution(&try_create(20.0, 0.5, 10.0), 10.0, 30.0); + test::check_continuous_distribution(&create_ok(0.0, 1.0, 3.0), -30.0, 30.0); + test::check_continuous_distribution(&create_ok(0.0, 1.0, 10.0), -10.0, 10.0); + test::check_continuous_distribution(&create_ok(20.0, 0.5, 10.0), 10.0, 30.0); } #[test] From fb99ffc61cca8c2b212385571add8ff98ef14ca2 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 16:21:56 +0200 Subject: [PATCH 133/185] Replace bad_create_case with create_err Example message: StudentsT::new was expected to fail, but succeeded for location=0.0, scale=10.0, freedom=1.0 with result: StudentsT { location: 0.0, scale: 10.0, freedom: 1.0 } --- src/distribution/beta.rs | 2 +- src/distribution/gamma.rs | 2 +- src/distribution/internal.rs | 13 ++++++++++--- src/distribution/students_t.rs | 10 +++++----- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 8555ed29..3ae05252 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -460,7 +460,7 @@ mod tests { (f64::INFINITY, f64::INFINITY), ]; for (a, b) in invalid { - bad_create_case(a, b); + create_err(a, b); } } diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index ef320529..804fa3f3 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -438,7 +438,7 @@ mod tests { (-1.0, f64::NAN), ]; for (s, r) in invalid { - bad_create_case(s, r); + create_err(s, r); } } diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 1b6d3791..d00d5994 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -135,9 +135,16 @@ pub mod test { } } - fn bad_create_case($($arg_name: $arg_ty),+) { - let n = <$dist>::new($($arg_name),+); - assert!(n.is_err()); + fn create_err($($arg_name: $arg_ty),+) -> $crate::StatsError { + match <$dist>::new($($arg_name),+) { + Err(e) => e, + Ok(d) => panic!( + "{}::new was expected to fail, but succeeded for {} with result: {:?}", + stringify!($dist), + make_param_text($($arg_name),+), + d + ) + } } fn get_value($($arg_name: $arg_ty),+, eval: F) -> T diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index f21e7b62..b908b0f6 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -446,11 +446,11 @@ mod tests { #[test] fn test_bad_create() { - bad_create_case(f64::NAN, 1.0, 1.0); - bad_create_case(0.0, f64::NAN, 1.0); - bad_create_case(0.0, 1.0, f64::NAN); - bad_create_case(0.0, -10.0, 1.0); - bad_create_case(0.0, 10.0, -1.0); + create_err(f64::NAN, 1.0, 1.0); + create_err(0.0, f64::NAN, 1.0); + create_err(0.0, 1.0, f64::NAN); + create_err(0.0, -10.0, 1.0); + create_err(0.0, 10.0, -1.0); } #[test] From 3bfbc0b41fa2a50475af0ab258f8bd6af8b5161a Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 16:48:15 +0200 Subject: [PATCH 134/185] Add docs to testing_boiler functions --- src/distribution/internal.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index d00d5994..d0ea881e 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -123,6 +123,8 @@ pub mod test { param_text } + /// Creates and returns a distribution with the given parameters, + /// panicking if `::new` fails. fn create_ok($($arg_name: $arg_ty),+) -> $dist { match <$dist>::new($($arg_name),+) { Ok(d) => d, @@ -135,6 +137,8 @@ pub mod test { } } + /// Returns the error when creating a distribution with the given parameters, + /// panicking if `::new` succeeds. fn create_err($($arg_name: $arg_ty),+) -> $crate::StatsError { match <$dist>::new($($arg_name),+) { Err(e) => e, @@ -147,6 +151,10 @@ pub mod test { } } + /// Creates a distribution with the given parameters, calls the `eval` + /// function with the new distribution and returns the result of `eval`. + /// + /// Panics if `::new` fails. fn get_value($($arg_name: $arg_ty),+, eval: F) -> T where F: Fn($dist) -> T, @@ -155,6 +163,12 @@ pub mod test { eval(n) } + /// Gets a value for the given parameters and `eval` by calling `get_value` + /// and compares it to `expected`. + /// + /// Allows relative error of up to [`crate::consts::ACC`]. + /// + /// Panics if `::new` fails. fn test_case($($arg_name: $arg_ty),+, expected: T, eval: F) where F: Fn($dist) -> T, @@ -164,6 +178,12 @@ pub mod test { assert_relative_eq!(expected, x, max_relative = $crate::consts::ACC); } + /// Gets a value for the given parameters and `eval` by calling `get_value` + /// and compares it to `expected`. + /// + /// Allows absolute error of up to `acc`. + /// + /// Panics if `::new` fails. #[allow(dead_code)] // This is not used by all distributions. fn test_case_special($($arg_name: $arg_ty),+, expected: T, acc: f64, eval: F) where @@ -174,6 +194,10 @@ pub mod test { assert_abs_diff_eq!(expected, x, epsilon = acc); } + /// Gets a value for the given parameters and `eval` by calling `get_value` + /// and asserts that it is [`None`]. + /// + /// Panics if `::new` fails. #[allow(dead_code)] // This is not used by all distributions. fn test_none($($arg_name: $arg_ty),+, eval: F) where From dd7ba8aa9652001a24ecfa12c4a85c650bf43b66 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 20:09:53 +0200 Subject: [PATCH 135/185] Use test_none for Beta --- src/distribution/beta.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 3ae05252..98627504 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -529,17 +529,13 @@ mod tests { } #[test] - #[should_panic] fn test_mode_shape_a_lte_1() { - let mode = |x: Beta| x.mode().unwrap(); - get_value(1.0, 5.0, mode); + test_none(1.0, 5.0, |dist| dist.mode()); } #[test] - #[should_panic] fn test_mode_shape_b_lte_1() { - let mode = |x: Beta| x.mode().unwrap(); - get_value(5.0, 1.0, mode); + test_none(5.0, 1.0, |dist| dist.mode()); } #[test] From e75b876dc4a85f9630acf7405121a63262720745 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 20:12:27 +0200 Subject: [PATCH 136/185] Use test_none for StudentsT --- src/distribution/students_t.rs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index b908b0f6..37f77aa5 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -469,10 +469,8 @@ mod tests { } #[test] - #[should_panic] fn test_mean_freedom_lte_1() { - let mean = |x: StudentsT| x.mean().unwrap(); - get_value(1.0, 1.0, 0.5, mean); + test_none(1.0, 1.0, 0.5, |dist| dist.mean()); } #[test] @@ -492,18 +490,14 @@ mod tests { } #[test] - #[should_panic] fn test_variance_freedom_lte1() { - let variance = |x: StudentsT| x.variance().unwrap(); - get_value(1.0, 1.0, 0.5, variance); + test_none(1.0, 1.0, 0.5, |dist| dist.variance()); } // TODO: valid skewness tests #[test] - #[should_panic] fn test_skewness_freedom_lte_3() { - let skewness = |x: StudentsT| x.skewness().unwrap(); - get_value(1.0, 1.0, 1.0, skewness); + test_none(1.0, 1.0, 1.0, |dist| dist.skewness()); } #[test] From f203c27ef6cad35f24358c40dc4e21dce882e943 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 17:12:56 +0200 Subject: [PATCH 137/185] Rename `get_value` to `create_and_get` --- src/distribution/internal.rs | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index d0ea881e..ae43abc1 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -151,60 +151,60 @@ pub mod test { } } - /// Creates a distribution with the given parameters, calls the `eval` - /// function with the new distribution and returns the result of `eval`. + /// Creates a distribution with the given parameters, calls the `get_fn` + /// function with the new distribution and returns the result of `get_fn`. /// /// Panics if `::new` fails. - fn get_value($($arg_name: $arg_ty),+, eval: F) -> T + fn create_and_get($($arg_name: $arg_ty),+, get_fn: F) -> T where F: Fn($dist) -> T, { let n = create_ok($($arg_name),+); - eval(n) + get_fn(n) } - /// Gets a value for the given parameters and `eval` by calling `get_value` + /// Gets a value for the given parameters by calling `create_and_get` /// and compares it to `expected`. /// /// Allows relative error of up to [`crate::consts::ACC`]. /// /// Panics if `::new` fails. - fn test_case($($arg_name: $arg_ty),+, expected: T, eval: F) + fn test_case($($arg_name: $arg_ty),+, expected: T, get_fn: F) where F: Fn($dist) -> T, T: ::core::fmt::Debug + ::approx::RelativeEq, { - let x = get_value($($arg_name),+, eval); + let x = create_and_get($($arg_name),+, get_fn); assert_relative_eq!(expected, x, max_relative = $crate::consts::ACC); } - /// Gets a value for the given parameters and `eval` by calling `get_value` + /// Gets a value for the given parameters by calling `create_and_get` /// and compares it to `expected`. /// /// Allows absolute error of up to `acc`. /// /// Panics if `::new` fails. #[allow(dead_code)] // This is not used by all distributions. - fn test_case_special($($arg_name: $arg_ty),+, expected: T, acc: f64, eval: F) + fn test_case_special($($arg_name: $arg_ty),+, expected: T, acc: f64, get_fn: F) where F: Fn($dist) -> T, T: ::core::fmt::Debug + ::approx::AbsDiffEq, { - let x = get_value($($arg_name),+, eval); + let x = create_and_get($($arg_name),+, get_fn); assert_abs_diff_eq!(expected, x, epsilon = acc); } - /// Gets a value for the given parameters and `eval` by calling `get_value` + /// Gets a value for the given parameters by calling `create_and_get` /// and asserts that it is [`None`]. /// /// Panics if `::new` fails. #[allow(dead_code)] // This is not used by all distributions. - fn test_none($($arg_name: $arg_ty),+, eval: F) + fn test_none($($arg_name: $arg_ty),+, get_fn: F) where F: Fn($dist) -> Option, T: ::core::cmp::PartialEq + ::core::fmt::Debug, { - let x = get_value($($arg_name),+, eval); + let x = create_and_get($($arg_name),+, get_fn); assert_eq!(None, x); } }; From ef1880e96f3524503319a652b3b6cc4046afd2e9 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 17:32:19 +0200 Subject: [PATCH 138/185] Rename testing functions & simplify trait bounds `test_case` -> `test_relative` `test_case_special` -> `test_absolute` Also improved the error messages --- src/distribution/beta.rs | 54 +++++----- src/distribution/gamma.rs | 36 +++---- src/distribution/internal.rs | 49 ++++++++-- src/distribution/students_t.rs | 174 ++++++++++++++++----------------- 4 files changed, 171 insertions(+), 142 deletions(-) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 98627504..4febc3c0 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -475,7 +475,7 @@ mod tests { ((f64::INFINITY, 1.0), 1.0), ]; for ((a, b), res) in test { - test_case(a, b, res, f); + test_relative(a, b, res, f); } } @@ -490,7 +490,7 @@ mod tests { ((f64::INFINITY, 1.0), 0.0), ]; for ((a, b), res) in test { - test_case(a, b, res, f); + test_relative(a, b, res, f); } } @@ -502,9 +502,9 @@ mod tests { ((5.0, 100.0), -2.52016231876027436794592), ]; for ((a, b), res) in test { - test_case(a, b, res, f); + test_relative(a, b, res, f); } - test_case_special(1.0, 1.0, 0.0, 1e-14, f); + test_absolute(1.0, 1.0, 0.0, 1e-14, f); let entropy = |x: Beta| x.entropy(); test_none(1.0, f64::INFINITY, entropy); test_none(f64::INFINITY, 1.0, entropy); @@ -513,19 +513,19 @@ mod tests { #[test] fn test_skewness() { let skewness = |x: Beta| x.skewness().unwrap(); - test_case(1.0, 1.0, 0.0, skewness); - test_case(9.0, 1.0, -1.4740554623801777107177478829, skewness); - test_case(5.0, 100.0, 0.817594109275534303545831591, skewness); - test_case(1.0, f64::INFINITY, 2.0, skewness); - test_case(f64::INFINITY, 1.0, -2.0, skewness); + test_relative(1.0, 1.0, 0.0, skewness); + test_relative(9.0, 1.0, -1.4740554623801777107177478829, skewness); + test_relative(5.0, 100.0, 0.817594109275534303545831591, skewness); + test_relative(1.0, f64::INFINITY, 2.0, skewness); + test_relative(f64::INFINITY, 1.0, -2.0, skewness); } #[test] fn test_mode() { let mode = |x: Beta| x.mode().unwrap(); - test_case(5.0, 100.0, 0.038834951456310676243255386, mode); - test_case(92.0, f64::INFINITY, 0.0, mode); - test_case(f64::INFINITY, 2.0, 1.0, mode); + test_relative(5.0, 100.0, 0.038834951456310676243255386, mode); + test_relative(92.0, f64::INFINITY, 0.0, mode); + test_relative(f64::INFINITY, 2.0, 1.0, mode); } #[test] @@ -542,8 +542,8 @@ mod tests { fn test_min_max() { let min = |x: Beta| x.min(); let max = |x: Beta| x.max(); - test_case(1.0, 1.0, 0.0, min); - test_case(1.0, 1.0, 1.0, max); + test_relative(1.0, 1.0, 0.0, min); + test_relative(1.0, 1.0, 1.0, max); } #[test] @@ -568,20 +568,20 @@ mod tests { ((f64::INFINITY, 1.0), 1.0, f64::INFINITY), ]; for ((a, b), x, expect) in test { - test_case(a, b, expect, f(x)); + test_relative(a, b, expect, f(x)); } } #[test] fn test_pdf_input_lt_0() { let pdf = |arg: f64| move |x: Beta| x.pdf(arg); - test_case(1.0, 1.0, 0.0, pdf(-1.0)); + test_relative(1.0, 1.0, 0.0, pdf(-1.0)); } #[test] fn test_pdf_input_gt_0() { let pdf = |arg: f64| move |x: Beta| x.pdf(arg); - test_case(1.0, 1.0, 0.0, pdf(2.0)); + test_relative(1.0, 1.0, 0.0, pdf(2.0)); } #[test] @@ -605,20 +605,20 @@ mod tests { ((f64::INFINITY, 1.0), 1.0, f64::INFINITY), ]; for ((a, b), x, expect) in test { - test_case(a, b, expect, f(x)); + test_relative(a, b, expect, f(x)); } } #[test] fn test_ln_pdf_input_lt_0() { let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg); - test_case(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); + test_relative(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); } #[test] fn test_ln_pdf_input_gt_1() { let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg); - test_case(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(2.0)); + test_relative(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(2.0)); } #[test] @@ -642,7 +642,7 @@ mod tests { ((f64::INFINITY, 1.0), 1.0, 1.0), ]; for ((a, b), x, expect) in test { - test_case(a, b, expect, cdf(x)); + test_relative(a, b, expect, cdf(x)); } } @@ -667,7 +667,7 @@ mod tests { ((f64::INFINITY, 1.0), 1.0, 0.0), ]; for ((a, b), x, expect) in test { - test_case(a, b, expect, sf(x)); + test_relative(a, b, expect, sf(x)); } } @@ -688,32 +688,32 @@ mod tests { ((5.0, 100.0), 1.0, 1.0), ]; for ((a, b), x, expect) in test { - test_case(a, b, expect, func(x)); + test_relative(a, b, expect, func(x)); }; } #[test] fn test_cdf_input_lt_0() { let cdf = |arg: f64| move |x: Beta| x.cdf(arg); - test_case(1.0, 1.0, 0.0, cdf(-1.0)); + test_relative(1.0, 1.0, 0.0, cdf(-1.0)); } #[test] fn test_cdf_input_gt_1() { let cdf = |arg: f64| move |x: Beta| x.cdf(arg); - test_case(1.0, 1.0, 1.0, cdf(2.0)); + test_relative(1.0, 1.0, 1.0, cdf(2.0)); } #[test] fn test_sf_input_lt_0() { let sf = |arg: f64| move |x: Beta| x.sf(arg); - test_case(1.0, 1.0, 1.0, sf(-1.0)); + test_relative(1.0, 1.0, 1.0, sf(-1.0)); } #[test] fn test_sf_input_gt_1() { let sf = |arg: f64| move |x: Beta| x.sf(arg); - test_case(1.0, 1.0, 0.0, sf(2.0)); + test_relative(1.0, 1.0, 0.0, sf(2.0)); } #[test] diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 804fa3f3..e5e1a282 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -453,7 +453,7 @@ mod tests { ((10.0, f64::INFINITY), 0.0), ]; for ((s, r), res) in test { - test_case(s, r, res, f); + test_relative(s, r, res, f); } } @@ -468,7 +468,7 @@ mod tests { ((10.0, f64::INFINITY), 0.0), ]; for ((s, r), res) in test { - test_case(s, r, res, f); + test_relative(s, r, res, f); } } @@ -483,7 +483,7 @@ mod tests { ((10.0, f64::INFINITY), f64::NEG_INFINITY), ]; for ((s, r), res) in test { - test_case(s, r, res, f); + test_relative(s, r, res, f); } } @@ -498,7 +498,7 @@ mod tests { ((10.0, f64::INFINITY), 0.6324555320336758), ]; for ((s, r), res) in test { - test_case(s, r, res, f); + test_relative(s, r, res, f); } } @@ -507,7 +507,7 @@ mod tests { let f = |x: Gamma| x.mode().unwrap(); let test = [((1.0, 0.1), 0.0), ((1.0, 1.0), 0.0)]; for &((s, r), res) in test.iter() { - test_case_special(s, r, res, 10e-6, f); + test_absolute(s, r, res, 10e-6, f); } let test = [ ((10.0, 10.0), 0.9), @@ -515,7 +515,7 @@ mod tests { ((10.0, f64::INFINITY), 0.0), ]; for ((s, r), res) in test { - test_case(s, r, res, f); + test_relative(s, r, res, f); } } @@ -530,7 +530,7 @@ mod tests { ((10.0, f64::INFINITY), 0.0), ]; for ((s, r), res) in test { - test_case(s, r, res, f); + test_relative(s, r, res, f); } let f = |x: Gamma| x.max(); let test = [ @@ -541,7 +541,7 @@ mod tests { ((10.0, f64::INFINITY), f64::INFINITY), ]; for ((s, r), res) in test { - test_case(s, r, res, f); + test_relative(s, r, res, f); } } @@ -559,7 +559,7 @@ mod tests { ((10.0, 1.0), 10.0, 0.125110035721133298984764), ]; for ((s, r), x, res) in test { - test_case(s, r, res, f(x)); + test_relative(s, r, res, f(x)); } // TODO: test special // test_is_nan((10.0, f64::INFINITY), pdf(1.0)); // is this really the behavior we want? @@ -569,8 +569,8 @@ mod tests { #[test] fn test_pdf_at_zero() { - test_case(1.0, 0.1, 0.1, |x| x.pdf(0.0)); - test_case(1.0, 0.1, 0.1f64.ln(), |x| x.ln_pdf(0.0)); + test_relative(1.0, 0.1, 0.1, |x| x.pdf(0.0)); + test_relative(1.0, 0.1, 0.1f64.ln(), |x| x.ln_pdf(0.0)); } #[test] @@ -588,7 +588,7 @@ mod tests { ((10.0, f64::INFINITY), f64::INFINITY, f64::NEG_INFINITY), ]; for ((s, r), x, res) in test { - test_case(s, r, res, f(x)); + test_relative(s, r, res, f(x)); } // TODO: test special // test_is_nan((10.0, f64::INFINITY), f(1.0)); // is this really the behavior we want? @@ -610,13 +610,13 @@ mod tests { ((10.0, f64::INFINITY), 10.0, 1.0), ]; for ((s, r), x, res) in test { - test_case(s, r, res, f(x)); + test_relative(s, r, res, f(x)); } } #[test] fn test_cdf_at_zero() { - test_case(1.0, 0.1, 0.0, |x| x.cdf(0.0)); + test_relative(1.0, 0.1, 0.0, |x| x.cdf(0.0)); } #[test] @@ -633,7 +633,7 @@ mod tests { for (s, r) in params { for n in -5..0 { let p = 10.0f64.powi(n); - test_case(s, r, p, f(p)); + test_relative(s, r, p, f(p)); } } @@ -641,7 +641,7 @@ mod tests { { let x = 20.5567; let f = |x: f64| move |g: Gamma| g.inverse_cdf(g.cdf(x)); - test_case(3.0, 0.5, x, f(x)) + test_relative(3.0, 0.5, x, f(x)) } } @@ -661,13 +661,13 @@ mod tests { ((10.0, f64::INFINITY), 10.0, 0.0), ]; for ((s, r), x, res) in test { - test_case(s, r, res, f(x)); + test_relative(s, r, res, f(x)); } } #[test] fn test_sf_at_zero() { - test_case(1.0, 0.1, 1.0, |x| x.sf(0.0)); + test_relative(1.0, 0.1, 1.0, |x| x.sf(0.0)); } #[test] diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index ae43abc1..ecc8435d 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -169,13 +169,22 @@ pub mod test { /// Allows relative error of up to [`crate::consts::ACC`]. /// /// Panics if `::new` fails. - fn test_case($($arg_name: $arg_ty),+, expected: T, get_fn: F) + fn test_relative($($arg_name: $arg_ty),+, expected: f64, get_fn: F) where - F: Fn($dist) -> T, - T: ::core::fmt::Debug + ::approx::RelativeEq, + F: Fn($dist) -> f64, { let x = create_and_get($($arg_name),+, get_fn); - assert_relative_eq!(expected, x, max_relative = $crate::consts::ACC); + let max_relative = $crate::consts::ACC; + + if !::approx::relative_eq!(expected, x, max_relative = max_relative) { + panic!( + "Expected {:?} to be almost equal to {:?} (max. relative error of {:?}), but wasn't for {}", + x, + expected, + max_relative, + make_param_text($($arg_name),+) + ); + } } /// Gets a value for the given parameters by calling `create_and_get` @@ -185,13 +194,26 @@ pub mod test { /// /// Panics if `::new` fails. #[allow(dead_code)] // This is not used by all distributions. - fn test_case_special($($arg_name: $arg_ty),+, expected: T, acc: f64, get_fn: F) + fn test_absolute($($arg_name: $arg_ty),+, expected: f64, acc: f64, get_fn: F) where - F: Fn($dist) -> T, - T: ::core::fmt::Debug + ::approx::AbsDiffEq, + F: Fn($dist) -> f64, { let x = create_and_get($($arg_name),+, get_fn); - assert_abs_diff_eq!(expected, x, epsilon = acc); + + // abs_diff_eq! cannot handle infinities, so we manually accept them here + if expected.is_infinite() && x == expected { + return; + } + + if !::approx::abs_diff_eq!(expected, x, epsilon = acc) { + panic!( + "Expected {:?} to be almost equal to {:?} (max. absolute error of {:?}), but wasn't for {}", + x, + expected, + acc, + make_param_text($($arg_name),+) + ); + } } /// Gets a value for the given parameters by calling `create_and_get` @@ -202,10 +224,17 @@ pub mod test { fn test_none($($arg_name: $arg_ty),+, get_fn: F) where F: Fn($dist) -> Option, - T: ::core::cmp::PartialEq + ::core::fmt::Debug, + T: ::core::fmt::Debug, { let x = create_and_get($($arg_name),+, get_fn); - assert_eq!(None, x); + + if let Some(inner) = x { + panic!( + "Expected None, got {:?} for {}", + inner, + make_param_text($($arg_name),+) + ) + } } }; } diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index 37f77aa5..4bada682 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -456,10 +456,10 @@ mod tests { #[test] fn test_mean() { let mean = |x: StudentsT| x.mean().unwrap(); - test_case(0.0, 1.0, 3.0, 0.0, mean); - test_case(0.0, 10.0, 2.0, 0.0, mean); - test_case(0.0, 10.0, f64::INFINITY, 0.0, mean); - test_case(-5.0, 100.0, 1.5, -5.0, mean); + test_relative(0.0, 1.0, 3.0, 0.0, mean); + test_relative(0.0, 10.0, 2.0, 0.0, mean); + test_relative(0.0, 10.0, f64::INFINITY, 0.0, mean); + test_relative(-5.0, 100.0, 1.5, -5.0, mean); let mean = |x: StudentsT| x.mean(); test_none(0.0, 1.0, 1.0, mean); test_none(0.0, 0.1, 1.0, mean); @@ -476,9 +476,9 @@ mod tests { #[test] fn test_variance() { let variance = |x: StudentsT| x.variance().unwrap(); - test_case(0.0, 1.0, 3.0, 3.0, variance); - test_case(0.0, 10.0, 2.5, 500.0, variance); - test_case(10.0, 1.0, 2.5, 5.0, variance); + test_relative(0.0, 1.0, 3.0, 3.0, variance); + test_relative(0.0, 10.0, 2.5, 500.0, variance); + test_relative(10.0, 1.0, 2.5, 5.0, variance); let variance = |x: StudentsT| x.variance(); test_none(0.0, 10.0, 2.0, variance); test_none(0.0, 1.0, 1.0, variance); @@ -503,121 +503,121 @@ mod tests { #[test] fn test_mode() { let mode = |x: StudentsT| x.mode().unwrap(); - test_case(0.0, 1.0, 1.0, 0.0, mode); - test_case(0.0, 0.1, 1.0, 0.0, mode); - test_case(0.0, 1.0, 3.0, 0.0, mode); - test_case(0.0, 10.0, 1.0, 0.0, mode); - test_case(0.0, 10.0, 2.0, 0.0, mode); - test_case(0.0, 10.0, 2.5, 0.0, mode); - test_case(0.0, 10.0, f64::INFINITY, 0.0, mode); - test_case(10.0, 1.0, 1.0, 10.0, mode); - test_case(10.0, 1.0, 2.5, 10.0, mode); - test_case(-5.0, 100.0, 1.5, -5.0, mode); - test_case(0.0, f64::INFINITY, 1.0, 0.0, mode); + test_relative(0.0, 1.0, 1.0, 0.0, mode); + test_relative(0.0, 0.1, 1.0, 0.0, mode); + test_relative(0.0, 1.0, 3.0, 0.0, mode); + test_relative(0.0, 10.0, 1.0, 0.0, mode); + test_relative(0.0, 10.0, 2.0, 0.0, mode); + test_relative(0.0, 10.0, 2.5, 0.0, mode); + test_relative(0.0, 10.0, f64::INFINITY, 0.0, mode); + test_relative(10.0, 1.0, 1.0, 10.0, mode); + test_relative(10.0, 1.0, 2.5, 10.0, mode); + test_relative(-5.0, 100.0, 1.5, -5.0, mode); + test_relative(0.0, f64::INFINITY, 1.0, 0.0, mode); } #[test] fn test_median() { let median = |x: StudentsT| x.median(); - test_case(0.0, 1.0, 1.0, 0.0, median); - test_case(0.0, 0.1, 1.0, 0.0, median); - test_case(0.0, 1.0, 3.0, 0.0, median); - test_case(0.0, 10.0, 1.0, 0.0, median); - test_case(0.0, 10.0, 2.0, 0.0, median); - test_case(0.0, 10.0, 2.5, 0.0, median); - test_case(0.0, 10.0, f64::INFINITY, 0.0, median); - test_case(10.0, 1.0, 1.0, 10.0, median); - test_case(10.0, 1.0, 2.5, 10.0, median); - test_case(-5.0, 100.0, 1.5, -5.0, median); - test_case(0.0, f64::INFINITY, 1.0, 0.0, median); + test_relative(0.0, 1.0, 1.0, 0.0, median); + test_relative(0.0, 0.1, 1.0, 0.0, median); + test_relative(0.0, 1.0, 3.0, 0.0, median); + test_relative(0.0, 10.0, 1.0, 0.0, median); + test_relative(0.0, 10.0, 2.0, 0.0, median); + test_relative(0.0, 10.0, 2.5, 0.0, median); + test_relative(0.0, 10.0, f64::INFINITY, 0.0, median); + test_relative(10.0, 1.0, 1.0, 10.0, median); + test_relative(10.0, 1.0, 2.5, 10.0, median); + test_relative(-5.0, 100.0, 1.5, -5.0, median); + test_relative(0.0, f64::INFINITY, 1.0, 0.0, median); } #[test] fn test_min_max() { let min = |x: StudentsT| x.min(); let max = |x: StudentsT| x.max(); - test_case(0.0, 1.0, 1.0, f64::NEG_INFINITY, min); - test_case(2.5, 100.0, 1.5, f64::NEG_INFINITY, min); - test_case(10.0, f64::INFINITY, 3.5, f64::NEG_INFINITY, min); - test_case(0.0, 1.0, 1.0, f64::INFINITY, max); - test_case(2.5, 100.0, 1.5, f64::INFINITY, max); - test_case(10.0, f64::INFINITY, 5.5, f64::INFINITY, max); + test_relative(0.0, 1.0, 1.0, f64::NEG_INFINITY, min); + test_relative(2.5, 100.0, 1.5, f64::NEG_INFINITY, min); + test_relative(10.0, f64::INFINITY, 3.5, f64::NEG_INFINITY, min); + test_relative(0.0, 1.0, 1.0, f64::INFINITY, max); + test_relative(2.5, 100.0, 1.5, f64::INFINITY, max); + test_relative(10.0, f64::INFINITY, 5.5, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: StudentsT| x.pdf(arg); - test_case(0.0, 1.0, 1.0, std::f64::consts::FRAC_1_PI, pdf(0.0)); - test_case(0.0, 1.0, 1.0, 0.159154943091895, pdf(1.0)); - test_case(0.0, 1.0, 1.0, 0.159154943091895, pdf(-1.0)); - test_case(0.0, 1.0, 1.0, 0.063661977236758, pdf(2.0)); - test_case(0.0, 1.0, 1.0, 0.063661977236758, pdf(-2.0)); - test_case(0.0, 1.0, 2.0, 0.353553390593274, pdf(0.0)); - test_case(0.0, 1.0, 2.0, 0.192450089729875, pdf(1.0)); - test_case(0.0, 1.0, 2.0, 0.192450089729875, pdf(-1.0)); - test_case(0.0, 1.0, 2.0, 0.068041381743977, pdf(2.0)); - test_case(0.0, 1.0, 2.0, 0.068041381743977, pdf(-2.0)); - test_case(0.0, 1.0, f64::INFINITY, 0.398942280401433, pdf(0.0)); - test_case(0.0, 1.0, f64::INFINITY, 0.241970724519143, pdf(1.0)); - test_case(0.0, 1.0, f64::INFINITY, 0.053990966513188, pdf(2.0)); + test_relative(0.0, 1.0, 1.0, std::f64::consts::FRAC_1_PI, pdf(0.0)); + test_relative(0.0, 1.0, 1.0, 0.159154943091895, pdf(1.0)); + test_relative(0.0, 1.0, 1.0, 0.159154943091895, pdf(-1.0)); + test_relative(0.0, 1.0, 1.0, 0.063661977236758, pdf(2.0)); + test_relative(0.0, 1.0, 1.0, 0.063661977236758, pdf(-2.0)); + test_relative(0.0, 1.0, 2.0, 0.353553390593274, pdf(0.0)); + test_relative(0.0, 1.0, 2.0, 0.192450089729875, pdf(1.0)); + test_relative(0.0, 1.0, 2.0, 0.192450089729875, pdf(-1.0)); + test_relative(0.0, 1.0, 2.0, 0.068041381743977, pdf(2.0)); + test_relative(0.0, 1.0, 2.0, 0.068041381743977, pdf(-2.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.398942280401433, pdf(0.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.241970724519143, pdf(1.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.053990966513188, pdf(2.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: StudentsT| x.ln_pdf(arg); - test_case(0.0, 1.0, 1.0, -1.144729885849399, ln_pdf(0.0)); - test_case(0.0, 1.0, 1.0, -1.837877066409348, ln_pdf(1.0)); - test_case(0.0, 1.0, 1.0, -1.837877066409348, ln_pdf(-1.0)); - test_case(0.0, 1.0, 1.0, -2.754167798283503, ln_pdf(2.0)); - test_case(0.0, 1.0, 1.0, -2.754167798283503, ln_pdf(-2.0)); - test_case(0.0, 1.0, 2.0, -1.039720770839917, ln_pdf(0.0)); - test_case(0.0, 1.0, 2.0, -1.647918433002166, ln_pdf(1.0)); - test_case(0.0, 1.0, 2.0, -1.647918433002166, ln_pdf(-1.0)); - test_case(0.0, 1.0, 2.0, -2.687639203842085, ln_pdf(2.0)); - test_case(0.0, 1.0, 2.0, -2.687639203842085, ln_pdf(-2.0)); - test_case(0.0, 1.0, f64::INFINITY, -0.918938533204672, ln_pdf(0.0)); - test_case(0.0, 1.0, f64::INFINITY, -1.418938533204674, ln_pdf(1.0)); - test_case(0.0, 1.0, f64::INFINITY, -2.918938533204674, ln_pdf(2.0)); + test_relative(0.0, 1.0, 1.0, -1.144729885849399, ln_pdf(0.0)); + test_relative(0.0, 1.0, 1.0, -1.837877066409348, ln_pdf(1.0)); + test_relative(0.0, 1.0, 1.0, -1.837877066409348, ln_pdf(-1.0)); + test_relative(0.0, 1.0, 1.0, -2.754167798283503, ln_pdf(2.0)); + test_relative(0.0, 1.0, 1.0, -2.754167798283503, ln_pdf(-2.0)); + test_relative(0.0, 1.0, 2.0, -1.039720770839917, ln_pdf(0.0)); + test_relative(0.0, 1.0, 2.0, -1.647918433002166, ln_pdf(1.0)); + test_relative(0.0, 1.0, 2.0, -1.647918433002166, ln_pdf(-1.0)); + test_relative(0.0, 1.0, 2.0, -2.687639203842085, ln_pdf(2.0)); + test_relative(0.0, 1.0, 2.0, -2.687639203842085, ln_pdf(-2.0)); + test_relative(0.0, 1.0, f64::INFINITY, -0.918938533204672, ln_pdf(0.0)); + test_relative(0.0, 1.0, f64::INFINITY, -1.418938533204674, ln_pdf(1.0)); + test_relative(0.0, 1.0, f64::INFINITY, -2.918938533204674, ln_pdf(2.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: StudentsT| x.cdf(arg); - test_case(0.0, 1.0, 1.0, 0.5, cdf(0.0)); - test_case(0.0, 1.0, 1.0, 0.75, cdf(1.0)); - test_case(0.0, 1.0, 1.0, 0.25, cdf(-1.0)); - test_case(0.0, 1.0, 1.0, 0.852416382349567, cdf(2.0)); - test_case(0.0, 1.0, 1.0, 0.147583617650433, cdf(-2.0)); - test_case(0.0, 1.0, 2.0, 0.5, cdf(0.0)); - test_case(0.0, 1.0, 2.0, 0.788675134594813, cdf(1.0)); - test_case(0.0, 1.0, 2.0, 0.211324865405187, cdf(-1.0)); - test_case(0.0, 1.0, 2.0, 0.908248290463863, cdf(2.0)); - test_case(0.0, 1.0, 2.0, 0.091751709536137, cdf(-2.0)); - test_case(0.0, 1.0, f64::INFINITY, 0.5, cdf(0.0)); + test_relative(0.0, 1.0, 1.0, 0.5, cdf(0.0)); + test_relative(0.0, 1.0, 1.0, 0.75, cdf(1.0)); + test_relative(0.0, 1.0, 1.0, 0.25, cdf(-1.0)); + test_relative(0.0, 1.0, 1.0, 0.852416382349567, cdf(2.0)); + test_relative(0.0, 1.0, 1.0, 0.147583617650433, cdf(-2.0)); + test_relative(0.0, 1.0, 2.0, 0.5, cdf(0.0)); + test_relative(0.0, 1.0, 2.0, 0.788675134594813, cdf(1.0)); + test_relative(0.0, 1.0, 2.0, 0.211324865405187, cdf(-1.0)); + test_relative(0.0, 1.0, 2.0, 0.908248290463863, cdf(2.0)); + test_relative(0.0, 1.0, 2.0, 0.091751709536137, cdf(-2.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.5, cdf(0.0)); // TODO: these are curiously low accuracy and should be re-examined - test_case(0.0, 1.0, f64::INFINITY, 0.841344746068543, cdf(1.0)); - test_case(0.0, 1.0, f64::INFINITY, 0.977249868051821, cdf(2.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.841344746068543, cdf(1.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.977249868051821, cdf(2.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: StudentsT| x.sf(arg); - test_case(0.0, 1.0, 1.0, 0.5, sf(0.0)); - test_case(0.0, 1.0, 1.0, 0.25, sf(1.0)); - test_case(0.0, 1.0, 1.0, 0.75, sf(-1.0)); - test_case(0.0, 1.0, 1.0, 0.147583617650433, sf(2.0)); - test_case(0.0, 1.0, 1.0, 0.852416382349566, sf(-2.0)); - test_case(0.0, 1.0, 2.0, 0.5, sf(0.0)); - test_case(0.0, 1.0, 2.0, 0.211324865405186, sf(1.0)); - test_case(0.0, 1.0, 2.0, 0.788675134594813, sf(-1.0)); - test_case(0.0, 1.0, 2.0, 0.091751709536137, sf(2.0)); - test_case(0.0, 1.0, 2.0, 0.908248290463862, sf(-2.0)); - test_case(0.0, 1.0, f64::INFINITY, 0.5, sf(0.0)); + test_relative(0.0, 1.0, 1.0, 0.5, sf(0.0)); + test_relative(0.0, 1.0, 1.0, 0.25, sf(1.0)); + test_relative(0.0, 1.0, 1.0, 0.75, sf(-1.0)); + test_relative(0.0, 1.0, 1.0, 0.147583617650433, sf(2.0)); + test_relative(0.0, 1.0, 1.0, 0.852416382349566, sf(-2.0)); + test_relative(0.0, 1.0, 2.0, 0.5, sf(0.0)); + test_relative(0.0, 1.0, 2.0, 0.211324865405186, sf(1.0)); + test_relative(0.0, 1.0, 2.0, 0.788675134594813, sf(-1.0)); + test_relative(0.0, 1.0, 2.0, 0.091751709536137, sf(2.0)); + test_relative(0.0, 1.0, 2.0, 0.908248290463862, sf(-2.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.5, sf(0.0)); // TODO: these are curiously low accuracy and should be re-examined - test_case(0.0, 1.0, f64::INFINITY, 0.158655253945057, sf(1.0)); - test_case(0.0, 1.0, f64::INFINITY, 0.022750131947162, sf(2.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.158655253945057, sf(1.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.022750131947162, sf(2.0)); } #[test] From 0189769f25edf2f17041694084fdf0cc445b0d26 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 17:43:08 +0200 Subject: [PATCH 139/185] Add test_exact function --- src/distribution/internal.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index ecc8435d..2b5ede9e 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -163,6 +163,28 @@ pub mod test { get_fn(n) } + /// Creates a distribution with the given parameters, calls the `get_fn` + /// function with the new distribution and compares the result of `get_fn` + /// to `expected` exactly. + /// + /// Panics if `::new` fails. + #[allow(dead_code)] // This is not used by all distributions. + fn test_exact($($arg_name: $arg_ty),+, expected: T, get_fn: F) + where + F: Fn($dist) -> T, + T: ::core::cmp::PartialEq + ::core::fmt::Debug + { + let x = create_and_get($($arg_name),+, get_fn); + if x != expected { + panic!( + "Expected {:?}, got {:?} for {}", + expected, + x, + make_param_text($($arg_name),+) + ); + } + } + /// Gets a value for the given parameters by calling `create_and_get` /// and compares it to `expected`. /// From 93b927f704a168cd1104fca22e42f42c296ac686 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 17:38:16 +0200 Subject: [PATCH 140/185] Use testing_boiler! for Bernoulli --- src/distribution/bernoulli.rs | 73 +++++++++-------------------------- 1 file changed, 18 insertions(+), 55 deletions(-) diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index 61499ebd..f82c0d65 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -265,90 +265,53 @@ impl Discrete for Bernoulli { #[rustfmt::skip] #[cfg(test)] mod testing { - use std::fmt::Debug; use crate::distribution::DiscreteCDF; + use crate::testing_boiler; use super::Bernoulli; - fn try_create(p: f64) -> Bernoulli { - let n = Bernoulli::new(p); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(p: f64) { - let dist = try_create(p); - assert_eq!(p, dist.p()); - } - - fn bad_create_case(p: f64) { - let n = Bernoulli::new(p); - assert!(n.is_err()); - } - - fn get_value(p: f64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Bernoulli) -> T - { - let n = try_create(p); - eval(n) - } - - fn test_case(p: f64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(Bernoulli) -> T - { - let x = get_value(p, eval); - assert_eq!(expected, x); - } - - fn test_almost(p: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Bernoulli) -> f64 - { - let x = get_value(p, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(p: f64; Bernoulli); #[test] fn test_create() { - create_case(0.0); - create_case(0.3); - create_case(1.0); + create_ok(0.0); + create_ok(0.3); + create_ok(1.0); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN); - bad_create_case(-1.0); - bad_create_case(2.0); + create_err(f64::NAN); + create_err(-1.0); + create_err(2.0); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg); - test_case(0.3, 1., cdf(1)); + test_relative(0.3, 1., cdf(1)); } #[test] fn test_sf_upper_bound() { let sf = |arg: u64| move |x: Bernoulli| x.sf(arg); - test_case(0.3, 0., sf(1)); + test_relative(0.3, 0., sf(1)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg); - test_case(0.0, 1.0, cdf(0)); - test_case(0.0, 1.0, cdf(1)); - test_almost(0.3, 0.7, 1e-15, cdf(0)); - test_almost(0.7, 0.3, 1e-15, cdf(0)); + test_relative(0.0, 1.0, cdf(0)); + test_relative(0.0, 1.0, cdf(1)); + test_absolute(0.3, 0.7, 1e-15, cdf(0)); + test_absolute(0.7, 0.3, 1e-15, cdf(0)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Bernoulli| x.sf(arg); - test_case(0.0, 0.0, sf(0)); - test_case(0.0, 0.0, sf(1)); - test_almost(0.3, 0.3, 1e-15, sf(0)); - test_almost(0.7, 0.7, 1e-15, sf(0)); + test_relative(0.0, 0.0, sf(0)); + test_relative(0.0, 0.0, sf(1)); + test_absolute(0.3, 0.3, 1e-15, sf(0)); + test_absolute(0.7, 0.7, 1e-15, sf(0)); } } From 88366fd28278c3293ed86d1cad08ebdf3d8b635c Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 17:52:18 +0200 Subject: [PATCH 141/185] Use testing_boiler! for Binomial --- src/distribution/binomial.rs | 307 +++++++++++++++-------------------- src/distribution/internal.rs | 1 + 2 files changed, 135 insertions(+), 173 deletions(-) diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 07060bfc..8eced1d7 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -328,271 +328,232 @@ impl Discrete for Binomial { #[rustfmt::skip] #[cfg(test)] mod tests { - use std::fmt::Debug; use crate::statistics::*; use crate::distribution::{DiscreteCDF, Discrete, Binomial}; use crate::distribution::internal::*; + use crate::testing_boiler; - fn try_create(p: f64, n: u64) -> Binomial { - let n = Binomial::new(p, n); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(p: f64, n: u64) { - let dist = try_create(p, n); - assert_eq!(p, dist.p()); - assert_eq!(n, dist.n()); - } - - fn bad_create_case(p: f64, n: u64) { - let n = Binomial::new(p, n); - assert!(n.is_err()); - } - - fn get_value(p: f64, n: u64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Binomial) -> T - { - let n = try_create(p, n); - eval(n) - } - - fn test_case(p: f64, n: u64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(Binomial) -> T - { - let x = get_value(p, n, eval); - println!("{} {} {:?}", p, n, expected); - assert_eq!(expected, x); - } - - fn test_almost(p: f64, n: u64, expected: f64, acc: f64, eval: F) - where F: Fn(Binomial) -> f64 - { - let x = get_value(p, n, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(p: f64, n: u64; Binomial); #[test] fn test_create() { - create_case(0.0, 4); - create_case(0.3, 3); - create_case(1.0, 2); + create_ok(0.0, 4); + create_ok(0.3, 3); + create_ok(1.0, 2); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN, 1); - bad_create_case(-1.0, 1); - bad_create_case(2.0, 1); + create_err(f64::NAN, 1); + create_err(-1.0, 1); + create_err(2.0, 1); } #[test] fn test_mean() { let mean = |x: Binomial| x.mean().unwrap(); - test_case(0.0, 4, 0.0, mean); - test_almost(0.3, 3, 0.9, 1e-15, mean); - test_case(1.0, 2, 2.0, mean); + test_exact(0.0, 4, 0.0, mean); + test_absolute(0.3, 3, 0.9, 1e-15, mean); + test_exact(1.0, 2, 2.0, mean); } #[test] fn test_variance() { let variance = |x: Binomial| x.variance().unwrap(); - test_case(0.0, 4, 0.0, variance); - test_case(0.3, 3, 0.63, variance); - test_case(1.0, 2, 0.0, variance); + test_exact(0.0, 4, 0.0, variance); + test_exact(0.3, 3, 0.63, variance); + test_exact(1.0, 2, 0.0, variance); } #[test] fn test_entropy() { let entropy = |x: Binomial| x.entropy().unwrap(); - test_case(0.0, 4, 0.0, entropy); - test_almost(0.3, 3, 1.1404671643037712668976423399228972051669206536461, 1e-15, entropy); - test_case(1.0, 2, 0.0, entropy); + test_exact(0.0, 4, 0.0, entropy); + test_absolute(0.3, 3, 1.1404671643037712668976423399228972051669206536461, 1e-15, entropy); + test_exact(1.0, 2, 0.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Binomial| x.skewness().unwrap(); - test_case(0.0, 4, f64::INFINITY, skewness); - test_case(0.3, 3, 0.503952630678969636286, skewness); - test_case(1.0, 2, f64::NEG_INFINITY, skewness); + test_exact(0.0, 4, f64::INFINITY, skewness); + test_exact(0.3, 3, 0.503952630678969636286, skewness); + test_exact(1.0, 2, f64::NEG_INFINITY, skewness); } #[test] fn test_median() { let median = |x: Binomial| x.median(); - test_case(0.0, 4, 0.0, median); - test_case(0.3, 3, 0.0, median); - test_case(1.0, 2, 2.0, median); + test_exact(0.0, 4, 0.0, median); + test_exact(0.3, 3, 0.0, median); + test_exact(1.0, 2, 2.0, median); } #[test] fn test_mode() { let mode = |x: Binomial| x.mode().unwrap(); - test_case(0.0, 4, 0, mode); - test_case(0.3, 3, 1, mode); - test_case(1.0, 2, 2, mode); + test_exact(0.0, 4, 0, mode); + test_exact(0.3, 3, 1, mode); + test_exact(1.0, 2, 2, mode); } #[test] fn test_min_max() { let min = |x: Binomial| x.min(); let max = |x: Binomial| x.max(); - test_case(0.3, 10, 0, min); - test_case(0.3, 10, 10, max); + test_exact(0.3, 10, 0, min); + test_exact(0.3, 10, 10, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Binomial| x.pmf(arg); - test_case(0.0, 1, 1.0, pmf(0)); - test_case(0.0, 1, 0.0, pmf(1)); - test_case(0.0, 3, 1.0, pmf(0)); - test_case(0.0, 3, 0.0, pmf(1)); - test_case(0.0, 3, 0.0, pmf(3)); - test_case(0.0, 10, 1.0, pmf(0)); - test_case(0.0, 10, 0.0, pmf(1)); - test_case(0.0, 10, 0.0, pmf(10)); - test_case(0.3, 1, 0.69999999999999995559107901499373838305473327636719, pmf(0)); - test_case(0.3, 1, 0.2999999999999999888977697537484345957636833190918, pmf(1)); - test_case(0.3, 3, 0.34299999999999993471888615204079956461021032657166, pmf(0)); - test_almost(0.3, 3, 0.44099999999999992772448109690231306411849135972008, 1e-15, pmf(1)); - test_almost(0.3, 3, 0.026999999999999997002397833512077451789759292859569, 1e-16, pmf(3)); - test_almost(0.3, 10, 0.02824752489999998207939855277004937778546385011091, 1e-17, pmf(0)); - test_almost(0.3, 10, 0.12106082099999992639752977030555903089040470780077, 1e-15, pmf(1)); - test_almost(0.3, 10, 0.0000059048999999999978147480206303047454017251032868501, 1e-20, pmf(10)); - test_case(1.0, 1, 0.0, pmf(0)); - test_case(1.0, 1, 1.0, pmf(1)); - test_case(1.0, 3, 0.0, pmf(0)); - test_case(1.0, 3, 0.0, pmf(1)); - test_case(1.0, 3, 1.0, pmf(3)); - test_case(1.0, 10, 0.0, pmf(0)); - test_case(1.0, 10, 0.0, pmf(1)); - test_case(1.0, 10, 1.0, pmf(10)); + test_exact(0.0, 1, 1.0, pmf(0)); + test_exact(0.0, 1, 0.0, pmf(1)); + test_exact(0.0, 3, 1.0, pmf(0)); + test_exact(0.0, 3, 0.0, pmf(1)); + test_exact(0.0, 3, 0.0, pmf(3)); + test_exact(0.0, 10, 1.0, pmf(0)); + test_exact(0.0, 10, 0.0, pmf(1)); + test_exact(0.0, 10, 0.0, pmf(10)); + test_exact(0.3, 1, 0.69999999999999995559107901499373838305473327636719, pmf(0)); + test_exact(0.3, 1, 0.2999999999999999888977697537484345957636833190918, pmf(1)); + test_exact(0.3, 3, 0.34299999999999993471888615204079956461021032657166, pmf(0)); + test_absolute(0.3, 3, 0.44099999999999992772448109690231306411849135972008, 1e-15, pmf(1)); + test_absolute(0.3, 3, 0.026999999999999997002397833512077451789759292859569, 1e-16, pmf(3)); + test_absolute(0.3, 10, 0.02824752489999998207939855277004937778546385011091, 1e-17, pmf(0)); + test_absolute(0.3, 10, 0.12106082099999992639752977030555903089040470780077, 1e-15, pmf(1)); + test_absolute(0.3, 10, 0.0000059048999999999978147480206303047454017251032868501, 1e-20, pmf(10)); + test_exact(1.0, 1, 0.0, pmf(0)); + test_exact(1.0, 1, 1.0, pmf(1)); + test_exact(1.0, 3, 0.0, pmf(0)); + test_exact(1.0, 3, 0.0, pmf(1)); + test_exact(1.0, 3, 1.0, pmf(3)); + test_exact(1.0, 10, 0.0, pmf(0)); + test_exact(1.0, 10, 0.0, pmf(1)); + test_exact(1.0, 10, 1.0, pmf(10)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Binomial| x.ln_pmf(arg); - test_case(0.0, 1, 0.0, ln_pmf(0)); - test_case(0.0, 1, f64::NEG_INFINITY, ln_pmf(1)); - test_case(0.0, 3, 0.0, ln_pmf(0)); - test_case(0.0, 3, f64::NEG_INFINITY, ln_pmf(1)); - test_case(0.0, 3, f64::NEG_INFINITY, ln_pmf(3)); - test_case(0.0, 10, 0.0, ln_pmf(0)); - test_case(0.0, 10, f64::NEG_INFINITY, ln_pmf(1)); - test_case(0.0, 10, f64::NEG_INFINITY, ln_pmf(10)); - test_case(0.3, 1, -0.3566749439387324423539544041072745145718090708995, ln_pmf(0)); - test_case(0.3, 1, -1.2039728043259360296301803719337238685164245381839, ln_pmf(1)); - test_case(0.3, 3, -1.0700248318161973270618632123218235437154272126985, ln_pmf(0)); - test_almost(0.3, 3, -0.81871040353529122294284394322574719301255212216016, 1e-15, ln_pmf(1)); - test_almost(0.3, 3, -3.6119184129778080888905411158011716055492736145517, 1e-15, ln_pmf(3)); - test_case(0.3, 10, -3.566749439387324423539544041072745145718090708995, ln_pmf(0)); - test_almost(0.3, 10, -2.1114622067804823267977785542148302920616046876506, 1e-14, ln_pmf(1)); - test_case(0.3, 10, -12.039728043259360296301803719337238685164245381839, ln_pmf(10)); - test_case(1.0, 1, f64::NEG_INFINITY, ln_pmf(0)); - test_case(1.0, 1, 0.0, ln_pmf(1)); - test_case(1.0, 3, f64::NEG_INFINITY, ln_pmf(0)); - test_case(1.0, 3, f64::NEG_INFINITY, ln_pmf(1)); - test_case(1.0, 3, 0.0, ln_pmf(3)); - test_case(1.0, 10, f64::NEG_INFINITY, ln_pmf(0)); - test_case(1.0, 10, f64::NEG_INFINITY, ln_pmf(1)); - test_case(1.0, 10, 0.0, ln_pmf(10)); + test_exact(0.0, 1, 0.0, ln_pmf(0)); + test_exact(0.0, 1, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(0.0, 3, 0.0, ln_pmf(0)); + test_exact(0.0, 3, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(0.0, 3, f64::NEG_INFINITY, ln_pmf(3)); + test_exact(0.0, 10, 0.0, ln_pmf(0)); + test_exact(0.0, 10, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(0.0, 10, f64::NEG_INFINITY, ln_pmf(10)); + test_exact(0.3, 1, -0.3566749439387324423539544041072745145718090708995, ln_pmf(0)); + test_exact(0.3, 1, -1.2039728043259360296301803719337238685164245381839, ln_pmf(1)); + test_exact(0.3, 3, -1.0700248318161973270618632123218235437154272126985, ln_pmf(0)); + test_absolute(0.3, 3, -0.81871040353529122294284394322574719301255212216016, 1e-15, ln_pmf(1)); + test_absolute(0.3, 3, -3.6119184129778080888905411158011716055492736145517, 1e-15, ln_pmf(3)); + test_exact(0.3, 10, -3.566749439387324423539544041072745145718090708995, ln_pmf(0)); + test_absolute(0.3, 10, -2.1114622067804823267977785542148302920616046876506, 1e-14, ln_pmf(1)); + test_exact(0.3, 10, -12.039728043259360296301803719337238685164245381839, ln_pmf(10)); + test_exact(1.0, 1, f64::NEG_INFINITY, ln_pmf(0)); + test_exact(1.0, 1, 0.0, ln_pmf(1)); + test_exact(1.0, 3, f64::NEG_INFINITY, ln_pmf(0)); + test_exact(1.0, 3, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(1.0, 3, 0.0, ln_pmf(3)); + test_exact(1.0, 10, f64::NEG_INFINITY, ln_pmf(0)); + test_exact(1.0, 10, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(1.0, 10, 0.0, ln_pmf(10)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Binomial| x.cdf(arg); - test_case(0.0, 1, 1.0, cdf(0)); - test_case(0.0, 1, 1.0, cdf(1)); - test_case(0.0, 3, 1.0, cdf(0)); - test_case(0.0, 3, 1.0, cdf(1)); - test_case(0.0, 3, 1.0, cdf(3)); - test_case(0.0, 10, 1.0, cdf(0)); - test_case(0.0, 10, 1.0, cdf(1)); - test_case(0.0, 10, 1.0, cdf(10)); - test_almost(0.3, 1, 0.7, 1e-15, cdf(0)); - test_case(0.3, 1, 1.0, cdf(1)); - test_almost(0.3, 3, 0.343, 1e-14, cdf(0)); - test_almost(0.3, 3, 0.784, 1e-15, cdf(1)); - test_case(0.3, 3, 1.0, cdf(3)); - test_almost(0.3, 10, 0.0282475249, 1e-16, cdf(0)); - test_almost(0.3, 10, 0.1493083459, 1e-14, cdf(1)); - test_case(0.3, 10, 1.0, cdf(10)); - test_case(1.0, 1, 0.0, cdf(0)); - test_case(1.0, 1, 1.0, cdf(1)); - test_case(1.0, 3, 0.0, cdf(0)); - test_case(1.0, 3, 0.0, cdf(1)); - test_case(1.0, 3, 1.0, cdf(3)); - test_case(1.0, 10, 0.0, cdf(0)); - test_case(1.0, 10, 0.0, cdf(1)); - test_case(1.0, 10, 1.0, cdf(10)); + test_exact(0.0, 1, 1.0, cdf(0)); + test_exact(0.0, 1, 1.0, cdf(1)); + test_exact(0.0, 3, 1.0, cdf(0)); + test_exact(0.0, 3, 1.0, cdf(1)); + test_exact(0.0, 3, 1.0, cdf(3)); + test_exact(0.0, 10, 1.0, cdf(0)); + test_exact(0.0, 10, 1.0, cdf(1)); + test_exact(0.0, 10, 1.0, cdf(10)); + test_absolute(0.3, 1, 0.7, 1e-15, cdf(0)); + test_exact(0.3, 1, 1.0, cdf(1)); + test_absolute(0.3, 3, 0.343, 1e-14, cdf(0)); + test_absolute(0.3, 3, 0.784, 1e-15, cdf(1)); + test_exact(0.3, 3, 1.0, cdf(3)); + test_absolute(0.3, 10, 0.0282475249, 1e-16, cdf(0)); + test_absolute(0.3, 10, 0.1493083459, 1e-14, cdf(1)); + test_exact(0.3, 10, 1.0, cdf(10)); + test_exact(1.0, 1, 0.0, cdf(0)); + test_exact(1.0, 1, 1.0, cdf(1)); + test_exact(1.0, 3, 0.0, cdf(0)); + test_exact(1.0, 3, 0.0, cdf(1)); + test_exact(1.0, 3, 1.0, cdf(3)); + test_exact(1.0, 10, 0.0, cdf(0)); + test_exact(1.0, 10, 0.0, cdf(1)); + test_exact(1.0, 10, 1.0, cdf(10)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Binomial| x.sf(arg); - test_case(0.0, 1, 0.0, sf(0)); - test_case(0.0, 1, 0.0, sf(1)); - test_case(0.0, 3, 0.0, sf(0)); - test_case(0.0, 3, 0.0, sf(1)); - test_case(0.0, 3, 0.0, sf(3)); - test_case(0.0, 10, 0.0, sf(0)); - test_case(0.0, 10, 0.0, sf(1)); - test_case(0.0, 10, 0.0, sf(10)); - test_almost(0.3, 1, 0.3, 1e-15, sf(0)); - test_case(0.3, 1, 0.0, sf(1)); - test_almost(0.3, 3, 0.657, 1e-14, sf(0)); - test_almost(0.3, 3, 0.216, 1e-15, sf(1)); - test_case(0.3, 3, 0.0, sf(3)); - test_almost(0.3, 10, 0.9717524751000001, 1e-16, sf(0)); - test_almost(0.3, 10, 0.850691654100002, 1e-14, sf(1)); - test_case(0.3, 10, 0.0, sf(10)); - test_case(1.0, 1, 1.0, sf(0)); - test_case(1.0, 1, 0.0, sf(1)); - test_case(1.0, 3, 1.0, sf(0)); - test_case(1.0, 3, 1.0, sf(1)); - test_case(1.0, 3, 0.0, sf(3)); - test_case(1.0, 10, 1.0, sf(0)); - test_case(1.0, 10, 1.0, sf(1)); - test_case(1.0, 10, 0.0, sf(10)); + test_exact(0.0, 1, 0.0, sf(0)); + test_exact(0.0, 1, 0.0, sf(1)); + test_exact(0.0, 3, 0.0, sf(0)); + test_exact(0.0, 3, 0.0, sf(1)); + test_exact(0.0, 3, 0.0, sf(3)); + test_exact(0.0, 10, 0.0, sf(0)); + test_exact(0.0, 10, 0.0, sf(1)); + test_exact(0.0, 10, 0.0, sf(10)); + test_absolute(0.3, 1, 0.3, 1e-15, sf(0)); + test_exact(0.3, 1, 0.0, sf(1)); + test_absolute(0.3, 3, 0.657, 1e-14, sf(0)); + test_absolute(0.3, 3, 0.216, 1e-15, sf(1)); + test_exact(0.3, 3, 0.0, sf(3)); + test_absolute(0.3, 10, 0.9717524751000001, 1e-16, sf(0)); + test_absolute(0.3, 10, 0.850691654100002, 1e-14, sf(1)); + test_exact(0.3, 10, 0.0, sf(10)); + test_exact(1.0, 1, 1.0, sf(0)); + test_exact(1.0, 1, 0.0, sf(1)); + test_exact(1.0, 3, 1.0, sf(0)); + test_exact(1.0, 3, 1.0, sf(1)); + test_exact(1.0, 3, 0.0, sf(3)); + test_exact(1.0, 10, 1.0, sf(0)); + test_exact(1.0, 10, 1.0, sf(1)); + test_exact(1.0, 10, 0.0, sf(10)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: u64| move |x: Binomial| x.cdf(arg); - test_case(0.5, 3, 1.0, cdf(5)); + test_exact(0.5, 3, 1.0, cdf(5)); } #[test] fn test_sf_upper_bound() { let sf = |arg: u64| move |x: Binomial| x.sf(arg); - test_case(0.5, 3, 0.0, sf(5)); + test_exact(0.5, 3, 0.0, sf(5)); } #[test] fn test_inverse_cdf() { let invcdf = |arg: f64| move |x: Binomial| x.inverse_cdf(arg); - test_case(0.4, 5, 2, invcdf(0.3456)); + test_exact(0.4, 5, 2, invcdf(0.3456)); // cases in issue #185 - test_case(0.018, 465, 1, invcdf(3.472e-4)); - test_case(0.5, 6, 4, invcdf(0.75)); + test_exact(0.018, 465, 1, invcdf(3.472e-4)); + test_exact(0.5, 6, 4, invcdf(0.75)); } #[test] fn test_cdf_inverse_cdf() { let cdf_invcdf = |arg: u64| move |x: Binomial| x.inverse_cdf(x.cdf(arg)); - test_case(0.3, 10, 3, cdf_invcdf(3)); - test_case(0.3, 10, 4, cdf_invcdf(4)); - test_case(0.5, 6, 4, cdf_invcdf(4)); + test_exact(0.3, 10, 3, cdf_invcdf(3)); + test_exact(0.3, 10, 4, cdf_invcdf(4)); + test_exact(0.5, 6, 4, cdf_invcdf(4)); } #[test] fn test_discrete() { - test::check_discrete_distribution(&try_create(0.3, 5), 5); - test::check_discrete_distribution(&try_create(0.7, 10), 10); + test::check_discrete_distribution(&create_ok(0.3, 5), 5); + test::check_discrete_distribution(&create_ok(0.7, 10), 10); } } diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 2b5ede9e..79acbe0d 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -191,6 +191,7 @@ pub mod test { /// Allows relative error of up to [`crate::consts::ACC`]. /// /// Panics if `::new` fails. + #[allow(dead_code)] // This is not used by all distributions. fn test_relative($($arg_name: $arg_ty),+, expected: f64, get_fn: F) where F: Fn($dist) -> f64, From 4d043b2e467564c7324c312fbc7db0c1f3ed2f8a Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:22:43 +0200 Subject: [PATCH 142/185] Use testing_boiler! for Categorical --- src/distribution/categorical.rs | 148 ++++++++++++-------------------- 1 file changed, 56 insertions(+), 92 deletions(-) diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 31bccf8b..5d26e7b5 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -352,163 +352,127 @@ fn test_binary_index() { #[rustfmt::skip] #[cfg(test)] mod tests { - use std::fmt::Debug; use crate::statistics::*; use crate::distribution::{Categorical, Discrete, DiscreteCDF}; use crate::distribution::internal::*; + use crate::testing_boiler; - fn try_create(prob_mass: &[f64]) -> Categorical { - let n = Categorical::new(prob_mass); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(prob_mass: &[f64]) { - try_create(prob_mass); - } - - fn bad_create_case(prob_mass: &[f64]) { - let n = Categorical::new(prob_mass); - assert!(n.is_err()); - } - - fn get_value(prob_mass: &[f64], eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Categorical) -> T - { - let n = try_create(prob_mass); - eval(n) - } - - fn test_case(prob_mass: &[f64], expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(Categorical) -> T - { - let x = get_value(prob_mass, eval); - assert_eq!(expected, x); - } - - fn test_almost(prob_mass: &[f64], expected: f64, acc: f64, eval: F) - where F: Fn(Categorical) -> f64 - { - let x = get_value(prob_mass, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(prob_mass: &[f64]; Categorical); #[test] fn test_create() { - create_case(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); + create_ok(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); } #[test] fn test_bad_create() { - bad_create_case(&[-1.0, 1.0]); - bad_create_case(&[0.0, 0.0]); + create_err(&[-1.0, 1.0]); + create_err(&[0.0, 0.0]); } #[test] fn test_mean() { let mean = |x: Categorical| x.mean().unwrap(); - test_case(&[0.0, 0.25, 0.5, 0.25], 2.0, mean); - test_case(&[0.0, 1.0, 2.0, 1.0], 2.0, mean); - test_case(&[0.0, 0.5, 0.5], 1.5, mean); - test_case(&[0.75, 0.25], 0.25, mean); - test_case(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 5.0, mean); + test_exact(&[0.0, 0.25, 0.5, 0.25], 2.0, mean); + test_exact(&[0.0, 1.0, 2.0, 1.0], 2.0, mean); + test_exact(&[0.0, 0.5, 0.5], 1.5, mean); + test_exact(&[0.75, 0.25], 0.25, mean); + test_exact(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 5.0, mean); } #[test] fn test_variance() { let variance = |x: Categorical| x.variance().unwrap(); - test_case(&[0.0, 0.25, 0.5, 0.25], 0.5, variance); - test_case(&[0.0, 1.0, 2.0, 1.0], 0.5, variance); - test_case(&[0.0, 0.5, 0.5], 0.25, variance); - test_case(&[0.75, 0.25], 0.1875, variance); - test_case(&[1.0, 0.0, 1.0], 1.0, variance); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0.5, variance); + test_exact(&[0.0, 1.0, 2.0, 1.0], 0.5, variance); + test_exact(&[0.0, 0.5, 0.5], 0.25, variance); + test_exact(&[0.75, 0.25], 0.1875, variance); + test_exact(&[1.0, 0.0, 1.0], 1.0, variance); } #[test] fn test_entropy() { let entropy = |x: Categorical| x.entropy().unwrap(); - test_case(&[0.0, 1.0], 0.0, entropy); - test_almost(&[0.0, 1.0, 1.0], 2f64.ln(), 1e-15, entropy); - test_almost(&[1.0, 1.0, 1.0], 3f64.ln(), 1e-15, entropy); - test_almost(&vec![1.0; 100], 100f64.ln(), 1e-14, entropy); - test_almost(&[0.0, 0.25, 0.5, 0.25], 1.0397207708399179, 1e-15, entropy); + test_exact(&[0.0, 1.0], 0.0, entropy); + test_absolute(&[0.0, 1.0, 1.0], 2f64.ln(), 1e-15, entropy); + test_absolute(&[1.0, 1.0, 1.0], 3f64.ln(), 1e-15, entropy); + test_absolute(&vec![1.0; 100], 100f64.ln(), 1e-14, entropy); + test_absolute(&[0.0, 0.25, 0.5, 0.25], 1.0397207708399179, 1e-15, entropy); } #[test] fn test_median() { let median = |x: Categorical| x.median(); - test_case(&[0.0, 3.0, 1.0, 1.0], 1.0, median); - test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, median); + test_exact(&[0.0, 3.0, 1.0, 1.0], 1.0, median); + test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, median); } #[test] fn test_min_max() { let min = |x: Categorical| x.min(); let max = |x: Categorical| x.max(); - test_case(&[4.0, 2.5, 2.5, 1.0], 0, min); - test_case(&[4.0, 2.5, 2.5, 1.0], 3, max); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0, min); + test_exact(&[4.0, 2.5, 2.5, 1.0], 3, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Categorical| x.pmf(arg); - test_case(&[0.0, 0.25, 0.5, 0.25], 0.0, pmf(0)); - test_case(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(1)); - test_case(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(3)); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0.0, pmf(0)); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(1)); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(3)); } #[test] fn test_pmf_x_too_high() { let pmf = |arg: u64| move |x: Categorical| x.pmf(arg); - test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, pmf(4)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, pmf(4)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg); - test_case(&[0.0, 0.25, 0.5, 0.25], 0f64.ln(), ln_pmf(0)); - test_case(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(1)); - test_case(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(3)); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0f64.ln(), ln_pmf(0)); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(1)); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(3)); } #[test] fn test_ln_pmf_x_too_high() { let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg); - test_case(&[4.0, 2.5, 2.5, 1.0], f64::NEG_INFINITY, ln_pmf(4)); + test_exact(&[4.0, 2.5, 2.5, 1.0], f64::NEG_INFINITY, ln_pmf(4)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Categorical| x.cdf(arg); - test_case(&[0.0, 3.0, 1.0, 1.0], 3.0 / 5.0, cdf(1)); - test_case(&[1.0, 1.0, 1.0, 1.0], 0.25, cdf(0)); - test_case(&[4.0, 2.5, 2.5, 1.0], 0.4, cdf(0)); - test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(3)); - test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4)); + test_exact(&[0.0, 3.0, 1.0, 1.0], 3.0 / 5.0, cdf(1)); + test_exact(&[1.0, 1.0, 1.0, 1.0], 0.25, cdf(0)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0.4, cdf(0)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(3)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Categorical| x.sf(arg); - test_case(&[0.0, 3.0, 1.0, 1.0], 2.0 / 5.0, sf(1)); - test_case(&[1.0, 1.0, 1.0, 1.0], 0.75, sf(0)); - test_case(&[4.0, 2.5, 2.5, 1.0], 0.6, sf(0)); - test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(3)); - test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4)); + test_exact(&[0.0, 3.0, 1.0, 1.0], 2.0 / 5.0, sf(1)); + test_exact(&[1.0, 1.0, 1.0, 1.0], 0.75, sf(0)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0.6, sf(0)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(3)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4)); } #[test] fn test_cdf_input_high() { let cdf = |arg: u64| move |x: Categorical| x.cdf(arg); - test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4)); } #[test] fn test_sf_input_high() { let sf = |arg: u64| move |x: Categorical| x.sf(arg); - test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4)); } #[test] @@ -524,31 +488,31 @@ mod tests { #[test] fn test_inverse_cdf() { let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg); - test_case(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.2)); - test_case(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.5)); - test_case(&[0.0, 3.0, 1.0, 1.0], 3, inverse_cdf(0.95)); - test_case(&[4.0, 2.5, 2.5, 1.0], 0, inverse_cdf(0.2)); - test_case(&[4.0, 2.5, 2.5, 1.0], 1, inverse_cdf(0.5)); - test_case(&[4.0, 2.5, 2.5, 1.0], 3, inverse_cdf(0.95)); + test_exact(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.2)); + test_exact(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.5)); + test_exact(&[0.0, 3.0, 1.0, 1.0], 3, inverse_cdf(0.95)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0, inverse_cdf(0.2)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 1, inverse_cdf(0.5)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 3, inverse_cdf(0.95)); } #[test] #[should_panic] fn test_inverse_cdf_input_low() { - let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg); - get_value(&[4.0, 2.5, 2.5, 1.0], inverse_cdf(0.0)); + let dist = create_ok(&[4.0, 2.5, 2.5, 1.0]); + dist.inverse_cdf(0.0); } #[test] #[should_panic] fn test_inverse_cdf_input_high() { - let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg); - get_value(&[4.0, 2.5, 2.5, 1.0], inverse_cdf(1.0)); + let dist = create_ok(&[4.0, 2.5, 2.5, 1.0]); + dist.inverse_cdf(1.0); } #[test] fn test_discrete() { - test::check_discrete_distribution(&try_create(&[1.0, 2.0, 3.0, 4.0]), 4); - test::check_discrete_distribution(&try_create(&[0.0, 1.0, 2.0, 3.0, 4.0]), 5); + test::check_discrete_distribution(&create_ok(&[1.0, 2.0, 3.0, 4.0]), 4); + test::check_discrete_distribution(&create_ok(&[0.0, 1.0, 2.0, 3.0, 4.0]), 5); } } From 606e670d67523a673a3c8b75fd92d9d9b1d50e6a Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:24:45 +0200 Subject: [PATCH 143/185] Use testing_boiler! for Cauchy --- src/distribution/cauchy.rs | 371 +++++++++++++++++-------------------- 1 file changed, 170 insertions(+), 201 deletions(-) diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index 078f7635..eb983847 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -252,266 +252,235 @@ impl Continuous for Cauchy { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::*; + use crate::{statistics::*, testing_boiler}; use crate::distribution::{ContinuousCDF, Continuous, Cauchy}; use crate::distribution::internal::*; - fn try_create(location: f64, scale: f64) -> Cauchy { - let n = Cauchy::new(location, scale); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(location: f64, scale: f64) { - let n = try_create(location, scale); - assert_eq!(location, n.location()); - assert_eq!(scale, n.scale()); - } - - fn bad_create_case(location: f64, scale: f64) { - let n = Cauchy::new(location, scale); - assert!(n.is_err()); - } - - fn test_case(location: f64, scale: f64, expected: f64, eval: F) - where F: Fn(Cauchy) -> f64 - { - let n = try_create(location, scale); - let x = eval(n); - assert_eq!(expected, x); - } - - fn test_almost(location: f64, scale: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Cauchy) -> f64 - { - let n = try_create(location, scale); - let x = eval(n); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(location: f64, scale: f64; Cauchy); #[test] fn test_create() { - create_case(0.0, 0.1); - create_case(0.0, 1.0); - create_case(0.0, 10.0); - create_case(10.0, 11.0); - create_case(-5.0, 100.0); - create_case(0.0, f64::INFINITY); + create_ok(0.0, 0.1); + create_ok(0.0, 1.0); + create_ok(0.0, 10.0); + create_ok(10.0, 11.0); + create_ok(-5.0, 100.0); + create_ok(0.0, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN); - bad_create_case(1.0, 0.0); + create_err(f64::NAN, 1.0); + create_err(1.0, f64::NAN); + create_err(f64::NAN, f64::NAN); + create_err(1.0, 0.0); } #[test] fn test_entropy() { let entropy = |x: Cauchy| x.entropy().unwrap(); - test_case(0.0, 2.0, 3.224171427529236102395, entropy); - test_case(0.1, 4.0, 3.917318608089181411812, entropy); - test_case(1.0, 10.0, 4.833609339963336476996, entropy); - test_case(10.0, 11.0, 4.92891951976766133704, entropy); + test_exact(0.0, 2.0, 3.224171427529236102395, entropy); + test_exact(0.1, 4.0, 3.917318608089181411812, entropy); + test_exact(1.0, 10.0, 4.833609339963336476996, entropy); + test_exact(10.0, 11.0, 4.92891951976766133704, entropy); } #[test] fn test_mode() { let mode = |x: Cauchy| x.mode().unwrap(); - test_case(0.0, 2.0, 0.0, mode); - test_case(0.1, 4.0, 0.1, mode); - test_case(1.0, 10.0, 1.0, mode); - test_case(10.0, 11.0, 10.0, mode); - test_case(0.0, f64::INFINITY, 0.0, mode); + test_exact(0.0, 2.0, 0.0, mode); + test_exact(0.1, 4.0, 0.1, mode); + test_exact(1.0, 10.0, 1.0, mode); + test_exact(10.0, 11.0, 10.0, mode); + test_exact(0.0, f64::INFINITY, 0.0, mode); } #[test] fn test_median() { let median = |x: Cauchy| x.median(); - test_case(0.0, 2.0, 0.0, median); - test_case(0.1, 4.0, 0.1, median); - test_case(1.0, 10.0, 1.0, median); - test_case(10.0, 11.0, 10.0, median); - test_case(0.0, f64::INFINITY, 0.0, median); + test_exact(0.0, 2.0, 0.0, median); + test_exact(0.1, 4.0, 0.1, median); + test_exact(1.0, 10.0, 1.0, median); + test_exact(10.0, 11.0, 10.0, median); + test_exact(0.0, f64::INFINITY, 0.0, median); } #[test] fn test_min_max() { let min = |x: Cauchy| x.min(); let max = |x: Cauchy| x.max(); - test_case(0.0, 1.0, f64::NEG_INFINITY, min); - test_case(0.0, 1.0, f64::INFINITY, max); + test_exact(0.0, 1.0, f64::NEG_INFINITY, min); + test_exact(0.0, 1.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Cauchy| x.pdf(arg); - test_case(0.0, 0.1, 0.001272730452554141029739, pdf(-5.0)); - test_case(0.0, 0.1, 0.03151583031522679916216, pdf(-1.0)); - test_almost(0.0, 0.1, 3.183098861837906715378, 1e-14, pdf(0.0)); - test_case(0.0, 0.1, 0.03151583031522679916216, pdf(1.0)); - test_case(0.0, 0.1, 0.001272730452554141029739, pdf(5.0)); - test_almost(0.0, 1.0, 0.01224268793014579505914, 1e-17, pdf(-5.0)); - test_case(0.0, 1.0, 0.1591549430918953357689, pdf(-1.0)); - test_case(0.0, 1.0, 0.3183098861837906715378, pdf(0.0)); - test_case(0.0, 1.0, 0.1591549430918953357689, pdf(1.0)); - test_almost(0.0, 1.0, 0.01224268793014579505914, 1e-17, pdf(5.0)); - test_case(0.0, 10.0, 0.02546479089470325372302, pdf(-5.0)); - test_case(0.0, 10.0, 0.03151583031522679916216, pdf(-1.0)); - test_case(0.0, 10.0, 0.03183098861837906715378, pdf(0.0)); - test_case(0.0, 10.0, 0.03151583031522679916216, pdf(1.0)); - test_case(0.0, 10.0, 0.02546479089470325372302, pdf(5.0)); - test_case(-5.0, 100.0, 0.003183098861837906715378, pdf(-5.0)); - test_almost(-5.0, 100.0, 0.003178014039374906864395, 1e-17, pdf(-1.0)); - test_case(-5.0, 100.0, 0.003175160959439308444267, pdf(0.0)); - test_case(-5.0, 100.0, 0.003171680810918599756255, pdf(1.0)); - test_almost(-5.0, 100.0, 0.003151583031522679916216, 1e-17, pdf(5.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(-5.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(-1.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(0.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(1.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(5.0)); - test_case(f64::INFINITY, 1.0, 0.0, pdf(-5.0)); - test_case(f64::INFINITY, 1.0, 0.0, pdf(-1.0)); - test_case(f64::INFINITY, 1.0, 0.0, pdf(0.0)); - test_case(f64::INFINITY, 1.0, 0.0, pdf(1.0)); - test_case(f64::INFINITY, 1.0, 0.0, pdf(5.0)); + test_exact(0.0, 0.1, 0.001272730452554141029739, pdf(-5.0)); + test_exact(0.0, 0.1, 0.03151583031522679916216, pdf(-1.0)); + test_absolute(0.0, 0.1, 3.183098861837906715378, 1e-14, pdf(0.0)); + test_exact(0.0, 0.1, 0.03151583031522679916216, pdf(1.0)); + test_exact(0.0, 0.1, 0.001272730452554141029739, pdf(5.0)); + test_absolute(0.0, 1.0, 0.01224268793014579505914, 1e-17, pdf(-5.0)); + test_exact(0.0, 1.0, 0.1591549430918953357689, pdf(-1.0)); + test_exact(0.0, 1.0, 0.3183098861837906715378, pdf(0.0)); + test_exact(0.0, 1.0, 0.1591549430918953357689, pdf(1.0)); + test_absolute(0.0, 1.0, 0.01224268793014579505914, 1e-17, pdf(5.0)); + test_exact(0.0, 10.0, 0.02546479089470325372302, pdf(-5.0)); + test_exact(0.0, 10.0, 0.03151583031522679916216, pdf(-1.0)); + test_exact(0.0, 10.0, 0.03183098861837906715378, pdf(0.0)); + test_exact(0.0, 10.0, 0.03151583031522679916216, pdf(1.0)); + test_exact(0.0, 10.0, 0.02546479089470325372302, pdf(5.0)); + test_exact(-5.0, 100.0, 0.003183098861837906715378, pdf(-5.0)); + test_absolute(-5.0, 100.0, 0.003178014039374906864395, 1e-17, pdf(-1.0)); + test_exact(-5.0, 100.0, 0.003175160959439308444267, pdf(0.0)); + test_exact(-5.0, 100.0, 0.003171680810918599756255, pdf(1.0)); + test_absolute(-5.0, 100.0, 0.003151583031522679916216, 1e-17, pdf(5.0)); + test_exact(0.0, f64::INFINITY, 0.0, pdf(-5.0)); + test_exact(0.0, f64::INFINITY, 0.0, pdf(-1.0)); + test_exact(0.0, f64::INFINITY, 0.0, pdf(0.0)); + test_exact(0.0, f64::INFINITY, 0.0, pdf(1.0)); + test_exact(0.0, f64::INFINITY, 0.0, pdf(5.0)); + test_exact(f64::INFINITY, 1.0, 0.0, pdf(-5.0)); + test_exact(f64::INFINITY, 1.0, 0.0, pdf(-1.0)); + test_exact(f64::INFINITY, 1.0, 0.0, pdf(0.0)); + test_exact(f64::INFINITY, 1.0, 0.0, pdf(1.0)); + test_exact(f64::INFINITY, 1.0, 0.0, pdf(5.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Cauchy| x.ln_pdf(arg); - test_case(0.0, 0.1, -6.666590723732973542744, ln_pdf(-5.0)); - test_almost(0.0, 0.1, -3.457265309696613941009, 1e-14, ln_pdf(-1.0)); - test_case(0.0, 0.1, 1.157855207144645509875, ln_pdf(0.0)); - test_almost(0.0, 0.1, -3.457265309696613941009, 1e-14, ln_pdf(1.0)); - test_case(0.0, 0.1, -6.666590723732973542744, ln_pdf(5.0)); - test_case(0.0, 1.0, -4.402826423870882219615, ln_pdf(-5.0)); - test_almost(0.0, 1.0, -1.837877066409345483561, 1e-15, ln_pdf(-1.0)); - test_case(0.0, 1.0, -1.144729885849400174143, ln_pdf(0.0)); - test_almost(0.0, 1.0, -1.837877066409345483561, 1e-15, ln_pdf(1.0)); - test_case(0.0, 1.0, -4.402826423870882219615, ln_pdf(5.0)); - test_case(0.0, 10.0, -3.670458530157655613928, ln_pdf(-5.0)); - test_almost(0.0, 10.0, -3.457265309696613941009, 1e-14, ln_pdf(-1.0)); - test_case(0.0, 10.0, -3.447314978843445858161, ln_pdf(0.0)); - test_almost(0.0, 10.0, -3.457265309696613941009, 1e-14, ln_pdf(1.0)); - test_case(0.0, 10.0, -3.670458530157655613928, ln_pdf(5.0)); - test_case(-5.0, 100.0, -5.749900071837491542179, ln_pdf(-5.0)); - test_case(-5.0, 100.0, -5.751498793201188569872, ln_pdf(-1.0)); - test_case(-5.0, 100.0, -5.75239695203607874116, ln_pdf(0.0)); - test_case(-5.0, 100.0, -5.75349360734762171285, ln_pdf(1.0)); - test_case(-5.0, 100.0, -5.759850402690659625027, ln_pdf(5.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-1.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(1.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(5.0)); - test_case(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); - test_case(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); - test_case(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(1.0)); - test_case(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(5.0)); + test_exact(0.0, 0.1, -6.666590723732973542744, ln_pdf(-5.0)); + test_absolute(0.0, 0.1, -3.457265309696613941009, 1e-14, ln_pdf(-1.0)); + test_exact(0.0, 0.1, 1.157855207144645509875, ln_pdf(0.0)); + test_absolute(0.0, 0.1, -3.457265309696613941009, 1e-14, ln_pdf(1.0)); + test_exact(0.0, 0.1, -6.666590723732973542744, ln_pdf(5.0)); + test_exact(0.0, 1.0, -4.402826423870882219615, ln_pdf(-5.0)); + test_absolute(0.0, 1.0, -1.837877066409345483561, 1e-15, ln_pdf(-1.0)); + test_exact(0.0, 1.0, -1.144729885849400174143, ln_pdf(0.0)); + test_absolute(0.0, 1.0, -1.837877066409345483561, 1e-15, ln_pdf(1.0)); + test_exact(0.0, 1.0, -4.402826423870882219615, ln_pdf(5.0)); + test_exact(0.0, 10.0, -3.670458530157655613928, ln_pdf(-5.0)); + test_absolute(0.0, 10.0, -3.457265309696613941009, 1e-14, ln_pdf(-1.0)); + test_exact(0.0, 10.0, -3.447314978843445858161, ln_pdf(0.0)); + test_absolute(0.0, 10.0, -3.457265309696613941009, 1e-14, ln_pdf(1.0)); + test_exact(0.0, 10.0, -3.670458530157655613928, ln_pdf(5.0)); + test_exact(-5.0, 100.0, -5.749900071837491542179, ln_pdf(-5.0)); + test_exact(-5.0, 100.0, -5.751498793201188569872, ln_pdf(-1.0)); + test_exact(-5.0, 100.0, -5.75239695203607874116, ln_pdf(0.0)); + test_exact(-5.0, 100.0, -5.75349360734762171285, ln_pdf(1.0)); + test_exact(-5.0, 100.0, -5.759850402690659625027, ln_pdf(5.0)); + test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); + test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-1.0)); + test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); + test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(1.0)); + test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(5.0)); + test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0)); + test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); + test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); + test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(1.0)); + test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(5.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Cauchy| x.cdf(arg); - test_almost(0.0, 0.1, 0.006365349100972796679298, 1e-16, cdf(-5.0)); - test_almost(0.0, 0.1, 0.03172551743055356951498, 1e-16, cdf(-1.0)); - test_case(0.0, 0.1, 0.5, cdf(0.0)); - test_case(0.0, 0.1, 0.968274482569446430485, cdf(1.0)); - test_case(0.0, 0.1, 0.9936346508990272033207, cdf(5.0)); - test_almost(0.0, 1.0, 0.06283295818900118381375, 1e-16, cdf(-5.0)); - test_case(0.0, 1.0, 0.25, cdf(-1.0)); - test_case(0.0, 1.0, 0.5, cdf(0.0)); - test_case(0.0, 1.0, 0.75, cdf(1.0)); - test_case(0.0, 1.0, 0.9371670418109988161863, cdf(5.0)); - test_case(0.0, 10.0, 0.3524163823495667258246, cdf(-5.0)); - test_case(0.0, 10.0, 0.468274482569446430485, cdf(-1.0)); - test_case(0.0, 10.0, 0.5, cdf(0.0)); - test_case(0.0, 10.0, 0.531725517430553569515, cdf(1.0)); - test_case(0.0, 10.0, 0.6475836176504332741754, cdf(5.0)); - test_case(-5.0, 100.0, 0.5, cdf(-5.0)); - test_case(-5.0, 100.0, 0.5127256113479918307809, cdf(-1.0)); - test_case(-5.0, 100.0, 0.5159022512561763751816, cdf(0.0)); - test_case(-5.0, 100.0, 0.5190757242358362337495, cdf(1.0)); - test_case(-5.0, 100.0, 0.531725517430553569515, cdf(5.0)); - test_case(0.0, f64::INFINITY, 0.5, cdf(-5.0)); - test_case(0.0, f64::INFINITY, 0.5, cdf(-1.0)); - test_case(0.0, f64::INFINITY, 0.5, cdf(0.0)); - test_case(0.0, f64::INFINITY, 0.5, cdf(1.0)); - test_case(0.0, f64::INFINITY, 0.5, cdf(5.0)); - test_case(f64::INFINITY, 1.0, 0.0, cdf(-5.0)); - test_case(f64::INFINITY, 1.0, 0.0, cdf(-1.0)); - test_case(f64::INFINITY, 1.0, 0.0, cdf(0.0)); - test_case(f64::INFINITY, 1.0, 0.0, cdf(1.0)); - test_case(f64::INFINITY, 1.0, 0.0, cdf(5.0)); + test_absolute(0.0, 0.1, 0.006365349100972796679298, 1e-16, cdf(-5.0)); + test_absolute(0.0, 0.1, 0.03172551743055356951498, 1e-16, cdf(-1.0)); + test_exact(0.0, 0.1, 0.5, cdf(0.0)); + test_exact(0.0, 0.1, 0.968274482569446430485, cdf(1.0)); + test_exact(0.0, 0.1, 0.9936346508990272033207, cdf(5.0)); + test_absolute(0.0, 1.0, 0.06283295818900118381375, 1e-16, cdf(-5.0)); + test_exact(0.0, 1.0, 0.25, cdf(-1.0)); + test_exact(0.0, 1.0, 0.5, cdf(0.0)); + test_exact(0.0, 1.0, 0.75, cdf(1.0)); + test_exact(0.0, 1.0, 0.9371670418109988161863, cdf(5.0)); + test_exact(0.0, 10.0, 0.3524163823495667258246, cdf(-5.0)); + test_exact(0.0, 10.0, 0.468274482569446430485, cdf(-1.0)); + test_exact(0.0, 10.0, 0.5, cdf(0.0)); + test_exact(0.0, 10.0, 0.531725517430553569515, cdf(1.0)); + test_exact(0.0, 10.0, 0.6475836176504332741754, cdf(5.0)); + test_exact(-5.0, 100.0, 0.5, cdf(-5.0)); + test_exact(-5.0, 100.0, 0.5127256113479918307809, cdf(-1.0)); + test_exact(-5.0, 100.0, 0.5159022512561763751816, cdf(0.0)); + test_exact(-5.0, 100.0, 0.5190757242358362337495, cdf(1.0)); + test_exact(-5.0, 100.0, 0.531725517430553569515, cdf(5.0)); + test_exact(0.0, f64::INFINITY, 0.5, cdf(-5.0)); + test_exact(0.0, f64::INFINITY, 0.5, cdf(-1.0)); + test_exact(0.0, f64::INFINITY, 0.5, cdf(0.0)); + test_exact(0.0, f64::INFINITY, 0.5, cdf(1.0)); + test_exact(0.0, f64::INFINITY, 0.5, cdf(5.0)); + test_exact(f64::INFINITY, 1.0, 0.0, cdf(-5.0)); + test_exact(f64::INFINITY, 1.0, 0.0, cdf(-1.0)); + test_exact(f64::INFINITY, 1.0, 0.0, cdf(0.0)); + test_exact(f64::INFINITY, 1.0, 0.0, cdf(1.0)); + test_exact(f64::INFINITY, 1.0, 0.0, cdf(5.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Cauchy| x.sf(arg); - test_almost(0.0, 0.1, 0.9936346508990272, 1e-16, sf(-5.0)); - test_almost(0.0, 0.1, 0.9682744825694465, 1e-16, sf(-1.0)); - test_case(0.0, 0.1, 0.5, sf(0.0)); - test_almost(0.0, 0.1, 0.03172551743055352, 1e-16, sf(1.0)); - test_case(0.0, 0.1, 0.006365349100972806, sf(5.0)); - test_almost(0.0, 1.0, 0.9371670418109989, 1e-16, sf(-5.0)); - test_case(0.0, 1.0, 0.75, sf(-1.0)); - test_case(0.0, 1.0, 0.5, sf(0.0)); - test_case(0.0, 1.0, 0.25, sf(1.0)); - test_case(0.0, 1.0, 0.06283295818900114, sf(5.0)); - test_case(0.0, 10.0, 0.6475836176504333, sf(-5.0)); - test_case(0.0, 10.0, 0.5317255174305535, sf(-1.0)); - test_case(0.0, 10.0, 0.5, sf(0.0)); - test_case(0.0, 10.0, 0.4682744825694464, sf(1.0)); - test_case(0.0, 10.0, 0.35241638234956674, sf(5.0)); - test_case(-5.0, 100.0, 0.5, sf(-5.0)); - test_case(-5.0, 100.0, 0.4872743886520082, sf(-1.0)); - test_case(-5.0, 100.0, 0.4840977487438236, sf(0.0)); - test_case(-5.0, 100.0, 0.48092427576416374, sf(1.0)); - test_case(-5.0, 100.0, 0.4682744825694464, sf(5.0)); - test_case(0.0, f64::INFINITY, 0.5, sf(-5.0)); - test_case(0.0, f64::INFINITY, 0.5, sf(-1.0)); - test_case(0.0, f64::INFINITY, 0.5, sf(0.0)); - test_case(0.0, f64::INFINITY, 0.5, sf(1.0)); - test_case(0.0, f64::INFINITY, 0.5, sf(5.0)); - test_case(f64::INFINITY, 1.0, 1.0, sf(-5.0)); - test_case(f64::INFINITY, 1.0, 1.0, sf(-1.0)); - test_case(f64::INFINITY, 1.0, 1.0, sf(0.0)); - test_case(f64::INFINITY, 1.0, 1.0, sf(1.0)); - test_case(f64::INFINITY, 1.0, 1.0, sf(5.0)); + test_absolute(0.0, 0.1, 0.9936346508990272, 1e-16, sf(-5.0)); + test_absolute(0.0, 0.1, 0.9682744825694465, 1e-16, sf(-1.0)); + test_exact(0.0, 0.1, 0.5, sf(0.0)); + test_absolute(0.0, 0.1, 0.03172551743055352, 1e-16, sf(1.0)); + test_exact(0.0, 0.1, 0.006365349100972806, sf(5.0)); + test_absolute(0.0, 1.0, 0.9371670418109989, 1e-16, sf(-5.0)); + test_exact(0.0, 1.0, 0.75, sf(-1.0)); + test_exact(0.0, 1.0, 0.5, sf(0.0)); + test_exact(0.0, 1.0, 0.25, sf(1.0)); + test_exact(0.0, 1.0, 0.06283295818900114, sf(5.0)); + test_exact(0.0, 10.0, 0.6475836176504333, sf(-5.0)); + test_exact(0.0, 10.0, 0.5317255174305535, sf(-1.0)); + test_exact(0.0, 10.0, 0.5, sf(0.0)); + test_exact(0.0, 10.0, 0.4682744825694464, sf(1.0)); + test_exact(0.0, 10.0, 0.35241638234956674, sf(5.0)); + test_exact(-5.0, 100.0, 0.5, sf(-5.0)); + test_exact(-5.0, 100.0, 0.4872743886520082, sf(-1.0)); + test_exact(-5.0, 100.0, 0.4840977487438236, sf(0.0)); + test_exact(-5.0, 100.0, 0.48092427576416374, sf(1.0)); + test_exact(-5.0, 100.0, 0.4682744825694464, sf(5.0)); + test_exact(0.0, f64::INFINITY, 0.5, sf(-5.0)); + test_exact(0.0, f64::INFINITY, 0.5, sf(-1.0)); + test_exact(0.0, f64::INFINITY, 0.5, sf(0.0)); + test_exact(0.0, f64::INFINITY, 0.5, sf(1.0)); + test_exact(0.0, f64::INFINITY, 0.5, sf(5.0)); + test_exact(f64::INFINITY, 1.0, 1.0, sf(-5.0)); + test_exact(f64::INFINITY, 1.0, 1.0, sf(-1.0)); + test_exact(f64::INFINITY, 1.0, 1.0, sf(0.0)); + test_exact(f64::INFINITY, 1.0, 1.0, sf(1.0)); + test_exact(f64::INFINITY, 1.0, 1.0, sf(5.0)); } #[test] fn test_inverse_cdf() { let func = |arg: f64| move |x: Cauchy| x.inverse_cdf(x.cdf(arg)); - test_almost(0.0, 0.1, -5.0, 1e-10, func(-5.0)); - test_almost(0.0, 0.1, -1.0, 1e-14, func(-1.0)); - test_case(0.0, 0.1, 0.0, func(0.0)); - test_almost(0.0, 0.1, 1.0, 1e-14, func(1.0)); - test_almost(0.0, 0.1, 5.0, 1e-10, func(5.0)); - test_almost(0.0, 1.0, -5.0, 1e-14, func(-5.0)); - test_almost(0.0, 1.0, -1.0, 1e-15, func(-1.0)); - test_case(0.0, 1.0, 0.0, func(0.0)); - test_almost(0.0, 1.0, 1.0, 1e-15, func(1.0)); - test_almost(0.0, 1.0, 5.0, 1e-14, func(5.0)); - test_almost(0.0, 10.0, -5.0, 1e-14, func(-5.0)); - test_almost(0.0, 10.0, -1.0, 1e-14, func(-1.0)); - test_case(0.0, 10.0, 0.0, func(0.0)); - test_almost(0.0, 10.0, 1.0, 1e-14, func(1.0)); - test_almost(0.0, 10.0, 5.0, 1e-14, func(5.0)); - test_case(-5.0, 100.0, -5.0, func(-5.0)); - test_almost(-5.0, 100.0, -1.0, 1e-10, func(-1.0)); - test_almost(-5.0, 100.0, 0.0, 1e-14, func(0.0)); - test_almost(-5.0, 100.0, 1.0, 1e-14, func(1.0)); - test_almost(-5.0, 100.0, 5.0, 1e-10, func(5.0)); + test_absolute(0.0, 0.1, -5.0, 1e-10, func(-5.0)); + test_absolute(0.0, 0.1, -1.0, 1e-14, func(-1.0)); + test_exact(0.0, 0.1, 0.0, func(0.0)); + test_absolute(0.0, 0.1, 1.0, 1e-14, func(1.0)); + test_absolute(0.0, 0.1, 5.0, 1e-10, func(5.0)); + test_absolute(0.0, 1.0, -5.0, 1e-14, func(-5.0)); + test_absolute(0.0, 1.0, -1.0, 1e-15, func(-1.0)); + test_exact(0.0, 1.0, 0.0, func(0.0)); + test_absolute(0.0, 1.0, 1.0, 1e-15, func(1.0)); + test_absolute(0.0, 1.0, 5.0, 1e-14, func(5.0)); + test_absolute(0.0, 10.0, -5.0, 1e-14, func(-5.0)); + test_absolute(0.0, 10.0, -1.0, 1e-14, func(-1.0)); + test_exact(0.0, 10.0, 0.0, func(0.0)); + test_absolute(0.0, 10.0, 1.0, 1e-14, func(1.0)); + test_absolute(0.0, 10.0, 5.0, 1e-14, func(5.0)); + test_exact(-5.0, 100.0, -5.0, func(-5.0)); + test_absolute(-5.0, 100.0, -1.0, 1e-10, func(-1.0)); + test_absolute(-5.0, 100.0, 0.0, 1e-14, func(0.0)); + test_absolute(-5.0, 100.0, 1.0, 1e-14, func(1.0)); + test_absolute(-5.0, 100.0, 5.0, 1e-10, func(5.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(-1.2, 3.4), -1500.0, 1500.0); - test::check_continuous_distribution(&try_create(-4.5, 6.7), -5000.0, 5000.0); + test::check_continuous_distribution(&create_ok(-1.2, 3.4), -1500.0, 1500.0); + test::check_continuous_distribution(&create_ok(-4.5, 6.7), -5000.0, 5000.0); } } From d42a04c352c9c68014b9a9d1e83ef09281d86a54 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:26:59 +0200 Subject: [PATCH 144/185] Use testing_boiler! for ChiSquared --- src/distribution/chi_squared.rs | 39 +++++++++------------------------ src/distribution/internal.rs | 9 ++++---- 2 files changed, 15 insertions(+), 33 deletions(-) diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index 8ead2314..afa5df71 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -309,44 +309,25 @@ mod tests { use crate::statistics::Median; use crate::distribution::ChiSquared; use crate::distribution::internal::*; + use crate::testing_boiler; - fn try_create(freedom: f64) -> ChiSquared { - let n = ChiSquared::new(freedom); - assert!(n.is_ok()); - n.unwrap() - } - - fn test_case(freedom: f64, expected: f64, eval: F) - where F: Fn(ChiSquared) -> f64 - { - let n = try_create(freedom); - let x = eval(n); - assert_eq!(expected, x); - } - - fn test_almost(freedom: f64, expected: f64, acc: f64, eval: F) - where F: Fn(ChiSquared) -> f64 - { - let n = try_create(freedom); - let x = eval(n); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(freedom: f64; ChiSquared); #[test] fn test_median() { let median = |x: ChiSquared| x.median(); - test_almost(0.5, 0.0857338820301783264746, 1e-16, median); - test_case(1.0, 1.0 - 2.0 / 3.0, median); - test_case(2.0, 2.0 - 2.0 / 3.0, median); - test_case(2.5, 2.5 - 2.0 / 3.0, median); - test_case(3.0, 3.0 - 2.0 / 3.0, median); + test_absolute(0.5, 0.0857338820301783264746, 1e-16, median); + test_exact(1.0, 1.0 - 2.0 / 3.0, median); + test_exact(2.0, 2.0 - 2.0 / 3.0, median); + test_exact(2.5, 2.5 - 2.0 / 3.0, median); + test_exact(3.0, 3.0 - 2.0 / 3.0, median); } #[test] fn test_continuous() { // TODO: figure out why this test fails: - //test::check_continuous_distribution(&try_create(1.0), 0.0, 10.0); - test::check_continuous_distribution(&try_create(2.0), 0.0, 10.0); - test::check_continuous_distribution(&try_create(5.0), 0.0, 50.0); + //test::check_continuous_distribution(&create_ok(1.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(2.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(5.0), 0.0, 50.0); } } diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 79acbe0d..a21e9330 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -139,6 +139,7 @@ pub mod test { /// Returns the error when creating a distribution with the given parameters, /// panicking if `::new` succeeds. + #[allow(dead_code)] fn create_err($($arg_name: $arg_ty),+) -> $crate::StatsError { match <$dist>::new($($arg_name),+) { Err(e) => e, @@ -168,7 +169,7 @@ pub mod test { /// to `expected` exactly. /// /// Panics if `::new` fails. - #[allow(dead_code)] // This is not used by all distributions. + #[allow(dead_code)] fn test_exact($($arg_name: $arg_ty),+, expected: T, get_fn: F) where F: Fn($dist) -> T, @@ -191,7 +192,7 @@ pub mod test { /// Allows relative error of up to [`crate::consts::ACC`]. /// /// Panics if `::new` fails. - #[allow(dead_code)] // This is not used by all distributions. + #[allow(dead_code)] fn test_relative($($arg_name: $arg_ty),+, expected: f64, get_fn: F) where F: Fn($dist) -> f64, @@ -216,7 +217,7 @@ pub mod test { /// Allows absolute error of up to `acc`. /// /// Panics if `::new` fails. - #[allow(dead_code)] // This is not used by all distributions. + #[allow(dead_code)] fn test_absolute($($arg_name: $arg_ty),+, expected: f64, acc: f64, get_fn: F) where F: Fn($dist) -> f64, @@ -243,7 +244,7 @@ pub mod test { /// and asserts that it is [`None`]. /// /// Panics if `::new` fails. - #[allow(dead_code)] // This is not used by all distributions. + #[allow(dead_code)] fn test_none($($arg_name: $arg_ty),+, get_fn: F) where F: Fn($dist) -> Option, From 96a65114d33d60e85de2f9fdf7bbba74a5251508 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:31:44 +0200 Subject: [PATCH 145/185] Use testing_boiler! for Chi --- src/distribution/chi.rs | 319 +++++++++++++++++----------------------- 1 file changed, 136 insertions(+), 183 deletions(-) diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index 205cca11..1bb74295 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -329,71 +329,34 @@ mod tests { use crate::distribution::internal::*; use crate::distribution::{Chi, Continuous, ContinuousCDF}; use crate::statistics::*; + use crate::testing_boiler; - fn try_create(freedom: f64) -> Chi { - let n = Chi::new(freedom); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(freedom: f64) { - let n = try_create(freedom); - assert_eq!(freedom, n.freedom()); - } - - fn bad_create_case(freedom: f64) { - let n = Chi::new(freedom); - assert!(n.is_err()); - } - - fn get_value(freedom: f64, eval: F) -> f64 - where - F: Fn(Chi) -> f64, - { - let n = try_create(freedom); - eval(n) - } - - fn test_case(freedom: f64, expected: f64, eval: F) - where - F: Fn(Chi) -> f64, - { - let x = get_value(freedom, eval); - assert_eq!(expected, x); - } - - fn test_almost(freedom: f64, expected: f64, acc: f64, eval: F) - where - F: Fn(Chi) -> f64, - { - let x = get_value(freedom, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(freedom: f64; Chi); #[test] fn test_create() { - create_case(1.0); - create_case(3.0); - create_case(f64::INFINITY); + create_ok(1.0); + create_ok(3.0); + create_ok(f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(0.0); - bad_create_case(-1.0); - bad_create_case(-100.0); - bad_create_case(f64::NEG_INFINITY); - bad_create_case(f64::NAN); + create_err(0.0); + create_err(-1.0); + create_err(-100.0); + create_err(f64::NEG_INFINITY); + create_err(f64::NAN); } #[test] fn test_mean() { let mean = |x: Chi| x.mean().unwrap(); - test_almost(1.0, 0.7978845608028653558799, 1e-15, mean); - test_almost(2.0, 1.25331413731550025121, 1e-14, mean); - test_almost(2.5, 1.43396639245837498609, 1e-14, mean); - test_almost(5.0, 2.12769216214097428235, 1e-14, mean); - test_almost(336.0, 18.31666925443713, 1e-12, mean); + test_absolute(1.0, 0.7978845608028653558799, 1e-15, mean); + test_absolute(2.0, 1.25331413731550025121, 1e-14, mean); + test_absolute(2.5, 1.43396639245837498609, 1e-14, mean); + test_absolute(5.0, 2.12769216214097428235, 1e-14, mean); + test_absolute(336.0, 18.31666925443713, 1e-12, mean); } #[test] @@ -405,223 +368,213 @@ mod tests { } #[test] - #[should_panic] fn test_mean_degen() { - let mean = |x: Chi| x.mean().unwrap(); - get_value(f64::INFINITY, mean); + test_none(f64::INFINITY, |dist| dist.mean()); } #[test] fn test_variance() { let variance = |x: Chi| x.variance().unwrap(); - test_almost(1.0, 0.3633802276324186569245, 1e-15, variance); - test_almost(2.0, 0.42920367320510338077, 1e-14, variance); - test_almost(2.5, 0.44374038529991368581, 1e-13, variance); - test_almost(3.0, 0.4535209105296746277, 1e-14, variance); + test_absolute(1.0, 0.3633802276324186569245, 1e-15, variance); + test_absolute(2.0, 0.42920367320510338077, 1e-14, variance); + test_absolute(2.5, 0.44374038529991368581, 1e-13, variance); + test_absolute(3.0, 0.4535209105296746277, 1e-14, variance); } #[test] - #[should_panic] fn test_variance_degen() { - let variance = |x: Chi| x.variance().unwrap(); - get_value(f64::INFINITY, variance); + test_none(f64::INFINITY, |dist| dist.variance()); } #[test] fn test_entropy() { let entropy = |x: Chi| x.entropy().unwrap(); - test_almost(1.0, 0.7257913526447274323631, 1e-15, entropy); - test_almost(2.0, 0.9420342421707937755946, 1e-15, entropy); - test_almost(2.5, 0.97574472333041323989, 1e-14, entropy); - test_almost(3.0, 0.99615419810620560239, 1e-14, entropy); + test_absolute(1.0, 0.7257913526447274323631, 1e-15, entropy); + test_absolute(2.0, 0.9420342421707937755946, 1e-15, entropy); + test_absolute(2.5, 0.97574472333041323989, 1e-14, entropy); + test_absolute(3.0, 0.99615419810620560239, 1e-14, entropy); } #[test] - #[should_panic] fn test_entropy_degen() { - let entropy = |x: Chi| x.entropy().unwrap(); - get_value(f64::INFINITY, entropy); + test_none(f64::INFINITY, |dist| dist.entropy()); } #[test] fn test_skewness() { let skewness = |x: Chi| x.skewness().unwrap(); - test_almost(1.0, 0.995271746431156042444, 1e-14, skewness); - test_almost(2.0, 0.6311106578189371382, 1e-13, skewness); - test_almost(2.5, 0.5458487096285153216, 1e-12, skewness); - test_almost(3.0, 0.485692828049590809, 1e-12, skewness); + test_absolute(1.0, 0.995271746431156042444, 1e-14, skewness); + test_absolute(2.0, 0.6311106578189371382, 1e-13, skewness); + test_absolute(2.5, 0.5458487096285153216, 1e-12, skewness); + test_absolute(3.0, 0.485692828049590809, 1e-12, skewness); } #[test] - #[should_panic] fn test_skewness_degen() { - let skewness = |x: Chi| x.skewness().unwrap(); - get_value(f64::INFINITY, skewness); + test_none(f64::INFINITY, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: Chi| x.mode().unwrap(); - test_case(1.0, 0.0, mode); - test_case(2.0, 1.0, mode); - test_case(2.5, 1.224744871391589049099, mode); - test_case(3.0, f64::consts::SQRT_2, mode); - test_case(f64::INFINITY, f64::INFINITY, mode); + test_exact(1.0, 0.0, mode); + test_exact(2.0, 1.0, mode); + test_exact(2.5, 1.224744871391589049099, mode); + test_exact(3.0, f64::consts::SQRT_2, mode); + test_exact(f64::INFINITY, f64::INFINITY, mode); } #[test] - #[should_panic] fn test_mode_freedom_lt_1() { - let mode = |x: Chi| x.mode().unwrap(); - get_value(0.5, mode); + test_none(0.5, |dist| dist.mode()); } #[test] fn test_min_max() { let min = |x: Chi| x.min(); let max = |x: Chi| x.max(); - test_case(1.0, 0.0, min); - test_case(2.0, 0.0, min); - test_case(2.5, 0.0, min); - test_case(3.0, 0.0, min); - test_case(f64::INFINITY, 0.0, min); - test_case(1.0, f64::INFINITY, max); - test_case(2.0, f64::INFINITY, max); - test_case(2.5, f64::INFINITY, max); - test_case(3.0, f64::INFINITY, max); - test_case(f64::INFINITY, f64::INFINITY, max); + test_exact(1.0, 0.0, min); + test_exact(2.0, 0.0, min); + test_exact(2.5, 0.0, min); + test_exact(3.0, 0.0, min); + test_exact(f64::INFINITY, 0.0, min); + test_exact(1.0, f64::INFINITY, max); + test_exact(2.0, f64::INFINITY, max); + test_exact(2.5, f64::INFINITY, max); + test_exact(3.0, f64::INFINITY, max); + test_exact(f64::INFINITY, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Chi| x.pdf(arg); - test_case(1.0, 0.0, pdf(0.0)); - test_almost(1.0, 0.79390509495402353102, 1e-15, pdf(0.1)); - test_almost(1.0, 0.48394144903828669960, 1e-15, pdf(1.0)); - test_almost(1.0, 2.1539520085086552718e-7, 1e-22, pdf(5.5)); - test_case(1.0, 0.0, pdf(f64::INFINITY)); - test_case(2.0, 0.0, pdf(0.0)); - test_almost(2.0, 0.099501247919268231335, 1e-16, pdf(0.1)); - test_almost(2.0, 0.60653065971263342360, 1e-15, pdf(1.0)); - test_almost(2.0, 1.4847681768496578863e-6, 1e-21, pdf(5.5)); - test_case(2.0, 0.0, pdf(f64::INFINITY)); - test_case(2.5, 0.0, pdf(0.0)); - test_almost(2.5, 0.029191065334961657461, 1e-16, pdf(0.1)); - test_almost(2.5, 0.56269645152636456261, 1e-15, pdf(1.0)); - test_almost(2.5, 3.2304380188895211768e-6, 1e-20, pdf(5.5)); - test_case(2.5, 0.0, pdf(f64::INFINITY)); - test_case(f64::INFINITY, 0.0, pdf(0.0)); - test_case(f64::INFINITY, 0.0, pdf(0.1)); - test_case(f64::INFINITY, 0.0, pdf(1.0)); - test_case(f64::INFINITY, 0.0, pdf(5.5)); - test_case(f64::INFINITY, 0.0, pdf(f64::INFINITY)); - test_almost(170.0, 0.5644678498668440878, 1e-13, pdf(13.0)); + test_exact(1.0, 0.0, pdf(0.0)); + test_absolute(1.0, 0.79390509495402353102, 1e-15, pdf(0.1)); + test_absolute(1.0, 0.48394144903828669960, 1e-15, pdf(1.0)); + test_absolute(1.0, 2.1539520085086552718e-7, 1e-22, pdf(5.5)); + test_exact(1.0, 0.0, pdf(f64::INFINITY)); + test_exact(2.0, 0.0, pdf(0.0)); + test_absolute(2.0, 0.099501247919268231335, 1e-16, pdf(0.1)); + test_absolute(2.0, 0.60653065971263342360, 1e-15, pdf(1.0)); + test_absolute(2.0, 1.4847681768496578863e-6, 1e-21, pdf(5.5)); + test_exact(2.0, 0.0, pdf(f64::INFINITY)); + test_exact(2.5, 0.0, pdf(0.0)); + test_absolute(2.5, 0.029191065334961657461, 1e-16, pdf(0.1)); + test_absolute(2.5, 0.56269645152636456261, 1e-15, pdf(1.0)); + test_absolute(2.5, 3.2304380188895211768e-6, 1e-20, pdf(5.5)); + test_exact(2.5, 0.0, pdf(f64::INFINITY)); + test_exact(f64::INFINITY, 0.0, pdf(0.0)); + test_exact(f64::INFINITY, 0.0, pdf(0.1)); + test_exact(f64::INFINITY, 0.0, pdf(1.0)); + test_exact(f64::INFINITY, 0.0, pdf(5.5)); + test_exact(f64::INFINITY, 0.0, pdf(f64::INFINITY)); + test_absolute(170.0, 0.5644678498668440878, 1e-13, pdf(13.0)); } #[test] fn test_neg_pdf() { let pdf = |arg: f64| move |x: Chi| x.pdf(arg); - test_case(1.0, 0.0, pdf(-1.0)); + test_exact(1.0, 0.0, pdf(-1.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Chi| x.ln_pdf(arg); - test_case(1.0, f64::NEG_INFINITY, ln_pdf(0.0)); - test_almost(1.0, -0.23079135264472743236, 1e-15, ln_pdf(0.1)); - test_almost(1.0, -0.72579135264472743236, 1e-15, ln_pdf(1.0)); - test_almost(1.0, -15.350791352644727432, 1e-14, ln_pdf(5.5)); - test_case(1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); - test_case(2.0, f64::NEG_INFINITY, ln_pdf(0.0)); - test_almost(2.0, -2.3075850929940456840, 1e-15, ln_pdf(0.1)); - test_almost(2.0, -0.5, 1e-15, ln_pdf(1.0)); - test_almost(2.0, -13.420251907761574765, 1e-15, ln_pdf(5.5)); - test_case(2.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); - test_case(2.5, f64::NEG_INFINITY, ln_pdf(0.0)); - test_almost(2.5, -3.5338925982092416919, 1e-15, ln_pdf(0.1)); - test_almost(2.5, -0.57501495871817316589, 1e-15, ln_pdf(1.0)); - test_almost(2.5, -12.642892820360535314, 1e-16, ln_pdf(5.5)); - test_case(2.5, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); - test_case(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); - test_case(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.1)); - test_case(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(1.0)); - test_case(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(5.5)); - test_case(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); - test_almost(170.0, -0.57187185030600516424237, 1e-13, ln_pdf(13.0)); + test_exact(1.0, f64::NEG_INFINITY, ln_pdf(0.0)); + test_absolute(1.0, -0.23079135264472743236, 1e-15, ln_pdf(0.1)); + test_absolute(1.0, -0.72579135264472743236, 1e-15, ln_pdf(1.0)); + test_absolute(1.0, -15.350791352644727432, 1e-14, ln_pdf(5.5)); + test_exact(1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(2.0, f64::NEG_INFINITY, ln_pdf(0.0)); + test_absolute(2.0, -2.3075850929940456840, 1e-15, ln_pdf(0.1)); + test_absolute(2.0, -0.5, 1e-15, ln_pdf(1.0)); + test_absolute(2.0, -13.420251907761574765, 1e-15, ln_pdf(5.5)); + test_exact(2.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(2.5, f64::NEG_INFINITY, ln_pdf(0.0)); + test_absolute(2.5, -3.5338925982092416919, 1e-15, ln_pdf(0.1)); + test_absolute(2.5, -0.57501495871817316589, 1e-15, ln_pdf(1.0)); + test_absolute(2.5, -12.642892820360535314, 1e-16, ln_pdf(5.5)); + test_exact(2.5, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); + test_exact(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.1)); + test_exact(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(1.0)); + test_exact(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(5.5)); + test_exact(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_absolute(170.0, -0.57187185030600516424237, 1e-13, ln_pdf(13.0)); } #[test] fn test_neg_ln_pdf() { let ln_pdf = |arg: f64| move |x: Chi| x.ln_pdf(arg); - test_case(1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); + test_exact(1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Chi| x.cdf(arg); - test_case(1.0, 0.0, cdf(0.0)); - test_almost(1.0, 0.079655674554057962931, 1e-16, cdf(0.1)); - test_almost(1.0, 0.68268949213708589717, 1e-15, cdf(1.0)); - test_case(1.0, 0.99999996202087506822, cdf(5.5)); - test_case(1.0, 1.0, cdf(f64::INFINITY)); - test_case(2.0, 0.0, cdf(0.0)); - test_almost(2.0, 0.0049875208073176866474, 1e-17, cdf(0.1)); - test_almost(2.0, 0.39346934028736657640, 1e-15, cdf(1.0)); - test_case(2.0, 0.99999973004214966370, cdf(5.5)); - test_case(2.0, 1.0, cdf(f64::INFINITY)); - test_case(2.5, 0.0, cdf(0.0)); - test_almost(2.5, 0.0011702413714030096290, 1e-18, cdf(0.1)); - test_almost(2.5, 0.28378995266531297417, 1e-16, cdf(1.0)); - test_case(2.5, 0.99999940337322804750, cdf(5.5)); - test_case(2.5, 1.0, cdf(f64::INFINITY)); - test_case(f64::INFINITY, 1.0, cdf(0.0)); - test_case(f64::INFINITY, 1.0, cdf(0.1)); - test_case(f64::INFINITY, 1.0, cdf(1.0)); - test_case(f64::INFINITY, 1.0, cdf(5.5)); - test_case(f64::INFINITY, 1.0, cdf(f64::INFINITY)); + test_exact(1.0, 0.0, cdf(0.0)); + test_absolute(1.0, 0.079655674554057962931, 1e-16, cdf(0.1)); + test_absolute(1.0, 0.68268949213708589717, 1e-15, cdf(1.0)); + test_exact(1.0, 0.99999996202087506822, cdf(5.5)); + test_exact(1.0, 1.0, cdf(f64::INFINITY)); + test_exact(2.0, 0.0, cdf(0.0)); + test_absolute(2.0, 0.0049875208073176866474, 1e-17, cdf(0.1)); + test_absolute(2.0, 0.39346934028736657640, 1e-15, cdf(1.0)); + test_exact(2.0, 0.99999973004214966370, cdf(5.5)); + test_exact(2.0, 1.0, cdf(f64::INFINITY)); + test_exact(2.5, 0.0, cdf(0.0)); + test_absolute(2.5, 0.0011702413714030096290, 1e-18, cdf(0.1)); + test_absolute(2.5, 0.28378995266531297417, 1e-16, cdf(1.0)); + test_exact(2.5, 0.99999940337322804750, cdf(5.5)); + test_exact(2.5, 1.0, cdf(f64::INFINITY)); + test_exact(f64::INFINITY, 1.0, cdf(0.0)); + test_exact(f64::INFINITY, 1.0, cdf(0.1)); + test_exact(f64::INFINITY, 1.0, cdf(1.0)); + test_exact(f64::INFINITY, 1.0, cdf(5.5)); + test_exact(f64::INFINITY, 1.0, cdf(f64::INFINITY)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Chi| x.sf(arg); - test_case(1.0, 1.0, sf(0.0)); - test_almost(1.0, 0.920344325445942, 1e-16, sf(0.1)); - test_almost(1.0, 0.31731050786291404, 1e-15, sf(1.0)); - test_almost(1.0, 3.797912493177544e-8, 1e-15, sf(5.5)); - test_case(1.0, 0.0, sf(f64::INFINITY)); - test_case(2.0, 1.0, sf(0.0)); - test_almost(2.0, 0.9950124791926823, 1e-17, sf(0.1)); - test_almost(2.0, 0.6065306597126333, 1e-15, sf(1.0)); - test_almost(2.0, 2.699578503363014e-7, 1e-15, sf(5.5)); - test_case(2.0, 0.0, sf(f64::INFINITY)); - test_case(2.5, 1.0, sf(0.0)); - test_almost(2.5, 0.998829758628597, 1e-18, sf(0.1)); - test_almost(2.5, 0.716210047334687, 1e-16, sf(1.0)); - test_almost(2.5, 5.966267719870189e-7, 1e-15, sf(5.5)); - test_case(2.5, 0.0, sf(f64::INFINITY)); - test_case(f64::INFINITY, 0.0, sf(0.0)); - test_case(f64::INFINITY, 0.0, sf(0.1)); - test_case(f64::INFINITY, 0.0, sf(1.0)); - test_case(f64::INFINITY, 0.0, sf(5.5)); - test_case(f64::INFINITY, 0.0, sf(f64::INFINITY)); + test_exact(1.0, 1.0, sf(0.0)); + test_absolute(1.0, 0.920344325445942, 1e-16, sf(0.1)); + test_absolute(1.0, 0.31731050786291404, 1e-15, sf(1.0)); + test_absolute(1.0, 3.797912493177544e-8, 1e-15, sf(5.5)); + test_exact(1.0, 0.0, sf(f64::INFINITY)); + test_exact(2.0, 1.0, sf(0.0)); + test_absolute(2.0, 0.9950124791926823, 1e-17, sf(0.1)); + test_absolute(2.0, 0.6065306597126333, 1e-15, sf(1.0)); + test_absolute(2.0, 2.699578503363014e-7, 1e-15, sf(5.5)); + test_exact(2.0, 0.0, sf(f64::INFINITY)); + test_exact(2.5, 1.0, sf(0.0)); + test_absolute(2.5, 0.998829758628597, 1e-18, sf(0.1)); + test_absolute(2.5, 0.716210047334687, 1e-16, sf(1.0)); + test_absolute(2.5, 5.966267719870189e-7, 1e-15, sf(5.5)); + test_exact(2.5, 0.0, sf(f64::INFINITY)); + test_exact(f64::INFINITY, 0.0, sf(0.0)); + test_exact(f64::INFINITY, 0.0, sf(0.1)); + test_exact(f64::INFINITY, 0.0, sf(1.0)); + test_exact(f64::INFINITY, 0.0, sf(5.5)); + test_exact(f64::INFINITY, 0.0, sf(f64::INFINITY)); } #[test] fn test_neg_cdf() { let cdf = |arg: f64| move |x: Chi| x.cdf(arg); - test_case(1.0, 0.0, cdf(-1.0)); + test_exact(1.0, 0.0, cdf(-1.0)); } #[test] fn test_neg_sf() { let sf = |arg: f64| move |x: Chi| x.sf(arg); - test_case(1.0, 1.0, sf(-1.0)); + test_exact(1.0, 1.0, sf(-1.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(1.0), 0.0, 10.0); - test::check_continuous_distribution(&try_create(2.0), 0.0, 10.0); - test::check_continuous_distribution(&try_create(5.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(1.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(2.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(5.0), 0.0, 10.0); } } From 2358d62adae3c2f949fdd94f2920c4b0f3d01aca Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:33:07 +0200 Subject: [PATCH 146/185] Use testing_boiler! for Dirac --- src/distribution/dirac.rs | 96 ++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 58 deletions(-) diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index c664d8db..41ac1d6c 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -193,114 +193,94 @@ impl Mode> for Dirac { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::*; use crate::distribution::{ContinuousCDF, Dirac}; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(v: f64) -> Dirac { - let d = Dirac::new(v); - assert!(d.is_ok()); - d.unwrap() - } - - fn create_case(v: f64) { - let d = try_create(v); - assert_eq!(v, d.mean().unwrap()); - } - - fn bad_create_case(v: f64) { - let d = Dirac::new(v); - assert!(d.is_err()); - } - - fn test_case(v: f64, expected: f64, eval: F) - where F: Fn(Dirac) -> f64 - { - let x = eval(try_create(v)); - assert_eq!(expected, x); - } + testing_boiler!(v: f64; Dirac); #[test] fn test_create() { - create_case(10.0); - create_case(-5.0); - create_case(10.0); - create_case(100.0); - create_case(f64::INFINITY); + create_ok(10.0); + create_ok(-5.0); + create_ok(10.0); + create_ok(100.0); + create_ok(f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN); + create_err(f64::NAN); } #[test] fn test_variance() { let variance = |x: Dirac| x.variance().unwrap(); - test_case(0.0, 0.0, variance); - test_case(-5.0, 0.0, variance); - test_case(f64::INFINITY, 0.0, variance); + test_exact(0.0, 0.0, variance); + test_exact(-5.0, 0.0, variance); + test_exact(f64::INFINITY, 0.0, variance); } #[test] fn test_entropy() { let entropy = |x: Dirac| x.entropy().unwrap(); - test_case(0.0, 0.0, entropy); - test_case(f64::INFINITY, 0.0, entropy); + test_exact(0.0, 0.0, entropy); + test_exact(f64::INFINITY, 0.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Dirac| x.skewness().unwrap(); - test_case(0.0, 0.0, skewness); - test_case(4.0, 0.0, skewness); - test_case(0.3, 0.0, skewness); - test_case(f64::INFINITY, 0.0, skewness); + test_exact(0.0, 0.0, skewness); + test_exact(4.0, 0.0, skewness); + test_exact(0.3, 0.0, skewness); + test_exact(f64::INFINITY, 0.0, skewness); } #[test] fn test_mode() { let mode = |x: Dirac| x.mode().unwrap(); - test_case(0.0, 0.0, mode); - test_case(3.0, 3.0, mode); - test_case(f64::INFINITY, f64::INFINITY, mode); + test_exact(0.0, 0.0, mode); + test_exact(3.0, 3.0, mode); + test_exact(f64::INFINITY, f64::INFINITY, mode); } #[test] fn test_median() { let median = |x: Dirac| x.median(); - test_case(0.0, 0.0, median); - test_case(3.0, 3.0, median); - test_case(f64::INFINITY, f64::INFINITY, median); + test_exact(0.0, 0.0, median); + test_exact(3.0, 3.0, median); + test_exact(f64::INFINITY, f64::INFINITY, median); } #[test] fn test_min_max() { let min = |x: Dirac| x.min(); let max = |x: Dirac| x.max(); - test_case(0.0, 0.0, min); - test_case(3.0, 3.0, min); - test_case(f64::INFINITY, f64::INFINITY, min); + test_exact(0.0, 0.0, min); + test_exact(3.0, 3.0, min); + test_exact(f64::INFINITY, f64::INFINITY, min); - test_case(0.0, 0.0, max); - test_case(3.0, 3.0, max); - test_case(f64::NEG_INFINITY, f64::NEG_INFINITY, max); + test_exact(0.0, 0.0, max); + test_exact(3.0, 3.0, max); + test_exact(f64::NEG_INFINITY, f64::NEG_INFINITY, max); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Dirac| x.cdf(arg); - test_case(0.0, 1.0, cdf(0.0)); - test_case(3.0, 1.0, cdf(3.0)); - test_case(f64::INFINITY, 0.0, cdf(1.0)); - test_case(f64::INFINITY, 1.0, cdf(f64::INFINITY)); + test_exact(0.0, 1.0, cdf(0.0)); + test_exact(3.0, 1.0, cdf(3.0)); + test_exact(f64::INFINITY, 0.0, cdf(1.0)); + test_exact(f64::INFINITY, 1.0, cdf(f64::INFINITY)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Dirac| x.sf(arg); - test_case(0.0, 0.0, sf(0.0)); - test_case(3.0, 0.0, sf(3.0)); - test_case(f64::INFINITY, 1.0, sf(1.0)); - test_case(f64::INFINITY, 0.0, sf(f64::INFINITY)); + test_exact(0.0, 0.0, sf(0.0)); + test_exact(3.0, 0.0, sf(3.0)); + test_exact(f64::INFINITY, 1.0, sf(1.0)); + test_exact(f64::INFINITY, 0.0, sf(f64::INFINITY)); } } From f29464638221f106f2f9501a1f4f85fa5556f412 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:34:32 +0200 Subject: [PATCH 147/185] Use testing_boiler! for DiscreteUniform --- src/distribution/discrete_uniform.rs | 139 +++++++++++---------------- 1 file changed, 54 insertions(+), 85 deletions(-) diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index 361cadd8..6871c80a 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -256,164 +256,133 @@ impl Discrete for DiscreteUniform { #[rustfmt::skip] #[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; use crate::distribution::{DiscreteCDF, Discrete, DiscreteUniform}; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(min: i64, max: i64) -> DiscreteUniform { - let n = DiscreteUniform::new(min, max); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(min: i64, max: i64) { - let n = try_create(min, max); - assert_eq!(min, n.min()); - assert_eq!(max, n.max()); - } - - fn bad_create_case(min: i64, max: i64) { - let n = DiscreteUniform::new(min, max); - assert!(n.is_err()); - } - - fn get_value(min: i64, max: i64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(DiscreteUniform) -> T - { - let n = try_create(min, max); - eval(n) - } - - fn test_case(min: i64, max: i64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(DiscreteUniform) -> T - { - let x = get_value(min, max, eval); - assert_eq!(expected, x); - } + testing_boiler!(min: i64, max: i64; DiscreteUniform); #[test] fn test_create() { - create_case(-10, 10); - create_case(0, 4); - create_case(10, 20); - create_case(20, 20); + create_ok(-10, 10); + create_ok(0, 4); + create_ok(10, 20); + create_ok(20, 20); } #[test] fn test_bad_create() { - bad_create_case(-1, -2); - bad_create_case(6, 5); + create_err(-1, -2); + create_err(6, 5); } #[test] fn test_mean() { let mean = |x: DiscreteUniform| x.mean().unwrap(); - test_case(-10, 10, 0.0, mean); - test_case(0, 4, 2.0, mean); - test_case(10, 20, 15.0, mean); - test_case(20, 20, 20.0, mean); + test_exact(-10, 10, 0.0, mean); + test_exact(0, 4, 2.0, mean); + test_exact(10, 20, 15.0, mean); + test_exact(20, 20, 20.0, mean); } #[test] fn test_variance() { let variance = |x: DiscreteUniform| x.variance().unwrap(); - test_case(-10, 10, 36.66666666666666666667, variance); - test_case(0, 4, 2.0, variance); - test_case(10, 20, 10.0, variance); - test_case(20, 20, 0.0, variance); + test_exact(-10, 10, 36.66666666666666666667, variance); + test_exact(0, 4, 2.0, variance); + test_exact(10, 20, 10.0, variance); + test_exact(20, 20, 0.0, variance); } #[test] fn test_entropy() { let entropy = |x: DiscreteUniform| x.entropy().unwrap(); - test_case(-10, 10, 3.0445224377234229965005979803657054342845752874046093, entropy); - test_case(0, 4, 1.6094379124341003746007593332261876395256013542685181, entropy); - test_case(10, 20, 2.3978952727983705440619435779651292998217068539374197, entropy); - test_case(20, 20, 0.0, entropy); + test_exact(-10, 10, 3.0445224377234229965005979803657054342845752874046093, entropy); + test_exact(0, 4, 1.6094379124341003746007593332261876395256013542685181, entropy); + test_exact(10, 20, 2.3978952727983705440619435779651292998217068539374197, entropy); + test_exact(20, 20, 0.0, entropy); } #[test] fn test_skewness() { let skewness = |x: DiscreteUniform| x.skewness().unwrap(); - test_case(-10, 10, 0.0, skewness); - test_case(0, 4, 0.0, skewness); - test_case(10, 20, 0.0, skewness); - test_case(20, 20, 0.0, skewness); + test_exact(-10, 10, 0.0, skewness); + test_exact(0, 4, 0.0, skewness); + test_exact(10, 20, 0.0, skewness); + test_exact(20, 20, 0.0, skewness); } #[test] fn test_median() { let median = |x: DiscreteUniform| x.median(); - test_case(-10, 10, 0.0, median); - test_case(0, 4, 2.0, median); - test_case(10, 20, 15.0, median); - test_case(20, 20, 20.0, median); + test_exact(-10, 10, 0.0, median); + test_exact(0, 4, 2.0, median); + test_exact(10, 20, 15.0, median); + test_exact(20, 20, 20.0, median); } #[test] fn test_mode() { let mode = |x: DiscreteUniform| x.mode().unwrap(); - test_case(-10, 10, 0, mode); - test_case(0, 4, 2, mode); - test_case(10, 20, 15, mode); - test_case(20, 20, 20, mode); + test_exact(-10, 10, 0, mode); + test_exact(0, 4, 2, mode); + test_exact(10, 20, 15, mode); + test_exact(20, 20, 20, mode); } #[test] fn test_pmf() { let pmf = |arg: i64| move |x: DiscreteUniform| x.pmf(arg); - test_case(-10, 10, 0.04761904761904761904762, pmf(-5)); - test_case(-10, 10, 0.04761904761904761904762, pmf(1)); - test_case(-10, 10, 0.04761904761904761904762, pmf(10)); - test_case(-10, -10, 0.0, pmf(0)); - test_case(-10, -10, 1.0, pmf(-10)); + test_exact(-10, 10, 0.04761904761904761904762, pmf(-5)); + test_exact(-10, 10, 0.04761904761904761904762, pmf(1)); + test_exact(-10, 10, 0.04761904761904761904762, pmf(10)); + test_exact(-10, -10, 0.0, pmf(0)); + test_exact(-10, -10, 1.0, pmf(-10)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: i64| move |x: DiscreteUniform| x.ln_pmf(arg); - test_case(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(-5)); - test_case(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(1)); - test_case(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(10)); - test_case(-10, -10, f64::NEG_INFINITY, ln_pmf(0)); - test_case(-10, -10, 0.0, ln_pmf(-10)); + test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(-5)); + test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(1)); + test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(10)); + test_exact(-10, -10, f64::NEG_INFINITY, ln_pmf(0)); + test_exact(-10, -10, 0.0, ln_pmf(-10)); } #[test] fn test_cdf() { let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg); - test_case(-10, 10, 0.2857142857142857142857, cdf(-5)); - test_case(-10, 10, 0.5714285714285714285714, cdf(1)); - test_case(-10, 10, 1.0, cdf(10)); - test_case(-10, -10, 1.0, cdf(-10)); + test_exact(-10, 10, 0.2857142857142857142857, cdf(-5)); + test_exact(-10, 10, 0.5714285714285714285714, cdf(1)); + test_exact(-10, 10, 1.0, cdf(10)); + test_exact(-10, -10, 1.0, cdf(-10)); } #[test] fn test_sf() { let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg); - test_case(-10, 10, 0.7142857142857142857143, sf(-5)); - test_case(-10, 10, 0.42857142857142855, sf(1)); - test_case(-10, 10, 0.0, sf(10)); - test_case(-10, -10, 0.0, sf(-10)); + test_exact(-10, 10, 0.7142857142857142857143, sf(-5)); + test_exact(-10, 10, 0.42857142857142855, sf(1)); + test_exact(-10, 10, 0.0, sf(10)); + test_exact(-10, -10, 0.0, sf(-10)); } #[test] fn test_cdf_lower_bound() { let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg); - test_case(0, 3, 0.0, cdf(-1)); + test_exact(0, 3, 0.0, cdf(-1)); } #[test] fn test_sf_lower_bound() { let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg); - test_case(0, 3, 1.0, sf(-1)); + test_exact(0, 3, 1.0, sf(-1)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg); - test_case(0, 3, 1.0, cdf(5)); + test_exact(0, 3, 1.0, cdf(5)); } } From 8d13b4207d00ee8a843e7bfba3f5e357fc6b8208 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:35:34 +0200 Subject: [PATCH 148/185] Use testing_boiler! for Erlang --- src/distribution/erlang.rs | 42 +++++++++++++------------------------- 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index 6fb6a098..ce6f68aa 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -295,45 +295,31 @@ impl Continuous for Erlang { mod tests { use crate::distribution::Erlang; use crate::distribution::internal::*; + use crate::testing_boiler; - fn try_create(shape: u64, rate: f64) -> Erlang { - let n = Erlang::new(shape, rate); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(shape: u64, rate: f64) { - let n = try_create(shape, rate); - assert_eq!(shape, n.shape()); - assert_eq!(rate, n.rate()); - } - - fn bad_create_case(shape: u64, rate: f64) { - let n = Erlang::new(shape, rate); - assert!(n.is_err()); - } + testing_boiler!(shape: u64, rate: f64; Erlang); #[test] fn test_create() { - create_case(1, 0.1); - create_case(1, 1.0); - create_case(10, 10.0); - create_case(10, 1.0); - create_case(10, f64::INFINITY); + create_ok(1, 0.1); + create_ok(1, 1.0); + create_ok(10, 10.0); + create_ok(10, 1.0); + create_ok(10, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(0, 1.0); - bad_create_case(1, 0.0); - bad_create_case(1, f64::NAN); - bad_create_case(1, -1.0); + create_err(0, 1.0); + create_err(1, 0.0); + create_err(1, f64::NAN); + create_err(1, -1.0); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(1, 2.5), 0.0, 20.0); - test::check_continuous_distribution(&try_create(2, 1.5), 0.0, 20.0); - test::check_continuous_distribution(&try_create(3, 0.5), 0.0, 20.0); + test::check_continuous_distribution(&create_ok(1, 2.5), 0.0, 20.0); + test::check_continuous_distribution(&create_ok(2, 1.5), 0.0, 20.0); + test::check_continuous_distribution(&create_ok(3, 0.5), 0.0, 20.0); } } From 15949c480c959e2f879591f4cc92ac7a3a814207 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:39:11 +0200 Subject: [PATCH 149/185] Use testing_boiler! for Exp --- src/distribution/exponential.rs | 215 +++++++++++++------------------- src/distribution/internal.rs | 13 ++ 2 files changed, 100 insertions(+), 128 deletions(-) diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index 978ae638..d5a54d56 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -280,202 +280,161 @@ impl Continuous for Exp { #[cfg(test)] mod tests { use std::f64; - use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Exp}; use crate::distribution::internal::*; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(rate: f64) -> Exp { - let n = Exp::new(rate); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(rate: f64) { - let n = try_create(rate); - assert_eq!(rate, n.rate()); - } - - fn bad_create_case(rate: f64) { - let n = Exp::new(rate); - assert!(n.is_err()); - } - - fn get_value(rate: f64, eval: F) -> f64 - where F: Fn(Exp) -> f64 - { - let n = try_create(rate); - eval(n) - } - - fn test_case(rate: f64, expected: f64, eval: F) - where F: Fn(Exp) -> f64 - { - let x = get_value(rate, eval); - assert_eq!(expected, x); - } - - fn test_almost(rate: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Exp) -> f64 - { - let x = get_value(rate, eval); - assert_almost_eq!(expected, x, acc); - } - - fn test_is_nan(rate: f64, eval: F) - where F : Fn(Exp) -> f64 - { - let x = get_value(rate, eval); - assert!(x.is_nan()); - } + testing_boiler!(rate: f64; Exp); #[test] fn test_create() { - create_case(0.1); - create_case(1.0); - create_case(10.0); + create_ok(0.1); + create_ok(1.0); + create_ok(10.0); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN); - bad_create_case(0.0); - bad_create_case(-1.0); - bad_create_case(-10.0); + create_err(f64::NAN); + create_err(0.0); + create_err(-1.0); + create_err(-10.0); } #[test] fn test_mean() { let mean = |x: Exp| x.mean().unwrap(); - test_case(0.1, 10.0, mean); - test_case(1.0, 1.0, mean); - test_case(10.0, 0.1, mean); + test_exact(0.1, 10.0, mean); + test_exact(1.0, 1.0, mean); + test_exact(10.0, 0.1, mean); } #[test] fn test_variance() { let variance = |x: Exp| x.variance().unwrap(); - test_almost(0.1, 100.0, 1e-13, variance); - test_case(1.0, 1.0, variance); - test_case(10.0, 0.01, variance); + test_absolute(0.1, 100.0, 1e-13, variance); + test_exact(1.0, 1.0, variance); + test_exact(10.0, 0.01, variance); } #[test] fn test_entropy() { let entropy = |x: Exp| x.entropy().unwrap(); - test_almost(0.1, 3.302585092994045684018, 1e-15, entropy); - test_case(1.0, 1.0, entropy); - test_almost(10.0, -1.302585092994045684018, 1e-15, entropy); + test_absolute(0.1, 3.302585092994045684018, 1e-15, entropy); + test_exact(1.0, 1.0, entropy); + test_absolute(10.0, -1.302585092994045684018, 1e-15, entropy); } #[test] fn test_skewness() { let skewness = |x: Exp| x.skewness().unwrap(); - test_case(0.1, 2.0, skewness); - test_case(1.0, 2.0, skewness); - test_case(10.0, 2.0, skewness); + test_exact(0.1, 2.0, skewness); + test_exact(1.0, 2.0, skewness); + test_exact(10.0, 2.0, skewness); } #[test] fn test_median() { let median = |x: Exp| x.median(); - test_almost(0.1, 6.931471805599453094172, 1e-15, median); - test_case(1.0, f64::consts::LN_2, median); - test_case(10.0, 0.06931471805599453094172, median); + test_absolute(0.1, 6.931471805599453094172, 1e-15, median); + test_exact(1.0, f64::consts::LN_2, median); + test_exact(10.0, 0.06931471805599453094172, median); } #[test] fn test_mode() { let mode = |x: Exp| x.mode().unwrap(); - test_case(0.1, 0.0, mode); - test_case(1.0, 0.0, mode); - test_case(10.0, 0.0, mode); + test_exact(0.1, 0.0, mode); + test_exact(1.0, 0.0, mode); + test_exact(10.0, 0.0, mode); } #[test] fn test_min_max() { let min = |x: Exp| x.min(); let max = |x: Exp| x.max(); - test_case(0.1, 0.0, min); - test_case(1.0, 0.0, min); - test_case(10.0, 0.0, min); - test_case(0.1, f64::INFINITY, max); - test_case(1.0, f64::INFINITY, max); - test_case(10.0, f64::INFINITY, max); + test_exact(0.1, 0.0, min); + test_exact(1.0, 0.0, min); + test_exact(10.0, 0.0, min); + test_exact(0.1, f64::INFINITY, max); + test_exact(1.0, f64::INFINITY, max); + test_exact(10.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Exp| x.pdf(arg); - test_case(0.1, 0.1, pdf(0.0)); - test_case(1.0, 1.0, pdf(0.0)); - test_case(10.0, 10.0, pdf(0.0)); + test_exact(0.1, 0.1, pdf(0.0)); + test_exact(1.0, 1.0, pdf(0.0)); + test_exact(10.0, 10.0, pdf(0.0)); test_is_nan(f64::INFINITY, pdf(0.0)); - test_case(0.1, 0.09900498337491680535739, pdf(0.1)); - test_almost(1.0, 0.9048374180359595731642, 1e-15, pdf(0.1)); - test_case(10.0, 3.678794411714423215955, pdf(0.1)); + test_exact(0.1, 0.09900498337491680535739, pdf(0.1)); + test_absolute(1.0, 0.9048374180359595731642, 1e-15, pdf(0.1)); + test_exact(10.0, 3.678794411714423215955, pdf(0.1)); test_is_nan(f64::INFINITY, pdf(0.1)); - test_case(0.1, 0.09048374180359595731642, pdf(1.0)); - test_case(1.0, 0.3678794411714423215955, pdf(1.0)); - test_almost(10.0, 4.539992976248485153559e-4, 1e-19, pdf(1.0)); + test_exact(0.1, 0.09048374180359595731642, pdf(1.0)); + test_exact(1.0, 0.3678794411714423215955, pdf(1.0)); + test_absolute(10.0, 4.539992976248485153559e-4, 1e-19, pdf(1.0)); test_is_nan(f64::INFINITY, pdf(1.0)); - test_case(0.1, 0.0, pdf(f64::INFINITY)); - test_case(1.0, 0.0, pdf(f64::INFINITY)); - test_case(10.0, 0.0, pdf(f64::INFINITY)); + test_exact(0.1, 0.0, pdf(f64::INFINITY)); + test_exact(1.0, 0.0, pdf(f64::INFINITY)); + test_exact(10.0, 0.0, pdf(f64::INFINITY)); test_is_nan(f64::INFINITY, pdf(f64::INFINITY)); } #[test] fn test_neg_pdf() { let pdf = |arg: f64| move |x: Exp| x.pdf(arg); - test_case(0.1, 0.0, pdf(-1.0)); + test_exact(0.1, 0.0, pdf(-1.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Exp| x.ln_pdf(arg); - test_almost(0.1, -2.302585092994045684018, 1e-15, ln_pdf(0.0)); - test_case(1.0, 0.0, ln_pdf(0.0)); - test_case(10.0, 2.302585092994045684018, ln_pdf(0.0)); + test_absolute(0.1, -2.302585092994045684018, 1e-15, ln_pdf(0.0)); + test_exact(1.0, 0.0, ln_pdf(0.0)); + test_exact(10.0, 2.302585092994045684018, ln_pdf(0.0)); test_is_nan(f64::INFINITY, ln_pdf(0.0)); - test_almost(0.1, -2.312585092994045684018, 1e-15, ln_pdf(0.1)); - test_case(1.0, -0.1, ln_pdf(0.1)); - test_almost(10.0, 1.302585092994045684018, 1e-15, ln_pdf(0.1)); + test_absolute(0.1, -2.312585092994045684018, 1e-15, ln_pdf(0.1)); + test_exact(1.0, -0.1, ln_pdf(0.1)); + test_absolute(10.0, 1.302585092994045684018, 1e-15, ln_pdf(0.1)); test_is_nan(f64::INFINITY, ln_pdf(0.1)); - test_case(0.1, -2.402585092994045684018, ln_pdf(1.0)); - test_case(1.0, -1.0, ln_pdf(1.0)); - test_case(10.0, -7.697414907005954315982, ln_pdf(1.0)); + test_exact(0.1, -2.402585092994045684018, ln_pdf(1.0)); + test_exact(1.0, -1.0, ln_pdf(1.0)); + test_exact(10.0, -7.697414907005954315982, ln_pdf(1.0)); test_is_nan(f64::INFINITY, ln_pdf(1.0)); - test_case(0.1, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); - test_case(1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); - test_case(10.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(0.1, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(10.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); test_is_nan(f64::INFINITY, ln_pdf(f64::INFINITY)); } #[test] fn test_neg_ln_pdf() { let ln_pdf = |arg: f64| move |x: Exp| x.ln_pdf(arg); - test_case(0.1, f64::NEG_INFINITY, ln_pdf(-1.0)); + test_exact(0.1, f64::NEG_INFINITY, ln_pdf(-1.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Exp| x.cdf(arg); - test_case(0.1, 0.0, cdf(0.0)); - test_case(1.0, 0.0, cdf(0.0)); - test_case(10.0, 0.0, cdf(0.0)); + test_exact(0.1, 0.0, cdf(0.0)); + test_exact(1.0, 0.0, cdf(0.0)); + test_exact(10.0, 0.0, cdf(0.0)); test_is_nan(f64::INFINITY, cdf(0.0)); - test_almost(0.1, 0.009950166250831946426094, 1e-16, cdf(0.1)); - test_almost(1.0, 0.0951625819640404268358, 1e-16, cdf(0.1)); - test_case(10.0, 0.6321205588285576784045, cdf(0.1)); - test_case(f64::INFINITY, 1.0, cdf(0.1)); - test_almost(0.1, 0.0951625819640404268358, 1e-16, cdf(1.0)); - test_case(1.0, 0.6321205588285576784045, cdf(1.0)); - test_case(10.0, 0.9999546000702375151485, cdf(1.0)); - test_case(f64::INFINITY, 1.0, cdf(1.0)); - test_case(0.1, 1.0, cdf(f64::INFINITY)); - test_case(1.0, 1.0, cdf(f64::INFINITY)); - test_case(10.0, 1.0, cdf(f64::INFINITY)); - test_case(f64::INFINITY, 1.0, cdf(f64::INFINITY)); + test_absolute(0.1, 0.009950166250831946426094, 1e-16, cdf(0.1)); + test_absolute(1.0, 0.0951625819640404268358, 1e-16, cdf(0.1)); + test_exact(10.0, 0.6321205588285576784045, cdf(0.1)); + test_exact(f64::INFINITY, 1.0, cdf(0.1)); + test_absolute(0.1, 0.0951625819640404268358, 1e-16, cdf(1.0)); + test_exact(1.0, 0.6321205588285576784045, cdf(1.0)); + test_exact(10.0, 0.9999546000702375151485, cdf(1.0)); + test_exact(f64::INFINITY, 1.0, cdf(1.0)); + test_exact(0.1, 1.0, cdf(f64::INFINITY)); + test_exact(1.0, 1.0, cdf(f64::INFINITY)); + test_exact(10.0, 1.0, cdf(f64::INFINITY)); + test_exact(f64::INFINITY, 1.0, cdf(f64::INFINITY)); } #[test] @@ -502,32 +461,32 @@ mod tests { #[test] fn test_sf() { let sf = |arg: f64| move |x: Exp| x.sf(arg); - test_case(0.1, 1.0, sf(0.0)); - test_case(1.0, 1.0, sf(0.0)); - test_case(10.0, 1.0, sf(0.0)); + test_exact(0.1, 1.0, sf(0.0)); + test_exact(1.0, 1.0, sf(0.0)); + test_exact(10.0, 1.0, sf(0.0)); test_is_nan(f64::INFINITY, sf(0.0)); - test_almost(0.1, 0.9900498337491681, 1e-16, sf(0.1)); - test_almost(1.0, 0.9048374180359595, 1e-16, sf(0.1)); - test_almost(10.0, 0.36787944117144233, 1e-15, sf(0.1)); - test_case(f64::INFINITY, 0.0, sf(0.1)); + test_absolute(0.1, 0.9900498337491681, 1e-16, sf(0.1)); + test_absolute(1.0, 0.9048374180359595, 1e-16, sf(0.1)); + test_absolute(10.0, 0.36787944117144233, 1e-15, sf(0.1)); + test_exact(f64::INFINITY, 0.0, sf(0.1)); } #[test] fn test_neg_cdf() { let cdf = |arg: f64| move |x: Exp| x.cdf(arg); - test_case(0.1, 0.0, cdf(-1.0)); + test_exact(0.1, 0.0, cdf(-1.0)); } #[test] fn test_neg_sf() { let sf = |arg: f64| move |x: Exp| x.sf(arg); - test_case(0.1, 1.0, sf(-1.0)); + test_exact(0.1, 1.0, sf(-1.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(0.5), 0.0, 10.0); - test::check_continuous_distribution(&try_create(1.5), 0.0, 20.0); - test::check_continuous_distribution(&try_create(2.5), 0.0, 50.0); + test::check_continuous_distribution(&create_ok(0.5), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(1.5), 0.0, 20.0); + test::check_continuous_distribution(&create_ok(2.5), 0.0, 50.0); } } diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index a21e9330..cf69f1cd 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -240,6 +240,19 @@ pub mod test { } } + /// Gets a value for the given parameters by calling `create_and_get` + /// and asserts that it is [`NAN`]. + /// + /// Panics if `::new` fails. + #[allow(dead_code)] + fn test_is_nan($($arg_name: $arg_ty),+, get_fn: F) + where + F: Fn($dist) -> f64 + { + let x = create_and_get($($arg_name),+, get_fn); + assert!(x.is_nan()); + } + /// Gets a value for the given parameters by calling `create_and_get` /// and asserts that it is [`None`]. /// From af369b5fb0afde3279b08a2d72636489717df084 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:43:20 +0200 Subject: [PATCH 150/185] Use testing_boiler! for FisherSnedecor --- src/distribution/fisher_snedecor.rs | 273 ++++++++++++---------------- 1 file changed, 115 insertions(+), 158 deletions(-) diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index 59ca595a..9d5ef867 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -385,254 +385,211 @@ impl Continuous for FisherSnedecor { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, FisherSnedecor}; use crate::distribution::internal::*; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(freedom_1: f64, freedom_2: f64) -> FisherSnedecor { - let n = FisherSnedecor::new(freedom_1, freedom_2); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(freedom_1: f64, freedom_2: f64) { - let n = try_create(freedom_1, freedom_2); - assert_eq!(freedom_1, n.freedom_1()); - assert_eq!(freedom_2, n.freedom_2()); - } - - fn bad_create_case(freedom_1: f64, freedom_2: f64) { - let n = FisherSnedecor::new(freedom_1, freedom_2); - assert!(n.is_err()); - } - - fn get_value(freedom_1: f64, freedom_2: f64, eval: F) -> f64 - where F: Fn(FisherSnedecor) -> f64 - { - let n = try_create(freedom_1, freedom_2); - eval(n) - } - - fn test_case(freedom_1: f64, freedom_2: f64, expected: f64, eval: F) - where F: Fn(FisherSnedecor) -> f64 - { - let x = get_value(freedom_1, freedom_2, eval); - assert_eq!(expected, x); - } - - fn test_almost(freedom_1: f64, freedom_2: f64, expected: f64, acc: f64, eval: F) - where F: Fn(FisherSnedecor) -> f64 - { - let x = get_value(freedom_1, freedom_2, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(freedom_1: f64, freedom_2: f64; FisherSnedecor); #[test] fn test_create() { - create_case(0.1, 0.1); - create_case(1.0, 0.1); - create_case(10.0, 0.1); - create_case(0.1, 1.0); - create_case(1.0, 1.0); - create_case(10.0, 1.0); + create_ok(0.1, 0.1); + create_ok(1.0, 0.1); + create_ok(10.0, 0.1); + create_ok(0.1, 1.0); + create_ok(1.0, 1.0); + create_ok(10.0, 1.0); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN, f64::NAN); - bad_create_case(0.0, f64::NAN); - bad_create_case(-1.0, f64::NAN); - bad_create_case(-10.0, f64::NAN); - bad_create_case(f64::NAN, 0.0); - bad_create_case(0.0, 0.0); - bad_create_case(-1.0, 0.0); - bad_create_case(-10.0, 0.0); - bad_create_case(f64::NAN, -1.0); - bad_create_case(0.0, -1.0); - bad_create_case(-1.0, -1.0); - bad_create_case(-10.0, -1.0); - bad_create_case(f64::NAN, -10.0); - bad_create_case(0.0, -10.0); - bad_create_case(-1.0, -10.0); - bad_create_case(-10.0, -10.0); - bad_create_case(f64::INFINITY, 0.1); - bad_create_case(0.1, f64::INFINITY); - bad_create_case(f64::INFINITY, f64::INFINITY); + create_err(f64::NAN, f64::NAN); + create_err(0.0, f64::NAN); + create_err(-1.0, f64::NAN); + create_err(-10.0, f64::NAN); + create_err(f64::NAN, 0.0); + create_err(0.0, 0.0); + create_err(-1.0, 0.0); + create_err(-10.0, 0.0); + create_err(f64::NAN, -1.0); + create_err(0.0, -1.0); + create_err(-1.0, -1.0); + create_err(-10.0, -1.0); + create_err(f64::NAN, -10.0); + create_err(0.0, -10.0); + create_err(-1.0, -10.0); + create_err(-10.0, -10.0); + create_err(f64::INFINITY, 0.1); + create_err(0.1, f64::INFINITY); + create_err(f64::INFINITY, f64::INFINITY); } #[test] fn test_mean() { let mean = |x: FisherSnedecor| x.mean().unwrap(); - test_case(0.1, 10.0, 1.25, mean); - test_case(1.0, 10.0, 1.25, mean); - test_case(10.0, 10.0, 1.25, mean); + test_exact(0.1, 10.0, 1.25, mean); + test_exact(1.0, 10.0, 1.25, mean); + test_exact(10.0, 10.0, 1.25, mean); } #[test] - #[should_panic] fn test_mean_with_low_d2() { - let mean = |x: FisherSnedecor| x.mean().unwrap(); - get_value(0.1, 0.1, mean); + test_none(0.1, 0.1, |dist| dist.mean()); } #[test] fn test_variance() { let variance = |x: FisherSnedecor| x.variance().unwrap(); - test_almost(0.1, 10.0, 42.1875, 1e-14, variance); - test_case(1.0, 10.0, 4.6875, variance); - test_case(10.0, 10.0, 0.9375, variance); + test_absolute(0.1, 10.0, 42.1875, 1e-14, variance); + test_exact(1.0, 10.0, 4.6875, variance); + test_exact(10.0, 10.0, 0.9375, variance); } #[test] - #[should_panic] fn test_variance_with_low_d2() { - let variance = |x: FisherSnedecor| x.variance().unwrap(); - get_value(0.1, 0.1, variance); + test_none(0.1, 0.1, |dist| dist.variance()); } #[test] fn test_skewness() { let skewness = |x: FisherSnedecor| x.skewness().unwrap(); - test_almost(0.1, 10.0, 15.78090735784977089658, 1e-14, skewness); - test_case(1.0, 10.0, 5.773502691896257645091, skewness); - test_case(10.0, 10.0, 3.614784456460255759501, skewness); + test_absolute(0.1, 10.0, 15.78090735784977089658, 1e-14, skewness); + test_exact(1.0, 10.0, 5.773502691896257645091, skewness); + test_exact(10.0, 10.0, 3.614784456460255759501, skewness); } #[test] - #[should_panic] fn test_skewness_with_low_d2() { - let skewness = |x: FisherSnedecor| x.skewness().unwrap(); - get_value(0.1, 0.1, skewness); + test_none(0.1, 0.1, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: FisherSnedecor| x.mode().unwrap(); - test_case(10.0, 0.1, 0.0380952380952380952381, mode); - test_case(10.0, 1.0, 4.0 / 15.0, mode); - test_case(10.0, 10.0, 2.0 / 3.0, mode); + test_exact(10.0, 0.1, 0.0380952380952380952381, mode); + test_exact(10.0, 1.0, 4.0 / 15.0, mode); + test_exact(10.0, 10.0, 2.0 / 3.0, mode); } #[test] - #[should_panic] fn test_mode_with_low_d1() { - let mode = |x: FisherSnedecor| x.mode().unwrap(); - get_value(0.1, 0.1, mode); + test_none(0.1, 0.1, |dist| dist.mode()); } #[test] fn test_min_max() { let min = |x: FisherSnedecor| x.min(); let max = |x: FisherSnedecor| x.max(); - test_case(1.0, 1.0, 0.0, min); - test_case(1.0, 1.0, f64::INFINITY, max); + test_exact(1.0, 1.0, 0.0, min); + test_exact(1.0, 1.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: FisherSnedecor| x.pdf(arg); - test_almost(0.1, 0.1, 0.0234154207226588982471, 1e-16, pdf(1.0)); - test_almost(1.0, 0.1, 0.0396064560910663979961, 1e-16, pdf(1.0)); - test_almost(10.0, 0.1, 0.0418440630400545297349, 1e-14, pdf(1.0)); - test_almost(0.1, 1.0, 0.0396064560910663979961, 1e-16, pdf(1.0)); - test_almost(1.0, 1.0, 0.1591549430918953357689, 1e-16, pdf(1.0)); - test_almost(10.0, 1.0, 0.230361989229138647108, 1e-16, pdf(1.0)); - test_almost(0.1, 0.1, 0.00221546909694001013517, 1e-18, pdf(10.0)); - test_almost(1.0, 0.1, 0.00369960370387922619592, 1e-17, pdf(10.0)); - test_almost(10.0, 0.1, 0.00390179721174142927402, 1e-15, pdf(10.0)); - test_almost(0.1, 1.0, 0.00319864073359931548273, 1e-17, pdf(10.0)); - test_almost(1.0, 1.0, 0.009150765837179460915678, 1e-17, pdf(10.0)); - test_almost(10.0, 1.0, 0.0116493859171442148446, 1e-17, pdf(10.0)); - test_almost(0.1, 10.0, 0.00305087016058573989694, 1e-15, pdf(10.0)); - test_almost(1.0, 10.0, 0.00271897749113479577864, 1e-17, pdf(10.0)); - test_almost(10.0, 10.0, 2.4289227234060500084E-4, 1e-18, pdf(10.0)); + test_absolute(0.1, 0.1, 0.0234154207226588982471, 1e-16, pdf(1.0)); + test_absolute(1.0, 0.1, 0.0396064560910663979961, 1e-16, pdf(1.0)); + test_absolute(10.0, 0.1, 0.0418440630400545297349, 1e-14, pdf(1.0)); + test_absolute(0.1, 1.0, 0.0396064560910663979961, 1e-16, pdf(1.0)); + test_absolute(1.0, 1.0, 0.1591549430918953357689, 1e-16, pdf(1.0)); + test_absolute(10.0, 1.0, 0.230361989229138647108, 1e-16, pdf(1.0)); + test_absolute(0.1, 0.1, 0.00221546909694001013517, 1e-18, pdf(10.0)); + test_absolute(1.0, 0.1, 0.00369960370387922619592, 1e-17, pdf(10.0)); + test_absolute(10.0, 0.1, 0.00390179721174142927402, 1e-15, pdf(10.0)); + test_absolute(0.1, 1.0, 0.00319864073359931548273, 1e-17, pdf(10.0)); + test_absolute(1.0, 1.0, 0.009150765837179460915678, 1e-17, pdf(10.0)); + test_absolute(10.0, 1.0, 0.0116493859171442148446, 1e-17, pdf(10.0)); + test_absolute(0.1, 10.0, 0.00305087016058573989694, 1e-15, pdf(10.0)); + test_absolute(1.0, 10.0, 0.00271897749113479577864, 1e-17, pdf(10.0)); + test_absolute(10.0, 10.0, 2.4289227234060500084E-4, 1e-18, pdf(10.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: FisherSnedecor| x.ln_pdf(arg); - test_almost(0.1, 0.1, 0.0234154207226588982471f64.ln(), 1e-15, ln_pdf(1.0)); - test_almost(1.0, 0.1, 0.0396064560910663979961f64.ln(), 1e-15, ln_pdf(1.0)); - test_almost(10.0, 0.1, 0.0418440630400545297349f64.ln(), 1e-13, ln_pdf(1.0)); - test_almost(0.1, 1.0, 0.0396064560910663979961f64.ln(), 1e-15, ln_pdf(1.0)); - test_almost(1.0, 1.0, 0.1591549430918953357689f64.ln(), 1e-15, ln_pdf(1.0)); - test_almost(10.0, 1.0, 0.230361989229138647108f64.ln(), 1e-15, ln_pdf(1.0)); - test_case(0.1, 0.1, 0.00221546909694001013517f64.ln(), ln_pdf(10.0)); - test_almost(1.0, 0.1, 0.00369960370387922619592f64.ln(), 1e-15, ln_pdf(10.0)); - test_almost(10.0, 0.1, 0.00390179721174142927402f64.ln(), 1e-13, ln_pdf(10.0)); - test_almost(0.1, 1.0, 0.00319864073359931548273f64.ln(), 1e-15, ln_pdf(10.0)); - test_almost(1.0, 1.0, 0.009150765837179460915678f64.ln(), 1e-15, ln_pdf(10.0)); - test_case(10.0, 1.0, 0.0116493859171442148446f64.ln(), ln_pdf(10.0)); - test_almost(0.1, 10.0, 0.00305087016058573989694f64.ln(), 1e-13, ln_pdf(10.0)); - test_case(1.0, 10.0, 0.00271897749113479577864f64.ln(), ln_pdf(10.0)); - test_almost(10.0, 10.0, 2.4289227234060500084E-4f64.ln(), 1e-14, ln_pdf(10.0)); + test_absolute(0.1, 0.1, 0.0234154207226588982471f64.ln(), 1e-15, ln_pdf(1.0)); + test_absolute(1.0, 0.1, 0.0396064560910663979961f64.ln(), 1e-15, ln_pdf(1.0)); + test_absolute(10.0, 0.1, 0.0418440630400545297349f64.ln(), 1e-13, ln_pdf(1.0)); + test_absolute(0.1, 1.0, 0.0396064560910663979961f64.ln(), 1e-15, ln_pdf(1.0)); + test_absolute(1.0, 1.0, 0.1591549430918953357689f64.ln(), 1e-15, ln_pdf(1.0)); + test_absolute(10.0, 1.0, 0.230361989229138647108f64.ln(), 1e-15, ln_pdf(1.0)); + test_exact(0.1, 0.1, 0.00221546909694001013517f64.ln(), ln_pdf(10.0)); + test_absolute(1.0, 0.1, 0.00369960370387922619592f64.ln(), 1e-15, ln_pdf(10.0)); + test_absolute(10.0, 0.1, 0.00390179721174142927402f64.ln(), 1e-13, ln_pdf(10.0)); + test_absolute(0.1, 1.0, 0.00319864073359931548273f64.ln(), 1e-15, ln_pdf(10.0)); + test_absolute(1.0, 1.0, 0.009150765837179460915678f64.ln(), 1e-15, ln_pdf(10.0)); + test_exact(10.0, 1.0, 0.0116493859171442148446f64.ln(), ln_pdf(10.0)); + test_absolute(0.1, 10.0, 0.00305087016058573989694f64.ln(), 1e-13, ln_pdf(10.0)); + test_exact(1.0, 10.0, 0.00271897749113479577864f64.ln(), ln_pdf(10.0)); + test_absolute(10.0, 10.0, 2.4289227234060500084E-4f64.ln(), 1e-14, ln_pdf(10.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: FisherSnedecor| x.cdf(arg); - test_almost(0.1, 0.1, 0.44712986033425140335, 1e-15, cdf(0.1)); - test_almost(1.0, 0.1, 0.08156522095104674015, 1e-15, cdf(0.1)); - test_almost(10.0, 0.1, 0.033184005716276536322, 1e-13, cdf(0.1)); - test_almost(0.1, 1.0, 0.74378710917986379989, 1e-15, cdf(0.1)); - test_almost(1.0, 1.0, 0.1949822290421366451595, 1e-16, cdf(0.1)); - test_almost(10.0, 1.0, 0.0101195597354337146205, 1e-17, cdf(0.1)); - test_almost(0.1, 0.1, 0.5, 1e-15, cdf(1.0)); - test_almost(1.0, 0.1, 0.16734351500944271141, 1e-14, cdf(1.0)); - test_almost(10.0, 0.1, 0.12207560664741704938, 1e-13, cdf(1.0)); - test_almost(0.1, 1.0, 0.83265648499055728859, 1e-15, cdf(1.0)); - test_almost(1.0, 1.0, 0.5, 1e-15, cdf(1.0)); - test_almost(10.0, 1.0, 0.340893132302059872675, 1e-15, cdf(1.0)); + test_absolute(0.1, 0.1, 0.44712986033425140335, 1e-15, cdf(0.1)); + test_absolute(1.0, 0.1, 0.08156522095104674015, 1e-15, cdf(0.1)); + test_absolute(10.0, 0.1, 0.033184005716276536322, 1e-13, cdf(0.1)); + test_absolute(0.1, 1.0, 0.74378710917986379989, 1e-15, cdf(0.1)); + test_absolute(1.0, 1.0, 0.1949822290421366451595, 1e-16, cdf(0.1)); + test_absolute(10.0, 1.0, 0.0101195597354337146205, 1e-17, cdf(0.1)); + test_absolute(0.1, 0.1, 0.5, 1e-15, cdf(1.0)); + test_absolute(1.0, 0.1, 0.16734351500944271141, 1e-14, cdf(1.0)); + test_absolute(10.0, 0.1, 0.12207560664741704938, 1e-13, cdf(1.0)); + test_absolute(0.1, 1.0, 0.83265648499055728859, 1e-15, cdf(1.0)); + test_absolute(1.0, 1.0, 0.5, 1e-15, cdf(1.0)); + test_absolute(10.0, 1.0, 0.340893132302059872675, 1e-15, cdf(1.0)); } #[test] fn test_cdf_lower_bound() { let cdf = |arg: f64| move |x: FisherSnedecor| x.cdf(arg); - test_case(0.1, 0.1, 0.0, cdf(-1.0)); + test_exact(0.1, 0.1, 0.0, cdf(-1.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: FisherSnedecor| x.sf(arg); - test_almost(0.1, 0.1, 0.5528701396657489, 1e-12, sf(0.1)); - test_almost(1.0, 0.1, 0.9184347790489533, 1e-12, sf(0.1)); - test_almost(10.0, 0.1, 0.9668159942836896, 1e-12, sf(0.1)); - test_almost(0.1, 1.0, 0.25621289082013654, 1e-12, sf(0.1)); - test_almost(1.0, 1.0, 0.8050177709578634, 1e-12, sf(0.1)); - test_almost(10.0, 1.0, 0.9898804402645662, 1e-12, sf(0.1)); - test_almost(0.1, 0.1, 0.5, 1e-15, sf(1.0)); - test_almost(1.0, 0.1, 0.8326564849905562, 1e-12, sf(1.0)); - test_almost(10.0, 0.1, 0.8779243933525519, 1e-12, sf(1.0)); - test_almost(0.1, 1.0, 0.16734351500944344, 1e-12, sf(1.0)); - test_almost(1.0, 1.0, 0.5, 1e-12, sf(1.0)); - test_almost(10.0, 1.0, 0.65910686769794, 1e-12, sf(1.0)); + test_absolute(0.1, 0.1, 0.5528701396657489, 1e-12, sf(0.1)); + test_absolute(1.0, 0.1, 0.9184347790489533, 1e-12, sf(0.1)); + test_absolute(10.0, 0.1, 0.9668159942836896, 1e-12, sf(0.1)); + test_absolute(0.1, 1.0, 0.25621289082013654, 1e-12, sf(0.1)); + test_absolute(1.0, 1.0, 0.8050177709578634, 1e-12, sf(0.1)); + test_absolute(10.0, 1.0, 0.9898804402645662, 1e-12, sf(0.1)); + test_absolute(0.1, 0.1, 0.5, 1e-15, sf(1.0)); + test_absolute(1.0, 0.1, 0.8326564849905562, 1e-12, sf(1.0)); + test_absolute(10.0, 0.1, 0.8779243933525519, 1e-12, sf(1.0)); + test_absolute(0.1, 1.0, 0.16734351500944344, 1e-12, sf(1.0)); + test_absolute(1.0, 1.0, 0.5, 1e-12, sf(1.0)); + test_absolute(10.0, 1.0, 0.65910686769794, 1e-12, sf(1.0)); } #[test] fn test_inverse_cdf() { let func = |arg: f64| move |x: FisherSnedecor| x.inverse_cdf(x.cdf(arg)); - test_almost(0.1, 0.1, 0.1, 1e-12, func(0.1)); - test_almost(1.0, 0.1, 0.1, 1e-12, func(0.1)); - test_almost(10.0, 0.1, 0.1, 1e-12, func(0.1)); - test_almost(0.1, 1.0, 0.1, 1e-12, func(0.1)); - test_almost(1.0, 1.0, 0.1, 1e-12, func(0.1)); - test_almost(10.0, 1.0, 0.1, 1e-12, func(0.1)); - test_almost(0.1, 0.1, 1.0, 1e-13, func(1.0)); - test_almost(1.0, 0.1, 1.0, 1e-12, func(1.0)); - test_almost(10.0, 0.1, 1.0, 1e-12, func(1.0)); - test_almost(0.1, 1.0, 1.0, 1e-12, func(1.0)); - test_almost(1.0, 1.0, 1.0, 1e-12, func(1.0)); - test_almost(10.0, 1.0, 1.0, 1e-12, func(1.0)); + test_absolute(0.1, 0.1, 0.1, 1e-12, func(0.1)); + test_absolute(1.0, 0.1, 0.1, 1e-12, func(0.1)); + test_absolute(10.0, 0.1, 0.1, 1e-12, func(0.1)); + test_absolute(0.1, 1.0, 0.1, 1e-12, func(0.1)); + test_absolute(1.0, 1.0, 0.1, 1e-12, func(0.1)); + test_absolute(10.0, 1.0, 0.1, 1e-12, func(0.1)); + test_absolute(0.1, 0.1, 1.0, 1e-13, func(1.0)); + test_absolute(1.0, 0.1, 1.0, 1e-12, func(1.0)); + test_absolute(10.0, 0.1, 1.0, 1e-12, func(1.0)); + test_absolute(0.1, 1.0, 1.0, 1e-12, func(1.0)); + test_absolute(1.0, 1.0, 1.0, 1e-12, func(1.0)); + test_absolute(10.0, 1.0, 1.0, 1e-12, func(1.0)); } #[test] fn test_sf_lower_bound() { let sf = |arg: f64| move |x: FisherSnedecor| x.sf(arg); - test_case(0.1, 0.1, 1.0, sf(-1.0)); + test_exact(0.1, 0.1, 1.0, sf(-1.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(10.0, 10.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(10.0, 10.0), 0.0, 10.0); } } From 8d603c08b17f0fb4c9c18e06b9073912659f07ed Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:44:57 +0200 Subject: [PATCH 151/185] Use testing_boiler! for Geometric --- src/distribution/geometric.rs | 146 ++++++++++++---------------------- 1 file changed, 51 insertions(+), 95 deletions(-) diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index 4df623ed..f584cc0e 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -273,173 +273,129 @@ impl Discrete for Geometric { #[rustfmt::skip] #[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; use crate::distribution::{DiscreteCDF, Discrete, Geometric}; use crate::distribution::internal::*; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(p: f64) -> Geometric { - let n = Geometric::new(p); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(p: f64) { - let n = try_create(p); - assert_eq!(p, n.p()); - } - - fn bad_create_case(p: f64) { - let n = Geometric::new(p); - assert!(n.is_err()); - } - - fn get_value(p: f64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Geometric) -> T - { - let n = try_create(p); - eval(n) - } - - fn test_case(p: f64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(Geometric) -> T - { - let x = get_value(p, eval); - assert_eq!(expected, x); - } - - fn test_almost(p: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Geometric) -> f64 - { - let x = get_value(p, eval); - assert_almost_eq!(expected, x, acc); - } - - fn test_is_nan(p: f64, eval: F) - where F: Fn(Geometric) -> f64 - { - let x = get_value(p, eval); - assert!(x.is_nan()); - } + testing_boiler!(p: f64; Geometric); #[test] fn test_create() { - create_case(0.3); - create_case(1.0); + create_ok(0.3); + create_ok(1.0); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN); - bad_create_case(0.0); - bad_create_case(-1.0); - bad_create_case(2.0); + create_err(f64::NAN); + create_err(0.0); + create_err(-1.0); + create_err(2.0); } #[test] fn test_mean() { let mean = |x: Geometric| x.mean().unwrap(); - test_case(0.3, 1.0 / 0.3, mean); - test_case(1.0, 1.0, mean); + test_exact(0.3, 1.0 / 0.3, mean); + test_exact(1.0, 1.0, mean); } #[test] fn test_variance() { let variance = |x: Geometric| x.variance().unwrap(); - test_case(0.3, 0.7 / (0.3 * 0.3), variance); - test_case(1.0, 0.0, variance); + test_exact(0.3, 0.7 / (0.3 * 0.3), variance); + test_exact(1.0, 0.0, variance); } #[test] fn test_entropy() { let entropy = |x: Geometric| x.entropy().unwrap(); - test_almost(0.3, 2.937636330768973333333, 1e-14, entropy); + test_absolute(0.3, 2.937636330768973333333, 1e-14, entropy); test_is_nan(1.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Geometric| x.skewness().unwrap(); - test_almost(0.3, 2.031888635868469187947, 1e-15, skewness); - test_case(1.0, f64::INFINITY, skewness); + test_absolute(0.3, 2.031888635868469187947, 1e-15, skewness); + test_exact(1.0, f64::INFINITY, skewness); } #[test] fn test_median() { let median = |x: Geometric| x.median(); - test_case(0.0001, 6932.0, median); - test_case(0.1, 7.0, median); - test_case(0.3, 2.0, median); - test_case(0.9, 1.0, median); - // test_case(0.99, 1.0, median); - test_case(1.0, 0.0, median); + test_exact(0.0001, 6932.0, median); + test_exact(0.1, 7.0, median); + test_exact(0.3, 2.0, median); + test_exact(0.9, 1.0, median); + // test_exact(0.99, 1.0, median); + test_exact(1.0, 0.0, median); } #[test] fn test_mode() { let mode = |x: Geometric| x.mode().unwrap(); - test_case(0.3, 1, mode); - test_case(1.0, 1, mode); + test_exact(0.3, 1, mode); + test_exact(1.0, 1, mode); } #[test] fn test_min_max() { let min = |x: Geometric| x.min(); let max = |x: Geometric| x.max(); - test_case(0.3, 1, min); - test_case(0.3, u64::MAX, max); + test_exact(0.3, 1, min); + test_exact(0.3, u64::MAX, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Geometric| x.pmf(arg); - test_case(0.3, 0.3, pmf(1)); - test_case(0.3, 0.21, pmf(2)); - test_case(1.0, 1.0, pmf(1)); - test_case(1.0, 0.0, pmf(2)); - test_almost(0.5, 0.5, 1e-10, pmf(1)); - test_almost(0.5, 0.25, 1e-10, pmf(2)); + test_exact(0.3, 0.3, pmf(1)); + test_exact(0.3, 0.21, pmf(2)); + test_exact(1.0, 1.0, pmf(1)); + test_exact(1.0, 0.0, pmf(2)); + test_absolute(0.5, 0.5, 1e-10, pmf(1)); + test_absolute(0.5, 0.25, 1e-10, pmf(2)); } #[test] fn test_pmf_lower_bound() { let pmf = |arg: u64| move |x: Geometric| x.pmf(arg); - test_case(0.3, 0.0, pmf(0)); + test_exact(0.3, 0.0, pmf(0)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Geometric| x.ln_pmf(arg); - test_almost(0.3, -1.203972804325935992623, 1e-15, ln_pmf(1)); - test_almost(0.3, -1.560647748264668371535, 1e-15, ln_pmf(2)); - test_case(1.0, 0.0, ln_pmf(1)); - test_case(1.0, f64::NEG_INFINITY, ln_pmf(2)); + test_absolute(0.3, -1.203972804325935992623, 1e-15, ln_pmf(1)); + test_absolute(0.3, -1.560647748264668371535, 1e-15, ln_pmf(2)); + test_exact(1.0, 0.0, ln_pmf(1)); + test_exact(1.0, f64::NEG_INFINITY, ln_pmf(2)); } #[test] fn test_ln_pmf_lower_bound() { let ln_pmf = |arg: u64| move |x: Geometric| x.ln_pmf(arg); - test_case(0.3, f64::NEG_INFINITY, ln_pmf(0)); + test_exact(0.3, f64::NEG_INFINITY, ln_pmf(0)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Geometric| x.cdf(arg); - test_case(1.0, 1.0, cdf(1)); - test_case(1.0, 1.0, cdf(2)); - test_almost(0.5, 0.5, 1e-15, cdf(1)); - test_almost(0.5, 0.75, 1e-15, cdf(2)); + test_exact(1.0, 1.0, cdf(1)); + test_exact(1.0, 1.0, cdf(2)); + test_absolute(0.5, 0.5, 1e-15, cdf(1)); + test_absolute(0.5, 0.75, 1e-15, cdf(2)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Geometric| x.sf(arg); - test_case(1.0, 0.0, sf(1)); - test_case(1.0, 0.0, sf(2)); - test_almost(0.5, 0.5, 1e-15, sf(1)); - test_almost(0.5, 0.25, 1e-15, sf(2)); + test_exact(1.0, 0.0, sf(1)); + test_exact(1.0, 0.0, sf(2)); + test_absolute(0.5, 0.5, 1e-15, sf(1)); + test_absolute(0.5, 0.25, 1e-15, sf(2)); } #[test] @@ -511,19 +467,19 @@ mod tests { #[test] fn test_cdf_lower_bound() { let cdf = |arg: u64| move |x: Geometric| x.cdf(arg); - test_case(0.3, 0.0, cdf(0)); + test_exact(0.3, 0.0, cdf(0)); } #[test] fn test_sf_lower_bound() { let sf = |arg: u64| move |x: Geometric| x.sf(arg); - test_case(0.3, 1.0, sf(0)); + test_exact(0.3, 1.0, sf(0)); } #[test] fn test_discrete() { - test::check_discrete_distribution(&try_create(0.3), 100); - test::check_discrete_distribution(&try_create(0.6), 100); - test::check_discrete_distribution(&try_create(1.0), 1); + test::check_discrete_distribution(&create_ok(0.3), 100); + test::check_discrete_distribution(&create_ok(0.6), 100); + test::check_discrete_distribution(&create_ok(1.0), 1); } } From 94cbf56e1ffe3ee6e9fe28bdabd3c72d208f3c82 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:51:19 +0200 Subject: [PATCH 152/185] Use testing_boiler! for HyperGeometric --- src/distribution/hypergeometric.rs | 205 +++++++++++------------------ 1 file changed, 80 insertions(+), 125 deletions(-) diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 8b6d8500..800e03fe 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -378,226 +378,181 @@ impl Discrete for Hypergeometric { #[rustfmt::skip] #[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; use crate::distribution::{DiscreteCDF, Discrete, Hypergeometric}; use crate::distribution::internal::*; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(population: u64, successes: u64, draws: u64) -> Hypergeometric { - let n = Hypergeometric::new(population, successes, draws); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(population: u64, successes: u64, draws: u64) { - let n = try_create(population, successes, draws); - assert_eq!(population, n.population()); - assert_eq!(successes, n.successes()); - assert_eq!(draws, n.draws()); - } - - fn bad_create_case(population: u64, successes: u64, draws: u64) { - let n = Hypergeometric::new(population, successes, draws); - assert!(n.is_err()); - } - - fn get_value(population: u64, successes: u64, draws: u64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Hypergeometric) -> T - { - let n = try_create(population, successes, draws); - eval(n) - } - - fn test_case(population: u64, successes: u64, draws: u64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(Hypergeometric) -> T - { - let x = get_value(population, successes, draws, eval); - assert_eq!(expected, x); - } - - fn test_almost(population: u64, successes: u64, draws: u64, expected: f64, acc: f64, eval: F) - where F: Fn(Hypergeometric) -> f64 - { - let x = get_value(population, successes, draws, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(population: u64, successes: u64, draws: u64; Hypergeometric); #[test] fn test_create() { - create_case(0, 0, 0); - create_case(1, 1, 1,); - create_case(2, 1, 1); - create_case(2, 2, 2); - create_case(10, 1, 1); - create_case(10, 5, 3); + create_ok(0, 0, 0); + create_ok(1, 1, 1,); + create_ok(2, 1, 1); + create_ok(2, 2, 2); + create_ok(10, 1, 1); + create_ok(10, 5, 3); } #[test] fn test_bad_create() { - bad_create_case(2, 3, 2); - bad_create_case(10, 5, 20); - bad_create_case(0, 1, 1); + create_err(2, 3, 2); + create_err(10, 5, 20); + create_err(0, 1, 1); } #[test] fn test_mean() { let mean = |x: Hypergeometric| x.mean().unwrap(); - test_case(1, 1, 1, 1.0, mean); - test_case(2, 1, 1, 0.5, mean); - test_case(2, 2, 2, 2.0, mean); - test_case(10, 1, 1, 0.1, mean); - test_case(10, 5, 3, 15.0 / 10.0, mean); + test_exact(1, 1, 1, 1.0, mean); + test_exact(2, 1, 1, 0.5, mean); + test_exact(2, 2, 2, 2.0, mean); + test_exact(10, 1, 1, 0.1, mean); + test_exact(10, 5, 3, 15.0 / 10.0, mean); } #[test] - #[should_panic] fn test_mean_with_population_0() { - let mean = |x: Hypergeometric| x.mean().unwrap(); - get_value(0, 0, 0, mean); + test_none(0, 0, 0, |dist| dist.mean()); } #[test] fn test_variance() { let variance = |x: Hypergeometric| x.variance().unwrap(); - test_case(2, 1, 1, 0.25, variance); - test_case(2, 2, 2, 0.0, variance); - test_case(10, 1, 1, 81.0 / 900.0, variance); - test_case(10, 5, 3, 525.0 / 900.0, variance); + test_exact(2, 1, 1, 0.25, variance); + test_exact(2, 2, 2, 0.0, variance); + test_exact(10, 1, 1, 81.0 / 900.0, variance); + test_exact(10, 5, 3, 525.0 / 900.0, variance); } #[test] - #[should_panic] fn test_variance_with_pop_lte_1() { - let variance = |x: Hypergeometric| x.variance().unwrap(); - get_value(1, 1, 1, variance); + test_none(1, 1, 1, |dist| dist.variance()); } #[test] fn test_skewness() { let skewness = |x: Hypergeometric| x.skewness().unwrap(); - test_case(10, 1, 1, 8.0 / 3.0, skewness); - test_case(10, 5, 3, 0.0, skewness); + test_exact(10, 1, 1, 8.0 / 3.0, skewness); + test_exact(10, 5, 3, 0.0, skewness); } #[test] - #[should_panic] fn test_skewness_with_pop_lte_2() { - let skewness = |x: Hypergeometric| x.skewness().unwrap(); - get_value(2, 2, 2, skewness); + test_none(2, 2, 2, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: Hypergeometric| x.mode().unwrap(); - test_case(0, 0, 0, 0, mode); - test_case(1, 1, 1, 1, mode); - test_case(2, 1, 1, 1, mode); - test_case(2, 2, 2, 2, mode); - test_case(10, 1, 1, 0, mode); - test_case(10, 5, 3, 2, mode); + test_exact(0, 0, 0, 0, mode); + test_exact(1, 1, 1, 1, mode); + test_exact(2, 1, 1, 1, mode); + test_exact(2, 2, 2, 2, mode); + test_exact(10, 1, 1, 0, mode); + test_exact(10, 5, 3, 2, mode); } #[test] fn test_min() { let min = |x: Hypergeometric| x.min(); - test_case(0, 0, 0, 0, min); - test_case(1, 1, 1, 1, min); - test_case(2, 1, 1, 0, min); - test_case(2, 2, 2, 2, min); - test_case(10, 1, 1, 0, min); - test_case(10, 5, 3, 0, min); + test_exact(0, 0, 0, 0, min); + test_exact(1, 1, 1, 1, min); + test_exact(2, 1, 1, 0, min); + test_exact(2, 2, 2, 2, min); + test_exact(10, 1, 1, 0, min); + test_exact(10, 5, 3, 0, min); } #[test] fn test_max() { let max = |x: Hypergeometric| x.max(); - test_case(0, 0, 0, 0, max); - test_case(1, 1, 1, 1, max); - test_case(2, 1, 1, 1, max); - test_case(2, 2, 2, 2, max); - test_case(10, 1, 1, 1, max); - test_case(10, 5, 3, 3, max); + test_exact(0, 0, 0, 0, max); + test_exact(1, 1, 1, 1, max); + test_exact(2, 1, 1, 1, max); + test_exact(2, 2, 2, 2, max); + test_exact(10, 1, 1, 1, max); + test_exact(10, 5, 3, 3, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Hypergeometric| x.pmf(arg); - test_case(0, 0, 0, 1.0, pmf(0)); - test_case(1, 1, 1, 1.0, pmf(1)); - test_case(2, 1, 1, 0.5, pmf(0)); - test_case(2, 1, 1, 0.5, pmf(1)); - test_case(2, 2, 2, 1.0, pmf(2)); - test_case(10, 1, 1, 0.9, pmf(0)); - test_case(10, 1, 1, 0.1, pmf(1)); - test_case(10, 5, 3, 0.41666666666666666667, pmf(1)); - test_case(10, 5, 3, 0.083333333333333333333, pmf(3)); + test_exact(0, 0, 0, 1.0, pmf(0)); + test_exact(1, 1, 1, 1.0, pmf(1)); + test_exact(2, 1, 1, 0.5, pmf(0)); + test_exact(2, 1, 1, 0.5, pmf(1)); + test_exact(2, 2, 2, 1.0, pmf(2)); + test_exact(10, 1, 1, 0.9, pmf(0)); + test_exact(10, 1, 1, 0.1, pmf(1)); + test_exact(10, 5, 3, 0.41666666666666666667, pmf(1)); + test_exact(10, 5, 3, 0.083333333333333333333, pmf(3)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Hypergeometric| x.ln_pmf(arg); - test_case(0, 0, 0, 0.0, ln_pmf(0)); - test_case(1, 1, 1, 0.0, ln_pmf(1)); - test_case(2, 1, 1, -0.6931471805599453094172, ln_pmf(0)); - test_case(2, 1, 1, -0.6931471805599453094172, ln_pmf(1)); - test_case(2, 2, 2, 0.0, ln_pmf(2)); - test_almost(10, 1, 1, -0.1053605156578263012275, 1e-14, ln_pmf(0)); - test_almost(10, 1, 1, -2.302585092994045684018, 1e-14, ln_pmf(1)); - test_almost(10, 5, 3, -0.875468737353899935621, 1e-14, ln_pmf(1)); - test_almost(10, 5, 3, -2.484906649788000310234, 1e-14, ln_pmf(3)); + test_exact(0, 0, 0, 0.0, ln_pmf(0)); + test_exact(1, 1, 1, 0.0, ln_pmf(1)); + test_exact(2, 1, 1, -0.6931471805599453094172, ln_pmf(0)); + test_exact(2, 1, 1, -0.6931471805599453094172, ln_pmf(1)); + test_exact(2, 2, 2, 0.0, ln_pmf(2)); + test_absolute(10, 1, 1, -0.1053605156578263012275, 1e-14, ln_pmf(0)); + test_absolute(10, 1, 1, -2.302585092994045684018, 1e-14, ln_pmf(1)); + test_absolute(10, 5, 3, -0.875468737353899935621, 1e-14, ln_pmf(1)); + test_absolute(10, 5, 3, -2.484906649788000310234, 1e-14, ln_pmf(3)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg); - test_case(2, 1, 1, 0.5, cdf(0)); - test_almost(10, 1, 1, 0.9, 1e-14, cdf(0)); - test_almost(10, 5, 3, 0.5, 1e-15, cdf(1)); - test_almost(10, 5, 3, 11.0 / 12.0, 1e-14, cdf(2)); - test_almost(10000, 2, 9800, 199.0 / 499950.0, 1e-14, cdf(0)); - test_almost(10000, 2, 9800, 19799.0 / 499950.0, 1e-12, cdf(1)); + test_exact(2, 1, 1, 0.5, cdf(0)); + test_absolute(10, 1, 1, 0.9, 1e-14, cdf(0)); + test_absolute(10, 5, 3, 0.5, 1e-15, cdf(1)); + test_absolute(10, 5, 3, 11.0 / 12.0, 1e-14, cdf(2)); + test_absolute(10000, 2, 9800, 199.0 / 499950.0, 1e-14, cdf(0)); + test_absolute(10000, 2, 9800, 19799.0 / 499950.0, 1e-12, cdf(1)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg); - test_case(2, 1, 1, 0.5, sf(0)); - test_almost(10, 1, 1, 0.1, 1e-14, sf(0)); - test_almost(10, 5, 3, 0.5, 1e-15, sf(1)); - test_almost(10, 5, 3, 1.0 / 12.0, 1e-14, sf(2)); - test_almost(10000, 2, 9800, 499751. / 499950.0, 1e-10, sf(0)); - test_almost(10000, 2, 9800, 480151. / 499950.0, 1e-10, sf(1)); + test_exact(2, 1, 1, 0.5, sf(0)); + test_absolute(10, 1, 1, 0.1, 1e-14, sf(0)); + test_absolute(10, 5, 3, 0.5, 1e-15, sf(1)); + test_absolute(10, 5, 3, 1.0 / 12.0, 1e-14, sf(2)); + test_absolute(10000, 2, 9800, 499751. / 499950.0, 1e-10, sf(0)); + test_absolute(10000, 2, 9800, 480151. / 499950.0, 1e-10, sf(1)); } #[test] fn test_cdf_arg_too_big() { let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg); - test_case(0, 0, 0, 1.0, cdf(0)); + test_exact(0, 0, 0, 1.0, cdf(0)); } #[test] fn test_cdf_arg_too_small() { let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg); - test_case(2, 2, 2, 0.0, cdf(0)); + test_exact(2, 2, 2, 0.0, cdf(0)); } #[test] fn test_sf_arg_too_big() { let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg); - test_case(0, 0, 0, 0.0, sf(0)); + test_exact(0, 0, 0, 0.0, sf(0)); } #[test] fn test_sf_arg_too_small() { let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg); - test_case(2, 2, 2, 1.0, sf(0)); + test_exact(2, 2, 2, 1.0, sf(0)); } #[test] fn test_discrete() { - test::check_discrete_distribution(&try_create(5, 4, 3), 4); - test::check_discrete_distribution(&try_create(3, 2, 1), 2); + test::check_discrete_distribution(&create_ok(5, 4, 3), 4); + test::check_discrete_distribution(&create_ok(3, 2, 1), 2); } } From ce1b1a48cba16f857fe5e3dcc1a8ad042e1434ff Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:53:16 +0200 Subject: [PATCH 153/185] Use testing_boiler! for InverseGamma --- src/distribution/inverse_gamma.rs | 145 +++++++++++------------------- 1 file changed, 52 insertions(+), 93 deletions(-) diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index d22d2239..8314a9dc 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -313,176 +313,135 @@ impl Continuous for InverseGamma { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, InverseGamma}; use crate::distribution::internal::*; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(shape: f64, rate: f64) -> InverseGamma { - let n = InverseGamma::new(shape, rate); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(shape: f64, rate: f64) { - let n = try_create(shape, rate); - assert_eq!(shape, n.shape()); - assert_eq!(rate, n.rate()); - } - - fn bad_create_case(shape: f64, rate: f64) { - let n = InverseGamma::new(shape, rate); - assert!(n.is_err()); - } - - fn get_value(shape: f64, rate: f64, eval: F) -> f64 - where F: Fn(InverseGamma) -> f64 - { - let n = try_create(shape, rate); - eval(n) - } - - fn test_case(shape: f64, rate: f64, expected: f64, eval: F) - where F: Fn(InverseGamma) -> f64 - { - let x = get_value(shape, rate, eval); - assert_eq!(expected, x); - } - - fn test_almost(shape: f64, rate: f64, expected: f64, acc: f64, eval: F) - where F: Fn(InverseGamma) -> f64 - { - let x = get_value(shape, rate, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(shape: f64, rate: f64; InverseGamma); #[test] fn test_create() { - create_case(0.1, 0.1); - create_case(1.0, 1.0); + create_ok(0.1, 0.1); + create_ok(1.0, 1.0); } #[test] fn test_bad_create() { - bad_create_case(0.0, 1.0); - bad_create_case(-1.0, 1.0); - bad_create_case(-100.0, 1.0); - bad_create_case(f64::NEG_INFINITY, 1.0); - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, 0.0); - bad_create_case(1.0, -1.0); - bad_create_case(1.0, -100.0); - bad_create_case(1.0, f64::NEG_INFINITY); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::INFINITY, 1.0); - bad_create_case(1.0, f64::INFINITY); - bad_create_case(f64::INFINITY, f64::INFINITY); + create_err(0.0, 1.0); + create_err(-1.0, 1.0); + create_err(-100.0, 1.0); + create_err(f64::NEG_INFINITY, 1.0); + create_err(f64::NAN, 1.0); + create_err(1.0, 0.0); + create_err(1.0, -1.0); + create_err(1.0, -100.0); + create_err(1.0, f64::NEG_INFINITY); + create_err(1.0, f64::NAN); + create_err(f64::INFINITY, 1.0); + create_err(1.0, f64::INFINITY); + create_err(f64::INFINITY, f64::INFINITY); } #[test] fn test_mean() { let mean = |x: InverseGamma| x.mean().unwrap(); - test_almost(1.1, 0.1, 1.0, 1e-14, mean); - test_almost(1.1, 1.0, 10.0, 1e-14, mean); + test_absolute(1.1, 0.1, 1.0, 1e-14, mean); + test_absolute(1.1, 1.0, 10.0, 1e-14, mean); } #[test] - #[should_panic] fn test_mean_with_shape_lte_1() { - let mean = |x: InverseGamma| x.mean().unwrap(); - get_value(0.1, 0.1, mean); + test_none(0.1, 0.1, |dist| dist.mean()); } #[test] fn test_variance() { let variance = |x: InverseGamma| x.variance().unwrap(); - test_almost(2.1, 0.1, 0.08264462809917355371901, 1e-15, variance); - test_almost(2.1, 1.0, 8.264462809917355371901, 1e-13, variance); + test_absolute(2.1, 0.1, 0.08264462809917355371901, 1e-15, variance); + test_absolute(2.1, 1.0, 8.264462809917355371901, 1e-13, variance); } #[test] - #[should_panic] fn test_variance_with_shape_lte_2() { - let variance = |x: InverseGamma| x.variance().unwrap(); - get_value(0.1, 0.1, variance); + test_none(0.1, 0.1, |dist| dist.variance()); } #[test] fn test_entropy() { let entropy = |x: InverseGamma| x.entropy().unwrap(); - test_almost(0.1, 0.1, 11.51625799319234475054, 1e-14, entropy); - test_almost(1.0, 1.0, 2.154431329803065721213, 1e-14, entropy); + test_absolute(0.1, 0.1, 11.51625799319234475054, 1e-14, entropy); + test_absolute(1.0, 1.0, 2.154431329803065721213, 1e-14, entropy); } #[test] fn test_skewness() { let skewness = |x: InverseGamma| x.skewness().unwrap(); - test_almost(3.1, 0.1, 41.95235392680606187966, 1e-13, skewness); - test_almost(3.1, 1.0, 41.95235392680606187966, 1e-13, skewness); - test_case(5.0, 0.1, 3.464101615137754587055, skewness); + test_absolute(3.1, 0.1, 41.95235392680606187966, 1e-13, skewness); + test_absolute(3.1, 1.0, 41.95235392680606187966, 1e-13, skewness); + test_exact(5.0, 0.1, 3.464101615137754587055, skewness); } #[test] - #[should_panic] fn test_skewness_with_shape_lte_3() { - let skewness = |x: InverseGamma| x.skewness().unwrap(); - get_value(0.1, 0.1, skewness); + test_none(0.1, 0.1, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: InverseGamma| x.mode().unwrap(); - test_case(0.1, 0.1, 0.09090909090909090909091, mode); - test_case(1.0, 1.0, 0.5, mode); + test_exact(0.1, 0.1, 0.09090909090909090909091, mode); + test_exact(1.0, 1.0, 0.5, mode); } #[test] fn test_min_max() { let min = |x: InverseGamma| x.min(); let max = |x: InverseGamma| x.max(); - test_case(1.0, 1.0, 0.0, min); - test_case(1.0, 1.0, f64::INFINITY, max); + test_exact(1.0, 1.0, 0.0, min); + test_exact(1.0, 1.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: InverseGamma| x.pdf(arg); - test_almost(0.1, 0.1, 0.0628591853882328004197, 1e-15, pdf(1.2)); - test_almost(0.1, 1.0, 0.0297426109178248997426, 1e-15, pdf(2.0)); - test_case(1.0, 0.1, 0.04157808822362745501024, pdf(1.5)); - test_case(1.0, 1.0, 0.3018043114632487660842, pdf(1.2)); + test_absolute(0.1, 0.1, 0.0628591853882328004197, 1e-15, pdf(1.2)); + test_absolute(0.1, 1.0, 0.0297426109178248997426, 1e-15, pdf(2.0)); + test_exact(1.0, 0.1, 0.04157808822362745501024, pdf(1.5)); + test_exact(1.0, 1.0, 0.3018043114632487660842, pdf(1.2)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: InverseGamma| x.ln_pdf(arg); - test_almost(0.1, 0.1, 0.0628591853882328004197f64.ln(), 1e-15, ln_pdf(1.2)); - test_almost(0.1, 1.0, 0.0297426109178248997426f64.ln(), 1e-15, ln_pdf(2.0)); - test_case(1.0, 0.1, 0.04157808822362745501024f64.ln(), ln_pdf(1.5)); - test_case(1.0, 1.0, 0.3018043114632487660842f64.ln(), ln_pdf(1.2)); + test_absolute(0.1, 0.1, 0.0628591853882328004197f64.ln(), 1e-15, ln_pdf(1.2)); + test_absolute(0.1, 1.0, 0.0297426109178248997426f64.ln(), 1e-15, ln_pdf(2.0)); + test_exact(1.0, 0.1, 0.04157808822362745501024f64.ln(), ln_pdf(1.5)); + test_exact(1.0, 1.0, 0.3018043114632487660842f64.ln(), ln_pdf(1.2)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: InverseGamma| x.cdf(arg); - test_almost(0.1, 0.1, 0.1862151961946054271994, 1e-14, cdf(1.2)); - test_almost(0.1, 1.0, 0.05859755410986647796141, 1e-14, cdf(2.0)); - test_case(1.0, 0.1, 0.9355069850316177377304, cdf(1.5)); - test_almost(1.0, 1.0, 0.4345982085070782231613, 1e-14, cdf(1.2)); + test_absolute(0.1, 0.1, 0.1862151961946054271994, 1e-14, cdf(1.2)); + test_absolute(0.1, 1.0, 0.05859755410986647796141, 1e-14, cdf(2.0)); + test_exact(1.0, 0.1, 0.9355069850316177377304, cdf(1.5)); + test_absolute(1.0, 1.0, 0.4345982085070782231613, 1e-14, cdf(1.2)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: InverseGamma| x.sf(arg); - test_almost(0.1, 0.1, 0.8137848038053936, 1e-14, sf(1.2)); - test_almost(0.1, 1.0, 0.9414024458901327, 1e-14, sf(2.0)); - test_almost(1.0, 0.1, 0.0644930149683822, 1e-14, sf(1.5)); - test_almost(1.0, 1.0, 0.565401791492922, 1e-14, sf(1.2)); + test_absolute(0.1, 0.1, 0.8137848038053936, 1e-14, sf(1.2)); + test_absolute(0.1, 1.0, 0.9414024458901327, 1e-14, sf(2.0)); + test_absolute(1.0, 0.1, 0.0644930149683822, 1e-14, sf(1.5)); + test_absolute(1.0, 1.0, 0.565401791492922, 1e-14, sf(1.2)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(1.0, 0.5), 0.0, 100.0); - test::check_continuous_distribution(&try_create(9.0, 2.0), 0.0, 100.0); + test::check_continuous_distribution(&create_ok(1.0, 0.5), 0.0, 100.0); + test::check_continuous_distribution(&create_ok(9.0, 2.0), 0.0, 100.0); } } From ea4b83b0eaa0877c5402a5c3c1de20ad1247f10e Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:55:16 +0200 Subject: [PATCH 154/185] Use testing_boiler! for Laplace --- src/distribution/laplace.rs | 209 +++++++++++++++--------------------- 1 file changed, 87 insertions(+), 122 deletions(-) diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index 1ed74132..0ccbac78 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -302,174 +302,139 @@ mod tests { use super::*; use rand::thread_rng; - fn try_create(location: f64, scale: f64) -> Laplace { - let n = Laplace::new(location, scale); - assert!(n.is_ok()); - n.unwrap() - } - - fn bad_create_case(location: f64, scale: f64) { - let n = Laplace::new(location, scale); - assert!(n.is_err()); - } + use crate::testing_boiler; - fn test_case(location: f64, scale: f64, expected: f64, eval: F) - where - F: Fn(Laplace) -> f64, - { - let n = try_create(location, scale); - let x = eval(n); - assert_eq!(expected, x); - } - - fn test_is_nan(location: f64, scale: f64, eval: F) - where - F: Fn(Laplace) -> f64, - { - let n = try_create(location, scale); - let x = eval(n); - assert!(x.is_nan()); - } - - fn test_almost(location: f64, scale: f64, expected: f64, acc: f64, eval: F) - where - F: Fn(Laplace) -> f64, - { - let n = try_create(location, scale); - let x = eval(n); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(location: f64, scale: f64; Laplace); // A wrapper for the `assert_relative_eq!` macro from the approx crate. // // `rtol` is the accepable relative error. This function is for testing // relative tolerance *only*. It should not be used with `expected = 0`. // - fn test_rel_close(location: f64, scale: f64, expected: f64, rtol: f64, eval: F) + fn test_rel_close(location: f64, scale: f64, expected: f64, rtol: f64, get_fn: F) where F: Fn(Laplace) -> f64, { - let n = try_create(location, scale); - let x = eval(n); + let x = create_and_get(location, scale, get_fn); assert_relative_eq!(expected, x, epsilon = 0.0, max_relative = rtol); } #[test] fn test_create() { - try_create(1.0, 2.0); - try_create(f64::NEG_INFINITY, 0.1); - try_create(-5.0 - 1.0, 1.0); - try_create(0.0, 5.0); - try_create(1.0, 7.0); - try_create(5.0, 10.0); - try_create(f64::INFINITY, f64::INFINITY); + create_ok(1.0, 2.0); + create_ok(f64::NEG_INFINITY, 0.1); + create_ok(-5.0 - 1.0, 1.0); + create_ok(0.0, 5.0); + create_ok(1.0, 7.0); + create_ok(5.0, 10.0); + create_ok(f64::INFINITY, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(2.0, -1.0); - bad_create_case(f64::NAN, 1.0); - bad_create_case(f64::NAN, -1.0); + create_err(2.0, -1.0); + create_err(f64::NAN, 1.0); + create_err(f64::NAN, -1.0); } #[test] fn test_mean() { let mean = |x: Laplace| x.mean().unwrap(); - test_case(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, mean); - test_case(-5.0 - 1.0, 1.0, -6.0, mean); - test_case(0.0, 5.0, 0.0, mean); - test_case(1.0, 10.0, 1.0, mean); - test_case(f64::INFINITY, f64::INFINITY, f64::INFINITY, mean); + test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, mean); + test_exact(-5.0 - 1.0, 1.0, -6.0, mean); + test_exact(0.0, 5.0, 0.0, mean); + test_exact(1.0, 10.0, 1.0, mean); + test_exact(f64::INFINITY, f64::INFINITY, f64::INFINITY, mean); } #[test] fn test_variance() { let variance = |x: Laplace| x.variance().unwrap(); - test_almost(f64::NEG_INFINITY, 0.1, 0.02, 1E-12, variance); - test_almost(-5.0 - 1.0, 1.0, 2.0, 1E-12, variance); - test_almost(0.0, 5.0, 50.0, 1E-12, variance); - test_almost(1.0, 7.0, 98.0, 1E-12, variance); - test_almost(5.0, 10.0, 200.0, 1E-12, variance); - test_almost(f64::INFINITY, f64::INFINITY, f64::INFINITY, 1E-12, variance); + test_absolute(f64::NEG_INFINITY, 0.1, 0.02, 1E-12, variance); + test_absolute(-5.0 - 1.0, 1.0, 2.0, 1E-12, variance); + test_absolute(0.0, 5.0, 50.0, 1E-12, variance); + test_absolute(1.0, 7.0, 98.0, 1E-12, variance); + test_absolute(5.0, 10.0, 200.0, 1E-12, variance); + test_absolute(f64::INFINITY, f64::INFINITY, f64::INFINITY, 1E-12, variance); } #[test] fn test_entropy() { let entropy = |x: Laplace| x.entropy().unwrap(); - test_almost( + test_absolute( f64::NEG_INFINITY, 0.1, (2.0 * f64::consts::E * 0.1).ln(), 1E-12, entropy, ); - test_almost(-6.0, 1.0, (2.0 * f64::consts::E).ln(), 1E-12, entropy); - test_almost(1.0, 7.0, (2.0 * f64::consts::E * 7.0).ln(), 1E-12, entropy); - test_almost(5., 10., (2. * f64::consts::E * 10.).ln(), 1E-12, entropy); - test_almost(f64::INFINITY, f64::INFINITY, f64::INFINITY, 1E-12, entropy); + test_absolute(-6.0, 1.0, (2.0 * f64::consts::E).ln(), 1E-12, entropy); + test_absolute(1.0, 7.0, (2.0 * f64::consts::E * 7.0).ln(), 1E-12, entropy); + test_absolute(5., 10., (2. * f64::consts::E * 10.).ln(), 1E-12, entropy); + test_absolute(f64::INFINITY, f64::INFINITY, f64::INFINITY, 1E-12, entropy); } #[test] fn test_skewness() { let skewness = |x: Laplace| x.skewness().unwrap(); - test_case(f64::NEG_INFINITY, 0.1, 0.0, skewness); - test_case(-6.0, 1.0, 0.0, skewness); - test_case(1.0, 7.0, 0.0, skewness); - test_case(5.0, 10.0, 0.0, skewness); - test_case(f64::INFINITY, f64::INFINITY, 0.0, skewness); + test_exact(f64::NEG_INFINITY, 0.1, 0.0, skewness); + test_exact(-6.0, 1.0, 0.0, skewness); + test_exact(1.0, 7.0, 0.0, skewness); + test_exact(5.0, 10.0, 0.0, skewness); + test_exact(f64::INFINITY, f64::INFINITY, 0.0, skewness); } #[test] fn test_mode() { let mode = |x: Laplace| x.mode().unwrap(); - test_case(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, mode); - test_case(-6.0, 1.0, -6.0, mode); - test_case(1.0, 7.0, 1.0, mode); - test_case(5.0, 10.0, 5.0, mode); - test_case(f64::INFINITY, f64::INFINITY, f64::INFINITY, mode); + test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, mode); + test_exact(-6.0, 1.0, -6.0, mode); + test_exact(1.0, 7.0, 1.0, mode); + test_exact(5.0, 10.0, 5.0, mode); + test_exact(f64::INFINITY, f64::INFINITY, f64::INFINITY, mode); } #[test] fn test_median() { let median = |x: Laplace| x.median(); - test_case(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, median); - test_case(-6.0, 1.0, -6.0, median); - test_case(1.0, 7.0, 1.0, median); - test_case(5.0, 10.0, 5.0, median); - test_case(f64::INFINITY, f64::INFINITY, f64::INFINITY, median); + test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, median); + test_exact(-6.0, 1.0, -6.0, median); + test_exact(1.0, 7.0, 1.0, median); + test_exact(5.0, 10.0, 5.0, median); + test_exact(f64::INFINITY, f64::INFINITY, f64::INFINITY, median); } #[test] fn test_min() { - test_case(0.0, 1.0, f64::NEG_INFINITY, |l| l.min()); + test_exact(0.0, 1.0, f64::NEG_INFINITY, |l| l.min()); } #[test] fn test_max() { - test_case(0.0, 1.0, f64::INFINITY, |l| l.max()); + test_exact(0.0, 1.0, f64::INFINITY, |l| l.max()); } #[test] fn test_density() { let pdf = |arg: f64| move |x: Laplace| x.pdf(arg); - test_almost(0.0, 0.1, 1.529511602509129e-06, 1E-12, pdf(1.5)); - test_almost(1.0, 0.1, 7.614989872356341e-08, 1E-12, pdf(2.8)); - test_almost(-1.0, 0.1, 3.8905661205668983e-19, 1E-12, pdf(-5.4)); - test_almost(5.0, 0.1, 5.056107463052243e-43, 1E-12, pdf(-4.9)); - test_almost(-5.0, 0.1, 1.9877248679543235e-30, 1E-12, pdf(2.0)); - test_almost(f64::INFINITY, 0.1, 0.0, 1E-12, pdf(5.5)); - test_almost(f64::NEG_INFINITY, 0.1, 0.0, 1E-12, pdf(-0.0)); - test_almost(0.0, 1.0, 0.0, 1E-12, pdf(f64::INFINITY)); - test_almost(1.0, 1.0, 0.00915781944436709, 1E-12, pdf(5.0)); - test_almost(-1.0, 1.0, 0.5, 1E-12, pdf(-1.0)); - test_almost(5.0, 1.0, 0.0012393760883331792, 1E-12, pdf(-1.0)); - test_almost(-5.0, 1.0, 0.0002765421850739168, 1E-12, pdf(2.5)); - test_almost(f64::INFINITY, 0.1, 0.0, 1E-12, pdf(2.0)); - test_almost(f64::NEG_INFINITY, 0.1, 0.0, 1E-12, pdf(15.0)); - test_almost(0.0, f64::INFINITY, 0.0, 1E-12, pdf(89.3)); - test_almost(1.0, f64::INFINITY, 0.0, 1E-12, pdf(-0.1)); - test_almost(-1.0, f64::INFINITY, 0.0, 1E-12, pdf(0.1)); - test_almost(5.0, f64::INFINITY, 0.0, 1E-12, pdf(-6.1)); - test_almost(-5.0, f64::INFINITY, 0.0, 1E-12, pdf(-10.0)); + test_absolute(0.0, 0.1, 1.529511602509129e-06, 1E-12, pdf(1.5)); + test_absolute(1.0, 0.1, 7.614989872356341e-08, 1E-12, pdf(2.8)); + test_absolute(-1.0, 0.1, 3.8905661205668983e-19, 1E-12, pdf(-5.4)); + test_absolute(5.0, 0.1, 5.056107463052243e-43, 1E-12, pdf(-4.9)); + test_absolute(-5.0, 0.1, 1.9877248679543235e-30, 1E-12, pdf(2.0)); + test_absolute(f64::INFINITY, 0.1, 0.0, 1E-12, pdf(5.5)); + test_absolute(f64::NEG_INFINITY, 0.1, 0.0, 1E-12, pdf(-0.0)); + test_absolute(0.0, 1.0, 0.0, 1E-12, pdf(f64::INFINITY)); + test_absolute(1.0, 1.0, 0.00915781944436709, 1E-12, pdf(5.0)); + test_absolute(-1.0, 1.0, 0.5, 1E-12, pdf(-1.0)); + test_absolute(5.0, 1.0, 0.0012393760883331792, 1E-12, pdf(-1.0)); + test_absolute(-5.0, 1.0, 0.0002765421850739168, 1E-12, pdf(2.5)); + test_absolute(f64::INFINITY, 0.1, 0.0, 1E-12, pdf(2.0)); + test_absolute(f64::NEG_INFINITY, 0.1, 0.0, 1E-12, pdf(15.0)); + test_absolute(0.0, f64::INFINITY, 0.0, 1E-12, pdf(89.3)); + test_absolute(1.0, f64::INFINITY, 0.0, 1E-12, pdf(-0.1)); + test_absolute(-1.0, f64::INFINITY, 0.0, 1E-12, pdf(0.1)); + test_absolute(5.0, f64::INFINITY, 0.0, 1E-12, pdf(-6.1)); + test_absolute(-5.0, f64::INFINITY, 0.0, 1E-12, pdf(-10.0)); test_is_nan(f64::INFINITY, f64::INFINITY, pdf(2.0)); test_is_nan(f64::NEG_INFINITY, f64::INFINITY, pdf(-5.1)); } @@ -477,25 +442,25 @@ mod tests { #[test] fn test_ln_density() { let ln_pdf = |arg: f64| move |x: Laplace| x.ln_pdf(arg); - test_almost(0.0, 0.1, -13.3905620875659, 1E-12, ln_pdf(1.5)); - test_almost(1.0, 0.1, -16.390562087565897, 1E-12, ln_pdf(2.8)); - test_almost(-1.0, 0.1, -42.39056208756591, 1E-12, ln_pdf(-5.4)); - test_almost(5.0, 0.1, -97.3905620875659, 1E-12, ln_pdf(-4.9)); - test_almost(-5.0, 0.1, -68.3905620875659, 1E-12, ln_pdf(2.0)); - test_case(f64::INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(5.5)); - test_case(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(-0.0)); - test_case(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); - test_almost(1.0, 1.0, -4.693147180559945, 1E-12, ln_pdf(5.0)); - test_almost(-1.0, 1.0, -f64::consts::LN_2, 1E-12, ln_pdf(-1.0)); - test_almost(5.0, 1.0, -6.693147180559945, 1E-12, ln_pdf(-1.0)); - test_almost(-5.0, 1.0, -8.193147180559945, 1E-12, ln_pdf(2.5)); - test_case(f64::INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(2.0)); - test_case(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(15.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(89.3)); - test_case(1.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-0.1)); - test_case(-1.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.1)); - test_case(5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-6.1)); - test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-10.0)); + test_absolute(0.0, 0.1, -13.3905620875659, 1E-12, ln_pdf(1.5)); + test_absolute(1.0, 0.1, -16.390562087565897, 1E-12, ln_pdf(2.8)); + test_absolute(-1.0, 0.1, -42.39056208756591, 1E-12, ln_pdf(-5.4)); + test_absolute(5.0, 0.1, -97.3905620875659, 1E-12, ln_pdf(-4.9)); + test_absolute(-5.0, 0.1, -68.3905620875659, 1E-12, ln_pdf(2.0)); + test_exact(f64::INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(5.5)); + test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(-0.0)); + test_exact(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_absolute(1.0, 1.0, -4.693147180559945, 1E-12, ln_pdf(5.0)); + test_absolute(-1.0, 1.0, -f64::consts::LN_2, 1E-12, ln_pdf(-1.0)); + test_absolute(5.0, 1.0, -6.693147180559945, 1E-12, ln_pdf(-1.0)); + test_absolute(-5.0, 1.0, -8.193147180559945, 1E-12, ln_pdf(2.5)); + test_exact(f64::INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(2.0)); + test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(15.0)); + test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(89.3)); + test_exact(1.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-0.1)); + test_exact(-1.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.1)); + test_exact(5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-6.1)); + test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-10.0)); test_is_nan(f64::INFINITY, f64::INFINITY, ln_pdf(2.0)); test_is_nan(f64::NEG_INFINITY, f64::INFINITY, ln_pdf(-5.1)); } @@ -563,7 +528,7 @@ mod tests { #[test] fn test_sample() { use ::rand::distributions::Distribution; - let l = try_create(0.1, 0.5); + let l = create_ok(0.1, 0.5); l.sample(&mut thread_rng()); } @@ -576,7 +541,7 @@ mod tests { // sanity check sampling let location = 0.0; let scale = 1.0; - let n = try_create(location, scale); + let n = create_ok(location, scale); let trials = 10_000; let tolerance = 250; From 63c251d659fe38167ffb5ea0c7bfc586fb7df00d Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:56:18 +0200 Subject: [PATCH 155/185] Use testing_boiler! for LogNormal --- src/distribution/log_normal.rs | 483 ++++++++++++++++----------------- 1 file changed, 227 insertions(+), 256 deletions(-) diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index b6dbff6f..88a78996 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -305,322 +305,293 @@ impl Continuous for LogNormal { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, LogNormal}; use crate::distribution::internal::*; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(mean: f64, std_dev: f64) -> LogNormal { - let n = LogNormal::new(mean, std_dev); - assert!(n.is_ok()); - n.unwrap() - } - - fn bad_create_case(mean: f64, std_dev: f64) { - let n = LogNormal::new(mean, std_dev); - assert!(n.is_err()); - } - - fn get_value(mean: f64, std_dev: f64, eval: F) -> f64 - where F: Fn(LogNormal) -> f64 - { - let n = try_create(mean, std_dev); - eval(n) - } - - fn test_case(mean: f64, std_dev: f64, expected: f64, eval: F) - where F: Fn(LogNormal) -> f64 - { - let x = get_value(mean, std_dev, eval); - assert_eq!(expected, x); - } - - fn test_almost(mean: f64, std_dev: f64, expected: f64, acc: f64, eval: F) - where F: Fn(LogNormal) -> f64 - { - let x = get_value(mean, std_dev, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(mean: f64, std_dev: f64; LogNormal); #[test] fn test_create() { - try_create(10.0, 0.1); - try_create(-5.0, 1.0); - try_create(0.0, 10.0); - try_create(10.0, 100.0); - try_create(-5.0, f64::INFINITY); + create_ok(10.0, 0.1); + create_ok(-5.0, 1.0); + create_ok(0.0, 10.0); + create_ok(10.0, 100.0); + create_ok(-5.0, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(0.0, 0.0); - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN); - bad_create_case(1.0, -1.0); + create_err(0.0, 0.0); + create_err(f64::NAN, 1.0); + create_err(1.0, f64::NAN); + create_err(f64::NAN, f64::NAN); + create_err(1.0, -1.0); } #[test] fn test_mean() { let mean = |x: LogNormal| x.mean().unwrap(); - test_case(-1.0, 0.1, 0.369723444544058982601, mean); - test_case(-1.0, 1.5, 1.133148453066826316829, mean); - test_case(-1.0, 2.5, 8.372897488127264663205, mean); - test_case(-1.0, 5.5, 1362729.18425285481771, mean); - test_case(-0.1, 0.1, 0.9093729344682314204933, mean); - test_case(-0.1, 1.5, 2.787095460565850768514, mean); - test_case(-0.1, 2.5, 20.59400471119602917533, mean); - test_almost(-0.1, 5.5, 3351772.941252693807591, 1e-9, mean); - test_case(0.1, 0.1, 1.110710610355705232259, mean); - test_case(0.1, 1.5, 3.40416608279081898632, mean); - test_almost(0.1, 2.5, 25.15357415581836182776, 1e-14, mean); - test_almost(0.1, 5.5, 4093864.715172665106863, 1e-8, mean); - test_almost(1.5, 0.1, 4.50415363028848413209, 1e-15, mean); - test_case(1.5, 1.5, 13.80457418606709491926, mean); - test_case(1.5, 2.5, 102.0027730826996844534, mean); - test_case(1.5, 5.5, 16601440.05723477471392, mean); - test_almost(2.5, 0.1, 12.24355896580102707724, 1e-14, mean); - test_almost(2.5, 1.5, 37.52472315960099891407, 1e-11, mean); - test_case(2.5, 2.5, 277.2722845231339804081, mean); - test_case(2.5, 5.5, 45127392.83383337999291, mean); - test_almost(5.5, 0.1, 245.9184556788219446833, 1e-13, mean); - test_case(5.5, 1.5, 753.7042125545612656606, mean); - test_case(5.5, 2.5, 5569.162708566004074422, mean); - test_case(5.5, 5.5, 906407915.0111549133446, mean); + test_exact(-1.0, 0.1, 0.369723444544058982601, mean); + test_exact(-1.0, 1.5, 1.133148453066826316829, mean); + test_exact(-1.0, 2.5, 8.372897488127264663205, mean); + test_exact(-1.0, 5.5, 1362729.18425285481771, mean); + test_exact(-0.1, 0.1, 0.9093729344682314204933, mean); + test_exact(-0.1, 1.5, 2.787095460565850768514, mean); + test_exact(-0.1, 2.5, 20.59400471119602917533, mean); + test_absolute(-0.1, 5.5, 3351772.941252693807591, 1e-9, mean); + test_exact(0.1, 0.1, 1.110710610355705232259, mean); + test_exact(0.1, 1.5, 3.40416608279081898632, mean); + test_absolute(0.1, 2.5, 25.15357415581836182776, 1e-14, mean); + test_absolute(0.1, 5.5, 4093864.715172665106863, 1e-8, mean); + test_absolute(1.5, 0.1, 4.50415363028848413209, 1e-15, mean); + test_exact(1.5, 1.5, 13.80457418606709491926, mean); + test_exact(1.5, 2.5, 102.0027730826996844534, mean); + test_exact(1.5, 5.5, 16601440.05723477471392, mean); + test_absolute(2.5, 0.1, 12.24355896580102707724, 1e-14, mean); + test_absolute(2.5, 1.5, 37.52472315960099891407, 1e-11, mean); + test_exact(2.5, 2.5, 277.2722845231339804081, mean); + test_exact(2.5, 5.5, 45127392.83383337999291, mean); + test_absolute(5.5, 0.1, 245.9184556788219446833, 1e-13, mean); + test_exact(5.5, 1.5, 753.7042125545612656606, mean); + test_exact(5.5, 2.5, 5569.162708566004074422, mean); + test_exact(5.5, 5.5, 906407915.0111549133446, mean); } #[test] fn test_variance() { let variance = |x: LogNormal| x.variance().unwrap(); - test_almost(-1.0, 0.1, 0.001373811865368952608715, 1e-16, variance); - test_case(-1.0, 1.5, 10.898468544015731954, variance); - test_case(-1.0, 2.5, 36245.39726189994988081, variance); - test_almost(-1.0, 5.5, 2.5481629178024539E+25, 1e10, variance); - test_almost(-0.1, 0.1, 0.008311077467909703803238, 1e-16, variance); - test_case(-0.1, 1.5, 65.93189259328902509552, variance); - test_almost(-0.1, 2.5, 219271.8756420929704707, 1e-10, variance); - test_almost(-0.1, 5.5, 1.541548733459471E+26, 1e12, variance); - test_almost(0.1, 0.1, 0.01239867063063756838894, 1e-15, variance); - test_almost(0.1, 1.5, 98.35882573290010981464, 1e-13, variance); - test_almost(0.1, 2.5, 327115.1995809995715014, 1e-10, variance); - test_almost(0.1, 5.5, 2.299720473192458E+26, 1e12, variance); - test_almost(1.5, 0.1, 0.2038917589520099120699, 1e-14, variance); - test_almost(1.5, 1.5, 1617.476145997433210727, 1e-12, variance); - test_almost(1.5, 2.5, 5379293.910566451644527, 1e-9, variance); - test_almost(1.5, 5.5, 3.7818090853910142E+27, 1e12, variance); - test_almost(2.5, 0.1, 1.506567645006046841936, 1e-13, variance); - test_almost(2.5, 1.5, 11951.62198145717670088, 1e-11, variance); - test_case(2.5, 2.5, 39747904.47781154725843, variance); - test_almost(2.5, 5.5, 2.7943999487399818E+28, 1e13, variance); - test_almost(5.5, 0.1, 607.7927673399807484235, 1e-11, variance); - test_case(5.5, 1.5, 4821628.436260521100027, variance); - test_case(5.5, 2.5, 16035449147.34799637823, variance); - test_case(5.5, 5.5, 1.127341399856331737823E+31, variance); + test_absolute(-1.0, 0.1, 0.001373811865368952608715, 1e-16, variance); + test_exact(-1.0, 1.5, 10.898468544015731954, variance); + test_exact(-1.0, 2.5, 36245.39726189994988081, variance); + test_absolute(-1.0, 5.5, 2.5481629178024539E+25, 1e10, variance); + test_absolute(-0.1, 0.1, 0.008311077467909703803238, 1e-16, variance); + test_exact(-0.1, 1.5, 65.93189259328902509552, variance); + test_absolute(-0.1, 2.5, 219271.8756420929704707, 1e-10, variance); + test_absolute(-0.1, 5.5, 1.541548733459471E+26, 1e12, variance); + test_absolute(0.1, 0.1, 0.01239867063063756838894, 1e-15, variance); + test_absolute(0.1, 1.5, 98.35882573290010981464, 1e-13, variance); + test_absolute(0.1, 2.5, 327115.1995809995715014, 1e-10, variance); + test_absolute(0.1, 5.5, 2.299720473192458E+26, 1e12, variance); + test_absolute(1.5, 0.1, 0.2038917589520099120699, 1e-14, variance); + test_absolute(1.5, 1.5, 1617.476145997433210727, 1e-12, variance); + test_absolute(1.5, 2.5, 5379293.910566451644527, 1e-9, variance); + test_absolute(1.5, 5.5, 3.7818090853910142E+27, 1e12, variance); + test_absolute(2.5, 0.1, 1.506567645006046841936, 1e-13, variance); + test_absolute(2.5, 1.5, 11951.62198145717670088, 1e-11, variance); + test_exact(2.5, 2.5, 39747904.47781154725843, variance); + test_absolute(2.5, 5.5, 2.7943999487399818E+28, 1e13, variance); + test_absolute(5.5, 0.1, 607.7927673399807484235, 1e-11, variance); + test_exact(5.5, 1.5, 4821628.436260521100027, variance); + test_exact(5.5, 2.5, 16035449147.34799637823, variance); + test_exact(5.5, 5.5, 1.127341399856331737823E+31, variance); } #[test] fn test_entropy() { let entropy = |x: LogNormal| x.entropy().unwrap(); - test_case(-1.0, 0.1, -1.8836465597893728867265104870209210873020761202386, entropy); - test_case(-1.0, 1.5, 0.82440364131283712375834285186996677643338789710028, entropy); - test_case(-1.0, 2.5, 1.335229265078827806963856948173628711311498693546, entropy); - test_case(-1.0, 5.5, 2.1236866254430979764250411929125703716076041932149, entropy); - test_almost(-0.1, 0.1, -0.9836465597893728922776256101467037894202344606927, 1e-15, entropy); - test_case(-0.1, 1.5, 1.7244036413128371182072277287441840743152295566462, entropy); - test_case(-0.1, 2.5, 2.2352292650788278014127418250478460091933403530919, entropy); - test_case(-0.1, 5.5, 3.0236866254430979708739260697867876694894458527608, entropy); - test_almost(0.1, 0.1, -0.7836465597893728811753953638951383851839177797845, 1e-15, entropy); - test_almost(0.1, 1.5, 1.9244036413128371293094579749957494785515462375544, 1e-15, entropy); - test_case(0.1, 2.5, 2.4352292650788278125149720712994114134296570340001, entropy); - test_case(0.1, 5.5, 3.223686625443097981976156316038353073725762533669, entropy); - test_almost(1.5, 0.1, 0.6163534402106271132734895129790789126979238797614, 1e-15, entropy); - test_case(1.5, 1.5, 3.3244036413128371237583428518699667764333878971003, entropy); - test_case(1.5, 2.5, 3.835229265078827806963856948173628711311498693546, entropy); - test_case(1.5, 5.5, 4.6236866254430979764250411929125703716076041932149, entropy); - test_case(2.5, 0.1, 1.6163534402106271132734895129790789126979238797614, entropy); - test_almost(2.5, 1.5, 4.3244036413128371237583428518699667764333878971003, 1e-15, entropy); - test_case(2.5, 2.5, 4.835229265078827806963856948173628711311498693546, entropy); - test_case(2.5, 5.5, 5.6236866254430979764250411929125703716076041932149, entropy); - test_case(5.5, 0.1, 4.6163534402106271132734895129790789126979238797614, entropy); - test_almost(5.5, 1.5, 7.3244036413128371237583428518699667764333878971003, 1e-15, entropy); - test_case(5.5, 2.5, 7.835229265078827806963856948173628711311498693546, entropy); - test_case(5.5, 5.5, 8.6236866254430979764250411929125703716076041932149, entropy); + test_exact(-1.0, 0.1, -1.8836465597893728867265104870209210873020761202386, entropy); + test_exact(-1.0, 1.5, 0.82440364131283712375834285186996677643338789710028, entropy); + test_exact(-1.0, 2.5, 1.335229265078827806963856948173628711311498693546, entropy); + test_exact(-1.0, 5.5, 2.1236866254430979764250411929125703716076041932149, entropy); + test_absolute(-0.1, 0.1, -0.9836465597893728922776256101467037894202344606927, 1e-15, entropy); + test_exact(-0.1, 1.5, 1.7244036413128371182072277287441840743152295566462, entropy); + test_exact(-0.1, 2.5, 2.2352292650788278014127418250478460091933403530919, entropy); + test_exact(-0.1, 5.5, 3.0236866254430979708739260697867876694894458527608, entropy); + test_absolute(0.1, 0.1, -0.7836465597893728811753953638951383851839177797845, 1e-15, entropy); + test_absolute(0.1, 1.5, 1.9244036413128371293094579749957494785515462375544, 1e-15, entropy); + test_exact(0.1, 2.5, 2.4352292650788278125149720712994114134296570340001, entropy); + test_exact(0.1, 5.5, 3.223686625443097981976156316038353073725762533669, entropy); + test_absolute(1.5, 0.1, 0.6163534402106271132734895129790789126979238797614, 1e-15, entropy); + test_exact(1.5, 1.5, 3.3244036413128371237583428518699667764333878971003, entropy); + test_exact(1.5, 2.5, 3.835229265078827806963856948173628711311498693546, entropy); + test_exact(1.5, 5.5, 4.6236866254430979764250411929125703716076041932149, entropy); + test_exact(2.5, 0.1, 1.6163534402106271132734895129790789126979238797614, entropy); + test_absolute(2.5, 1.5, 4.3244036413128371237583428518699667764333878971003, 1e-15, entropy); + test_exact(2.5, 2.5, 4.835229265078827806963856948173628711311498693546, entropy); + test_exact(2.5, 5.5, 5.6236866254430979764250411929125703716076041932149, entropy); + test_exact(5.5, 0.1, 4.6163534402106271132734895129790789126979238797614, entropy); + test_absolute(5.5, 1.5, 7.3244036413128371237583428518699667764333878971003, 1e-15, entropy); + test_exact(5.5, 2.5, 7.835229265078827806963856948173628711311498693546, entropy); + test_exact(5.5, 5.5, 8.6236866254430979764250411929125703716076041932149, entropy); } #[test] fn test_skewness() { let skewness = |x: LogNormal| x.skewness().unwrap(); - test_almost(-1.0, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); - test_case(-1.0, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); - test_almost(-1.0, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); - test_almost(-1.0, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); - test_almost(-0.1, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); - test_case(-0.1, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); - test_almost(-0.1, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); - test_almost(-0.1, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); - test_almost(0.1, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); - test_case(0.1, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); - test_almost(0.1, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); - test_almost(0.1, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); - test_almost(1.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); - test_case(1.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); - test_almost(1.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); - test_almost(1.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); - test_almost(2.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); - test_case(2.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); - test_almost(2.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); - test_almost(2.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); - test_almost(5.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); - test_case(5.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); - test_almost(5.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); - test_almost(5.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); + test_absolute(-1.0, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); + test_exact(-1.0, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); + test_absolute(-1.0, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); + test_absolute(-1.0, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); + test_absolute(-0.1, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); + test_exact(-0.1, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); + test_absolute(-0.1, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); + test_absolute(-0.1, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); + test_absolute(0.1, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); + test_exact(0.1, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); + test_absolute(0.1, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); + test_absolute(0.1, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); + test_absolute(1.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); + test_exact(1.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); + test_absolute(1.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); + test_absolute(1.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); + test_absolute(2.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); + test_exact(2.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); + test_absolute(2.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); + test_absolute(2.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); + test_absolute(5.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); + test_exact(5.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); + test_absolute(5.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); + test_absolute(5.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); } #[test] fn test_mode() { let mode = |x: LogNormal| x.mode().unwrap(); - test_case(-1.0, 0.1, 0.36421897957152331652213191863106773137983085909534, mode); - test_case(-1.0, 1.5, 0.03877420783172200988689983526759614326014406193602, mode); - test_case(-1.0, 2.5, 0.0007101743888425490635846003705775444086763023873619, mode); - test_case(-1.0, 5.5, 0.000000000000026810038677818032221548731163905979029274677187036, mode); - test_case(-0.1, 0.1, 0.89583413529652823774737070060865897390995185639633, mode); - test_case(-0.1, 1.5, 0.095369162215549610417813418326627245539514227574881, mode); - test_case(-0.1, 2.5, 0.0017467471362611196181003627521060283221112106850165, mode); - test_case(-0.1, 5.5, 0.00000000000006594205454219929159167575814655534255162059017114, mode); - test_case(0.1, 0.1, 1.0941742837052103542285651753780976842292770841345, mode); - test_case(0.1, 1.5, 0.11648415777349696821514223131929465848700730137808, mode); - test_case(0.1, 2.5, 0.0021334817700377079925027678518795817076296484352472, mode); - test_case(0.1, 5.5, 0.000000000000080541807296590798973741710866097756565304960216803, mode); - test_case(1.5, 0.1, 4.4370955190036645692996309927420381428715912422597, mode); - test_case(1.5, 1.5, 0.47236655274101470713804655094326791297020357913648, mode); - test_case(1.5, 2.5, 0.008651695203120634177071503957250390848166331197708, mode); - test_case(1.5, 5.5, 0.00000000000032661313427874471360158184468030186601222739665225, mode); - test_case(2.5, 0.1, 12.061276120444720299113038763305617245808510584994, mode); - test_case(2.5, 1.5, 1.2840254166877414840734205680624364583362808652815, mode); - test_case(2.5, 2.5, 0.023517745856009108236151185100432939470067655273072, mode); - test_case(2.5, 5.5, 0.00000000000088782654784596584473099190326928541185172970391855, mode); - test_case(5.5, 0.1, 242.2572068579541371904816252345031593584721473492, mode); - test_case(5.5, 1.5, 25.790339917193062089080107669377221876655268848954, mode); - test_case(5.5, 2.5, 0.47236655274101470713804655094326791297020357913648, mode); - test_case(5.5, 5.5, 0.000000000017832472908146389493511850431527026413424899198327, mode); + test_exact(-1.0, 0.1, 0.36421897957152331652213191863106773137983085909534, mode); + test_exact(-1.0, 1.5, 0.03877420783172200988689983526759614326014406193602, mode); + test_exact(-1.0, 2.5, 0.0007101743888425490635846003705775444086763023873619, mode); + test_exact(-1.0, 5.5, 0.000000000000026810038677818032221548731163905979029274677187036, mode); + test_exact(-0.1, 0.1, 0.89583413529652823774737070060865897390995185639633, mode); + test_exact(-0.1, 1.5, 0.095369162215549610417813418326627245539514227574881, mode); + test_exact(-0.1, 2.5, 0.0017467471362611196181003627521060283221112106850165, mode); + test_exact(-0.1, 5.5, 0.00000000000006594205454219929159167575814655534255162059017114, mode); + test_exact(0.1, 0.1, 1.0941742837052103542285651753780976842292770841345, mode); + test_exact(0.1, 1.5, 0.11648415777349696821514223131929465848700730137808, mode); + test_exact(0.1, 2.5, 0.0021334817700377079925027678518795817076296484352472, mode); + test_exact(0.1, 5.5, 0.000000000000080541807296590798973741710866097756565304960216803, mode); + test_exact(1.5, 0.1, 4.4370955190036645692996309927420381428715912422597, mode); + test_exact(1.5, 1.5, 0.47236655274101470713804655094326791297020357913648, mode); + test_exact(1.5, 2.5, 0.008651695203120634177071503957250390848166331197708, mode); + test_exact(1.5, 5.5, 0.00000000000032661313427874471360158184468030186601222739665225, mode); + test_exact(2.5, 0.1, 12.061276120444720299113038763305617245808510584994, mode); + test_exact(2.5, 1.5, 1.2840254166877414840734205680624364583362808652815, mode); + test_exact(2.5, 2.5, 0.023517745856009108236151185100432939470067655273072, mode); + test_exact(2.5, 5.5, 0.00000000000088782654784596584473099190326928541185172970391855, mode); + test_exact(5.5, 0.1, 242.2572068579541371904816252345031593584721473492, mode); + test_exact(5.5, 1.5, 25.790339917193062089080107669377221876655268848954, mode); + test_exact(5.5, 2.5, 0.47236655274101470713804655094326791297020357913648, mode); + test_exact(5.5, 5.5, 0.000000000017832472908146389493511850431527026413424899198327, mode); } #[test] fn test_median() { let median = |x: LogNormal| x.median(); - test_case(-1.0, 0.1, 0.36787944117144232159552377016146086744581113103177, median); - test_case(-1.0, 1.5, 0.36787944117144232159552377016146086744581113103177, median); - test_case(-1.0, 2.5, 0.36787944117144232159552377016146086744581113103177, median); - test_case(-1.0, 5.5, 0.36787944117144232159552377016146086744581113103177, median); - test_case(-0.1, 0.1, 0.90483741803595956814139238421693559530906465375738, median); - test_case(-0.1, 1.5, 0.90483741803595956814139238421693559530906465375738, median); - test_case(-0.1, 2.5, 0.90483741803595956814139238421693559530906465375738, median); - test_case(-0.1, 5.5, 0.90483741803595956814139238421693559530906465375738, median); - test_case(0.1, 0.1, 1.1051709180756476309466388234587796577416634163742, median); - test_case(0.1, 1.5, 1.1051709180756476309466388234587796577416634163742, median); - test_case(0.1, 2.5, 1.1051709180756476309466388234587796577416634163742, median); - test_case(0.1, 5.5, 1.1051709180756476309466388234587796577416634163742, median); - test_case(1.5, 0.1, 4.4816890703380648226020554601192758190057498683697, median); - test_case(1.5, 1.5, 4.4816890703380648226020554601192758190057498683697, median); - test_case(1.5, 2.5, 4.4816890703380648226020554601192758190057498683697, median); - test_case(1.5, 5.5, 4.4816890703380648226020554601192758190057498683697, median); - test_case(2.5, 0.1, 12.182493960703473438070175951167966183182767790063, median); - test_case(2.5, 1.5, 12.182493960703473438070175951167966183182767790063, median); - test_case(2.5, 2.5, 12.182493960703473438070175951167966183182767790063, median); - test_case(2.5, 5.5, 12.182493960703473438070175951167966183182767790063, median); - test_case(5.5, 0.1, 244.6919322642203879151889495118393501842287101075, median); - test_case(5.5, 1.5, 244.6919322642203879151889495118393501842287101075, median); - test_case(5.5, 2.5, 244.6919322642203879151889495118393501842287101075, median); - test_case(5.5, 5.5, 244.6919322642203879151889495118393501842287101075, median); + test_exact(-1.0, 0.1, 0.36787944117144232159552377016146086744581113103177, median); + test_exact(-1.0, 1.5, 0.36787944117144232159552377016146086744581113103177, median); + test_exact(-1.0, 2.5, 0.36787944117144232159552377016146086744581113103177, median); + test_exact(-1.0, 5.5, 0.36787944117144232159552377016146086744581113103177, median); + test_exact(-0.1, 0.1, 0.90483741803595956814139238421693559530906465375738, median); + test_exact(-0.1, 1.5, 0.90483741803595956814139238421693559530906465375738, median); + test_exact(-0.1, 2.5, 0.90483741803595956814139238421693559530906465375738, median); + test_exact(-0.1, 5.5, 0.90483741803595956814139238421693559530906465375738, median); + test_exact(0.1, 0.1, 1.1051709180756476309466388234587796577416634163742, median); + test_exact(0.1, 1.5, 1.1051709180756476309466388234587796577416634163742, median); + test_exact(0.1, 2.5, 1.1051709180756476309466388234587796577416634163742, median); + test_exact(0.1, 5.5, 1.1051709180756476309466388234587796577416634163742, median); + test_exact(1.5, 0.1, 4.4816890703380648226020554601192758190057498683697, median); + test_exact(1.5, 1.5, 4.4816890703380648226020554601192758190057498683697, median); + test_exact(1.5, 2.5, 4.4816890703380648226020554601192758190057498683697, median); + test_exact(1.5, 5.5, 4.4816890703380648226020554601192758190057498683697, median); + test_exact(2.5, 0.1, 12.182493960703473438070175951167966183182767790063, median); + test_exact(2.5, 1.5, 12.182493960703473438070175951167966183182767790063, median); + test_exact(2.5, 2.5, 12.182493960703473438070175951167966183182767790063, median); + test_exact(2.5, 5.5, 12.182493960703473438070175951167966183182767790063, median); + test_exact(5.5, 0.1, 244.6919322642203879151889495118393501842287101075, median); + test_exact(5.5, 1.5, 244.6919322642203879151889495118393501842287101075, median); + test_exact(5.5, 2.5, 244.6919322642203879151889495118393501842287101075, median); + test_exact(5.5, 5.5, 244.6919322642203879151889495118393501842287101075, median); } #[test] fn test_min_max() { let min = |x: LogNormal| x.min(); let max = |x: LogNormal| x.max(); - test_case(0.0, 0.1, 0.0, min); - test_case(-3.0, 10.0, 0.0, min); - test_case(0.0, 0.1, f64::INFINITY, max); - test_case(-3.0, 10.0, f64::INFINITY, max); + test_exact(0.0, 0.1, 0.0, min); + test_exact(-3.0, 10.0, 0.0, min); + test_exact(0.0, 0.1, f64::INFINITY, max); + test_exact(-3.0, 10.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: LogNormal| x.pdf(arg); - test_almost(-0.1, 0.1, 1.7968349035073582236359415565799753846986440127816e-104, 1e-118, pdf(0.1)); - test_almost(-0.1, 0.1, 0.00000018288923328441197822391757965928083462391836798722, 1e-21, pdf(0.5)); - test_case(-0.1, 0.1, 2.3363114904470413709866234247494393485647978367885, pdf(0.8)); - test_almost(-0.1, 1.5, 0.90492497850024368541682348133921492204585092983646, 1e-15, pdf(0.1)); - test_almost(-0.1, 1.5, 0.49191985207660942803818797602364034466489243416574, 1e-16, pdf(0.5)); - test_case(-0.1, 1.5, 0.33133347214343229148978298237579567194870525187207, pdf(0.8)); - test_case(-0.1, 2.5, 1.0824698632626565182080576574958317806389057196768, pdf(0.1)); - test_almost(-0.1, 2.5, 0.31029619474753883558901295436486123689563749784867, 1e-16, pdf(0.5)); - test_almost(-0.1, 2.5, 0.19922929916156673799861939824205622734205083805245, 1e-16, pdf(0.8)); + test_absolute(-0.1, 0.1, 1.7968349035073582236359415565799753846986440127816e-104, 1e-118, pdf(0.1)); + test_absolute(-0.1, 0.1, 0.00000018288923328441197822391757965928083462391836798722, 1e-21, pdf(0.5)); + test_exact(-0.1, 0.1, 2.3363114904470413709866234247494393485647978367885, pdf(0.8)); + test_absolute(-0.1, 1.5, 0.90492497850024368541682348133921492204585092983646, 1e-15, pdf(0.1)); + test_absolute(-0.1, 1.5, 0.49191985207660942803818797602364034466489243416574, 1e-16, pdf(0.5)); + test_exact(-0.1, 1.5, 0.33133347214343229148978298237579567194870525187207, pdf(0.8)); + test_exact(-0.1, 2.5, 1.0824698632626565182080576574958317806389057196768, pdf(0.1)); + test_absolute(-0.1, 2.5, 0.31029619474753883558901295436486123689563749784867, 1e-16, pdf(0.5)); + test_absolute(-0.1, 2.5, 0.19922929916156673799861939824205622734205083805245, 1e-16, pdf(0.8)); // Test removed because it was causing compiler issues (see issue 31407 for rust) -// test_almost(1.5, 0.1, 4.1070141770545881694056265342787422035256248474059e-313, 1e-322, pdf(0.1)); +// test_absolute(1.5, 0.1, 4.1070141770545881694056265342787422035256248474059e-313, 1e-322, pdf(0.1)); // - test_almost(1.5, 0.1, 2.8602688726477103843476657332784045661507239533567e-104, 1e-116, pdf(0.5)); - test_case(1.5, 0.1, 1.6670425710002183246335601541889400558525870482613e-64, pdf(0.8)); - test_almost(1.5, 1.5, 0.10698412103361841220076392503406214751353235895732, 1e-16, pdf(0.1)); - test_almost(1.5, 1.5, 0.18266125308224685664142384493330155315630876975024, 1e-16, pdf(0.5)); - test_almost(1.5, 1.5, 0.17185785323404088913982425377565512294017306418953, 1e-16, pdf(0.8)); - test_almost(1.5, 2.5, 0.50186885259059181992025035649158160252576845315332, 1e-15, pdf(0.1)); - test_almost(1.5, 2.5, 0.21721369314437986034957451699565540205404697589349, 1e-16, pdf(0.5)); - test_case(1.5, 2.5, 0.15729636000661278918949298391170443742675565300598, pdf(0.8)); - test_case(2.5, 0.1, 5.6836826548848916385760779034504046896805825555997e-500, pdf(0.1)); - test_almost(2.5, 0.1, 3.1225608678589488061206338085285607881363155340377e-221, 1e-233, pdf(0.5)); - test_almost(2.5, 0.1, 4.6994713794671660918554320071312374073172560048297e-161, 1e-173, pdf(0.8)); - test_almost(2.5, 1.5, 0.015806486291412916772431170442330946677601577502353, 1e-16, pdf(0.1)); - test_almost(2.5, 1.5, 0.055184331257528847223852028950484131834529030116388, 1e-16, pdf(0.5)); - test_case(2.5, 1.5, 0.063982134749859504449658286955049840393511776984362, pdf(0.8)); - test_almost(2.5, 2.5, 0.25212505662402617595900822552548977822542300480086, 1e-15, pdf(0.1)); - test_almost(2.5, 2.5, 0.14117186955911792460646517002386088579088567275401, 1e-16, pdf(0.5)); - test_almost(2.5, 2.5, 0.11021452580363707866161369621432656293405065561317, 1e-16, pdf(0.8)); + test_absolute(1.5, 0.1, 2.8602688726477103843476657332784045661507239533567e-104, 1e-116, pdf(0.5)); + test_exact(1.5, 0.1, 1.6670425710002183246335601541889400558525870482613e-64, pdf(0.8)); + test_absolute(1.5, 1.5, 0.10698412103361841220076392503406214751353235895732, 1e-16, pdf(0.1)); + test_absolute(1.5, 1.5, 0.18266125308224685664142384493330155315630876975024, 1e-16, pdf(0.5)); + test_absolute(1.5, 1.5, 0.17185785323404088913982425377565512294017306418953, 1e-16, pdf(0.8)); + test_absolute(1.5, 2.5, 0.50186885259059181992025035649158160252576845315332, 1e-15, pdf(0.1)); + test_absolute(1.5, 2.5, 0.21721369314437986034957451699565540205404697589349, 1e-16, pdf(0.5)); + test_exact(1.5, 2.5, 0.15729636000661278918949298391170443742675565300598, pdf(0.8)); + test_exact(2.5, 0.1, 5.6836826548848916385760779034504046896805825555997e-500, pdf(0.1)); + test_absolute(2.5, 0.1, 3.1225608678589488061206338085285607881363155340377e-221, 1e-233, pdf(0.5)); + test_absolute(2.5, 0.1, 4.6994713794671660918554320071312374073172560048297e-161, 1e-173, pdf(0.8)); + test_absolute(2.5, 1.5, 0.015806486291412916772431170442330946677601577502353, 1e-16, pdf(0.1)); + test_absolute(2.5, 1.5, 0.055184331257528847223852028950484131834529030116388, 1e-16, pdf(0.5)); + test_exact(2.5, 1.5, 0.063982134749859504449658286955049840393511776984362, pdf(0.8)); + test_absolute(2.5, 2.5, 0.25212505662402617595900822552548977822542300480086, 1e-15, pdf(0.1)); + test_absolute(2.5, 2.5, 0.14117186955911792460646517002386088579088567275401, 1e-16, pdf(0.5)); + test_absolute(2.5, 2.5, 0.11021452580363707866161369621432656293405065561317, 1e-16, pdf(0.8)); } #[test] fn test_neg_pdf() { let pdf = |arg: f64| move |x: LogNormal| x.pdf(arg); - test_case(0.0, 1.0, 0.0, pdf(0.0)); + test_exact(0.0, 1.0, 0.0, pdf(0.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: LogNormal| x.ln_pdf(arg); - test_case(-0.1, 0.1, -238.88282294119596467794686179588610665317241097599, ln_pdf(0.1)); - test_almost(-0.1, 0.1, -15.514385149961296196003163062199569075052113039686, 1e-14, ln_pdf(0.5)); - test_case(-0.1, 0.1, 0.84857339958981283964373051826407417105725729082041, ln_pdf(0.8)); - test_almost(-0.1, 1.5, -0.099903235403144611051953094864849327288457482212211, 1e-15, ln_pdf(0.1)); - test_almost(-0.1, 1.5, -0.70943947804316122682964396008813828577195771418027, 1e-15, ln_pdf(0.5)); - test_almost(-0.1, 1.5, -1.1046299420497998262946038709903250420774183529995, 1e-15, ln_pdf(0.8)); - test_almost(-0.1, 2.5, 0.07924534056485078867266307735371665927517517183681, 1e-16, ln_pdf(0.1)); - test_case(-0.1, 2.5, -1.1702279707433794860424967893989374511050637417043, ln_pdf(0.5)); - test_case(-0.1, 2.5, -1.6132988605030400828957768752511536087538109996183, ln_pdf(0.8)); - test_case(1.5, 0.1, -719.29643782024317312262673764204041218720576249741, ln_pdf(0.1)); - test_almost(1.5, 0.1, -238.41793403955250272430898754048547661932857086122, 1e-13, ln_pdf(0.5)); - test_case(1.5, 0.1, -146.85439481068371057247137024006716189469284256628, ln_pdf(0.8)); - test_almost(1.5, 1.5, -2.2350748570877992856465076624973458117562108140674, 1e-15, ln_pdf(0.1)); - test_almost(1.5, 1.5, -1.7001219175524556705452882616787223585705662860012, 1e-15, ln_pdf(0.5)); - test_almost(1.5, 1.5, -1.7610875785399045023354101841009649273236721172008, 1e-15, ln_pdf(0.8)); - test_almost(1.5, 2.5, -0.68941644324162489418137656699398207513321602763104, 1e-15, ln_pdf(0.1)); - test_case(1.5, 2.5, -1.5268736489667254857801287379715477173125628275598, ln_pdf(0.5)); - test_case(1.5, 2.5, -1.8496236096394777662704671479709839674424623547308, ln_pdf(0.8)); - test_almost(2.5, 0.1, -1149.5549471196476523788026360929146688367845019398, 1e-12, ln_pdf(0.1)); - test_almost(2.5, 0.1, -507.73265209554698134113704985174959301922196605736, 1e-12, ln_pdf(0.5)); - test_almost(2.5, 0.1, -369.16874994210463740474549611573497379941224077335, 1e-13, ln_pdf(0.8)); - test_almost(2.5, 1.5, -4.1473348984184862316495477617980296904955324113457, 1e-15, ln_pdf(0.1)); - test_almost(2.5, 1.5, -2.8970762200235424747307247601045786110485663457169, 1e-15, ln_pdf(0.5)); - test_case(2.5, 1.5, -2.7491513791239977024488074547907467152956602019989, ln_pdf(0.8)); - test_almost(2.5, 2.5, -1.3778300581206721947424710027422282714793718026513, 1e-15, ln_pdf(0.1)); - test_case(2.5, 2.5, -1.9577771978563167352868858774048559682046428490575, ln_pdf(0.5)); - test_case(2.5, 2.5, -2.2053265778497513183112901654193054111123780652581, ln_pdf(0.8)); + test_exact(-0.1, 0.1, -238.88282294119596467794686179588610665317241097599, ln_pdf(0.1)); + test_absolute(-0.1, 0.1, -15.514385149961296196003163062199569075052113039686, 1e-14, ln_pdf(0.5)); + test_exact(-0.1, 0.1, 0.84857339958981283964373051826407417105725729082041, ln_pdf(0.8)); + test_absolute(-0.1, 1.5, -0.099903235403144611051953094864849327288457482212211, 1e-15, ln_pdf(0.1)); + test_absolute(-0.1, 1.5, -0.70943947804316122682964396008813828577195771418027, 1e-15, ln_pdf(0.5)); + test_absolute(-0.1, 1.5, -1.1046299420497998262946038709903250420774183529995, 1e-15, ln_pdf(0.8)); + test_absolute(-0.1, 2.5, 0.07924534056485078867266307735371665927517517183681, 1e-16, ln_pdf(0.1)); + test_exact(-0.1, 2.5, -1.1702279707433794860424967893989374511050637417043, ln_pdf(0.5)); + test_exact(-0.1, 2.5, -1.6132988605030400828957768752511536087538109996183, ln_pdf(0.8)); + test_exact(1.5, 0.1, -719.29643782024317312262673764204041218720576249741, ln_pdf(0.1)); + test_absolute(1.5, 0.1, -238.41793403955250272430898754048547661932857086122, 1e-13, ln_pdf(0.5)); + test_exact(1.5, 0.1, -146.85439481068371057247137024006716189469284256628, ln_pdf(0.8)); + test_absolute(1.5, 1.5, -2.2350748570877992856465076624973458117562108140674, 1e-15, ln_pdf(0.1)); + test_absolute(1.5, 1.5, -1.7001219175524556705452882616787223585705662860012, 1e-15, ln_pdf(0.5)); + test_absolute(1.5, 1.5, -1.7610875785399045023354101841009649273236721172008, 1e-15, ln_pdf(0.8)); + test_absolute(1.5, 2.5, -0.68941644324162489418137656699398207513321602763104, 1e-15, ln_pdf(0.1)); + test_exact(1.5, 2.5, -1.5268736489667254857801287379715477173125628275598, ln_pdf(0.5)); + test_exact(1.5, 2.5, -1.8496236096394777662704671479709839674424623547308, ln_pdf(0.8)); + test_absolute(2.5, 0.1, -1149.5549471196476523788026360929146688367845019398, 1e-12, ln_pdf(0.1)); + test_absolute(2.5, 0.1, -507.73265209554698134113704985174959301922196605736, 1e-12, ln_pdf(0.5)); + test_absolute(2.5, 0.1, -369.16874994210463740474549611573497379941224077335, 1e-13, ln_pdf(0.8)); + test_absolute(2.5, 1.5, -4.1473348984184862316495477617980296904955324113457, 1e-15, ln_pdf(0.1)); + test_absolute(2.5, 1.5, -2.8970762200235424747307247601045786110485663457169, 1e-15, ln_pdf(0.5)); + test_exact(2.5, 1.5, -2.7491513791239977024488074547907467152956602019989, ln_pdf(0.8)); + test_absolute(2.5, 2.5, -1.3778300581206721947424710027422282714793718026513, 1e-15, ln_pdf(0.1)); + test_exact(2.5, 2.5, -1.9577771978563167352868858774048559682046428490575, ln_pdf(0.5)); + test_exact(2.5, 2.5, -2.2053265778497513183112901654193054111123780652581, ln_pdf(0.8)); } #[test] fn test_neg_ln_pdf() { let ln_pdf = |arg: f64| move |x: LogNormal| x.ln_pdf(arg); - test_case(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); + test_exact(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); } #[test] @@ -658,13 +629,13 @@ mod tests { // calls test_almost after re-arranging the input/output arguments and calling f with input let almost = |mean: f64, std_dev: f64, cdf_input: f64, cdf_output: f64, acc: f64| { let (input, output) = arrange_input_output(cdf_input, cdf_output); - test_almost(mean, std_dev, output, acc, f(input)); + test_absolute(mean, std_dev, output, acc, f(input)); }; // calls test_case after re-arranging the input/output arguments and calling f with input let case = |mean: f64, std_dev: f64, cdf_input: f64, cdf_output: f64| { let (input, output) = arrange_input_output(cdf_input, cdf_output); - test_case(mean, std_dev, output, f(input)); + test_exact(mean, std_dev, output, f(input)); }; // we skip cases where the CDF outputs 0.0 when testing the inverse CDF because @@ -720,34 +691,34 @@ mod tests { let sf = |arg: f64| move |x: LogNormal| x.sf(arg); // Wolfram Alpha:: SurvivalFunction[ LogNormalDistribution(-0.1, 0.1), 0.1] - test_almost(-0.1, 0.1, 1.0, 1e-107, sf(0.1)); + test_absolute(-0.1, 0.1, 1.0, 1e-107, sf(0.1)); // Wolfram Alpha:: SurvivalFunction[ LogNormalDistribution(-0.1, 0.1), 0.8] - test_almost(-0.1, 0.1, 0.890919989231123, 1e-14, sf(0.8)); + test_absolute(-0.1, 0.1, 0.890919989231123, 1e-14, sf(0.8)); // Wolfram Alpha:: SurvivalFunction[LogNormalDistribution[1.5, 1], 0.8] - test_almost(1.5, 1.0, 0.957568715612642, 1e-14, sf(0.8)); + test_absolute(1.5, 1.0, 0.957568715612642, 1e-14, sf(0.8)); // Wolfram Alpha:: SurvivalFunction[ LogNormalDistribution(2.5, 1.5), 0.1] - test_almost(2.5, 1.5, 0.9993169594777358, 1e-14, sf(0.1)); + test_absolute(2.5, 1.5, 0.9993169594777358, 1e-14, sf(0.1)); } #[test] fn test_neg_cdf() { let cdf = |arg: f64| move |x: LogNormal| x.cdf(arg); - test_case(0.0, 1.0, 0.0, cdf(0.0)); + test_exact(0.0, 1.0, 0.0, cdf(0.0)); } #[test] fn test_neg_sf() { let sf = |arg: f64| move |x: LogNormal| x.sf(arg); - test_case(0.0, 1.0, 1.0, sf(0.0)); + test_exact(0.0, 1.0, 1.0, sf(0.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(0.0, 0.25), 0.0, 10.0); - test::check_continuous_distribution(&try_create(0.0, 0.5), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(0.0, 0.25), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(0.0, 0.5), 0.0, 10.0); } } From ed71f36f4b8a74b2ffb562bafe6fbffbbc7ee2b0 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 19:59:28 +0200 Subject: [PATCH 156/185] Use testing_boiler! for NegativeBinomial --- src/distribution/negative_binomial.rs | 244 ++++++++++---------------- 1 file changed, 97 insertions(+), 147 deletions(-) diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index 065c2239..d36d0f98 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -291,222 +291,172 @@ impl Discrete for NegativeBinomial { #[rustfmt::skip] #[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; use crate::distribution::{DiscreteCDF, Discrete, NegativeBinomial}; use crate::distribution::internal::test; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(r: f64, p: f64) -> NegativeBinomial { - let r = NegativeBinomial::new(r, p); - assert!(r.is_ok()); - r.unwrap() - } - - fn create_case(r: f64, p: f64) { - let dist = try_create(r, p); - assert_eq!(p, dist.p()); - assert_eq!(r, dist.r()); - } - - fn bad_create_case(r: f64, p: f64) { - let r = NegativeBinomial::new(r, p); - assert!(r.is_err()); - } - - fn get_value(r: f64, p: f64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(NegativeBinomial) -> T - { - let r = try_create(r, p); - eval(r) - } - - fn test_case(r: f64, p: f64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(NegativeBinomial) -> T - { - let x = get_value(r, p, eval); - assert_eq!(expected, x); - } - - - fn test_case_or_nan(r: f64, p: f64, expected: f64, eval: F) - where F: Fn(NegativeBinomial) -> f64 - { - let x = get_value(r, p, eval); - if expected.is_nan() { - assert!(x.is_nan()) - } - else { - assert_eq!(expected, x); - } - } - fn test_almost(r: f64, p: f64, expected: f64, acc: f64, eval: F) - where F: Fn(NegativeBinomial) -> f64 - { - let x = get_value(r, p, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(r: f64, p: f64; NegativeBinomial); #[test] fn test_create() { - create_case(0.0, 0.0); - create_case(0.3, 0.4); - create_case(1.0, 0.3); + create_ok(0.0, 0.0); + create_ok(0.3, 0.4); + create_ok(1.0, 0.3); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN, 1.0); - bad_create_case(0.0, f64::NAN); - bad_create_case(-1.0, 1.0); - bad_create_case(2.0, 2.0); + create_err(f64::NAN, 1.0); + create_err(0.0, f64::NAN); + create_err(-1.0, 1.0); + create_err(2.0, 2.0); } #[test] fn test_mean() { let mean = |x: NegativeBinomial| x.mean().unwrap(); - test_case(4.0, 0.0, f64::INFINITY, mean); - test_almost(3.0, 0.3, 7.0, 1e-15 , mean); - test_case(2.0, 1.0, 0.0, mean); + test_exact(4.0, 0.0, f64::INFINITY, mean); + test_absolute(3.0, 0.3, 7.0, 1e-15 , mean); + test_exact(2.0, 1.0, 0.0, mean); } #[test] fn test_variance() { let variance = |x: NegativeBinomial| x.variance().unwrap(); - test_case(4.0, 0.0, f64::INFINITY, variance); - test_almost(3.0, 0.3, 23.333333333333, 1e-12, variance); - test_case(2.0, 1.0, 0.0, variance); + test_exact(4.0, 0.0, f64::INFINITY, variance); + test_absolute(3.0, 0.3, 23.333333333333, 1e-12, variance); + test_exact(2.0, 1.0, 0.0, variance); } #[test] fn test_skewness() { let skewness = |x: NegativeBinomial| x.skewness().unwrap(); - test_case(0.0, 0.0, f64::INFINITY, skewness); - test_almost(0.1, 0.3, 6.425396041, 1e-09, skewness); - test_case(1.0, 1.0, f64::INFINITY, skewness); + test_exact(0.0, 0.0, f64::INFINITY, skewness); + test_absolute(0.1, 0.3, 6.425396041, 1e-09, skewness); + test_exact(1.0, 1.0, f64::INFINITY, skewness); } #[test] fn test_mode() { let mode = |x: NegativeBinomial| x.mode().unwrap(); - test_case(0.0, 0.0, 0.0, mode); - test_case(0.3, 0.0, 0.0, mode); - test_case(1.0, 1.0, 0.0, mode); - test_case(10.0, 0.01, 891.0, mode); + test_exact(0.0, 0.0, 0.0, mode); + test_exact(0.3, 0.0, 0.0, mode); + test_exact(1.0, 1.0, 0.0, mode); + test_exact(10.0, 0.01, 891.0, mode); } #[test] fn test_min_max() { let min = |x: NegativeBinomial| x.min(); let max = |x: NegativeBinomial| x.max(); - test_case(1.0, 0.5, 0, min); - test_case(1.0, 0.3, u64::MAX, max); + test_exact(1.0, 0.5, 0, min); + test_exact(1.0, 0.3, u64::MAX, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: NegativeBinomial| x.pmf(arg); - test_almost(4.0, 0.5, 0.0625, 1e-8, pmf(0)); - test_almost(4.0, 0.5, 0.15625, 1e-8, pmf(3)); - test_case(1.0, 0.0, 0.0, pmf(0)); - test_case(1.0, 0.0, 0.0, pmf(1)); - test_almost(3.0, 0.2, 0.008, 1e-15, pmf(0)); - test_almost(3.0, 0.2, 0.0192, 1e-15, pmf(1)); - test_almost(3.0, 0.2, 0.04096, 1e-15, pmf(3)); - test_almost(10.0, 0.2, 1.024e-07, 1e-07, pmf(0)); - test_almost(10.0, 0.2, 8.192e-07, 1e-07, pmf(1)); - test_almost(10.0, 0.2, 0.001015706852, 1e-07, pmf(10)); - test_almost(1.0, 0.3, 0.3, 1e-15, pmf(0)); - test_almost(1.0, 0.3, 0.21, 1e-15, pmf(1)); - test_almost(3.0, 0.3, 0.027, 1e-15, pmf(0)); - test_case(0.3, 1.0, 0.0, pmf(1)); - test_case(0.3, 1.0, 0.0, pmf(3)); - test_case_or_nan(0.3, 1.0, f64::NAN, pmf(0)); - test_case(0.3, 1.0, 0.0, pmf(1)); - test_case(0.3, 1.0, 0.0, pmf(10)); - test_case_or_nan(1.0, 1.0, f64::NAN, pmf(0)); - test_case(1.0, 1.0, 0.0, pmf(1)); - test_case_or_nan(3.0, 1.0, f64::NAN, pmf(0)); - test_case(3.0, 1.0, 0.0, pmf(1)); - test_case(3.0, 1.0, 0.0, pmf(3)); - test_case_or_nan(10.0, 1.0, f64::NAN, pmf(0)); - test_case(10.0, 1.0, 0.0, pmf(1)); - test_case(10.0, 1.0, 0.0, pmf(10)); + test_absolute(4.0, 0.5, 0.0625, 1e-8, pmf(0)); + test_absolute(4.0, 0.5, 0.15625, 1e-8, pmf(3)); + test_exact(1.0, 0.0, 0.0, pmf(0)); + test_exact(1.0, 0.0, 0.0, pmf(1)); + test_absolute(3.0, 0.2, 0.008, 1e-15, pmf(0)); + test_absolute(3.0, 0.2, 0.0192, 1e-15, pmf(1)); + test_absolute(3.0, 0.2, 0.04096, 1e-15, pmf(3)); + test_absolute(10.0, 0.2, 1.024e-07, 1e-07, pmf(0)); + test_absolute(10.0, 0.2, 8.192e-07, 1e-07, pmf(1)); + test_absolute(10.0, 0.2, 0.001015706852, 1e-07, pmf(10)); + test_absolute(1.0, 0.3, 0.3, 1e-15, pmf(0)); + test_absolute(1.0, 0.3, 0.21, 1e-15, pmf(1)); + test_absolute(3.0, 0.3, 0.027, 1e-15, pmf(0)); + test_exact(0.3, 1.0, 0.0, pmf(1)); + test_exact(0.3, 1.0, 0.0, pmf(3)); + test_is_nan(0.3, 1.0, pmf(0)); + test_exact(0.3, 1.0, 0.0, pmf(1)); + test_exact(0.3, 1.0, 0.0, pmf(10)); + test_is_nan(1.0, 1.0, pmf(0)); + test_exact(1.0, 1.0, 0.0, pmf(1)); + test_is_nan(3.0, 1.0, pmf(0)); + test_exact(3.0, 1.0, 0.0, pmf(1)); + test_exact(3.0, 1.0, 0.0, pmf(3)); + test_is_nan(10.0, 1.0, pmf(0)); + test_exact(10.0, 1.0, 0.0, pmf(1)); + test_exact(10.0, 1.0, 0.0, pmf(10)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: NegativeBinomial| x.ln_pmf(arg); - test_case(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(0)); - test_case(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(1)); - test_almost(3.0, 0.2, -4.828313737, 1e-08, ln_pmf(0)); - test_almost(3.0, 0.2, -3.952845, 1e-08, ln_pmf(1)); - test_almost(3.0, 0.2, -3.195159298, 1e-08, ln_pmf(3)); - test_almost(10.0, 0.2, -16.09437912, 1e-08, ln_pmf(0)); - test_almost(10.0, 0.2, -14.01493758, 1e-08, ln_pmf(1)); - test_almost(10.0, 0.2, -6.892170503, 1e-08, ln_pmf(10)); - test_almost(1.0, 0.3, -1.203972804, 1e-08, ln_pmf(0)); - test_almost(1.0, 0.3, -1.560647748, 1e-08, ln_pmf(1)); - test_almost(3.0, 0.3, -3.611918413, 1e-08, ln_pmf(0)); - test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(3)); - test_case_or_nan(0.3, 1.0, f64::NAN, ln_pmf(0)); - test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(10)); - test_case_or_nan(1.0, 1.0, f64::NAN, ln_pmf(0)); - test_case(1.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case_or_nan(3.0, 1.0, f64::NAN, ln_pmf(0)); - test_case(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(3)); - test_case_or_nan(10.0, 1.0, f64::NAN, ln_pmf(0)); - test_case(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(10)); + test_exact(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(0)); + test_exact(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(1)); + test_absolute(3.0, 0.2, -4.828313737, 1e-08, ln_pmf(0)); + test_absolute(3.0, 0.2, -3.952845, 1e-08, ln_pmf(1)); + test_absolute(3.0, 0.2, -3.195159298, 1e-08, ln_pmf(3)); + test_absolute(10.0, 0.2, -16.09437912, 1e-08, ln_pmf(0)); + test_absolute(10.0, 0.2, -14.01493758, 1e-08, ln_pmf(1)); + test_absolute(10.0, 0.2, -6.892170503, 1e-08, ln_pmf(10)); + test_absolute(1.0, 0.3, -1.203972804, 1e-08, ln_pmf(0)); + test_absolute(1.0, 0.3, -1.560647748, 1e-08, ln_pmf(1)); + test_absolute(3.0, 0.3, -3.611918413, 1e-08, ln_pmf(0)); + test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(3)); + test_is_nan(0.3, 1.0, ln_pmf(0)); + test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(10)); + test_is_nan(1.0, 1.0, ln_pmf(0)); + test_exact(1.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + test_is_nan(3.0, 1.0, ln_pmf(0)); + test_exact(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(3)); + test_is_nan(10.0, 1.0, ln_pmf(0)); + test_exact(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(10)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg); - test_almost(1.0, 0.3, 0.3, 1e-08, cdf(0)); - test_almost(1.0, 0.3, 0.51, 1e-08, cdf(1)); - test_almost(1.0, 0.3, 0.83193, 1e-08, cdf(4)); - test_almost(1.0, 0.3, 0.9802267326, 1e-08, cdf(10)); - test_case(1.0, 1.0, 1.0, cdf(0)); - test_case(1.0, 1.0, 1.0, cdf(1)); - test_almost(10.0, 0.75, 0.05631351471, 1e-08, cdf(0)); - test_almost(10.0, 0.75, 0.1970973015, 1e-08, cdf(1)); - test_almost(10.0, 0.75, 0.9960578583, 1e-08, cdf(10)); + test_absolute(1.0, 0.3, 0.3, 1e-08, cdf(0)); + test_absolute(1.0, 0.3, 0.51, 1e-08, cdf(1)); + test_absolute(1.0, 0.3, 0.83193, 1e-08, cdf(4)); + test_absolute(1.0, 0.3, 0.9802267326, 1e-08, cdf(10)); + test_exact(1.0, 1.0, 1.0, cdf(0)); + test_exact(1.0, 1.0, 1.0, cdf(1)); + test_absolute(10.0, 0.75, 0.05631351471, 1e-08, cdf(0)); + test_absolute(10.0, 0.75, 0.1970973015, 1e-08, cdf(1)); + test_absolute(10.0, 0.75, 0.9960578583, 1e-08, cdf(10)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg); - test_almost(1.0, 0.3, 0.7, 1e-08, sf(0)); - test_almost(1.0, 0.3, 0.49, 1e-08, sf(1)); - test_almost(1.0, 0.3, 0.1680699999999986, 1e-08, sf(4)); - test_almost(1.0, 0.3, 0.019773267430000074, 1e-08, sf(10)); - test_case(1.0, 1.0, 0.0, sf(0)); - test_case(1.0, 1.0, 0.0, sf(1)); - test_almost(10.0, 0.75, 0.9436864852905275, 1e-08, sf(0)); - test_almost(10.0, 0.75, 0.8029026985168456, 1e-08, sf(1)); - test_almost(10.0, 0.75, 0.003942141664083465, 1e-08, sf(10)); + test_absolute(1.0, 0.3, 0.7, 1e-08, sf(0)); + test_absolute(1.0, 0.3, 0.49, 1e-08, sf(1)); + test_absolute(1.0, 0.3, 0.1680699999999986, 1e-08, sf(4)); + test_absolute(1.0, 0.3, 0.019773267430000074, 1e-08, sf(10)); + test_exact(1.0, 1.0, 0.0, sf(0)); + test_exact(1.0, 1.0, 0.0, sf(1)); + test_absolute(10.0, 0.75, 0.9436864852905275, 1e-08, sf(0)); + test_absolute(10.0, 0.75, 0.8029026985168456, 1e-08, sf(1)); + test_absolute(10.0, 0.75, 0.003942141664083465, 1e-08, sf(10)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg); - test_case(3.0, 0.5, 1.0, cdf(100)); + test_exact(3.0, 0.5, 1.0, cdf(100)); } #[test] fn test_discrete() { - test::check_discrete_distribution(&try_create(5.0, 0.3), 35); - test::check_discrete_distribution(&try_create(10.0, 0.7), 21); + test::check_discrete_distribution(&create_ok(5.0, 0.3), 35); + test::check_discrete_distribution(&create_ok(10.0, 0.7), 21); } #[test] fn test_sf_upper_bound() { let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg); - test_almost(3.0, 0.5, 5.282409836586059e-28, 1e-28, sf(100)); + test_absolute(3.0, 0.5, 5.282409836586059e-28, 1e-28, sf(100)); } } From b39d8388e0de024c06955c84ef4c35a51d31c716 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 20:01:28 +0200 Subject: [PATCH 157/185] Use testing_boiler! for Normal --- src/distribution/normal.rs | 260 ++++++++++++++++--------------------- 1 file changed, 115 insertions(+), 145 deletions(-) diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 94e8c6b6..65f6ad90 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -334,221 +334,191 @@ impl std::default::Default for Normal { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Normal}; use crate::distribution::internal::*; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(mean: f64, std_dev: f64) -> Normal { - let n = Normal::new(mean, std_dev); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(mean: f64, std_dev: f64) { - let n = try_create(mean, std_dev); - assert_eq!(mean, n.mean().unwrap()); - assert_eq!(std_dev, n.std_dev().unwrap()); - } - - fn bad_create_case(mean: f64, std_dev: f64) { - let n = Normal::new(mean, std_dev); - assert!(n.is_err()); - } - - fn test_case(mean: f64, std_dev: f64, expected: f64, eval: F) - where F: Fn(Normal) -> f64 - { - let n = try_create(mean, std_dev); - let x = eval(n); - assert_eq!(expected, x); - } - - fn test_almost(mean: f64, std_dev: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Normal) -> f64 - { - let n = try_create(mean, std_dev); - let x = eval(n); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(mean: f64, std_dev: f64; Normal); #[test] fn test_create() { - create_case(10.0, 0.1); - create_case(-5.0, 1.0); - create_case(0.0, 10.0); - create_case(10.0, 100.0); - create_case(-5.0, f64::INFINITY); + create_ok(10.0, 0.1); + create_ok(-5.0, 1.0); + create_ok(0.0, 10.0); + create_ok(10.0, 100.0); + create_ok(-5.0, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(0.0, 0.0); - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN); - bad_create_case(1.0, -1.0); + create_err(0.0, 0.0); + create_err(f64::NAN, 1.0); + create_err(1.0, f64::NAN); + create_err(f64::NAN, f64::NAN); + create_err(1.0, -1.0); } #[test] fn test_variance() { let variance = |x: Normal| x.variance().unwrap(); - test_case(0.0, 0.1, 0.1 * 0.1, variance); - test_case(0.0, 1.0, 1.0, variance); - test_case(0.0, 10.0, 100.0, variance); - test_case(0.0, f64::INFINITY, f64::INFINITY, variance); + test_exact(0.0, 0.1, 0.1 * 0.1, variance); + test_exact(0.0, 1.0, 1.0, variance); + test_exact(0.0, 10.0, 100.0, variance); + test_exact(0.0, f64::INFINITY, f64::INFINITY, variance); } #[test] fn test_entropy() { let entropy = |x: Normal| x.entropy().unwrap(); - test_almost(0.0, 0.1, -0.8836465597893729422377, 1e-15, entropy); - test_case(0.0, 1.0, 1.41893853320467274178, entropy); - test_case(0.0, 10.0, 3.721523626198718425798, entropy); - test_case(0.0, f64::INFINITY, f64::INFINITY, entropy); + test_absolute(0.0, 0.1, -0.8836465597893729422377, 1e-15, entropy); + test_exact(0.0, 1.0, 1.41893853320467274178, entropy); + test_exact(0.0, 10.0, 3.721523626198718425798, entropy); + test_exact(0.0, f64::INFINITY, f64::INFINITY, entropy); } #[test] fn test_skewness() { let skewness = |x: Normal| x.skewness().unwrap(); - test_case(0.0, 0.1, 0.0, skewness); - test_case(4.0, 1.0, 0.0, skewness); - test_case(0.3, 10.0, 0.0, skewness); - test_case(0.0, f64::INFINITY, 0.0, skewness); + test_exact(0.0, 0.1, 0.0, skewness); + test_exact(4.0, 1.0, 0.0, skewness); + test_exact(0.3, 10.0, 0.0, skewness); + test_exact(0.0, f64::INFINITY, 0.0, skewness); } #[test] fn test_mode() { let mode = |x: Normal| x.mode().unwrap(); - test_case(-0.0, 1.0, 0.0, mode); - test_case(0.0, 1.0, 0.0, mode); - test_case(0.1, 1.0, 0.1, mode); - test_case(1.0, 1.0, 1.0, mode); - test_case(-10.0, 1.0, -10.0, mode); - test_case(f64::INFINITY, 1.0, f64::INFINITY, mode); + test_exact(-0.0, 1.0, 0.0, mode); + test_exact(0.0, 1.0, 0.0, mode); + test_exact(0.1, 1.0, 0.1, mode); + test_exact(1.0, 1.0, 1.0, mode); + test_exact(-10.0, 1.0, -10.0, mode); + test_exact(f64::INFINITY, 1.0, f64::INFINITY, mode); } #[test] fn test_median() { let median = |x: Normal| x.median(); - test_case(-0.0, 1.0, 0.0, median); - test_case(0.0, 1.0, 0.0, median); - test_case(0.1, 1.0, 0.1, median); - test_case(1.0, 1.0, 1.0, median); - test_case(-0.0, 1.0, -0.0, median); - test_case(f64::INFINITY, 1.0, f64::INFINITY, median); + test_exact(-0.0, 1.0, 0.0, median); + test_exact(0.0, 1.0, 0.0, median); + test_exact(0.1, 1.0, 0.1, median); + test_exact(1.0, 1.0, 1.0, median); + test_exact(-0.0, 1.0, -0.0, median); + test_exact(f64::INFINITY, 1.0, f64::INFINITY, median); } #[test] fn test_min_max() { let min = |x: Normal| x.min(); let max = |x: Normal| x.max(); - test_case(0.0, 0.1, f64::NEG_INFINITY, min); - test_case(-3.0, 10.0, f64::NEG_INFINITY, min); - test_case(0.0, 0.1, f64::INFINITY, max); - test_case(-3.0, 10.0, f64::INFINITY, max); + test_exact(0.0, 0.1, f64::NEG_INFINITY, min); + test_exact(-3.0, 10.0, f64::NEG_INFINITY, min); + test_exact(0.0, 0.1, f64::INFINITY, max); + test_exact(-3.0, 10.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Normal| x.pdf(arg); - test_almost(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(8.5)); - test_almost(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(9.8)); - test_almost(10.0, 0.1, 3.989422804014326779399, 1e-15, pdf(10.0)); - test_almost(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(10.2)); - test_almost(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(11.5)); - test_case(-5.0, 1.0, 1.486719514734297707908E-6, pdf(-10.0)); - test_case(-5.0, 1.0, 0.01752830049356853736216, pdf(-7.5)); - test_almost(-5.0, 1.0, 0.3989422804014326779399, 1e-16, pdf(-5.0)); - test_case(-5.0, 1.0, 0.01752830049356853736216, pdf(-2.5)); - test_case(-5.0, 1.0, 1.486719514734297707908E-6, pdf(0.0)); - test_case(0.0, 10.0, 0.03520653267642994777747, pdf(-5.0)); - test_almost(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(-2.5)); - test_almost(0.0, 10.0, 0.03989422804014326779399, 1e-17, pdf(0.0)); - test_almost(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(2.5)); - test_case(0.0, 10.0, 0.03520653267642994777747, pdf(5.0)); - test_almost(10.0, 100.0, 4.398359598042719404845E-4, 1e-19, pdf(-200.0)); - test_case(10.0, 100.0, 0.002178521770325505313831, pdf(-100.0)); - test_case(10.0, 100.0, 0.003969525474770117655105, pdf(0.0)); - test_almost(10.0, 100.0, 0.002660852498987548218204, 1e-18, pdf(100.0)); - test_case(10.0, 100.0, 6.561581477467659126534E-4, pdf(200.0)); - test_case(-5.0, f64::INFINITY, 0.0, pdf(-5.0)); - test_case(-5.0, f64::INFINITY, 0.0, pdf(0.0)); - test_case(-5.0, f64::INFINITY, 0.0, pdf(100.0)); + test_absolute(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(8.5)); + test_absolute(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(9.8)); + test_absolute(10.0, 0.1, 3.989422804014326779399, 1e-15, pdf(10.0)); + test_absolute(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(10.2)); + test_absolute(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(11.5)); + test_exact(-5.0, 1.0, 1.486719514734297707908E-6, pdf(-10.0)); + test_exact(-5.0, 1.0, 0.01752830049356853736216, pdf(-7.5)); + test_absolute(-5.0, 1.0, 0.3989422804014326779399, 1e-16, pdf(-5.0)); + test_exact(-5.0, 1.0, 0.01752830049356853736216, pdf(-2.5)); + test_exact(-5.0, 1.0, 1.486719514734297707908E-6, pdf(0.0)); + test_exact(0.0, 10.0, 0.03520653267642994777747, pdf(-5.0)); + test_absolute(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(-2.5)); + test_absolute(0.0, 10.0, 0.03989422804014326779399, 1e-17, pdf(0.0)); + test_absolute(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(2.5)); + test_exact(0.0, 10.0, 0.03520653267642994777747, pdf(5.0)); + test_absolute(10.0, 100.0, 4.398359598042719404845E-4, 1e-19, pdf(-200.0)); + test_exact(10.0, 100.0, 0.002178521770325505313831, pdf(-100.0)); + test_exact(10.0, 100.0, 0.003969525474770117655105, pdf(0.0)); + test_absolute(10.0, 100.0, 0.002660852498987548218204, 1e-18, pdf(100.0)); + test_exact(10.0, 100.0, 6.561581477467659126534E-4, pdf(200.0)); + test_exact(-5.0, f64::INFINITY, 0.0, pdf(-5.0)); + test_exact(-5.0, f64::INFINITY, 0.0, pdf(0.0)); + test_exact(-5.0, f64::INFINITY, 0.0, pdf(100.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Normal| x.ln_pdf(arg); - test_almost(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(8.5)); - test_almost(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(9.8)); - test_almost(10.0, 0.1, (3.989422804014326779399f64).ln(), 1e-15, ln_pdf(10.0)); - test_almost(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(10.2)); - test_almost(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(11.5)); - test_case(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(-10.0)); - test_case(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-7.5)); - test_almost(-5.0, 1.0, (0.3989422804014326779399f64).ln(), 1e-15, ln_pdf(-5.0)); - test_case(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-2.5)); - test_case(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(0.0)); - test_case(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(-5.0)); - test_case(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(-2.5)); - test_case(0.0, 10.0, (0.03989422804014326779399f64).ln(), ln_pdf(0.0)); - test_case(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(2.5)); - test_case(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(5.0)); - test_case(10.0, 100.0, (4.398359598042719404845E-4f64).ln(), ln_pdf(-200.0)); - test_case(10.0, 100.0, (0.002178521770325505313831f64).ln(), ln_pdf(-100.0)); - test_almost(10.0, 100.0, (0.003969525474770117655105f64).ln(),1e-15, ln_pdf(0.0)); - test_almost(10.0, 100.0, (0.002660852498987548218204f64).ln(), 1e-15, ln_pdf(100.0)); - test_almost(10.0, 100.0, (6.561581477467659126534E-4f64).ln(), 1e-15, ln_pdf(200.0)); - test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); - test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(100.0)); + test_absolute(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(8.5)); + test_absolute(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(9.8)); + test_absolute(10.0, 0.1, (3.989422804014326779399f64).ln(), 1e-15, ln_pdf(10.0)); + test_absolute(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(10.2)); + test_absolute(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(11.5)); + test_exact(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(-10.0)); + test_exact(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-7.5)); + test_absolute(-5.0, 1.0, (0.3989422804014326779399f64).ln(), 1e-15, ln_pdf(-5.0)); + test_exact(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-2.5)); + test_exact(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(0.0)); + test_exact(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(-5.0)); + test_exact(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(-2.5)); + test_exact(0.0, 10.0, (0.03989422804014326779399f64).ln(), ln_pdf(0.0)); + test_exact(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(2.5)); + test_exact(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(5.0)); + test_exact(10.0, 100.0, (4.398359598042719404845E-4f64).ln(), ln_pdf(-200.0)); + test_exact(10.0, 100.0, (0.002178521770325505313831f64).ln(), ln_pdf(-100.0)); + test_absolute(10.0, 100.0, (0.003969525474770117655105f64).ln(),1e-15, ln_pdf(0.0)); + test_absolute(10.0, 100.0, (0.002660852498987548218204f64).ln(), 1e-15, ln_pdf(100.0)); + test_absolute(10.0, 100.0, (6.561581477467659126534E-4f64).ln(), 1e-15, ln_pdf(200.0)); + test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); + test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); + test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(100.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Normal| x.cdf(arg); - test_case(5.0, 2.0, 0.0, cdf(f64::NEG_INFINITY)); - test_almost(5.0, 2.0, 0.0000002866515718, 1e-16, cdf(-5.0)); - test_almost(5.0, 2.0, 0.0002326290790, 1e-13, cdf(-2.0)); - test_almost(5.0, 2.0, 0.006209665325, 1e-12, cdf(0.0)); - test_case(5.0, 2.0, 0.30853753872598689636229538939166226011639782444542207, cdf(4.0)); - test_case(5.0, 2.0, 0.5, cdf(5.0)); - test_case(5.0, 2.0, 0.69146246127401310363770461060833773988360217555457859, cdf(6.0)); - test_almost(5.0, 2.0, 0.993790334674, 1e-12, cdf(10.0)); + test_exact(5.0, 2.0, 0.0, cdf(f64::NEG_INFINITY)); + test_absolute(5.0, 2.0, 0.0000002866515718, 1e-16, cdf(-5.0)); + test_absolute(5.0, 2.0, 0.0002326290790, 1e-13, cdf(-2.0)); + test_absolute(5.0, 2.0, 0.006209665325, 1e-12, cdf(0.0)); + test_exact(5.0, 2.0, 0.30853753872598689636229538939166226011639782444542207, cdf(4.0)); + test_exact(5.0, 2.0, 0.5, cdf(5.0)); + test_exact(5.0, 2.0, 0.69146246127401310363770461060833773988360217555457859, cdf(6.0)); + test_absolute(5.0, 2.0, 0.993790334674, 1e-12, cdf(10.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Normal| x.sf(arg); - test_case(5.0, 2.0, 1.0, sf(f64::NEG_INFINITY)); - test_almost(5.0, 2.0, 0.9999997133484281, 1e-16, sf(-5.0)); - test_almost(5.0, 2.0, 0.9997673709209455, 1e-13, sf(-2.0)); - test_almost(5.0, 2.0, 0.9937903346744879, 1e-12, sf(0.0)); - test_case(5.0, 2.0, 0.6914624612740131, sf(4.0)); - test_case(5.0, 2.0, 0.5, sf(5.0)); - test_case(5.0, 2.0, 0.3085375387259869, sf(6.0)); - test_almost(5.0, 2.0, 0.006209665325512148, 1e-12, sf(10.0)); + test_exact(5.0, 2.0, 1.0, sf(f64::NEG_INFINITY)); + test_absolute(5.0, 2.0, 0.9999997133484281, 1e-16, sf(-5.0)); + test_absolute(5.0, 2.0, 0.9997673709209455, 1e-13, sf(-2.0)); + test_absolute(5.0, 2.0, 0.9937903346744879, 1e-12, sf(0.0)); + test_exact(5.0, 2.0, 0.6914624612740131, sf(4.0)); + test_exact(5.0, 2.0, 0.5, sf(5.0)); + test_exact(5.0, 2.0, 0.3085375387259869, sf(6.0)); + test_absolute(5.0, 2.0, 0.006209665325512148, 1e-12, sf(10.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(0.0, 1.0), -10.0, 10.0); - test::check_continuous_distribution(&try_create(20.0, 0.5), 10.0, 30.0); + test::check_continuous_distribution(&create_ok(0.0, 1.0), -10.0, 10.0); + test::check_continuous_distribution(&create_ok(20.0, 0.5), 10.0, 30.0); } #[test] fn test_inverse_cdf() { let inverse_cdf = |arg: f64| move |x: Normal| x.inverse_cdf(arg); - test_case(5.0, 2.0, f64::NEG_INFINITY, inverse_cdf( 0.0)); - test_almost(5.0, 2.0, -5.0, 1e-14, inverse_cdf(0.00000028665157187919391167375233287464535385442301361187883)); - test_almost(5.0, 2.0, -2.0, 1e-14, inverse_cdf(0.0002326290790355250363499258867279847735487493358890356)); - test_almost(5.0, 2.0, -0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036)); - test_almost(5.0, 2.0, 0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036)); - test_almost(5.0, 2.0, 4.0, 1e-14, inverse_cdf(0.30853753872598689636229538939166226011639782444542207)); - test_almost(5.0, 2.0, 5.0, 1e-14, inverse_cdf(0.5)); - test_almost(5.0, 2.0, 6.0, 1e-14, inverse_cdf(0.69146246127401310363770461060833773988360217555457859)); - test_almost(5.0, 2.0, 10.0, 1e-14, inverse_cdf(0.9937903346742238648330218954258077788721022530769078)); - test_case(5.0, 2.0, f64::INFINITY, inverse_cdf(1.0)); + test_exact(5.0, 2.0, f64::NEG_INFINITY, inverse_cdf( 0.0)); + test_absolute(5.0, 2.0, -5.0, 1e-14, inverse_cdf(0.00000028665157187919391167375233287464535385442301361187883)); + test_absolute(5.0, 2.0, -2.0, 1e-14, inverse_cdf(0.0002326290790355250363499258867279847735487493358890356)); + test_absolute(5.0, 2.0, -0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036)); + test_absolute(5.0, 2.0, 0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036)); + test_absolute(5.0, 2.0, 4.0, 1e-14, inverse_cdf(0.30853753872598689636229538939166226011639782444542207)); + test_absolute(5.0, 2.0, 5.0, 1e-14, inverse_cdf(0.5)); + test_absolute(5.0, 2.0, 6.0, 1e-14, inverse_cdf(0.69146246127401310363770461060833773988360217555457859)); + test_absolute(5.0, 2.0, 10.0, 1e-14, inverse_cdf(0.9937903346742238648330218954258077788721022530769078)); + test_exact(5.0, 2.0, f64::INFINITY, inverse_cdf(1.0)); } #[test] From 0a244e3fbf4378e60ed9a20bc80a565a0f991cb9 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 20:02:40 +0200 Subject: [PATCH 158/185] Use testing_boiler! for Pareto --- src/distribution/pareto.rs | 220 +++++++++++++++---------------------- 1 file changed, 91 insertions(+), 129 deletions(-) diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 5a22e63e..7f91e0a4 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -354,214 +354,176 @@ impl Continuous for Pareto { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Pareto}; use crate::distribution::internal::*; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(scale: f64, shape: f64) -> Pareto { - let p = Pareto::new(scale, shape); - assert!(p.is_ok()); - p.unwrap() - } - - fn create_case(scale: f64, shape: f64) { - let p = try_create(scale, shape); - assert_eq!(scale, p.scale()); - assert_eq!(shape, p.shape()); - } - - fn bad_create_case(scale: f64, shape: f64) { - let p = Pareto::new(scale, shape); - assert!(p.is_err()); - } - - fn get_value(scale: f64, shape: f64, eval: F) -> T - where F: Fn(Pareto) -> T - { - let p = try_create(scale, shape); - eval(p) - } - - fn test_case(scale: f64, shape: f64, expected: f64, eval: F) - where F: Fn(Pareto) -> f64 - { - let x = get_value(scale, shape, eval); - assert_eq!(expected, x); - } - - fn test_almost(scale: f64, shape: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Pareto) -> f64 - { - let p = try_create(scale, shape); - let x = eval(p); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(scale: f64, shape: f64; Pareto); #[test] fn test_create() { - create_case(10.0, 0.1); - create_case(5.0, 1.0); - create_case(0.1, 10.0); - create_case(10.0, 100.0); - create_case(1.0, f64::INFINITY); - create_case(f64::INFINITY, f64::INFINITY); + create_ok(10.0, 0.1); + create_ok(5.0, 1.0); + create_ok(0.1, 10.0); + create_ok(10.0, 100.0); + create_ok(1.0, f64::INFINITY); + create_ok(f64::INFINITY, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(0.0, 0.0); - bad_create_case(1.0, -1.0); - bad_create_case(-1.0, 1.0); - bad_create_case(-1.0, -1.0); - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN); + create_err(0.0, 0.0); + create_err(1.0, -1.0); + create_err(-1.0, 1.0); + create_err(-1.0, -1.0); + create_err(f64::NAN, 1.0); + create_err(1.0, f64::NAN); + create_err(f64::NAN, f64::NAN); } #[test] fn test_variance() { let variance = |x: Pareto| x.variance().unwrap(); - test_case(1.0, 3.0, 0.75, variance); - test_almost(10.0, 10.0, 125.0 / 81.0, 1e-13, variance); + test_exact(1.0, 3.0, 0.75, variance); + test_absolute(10.0, 10.0, 125.0 / 81.0, 1e-13, variance); } #[test] #[should_panic] fn test_variance_degen() { let variance = |x: Pareto| x.variance().unwrap(); - test_case(1.0, 1.0, f64::INFINITY, variance); // shape <= 2.0 + test_exact(1.0, 1.0, f64::INFINITY, variance); // shape <= 2.0 } #[test] fn test_entropy() { let entropy = |x: Pareto| x.entropy().unwrap(); - test_case(0.1, 0.1, -11.0, entropy); - test_case(1.0, 1.0, -2.0, entropy); - test_case(10.0, 10.0, -1.1, entropy); - test_case(3.0, 1.0, -2.0 - 3f64.ln(), entropy); - test_case(1.0, 3.0, -4.0/3.0 + 3f64.ln(), entropy); + test_exact(0.1, 0.1, -11.0, entropy); + test_exact(1.0, 1.0, -2.0, entropy); + test_exact(10.0, 10.0, -1.1, entropy); + test_exact(3.0, 1.0, -2.0 - 3f64.ln(), entropy); + test_exact(1.0, 3.0, -4.0/3.0 + 3f64.ln(), entropy); } #[test] fn test_skewness() { let skewness = |x: Pareto| x.skewness().unwrap(); - test_case(1.0, 4.0, 5.0*2f64.sqrt(), skewness); - test_case(1.0, 100.0, (707.0/485.0)*2f64.sqrt(), skewness); + test_exact(1.0, 4.0, 5.0*2f64.sqrt(), skewness); + test_exact(1.0, 100.0, (707.0/485.0)*2f64.sqrt(), skewness); } #[test] - #[should_panic] fn test_skewness_invalid_shape() { - let skewness = |x: Pareto| x.skewness().unwrap(); - get_value(1.0, 3.0, skewness); + test_none(1.0, 3.0, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: Pareto| x.mode().unwrap(); - test_case(0.1, 1.0, 0.1, mode); - test_case(2.0, 1.0, 2.0, mode); - test_case(10.0, f64::INFINITY, 10.0, mode); - test_case(f64::INFINITY, 1.0, f64::INFINITY, mode); + test_exact(0.1, 1.0, 0.1, mode); + test_exact(2.0, 1.0, 2.0, mode); + test_exact(10.0, f64::INFINITY, 10.0, mode); + test_exact(f64::INFINITY, 1.0, f64::INFINITY, mode); } #[test] fn test_median() { let median = |x: Pareto| x.median(); - test_case(0.1, 0.1, 102.4, median); - test_case(1.0, 1.0, 2.0, median); - test_case(10.0, 10.0, 10.0*2f64.powf(0.1), median); - test_case(3.0, 0.5, 12.0, median); - test_case(10.0, f64::INFINITY, 10.0, median); + test_exact(0.1, 0.1, 102.4, median); + test_exact(1.0, 1.0, 2.0, median); + test_exact(10.0, 10.0, 10.0*2f64.powf(0.1), median); + test_exact(3.0, 0.5, 12.0, median); + test_exact(10.0, f64::INFINITY, 10.0, median); } #[test] fn test_min_max() { let min = |x: Pareto| x.min(); let max = |x: Pareto| x.max(); - test_case(0.2, f64::INFINITY, 0.2, min); - test_case(10.0, f64::INFINITY, 10.0, min); - test_case(f64::INFINITY, 1.0, f64::INFINITY, min); - test_case(1.0, 0.1, f64::INFINITY, max); - test_case(3.0, 10.0, f64::INFINITY, max); + test_exact(0.2, f64::INFINITY, 0.2, min); + test_exact(10.0, f64::INFINITY, 10.0, min); + test_exact(f64::INFINITY, 1.0, f64::INFINITY, min); + test_exact(1.0, 0.1, f64::INFINITY, max); + test_exact(3.0, 10.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Pareto| x.pdf(arg); - test_case(1.0, 1.0, 0.0, pdf(0.1)); - test_case(1.0, 1.0, 1.0, pdf(1.0)); - test_case(1.0, 1.0, 4.0/9.0, pdf(1.5)); - test_case(1.0, 1.0, 1.0/25.0, pdf(5.0)); - test_case(1.0, 1.0, 1.0/2500.0, pdf(50.0)); - test_case(1.0, 4.0, 4.0, pdf(1.0)); - test_case(1.0, 4.0, 128.0/243.0, pdf(1.5)); - test_case(1.0, 4.0, 1.0/78125000.0, pdf(50.0)); - test_case(3.0, 2.0, 2.0/3.0, pdf(3.0)); - test_case(3.0, 2.0, 18.0/125.0, pdf(5.0)); - test_almost(25.0, 100.0, 1.5777218104420236e-30, 1e-50, pdf(50.0)); - test_almost(100.0, 25.0, 6.6003546737276816e-6, 1e-16, pdf(150.0)); - test_case(1.0, 2.0, 0.0, pdf(f64::INFINITY)); + test_exact(1.0, 1.0, 0.0, pdf(0.1)); + test_exact(1.0, 1.0, 1.0, pdf(1.0)); + test_exact(1.0, 1.0, 4.0/9.0, pdf(1.5)); + test_exact(1.0, 1.0, 1.0/25.0, pdf(5.0)); + test_exact(1.0, 1.0, 1.0/2500.0, pdf(50.0)); + test_exact(1.0, 4.0, 4.0, pdf(1.0)); + test_exact(1.0, 4.0, 128.0/243.0, pdf(1.5)); + test_exact(1.0, 4.0, 1.0/78125000.0, pdf(50.0)); + test_exact(3.0, 2.0, 2.0/3.0, pdf(3.0)); + test_exact(3.0, 2.0, 18.0/125.0, pdf(5.0)); + test_absolute(25.0, 100.0, 1.5777218104420236e-30, 1e-50, pdf(50.0)); + test_absolute(100.0, 25.0, 6.6003546737276816e-6, 1e-16, pdf(150.0)); + test_exact(1.0, 2.0, 0.0, pdf(f64::INFINITY)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Pareto| x.ln_pdf(arg); - test_case(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.1)); - test_case(1.0, 1.0, 0.0, ln_pdf(1.0)); - test_almost(1.0, 1.0, 4f64.ln() - 9f64.ln(), 1e-14, ln_pdf(1.5)); - test_almost(1.0, 1.0, -(25f64.ln()), 1e-14, ln_pdf(5.0)); - test_almost(1.0, 1.0, -(2500f64.ln()), 1e-14, ln_pdf(50.0)); - test_almost(1.0, 4.0, 4f64.ln(), 1e-14, ln_pdf(1.0)); - test_almost(1.0, 4.0, 128f64.ln() - 243f64.ln(), 1e-14, ln_pdf(1.5)); - test_almost(1.0, 4.0, -(78125000f64.ln()), 1e-14, ln_pdf(50.0)); - test_almost(3.0, 2.0, 2f64.ln() - 3f64.ln(), 1e-14, ln_pdf(3.0)); - test_almost(3.0, 2.0, 18f64.ln() - 125f64.ln(), 1e-14, ln_pdf(5.0)); - test_almost(25.0, 100.0, 1.5777218104420236e-30f64.ln(), 1e-12, ln_pdf(50.0)); - test_almost(100.0, 25.0, 6.6003546737276816e-6f64.ln(), 1e-12, ln_pdf(150.0)); - test_case(1.0, 2.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.1)); + test_exact(1.0, 1.0, 0.0, ln_pdf(1.0)); + test_absolute(1.0, 1.0, 4f64.ln() - 9f64.ln(), 1e-14, ln_pdf(1.5)); + test_absolute(1.0, 1.0, -(25f64.ln()), 1e-14, ln_pdf(5.0)); + test_absolute(1.0, 1.0, -(2500f64.ln()), 1e-14, ln_pdf(50.0)); + test_absolute(1.0, 4.0, 4f64.ln(), 1e-14, ln_pdf(1.0)); + test_absolute(1.0, 4.0, 128f64.ln() - 243f64.ln(), 1e-14, ln_pdf(1.5)); + test_absolute(1.0, 4.0, -(78125000f64.ln()), 1e-14, ln_pdf(50.0)); + test_absolute(3.0, 2.0, 2f64.ln() - 3f64.ln(), 1e-14, ln_pdf(3.0)); + test_absolute(3.0, 2.0, 18f64.ln() - 125f64.ln(), 1e-14, ln_pdf(5.0)); + test_absolute(25.0, 100.0, 1.5777218104420236e-30f64.ln(), 1e-12, ln_pdf(50.0)); + test_absolute(100.0, 25.0, 6.6003546737276816e-6f64.ln(), 1e-12, ln_pdf(150.0)); + test_exact(1.0, 2.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Pareto| x.cdf(arg); - test_case(0.1, 0.1, 0.0, cdf(0.1)); - test_case(1.0, 1.0, 0.0, cdf(1.0)); - test_case(5.0, 5.0, 0.0, cdf(2.0)); - test_case(7.0, 7.0, 0.9176457, cdf(10.0)); - test_case(10.0, 10.0, 50700551.0/60466176.0, cdf(12.0)); - test_case(5.0, 1.0, 0.5, cdf(10.0)); - test_case(3.0, 10.0, 1023.0/1024.0, cdf(6.0)); - test_case(1.0, 1.0, 1.0, cdf(f64::INFINITY)); + test_exact(0.1, 0.1, 0.0, cdf(0.1)); + test_exact(1.0, 1.0, 0.0, cdf(1.0)); + test_exact(5.0, 5.0, 0.0, cdf(2.0)); + test_exact(7.0, 7.0, 0.9176457, cdf(10.0)); + test_exact(10.0, 10.0, 50700551.0/60466176.0, cdf(12.0)); + test_exact(5.0, 1.0, 0.5, cdf(10.0)); + test_exact(3.0, 10.0, 1023.0/1024.0, cdf(6.0)); + test_exact(1.0, 1.0, 1.0, cdf(f64::INFINITY)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Pareto| x.sf(arg); - test_case(0.1, 0.1, 1.0, sf(0.1)); - test_case(1.0, 1.0, 1.0, sf(1.0)); - test_case(5.0, 5.0, 1.0, sf(2.0)); - test_almost(7.0, 7.0, 0.08235429999999999, 1e-14, sf(10.0)); - test_almost(10.0, 10.0, 0.16150558288984573, 1e-14, sf(12.0)); - test_case(5.0, 1.0, 0.5, sf(10.0)); - test_almost(3.0, 10.0, 0.0009765625, 1e-14, sf(6.0)); - test_case(1.0, 1.0, 0.0, sf(f64::INFINITY)); + test_exact(0.1, 0.1, 1.0, sf(0.1)); + test_exact(1.0, 1.0, 1.0, sf(1.0)); + test_exact(5.0, 5.0, 1.0, sf(2.0)); + test_absolute(7.0, 7.0, 0.08235429999999999, 1e-14, sf(10.0)); + test_absolute(10.0, 10.0, 0.16150558288984573, 1e-14, sf(12.0)); + test_exact(5.0, 1.0, 0.5, sf(10.0)); + test_absolute(3.0, 10.0, 0.0009765625, 1e-14, sf(6.0)); + test_exact(1.0, 1.0, 0.0, sf(f64::INFINITY)); } #[test] fn test_inverse_cdf() { let func = |arg: f64| move |x: Pareto| x.inverse_cdf(x.cdf(arg)); - test_case(0.1, 0.1, 0.1, func(0.1)); - test_case(1.0, 1.0, 1.0, func(1.0)); - test_case(7.0, 7.0, 10.0, func(10.0)); - test_case(10.0, 10.0, 12.0, func(12.0)); - test_case(5.0, 1.0, 10.0, func(10.0)); - test_case(3.0, 10.0, 6.0, func(6.0)); + test_exact(0.1, 0.1, 0.1, func(0.1)); + test_exact(1.0, 1.0, 1.0, func(1.0)); + test_exact(7.0, 7.0, 10.0, func(10.0)); + test_exact(10.0, 10.0, 12.0, func(12.0)); + test_exact(5.0, 1.0, 10.0, func(10.0)); + test_exact(3.0, 10.0, 6.0, func(6.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(1.0, 10.0), 1.0, 10.0); - test::check_continuous_distribution(&try_create(0.1, 2.0), 0.1, 100.0); + test::check_continuous_distribution(&create_ok(1.0, 10.0), 1.0, 10.0); + test::check_continuous_distribution(&create_ok(0.1, 2.0), 0.1, 100.0); } } From 58f719c7beaaad221b996b233ef01a91bd9b1848 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 20:03:44 +0200 Subject: [PATCH 159/185] Use testing_boiler! for Poisson --- src/distribution/poisson.rs | 179 ++++++++++++++---------------------- 1 file changed, 71 insertions(+), 108 deletions(-) diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 7653ed20..41b56e6a 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -304,183 +304,146 @@ pub fn sample_unchecked(rng: &mut R, lambda: f64) -> f64 { #[rustfmt::skip] #[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; use crate::distribution::{DiscreteCDF, Discrete, Poisson}; use crate::distribution::internal::*; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(lambda: f64) -> Poisson { - let n = Poisson::new(lambda); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(lambda: f64) { - let n = try_create(lambda); - assert_eq!(lambda, n.lambda()); - } - - fn bad_create_case(lambda: f64) { - let n = Poisson::new(lambda); - assert!(n.is_err()); - } - - fn get_value(lambda: f64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Poisson) -> T - { - let n = try_create(lambda); - eval(n) - } - - fn test_case(lambda: f64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(Poisson) -> T - { - let x = get_value(lambda, eval); - assert_eq!(expected, x); - } - - fn test_almost(lambda: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Poisson) -> f64 - { - let x = get_value(lambda, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(lambda: f64; Poisson); #[test] fn test_create() { - create_case(1.5); - create_case(5.4); - create_case(10.8); + create_ok(1.5); + create_ok(5.4); + create_ok(10.8); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN); - bad_create_case(-1.5); - bad_create_case(0.0); + create_err(f64::NAN); + create_err(-1.5); + create_err(0.0); } #[test] fn test_mean() { let mean = |x: Poisson| x.mean().unwrap(); - test_case(1.5, 1.5, mean); - test_case(5.4, 5.4, mean); - test_case(10.8, 10.8, mean); + test_exact(1.5, 1.5, mean); + test_exact(5.4, 5.4, mean); + test_exact(10.8, 10.8, mean); } #[test] fn test_variance() { let variance = |x: Poisson| x.variance().unwrap(); - test_case(1.5, 1.5, variance); - test_case(5.4, 5.4, variance); - test_case(10.8, 10.8, variance); + test_exact(1.5, 1.5, variance); + test_exact(5.4, 5.4, variance); + test_exact(10.8, 10.8, variance); } #[test] fn test_entropy() { let entropy = |x: Poisson| x.entropy().unwrap(); - test_almost(1.5, 1.531959153102376331946, 1e-15, entropy); - test_almost(5.4, 2.244941839577643504608, 1e-15, entropy); - test_case(10.8, 2.600596429676975222694, entropy); + test_absolute(1.5, 1.531959153102376331946, 1e-15, entropy); + test_absolute(5.4, 2.244941839577643504608, 1e-15, entropy); + test_exact(10.8, 2.600596429676975222694, entropy); } #[test] fn test_skewness() { let skewness = |x: Poisson| x.skewness().unwrap(); - test_almost(1.5, 0.8164965809277260327324, 1e-15, skewness); - test_almost(5.4, 0.4303314829119352094644, 1e-16, skewness); - test_almost(10.8, 0.3042903097250922852539, 1e-16, skewness); + test_absolute(1.5, 0.8164965809277260327324, 1e-15, skewness); + test_absolute(5.4, 0.4303314829119352094644, 1e-16, skewness); + test_absolute(10.8, 0.3042903097250922852539, 1e-16, skewness); } #[test] fn test_median() { let median = |x: Poisson| x.median(); - test_case(1.5, 1.0, median); - test_case(5.4, 5.0, median); - test_case(10.8, 11.0, median); + test_exact(1.5, 1.0, median); + test_exact(5.4, 5.0, median); + test_exact(10.8, 11.0, median); } #[test] fn test_mode() { let mode = |x: Poisson| x.mode().unwrap(); - test_case(1.5, 1, mode); - test_case(5.4, 5, mode); - test_case(10.8, 10, mode); + test_exact(1.5, 1, mode); + test_exact(5.4, 5, mode); + test_exact(10.8, 10, mode); } #[test] fn test_min_max() { let min = |x: Poisson| x.min(); let max = |x: Poisson| x.max(); - test_case(1.5, 0, min); - test_case(5.4, 0, min); - test_case(10.8, 0, min); - test_case(1.5, u64::MAX, max); - test_case(5.4, u64::MAX, max); - test_case(10.8, u64::MAX, max); + test_exact(1.5, 0, min); + test_exact(5.4, 0, min); + test_exact(10.8, 0, min); + test_exact(1.5, u64::MAX, max); + test_exact(5.4, u64::MAX, max); + test_exact(10.8, u64::MAX, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Poisson| x.pmf(arg); - test_almost(1.5, 0.334695240222645000000000000000, 1e-15, pmf(1)); - test_almost(1.5, 0.000003545747740570180000000000, 1e-20, pmf(10)); - test_almost(1.5, 0.000000000000000304971208961018, 1e-30, pmf(20)); - test_almost(5.4, 0.024389537090108400000000000000, 1e-17, pmf(1)); - test_almost(5.4, 0.026241240591792300000000000000, 1e-16, pmf(10)); - test_almost(5.4, 0.000000825202200316548000000000, 1e-20, pmf(20)); - test_almost(10.8, 0.000220314636840657000000000000, 1e-18, pmf(1)); - test_almost(10.8, 0.121365183659420000000000000000, 1e-15, pmf(10)); - test_almost(10.8, 0.003908139778574110000000000000, 1e-16, pmf(20)); + test_absolute(1.5, 0.334695240222645000000000000000, 1e-15, pmf(1)); + test_absolute(1.5, 0.000003545747740570180000000000, 1e-20, pmf(10)); + test_absolute(1.5, 0.000000000000000304971208961018, 1e-30, pmf(20)); + test_absolute(5.4, 0.024389537090108400000000000000, 1e-17, pmf(1)); + test_absolute(5.4, 0.026241240591792300000000000000, 1e-16, pmf(10)); + test_absolute(5.4, 0.000000825202200316548000000000, 1e-20, pmf(20)); + test_absolute(10.8, 0.000220314636840657000000000000, 1e-18, pmf(1)); + test_absolute(10.8, 0.121365183659420000000000000000, 1e-15, pmf(10)); + test_absolute(10.8, 0.003908139778574110000000000000, 1e-16, pmf(20)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Poisson| x.ln_pmf(arg); - test_almost(1.5, -1.09453489189183485135413967177, 1e-15, ln_pmf(1)); - test_almost(1.5, -12.5497614919938728510400000000, 1e-14, ln_pmf(10)); - test_almost(1.5, -35.7263142985901000000000000000, 1e-13, ln_pmf(20)); - test_case(5.4, -3.71360104642977159156055355910, ln_pmf(1)); - test_almost(5.4, -3.64042303737322774736223038530, 1e-15, ln_pmf(10)); - test_almost(5.4, -14.0076373893489089949388000000, 1e-14, ln_pmf(20)); - test_almost(10.8, -8.42045386586982559781714423000, 1e-14, ln_pmf(1)); - test_almost(10.8, -2.10895123177378079525424989992, 1e-14, ln_pmf(10)); - test_almost(10.8, -5.54469377815000936289610059500, 1e-14, ln_pmf(20)); + test_absolute(1.5, -1.09453489189183485135413967177, 1e-15, ln_pmf(1)); + test_absolute(1.5, -12.5497614919938728510400000000, 1e-14, ln_pmf(10)); + test_absolute(1.5, -35.7263142985901000000000000000, 1e-13, ln_pmf(20)); + test_exact(5.4, -3.71360104642977159156055355910, ln_pmf(1)); + test_absolute(5.4, -3.64042303737322774736223038530, 1e-15, ln_pmf(10)); + test_absolute(5.4, -14.0076373893489089949388000000, 1e-14, ln_pmf(20)); + test_absolute(10.8, -8.42045386586982559781714423000, 1e-14, ln_pmf(1)); + test_absolute(10.8, -2.10895123177378079525424989992, 1e-14, ln_pmf(10)); + test_absolute(10.8, -5.54469377815000936289610059500, 1e-14, ln_pmf(20)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Poisson| x.cdf(arg); - test_almost(1.5, 0.5578254003710750000000, 1e-15, cdf(1)); - test_almost(1.5, 0.9999994482467640000000, 1e-15, cdf(10)); - test_case(1.5, 1.0, cdf(20)); - test_almost(5.4, 0.0289061180327211000000, 1e-16, cdf(1)); - test_almost(5.4, 0.9774863006897650000000, 1e-15, cdf(10)); - test_almost(5.4, 0.9999997199928290000000, 1e-15, cdf(20)); - test_almost(10.8, 0.0002407141402518290000, 1e-16, cdf(1)); - test_almost(10.8, 0.4839692359955690000000, 1e-15, cdf(10)); - test_almost(10.8, 0.9961800769608090000000, 1e-15, cdf(20)); + test_absolute(1.5, 0.5578254003710750000000, 1e-15, cdf(1)); + test_absolute(1.5, 0.9999994482467640000000, 1e-15, cdf(10)); + test_exact(1.5, 1.0, cdf(20)); + test_absolute(5.4, 0.0289061180327211000000, 1e-16, cdf(1)); + test_absolute(5.4, 0.9774863006897650000000, 1e-15, cdf(10)); + test_absolute(5.4, 0.9999997199928290000000, 1e-15, cdf(20)); + test_absolute(10.8, 0.0002407141402518290000, 1e-16, cdf(1)); + test_absolute(10.8, 0.4839692359955690000000, 1e-15, cdf(10)); + test_absolute(10.8, 0.9961800769608090000000, 1e-15, cdf(20)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Poisson| x.sf(arg); - test_almost(1.5, 0.44217459962892536, 1e-15, sf(1)); - test_almost(1.5, 0.0000005517532358246565, 1e-15, sf(10)); - test_almost(1.5, 2.3372210700347092e-17, 1e-15, sf(20)); - test_almost(5.4, 0.971093881967279, 1e-16, sf(1)); - test_almost(5.4, 0.022513699310235582, 1e-15, sf(10)); - test_almost(5.4, 0.0000002800071708975261, 1e-15, sf(20)); - test_almost(10.8, 0.9997592858597482, 1e-16, sf(1)); - test_almost(10.8, 0.5160307640044303, 1e-15, sf(10)); - test_almost(10.8, 0.003819923039191422, 1e-15, sf(20)); + test_absolute(1.5, 0.44217459962892536, 1e-15, sf(1)); + test_absolute(1.5, 0.0000005517532358246565, 1e-15, sf(10)); + test_absolute(1.5, 2.3372210700347092e-17, 1e-15, sf(20)); + test_absolute(5.4, 0.971093881967279, 1e-16, sf(1)); + test_absolute(5.4, 0.022513699310235582, 1e-15, sf(10)); + test_absolute(5.4, 0.0000002800071708975261, 1e-15, sf(20)); + test_absolute(10.8, 0.9997592858597482, 1e-16, sf(1)); + test_absolute(10.8, 0.5160307640044303, 1e-15, sf(10)); + test_absolute(10.8, 0.003819923039191422, 1e-15, sf(20)); } #[test] fn test_discrete() { - test::check_discrete_distribution(&try_create(0.3), 10); - test::check_discrete_distribution(&try_create(4.5), 30); + test::check_discrete_distribution(&create_ok(0.3), 10); + test::check_discrete_distribution(&create_ok(4.5), 30); } } From fa5569a56c4cde92118fc789774b1e209bc0410f Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 20:04:46 +0200 Subject: [PATCH 160/185] Use testing_boiler! for Triangular --- src/distribution/triangular.rs | 266 ++++++++++++++------------------- 1 file changed, 114 insertions(+), 152 deletions(-) diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index 8cb48dfa..fff9fe72 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -347,242 +347,204 @@ fn sample_unchecked(rng: &mut R, min: f64, max: f64, mode: f64) #[rustfmt::skip] #[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Triangular}; use crate::distribution::internal::*; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(min: f64, max: f64, mode: f64) -> Triangular { - let n = Triangular::new(min, max, mode); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(min: f64, max: f64, mode: f64) { - let n = try_create(min, max, mode); - assert_eq!(n.min(), min); - assert_eq!(n.max(), max); - assert_eq!(n.mode().unwrap(), mode); - } - - fn bad_create_case(min: f64, max: f64, mode: f64) { - let n = Triangular::new(min, max, mode); - assert!(n.is_err()); - } - - fn get_value(min: f64, max: f64, mode: f64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Triangular) -> T - { - let n = try_create(min, max, mode); - eval(n) - } - - fn test_case(min: f64, max: f64, mode: f64, expected: f64, eval: F) - where F: Fn(Triangular) -> f64 - { - let x = get_value(min, max, mode, eval); - assert_eq!(expected, x); - } - - fn test_almost(min: f64, max: f64, mode: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Triangular) -> f64 - { - let x = get_value(min, max, mode, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(min: f64, max: f64, mode: f64; Triangular); #[test] fn test_create() { - create_case(-1.0, 1.0, 0.0); - create_case(1.0, 2.0, 1.0); - create_case(5.0, 25.0, 25.0); - create_case(1.0e-5, 1.0e5, 1.0e-3); - create_case(0.0, 1.0, 0.9); - create_case(-4.0, -0.5, -2.0); - create_case(-13.039, 8.42, 1.17); + create_ok(-1.0, 1.0, 0.0); + create_ok(1.0, 2.0, 1.0); + create_ok(5.0, 25.0, 25.0); + create_ok(1.0e-5, 1.0e5, 1.0e-3); + create_ok(0.0, 1.0, 0.9); + create_ok(-4.0, -0.5, -2.0); + create_ok(-13.039, 8.42, 1.17); } #[test] fn test_bad_create() { - bad_create_case(0.0, 0.0, 0.0); - bad_create_case(0.0, 1.0, -0.1); - bad_create_case(0.0, 1.0, 1.1); - bad_create_case(0.0, -1.0, 0.5); - bad_create_case(2.0, 1.0, 1.5); - bad_create_case(f64::NAN, 1.0, 0.5); - bad_create_case(0.2, f64::NAN, 0.5); - bad_create_case(0.5, 1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN, f64::NAN); - bad_create_case(f64::NEG_INFINITY, 1.0, 0.5); - bad_create_case(0.0, f64::INFINITY, 0.5); + create_err(0.0, 0.0, 0.0); + create_err(0.0, 1.0, -0.1); + create_err(0.0, 1.0, 1.1); + create_err(0.0, -1.0, 0.5); + create_err(2.0, 1.0, 1.5); + create_err(f64::NAN, 1.0, 0.5); + create_err(0.2, f64::NAN, 0.5); + create_err(0.5, 1.0, f64::NAN); + create_err(f64::NAN, f64::NAN, f64::NAN); + create_err(f64::NEG_INFINITY, 1.0, 0.5); + create_err(0.0, f64::INFINITY, 0.5); } #[test] fn test_variance() { let variance = |x: Triangular| x.variance().unwrap(); - test_case(0.0, 1.0, 0.5, 0.75 / 18.0, variance); - test_case(0.0, 1.0, 0.75, 0.8125 / 18.0, variance); - test_case(-5.0, 8.0, -3.5, 151.75 / 18.0, variance); - test_case(-5.0, 8.0, 5.0, 139.0 / 18.0, variance); - test_case(-5.0, -3.0, -4.0, 3.0 / 18.0, variance); - test_case(15.0, 134.0, 21.0, 13483.0 / 18.0, variance); + test_exact(0.0, 1.0, 0.5, 0.75 / 18.0, variance); + test_exact(0.0, 1.0, 0.75, 0.8125 / 18.0, variance); + test_exact(-5.0, 8.0, -3.5, 151.75 / 18.0, variance); + test_exact(-5.0, 8.0, 5.0, 139.0 / 18.0, variance); + test_exact(-5.0, -3.0, -4.0, 3.0 / 18.0, variance); + test_exact(15.0, 134.0, 21.0, 13483.0 / 18.0, variance); } #[test] fn test_entropy() { let entropy = |x: Triangular| x.entropy().unwrap(); - test_almost(0.0, 1.0, 0.5, -0.1931471805599453094172, 1e-16, entropy); - test_almost(0.0, 1.0, 0.75, -0.1931471805599453094172, 1e-16, entropy); - test_case(-5.0, 8.0, -3.5, 2.371802176901591426636, entropy); - test_case(-5.0, 8.0, 5.0, 2.371802176901591426636, entropy); - test_case(-5.0, -3.0, -4.0, 0.5, entropy); - test_case(15.0, 134.0, 21.0, 4.585976312551584075938, entropy); + test_absolute(0.0, 1.0, 0.5, -0.1931471805599453094172, 1e-16, entropy); + test_absolute(0.0, 1.0, 0.75, -0.1931471805599453094172, 1e-16, entropy); + test_exact(-5.0, 8.0, -3.5, 2.371802176901591426636, entropy); + test_exact(-5.0, 8.0, 5.0, 2.371802176901591426636, entropy); + test_exact(-5.0, -3.0, -4.0, 0.5, entropy); + test_exact(15.0, 134.0, 21.0, 4.585976312551584075938, entropy); } #[test] fn test_skewness() { let skewness = |x: Triangular| x.skewness().unwrap(); - test_case(0.0, 1.0, 0.5, 0.0, skewness); - test_case(0.0, 1.0, 0.75, -0.4224039833745502226059, skewness); - test_case(-5.0, 8.0, -3.5, 0.5375093589712976359809, skewness); - test_case(-5.0, 8.0, 5.0, -0.4445991743012595633537, skewness); - test_case(-5.0, -3.0, -4.0, 0.0, skewness); - test_case(15.0, 134.0, 21.0, 0.5605920922751860613217, skewness); + test_exact(0.0, 1.0, 0.5, 0.0, skewness); + test_exact(0.0, 1.0, 0.75, -0.4224039833745502226059, skewness); + test_exact(-5.0, 8.0, -3.5, 0.5375093589712976359809, skewness); + test_exact(-5.0, 8.0, 5.0, -0.4445991743012595633537, skewness); + test_exact(-5.0, -3.0, -4.0, 0.0, skewness); + test_exact(15.0, 134.0, 21.0, 0.5605920922751860613217, skewness); } #[test] fn test_mode() { let mode = |x: Triangular| x.mode().unwrap(); - test_case(0.0, 1.0, 0.5, 0.5, mode); - test_case(0.0, 1.0, 0.75, 0.75, mode); - test_case(-5.0, 8.0, -3.5, -3.5, mode); - test_case(-5.0, 8.0, 5.0, 5.0, mode); - test_case(-5.0, -3.0, -4.0, -4.0, mode); - test_case(15.0, 134.0, 21.0, 21.0, mode); + test_exact(0.0, 1.0, 0.5, 0.5, mode); + test_exact(0.0, 1.0, 0.75, 0.75, mode); + test_exact(-5.0, 8.0, -3.5, -3.5, mode); + test_exact(-5.0, 8.0, 5.0, 5.0, mode); + test_exact(-5.0, -3.0, -4.0, -4.0, mode); + test_exact(15.0, 134.0, 21.0, 21.0, mode); } #[test] fn test_median() { let median = |x: Triangular| x.median(); - test_case(0.0, 1.0, 0.5, 0.5, median); - test_case(0.0, 1.0, 0.75, 0.6123724356957945245493, median); - test_almost(-5.0, 8.0, -3.5, -0.6458082328952913226724, 1e-15, median); - test_almost(-5.0, 8.0, 5.0, 3.062257748298549652367, 1e-15, median); - test_case(-5.0, -3.0, -4.0, -4.0, median); - test_almost(15.0, 134.0, 21.0, 52.00304883716712238797, 1e-14, median); + test_exact(0.0, 1.0, 0.5, 0.5, median); + test_exact(0.0, 1.0, 0.75, 0.6123724356957945245493, median); + test_absolute(-5.0, 8.0, -3.5, -0.6458082328952913226724, 1e-15, median); + test_absolute(-5.0, 8.0, 5.0, 3.062257748298549652367, 1e-15, median); + test_exact(-5.0, -3.0, -4.0, -4.0, median); + test_absolute(15.0, 134.0, 21.0, 52.00304883716712238797, 1e-14, median); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Triangular| x.pdf(arg); - test_case(0.0, 1.0, 0.5, 0.0, pdf(-1.0)); - test_case(0.0, 1.0, 0.5, 0.0, pdf(1.1)); - test_case(0.0, 1.0, 0.5, 1.0, pdf(0.25)); - test_case(0.0, 1.0, 0.5, 2.0, pdf(0.5)); - test_case(0.0, 1.0, 0.5, 1.0, pdf(0.75)); - test_case(-5.0, 8.0, -3.5, 0.0, pdf(-5.1)); - test_case(-5.0, 8.0, -3.5, 0.0, pdf(8.1)); - test_case(-5.0, 8.0, -3.5, 0.1025641025641025641026, pdf(-4.0)); - test_case(-5.0, 8.0, -3.5, 0.1538461538461538461538, pdf(-3.5)); - test_case(-5.0, 8.0, -3.5, 0.05351170568561872909699, pdf(4.0)); - test_case(-5.0, -3.0, -4.0, 0.0, pdf(-5.1)); - test_case(-5.0, -3.0, -4.0, 0.0, pdf(-2.9)); - test_case(-5.0, -3.0, -4.0, 0.5, pdf(-4.5)); - test_case(-5.0, -3.0, -4.0, 1.0, pdf(-4.0)); - test_case(-5.0, -3.0, -4.0, 0.5, pdf(-3.5)); + test_exact(0.0, 1.0, 0.5, 0.0, pdf(-1.0)); + test_exact(0.0, 1.0, 0.5, 0.0, pdf(1.1)); + test_exact(0.0, 1.0, 0.5, 1.0, pdf(0.25)); + test_exact(0.0, 1.0, 0.5, 2.0, pdf(0.5)); + test_exact(0.0, 1.0, 0.5, 1.0, pdf(0.75)); + test_exact(-5.0, 8.0, -3.5, 0.0, pdf(-5.1)); + test_exact(-5.0, 8.0, -3.5, 0.0, pdf(8.1)); + test_exact(-5.0, 8.0, -3.5, 0.1025641025641025641026, pdf(-4.0)); + test_exact(-5.0, 8.0, -3.5, 0.1538461538461538461538, pdf(-3.5)); + test_exact(-5.0, 8.0, -3.5, 0.05351170568561872909699, pdf(4.0)); + test_exact(-5.0, -3.0, -4.0, 0.0, pdf(-5.1)); + test_exact(-5.0, -3.0, -4.0, 0.0, pdf(-2.9)); + test_exact(-5.0, -3.0, -4.0, 0.5, pdf(-4.5)); + test_exact(-5.0, -3.0, -4.0, 1.0, pdf(-4.0)); + test_exact(-5.0, -3.0, -4.0, 0.5, pdf(-3.5)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Triangular| x.ln_pdf(arg); - test_case(0.0, 1.0, 0.5, f64::NEG_INFINITY, ln_pdf(-1.0)); - test_case(0.0, 1.0, 0.5, f64::NEG_INFINITY, ln_pdf(1.1)); - test_case(0.0, 1.0, 0.5, 0.0, ln_pdf(0.25)); - test_case(0.0, 1.0, 0.5, 2f64.ln(), ln_pdf(0.5)); - test_case(0.0, 1.0, 0.5, 0.0, ln_pdf(0.75)); - test_case(-5.0, 8.0, -3.5, f64::NEG_INFINITY, ln_pdf(-5.1)); - test_case(-5.0, 8.0, -3.5, f64::NEG_INFINITY, ln_pdf(8.1)); - test_case(-5.0, 8.0, -3.5, 0.1025641025641025641026f64.ln(), ln_pdf(-4.0)); - test_case(-5.0, 8.0, -3.5, 0.1538461538461538461538f64.ln(), ln_pdf(-3.5)); - test_case(-5.0, 8.0, -3.5, 0.05351170568561872909699f64.ln(), ln_pdf(4.0)); - test_case(-5.0, -3.0, -4.0, f64::NEG_INFINITY, ln_pdf(-5.1)); - test_case(-5.0, -3.0, -4.0, f64::NEG_INFINITY, ln_pdf(-2.9)); - test_case(-5.0, -3.0, -4.0, 0.5f64.ln(), ln_pdf(-4.5)); - test_case(-5.0, -3.0, -4.0, 0.0, ln_pdf(-4.0)); - test_case(-5.0, -3.0, -4.0, 0.5f64.ln(), ln_pdf(-3.5)); + test_exact(0.0, 1.0, 0.5, f64::NEG_INFINITY, ln_pdf(-1.0)); + test_exact(0.0, 1.0, 0.5, f64::NEG_INFINITY, ln_pdf(1.1)); + test_exact(0.0, 1.0, 0.5, 0.0, ln_pdf(0.25)); + test_exact(0.0, 1.0, 0.5, 2f64.ln(), ln_pdf(0.5)); + test_exact(0.0, 1.0, 0.5, 0.0, ln_pdf(0.75)); + test_exact(-5.0, 8.0, -3.5, f64::NEG_INFINITY, ln_pdf(-5.1)); + test_exact(-5.0, 8.0, -3.5, f64::NEG_INFINITY, ln_pdf(8.1)); + test_exact(-5.0, 8.0, -3.5, 0.1025641025641025641026f64.ln(), ln_pdf(-4.0)); + test_exact(-5.0, 8.0, -3.5, 0.1538461538461538461538f64.ln(), ln_pdf(-3.5)); + test_exact(-5.0, 8.0, -3.5, 0.05351170568561872909699f64.ln(), ln_pdf(4.0)); + test_exact(-5.0, -3.0, -4.0, f64::NEG_INFINITY, ln_pdf(-5.1)); + test_exact(-5.0, -3.0, -4.0, f64::NEG_INFINITY, ln_pdf(-2.9)); + test_exact(-5.0, -3.0, -4.0, 0.5f64.ln(), ln_pdf(-4.5)); + test_exact(-5.0, -3.0, -4.0, 0.0, ln_pdf(-4.0)); + test_exact(-5.0, -3.0, -4.0, 0.5f64.ln(), ln_pdf(-3.5)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Triangular| x.cdf(arg); - test_case(0.0, 1.0, 0.5, 0.125, cdf(0.25)); - test_case(0.0, 1.0, 0.5, 0.5, cdf(0.5)); - test_case(0.0, 1.0, 0.5, 0.875, cdf(0.75)); - test_case(-5.0, 8.0, -3.5, 0.05128205128205128205128, cdf(-4.0)); - test_case(-5.0, 8.0, -3.5, 0.1153846153846153846154, cdf(-3.5)); - test_case(-5.0, 8.0, -3.5, 0.892976588628762541806, cdf(4.0)); - test_case(-5.0, -3.0, -4.0, 0.125, cdf(-4.5)); - test_case(-5.0, -3.0, -4.0, 0.5, cdf(-4.0)); - test_case(-5.0, -3.0, -4.0, 0.875, cdf(-3.5)); + test_exact(0.0, 1.0, 0.5, 0.125, cdf(0.25)); + test_exact(0.0, 1.0, 0.5, 0.5, cdf(0.5)); + test_exact(0.0, 1.0, 0.5, 0.875, cdf(0.75)); + test_exact(-5.0, 8.0, -3.5, 0.05128205128205128205128, cdf(-4.0)); + test_exact(-5.0, 8.0, -3.5, 0.1153846153846153846154, cdf(-3.5)); + test_exact(-5.0, 8.0, -3.5, 0.892976588628762541806, cdf(4.0)); + test_exact(-5.0, -3.0, -4.0, 0.125, cdf(-4.5)); + test_exact(-5.0, -3.0, -4.0, 0.5, cdf(-4.0)); + test_exact(-5.0, -3.0, -4.0, 0.875, cdf(-3.5)); } #[test] fn test_cdf_lower_bound() { let cdf = |arg: f64| move |x: Triangular| x.cdf(arg); - test_case(0.0, 3.0, 1.5, 0.0, cdf(-1.0)); + test_exact(0.0, 3.0, 1.5, 0.0, cdf(-1.0)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: f64| move |x: Triangular| x.cdf(arg); - test_case(0.0, 3.0, 1.5, 1.0, cdf(5.0)); + test_exact(0.0, 3.0, 1.5, 1.0, cdf(5.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Triangular| x.sf(arg); - test_case(0.0, 1.0, 0.5, 0.875, sf(0.25)); - test_case(0.0, 1.0, 0.5, 0.5, sf(0.5)); - test_case(0.0, 1.0, 0.5, 0.125, sf(0.75)); - test_case(-5.0, 8.0, -3.5, 0.9487179487179487, sf(-4.0)); - test_case(-5.0, 8.0, -3.5, 0.8846153846153846, sf(-3.5)); - test_case(-5.0, 8.0, -3.5, 0.10702341137123746, sf(4.0)); - test_case(-5.0, -3.0, -4.0, 0.875, sf(-4.5)); - test_case(-5.0, -3.0, -4.0, 0.5, sf(-4.0)); - test_case(-5.0, -3.0, -4.0, 0.125, sf(-3.5)); + test_exact(0.0, 1.0, 0.5, 0.875, sf(0.25)); + test_exact(0.0, 1.0, 0.5, 0.5, sf(0.5)); + test_exact(0.0, 1.0, 0.5, 0.125, sf(0.75)); + test_exact(-5.0, 8.0, -3.5, 0.9487179487179487, sf(-4.0)); + test_exact(-5.0, 8.0, -3.5, 0.8846153846153846, sf(-3.5)); + test_exact(-5.0, 8.0, -3.5, 0.10702341137123746, sf(4.0)); + test_exact(-5.0, -3.0, -4.0, 0.875, sf(-4.5)); + test_exact(-5.0, -3.0, -4.0, 0.5, sf(-4.0)); + test_exact(-5.0, -3.0, -4.0, 0.125, sf(-3.5)); } #[test] fn test_sf_lower_bound() { let sf = |arg: f64| move |x: Triangular| x.sf(arg); - test_case(0.0, 3.0, 1.5, 1.0, sf(-1.0)); + test_exact(0.0, 3.0, 1.5, 1.0, sf(-1.0)); } #[test] fn test_sf_upper_bound() { let sf = |arg: f64| move |x: Triangular| x.sf(arg); - test_case(0.0, 3.0, 1.5, 0.0, sf(5.0)); + test_exact(0.0, 3.0, 1.5, 0.0, sf(5.0)); } #[test] fn test_inverse_cdf() { let func = |arg: f64| move |x: Triangular| x.inverse_cdf(x.cdf(arg)); - test_almost(0.0, 1.0, 0.5, 0.25, 1e-15, func(0.25)); - test_almost(0.0, 1.0, 0.5, 0.5, 1e-15, func(0.5)); - test_almost(0.0, 1.0, 0.5, 0.75, 1e-15, func(0.75)); - test_almost(-5.0, 8.0, -3.5, -4.0, 1e-15, func(-4.0)); - test_almost(-5.0, 8.0, -3.5, -3.5, 1e-15, func(-3.5)); - test_almost(-5.0, 8.0, -3.5, 4.0, 1e-15, func(4.0)); - test_almost(-5.0, -3.0, -4.0, -4.5, 1e-15, func(-4.5)); - test_almost(-5.0, -3.0, -4.0, -4.0, 1e-15, func(-4.0)); - test_almost(-5.0, -3.0, -4.0, -3.5, 1e-15, func(-3.5)); + test_absolute(0.0, 1.0, 0.5, 0.25, 1e-15, func(0.25)); + test_absolute(0.0, 1.0, 0.5, 0.5, 1e-15, func(0.5)); + test_absolute(0.0, 1.0, 0.5, 0.75, 1e-15, func(0.75)); + test_absolute(-5.0, 8.0, -3.5, -4.0, 1e-15, func(-4.0)); + test_absolute(-5.0, 8.0, -3.5, -3.5, 1e-15, func(-3.5)); + test_absolute(-5.0, 8.0, -3.5, 4.0, 1e-15, func(4.0)); + test_absolute(-5.0, -3.0, -4.0, -4.5, 1e-15, func(-4.5)); + test_absolute(-5.0, -3.0, -4.0, -4.0, 1e-15, func(-4.0)); + test_absolute(-5.0, -3.0, -4.0, -3.5, 1e-15, func(-3.5)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(-5.0, 5.0, 0.0), -5.0, 5.0); - test::check_continuous_distribution(&try_create(-15.0, -2.0, -3.0), -15.0, -2.0); + test::check_continuous_distribution(&create_ok(-5.0, 5.0, 0.0), -5.0, 5.0); + test::check_continuous_distribution(&create_ok(-15.0, -2.0, -3.0), -15.0, -2.0); } } From e7ae56e29393c44826de198c1b608b6ee1745c79 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 20:05:58 +0200 Subject: [PATCH 161/185] Use testing_boiler! for Uniform --- src/distribution/uniform.rs | 215 +++++++++++++++--------------------- 1 file changed, 89 insertions(+), 126 deletions(-) diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index 04578a58..4186df71 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -284,216 +284,179 @@ impl Continuous for Uniform { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Uniform}; use crate::distribution::internal::*; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(min: f64, max: f64) -> Uniform { - let n = Uniform::new(min, max); - assert!(n.is_ok(), "failed create over interval [{}, {}]", min, max); - n.unwrap() - } - - fn create_case(min: f64, max: f64) { - let n = try_create(min, max); - assert_eq!(n.min(), min); - assert_eq!(n.max(), max); - } - - fn bad_create_case(min: f64, max: f64) { - let n = Uniform::new(min, max); - assert!(n.is_err()); - } - - fn get_value(min: f64, max: f64, eval: F) -> f64 - where F: Fn(Uniform) -> f64 - { - let n = try_create(min, max); - eval(n) - } - - fn test_case(min: f64, max: f64, expected: f64, eval: F) - where F: Fn(Uniform) -> f64 - { - - let x = get_value(min, max, eval); - assert_eq!(expected, x); - } - - fn test_almost(min: f64, max: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Uniform) -> f64 - { - - let x = get_value(min, max, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(min: f64, max: f64; Uniform); #[test] fn test_create() { - create_case(0.0, 0.1); - create_case(0.0, 1.0); - create_case(-5.0, 11.0); - create_case(-5.0, 100.0); + create_ok(0.0, 0.1); + create_ok(0.0, 1.0); + create_ok(-5.0, 11.0); + create_ok(-5.0, 100.0); } #[test] fn test_bad_create() { - bad_create_case(0.0, 0.0); - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN); - bad_create_case(0.0, f64::INFINITY); - bad_create_case(1.0, 0.0); + create_err(0.0, 0.0); + create_err(f64::NAN, 1.0); + create_err(1.0, f64::NAN); + create_err(f64::NAN, f64::NAN); + create_err(0.0, f64::INFINITY); + create_err(1.0, 0.0); } #[test] fn test_variance() { let variance = |x: Uniform| x.variance().unwrap(); - test_case(-0.0, 2.0, 1.0 / 3.0, variance); - test_case(0.0, 2.0, 1.0 / 3.0, variance); - test_almost(0.1, 4.0, 1.2675, 1e-15, variance); - test_case(10.0, 11.0, 1.0 / 12.0, variance); + test_exact(-0.0, 2.0, 1.0 / 3.0, variance); + test_exact(0.0, 2.0, 1.0 / 3.0, variance); + test_absolute(0.1, 4.0, 1.2675, 1e-15, variance); + test_exact(10.0, 11.0, 1.0 / 12.0, variance); } #[test] fn test_entropy() { let entropy = |x: Uniform| x.entropy().unwrap(); - test_case(-0.0, 2.0, 0.6931471805599453094172, entropy); - test_case(0.0, 2.0, 0.6931471805599453094172, entropy); - test_almost(0.1, 4.0, 1.360976553135600743431, 1e-15, entropy); - test_case(1.0, 10.0, 2.19722457733621938279, entropy); - test_case(10.0, 11.0, 0.0, entropy); + test_exact(-0.0, 2.0, 0.6931471805599453094172, entropy); + test_exact(0.0, 2.0, 0.6931471805599453094172, entropy); + test_absolute(0.1, 4.0, 1.360976553135600743431, 1e-15, entropy); + test_exact(1.0, 10.0, 2.19722457733621938279, entropy); + test_exact(10.0, 11.0, 0.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Uniform| x.skewness().unwrap(); - test_case(-0.0, 2.0, 0.0, skewness); - test_case(0.0, 2.0, 0.0, skewness); - test_case(0.1, 4.0, 0.0, skewness); - test_case(1.0, 10.0, 0.0, skewness); - test_case(10.0, 11.0, 0.0, skewness); + test_exact(-0.0, 2.0, 0.0, skewness); + test_exact(0.0, 2.0, 0.0, skewness); + test_exact(0.1, 4.0, 0.0, skewness); + test_exact(1.0, 10.0, 0.0, skewness); + test_exact(10.0, 11.0, 0.0, skewness); } #[test] fn test_mode() { let mode = |x: Uniform| x.mode().unwrap(); - test_case(-0.0, 2.0, 1.0, mode); - test_case(0.0, 2.0, 1.0, mode); - test_case(0.1, 4.0, 2.05, mode); - test_case(1.0, 10.0, 5.5, mode); - test_case(10.0, 11.0, 10.5, mode); + test_exact(-0.0, 2.0, 1.0, mode); + test_exact(0.0, 2.0, 1.0, mode); + test_exact(0.1, 4.0, 2.05, mode); + test_exact(1.0, 10.0, 5.5, mode); + test_exact(10.0, 11.0, 10.5, mode); } #[test] fn test_median() { let median = |x: Uniform| x.median(); - test_case(-0.0, 2.0, 1.0, median); - test_case(0.0, 2.0, 1.0, median); - test_case(0.1, 4.0, 2.05, median); - test_case(1.0, 10.0, 5.5, median); - test_case(10.0, 11.0, 10.5, median); + test_exact(-0.0, 2.0, 1.0, median); + test_exact(0.0, 2.0, 1.0, median); + test_exact(0.1, 4.0, 2.05, median); + test_exact(1.0, 10.0, 5.5, median); + test_exact(10.0, 11.0, 10.5, median); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Uniform| x.pdf(arg); - test_case(0.0, 0.1, 0.0, pdf(-5.0)); - test_case(0.0, 0.1, 10.0, pdf(0.05)); - test_case(0.0, 0.1, 0.0, pdf(5.0)); - test_case(0.0, 1.0, 0.0, pdf(-5.0)); - test_case(0.0, 1.0, 1.0, pdf(0.5)); - test_case(0.0, 0.1, 0.0, pdf(5.0)); - test_case(0.0, 10.0, 0.0, pdf(-5.0)); - test_case(0.0, 10.0, 0.1, pdf(1.0)); - test_case(0.0, 10.0, 0.1, pdf(5.0)); - test_case(0.0, 10.0, 0.0, pdf(11.0)); - test_case(-5.0, 100.0, 0.0, pdf(-10.0)); - test_case(-5.0, 100.0, 0.009523809523809523809524, pdf(-5.0)); - test_case(-5.0, 100.0, 0.009523809523809523809524, pdf(0.0)); - test_case(-5.0, 100.0, 0.0, pdf(101.0)); + test_exact(0.0, 0.1, 0.0, pdf(-5.0)); + test_exact(0.0, 0.1, 10.0, pdf(0.05)); + test_exact(0.0, 0.1, 0.0, pdf(5.0)); + test_exact(0.0, 1.0, 0.0, pdf(-5.0)); + test_exact(0.0, 1.0, 1.0, pdf(0.5)); + test_exact(0.0, 0.1, 0.0, pdf(5.0)); + test_exact(0.0, 10.0, 0.0, pdf(-5.0)); + test_exact(0.0, 10.0, 0.1, pdf(1.0)); + test_exact(0.0, 10.0, 0.1, pdf(5.0)); + test_exact(0.0, 10.0, 0.0, pdf(11.0)); + test_exact(-5.0, 100.0, 0.0, pdf(-10.0)); + test_exact(-5.0, 100.0, 0.009523809523809523809524, pdf(-5.0)); + test_exact(-5.0, 100.0, 0.009523809523809523809524, pdf(0.0)); + test_exact(-5.0, 100.0, 0.0, pdf(101.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Uniform| x.ln_pdf(arg); - test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_almost(0.0, 0.1, 2.302585092994045684018, 1e-15, ln_pdf(0.05)); - test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0)); - test_case(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(0.0, 1.0, 0.0, ln_pdf(0.5)); - test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0)); - test_case(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(0.0, 10.0, -2.302585092994045684018, ln_pdf(1.0)); - test_case(0.0, 10.0, -2.302585092994045684018, ln_pdf(5.0)); - test_case(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(11.0)); - test_case(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(-10.0)); - test_case(-5.0, 100.0, -4.653960350157523371101, ln_pdf(-5.0)); - test_case(-5.0, 100.0, -4.653960350157523371101, ln_pdf(0.0)); - test_case(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(101.0)); + test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(-5.0)); + test_absolute(0.0, 0.1, 2.302585092994045684018, 1e-15, ln_pdf(0.05)); + test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0)); + test_exact(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0)); + test_exact(0.0, 1.0, 0.0, ln_pdf(0.5)); + test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0)); + test_exact(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(-5.0)); + test_exact(0.0, 10.0, -2.302585092994045684018, ln_pdf(1.0)); + test_exact(0.0, 10.0, -2.302585092994045684018, ln_pdf(5.0)); + test_exact(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(11.0)); + test_exact(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(-10.0)); + test_exact(-5.0, 100.0, -4.653960350157523371101, ln_pdf(-5.0)); + test_exact(-5.0, 100.0, -4.653960350157523371101, ln_pdf(0.0)); + test_exact(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(101.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Uniform| x.cdf(arg); - test_case(0.0, 0.1, 0.5, cdf(0.05)); - test_case(0.0, 1.0, 0.5, cdf(0.5)); - test_case(0.0, 10.0, 0.1, cdf(1.0)); - test_case(0.0, 10.0, 0.5, cdf(5.0)); - test_case(-5.0, 100.0, 0.0, cdf(-5.0)); - test_case(-5.0, 100.0, 0.04761904761904761904762, cdf(0.0)); + test_exact(0.0, 0.1, 0.5, cdf(0.05)); + test_exact(0.0, 1.0, 0.5, cdf(0.5)); + test_exact(0.0, 10.0, 0.1, cdf(1.0)); + test_exact(0.0, 10.0, 0.5, cdf(5.0)); + test_exact(-5.0, 100.0, 0.0, cdf(-5.0)); + test_exact(-5.0, 100.0, 0.04761904761904761904762, cdf(0.0)); } #[test] fn test_inverse_cdf() { let inverse_cdf = |arg: f64| move |x: Uniform| x.inverse_cdf(arg); - test_case(0.0, 0.1, 0.05, inverse_cdf(0.5)); - test_case(0.0, 10.0, 5.0, inverse_cdf(0.5)); - test_case(1.0, 10.0, 1.0, inverse_cdf(0.0)); - test_case(1.0, 10.0, 4.0, inverse_cdf(1.0 / 3.0)); - test_case(1.0, 10.0, 10.0, inverse_cdf(1.0)); + test_exact(0.0, 0.1, 0.05, inverse_cdf(0.5)); + test_exact(0.0, 10.0, 5.0, inverse_cdf(0.5)); + test_exact(1.0, 10.0, 1.0, inverse_cdf(0.0)); + test_exact(1.0, 10.0, 4.0, inverse_cdf(1.0 / 3.0)); + test_exact(1.0, 10.0, 10.0, inverse_cdf(1.0)); } #[test] fn test_cdf_lower_bound() { let cdf = |arg: f64| move |x: Uniform| x.cdf(arg); - test_case(0.0, 3.0, 0.0, cdf(-1.0)); + test_exact(0.0, 3.0, 0.0, cdf(-1.0)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: f64| move |x: Uniform| x.cdf(arg); - test_case(0.0, 3.0, 1.0, cdf(5.0)); + test_exact(0.0, 3.0, 1.0, cdf(5.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Uniform| x.sf(arg); - test_case(0.0, 0.1, 0.5, sf(0.05)); - test_case(0.0, 1.0, 0.5, sf(0.5)); - test_case(0.0, 10.0, 0.9, sf(1.0)); - test_case(0.0, 10.0, 0.5, sf(5.0)); - test_case(-5.0, 100.0, 1.0, sf(-5.0)); - test_case(-5.0, 100.0, 0.9523809523809523, sf(0.0)); + test_exact(0.0, 0.1, 0.5, sf(0.05)); + test_exact(0.0, 1.0, 0.5, sf(0.5)); + test_exact(0.0, 10.0, 0.9, sf(1.0)); + test_exact(0.0, 10.0, 0.5, sf(5.0)); + test_exact(-5.0, 100.0, 1.0, sf(-5.0)); + test_exact(-5.0, 100.0, 0.9523809523809523, sf(0.0)); } #[test] fn test_sf_lower_bound() { let sf = |arg: f64| move |x: Uniform| x.sf(arg); - test_case(0.0, 3.0, 1.0, sf(-1.0)); + test_exact(0.0, 3.0, 1.0, sf(-1.0)); } #[test] fn test_sf_upper_bound() { let sf = |arg: f64| move |x: Uniform| x.sf(arg); - test_case(0.0, 3.0, 0.0, sf(5.0)); + test_exact(0.0, 3.0, 0.0, sf(5.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(0.0, 10.0), 0.0, 10.0); - test::check_continuous_distribution(&try_create(-2.0, 15.0), -2.0, 15.0); + test::check_continuous_distribution(&create_ok(0.0, 10.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(-2.0, 15.0), -2.0, 15.0); } #[test] @@ -511,7 +474,7 @@ mod tests { let min = -0.5; let max = 0.5; let num_trials = 10_000; - let n = try_create(min, max); + let n = create_ok(min, max); assert!((0..num_trials) .map(|_| n.sample::(&mut r)) From 8c49218deb50ebfcf0226f753c40533857031bb2 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 20:06:55 +0200 Subject: [PATCH 162/185] Use testing_boiler! for Weibull --- src/distribution/weibull.rs | 237 +++++++++++++++--------------------- 1 file changed, 101 insertions(+), 136 deletions(-) diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 1382998c..2d3a8a87 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -350,216 +350,181 @@ impl Continuous for Weibull { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::*; use crate::distribution::{ContinuousCDF, Continuous, Weibull}; use crate::distribution::internal::*; + use crate::statistics::*; + use crate::testing_boiler; - fn try_create(shape: f64, scale: f64) -> Weibull { - let n = Weibull::new(shape, scale); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(shape: f64, scale: f64) { - let n = try_create(shape, scale); - assert_eq!(shape, n.shape()); - assert_eq!(scale, n.scale()); - } - - fn bad_create_case(shape: f64, scale: f64) { - let n = Weibull::new(shape, scale); - assert!(n.is_err()); - } - - fn get_value(shape: f64, scale: f64, eval: F) -> f64 - where F: Fn(Weibull) -> f64 - { - let n = try_create(shape, scale); - eval(n) - } - - fn test_case(shape: f64, scale: f64, expected: f64, eval: F) - where F: Fn(Weibull) -> f64 - { - let x = get_value(shape, scale, eval); - assert_eq!(expected, x); - } - - fn test_almost(shape: f64, scale: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Weibull) -> f64 - { - let x = get_value(shape, scale, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(shape: f64, scale: f64; Weibull); #[test] fn test_create() { - create_case(1.0, 0.1); - create_case(10.0, 1.0); - create_case(11.0, 10.0); - create_case(12.0, f64::INFINITY); + create_ok(1.0, 0.1); + create_ok(10.0, 1.0); + create_ok(11.0, 10.0); + create_ok(12.0, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN); - bad_create_case(1.0, -1.0); - bad_create_case(-1.0, 1.0); - bad_create_case(-1.0, -1.0); - bad_create_case(0.0, 0.0); - bad_create_case(0.0, 1.0); - bad_create_case(1.0, 0.0); + create_err(f64::NAN, 1.0); + create_err(1.0, f64::NAN); + create_err(f64::NAN, f64::NAN); + create_err(1.0, -1.0); + create_err(-1.0, 1.0); + create_err(-1.0, -1.0); + create_err(0.0, 0.0); + create_err(0.0, 1.0); + create_err(1.0, 0.0); } #[test] fn test_mean() { let mean = |x: Weibull| x.mean().unwrap(); - test_case(1.0, 0.1, 0.1, mean); - test_case(1.0, 1.0, 1.0, mean); - test_almost(10.0, 10.0, 9.5135076986687318362924871772654021925505786260884, 1e-14, mean); - test_almost(10.0, 1.0, 0.95135076986687318362924871772654021925505786260884, 1e-15, mean); + test_exact(1.0, 0.1, 0.1, mean); + test_exact(1.0, 1.0, 1.0, mean); + test_absolute(10.0, 10.0, 9.5135076986687318362924871772654021925505786260884, 1e-14, mean); + test_absolute(10.0, 1.0, 0.95135076986687318362924871772654021925505786260884, 1e-15, mean); } #[test] fn test_variance() { let variance = |x: Weibull| x.variance().unwrap(); - test_almost(1.0, 0.1, 0.01, 1e-16, variance); - test_almost(1.0, 1.0, 1.0, 1e-14, variance); - test_almost(10.0, 10.0, 1.3100455073468309147154581687505295026863354547057, 1e-12, variance); - test_almost(10.0, 1.0, 0.013100455073468309147154581687505295026863354547057, 1e-14, variance); + test_absolute(1.0, 0.1, 0.01, 1e-16, variance); + test_absolute(1.0, 1.0, 1.0, 1e-14, variance); + test_absolute(10.0, 10.0, 1.3100455073468309147154581687505295026863354547057, 1e-12, variance); + test_absolute(10.0, 1.0, 0.013100455073468309147154581687505295026863354547057, 1e-14, variance); } #[test] fn test_entropy() { let entropy = |x: Weibull| x.entropy().unwrap(); - test_almost(1.0, 0.1, -1.302585092994045684018, 1e-15, entropy); - test_case(1.0, 1.0, 1.0, entropy); - test_case(10.0, 10.0, 1.519494098411379574546, entropy); - test_almost(10.0, 1.0, -0.783090994582666109472, 1e-15, entropy); + test_absolute(1.0, 0.1, -1.302585092994045684018, 1e-15, entropy); + test_exact(1.0, 1.0, 1.0, entropy); + test_exact(10.0, 10.0, 1.519494098411379574546, entropy); + test_absolute(10.0, 1.0, -0.783090994582666109472, 1e-15, entropy); } #[test] fn test_skewnewss() { let skewness = |x: Weibull| x.skewness().unwrap(); - test_almost(1.0, 0.1, 2.0, 1e-13, skewness); - test_almost(1.0, 1.0, 2.0, 1e-13, skewness); - test_almost(10.0, 10.0, -0.63763713390314440916597757156663888653981696212127, 1e-11, skewness); - test_almost(10.0, 1.0, -0.63763713390314440916597757156663888653981696212127, 1e-11, skewness); + test_absolute(1.0, 0.1, 2.0, 1e-13, skewness); + test_absolute(1.0, 1.0, 2.0, 1e-13, skewness); + test_absolute(10.0, 10.0, -0.63763713390314440916597757156663888653981696212127, 1e-11, skewness); + test_absolute(10.0, 1.0, -0.63763713390314440916597757156663888653981696212127, 1e-11, skewness); } #[test] fn test_median() { let median = |x: Weibull| x.median(); - test_case(1.0, 0.1, 0.069314718055994530941723212145817656807550013436026, median); - test_case(1.0, 1.0, 0.69314718055994530941723212145817656807550013436026, median); - test_case(10.0, 10.0, 9.6401223546778973665856033763604752124634905617583, median); - test_case(10.0, 1.0, 0.96401223546778973665856033763604752124634905617583, median); + test_exact(1.0, 0.1, 0.069314718055994530941723212145817656807550013436026, median); + test_exact(1.0, 1.0, 0.69314718055994530941723212145817656807550013436026, median); + test_exact(10.0, 10.0, 9.6401223546778973665856033763604752124634905617583, median); + test_exact(10.0, 1.0, 0.96401223546778973665856033763604752124634905617583, median); } #[test] fn test_mode() { let mode = |x: Weibull| x.mode().unwrap(); - test_case(1.0, 0.1, 0.0, mode); - test_case(1.0, 1.0, 0.0, mode); - test_case(10.0, 10.0, 9.8951925820621439264623017041980483215553841533709, mode); - test_case(10.0, 1.0, 0.98951925820621439264623017041980483215553841533709, mode); + test_exact(1.0, 0.1, 0.0, mode); + test_exact(1.0, 1.0, 0.0, mode); + test_exact(10.0, 10.0, 9.8951925820621439264623017041980483215553841533709, mode); + test_exact(10.0, 1.0, 0.98951925820621439264623017041980483215553841533709, mode); } #[test] fn test_min_max() { let min = |x: Weibull| x.min(); let max = |x: Weibull| x.max(); - test_case(1.0, 1.0, 0.0, min); - test_case(1.0, 1.0, f64::INFINITY, max); + test_exact(1.0, 1.0, 0.0, min); + test_exact(1.0, 1.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Weibull| x.pdf(arg); - test_case(1.0, 0.1, 10.0, pdf(0.0)); - test_case(1.0, 0.1, 0.00045399929762484851535591515560550610237918088866565, pdf(1.0)); - test_case(1.0, 0.1, 3.7200759760208359629596958038631183373588922923768e-43, pdf(10.0)); - test_case(1.0, 1.0, 1.0, pdf(0.0)); - test_case(1.0, 1.0, 0.36787944117144232159552377016146086744581113103177, pdf(1.0)); - test_case(1.0, 1.0, 0.000045399929762484851535591515560550610237918088866565, pdf(10.0)); - test_case(10.0, 10.0, 0.0, pdf(0.0)); - test_almost(10.0, 10.0, 9.9999999990000000000499999999983333333333750000000e-10, 1e-24, pdf(1.0)); - test_case(10.0, 10.0, 0.36787944117144232159552377016146086744581113103177, pdf(10.0)); - test_case(10.0, 1.0, 0.0, pdf(0.0)); - test_case(10.0, 1.0, 3.6787944117144232159552377016146086744581113103177, pdf(1.0)); - test_case(10.0, 1.0, 0.0, pdf(10.0)); + test_exact(1.0, 0.1, 10.0, pdf(0.0)); + test_exact(1.0, 0.1, 0.00045399929762484851535591515560550610237918088866565, pdf(1.0)); + test_exact(1.0, 0.1, 3.7200759760208359629596958038631183373588922923768e-43, pdf(10.0)); + test_exact(1.0, 1.0, 1.0, pdf(0.0)); + test_exact(1.0, 1.0, 0.36787944117144232159552377016146086744581113103177, pdf(1.0)); + test_exact(1.0, 1.0, 0.000045399929762484851535591515560550610237918088866565, pdf(10.0)); + test_exact(10.0, 10.0, 0.0, pdf(0.0)); + test_absolute(10.0, 10.0, 9.9999999990000000000499999999983333333333750000000e-10, 1e-24, pdf(1.0)); + test_exact(10.0, 10.0, 0.36787944117144232159552377016146086744581113103177, pdf(10.0)); + test_exact(10.0, 1.0, 0.0, pdf(0.0)); + test_exact(10.0, 1.0, 3.6787944117144232159552377016146086744581113103177, pdf(1.0)); + test_exact(10.0, 1.0, 0.0, pdf(10.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Weibull| x.ln_pdf(arg); - test_almost(1.0, 0.1, 2.3025850929940456840179914546843642076011014886288, 1e-15, ln_pdf(0.0)); - test_almost(1.0, 0.1, -7.6974149070059543159820085453156357923988985113712, 1e-15, ln_pdf(1.0)); - test_case(1.0, 0.1, -97.697414907005954315982008545315635792398898511371, ln_pdf(10.0)); - test_case(1.0, 1.0, 0.0, ln_pdf(0.0)); - test_case(1.0, 1.0, -1.0, ln_pdf(1.0)); - test_case(1.0, 1.0, -10.0, ln_pdf(10.0)); - test_case(10.0, 10.0, f64::NEG_INFINITY, ln_pdf(0.0)); - test_almost(10.0, 10.0, -20.723265837046411156161923092159277868409913397659, 1e-14, ln_pdf(1.0)); - test_case(10.0, 10.0, -1.0, ln_pdf(10.0)); - test_case(10.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); - test_almost(10.0, 1.0, 1.3025850929940456840179914546843642076011014886288, 1e-15, ln_pdf(1.0)); - test_case(10.0, 1.0, -9.999999976974149070059543159820085453156357923988985113712e9, ln_pdf(10.0)); + test_absolute(1.0, 0.1, 2.3025850929940456840179914546843642076011014886288, 1e-15, ln_pdf(0.0)); + test_absolute(1.0, 0.1, -7.6974149070059543159820085453156357923988985113712, 1e-15, ln_pdf(1.0)); + test_exact(1.0, 0.1, -97.697414907005954315982008545315635792398898511371, ln_pdf(10.0)); + test_exact(1.0, 1.0, 0.0, ln_pdf(0.0)); + test_exact(1.0, 1.0, -1.0, ln_pdf(1.0)); + test_exact(1.0, 1.0, -10.0, ln_pdf(10.0)); + test_exact(10.0, 10.0, f64::NEG_INFINITY, ln_pdf(0.0)); + test_absolute(10.0, 10.0, -20.723265837046411156161923092159277868409913397659, 1e-14, ln_pdf(1.0)); + test_exact(10.0, 10.0, -1.0, ln_pdf(10.0)); + test_exact(10.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); + test_absolute(10.0, 1.0, 1.3025850929940456840179914546843642076011014886288, 1e-15, ln_pdf(1.0)); + test_exact(10.0, 1.0, -9.999999976974149070059543159820085453156357923988985113712e9, ln_pdf(10.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Weibull| x.cdf(arg); - test_case(1.0, 0.1, 0.0, cdf(0.0)); - test_case(1.0, 0.1, 0.99995460007023751514846440848443944938976208191113, cdf(1.0)); - test_case(1.0, 0.1, 0.99999999999999999999999999999999999999999996279924, cdf(10.0)); - test_case(1.0, 1.0, 0.0, cdf(0.0)); - test_case(1.0, 1.0, 0.63212055882855767840447622983853913255418886896823, cdf(1.0)); - test_case(1.0, 1.0, 0.99995460007023751514846440848443944938976208191113, cdf(10.0)); - test_case(10.0, 10.0, 0.0, cdf(0.0)); - test_almost(10.0, 10.0, 9.9999999995000000000166666666662500000000083333333e-11, 1e-25, cdf(1.0)); - test_case(10.0, 10.0, 0.63212055882855767840447622983853913255418886896823, cdf(10.0)); - test_case(10.0, 1.0, 0.0, cdf(0.0)); - test_case(10.0, 1.0, 0.63212055882855767840447622983853913255418886896823, cdf(1.0)); - test_case(10.0, 1.0, 1.0, cdf(10.0)); + test_exact(1.0, 0.1, 0.0, cdf(0.0)); + test_exact(1.0, 0.1, 0.99995460007023751514846440848443944938976208191113, cdf(1.0)); + test_exact(1.0, 0.1, 0.99999999999999999999999999999999999999999996279924, cdf(10.0)); + test_exact(1.0, 1.0, 0.0, cdf(0.0)); + test_exact(1.0, 1.0, 0.63212055882855767840447622983853913255418886896823, cdf(1.0)); + test_exact(1.0, 1.0, 0.99995460007023751514846440848443944938976208191113, cdf(10.0)); + test_exact(10.0, 10.0, 0.0, cdf(0.0)); + test_absolute(10.0, 10.0, 9.9999999995000000000166666666662500000000083333333e-11, 1e-25, cdf(1.0)); + test_exact(10.0, 10.0, 0.63212055882855767840447622983853913255418886896823, cdf(10.0)); + test_exact(10.0, 1.0, 0.0, cdf(0.0)); + test_exact(10.0, 1.0, 0.63212055882855767840447622983853913255418886896823, cdf(1.0)); + test_exact(10.0, 1.0, 1.0, cdf(10.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Weibull| x.sf(arg); - test_case(1.0, 0.1, 1.0, sf(0.0)); - test_case(1.0, 0.1, 4.5399929762484854e-5, sf(1.0)); - test_case(1.0, 0.1, 3.720075976020836e-44, sf(10.0)); - test_case(1.0, 1.0, 1.0, sf(0.0)); - test_case(1.0, 1.0, 0.36787944117144233, sf(1.0)); - test_case(1.0, 1.0, 4.5399929762484854e-5, sf(10.0)); - test_case(10.0, 10.0, 1.0, sf(0.0)); - test_almost(10.0, 10.0, 0.9999999999, 1e-25, sf(1.0)); - test_case(10.0, 10.0, 0.36787944117144233, sf(10.0)); - test_case(10.0, 1.0, 1.0, sf(0.0)); - test_case(10.0, 1.0, 0.36787944117144233, sf(1.0)); - test_case(10.0, 1.0, 0.0, sf(10.0)); + test_exact(1.0, 0.1, 1.0, sf(0.0)); + test_exact(1.0, 0.1, 4.5399929762484854e-5, sf(1.0)); + test_exact(1.0, 0.1, 3.720075976020836e-44, sf(10.0)); + test_exact(1.0, 1.0, 1.0, sf(0.0)); + test_exact(1.0, 1.0, 0.36787944117144233, sf(1.0)); + test_exact(1.0, 1.0, 4.5399929762484854e-5, sf(10.0)); + test_exact(10.0, 10.0, 1.0, sf(0.0)); + test_absolute(10.0, 10.0, 0.9999999999, 1e-25, sf(1.0)); + test_exact(10.0, 10.0, 0.36787944117144233, sf(10.0)); + test_exact(10.0, 1.0, 1.0, sf(0.0)); + test_exact(10.0, 1.0, 0.36787944117144233, sf(1.0)); + test_exact(10.0, 1.0, 0.0, sf(10.0)); } #[test] fn test_inverse_cdf() { let func = |arg: f64| move |x: Weibull| x.inverse_cdf(x.cdf(arg)); - test_case(1.0, 0.1, 0.0, func(0.0)); - test_almost(1.0, 0.1, 1.0, 1e-13, func(1.0)); - test_case(1.0, 1.0, 0.0, func(0.0)); - test_case(1.0, 1.0, 1.0, func(1.0)); - test_almost(1.0, 1.0, 10.0, 1e-10, func(10.0)); - test_case(10.0, 10.0, 0.0, func(0.0)); - test_almost(10.0, 10.0, 1.0, 1e-5, func(1.0)); - test_almost(10.0, 10.0, 10.0, 1e-10, func(10.0)); - test_case(10.0, 1.0, 0.0, func(0.0)); - test_case(10.0, 1.0, 1.0, func(1.0)); + test_exact(1.0, 0.1, 0.0, func(0.0)); + test_absolute(1.0, 0.1, 1.0, 1e-13, func(1.0)); + test_exact(1.0, 1.0, 0.0, func(0.0)); + test_exact(1.0, 1.0, 1.0, func(1.0)); + test_absolute(1.0, 1.0, 10.0, 1e-10, func(10.0)); + test_exact(10.0, 10.0, 0.0, func(0.0)); + test_absolute(10.0, 10.0, 1.0, 1e-5, func(1.0)); + test_absolute(10.0, 10.0, 10.0, 1e-10, func(10.0)); + test_exact(10.0, 1.0, 0.0, func(0.0)); + test_exact(10.0, 1.0, 1.0, func(1.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(1.0, 0.2), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(1.0, 0.2), 0.0, 10.0); } } From 46f3e68b2ec04b98d438851cf4804289f3eb11e5 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 20:53:24 +0200 Subject: [PATCH 163/185] Rewrite one `#[should_panic]` test for Pareto --- src/distribution/pareto.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 7f91e0a4..5fb62044 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -390,10 +390,8 @@ mod tests { } #[test] - #[should_panic] fn test_variance_degen() { - let variance = |x: Pareto| x.variance().unwrap(); - test_exact(1.0, 1.0, f64::INFINITY, variance); // shape <= 2.0 + test_none(1.0, 1.0, |dist| dist.variance()); // shape <= 2.0 } #[test] From c428c751873cee90f2dffa421e97d2ab954ad9fa Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 21:12:30 +0200 Subject: [PATCH 164/185] Add unit tests for testing_boiler --- src/distribution/internal.rs | 90 ++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index cf69f1cd..15f63ea3 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -276,6 +276,96 @@ pub mod test { }; } + pub mod boiler_tests { + use super::*; + use crate::distribution::Binomial; + use crate::statistics::*; + + testing_boiler!(p: f64, n: u64; Binomial); + + #[test] + fn create_ok_success() { + let b = create_ok(0.8, 1200); + assert_eq!(b.p(), 0.8); + assert_eq!(b.n(), 1200); + } + + #[test] + #[should_panic] + fn create_err_failure() { + create_err(0.8, 1200); + } + + #[test] + fn create_err_success() { + let err = create_err(-0.5, 1000); + assert_eq!(err, StatsError::BadParams); + } + + #[test] + #[should_panic] + fn create_ok_failure() { + create_ok(-0.5, 1000); + } + + #[test] + fn test_exact_success() { + test_exact(0.0, 4, 0.0, |dist| dist.mean().unwrap()); + } + + #[test] + #[should_panic] + fn test_exact_failure() { + test_exact(0.3, 3, 0.9, |dist| dist.mean().unwrap()); + } + + #[test] + fn test_relative_success() { + test_relative(0.3, 3, 0.9, |dist| dist.mean().unwrap()); + } + + #[test] + #[should_panic] + fn test_relative_failure() { + test_relative(0.3, 3, 0.8, |dist| dist.mean().unwrap()); + } + + #[test] + fn test_absolute_success() { + test_absolute(0.3, 3, 0.9, 1e-15, |dist| dist.mean().unwrap()); + } + + #[test] + #[should_panic] + fn test_absolute_failure() { + test_absolute(0.3, 3, 0.9, 1e-17, |dist| dist.mean().unwrap()); + } + + #[test] + fn test_is_nan_success() { + // Not sure that any Binomial API can return a NaN, so we force the issue + test_is_nan(0.8, 1200, |_| f64::NAN); + } + + #[test] + #[should_panic] + fn test_is_nan_failure() { + test_is_nan(0.8, 1200, |dist| dist.mean().unwrap()); + } + + #[test] + fn test_is_none_success() { + // Same as test_is_nan_success, force returning `None` here + test_none(0.8, 1200, |_| Option::::None); + } + + #[test] + #[should_panic] + fn test_is_none_failure() { + test_none(0.8, 1200, |dist| dist.mean()); + } + } + /// cdf should be the integral of the pdf fn check_integrate_pdf_is_cdf + Continuous>( dist: &D, From a324820bd4cf7fb575dcda2c79e9ef8c06f4409f Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Thu, 5 Sep 2024 11:08:15 +0200 Subject: [PATCH 165/185] style: remove rustfmt.toml Options set to their default value: - wrap_comments - use_small_heuristics - fn_single_line Unnecessary options: - edition = "2021" (read from Cargo.toml) - comment_width (wrap_comments = false) Actual config changes: - blank_lines_upper_bound: 2 -> 1 - reorder_impl_items: true -> false --- rustfmt.toml | 29 ----------------------------- 1 file changed, 29 deletions(-) delete mode 100644 rustfmt.toml diff --git a/rustfmt.toml b/rustfmt.toml deleted file mode 100644 index 2f399e9c..00000000 --- a/rustfmt.toml +++ /dev/null @@ -1,29 +0,0 @@ -# This rustfmt file is added for configuration, but in practice much of our -# code is hand-formatted, frequently with more readable results. -# taken from rust-random/rand - -# Comments: -normalize_comments = true -wrap_comments = false -comment_width = 90 # small excess is okay but prefer 80 - -# Arguments: -use_small_heuristics = "Default" -# TODO: single line functions only where short, please? -# https://github.com/rust-lang/rustfmt/issues/3358 -fn_single_line = false - -# enum_discrim_align_threshold = 20 -# struct_field_align_threshold = 20 - -# Compatibility: -edition = "2021" - -# Misc: -blank_lines_upper_bound = 2 -reorder_impl_items = true -# report_todo = "Unnumbered" -# report_fixme = "Unnumbered" - -# Ignored files: -ignore = [] From 3e5393e7bba68c3a4972a3dbba15cf5e28582b71 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Thu, 5 Sep 2024 11:09:33 +0200 Subject: [PATCH 166/185] style: Run `cargo fmt` after config change --- src/distribution/categorical.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 5d26e7b5..71e09560 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -12,7 +12,6 @@ use std::f64; /// # Examples /// /// ``` -/// /// use statrs::distribution::{Categorical, Discrete}; /// use statrs::statistics::Distribution; /// use statrs::prec; From d9fa83b1cd2c53031c1285b6782ba060d688b0f0 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Thu, 5 Sep 2024 11:11:55 +0200 Subject: [PATCH 167/185] ci: use Rust stable for fmt job --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2b39f41e..1b0a2eca 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,8 +30,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Install Rust nightly with rustfmt - uses: dtolnay/rust-toolchain@nightly + - name: Install Rust stable with rustfmt + uses: dtolnay/rust-toolchain@stable with: components: rustfmt From 7285bdedf5e0b015cb72854fdb87b65214c04394 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Thu, 5 Sep 2024 11:24:23 +0200 Subject: [PATCH 168/185] docs: No more `nightly` for rustfmt --- README.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 740e5dfe..8d4ff9db 100644 --- a/README.md +++ b/README.md @@ -75,14 +75,13 @@ git checkout -b master Write your code and docs, then ensure it is formatted: -The below sample modify in-place, use `--check` flag to view diff without making file changes. -Not using `fmt` from +nightly may result in some warnings and different formatting. -Our CI will `fmt`, but less chores in commit history are appreciated. - ``` -cargo +nightly fmt +cargo fmt ``` +Add `--check` to view the diff without making file changes. +Our CI will `fmt`, but less chores in commit history are appreciated. + After commiting your code: ``` From f7689fc824aca5d84165797cff0c32d325a3cbd9 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Thu, 5 Sep 2024 18:15:03 +0200 Subject: [PATCH 169/185] fix: missing/moved type import after auto-merge --- src/distribution/internal.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 15f63ea3..ae853637 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -277,9 +277,9 @@ pub mod test { } pub mod boiler_tests { - use super::*; use crate::distribution::Binomial; use crate::statistics::*; + use crate::StatsError; testing_boiler!(p: f64, n: u64; Binomial); From d0a5b045416828829a655120f5d2111d56c99136 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 23:56:47 +0200 Subject: [PATCH 170/185] chore: Fix several rustdoc warnings --- src/distribution/gamma.rs | 10 +++------- src/distribution/hypergeometric.rs | 16 +++++----------- src/function/exponential.rs | 26 +++++++++----------------- src/stats_tests/fisher.rs | 4 ++-- 4 files changed, 19 insertions(+), 37 deletions(-) diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index e5e1a282..e986037a 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -365,15 +365,11 @@ impl Continuous for Gamma { } /// Samples from a gamma distribution with a shape of `shape` and a /// rate of `rate` using `rng` as the source of randomness. Implementation from: -///
-///
-/// "A Simple Method for Generating Gamma Variables" - Marsaglia & Tsang -///
-///
+/// +/// _"A Simple Method for Generating Gamma Variables"_ - Marsaglia & Tsang +/// /// ACM Transactions on Mathematical Software, Vol. 26, No. 3, September 2000, /// Pages 363-372 -///
-///
pub fn sample_unchecked(rng: &mut R, shape: f64, rate: f64) -> f64 { let mut a = shape; let mut afix = 1.0; diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 800e03fe..43bc30ca 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -9,11 +9,7 @@ use std::f64; /// Implements the /// [Hypergeometric](http://en.wikipedia.org/wiki/Hypergeometric_distribution) /// distribution -/// -/// # Examples -/// -/// ``` -/// ``` +// TODO: Add examples #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub struct Hypergeometric { population: u64, @@ -155,9 +151,8 @@ impl DiscreteCDF for Hypergeometric { /// ``` /// /// where `N` is population, `K` is successes, `n` is draws, - /// and `p_F_q` is the [generalized hypergeometric - /// function](https://en.wikipedia. - /// org/wiki/Generalized_hypergeometric_function) + /// and `p_F_q` is the + /// [generalized hypergeometric function](https://en.wikipedia.org/wiki/Generalized_hypergeometric_function) /// /// Calculated as a discrete integral over the probability mass /// function evaluated from 0..x+1 @@ -189,9 +184,8 @@ impl DiscreteCDF for Hypergeometric { /// ``` /// /// where `N` is population, `K` is successes, `n` is draws, - /// and `p_F_q` is the [generalized hypergeometric - /// function](https://en.wikipedia. - /// org/wiki/Generalized_hypergeometric_function) + /// and `p_F_q` is the + /// [generalized hypergeometric function](https://en.wikipedia.org/wiki/Generalized_hypergeometric_function) /// /// Calculated as a discrete integral over the probability mass /// function evaluated from (x+1)..max diff --git a/src/function/exponential.rs b/src/function/exponential.rs index 1124c55c..55280d7c 100644 --- a/src/function/exponential.rs +++ b/src/function/exponential.rs @@ -14,26 +14,18 @@ use crate::{consts, Result, StatsError}; /// # Remarks /// /// This implementation follows the derivation in -///
-///
-/// "Handbook of Mathematical Functions, Applied Mathematics Series, Volume -/// 55" - Abramowitz, M., and Stegun, I.A 1964 -///
+/// +/// _"Handbook of Mathematical Functions, Applied Mathematics Series, Volume +/// 55"_ - Abramowitz, M., and Stegun, I.A 1964 +/// /// AND -///
-///
-/// "Advanced mathematical methods for scientists and engineers" - Bender, -/// Carl M.; Steven A. Orszag (1978). page 253 -///
-///
-/// The continued fraction approac is used for `x > 1.0` while the taylor -/// series expansions -/// is used for `0.0 < x <= 1` /// -/// # Examples +/// _"Advanced mathematical methods for scientists and engineers"_ - Bender, +/// Carl M.; Steven A. Orszag (1978). page 253 /// -/// ``` -/// ``` +/// The continued fraction approach is used for `x > 1.0` while the taylor +/// series expansions is used for `0.0 < x <= 1`. +// TODO: Add examples pub fn integral(x: f64, n: u64) -> Result { let eps = 0.00000000000000001; let max_iter = 100; diff --git a/src/stats_tests/fisher.rs b/src/stats_tests/fisher.rs index a18c821d..69b41d6d 100644 --- a/src/stats_tests/fisher.rs +++ b/src/stats_tests/fisher.rs @@ -98,7 +98,7 @@ fn binary_search( } /// Perform a Fisher exact test on a 2x2 contingency table. -/// Based on scipy's fisher test: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.fisher_exact.html#scipy-stats-fisher-exact +/// Based on scipy's fisher test: /// Expects a table in row-major order /// Returns the [odds ratio](https://en.wikipedia.org/wiki/Odds_ratio) and p_value /// # Examples @@ -133,7 +133,7 @@ pub fn fishers_exact_with_odds_ratio( } /// Perform a Fisher exact test on a 2x2 contingency table. -/// Based on scipy's fisher test: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.fisher_exact.html#scipy-stats-fisher-exact +/// Based on scipy's fisher test: /// Expects a table in row-major order /// Returns only the p_value /// # Examples From 350eb961887c265fd6a3b4934056e05d2d9db55e Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Thu, 5 Sep 2024 21:15:52 +0200 Subject: [PATCH 171/185] test: Add error type to testing_boiler macro --- src/distribution/bernoulli.rs | 6 +++--- src/distribution/beta.rs | 2 +- src/distribution/binomial.rs | 5 ++--- src/distribution/categorical.rs | 5 ++--- src/distribution/cauchy.rs | 6 +++--- src/distribution/chi.rs | 6 ++---- src/distribution/chi_squared.rs | 6 +++--- src/distribution/dirac.rs | 5 ++--- src/distribution/discrete_uniform.rs | 5 ++--- src/distribution/erlang.rs | 5 +++-- src/distribution/exponential.rs | 6 ++---- src/distribution/fisher_snedecor.rs | 5 ++--- src/distribution/gamma.rs | 2 +- src/distribution/geometric.rs | 5 ++--- src/distribution/hypergeometric.rs | 5 ++--- src/distribution/internal.rs | 6 +++--- src/distribution/inverse_gamma.rs | 5 ++--- src/distribution/laplace.rs | 2 +- src/distribution/log_normal.rs | 5 ++--- src/distribution/negative_binomial.rs | 5 ++--- src/distribution/normal.rs | 5 ++--- src/distribution/pareto.rs | 5 ++--- src/distribution/poisson.rs | 5 ++--- src/distribution/students_t.rs | 6 ++---- src/distribution/triangular.rs | 5 ++--- src/distribution/uniform.rs | 5 ++--- src/distribution/weibull.rs | 5 ++--- 27 files changed, 56 insertions(+), 77 deletions(-) diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index f82c0d65..12c0db2f 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -265,11 +265,11 @@ impl Discrete for Bernoulli { #[rustfmt::skip] #[cfg(test)] mod testing { - use crate::distribution::DiscreteCDF; + use super::*; + use crate::StatsError; use crate::testing_boiler; - use super::Bernoulli; - testing_boiler!(p: f64; Bernoulli); + testing_boiler!(p: f64; Bernoulli; StatsError); #[test] fn test_create() { diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 4febc3c0..748bfcad 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -433,7 +433,7 @@ mod tests { use super::super::internal::*; use crate::testing_boiler; - testing_boiler!(a: f64, b: f64; Beta); + testing_boiler!(a: f64, b: f64; Beta; StatsError); #[test] fn test_create() { diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 8eced1d7..3536ce8a 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -328,12 +328,11 @@ impl Discrete for Binomial { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::*; - use crate::distribution::{DiscreteCDF, Discrete, Binomial}; + use super::*; use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(p: f64, n: u64; Binomial); + testing_boiler!(p: f64, n: u64; Binomial; StatsError); #[test] fn test_create() { diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 71e09560..fd24893b 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -351,12 +351,11 @@ fn test_binary_index() { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::*; - use crate::distribution::{Categorical, Discrete, DiscreteCDF}; + use super::*; use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(prob_mass: &[f64]; Categorical); + testing_boiler!(prob_mass: &[f64]; Categorical; StatsError); #[test] fn test_create() { diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index eb983847..f9e0957d 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -252,11 +252,11 @@ impl Continuous for Cauchy { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::{statistics::*, testing_boiler}; - use crate::distribution::{ContinuousCDF, Continuous, Cauchy}; + use super::*; use crate::distribution::internal::*; + use crate::testing_boiler; - testing_boiler!(location: f64, scale: f64; Cauchy); + testing_boiler!(location: f64, scale: f64; Cauchy; StatsError); #[test] fn test_create() { diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index 1bb74295..3c35336a 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -325,13 +325,11 @@ impl Continuous for Chi { #[rustfmt::skip] #[cfg(test)] mod tests { - use std::f64; + use super::*; use crate::distribution::internal::*; - use crate::distribution::{Chi, Continuous, ContinuousCDF}; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(freedom: f64; Chi); + testing_boiler!(freedom: f64; Chi; StatsError); #[test] fn test_create() { diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index afa5df71..b0b2e2f9 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -306,12 +306,12 @@ impl Continuous for ChiSquared { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::statistics::Median; - use crate::distribution::ChiSquared; + use super::*; use crate::distribution::internal::*; + use crate::StatsError; use crate::testing_boiler; - testing_boiler!(freedom: f64; ChiSquared); + testing_boiler!(freedom: f64; ChiSquared; StatsError); #[test] fn test_median() { diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index 41ac1d6c..5aa6cb9b 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -193,11 +193,10 @@ impl Mode> for Dirac { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Dirac}; - use crate::statistics::*; + use super::*; use crate::testing_boiler; - testing_boiler!(v: f64; Dirac); + testing_boiler!(v: f64; Dirac; StatsError); #[test] fn test_create() { diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index 6871c80a..086e4849 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -256,11 +256,10 @@ impl Discrete for DiscreteUniform { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{DiscreteCDF, Discrete, DiscreteUniform}; - use crate::statistics::*; + use super::*; use crate::testing_boiler; - testing_boiler!(min: i64, max: i64; DiscreteUniform); + testing_boiler!(min: i64, max: i64; DiscreteUniform; StatsError); #[test] fn test_create() { diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index ce6f68aa..68b9675d 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -293,11 +293,12 @@ impl Continuous for Erlang { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::Erlang; + use super::*; use crate::distribution::internal::*; + use crate::StatsError; use crate::testing_boiler; - testing_boiler!(shape: u64, rate: f64; Erlang); + testing_boiler!(shape: u64, rate: f64; Erlang; StatsError); #[test] fn test_create() { diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index d5a54d56..9796028d 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -279,13 +279,11 @@ impl Continuous for Exp { #[rustfmt::skip] #[cfg(test)] mod tests { - use std::f64; - use crate::distribution::{ContinuousCDF, Continuous, Exp}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(rate: f64; Exp); + testing_boiler!(rate: f64; Exp; StatsError); #[test] fn test_create() { diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index 9d5ef867..76913651 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -385,12 +385,11 @@ impl Continuous for FisherSnedecor { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, FisherSnedecor}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(freedom_1: f64, freedom_2: f64; FisherSnedecor); + testing_boiler!(freedom_1: f64, freedom_2: f64; FisherSnedecor; StatsError); #[test] fn test_create() { diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index e986037a..fae9848a 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -406,7 +406,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(shape: f64, rate: f64; Gamma); + testing_boiler!(shape: f64, rate: f64; Gamma; StatsError); #[test] fn test_create() { diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index f584cc0e..b4bda5a3 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -273,12 +273,11 @@ impl Discrete for Geometric { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{DiscreteCDF, Discrete, Geometric}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(p: f64; Geometric); + testing_boiler!(p: f64; Geometric; StatsError); #[test] fn test_create() { diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 43bc30ca..fabb2532 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -372,12 +372,11 @@ impl Discrete for Hypergeometric { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{DiscreteCDF, Discrete, Hypergeometric}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(population: u64, successes: u64, draws: u64; Hypergeometric); + testing_boiler!(population: u64, successes: u64, draws: u64; Hypergeometric; StatsError); #[test] fn test_create() { diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index ae853637..4d328d7d 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -100,7 +100,7 @@ pub mod test { #[macro_export] macro_rules! testing_boiler { - ($($arg_name:ident: $arg_ty:ty),+; $dist:ty) => { + ($($arg_name:ident: $arg_ty:ty),+; $dist:ty; $dist_err:ty) => { fn make_param_text($($arg_name: $arg_ty),+) -> String { // "" let mut param_text = String::new(); @@ -140,7 +140,7 @@ pub mod test { /// Returns the error when creating a distribution with the given parameters, /// panicking if `::new` succeeds. #[allow(dead_code)] - fn create_err($($arg_name: $arg_ty),+) -> $crate::StatsError { + fn create_err($($arg_name: $arg_ty),+) -> $dist_err { match <$dist>::new($($arg_name),+) { Err(e) => e, Ok(d) => panic!( @@ -281,7 +281,7 @@ pub mod test { use crate::statistics::*; use crate::StatsError; - testing_boiler!(p: f64, n: u64; Binomial); + testing_boiler!(p: f64, n: u64; Binomial; StatsError); #[test] fn create_ok_success() { diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index 8314a9dc..ec70f6f2 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -313,12 +313,11 @@ impl Continuous for InverseGamma { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, InverseGamma}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(shape: f64, rate: f64; InverseGamma); + testing_boiler!(shape: f64, rate: f64; InverseGamma; StatsError); #[test] fn test_create() { diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index 0ccbac78..5466e89f 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -304,7 +304,7 @@ mod tests { use crate::testing_boiler; - testing_boiler!(location: f64, scale: f64; Laplace); + testing_boiler!(location: f64, scale: f64; Laplace; StatsError); // A wrapper for the `assert_relative_eq!` macro from the approx crate. // diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 88a78996..cce0f11d 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -305,12 +305,11 @@ impl Continuous for LogNormal { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, LogNormal}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(mean: f64, std_dev: f64; LogNormal); + testing_boiler!(mean: f64, std_dev: f64; LogNormal; StatsError); #[test] fn test_create() { diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index d36d0f98..d461af6e 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -291,12 +291,11 @@ impl Discrete for NegativeBinomial { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{DiscreteCDF, Discrete, NegativeBinomial}; + use super::*; use crate::distribution::internal::test; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(r: f64, p: f64; NegativeBinomial); + testing_boiler!(r: f64, p: f64; NegativeBinomial; StatsError); #[test] fn test_create() { diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 65f6ad90..74a4fde4 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -334,12 +334,11 @@ impl std::default::Default for Normal { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, Normal}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(mean: f64, std_dev: f64; Normal); + testing_boiler!(mean: f64, std_dev: f64; Normal; StatsError); #[test] fn test_create() { diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 5fb62044..c983289b 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -354,12 +354,11 @@ impl Continuous for Pareto { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, Pareto}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(scale: f64, shape: f64; Pareto); + testing_boiler!(scale: f64, shape: f64; Pareto; StatsError); #[test] fn test_create() { diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 41b56e6a..4588d00e 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -304,12 +304,11 @@ pub fn sample_unchecked(rng: &mut R, lambda: f64) -> f64 { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{DiscreteCDF, Discrete, Poisson}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(lambda: f64; Poisson); + testing_boiler!(lambda: f64; Poisson; StatsError); #[test] fn test_create() { diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index 4bada682..a362e071 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -421,14 +421,12 @@ impl Continuous for StudentsT { #[cfg(test)] mod tests { + use super::*; use crate::consts::ACC; use crate::distribution::internal::*; - use crate::distribution::{Continuous, ContinuousCDF, StudentsT}; - use crate::statistics::*; use crate::testing_boiler; - use std::panic; - testing_boiler!(location: f64, scale: f64, freedom: f64; StudentsT); + testing_boiler!(location: f64, scale: f64, freedom: f64; StudentsT; StatsError); #[test] fn test_create() { diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index fff9fe72..2a4d31b2 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -347,12 +347,11 @@ fn sample_unchecked(rng: &mut R, min: f64, max: f64, mode: f64) #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, Triangular}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(min: f64, max: f64, mode: f64; Triangular); + testing_boiler!(min: f64, max: f64, mode: f64; Triangular; StatsError); #[test] fn test_create() { diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index 4186df71..fdf25498 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -284,12 +284,11 @@ impl Continuous for Uniform { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, Uniform}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(min: f64, max: f64; Uniform); + testing_boiler!(min: f64, max: f64; Uniform; StatsError); #[test] fn test_create() { diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 2d3a8a87..8137f664 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -350,12 +350,11 @@ impl Continuous for Weibull { #[rustfmt::skip] #[cfg(test)] mod tests { - use crate::distribution::{ContinuousCDF, Continuous, Weibull}; + use super::*; use crate::distribution::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!(shape: f64, scale: f64; Weibull); + testing_boiler!(shape: f64, scale: f64; Weibull; StatsError); #[test] fn test_create() { From 2ab248bb8009a01f2a7c42e55733ec391cd17304 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Fri, 16 Aug 2024 19:47:48 +0200 Subject: [PATCH 172/185] feat!: Concrete errors for `::new` Includes notable changes: - Add internal `test_create_err` function - Use `Beta` in unit tests for testing_boiler - Validate Dirichlet params inside `new` and remove `is_valid_alpha` function - Use `Result` for Empirical (infallible `::new` function) - Validate Multinomial params inside `new` and remove `check_multinomial` function - Add a concrete error type to fisher's exact test, too (it is dependent on Hypergeometric, which is why it's included in this change) --- src/distribution/bernoulli.rs | 8 +- src/distribution/beta.rs | 51 +++++-- src/distribution/binomial.rs | 25 +++- src/distribution/categorical.rs | 88 +++++++++--- src/distribution/cauchy.rs | 53 +++++-- src/distribution/chi.rs | 27 +++- src/distribution/chi_squared.rs | 8 +- src/distribution/dirac.rs | 27 +++- src/distribution/dirichlet.rs | 67 +++++---- src/distribution/discrete_uniform.rs | 25 +++- src/distribution/empirical.rs | 6 +- src/distribution/erlang.rs | 22 +-- src/distribution/exponential.rs | 27 +++- src/distribution/fisher_snedecor.rs | 55 ++++++-- src/distribution/gamma.rs | 70 ++++++--- src/distribution/geometric.rs | 25 +++- src/distribution/hypergeometric.rs | 59 ++++++-- src/distribution/internal.rs | 148 +++++++------------- src/distribution/inverse_gamma.rs | 51 +++++-- src/distribution/laplace.rs | 43 ++++-- src/distribution/log_normal.rs | 44 ++++-- src/distribution/mod.rs | 54 +++---- src/distribution/multinomial.rs | 71 ++++++++-- src/distribution/multivariate_normal.rs | 67 +++++++-- src/distribution/multivariate_students_t.rs | 86 +++++++++--- src/distribution/negative_binomial.rs | 43 ++++-- src/distribution/normal.rs | 46 ++++-- src/distribution/pareto.rs | 44 ++++-- src/distribution/poisson.rs | 25 +++- src/distribution/students_t.rs | 79 ++++++++--- src/distribution/triangular.rs | 90 +++++++++--- src/distribution/uniform.rs | 74 +++++++--- src/distribution/weibull.rs | 53 +++++-- src/stats_tests/fisher.rs | 38 ++++- 34 files changed, 1233 insertions(+), 466 deletions(-) diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index 12c0db2f..d5de981a 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -1,6 +1,5 @@ -use crate::distribution::{Binomial, Discrete, DiscreteCDF}; +use crate::distribution::{Binomial, BinomialError, Discrete, DiscreteCDF}; use crate::statistics::*; -use crate::Result; use rand::Rng; /// Implements the @@ -45,7 +44,7 @@ impl Bernoulli { /// result = Bernoulli::new(-0.5); /// assert!(result.is_err()); /// ``` - pub fn new(p: f64) -> Result { + pub fn new(p: f64) -> Result { Binomial::new(p, 1).map(|b| Bernoulli { b }) } @@ -266,10 +265,9 @@ impl Discrete for Bernoulli { #[cfg(test)] mod testing { use super::*; - use crate::StatsError; use crate::testing_boiler; - testing_boiler!(p: f64; Bernoulli; StatsError); + testing_boiler!(p: f64; Bernoulli; BinomialError); #[test] fn test_create() { diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 748bfcad..763945d6 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::{beta, gamma}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; /// Implements the [Beta](https://en.wikipedia.org/wiki/Beta_distribution) @@ -24,6 +23,32 @@ pub struct Beta { shape_b: f64, } +/// Represents the errors that can occur when creating a [`Beta`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum BetaError { + /// Shape A is NaN, zero or negative. + ShapeAInvalid, + + /// Shape B is NaN, zero or negative. + ShapeBInvalid, + + /// Shape A and Shape B are infinite. + BothShapesInfinite, +} + +impl std::fmt::Display for BetaError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + BetaError::ShapeAInvalid => write!(f, "Shape A is NaN, zero or negative"), + BetaError::ShapeBInvalid => write!(f, "Shape B is NaN, zero or negative"), + BetaError::BothShapesInfinite => write!(f, "Shape A and shape B are infinite"), + } + } +} + +impl std::error::Error for BetaError {} + impl Beta { /// Constructs a new beta distribution with shapeA (α) of `shape_a` /// and shapeB (β) of `shape_b` @@ -44,15 +69,19 @@ impl Beta { /// result = Beta::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape_a: f64, shape_b: f64) -> Result { - if shape_a.is_nan() - || shape_b.is_nan() - || shape_a.is_infinite() && shape_b.is_infinite() - || shape_a <= 0.0 - || shape_b <= 0.0 - { - return Err(StatsError::BadParams); - }; + pub fn new(shape_a: f64, shape_b: f64) -> Result { + if shape_a.is_nan() || shape_a <= 0.0 { + return Err(BetaError::ShapeAInvalid); + } + + if shape_b.is_nan() || shape_b <= 0.0 { + return Err(BetaError::ShapeBInvalid); + } + + if shape_a.is_infinite() && shape_b.is_infinite() { + return Err(BetaError::BothShapesInfinite); + } + Ok(Beta { shape_a, shape_b }) } @@ -433,7 +462,7 @@ mod tests { use super::super::internal::*; use crate::testing_boiler; - testing_boiler!(a: f64, b: f64; Beta; StatsError); + testing_boiler!(a: f64, b: f64; Beta; BetaError); #[test] fn test_create() { diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 3536ce8a..9f5ffc47 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -1,7 +1,6 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{beta, factorial}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -26,6 +25,24 @@ pub struct Binomial { n: u64, } +/// Represents the errors that can occur when creating a [`Binomial`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum BinomialError { + /// The probability is NaN or not in `[0, 1]`. + ProbabilityInvalid, +} + +impl std::fmt::Display for BinomialError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + BinomialError::ProbabilityInvalid => write!(f, "Probability is NaN or not in [0, 1]"), + } + } +} + +impl std::error::Error for BinomialError {} + impl Binomial { /// Constructs a new binomial distribution /// with a given `p` probability of success of `n` @@ -47,9 +64,9 @@ impl Binomial { /// result = Binomial::new(-0.5, 5); /// assert!(result.is_err()); /// ``` - pub fn new(p: f64, n: u64) -> Result { + pub fn new(p: f64, n: u64) -> Result { if p.is_nan() || !(0.0..=1.0).contains(&p) { - Err(StatsError::BadParams) + Err(BinomialError::ProbabilityInvalid) } else { Ok(Binomial { p, n }) } @@ -332,7 +349,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(p: f64, n: u64; Binomial; StatsError); + testing_boiler!(p: f64, n: u64; Binomial; BinomialError); #[test] fn test_create() { diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index fd24893b..9cd95a9d 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -1,6 +1,5 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -27,6 +26,35 @@ pub struct Categorical { sf: Vec, } +/// Represents the errors that can occur when creating a [`Categorical`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum CategoricalError { + /// The probability mass is empty. + ProbMassEmpty, + + /// The probabilities sums up to zero. + ProbMassSumZero, + + /// The probability mass contains at least one element which is NaN or less than zero. + ProbMassHasInvalidElements, +} + +impl std::fmt::Display for CategoricalError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + CategoricalError::ProbMassEmpty => write!(f, "Probability mass is empty"), + CategoricalError::ProbMassSumZero => write!(f, "Probabilities sum up to zero"), + CategoricalError::ProbMassHasInvalidElements => write!( + f, + "Probability mass contains at least one element which is NaN or less than zero" + ), + } + } +} + +impl std::error::Error for CategoricalError {} + impl Categorical { /// Constructs a new categorical distribution /// with the probabilities masses defined by `prob_mass` @@ -52,23 +80,36 @@ impl Categorical { /// result = Categorical::new(&[0.0, -1.0, 2.0]); /// assert!(result.is_err()); /// ``` - pub fn new(prob_mass: &[f64]) -> Result { - if !super::internal::is_valid_multinomial(prob_mass, true) { - Err(StatsError::BadParams) - } else { - // extract un-normalized cdf - let cdf = prob_mass_to_cdf(prob_mass); - // extract un-normalized sf - let sf = cdf_to_sf(&cdf); - // extract normalized probability mass - let sum = cdf[cdf.len() - 1]; - let mut norm_pmf = vec![0.0; prob_mass.len()]; - norm_pmf - .iter_mut() - .zip(prob_mass.iter()) - .for_each(|(np, pm)| *np = *pm / sum); - Ok(Categorical { norm_pmf, cdf, sf }) + pub fn new(prob_mass: &[f64]) -> Result { + if prob_mass.is_empty() { + return Err(CategoricalError::ProbMassEmpty); + } + + let mut prob_sum = 0.0; + for &p in prob_mass { + if p.is_nan() || p < 0.0 { + return Err(CategoricalError::ProbMassHasInvalidElements); + } + + prob_sum += p; } + + if prob_sum == 0.0 { + return Err(CategoricalError::ProbMassSumZero); + } + + // extract un-normalized cdf + let cdf = prob_mass_to_cdf(prob_mass); + // extract un-normalized sf + let sf = cdf_to_sf(&cdf); + // extract normalized probability mass + let sum = cdf[cdf.len() - 1]; + let mut norm_pmf = vec![0.0; prob_mass.len()]; + norm_pmf + .iter_mut() + .zip(prob_mass.iter()) + .for_each(|(np, pm)| *np = *pm / sum); + Ok(Categorical { norm_pmf, cdf, sf }) } fn cdf_max(&self) -> f64 { @@ -355,7 +396,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(prob_mass: &[f64]; Categorical; StatsError); + testing_boiler!(prob_mass: &[f64]; Categorical; CategoricalError); #[test] fn test_create() { @@ -364,8 +405,15 @@ mod tests { #[test] fn test_bad_create() { - create_err(&[-1.0, 1.0]); - create_err(&[0.0, 0.0]); + let invalid: &[(&[f64], CategoricalError)] = &[ + (&[], CategoricalError::ProbMassEmpty), + (&[-1.0, 1.0], CategoricalError::ProbMassHasInvalidElements), + (&[0.0, 0.0, 0.0], CategoricalError::ProbMassSumZero), + ]; + + for &(prob_mass, err) in invalid { + test_create_err(prob_mass, err); + } } #[test] diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index f9e0957d..a8815349 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -23,6 +22,28 @@ pub struct Cauchy { scale: f64, } +/// Represents the errors that can occur when creating a [`Cauchy`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum CauchyError { + /// The location is NaN. + LocationInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, +} + +impl std::fmt::Display for CauchyError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + CauchyError::LocationInvalid => write!(f, "Location is NaN"), + CauchyError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for CauchyError {} + impl Cauchy { /// Constructs a new cauchy distribution with the given /// location and scale. @@ -42,12 +63,16 @@ impl Cauchy { /// result = Cauchy::new(0.0, -1.0); /// assert!(result.is_err()); /// ``` - pub fn new(location: f64, scale: f64) -> Result { - if location.is_nan() || scale.is_nan() || scale <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(Cauchy { location, scale }) + pub fn new(location: f64, scale: f64) -> Result { + if location.is_nan() { + return Err(CauchyError::LocationInvalid); + } + + if scale.is_nan() || scale <= 0.0 { + return Err(CauchyError::ScaleInvalid); } + + Ok(Cauchy { location, scale }) } /// Returns the location of the cauchy distribution @@ -256,7 +281,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(location: f64, scale: f64; Cauchy; StatsError); + testing_boiler!(location: f64, scale: f64; Cauchy; CauchyError); #[test] fn test_create() { @@ -270,10 +295,16 @@ mod tests { #[test] fn test_bad_create() { - create_err(f64::NAN, 1.0); - create_err(1.0, f64::NAN); - create_err(f64::NAN, f64::NAN); - create_err(1.0, 0.0); + let invalid = [ + (f64::NAN, 1.0, CauchyError::LocationInvalid), + (1.0, f64::NAN, CauchyError::ScaleInvalid), + (f64::NAN, f64::NAN, CauchyError::LocationInvalid), + (1.0, 0.0, CauchyError::ScaleInvalid), + ]; + + for (location, scale, err) in invalid { + test_create_err(location, scale, err); + } } #[test] diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index 3c35336a..cce8535f 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -24,6 +23,26 @@ pub struct Chi { freedom: f64, } +/// Represents the errors that can occur when creating a [`Chi`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum ChiError { + /// The degrees of freedom are NaN, zero or less than zero. + FreedomInvalid, +} + +impl std::fmt::Display for ChiError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ChiError::FreedomInvalid => { + write!(f, "Degrees of freedom are NaN, zero or less than zero") + } + } + } +} + +impl std::error::Error for ChiError {} + impl Chi { /// Constructs a new chi distribution /// with `freedom` degrees of freedom @@ -44,9 +63,9 @@ impl Chi { /// result = Chi::new(0.0); /// assert!(result.is_err()); /// ``` - pub fn new(freedom: f64) -> Result { + pub fn new(freedom: f64) -> Result { if freedom.is_nan() || freedom <= 0.0 { - Err(StatsError::BadParams) + Err(ChiError::FreedomInvalid) } else { Ok(Chi { freedom }) } @@ -329,7 +348,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(freedom: f64; Chi; StatsError); + testing_boiler!(freedom: f64; Chi; ChiError); #[test] fn test_create() { diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index b0b2e2f9..a847ac94 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -1,6 +1,5 @@ -use crate::distribution::{Continuous, ContinuousCDF, Gamma}; +use crate::distribution::{Continuous, ContinuousCDF, Gamma, GammaError}; use crate::statistics::*; -use crate::Result; use rand::Rng; use std::f64; @@ -48,7 +47,7 @@ impl ChiSquared { /// result = ChiSquared::new(0.0); /// assert!(result.is_err()); /// ``` - pub fn new(freedom: f64) -> Result { + pub fn new(freedom: f64) -> Result { Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared { freedom, g }) } @@ -308,10 +307,9 @@ impl Continuous for ChiSquared { mod tests { use super::*; use crate::distribution::internal::*; - use crate::StatsError; use crate::testing_boiler; - testing_boiler!(freedom: f64; ChiSquared; StatsError); + testing_boiler!(freedom: f64; ChiSquared; GammaError); #[test] fn test_median() { diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index 5aa6cb9b..142f8c12 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -1,6 +1,5 @@ use crate::distribution::ContinuousCDF; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; /// Implements the [Dirac Delta](https://en.wikipedia.org/wiki/Dirac_delta_function#As_a_distribution) @@ -18,8 +17,26 @@ use rand::Rng; #[derive(Debug, Copy, Clone, PartialEq)] pub struct Dirac(f64); +/// Represents the errors that can occur when creating a [`Dirac`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum DiracError { + /// The value v is NaN. + ValueInvalid, +} + +impl std::fmt::Display for DiracError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + DiracError::ValueInvalid => write!(f, "Value v is NaN"), + } + } +} + +impl std::error::Error for DiracError {} + impl Dirac { - /// Constructs a new dirac distribution function at value `v`. + /// Constructs a new dirac distribution function at value `v`. /// /// # Errors /// @@ -36,9 +53,9 @@ impl Dirac { /// result = Dirac::new(f64::NAN); /// assert!(result.is_err()); /// ``` - pub fn new(v: f64) -> Result { + pub fn new(v: f64) -> Result { if v.is_nan() { - Err(StatsError::BadParams) + Err(DiracError::ValueInvalid) } else { Ok(Dirac(v)) } @@ -196,7 +213,7 @@ mod tests { use super::*; use crate::testing_boiler; - testing_boiler!(v: f64; Dirac; StatsError); + testing_boiler!(v: f64; Dirac; DiracError); #[test] fn test_create() { diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index f058b46d..b8aaad86 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -1,7 +1,7 @@ use crate::distribution::Continuous; use crate::function::gamma; +use crate::prec; use crate::statistics::*; -use crate::{prec, Result, StatsError}; use nalgebra::{Const, Dim, Dyn, OMatrix, OVector}; use rand::Rng; use std::f64; @@ -31,6 +31,31 @@ where alpha: OVector, } +/// Represents the errors that can occur when creating a [`Dirichlet`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum DirichletError { + /// Alpha contains less than two elements. + AlphaTooShort, + + /// Alpha contains an element that is NaN, infinite, zero or less than zero. + AlphaHasInvalidElements, +} + +impl std::fmt::Display for DirichletError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + DirichletError::AlphaTooShort => write!(f, "Alpha contains less than two elements"), + DirichletError::AlphaHasInvalidElements => write!( + f, + "Alpha contains an element that is NaN, infinite, zero or less than zero" + ), + } + } +} + +impl std::error::Error for DirichletError {} + impl Dirichlet { /// Constructs a new dirichlet distribution with the given /// concentration parameters (alpha) @@ -55,7 +80,7 @@ impl Dirichlet { /// result = Dirichlet::new(alpha_err); /// assert!(result.is_err()); /// ``` - pub fn new(alpha: Vec) -> Result { + pub fn new(alpha: Vec) -> Result { Self::new_from_nalgebra(alpha.into()) } @@ -78,7 +103,7 @@ impl Dirichlet { /// result = Dirichlet::new_with_param(0.0, 1); /// assert!(result.is_err()); /// ``` - pub fn new_with_param(alpha: f64, n: usize) -> Result { + pub fn new_with_param(alpha: f64, n: usize) -> Result { Self::new(vec![alpha; n]) } } @@ -95,12 +120,16 @@ where /// /// Returns an error if vector has length less than 2 or if any element /// of alpha is NOT finite positive - pub fn new_from_nalgebra(alpha: OVector) -> Result { - if !is_valid_alpha(alpha.as_slice()) { - Err(StatsError::BadParams) - } else { - Ok(Self { alpha }) + pub fn new_from_nalgebra(alpha: OVector) -> Result { + if alpha.len() < 2 { + return Err(DirichletError::AlphaTooShort); } + + if alpha.iter().any(|&a_i| !a_i.is_finite() || a_i <= 0.0) { + return Err(DirichletError::AlphaHasInvalidElements); + } + + Ok(Self { alpha }) } /// Returns the concentration parameters of @@ -336,12 +365,6 @@ where } } -// determines if `a` is a valid alpha array -// for the Dirichlet distribution -fn is_valid_alpha(a: &[f64]) -> bool { - a.len() >= 2 && a.iter().all(|&a_i| a_i.is_finite() && a_i > 0.0) -} - #[rustfmt::skip] #[cfg(test)] mod tests { @@ -349,7 +372,6 @@ mod tests { use nalgebra::{dmatrix, dvector, vector, DimMin, OVector}; - use super::is_valid_alpha; use crate::{ distribution::{Continuous, Dirichlet}, statistics::{MeanN, VarianceN}, @@ -386,18 +408,9 @@ mod tests { assert_relative_eq!(expected, x, epsilon = acc); } - #[test] - fn test_is_valid_alpha() { - assert!(!is_valid_alpha(&[1.0])); - assert!(!is_valid_alpha(&[1.0, f64::NAN])); - assert!(is_valid_alpha(&[1.0, 2.0])); - assert!(!is_valid_alpha(&[1.0, 0.0])); - assert!(!is_valid_alpha(&[1.0, f64::INFINITY])); - assert!(!is_valid_alpha(&[-1.0, 2.0])); - } - #[test] fn test_create() { + try_create(vector![1.0, 2.0]); try_create(vector![1.0, 2.0, 3.0, 4.0, 5.0]); assert!(Dirichlet::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]).is_ok()); // try_create(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate @@ -405,6 +418,10 @@ mod tests { #[test] fn test_bad_create() { + bad_create_case(vector![1.0, f64::NAN]); + bad_create_case(vector![1.0, 0.0]); + bad_create_case(vector![1.0, f64::INFINITY]); + bad_create_case(vector![-1.0, 2.0]); bad_create_case(vector![1.0]); bad_create_case(vector![1.0, 2.0, 0.0, 4.0, 5.0]); bad_create_case(vector![1.0, f64::NAN, 3.0, 4.0, 5.0]); diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index 086e4849..be851b62 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -1,6 +1,5 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; /// Implements the [Discrete @@ -23,6 +22,24 @@ pub struct DiscreteUniform { max: i64, } +/// Represents the errors that can occur when creating a [`DiscreteUniform`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum DiscreteUniformError { + /// The maximum is less than the minimum. + MinMaxInvalid, +} + +impl std::fmt::Display for DiscreteUniformError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + DiscreteUniformError::MinMaxInvalid => write!(f, "Maximum is less than minimum"), + } + } +} + +impl std::error::Error for DiscreteUniformError {} + impl DiscreteUniform { /// Constructs a new discrete uniform distribution with a minimum value /// of `min` and a maximum value of `max`. @@ -42,9 +59,9 @@ impl DiscreteUniform { /// result = DiscreteUniform::new(5, 0); /// assert!(result.is_err()); /// ``` - pub fn new(min: i64, max: i64) -> Result { + pub fn new(min: i64, max: i64) -> Result { if max < min { - Err(StatsError::BadParams) + Err(DiscreteUniformError::MinMaxInvalid) } else { Ok(DiscreteUniform { min, max }) } @@ -259,7 +276,7 @@ mod tests { use super::*; use crate::testing_boiler; - testing_boiler!(min: i64, max: i64; DiscreteUniform; StatsError); + testing_boiler!(min: i64, max: i64; DiscreteUniform; DiscreteUniformError); #[test] fn test_create() { diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index 104169aa..6dc7ec71 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -1,6 +1,5 @@ use crate::distribution::{ContinuousCDF, Uniform}; use crate::statistics::*; -use crate::Result; use core::cmp::Ordering; use rand::Rng; use std::collections::BTreeMap; @@ -48,6 +47,8 @@ impl Empirical { /// Constructs a new discrete uniform distribution with a minimum value /// of `min` and a maximum value of `max`. /// + /// Note that this will always succeed and never return the [`Err`][Result::Err] variant. + /// /// # Examples /// /// ``` @@ -56,7 +57,8 @@ impl Empirical { /// let mut result = Empirical::new(); /// assert!(result.is_ok()); /// ``` - pub fn new() -> Result { + #[allow(clippy::result_unit_err)] + pub fn new() -> Result { Ok(Empirical { sum: 0., mean_and_var: None, diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index 68b9675d..9b7a332c 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -1,6 +1,5 @@ -use crate::distribution::{Continuous, ContinuousCDF, Gamma}; +use crate::distribution::{Continuous, ContinuousCDF, Gamma, GammaError}; use crate::statistics::*; -use crate::Result; use rand::Rng; /// Implements the [Erlang](https://en.wikipedia.org/wiki/Erlang_distribution) @@ -45,7 +44,7 @@ impl Erlang { /// result = Erlang::new(0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: u64, rate: f64) -> Result { + pub fn new(shape: u64, rate: f64) -> Result { Gamma::new(shape as f64, rate).map(|g| Erlang { g }) } @@ -295,10 +294,9 @@ impl Continuous for Erlang { mod tests { use super::*; use crate::distribution::internal::*; - use crate::StatsError; use crate::testing_boiler; - testing_boiler!(shape: u64, rate: f64; Erlang; StatsError); + testing_boiler!(shape: u64, rate: f64; Erlang; GammaError); #[test] fn test_create() { @@ -311,10 +309,16 @@ mod tests { #[test] fn test_bad_create() { - create_err(0, 1.0); - create_err(1, 0.0); - create_err(1, f64::NAN); - create_err(1, -1.0); + let invalid = [ + (0, 1.0, GammaError::ShapeInvalid), + (1, 0.0, GammaError::RateInvalid), + (1, f64::NAN, GammaError::RateInvalid), + (1, -1.0, GammaError::RateInvalid), + ]; + + for (s, r, err) in invalid { + test_create_err(s, r, err); + } } #[test] diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index 9796028d..1389b1e4 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -1,6 +1,5 @@ use crate::distribution::{ziggurat, Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -25,13 +24,31 @@ pub struct Exp { rate: f64, } +/// Represents the errors that can occur when creating a [`Exp`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum ExpError { + /// The rate is NaN, zero or less than zero. + RateInvalid, +} + +impl std::fmt::Display for ExpError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ExpError::RateInvalid => write!(f, "Rate is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for ExpError {} + impl Exp { /// Constructs a new exponential distribution with a /// rate (λ) of `rate`. /// /// # Errors /// - /// Returns an error if rate is `NaN` or `rate <= 0.0` + /// Returns an error if rate is `NaN` or `rate <= 0.0`. /// /// # Examples /// @@ -44,9 +61,9 @@ impl Exp { /// result = Exp::new(-1.0); /// assert!(result.is_err()); /// ``` - pub fn new(rate: f64) -> Result { + pub fn new(rate: f64) -> Result { if rate.is_nan() || rate <= 0.0 { - Err(StatsError::BadParams) + Err(ExpError::RateInvalid) } else { Ok(Exp { rate }) } @@ -283,7 +300,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(rate: f64; Exp; StatsError); + testing_boiler!(rate: f64; Exp; ExpError); #[test] fn test_create() { diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index 76913651..a50ed42a 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::beta; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -26,6 +25,32 @@ pub struct FisherSnedecor { freedom_2: f64, } +/// Represents the errors that can occur when creating a [`FisherSnedecor`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum FisherSnedecorError { + /// `freedom_1` is NaN, infinite, zero or less than zero. + Freedom1Invalid, + + /// `freedom_2` is NaN, infinite, zero or less than zero. + Freedom2Invalid, +} + +impl std::fmt::Display for FisherSnedecorError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + FisherSnedecorError::Freedom1Invalid => { + write!(f, "freedom_1 is NaN, infinite, zero or less than zero.") + } + FisherSnedecorError::Freedom2Invalid => { + write!(f, "freedom_2 is NaN, infinite, zero or less than zero.") + } + } + } +} + +impl std::error::Error for FisherSnedecorError {} + impl FisherSnedecor { /// Constructs a new fisher-snedecor distribution with /// degrees of freedom `freedom_1` and `freedom_2` @@ -46,16 +71,19 @@ impl FisherSnedecor { /// result = FisherSnedecor::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(freedom_1: f64, freedom_2: f64) -> Result { - if !freedom_1.is_finite() || freedom_1 <= 0.0 || !freedom_2.is_finite() || freedom_2 <= 0.0 - { - Err(StatsError::BadParams) - } else { - Ok(FisherSnedecor { - freedom_1, - freedom_2, - }) + pub fn new(freedom_1: f64, freedom_2: f64) -> Result { + if !freedom_1.is_finite() || freedom_1 <= 0.0 { + return Err(FisherSnedecorError::Freedom1Invalid); + } + + if !freedom_2.is_finite() || freedom_2 <= 0.0 { + return Err(FisherSnedecorError::Freedom2Invalid); } + + Ok(FisherSnedecor { + freedom_1, + freedom_2, + }) } /// Returns the first degree of freedom for the @@ -389,7 +417,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(freedom_1: f64, freedom_2: f64; FisherSnedecor; StatsError); + testing_boiler!(freedom_1: f64, freedom_2: f64; FisherSnedecor; FisherSnedecorError); #[test] fn test_create() { @@ -403,6 +431,9 @@ mod tests { #[test] fn test_bad_create() { + test_create_err(f64::INFINITY, 0.1, FisherSnedecorError::Freedom1Invalid); + test_create_err(0.1, f64::INFINITY, FisherSnedecorError::Freedom2Invalid); + create_err(f64::NAN, f64::NAN); create_err(0.0, f64::NAN); create_err(-1.0, f64::NAN); @@ -419,8 +450,6 @@ mod tests { create_err(0.0, -10.0); create_err(-1.0, -10.0); create_err(-10.0, -10.0); - create_err(f64::INFINITY, 0.1); - create_err(0.1, f64::INFINITY); create_err(f64::INFINITY, f64::INFINITY); } diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index fae9848a..201d0cfd 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -2,7 +2,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::prec; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; /// Implements the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) @@ -25,6 +24,32 @@ pub struct Gamma { rate: f64, } +/// Represents the errors that can occur when creating a [`Gamma`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum GammaError { + /// The shape is NaN, zero or less than zero. + ShapeInvalid, + + /// The rate is NaN, zero or less than zero. + RateInvalid, + + /// The shape and rate are both infinite. + ShapeAndRateInfinite, +} + +impl std::fmt::Display for GammaError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + GammaError::ShapeInvalid => write!(f, "Shape is NaN zero, or less than zero."), + GammaError::RateInvalid => write!(f, "Rate is NaN zero, or less than zero."), + GammaError::ShapeAndRateInfinite => write!(f, "Shape and rate are infinite"), + } + } +} + +impl std::error::Error for GammaError {} + impl Gamma { /// Constructs a new gamma distribution with a shape (α) /// of `shape` and a rate (β) of `rate` @@ -45,15 +70,19 @@ impl Gamma { /// result = Gamma::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: f64, rate: f64) -> Result { - if shape.is_nan() - || rate.is_nan() - || shape.is_infinite() && rate.is_infinite() - || shape <= 0.0 - || rate <= 0.0 - { - return Err(StatsError::BadParams); + pub fn new(shape: f64, rate: f64) -> Result { + if shape.is_nan() || shape <= 0.0 { + return Err(GammaError::ShapeInvalid); + } + + if rate.is_nan() || rate <= 0.0 { + return Err(GammaError::RateInvalid); } + + if shape.is_infinite() && rate.is_infinite() { + return Err(GammaError::ShapeAndRateInfinite); + } + Ok(Gamma { shape, rate }) } @@ -406,7 +435,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(shape: f64, rate: f64; Gamma; StatsError); + testing_boiler!(shape: f64, rate: f64; Gamma; GammaError); #[test] fn test_create() { @@ -426,15 +455,20 @@ mod tests { #[test] fn test_bad_create() { let invalid = [ - (0.0, 0.0), - (1.0, f64::NAN), - (1.0, -1.0), - (-1.0, 1.0), - (-1.0, -1.0), - (-1.0, f64::NAN), + (0.0, 0.0, GammaError::ShapeInvalid), + (1.0, f64::NAN, GammaError::RateInvalid), + (1.0, -1.0, GammaError::RateInvalid), + (-1.0, 1.0, GammaError::ShapeInvalid), + (-1.0, -1.0, GammaError::ShapeInvalid), + (-1.0, f64::NAN, GammaError::ShapeInvalid), + ( + f64::INFINITY, + f64::INFINITY, + GammaError::ShapeAndRateInfinite, + ), ]; - for (s, r) in invalid { - create_err(s, r); + for (s, r, err) in invalid { + test_create_err(s, r, err); } } diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index b4bda5a3..dfb28ef6 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -1,6 +1,5 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::distributions::OpenClosed01; use rand::Rng; use std::f64; @@ -25,6 +24,24 @@ pub struct Geometric { p: f64, } +/// Represents the errors that can occur when creating a [`Geometric`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum GeometricError { + /// The probability is NaN or not in `(0, 1]`. + ProbabilityInvalid, +} + +impl std::fmt::Display for GeometricError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + GeometricError::ProbabilityInvalid => write!(f, "Probability is NaN or not in (0, 1]"), + } + } +} + +impl std::error::Error for GeometricError {} + impl Geometric { /// Constructs a new shifted geometric distribution with a probability /// of `p` @@ -44,9 +61,9 @@ impl Geometric { /// result = Geometric::new(0.0); /// assert!(result.is_err()); /// ``` - pub fn new(p: f64) -> Result { + pub fn new(p: f64) -> Result { if p <= 0.0 || p > 1.0 || p.is_nan() { - Err(StatsError::BadParams) + Err(GeometricError::ProbabilityInvalid) } else { Ok(Geometric { p }) } @@ -277,7 +294,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(p: f64; Geometric; StatsError); + testing_boiler!(p: f64; Geometric; GeometricError); #[test] fn test_create() { diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index fabb2532..c3ba5a3a 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -1,7 +1,6 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::factorial; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::cmp; use std::f64; @@ -17,15 +16,37 @@ pub struct Hypergeometric { draws: u64, } +/// Represents the errors that can occur when creating a [`Hypergeometric`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum HypergeometricError { + /// The number of successes is greater than the population. + TooManySuccesses, + + /// The number of draws is greater than the population. + TooManyDraws, +} + +impl std::fmt::Display for HypergeometricError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + HypergeometricError::TooManySuccesses => write!(f, "successes > population"), + HypergeometricError::TooManyDraws => write!(f, "draws > population"), + } + } +} + +impl std::error::Error for HypergeometricError {} + impl Hypergeometric { /// Constructs a new hypergeometric distribution /// with a population (N) of `population`, number /// of successes (K) of `successes`, and number of draws - /// (n) of `draws` + /// (n) of `draws`. /// /// # Errors /// - /// If `successes > population` or `draws > population` + /// If `successes > population` or `draws > population`. /// /// # Examples /// @@ -38,16 +59,24 @@ impl Hypergeometric { /// result = Hypergeometric::new(2, 3, 2); /// assert!(result.is_err()); /// ``` - pub fn new(population: u64, successes: u64, draws: u64) -> Result { - if successes > population || draws > population { - Err(StatsError::BadParams) - } else { - Ok(Hypergeometric { - population, - successes, - draws, - }) + pub fn new( + population: u64, + successes: u64, + draws: u64, + ) -> Result { + if successes > population { + return Err(HypergeometricError::TooManySuccesses); } + + if draws > population { + return Err(HypergeometricError::TooManyDraws); + } + + Ok(Hypergeometric { + population, + successes, + draws, + }) } /// Returns the population size of the hypergeometric @@ -376,7 +405,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(population: u64, successes: u64, draws: u64; Hypergeometric; StatsError); + testing_boiler!(population: u64, successes: u64, draws: u64; Hypergeometric; HypergeometricError); #[test] fn test_create() { @@ -390,8 +419,8 @@ mod tests { #[test] fn test_bad_create() { - create_err(2, 3, 2); - create_err(10, 5, 20); + test_create_err(2, 3, 2, HypergeometricError::TooManySuccesses); + test_create_err(10, 5, 20, HypergeometricError::TooManyDraws); create_err(0, 1, 1); } diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 4d328d7d..c3b9f22c 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -1,56 +1,5 @@ use num_traits::Num; -/// Returns true if there are no elements in `x` in `arr` -/// such that `x <= 0.0` or `x` is `f64::NAN` and `sum(arr) > 0.0`. -/// IF `incl_zero` is true, it tests for `x < 0.0` instead of `x <= 0.0` -pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool { - let mut sum = 0.0; - for &elt in arr { - if incl_zero && elt < 0.0 || !incl_zero && elt <= 0.0 || elt.is_nan() { - return false; - } - sum += elt; - } - sum != 0.0 -} - -#[cfg(feature = "nalgebra")] -use nalgebra::{Dim, OVector}; - -#[cfg(feature = "nalgebra")] -pub fn check_multinomial(arr: &OVector, accept_zeroes: bool) -> crate::Result<()> -where - D: Dim, - nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, -{ - use crate::StatsError; - - if arr.len() < 2 { - return Err(StatsError::BadParams); - } - let mut sum = 0.0; - for &x in arr.iter() { - #[allow(clippy::if_same_then_else)] - if x.is_nan() { - return Err(StatsError::BadParams); - } else if x.is_infinite() { - return Err(StatsError::BadParams); - } else if x < 0.0 { - return Err(StatsError::BadParams); - } else if x == 0.0 && !accept_zeroes { - return Err(StatsError::BadParams); - } else { - sum += x; - } - } - - if sum != 0.0 { - Ok(()) - } else { - Err(StatsError::BadParams) - } -} - /// Implements univariate function bisection searching for criteria /// ```text /// smallest k such that f(k) >= z @@ -240,6 +189,25 @@ pub mod test { } } + /// Purposely fails creating a distribution with the given + /// parameters and compares the returned error to `expected`. + /// + /// Panics if `::new` succeeds. + #[allow(dead_code)] + fn test_create_err($($arg_name: $arg_ty),+, expected: $dist_err) + { + let err = create_err($($arg_name),+); + if err != expected { + panic!( + "{}::new was expected to fail with error {:?}, but failed with error {:?} for {}", + stringify!($dist), + expected, + err, + make_param_text($($arg_name),+) + ) + } + } + /// Gets a value for the given parameters by calling `create_and_get` /// and asserts that it is [`NAN`]. /// @@ -277,92 +245,101 @@ pub mod test { } pub mod boiler_tests { - use crate::distribution::Binomial; + use crate::distribution::{Beta, BetaError}; use crate::statistics::*; - use crate::StatsError; - testing_boiler!(p: f64, n: u64; Binomial; StatsError); + testing_boiler!(shape_a: f64, shape_b: f64; Beta; BetaError); #[test] fn create_ok_success() { - let b = create_ok(0.8, 1200); - assert_eq!(b.p(), 0.8); - assert_eq!(b.n(), 1200); + let b = create_ok(0.8, 1.2); + assert_eq!(b.shape_a(), 0.8); + assert_eq!(b.shape_b(), 1.2); } #[test] #[should_panic] fn create_err_failure() { - create_err(0.8, 1200); + create_err(0.8, 1.2); } #[test] fn create_err_success() { - let err = create_err(-0.5, 1000); - assert_eq!(err, StatsError::BadParams); + let err = create_err(-0.5, 1.2); + assert_eq!(err, BetaError::ShapeAInvalid); } #[test] #[should_panic] fn create_ok_failure() { - create_ok(-0.5, 1000); + create_ok(-0.5, 1.2); } #[test] fn test_exact_success() { - test_exact(0.0, 4, 0.0, |dist| dist.mean().unwrap()); + test_exact(1.5, 1.5, 0.5, |dist| dist.mode().unwrap()); } #[test] #[should_panic] fn test_exact_failure() { - test_exact(0.3, 3, 0.9, |dist| dist.mean().unwrap()); + test_exact(1.2, 1.4, 0.333333333333, |dist| dist.mode().unwrap()); } #[test] fn test_relative_success() { - test_relative(0.3, 3, 0.9, |dist| dist.mean().unwrap()); + test_relative(1.2, 1.4, 0.333333333333, |dist| dist.mode().unwrap()); } #[test] #[should_panic] fn test_relative_failure() { - test_relative(0.3, 3, 0.8, |dist| dist.mean().unwrap()); + test_relative(1.2, 1.4, 0.333, |dist| dist.mode().unwrap()); } #[test] fn test_absolute_success() { - test_absolute(0.3, 3, 0.9, 1e-15, |dist| dist.mean().unwrap()); + test_absolute(1.2, 1.4, 0.333333333333, 1e-12, |dist| dist.mode().unwrap()); } #[test] #[should_panic] fn test_absolute_failure() { - test_absolute(0.3, 3, 0.9, 1e-17, |dist| dist.mean().unwrap()); + test_absolute(1.2, 1.4, 0.333333333333, 1e-15, |dist| dist.mode().unwrap()); + } + + #[test] + fn test_create_err_success() { + test_create_err(0.0, 0.5, BetaError::ShapeAInvalid); + } + + #[test] + #[should_panic] + fn test_create_err_failure() { + test_create_err(0.0, 0.5, BetaError::BothShapesInfinite); } #[test] fn test_is_nan_success() { - // Not sure that any Binomial API can return a NaN, so we force the issue - test_is_nan(0.8, 1200, |_| f64::NAN); + // Not sure that any Beta API can return a NaN, so we force the issue + test_is_nan(0.8, 1.2, |_| f64::NAN); } #[test] #[should_panic] fn test_is_nan_failure() { - test_is_nan(0.8, 1200, |dist| dist.mean().unwrap()); + test_is_nan(0.8, 1.2, |dist| dist.mean().unwrap()); } #[test] fn test_is_none_success() { - // Same as test_is_nan_success, force returning `None` here - test_none(0.8, 1200, |_| Option::::None); + test_none(f64::INFINITY, 1.2, |dist| dist.entropy()); } #[test] #[should_panic] fn test_is_none_failure() { - test_none(0.8, 1200, |dist| dist.mean()); + test_none(0.8, 1.2, |dist| dist.mean()); } } @@ -471,31 +448,6 @@ pub mod test { check_sum_pmf_is_cdf(dist, x_max); } - #[cfg(feature = "nalgebra")] - #[test] - fn test_is_valid_multinomial() { - use std::f64; - - let invalid = [1.0, f64::NAN, 3.0]; - assert!(!is_valid_multinomial(&invalid, true)); - assert!(check_multinomial(&invalid.to_vec().into(), true).is_err()); - let invalid2 = [-2.0, 5.0, 1.0, 6.2]; - assert!(!is_valid_multinomial(&invalid2, true)); - assert!(check_multinomial(&invalid2.to_vec().into(), true).is_err()); - let invalid3 = [0.0, 0.0, 0.0]; - assert!(!is_valid_multinomial(&invalid3, true)); - assert!(check_multinomial(&invalid3.to_vec().into(), true).is_err()); - let valid = [5.2, 0.0, 1e-15, 1000000.12]; - assert!(is_valid_multinomial(&valid, true)); - assert!(check_multinomial(&valid.to_vec().into(), true).is_ok()); - } - - #[test] - fn test_is_valid_multinomial_no_zero() { - let invalid = [5.2, 0.0, 1e-15, 1000000.12]; - assert!(!is_valid_multinomial(&invalid, false)); - } - #[test] fn test_integer_bisection() { fn search(z: usize, data: &[usize]) -> Option { diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index ec70f6f2..1f1bee4b 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -26,6 +25,32 @@ pub struct InverseGamma { rate: f64, } +/// Represents the errors that can occur when creating an [`InverseGamma`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum InverseGammaError { + /// The shape is NaN, infinite, zero or less than zero. + ShapeInvalid, + + /// The rate is NaN, infinite, zero or less than zero. + RateInvalid, +} + +impl std::fmt::Display for InverseGammaError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + InverseGammaError::ShapeInvalid => { + write!(f, "Shape is NaN, infinite, zero or less than zero") + } + InverseGammaError::RateInvalid => { + write!(f, "Rate is NaN, infinite, zero or less than zero") + } + } + } +} + +impl std::error::Error for InverseGammaError {} + impl InverseGamma { /// Constructs a new inverse gamma distribution with a shape (α) /// of `shape` and a rate (β) of `rate` @@ -46,16 +71,16 @@ impl InverseGamma { /// result = InverseGamma::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: f64, rate: f64) -> Result { - let is_nan = shape.is_nan() || rate.is_nan(); - match (shape, rate, is_nan) { - (_, _, true) => Err(StatsError::BadParams), - (_, _, false) if shape <= 0.0 || rate <= 0.0 => Err(StatsError::BadParams), - (_, _, false) if shape.is_infinite() || rate.is_infinite() => { - Err(StatsError::BadParams) - } - (_, _, false) => Ok(InverseGamma { shape, rate }), + pub fn new(shape: f64, rate: f64) -> Result { + if shape.is_nan() || shape.is_infinite() || shape <= 0.0 { + return Err(InverseGammaError::ShapeInvalid); } + + if rate.is_nan() || rate.is_infinite() || rate <= 0.0 { + return Err(InverseGammaError::RateInvalid); + } + + Ok(InverseGamma { shape, rate }) } /// Returns the shape (α) of the inverse gamma distribution @@ -317,7 +342,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(shape: f64, rate: f64; InverseGamma; StatsError); + testing_boiler!(shape: f64, rate: f64; InverseGamma; InverseGammaError); #[test] fn test_create() { @@ -327,13 +352,13 @@ mod tests { #[test] fn test_bad_create() { - create_err(0.0, 1.0); + test_create_err(0.0, 1.0, InverseGammaError::ShapeInvalid); + test_create_err(1.0, -1.0, InverseGammaError::RateInvalid); create_err(-1.0, 1.0); create_err(-100.0, 1.0); create_err(f64::NEG_INFINITY, 1.0); create_err(f64::NAN, 1.0); create_err(1.0, 0.0); - create_err(1.0, -1.0); create_err(1.0, -100.0); create_err(1.0, f64::NEG_INFINITY); create_err(1.0, f64::NAN); diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index 5466e89f..d4f6fc16 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::{Distribution, Max, Median, Min, Mode}; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -23,6 +22,28 @@ pub struct Laplace { scale: f64, } +/// Represents the errors that can occur when creating a [`Laplace`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum LaplaceError { + /// The location is NaN. + LocationInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, +} + +impl std::fmt::Display for LaplaceError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + LaplaceError::LocationInvalid => write!(f, "Location is NaN"), + LaplaceError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for LaplaceError {} + impl Laplace { /// Constructs a new laplace distribution with the given /// location and scale. @@ -42,12 +63,16 @@ impl Laplace { /// result = Laplace::new(0.0, -1.0); /// assert!(result.is_err()); /// ``` - pub fn new(location: f64, scale: f64) -> Result { - if location.is_nan() || scale.is_nan() || scale <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(Laplace { location, scale }) + pub fn new(location: f64, scale: f64) -> Result { + if location.is_nan() { + return Err(LaplaceError::LocationInvalid); } + + if scale.is_nan() || scale <= 0.0 { + return Err(LaplaceError::ScaleInvalid); + } + + Ok(Laplace { location, scale }) } /// Returns the location of the laplace distribution @@ -304,7 +329,7 @@ mod tests { use crate::testing_boiler; - testing_boiler!(location: f64, scale: f64; Laplace; StatsError); + testing_boiler!(location: f64, scale: f64; Laplace; LaplaceError); // A wrapper for the `assert_relative_eq!` macro from the approx crate. // @@ -332,8 +357,8 @@ mod tests { #[test] fn test_bad_create() { - create_err(2.0, -1.0); - create_err(f64::NAN, 1.0); + test_create_err(2.0, -1.0, LaplaceError::ScaleInvalid); + test_create_err(f64::NAN, 1.0, LaplaceError::LocationInvalid); create_err(f64::NAN, -1.0); } diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index cce0f11d..9075bfd2 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -1,7 +1,7 @@ +use crate::consts; use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::erf; use crate::statistics::*; -use crate::{consts, Result, StatsError}; use rand::Rng; use std::f64; @@ -26,6 +26,28 @@ pub struct LogNormal { scale: f64, } +/// Represents the errors that can occur when creating a [`LogNormal`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum LogNormalError { + /// The location is NaN. + LocationInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, +} + +impl std::fmt::Display for LogNormalError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + LogNormalError::LocationInvalid => write!(f, "Location is NaN"), + LogNormalError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for LogNormalError {} + impl LogNormal { /// Constructs a new log-normal distribution with a location of `location` /// and a scale of `scale` @@ -46,12 +68,16 @@ impl LogNormal { /// result = LogNormal::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(location: f64, scale: f64) -> Result { - if location.is_nan() || scale.is_nan() || scale <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(LogNormal { location, scale }) + pub fn new(location: f64, scale: f64) -> Result { + if location.is_nan() { + return Err(LogNormalError::LocationInvalid); } + + if scale.is_nan() || scale <= 0.0 { + return Err(LogNormalError::ScaleInvalid); + } + + Ok(LogNormal { location, scale }) } } @@ -309,7 +335,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(mean: f64, std_dev: f64; LogNormal; StatsError); + testing_boiler!(location: f64, scale: f64; LogNormal; LogNormalError); #[test] fn test_create() { @@ -322,9 +348,9 @@ mod tests { #[test] fn test_bad_create() { + test_create_err(f64::NAN, 1.0, LogNormalError::LocationInvalid); + test_create_err(1.0, f64::NAN, LogNormalError::ScaleInvalid); create_err(0.0, 0.0); - create_err(f64::NAN, 1.0); - create_err(1.0, f64::NAN); create_err(f64::NAN, f64::NAN); create_err(1.0, -1.0); } diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 6e43db8e..8955ed63 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -6,40 +6,40 @@ use ::num_traits::{Float, Num}; use num_traits::NumAssignOps; pub use self::bernoulli::Bernoulli; -pub use self::beta::Beta; -pub use self::binomial::Binomial; -pub use self::categorical::Categorical; -pub use self::cauchy::Cauchy; -pub use self::chi::Chi; +pub use self::beta::{Beta, BetaError}; +pub use self::binomial::{Binomial, BinomialError}; +pub use self::categorical::{Categorical, CategoricalError}; +pub use self::cauchy::{Cauchy, CauchyError}; +pub use self::chi::{Chi, ChiError}; pub use self::chi_squared::ChiSquared; -pub use self::dirac::Dirac; +pub use self::dirac::{Dirac, DiracError}; #[cfg(feature = "nalgebra")] -pub use self::dirichlet::Dirichlet; -pub use self::discrete_uniform::DiscreteUniform; +pub use self::dirichlet::{Dirichlet, DirichletError}; +pub use self::discrete_uniform::{DiscreteUniform, DiscreteUniformError}; pub use self::empirical::Empirical; pub use self::erlang::Erlang; -pub use self::exponential::Exp; -pub use self::fisher_snedecor::FisherSnedecor; -pub use self::gamma::Gamma; -pub use self::geometric::Geometric; -pub use self::hypergeometric::Hypergeometric; -pub use self::inverse_gamma::InverseGamma; -pub use self::laplace::Laplace; -pub use self::log_normal::LogNormal; +pub use self::exponential::{Exp, ExpError}; +pub use self::fisher_snedecor::{FisherSnedecor, FisherSnedecorError}; +pub use self::gamma::{Gamma, GammaError}; +pub use self::geometric::{Geometric, GeometricError}; +pub use self::hypergeometric::{Hypergeometric, HypergeometricError}; +pub use self::inverse_gamma::{InverseGamma, InverseGammaError}; +pub use self::laplace::{Laplace, LaplaceError}; +pub use self::log_normal::{LogNormal, LogNormalError}; #[cfg(feature = "nalgebra")] -pub use self::multinomial::Multinomial; +pub use self::multinomial::{Multinomial, MultinomialError}; #[cfg(feature = "nalgebra")] -pub use self::multivariate_normal::MultivariateNormal; +pub use self::multivariate_normal::{MultivariateNormal, MultivariateNormalError}; #[cfg(feature = "nalgebra")] -pub use self::multivariate_students_t::MultivariateStudent; -pub use self::negative_binomial::NegativeBinomial; -pub use self::normal::Normal; -pub use self::pareto::Pareto; -pub use self::poisson::Poisson; -pub use self::students_t::StudentsT; -pub use self::triangular::Triangular; -pub use self::uniform::Uniform; -pub use self::weibull::Weibull; +pub use self::multivariate_students_t::{MultivariateStudent, MultivariateStudentError}; +pub use self::negative_binomial::{NegativeBinomial, NegativeBinomialError}; +pub use self::normal::{Normal, NormalError}; +pub use self::pareto::{Pareto, ParetoError}; +pub use self::poisson::{Poisson, PoissonError}; +pub use self::students_t::{StudentsT, StudentsTError}; +pub use self::triangular::{Triangular, TriangularError}; +pub use self::uniform::{Uniform, UniformError}; +pub use self::weibull::{Weibull, WeibullError}; mod bernoulli; mod beta; diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index dc402050..4fcaca1a 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -1,7 +1,6 @@ use crate::distribution::Discrete; use crate::function::factorial; use crate::statistics::*; -use crate::Result; use nalgebra::{Const, DVector, Dim, Dyn, OMatrix, OVector}; use rand::Rng; @@ -33,6 +32,35 @@ where n: u64, } +/// Represents the errors that can occur when creating a [`Multinomial`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum MultinomialError { + /// Fewer than two probabilities. + NotEnoughProbabilities, + + /// The sum of all probabilities is zero. + ProbabilitySumZero, + + /// At least one probability is NaN, infinite or less than zero. + ProbabilityInvalid, +} + +impl std::fmt::Display for MultinomialError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MultinomialError::NotEnoughProbabilities => write!(f, "Fewer than two probabilities"), + MultinomialError::ProbabilitySumZero => write!(f, "The probabilities sum up to zero"), + MultinomialError::ProbabilityInvalid => write!( + f, + "At least one probability is NaN, infinity or less than zero" + ), + } + } +} + +impl std::error::Error for MultinomialError {} + impl Multinomial { /// Constructs a new multinomial distribution with probabilities `p` /// and `n` number of trials. @@ -57,7 +85,7 @@ impl Multinomial { /// result = Multinomial::new(vec![0.0, -1.0, 2.0], 3); /// assert!(result.is_err()); /// ``` - pub fn new(p: Vec, n: u64) -> Result { + pub fn new(p: Vec, n: u64) -> Result { Self::new_from_nalgebra(p.into(), n) } } @@ -67,14 +95,26 @@ where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { - pub fn new_from_nalgebra(mut p: OVector, n: u64) -> Result { - match super::internal::check_multinomial(&p, true) { - Err(e) => Err(e), - Ok(_) => { - p.unscale_mut(p.lp_norm(1)); - Ok(Self { p, n }) + pub fn new_from_nalgebra(mut p: OVector, n: u64) -> Result { + if p.len() < 2 { + return Err(MultinomialError::NotEnoughProbabilities); + } + + let mut sum = 0.0; + for &val in &p { + if val.is_nan() || val < 0.0 { + return Err(MultinomialError::ProbabilityInvalid); } + + sum += val; + } + + if sum == 0.0 { + return Err(MultinomialError::ProbabilitySumZero); } + + p.unscale_mut(p.lp_norm(1)); + Ok(Self { p, n }) } /// Returns the probabilities of the multinomial @@ -295,7 +335,7 @@ where #[cfg(test)] mod tests { use crate::{ - distribution::{Discrete, Multinomial}, + distribution::{Discrete, Multinomial, MultinomialError}, statistics::{MeanN, VarianceN}, }; use nalgebra::{dmatrix, dvector, vector, DimMin, Dyn, OVector}; @@ -311,7 +351,7 @@ mod tests { mvn.unwrap() } - fn bad_create_case(p: OVector, n: u64) -> crate::StatsError + fn bad_create_case(p: OVector, n: u64) -> MultinomialError where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, @@ -344,18 +384,23 @@ mod tests { #[test] fn test_bad_create() { + assert_eq!( + bad_create_case(vector![0.5], 4), + MultinomialError::NotEnoughProbabilities, + ); + assert_eq!( bad_create_case(vector![-1.0, 2.0], 4), - crate::StatsError::BadParams + MultinomialError::ProbabilityInvalid, ); assert_eq!( bad_create_case(vector![0.0, 0.0], 4), - crate::StatsError::BadParams + MultinomialError::ProbabilitySumZero, ); assert_eq!( bad_create_case(vector![1.0, f64::NAN], 4), - crate::StatsError::BadParams + MultinomialError::ProbabilityInvalid, ); } diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 61ccfa57..0da93cd4 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -108,15 +108,53 @@ where pdf_const: f64, } +/// Represents the errors that can occur when creating a [`MultivariateNormal`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum MultivariateNormalError { + /// The covariance matrix is asymmetric or contains a NaN. + CovInvalid, + + /// The mean vector contains a NaN. + MeanInvalid, + + /// The amount of rows in the vector of means is not equal to the amount + /// of rows in the covariance matrix. + DimensionMismatch, + + /// After all other validation, computing the Cholesky decomposition failed. + CholeskyFailed, +} + +impl std::fmt::Display for MultivariateNormalError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MultivariateNormalError::CovInvalid => { + write!(f, "Covariance matrix is asymmetric or contains a NaN") + } + MultivariateNormalError::MeanInvalid => write!(f, "Mean vector contains a NaN"), + MultivariateNormalError::DimensionMismatch => write!( + f, + "Mean vector and covariance matrix do not have the same number of rows" + ), + MultivariateNormalError::CholeskyFailed => { + write!(f, "Computing the Cholesky decomposition failed") + } + } + } +} + +impl std::error::Error for MultivariateNormalError {} + impl MultivariateNormal { - /// Constructs a new multivariate normal distribution with a mean of `mean` + /// Constructs a new multivariate normal distribution with a mean of `mean` /// and covariance matrix `cov` /// /// # Errors /// /// Returns an error if the given covariance matrix is not /// symmetric or positive-definite - pub fn new(mean: Vec, cov: Vec) -> Result { + pub fn new(mean: Vec, cov: Vec) -> Result { let mean = DVector::from_vec(mean); let cov = DMatrix::from_vec(mean.len(), mean.len(), cov); MultivariateNormal::new_from_nalgebra(mean, cov) @@ -141,24 +179,31 @@ where pub fn new_from_nalgebra( mean: OVector, cov: OMatrix, - ) -> Result { - // Check that the provided covariance matrix is symmetric - if cov.lower_triangle() != cov.upper_triangle().transpose() - // Check that mean and covariance do not contain NaN - || mean.iter().any(|f| f.is_nan()) + ) -> Result { + if mean.iter().any(|f| f.is_nan()) { + return Err(MultivariateNormalError::MeanInvalid); + } + + if !cov.is_square() + || cov.lower_triangle() != cov.upper_triangle().transpose() || cov.iter().any(|f| f.is_nan()) - // Check that the dimensions match - || mean.nrows() != cov.nrows() || cov.nrows() != cov.ncols() { - return Err(StatsError::BadParams); + return Err(MultivariateNormalError::CovInvalid); } + + // Compare number of rows + if mean.shape_generic().0 != cov.shape_generic().0 { + return Err(MultivariateNormalError::DimensionMismatch); + } + // Store the Cholesky decomposition of the covariance matrix // for sampling match Cholesky::new(cov.clone()) { - None => Err(StatsError::BadParams), + None => Err(MultivariateNormalError::CholeskyFailed), Some(cholesky_decomp) => { let precision = cholesky_decomp.inverse(); Ok(MultivariateNormal { + // .unwrap() because prerequisites are already checked above pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(), cov_chol_decomp: cholesky_decomp.unpack(), mu: mean, diff --git a/src/distribution/multivariate_students_t.rs b/src/distribution/multivariate_students_t.rs index 3758d328..143237b8 100644 --- a/src/distribution/multivariate_students_t.rs +++ b/src/distribution/multivariate_students_t.rs @@ -2,7 +2,6 @@ use crate::distribution::Continuous; use crate::distribution::{ChiSquared, Normal}; use crate::function::gamma; use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; -use crate::{Result, StatsError}; use nalgebra::{Cholesky, Const, DMatrix, Dim, DimMin, Dyn, OMatrix, OVector}; use rand::Rng; use std::f64::consts::PI; @@ -39,6 +38,53 @@ where ln_pdf_const: f64, } +/// Represents the errors that can occur when creating a [`MultivariateStudent`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum MultivariateStudentError { + /// The scale matrix is asymmetric or contains a NaN. + ScaleInvalid, + + /// The location vector contains a NaN. + LocationInvalid, + + /// The degrees of freedom are NaN, zero or less than zero. + FreedomInvalid, + + /// The amount of rows in the location vector is not equal to the amount + /// of rows in the scale matrix. + DimensionMismatch, + + /// After all other validation, computing the Cholesky decomposition failed. + /// This means that the scale matrix is not definite-positive. + CholeskyFailed, +} + +impl std::fmt::Display for MultivariateStudentError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MultivariateStudentError::ScaleInvalid => { + write!(f, "Scale matrix is asymmetric or contains a NaN") + } + MultivariateStudentError::LocationInvalid => { + write!(f, "Location vector contains a NaN") + } + MultivariateStudentError::FreedomInvalid => { + write!(f, "Degrees of freedom are NaN, zero or less than zero") + } + MultivariateStudentError::DimensionMismatch => write!( + f, + "Location vector and scale matrix do not have the same number of rows" + ), + MultivariateStudentError::CholeskyFailed => { + write!(f, "Computing the Cholesky decomposition failed") + } + } + } +} + +impl std::error::Error for MultivariateStudentError {} + impl MultivariateStudent { /// Constructs a new multivariate students t distribution with a location of `location`, /// scale matrix `scale` and `freedom` degrees of freedom. @@ -47,7 +93,11 @@ impl MultivariateStudent { /// /// Returns `StatsError::BadParams` if the scale matrix is not symmetric-positive /// definite and `StatsError::ArgMustBePositive` if freedom is non-positive. - pub fn new(location: Vec, scale: Vec, freedom: f64) -> Result { + pub fn new( + location: Vec, + scale: Vec, + freedom: f64, + ) -> Result { let dim = location.len(); Self::new_from_nalgebra(location.into(), DMatrix::from_vec(dim, dim, scale), freedom) } @@ -69,26 +119,26 @@ where location: OVector, scale: OMatrix, freedom: f64, - ) -> Result { + ) -> Result { let dim = location.len(); - // Check that the provided scale matrix is symmetric - if scale.lower_triangle() != scale.upper_triangle().transpose() - // Check that mean and covariance do not contain NaN - || location.iter().any(|f| f.is_nan()) + if location.iter().any(|f| f.is_nan()) { + return Err(MultivariateStudentError::LocationInvalid); + } + + if !scale.is_square() + || scale.lower_triangle() != scale.upper_triangle().transpose() || scale.iter().any(|f| f.is_nan()) - // Check that the dimensions match - || location.nrows() != scale.nrows() || scale.nrows() != scale.ncols() - // Check that the degrees of freedom is not NaN - || freedom.is_nan() { - return Err(StatsError::BadParams); + return Err(MultivariateStudentError::ScaleInvalid); } - // Check that degrees of freedom is positive - if freedom <= 0. { - return Err(StatsError::ArgMustBePositive( - "Degrees of freedom must be positive", - )); + + if freedom.is_nan() || freedom <= 0.0 { + return Err(MultivariateStudentError::FreedomInvalid); + } + + if location.nrows() != scale.nrows() { + return Err(MultivariateStudentError::DimensionMismatch); } let scale_det = scale.determinant(); @@ -98,7 +148,7 @@ where - 0.5 * scale_det.ln(); match Cholesky::new(scale.clone()) { - None => Err(StatsError::BadParams), // Scale matrix is not positive definite + None => Err(MultivariateStudentError::CholeskyFailed), Some(cholesky_decomp) => { let precision = cholesky_decomp.inverse(); Ok(MultivariateStudent { diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index d461af6e..a3f4ffed 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -1,7 +1,6 @@ use crate::distribution::{self, poisson, Discrete, DiscreteCDF}; use crate::function::{beta, gamma}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -41,6 +40,28 @@ pub struct NegativeBinomial { p: f64, } +/// Represents the errors that can occur when creating a [`NegativeBinomial`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum NegativeBinomialError { + /// `r` is NaN or less than zero. + RInvalid, + + /// `p` is NaN or not in `[0, 1]`. + PInvalid, +} + +impl std::fmt::Display for NegativeBinomialError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + NegativeBinomialError::RInvalid => write!(f, "r is NaN or less than zero"), + NegativeBinomialError::PInvalid => write!(f, "p is NaN or not in [0, 1]"), + } + } +} + +impl std::error::Error for NegativeBinomialError {} + impl NegativeBinomial { /// Constructs a new negative binomial distribution with parameters `r` /// and `p`. When `r` is an integer, the negative binomial distribution @@ -64,12 +85,16 @@ impl NegativeBinomial { /// result = NegativeBinomial::new(-0.5, 5.0); /// assert!(result.is_err()); /// ``` - pub fn new(r: f64, p: f64) -> Result { - if p.is_nan() || !(0.0..=1.0).contains(&p) || r.is_nan() || r < 0.0 { - Err(StatsError::BadParams) - } else { - Ok(NegativeBinomial { r, p }) + pub fn new(r: f64, p: f64) -> Result { + if r.is_nan() || r < 0.0 { + return Err(NegativeBinomialError::RInvalid); } + + if p.is_nan() || !(0.0..=1.0).contains(&p) { + return Err(NegativeBinomialError::PInvalid); + } + + Ok(NegativeBinomial { r, p }) } /// Returns the probability of success `p` of a single @@ -295,7 +320,7 @@ mod tests { use crate::distribution::internal::test; use crate::testing_boiler; - testing_boiler!(r: f64, p: f64; NegativeBinomial; StatsError); + testing_boiler!(r: f64, p: f64; NegativeBinomial; NegativeBinomialError); #[test] fn test_create() { @@ -306,8 +331,8 @@ mod tests { #[test] fn test_bad_create() { - create_err(f64::NAN, 1.0); - create_err(0.0, f64::NAN); + test_create_err(f64::NAN, 1.0, NegativeBinomialError::RInvalid); + test_create_err(0.0, f64::NAN, NegativeBinomialError::PInvalid); create_err(-1.0, 1.0); create_err(2.0, 2.0); } diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 74a4fde4..71f50eda 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -1,7 +1,7 @@ +use crate::consts; use crate::distribution::{ziggurat, Continuous, ContinuousCDF}; use crate::function::erf; use crate::statistics::*; -use crate::{consts, Result, StatsError}; use rand::Rng; use std::f64; @@ -24,6 +24,30 @@ pub struct Normal { std_dev: f64, } +/// Represents the errors that can occur when creating a [`Normal`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum NormalError { + /// The mean is NaN. + MeanInvalid, + + /// The standard deviation is NaN, zero or less than zero. + StandardDeviationInvalid, +} + +impl std::fmt::Display for NormalError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + NormalError::MeanInvalid => write!(f, "Mean is NaN"), + NormalError::StandardDeviationInvalid => { + write!(f, "Standard deviation is NaN, zero or less than zero") + } + } + } +} + +impl std::error::Error for NormalError {} + impl Normal { /// Constructs a new normal distribution with a mean of `mean` /// and a standard deviation of `std_dev` @@ -44,12 +68,16 @@ impl Normal { /// result = Normal::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(mean: f64, std_dev: f64) -> Result { - if mean.is_nan() || std_dev.is_nan() || std_dev <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(Normal { mean, std_dev }) + pub fn new(mean: f64, std_dev: f64) -> Result { + if mean.is_nan() { + return Err(NormalError::MeanInvalid); } + + if std_dev.is_nan() || std_dev <= 0.0 { + return Err(NormalError::StandardDeviationInvalid); + } + + Ok(Normal { mean, std_dev }) } /// Constructs a new standard normal distribution with a mean of 0 @@ -338,7 +366,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(mean: f64, std_dev: f64; Normal; StatsError); + testing_boiler!(mean: f64, std_dev: f64; Normal; NormalError); #[test] fn test_create() { @@ -351,9 +379,9 @@ mod tests { #[test] fn test_bad_create() { + test_create_err(f64::NAN, 1.0, NormalError::MeanInvalid); + test_create_err(1.0, f64::NAN, NormalError::StandardDeviationInvalid); create_err(0.0, 0.0); - create_err(f64::NAN, 1.0); - create_err(1.0, f64::NAN); create_err(f64::NAN, f64::NAN); create_err(1.0, -1.0); } diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index c983289b..06f9e176 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::distributions::OpenClosed01; use rand::Rng; use std::f64; @@ -25,6 +24,28 @@ pub struct Pareto { shape: f64, } +/// Represents the errors that can occur when creating a [`Pareto`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum ParetoError { + /// The scale is NaN, zero or less than zero. + ScaleInvalid, + + /// The shape is NaN, zero or less than zero. + ShapeInvalid, +} + +impl std::fmt::Display for ParetoError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ParetoError::ScaleInvalid => write!(f, "Scale is NaN, zero, or less than zero"), + ParetoError::ShapeInvalid => write!(f, "Shape is NaN, zero, or less than zero"), + } + } +} + +impl std::error::Error for ParetoError {} + impl Pareto { /// Constructs a new Pareto distribution with scale `scale`, and `shape` /// shape. @@ -45,13 +66,16 @@ impl Pareto { /// result = Pareto::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(scale: f64, shape: f64) -> Result { - let is_nan = scale.is_nan() || shape.is_nan(); - if is_nan || scale <= 0.0 || shape <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(Pareto { scale, shape }) + pub fn new(scale: f64, shape: f64) -> Result { + if scale.is_nan() || scale <= 0.0 { + return Err(ParetoError::ScaleInvalid); } + + if shape.is_nan() || shape <= 0.0 { + return Err(ParetoError::ShapeInvalid); + } + + Ok(Pareto { scale, shape }) } /// Returns the scale of the Pareto distribution @@ -358,7 +382,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(scale: f64, shape: f64; Pareto; StatsError); + testing_boiler!(scale: f64, shape: f64; Pareto; ParetoError); #[test] fn test_create() { @@ -372,9 +396,9 @@ mod tests { #[test] fn test_bad_create() { + test_create_err(1.0, -1.0, ParetoError::ShapeInvalid); + test_create_err(-1.0, 1.0, ParetoError::ScaleInvalid); create_err(0.0, 0.0); - create_err(1.0, -1.0); - create_err(-1.0, 1.0); create_err(-1.0, -1.0); create_err(f64::NAN, 1.0); create_err(1.0, f64::NAN); diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 4588d00e..0780cf55 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -1,7 +1,6 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{factorial, gamma}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -24,6 +23,24 @@ pub struct Poisson { lambda: f64, } +/// Represents the errors that can occur when creating a [`Poisson`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum PoissonError { + /// The lambda is NaN, zero or less than zero. + LambdaInvalid, +} + +impl std::fmt::Display for PoissonError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + PoissonError::LambdaInvalid => write!(f, "Lambda is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for PoissonError {} + impl Poisson { /// Constructs a new poisson distribution with a rate (λ) /// of `lambda` @@ -43,9 +60,9 @@ impl Poisson { /// result = Poisson::new(0.0); /// assert!(result.is_err()); /// ``` - pub fn new(lambda: f64) -> Result { + pub fn new(lambda: f64) -> Result { if lambda.is_nan() || lambda <= 0.0 { - Err(StatsError::BadParams) + Err(PoissonError::LambdaInvalid) } else { Ok(Poisson { lambda }) } @@ -308,7 +325,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(lambda: f64; Poisson; StatsError); + testing_boiler!(lambda: f64; Poisson; PoissonError); #[test] fn test_create() { diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index a362e071..b9543cc1 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::{beta, gamma}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -26,15 +25,42 @@ pub struct StudentsT { freedom: f64, } +/// Represents the errors that can occur when creating a [`StudentsT`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum StudentsTError { + /// The location is NaN. + LocationInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, + + /// The degrees of freedom are NaN, zero or less than zero. + FreedomInvalid, +} + +impl std::fmt::Display for StudentsTError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + StudentsTError::LocationInvalid => write!(f, "Location is NaN"), + StudentsTError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), + StudentsTError::FreedomInvalid => { + write!(f, "Degrees of freedom are NaN, zero or less than zero") + } + } + } +} + +impl std::error::Error for StudentsTError {} + impl StudentsT { /// Constructs a new student's t-distribution with location `location`, - /// scale `scale`, - /// and `freedom` freedom. + /// scale `scale`, and `freedom` freedom. /// /// # Errors /// /// Returns an error if any of `location`, `scale`, or `freedom` are `NaN`. - /// Returns an error if `scale <= 0.0` or `freedom <= 0.0` + /// Returns an error if `scale <= 0.0` or `freedom <= 0.0`. /// /// # Examples /// @@ -47,17 +73,24 @@ impl StudentsT { /// result = StudentsT::new(0.0, 0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(location: f64, scale: f64, freedom: f64) -> Result { - let is_nan = location.is_nan() || scale.is_nan() || freedom.is_nan(); - if is_nan || scale <= 0.0 || freedom <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(StudentsT { - location, - scale, - freedom, - }) + pub fn new(location: f64, scale: f64, freedom: f64) -> Result { + if location.is_nan() { + return Err(StudentsTError::LocationInvalid); + } + + if scale.is_nan() || scale <= 0.0 { + return Err(StudentsTError::ScaleInvalid); + } + + if freedom.is_nan() || freedom <= 0.0 { + return Err(StudentsTError::FreedomInvalid); } + + Ok(StudentsT { + location, + scale, + freedom, + }) } /// Returns the location of the student's t-distribution @@ -426,7 +459,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(location: f64, scale: f64, freedom: f64; StudentsT; StatsError); + testing_boiler!(location: f64, scale: f64, freedom: f64; StudentsT; StudentsTError); #[test] fn test_create() { @@ -444,11 +477,17 @@ mod tests { #[test] fn test_bad_create() { - create_err(f64::NAN, 1.0, 1.0); - create_err(0.0, f64::NAN, 1.0); - create_err(0.0, 1.0, f64::NAN); - create_err(0.0, -10.0, 1.0); - create_err(0.0, 10.0, -1.0); + let invalid = [ + (f64::NAN, 1.0, 1.0, StudentsTError::LocationInvalid), + (0.0, f64::NAN, 1.0, StudentsTError::ScaleInvalid), + (0.0, 1.0, f64::NAN, StudentsTError::FreedomInvalid), + (0.0, -10.0, 1.0, StudentsTError::ScaleInvalid), + (0.0, 10.0, -1.0, StudentsTError::FreedomInvalid), + ]; + + for (l, s, f, err) in invalid { + test_create_err(l, s, f, err); + } } #[test] diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index 2a4d31b2..9115e3c4 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::Rng; use std::f64; @@ -25,6 +24,42 @@ pub struct Triangular { mode: f64, } +/// Represents the errors that can occur when creating a [`Triangular`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum TriangularError { + /// The minimum is NaN or infinite. + MinInvalid, + + /// The maximum is NaN or infinite. + MaxInvalid, + + /// The mode is NaN or infinite. + ModeInvalid, + + /// The mode is less than the minimum or greater than the maximum. + ModeOutOfRange, + + /// The minimum equals the maximum. + MinEqualsMax, +} + +impl std::fmt::Display for TriangularError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TriangularError::MinInvalid => write!(f, "Minimum is NaN or infinite."), + TriangularError::MaxInvalid => write!(f, "Maximum is NaN or infinite."), + TriangularError::ModeInvalid => write!(f, "Mode is NaN or infinite."), + TriangularError::ModeOutOfRange => { + write!(f, "Mode is less than minimum or greater than maximum") + } + TriangularError::MinEqualsMax => write!(f, "Minimum equals Maximum"), + } + } +} + +impl std::error::Error for TriangularError {} + impl Triangular { /// Constructs a new triangular distribution with a minimum of `min`, /// maximum of `max`, and a mode of `mode`. @@ -45,16 +80,27 @@ impl Triangular { /// result = Triangular::new(2.5, 1.5, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(min: f64, max: f64, mode: f64) -> Result { - if !min.is_finite() || !max.is_finite() || !mode.is_finite() { - return Err(StatsError::BadParams); + pub fn new(min: f64, max: f64, mode: f64) -> Result { + if !min.is_finite() { + return Err(TriangularError::MinInvalid); + } + + if !max.is_finite() { + return Err(TriangularError::MaxInvalid); + } + + if !mode.is_finite() { + return Err(TriangularError::ModeInvalid); } + if max < mode || mode < min { - return Err(StatsError::BadParams); + return Err(TriangularError::ModeOutOfRange); } - if ulps_eq!(max, min, max_ulps = 0) { - return Err(StatsError::BadParams); + + if min == max { + return Err(TriangularError::MinEqualsMax); } + Ok(Triangular { min, max, mode }) } } @@ -351,7 +397,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(min: f64, max: f64, mode: f64; Triangular; StatsError); + testing_boiler!(min: f64, max: f64, mode: f64; Triangular; TriangularError); #[test] fn test_create() { @@ -366,17 +412,23 @@ mod tests { #[test] fn test_bad_create() { - create_err(0.0, 0.0, 0.0); - create_err(0.0, 1.0, -0.1); - create_err(0.0, 1.0, 1.1); - create_err(0.0, -1.0, 0.5); - create_err(2.0, 1.0, 1.5); - create_err(f64::NAN, 1.0, 0.5); - create_err(0.2, f64::NAN, 0.5); - create_err(0.5, 1.0, f64::NAN); - create_err(f64::NAN, f64::NAN, f64::NAN); - create_err(f64::NEG_INFINITY, 1.0, 0.5); - create_err(0.0, f64::INFINITY, 0.5); + let invalid = [ + (0.0, 0.0, 0.0, TriangularError::MinEqualsMax), + (0.0, 1.0, -0.1, TriangularError::ModeOutOfRange), + (0.0, 1.0, 1.1, TriangularError::ModeOutOfRange), + (0.0, -1.0, 0.5, TriangularError::ModeOutOfRange), + (2.0, 1.0, 1.5, TriangularError::ModeOutOfRange), + (f64::NAN, 1.0, 0.5, TriangularError::MinInvalid), + (0.2, f64::NAN, 0.5, TriangularError::MaxInvalid), + (0.5, 1.0, f64::NAN, TriangularError::ModeInvalid), + (f64::NAN, f64::NAN, f64::NAN, TriangularError::MinInvalid), + (f64::NEG_INFINITY, 1.0, 0.5, TriangularError::MinInvalid), + (0.0, f64::INFINITY, 0.5, TriangularError::MaxInvalid), + ]; + + for (min, max, mode, err) in invalid { + test_create_err(min, max, mode, err); + } } #[test] diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index fdf25498..588413cb 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; use rand::distributions::Uniform as RandUniform; use rand::Rng; use std::f64; @@ -26,13 +25,42 @@ pub struct Uniform { max: f64, } +/// Represents the errors that can occur when creating a [`Uniform`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum UniformError { + /// The minimum is NaN or infinite. + MinInvalid, + + /// The maximum is NaN or infinite. + MaxInvalid, + + /// The maximum is not greater than the minimum. + MaxNotGreaterThanMin, +} + +impl std::fmt::Display for UniformError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + UniformError::MinInvalid => write!(f, "Minimum is NaN or infinite"), + UniformError::MaxInvalid => write!(f, "Maximum is NaN or infinite"), + UniformError::MaxNotGreaterThanMin => { + write!(f, "Maximum is not greater than the minimum") + } + } + } +} + +impl std::error::Error for UniformError {} + impl Uniform { /// Constructs a new uniform distribution with a min of `min` and a max - /// of `max` + /// of `max`. /// /// # Errors /// - /// Returns an error if `min` or `max` are `NaN` or unbounded + /// Returns an error if `min` or `max` are `NaN` or infinite. + /// Returns an error if `min >= max`. /// /// # Examples /// @@ -49,17 +77,19 @@ impl Uniform { /// result = Uniform::new(f64::NEG_INFINITY, 1.0); /// assert!(result.is_err()); /// ``` - pub fn new(min: f64, max: f64) -> Result { - if min.is_nan() || max.is_nan() { - return Err(StatsError::BadParams); + pub fn new(min: f64, max: f64) -> Result { + if !min.is_finite() { + return Err(UniformError::MinInvalid); } - match (min.is_finite(), max.is_finite(), min < max) { - (false, false, _) => Err(StatsError::ArgFinite("min and max")), - (false, true, _) => Err(StatsError::ArgFinite("min")), - (true, false, _) => Err(StatsError::ArgFinite("max")), - (true, true, false) => Err(StatsError::ArgLteArg("min", "max")), - (true, true, true) => Ok(Uniform { min, max }), + if !max.is_finite() { + return Err(UniformError::MaxInvalid); + } + + if min < max { + Ok(Uniform { min, max }) + } else { + Err(UniformError::MaxNotGreaterThanMin) } } @@ -288,7 +318,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(min: f64, max: f64; Uniform; StatsError); + testing_boiler!(min: f64, max: f64; Uniform; UniformError); #[test] fn test_create() { @@ -300,12 +330,18 @@ mod tests { #[test] fn test_bad_create() { - create_err(0.0, 0.0); - create_err(f64::NAN, 1.0); - create_err(1.0, f64::NAN); - create_err(f64::NAN, f64::NAN); - create_err(0.0, f64::INFINITY); - create_err(1.0, 0.0); + let invalid = [ + (0.0, 0.0, UniformError::MaxNotGreaterThanMin), + (f64::NAN, 1.0, UniformError::MinInvalid), + (1.0, f64::NAN, UniformError::MaxInvalid), + (f64::NAN, f64::NAN, UniformError::MinInvalid), + (0.0, f64::INFINITY, UniformError::MaxInvalid), + (1.0, 0.0, UniformError::MaxNotGreaterThanMin), + ]; + + for (min, max, err) in invalid { + test_create_err(min, max, err); + } } #[test] diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 8137f664..e0d856c5 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -1,7 +1,7 @@ +use crate::consts; use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; -use crate::{consts, Result, StatsError}; use rand::Rng; use std::f64; @@ -27,6 +27,28 @@ pub struct Weibull { scale_pow_shape_inv: f64, } +/// Represents the errors that can occur when creating a [`Weibull`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum WeibullError { + /// The shape is NaN, zero or less than zero. + ShapeInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, +} + +impl std::fmt::Display for WeibullError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + WeibullError::ShapeInvalid => write!(f, "Shape is NaN, zero or less than zero."), + WeibullError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero."), + } + } +} + +impl std::error::Error for WeibullError {} + impl Weibull { /// Constructs a new weibull distribution with a shape (k) of `shape` /// and a scale (λ) of `scale` @@ -47,17 +69,20 @@ impl Weibull { /// result = Weibull::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: f64, scale: f64) -> Result { - let is_nan = shape.is_nan() || scale.is_nan(); - match (shape, scale, is_nan) { - (_, _, true) => Err(StatsError::BadParams), - (_, _, false) if shape <= 0.0 || scale <= 0.0 => Err(StatsError::BadParams), - (_, _, false) => Ok(Weibull { - shape, - scale, - scale_pow_shape_inv: scale.powf(-shape), - }), + pub fn new(shape: f64, scale: f64) -> Result { + if shape.is_nan() || shape <= 0.0 { + return Err(WeibullError::ShapeInvalid); } + + if scale.is_nan() || scale <= 0.0 { + return Err(WeibullError::ScaleInvalid); + } + + Ok(Weibull { + shape, + scale, + scale_pow_shape_inv: scale.powf(-shape), + }) } /// Returns the shape of the weibull distribution @@ -354,7 +379,7 @@ mod tests { use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!(shape: f64, scale: f64; Weibull; StatsError); + testing_boiler!(shape: f64, scale: f64; Weibull; WeibullError); #[test] fn test_create() { @@ -366,8 +391,8 @@ mod tests { #[test] fn test_bad_create() { - create_err(f64::NAN, 1.0); - create_err(1.0, f64::NAN); + test_create_err(f64::NAN, 1.0, WeibullError::ShapeInvalid); + test_create_err(1.0, f64::NAN, WeibullError::ScaleInvalid); create_err(f64::NAN, f64::NAN); create_err(1.0, -1.0); create_err(-1.0, 1.0); diff --git a/src/stats_tests/fisher.rs b/src/stats_tests/fisher.rs index 69b41d6d..31d07173 100644 --- a/src/stats_tests/fisher.rs +++ b/src/stats_tests/fisher.rs @@ -1,6 +1,5 @@ use super::Alternative; -use crate::distribution::{Discrete, DiscreteCDF, Hypergeometric}; -use crate::StatsError; +use crate::distribution::{Discrete, DiscreteCDF, Hypergeometric, HypergeometricError}; const EPSILON: f64 = 1.0 - 1e-4; @@ -97,6 +96,34 @@ fn binary_search( guess } +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum FishersExactTestError { + /// The table does not describe a valid [`Hypergeometric`] distribution. + /// Make sure that the contingency table stores the data in row-major order. + TableInvalidForHypergeometric(HypergeometricError), +} + +impl std::fmt::Display for FishersExactTestError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + FishersExactTestError::TableInvalidForHypergeometric(hg_err) => { + writeln!(f, "Cannot create a Hypergeometric distribution from the data in the contingency table.")?; + writeln!(f, "Is it in row-major order?")?; + write!(f, "Inner error: '{}'", hg_err) + } + } + } +} + +impl std::error::Error for FishersExactTestError {} + +impl From for FishersExactTestError { + fn from(value: HypergeometricError) -> Self { + Self::TableInvalidForHypergeometric(value) + } +} + /// Perform a Fisher exact test on a 2x2 contingency table. /// Based on scipy's fisher test: /// Expects a table in row-major order @@ -112,7 +139,7 @@ fn binary_search( pub fn fishers_exact_with_odds_ratio( table: &[u64; 4], alternative: Alternative, -) -> Result<(f64, f64), StatsError> { +) -> Result<(f64, f64), FishersExactTestError> { // If both values in a row or column are zero, p-value is 1 and odds ratio is NaN. match table { [0, _, 0, _] | [_, 0, _, 0] => return Ok((f64::NAN, 1.0)), // both 0 in a row @@ -144,7 +171,10 @@ pub fn fishers_exact_with_odds_ratio( /// let table = [3, 5, 4, 50]; /// let p_value = fishers_exact(&table, Alternative::Less).unwrap(); /// ``` -pub fn fishers_exact(table: &[u64; 4], alternative: Alternative) -> Result { +pub fn fishers_exact( + table: &[u64; 4], + alternative: Alternative, +) -> Result { // If both values in a row or column are zero, the p-value is 1 and the odds ratio is NaN. match table { [0, _, 0, _] | [_, 0, _, 0] => return Ok(1.0), // both 0 in a row From 485eef4a69098c8104aa6aabcc96de5317858945 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Fri, 6 Sep 2024 09:56:14 +0200 Subject: [PATCH 173/185] chore: Exclude Error Display impls from coverage --- Cargo.toml | 5 +++++ src/distribution/beta.rs | 1 + src/distribution/binomial.rs | 1 + src/distribution/categorical.rs | 1 + src/distribution/cauchy.rs | 1 + src/distribution/chi.rs | 1 + src/distribution/dirac.rs | 1 + src/distribution/dirichlet.rs | 1 + src/distribution/discrete_uniform.rs | 1 + src/distribution/exponential.rs | 1 + src/distribution/fisher_snedecor.rs | 1 + src/distribution/gamma.rs | 1 + src/distribution/geometric.rs | 1 + src/distribution/hypergeometric.rs | 1 + src/distribution/inverse_gamma.rs | 1 + src/distribution/laplace.rs | 1 + src/distribution/log_normal.rs | 1 + src/distribution/multinomial.rs | 1 + src/distribution/multivariate_normal.rs | 1 + src/distribution/multivariate_students_t.rs | 1 + src/distribution/negative_binomial.rs | 1 + src/distribution/normal.rs | 1 + src/distribution/pareto.rs | 1 + src/distribution/poisson.rs | 1 + src/distribution/students_t.rs | 1 + src/distribution/triangular.rs | 1 + src/distribution/uniform.rs | 1 + src/distribution/weibull.rs | 1 + src/lib.rs | 1 + src/stats_tests/fisher.rs | 1 + 30 files changed, 34 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 15b06e4e..b295a5fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,3 +44,8 @@ anyhow = "1.0" version = "0.32" default-features = false features = ["macros"] + +[lints.rust.unexpected_cfgs] +level = "warn" +# Set by cargo-llvm-cov when running on nightly +check-cfg = ['cfg(coverage_nightly)'] diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 763945d6..e20ea302 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -38,6 +38,7 @@ pub enum BetaError { } impl std::fmt::Display for BetaError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { BetaError::ShapeAInvalid => write!(f, "Shape A is NaN, zero or negative"), diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 9f5ffc47..1d86283d 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -34,6 +34,7 @@ pub enum BinomialError { } impl std::fmt::Display for BinomialError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { BinomialError::ProbabilityInvalid => write!(f, "Probability is NaN or not in [0, 1]"), diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 9cd95a9d..cb3c7ea8 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -41,6 +41,7 @@ pub enum CategoricalError { } impl std::fmt::Display for CategoricalError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { CategoricalError::ProbMassEmpty => write!(f, "Probability mass is empty"), diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index a8815349..5ba7f69f 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -34,6 +34,7 @@ pub enum CauchyError { } impl std::fmt::Display for CauchyError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { CauchyError::LocationInvalid => write!(f, "Location is NaN"), diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index cce8535f..796fcd23 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -32,6 +32,7 @@ pub enum ChiError { } impl std::fmt::Display for ChiError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { ChiError::FreedomInvalid => { diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index 142f8c12..18e70f9b 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -26,6 +26,7 @@ pub enum DiracError { } impl std::fmt::Display for DiracError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { DiracError::ValueInvalid => write!(f, "Value v is NaN"), diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index b8aaad86..2670929a 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -43,6 +43,7 @@ pub enum DirichletError { } impl std::fmt::Display for DirichletError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { DirichletError::AlphaTooShort => write!(f, "Alpha contains less than two elements"), diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index be851b62..524bc2a3 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -31,6 +31,7 @@ pub enum DiscreteUniformError { } impl std::fmt::Display for DiscreteUniformError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { DiscreteUniformError::MinMaxInvalid => write!(f, "Maximum is less than minimum"), diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index 1389b1e4..ec30d1f7 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -33,6 +33,7 @@ pub enum ExpError { } impl std::fmt::Display for ExpError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { ExpError::RateInvalid => write!(f, "Rate is NaN, zero or less than zero"), diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index a50ed42a..610da130 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -37,6 +37,7 @@ pub enum FisherSnedecorError { } impl std::fmt::Display for FisherSnedecorError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { FisherSnedecorError::Freedom1Invalid => { diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 201d0cfd..b6055c77 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -39,6 +39,7 @@ pub enum GammaError { } impl std::fmt::Display for GammaError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { GammaError::ShapeInvalid => write!(f, "Shape is NaN zero, or less than zero."), diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index dfb28ef6..41e35f52 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -33,6 +33,7 @@ pub enum GeometricError { } impl std::fmt::Display for GeometricError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { GeometricError::ProbabilityInvalid => write!(f, "Probability is NaN or not in (0, 1]"), diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index c3ba5a3a..7da6f45a 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -28,6 +28,7 @@ pub enum HypergeometricError { } impl std::fmt::Display for HypergeometricError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { HypergeometricError::TooManySuccesses => write!(f, "successes > population"), diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index 1f1bee4b..a36c7e17 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -37,6 +37,7 @@ pub enum InverseGammaError { } impl std::fmt::Display for InverseGammaError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { InverseGammaError::ShapeInvalid => { diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index d4f6fc16..13f03a55 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -34,6 +34,7 @@ pub enum LaplaceError { } impl std::fmt::Display for LaplaceError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { LaplaceError::LocationInvalid => write!(f, "Location is NaN"), diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 9075bfd2..49380d2b 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -38,6 +38,7 @@ pub enum LogNormalError { } impl std::fmt::Display for LogNormalError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { LogNormalError::LocationInvalid => write!(f, "Location is NaN"), diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index 4fcaca1a..aca82a2a 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -47,6 +47,7 @@ pub enum MultinomialError { } impl std::fmt::Display for MultinomialError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { MultinomialError::NotEnoughProbabilities => write!(f, "Fewer than two probabilities"), diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 0da93cd4..68b8d098 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -127,6 +127,7 @@ pub enum MultivariateNormalError { } impl std::fmt::Display for MultivariateNormalError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { MultivariateNormalError::CovInvalid => { diff --git a/src/distribution/multivariate_students_t.rs b/src/distribution/multivariate_students_t.rs index 143237b8..8bcb1d78 100644 --- a/src/distribution/multivariate_students_t.rs +++ b/src/distribution/multivariate_students_t.rs @@ -61,6 +61,7 @@ pub enum MultivariateStudentError { } impl std::fmt::Display for MultivariateStudentError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { MultivariateStudentError::ScaleInvalid => { diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index a3f4ffed..6ed557be 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -52,6 +52,7 @@ pub enum NegativeBinomialError { } impl std::fmt::Display for NegativeBinomialError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { NegativeBinomialError::RInvalid => write!(f, "r is NaN or less than zero"), diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index 71f50eda..b536c101 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -36,6 +36,7 @@ pub enum NormalError { } impl std::fmt::Display for NormalError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { NormalError::MeanInvalid => write!(f, "Mean is NaN"), diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 06f9e176..886db43b 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -36,6 +36,7 @@ pub enum ParetoError { } impl std::fmt::Display for ParetoError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { ParetoError::ScaleInvalid => write!(f, "Scale is NaN, zero, or less than zero"), diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 0780cf55..33e8f8a2 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -32,6 +32,7 @@ pub enum PoissonError { } impl std::fmt::Display for PoissonError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { PoissonError::LambdaInvalid => write!(f, "Lambda is NaN, zero or less than zero"), diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index b9543cc1..cc88707f 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -40,6 +40,7 @@ pub enum StudentsTError { } impl std::fmt::Display for StudentsTError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { StudentsTError::LocationInvalid => write!(f, "Location is NaN"), diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index 9115e3c4..eb3cb93d 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -45,6 +45,7 @@ pub enum TriangularError { } impl std::fmt::Display for TriangularError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { TriangularError::MinInvalid => write!(f, "Minimum is NaN or infinite."), diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index 588413cb..55bd7884 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -40,6 +40,7 @@ pub enum UniformError { } impl std::fmt::Display for UniformError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { UniformError::MinInvalid => write!(f, "Minimum is NaN or infinite"), diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index e0d856c5..71aa30ef 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -39,6 +39,7 @@ pub enum WeibullError { } impl std::fmt::Display for WeibullError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { WeibullError::ShapeInvalid => write!(f, "Shape is NaN, zero or less than zero."), diff --git a/src/lib.rs b/src/lib.rs index 56f9b162..7ca3157c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,7 @@ #![allow(clippy::excessive_precision)] #![allow(clippy::many_single_char_names)] #![forbid(unsafe_code)] +#![cfg_attr(coverage_nightly, feature(coverage_attribute))] #[macro_use] extern crate approx; diff --git a/src/stats_tests/fisher.rs b/src/stats_tests/fisher.rs index 31d07173..909b4e7b 100644 --- a/src/stats_tests/fisher.rs +++ b/src/stats_tests/fisher.rs @@ -105,6 +105,7 @@ pub enum FishersExactTestError { } impl std::fmt::Display for FishersExactTestError { + #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { FishersExactTestError::TableInvalidForHypergeometric(hg_err) => { From 59d0b711e099c32370c11aacede324dd3e0715c7 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 8 Sep 2024 21:26:03 -0500 Subject: [PATCH 174/185] test: assert Send and Sync for error types Co-authored-by: FreezyLemon --- src/distribution/dirichlet.rs | 13 ++++++++----- src/distribution/internal.rs | 7 +++++++ src/distribution/multinomial.rs | 6 ++++++ src/distribution/multivariate_normal.rs | 8 ++++++++ src/distribution/multivariate_students_t.rs | 8 +++++++- 5 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index 2670929a..355476db 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -369,15 +369,12 @@ where #[rustfmt::skip] #[cfg(test)] mod tests { + use super::*; + use std::fmt::{Debug, Display}; use nalgebra::{dmatrix, dvector, vector, DimMin, OVector}; - use crate::{ - distribution::{Continuous, Dirichlet}, - statistics::{MeanN, VarianceN}, - }; - fn try_create(alpha: OVector) -> Dirichlet where D: DimMin, @@ -580,4 +577,10 @@ mod tests { let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); n.ln_pdf(&vector![0.5, 0.25, 0.8, 0.9]); } + + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::(); + } } diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index c3b9f22c..9e7651b0 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -241,6 +241,13 @@ pub mod test { ) } } + + /// Asserts that associated error type is Send and Sync + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::<$dist_err>(); + } }; } diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index aca82a2a..7d1b408c 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -500,6 +500,12 @@ mod tests { ); } + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::(); + } + // #[test] // #[should_panic] // fn test_pmf_x_wrong_length() { diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 68b8d098..eb86edd3 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -389,6 +389,8 @@ mod tests { statistics::{Max, MeanN, Min, Mode, VarianceN}, }; + use super::MultivariateNormalError; + fn try_create(mean: OVector, covariance: OMatrix) -> MultivariateNormal where D: DimMin, @@ -703,4 +705,10 @@ mod tests { let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.,]).unwrap(); mvn.pdf(&vec![1.].into()); // x.size != mu.size } + + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::(); + } } diff --git a/src/distribution/multivariate_students_t.rs b/src/distribution/multivariate_students_t.rs index 8bcb1d78..b75e0a88 100644 --- a/src/distribution/multivariate_students_t.rs +++ b/src/distribution/multivariate_students_t.rs @@ -397,6 +397,8 @@ mod tests { statistics::{Max, MeanN, Min, Mode, VarianceN}, }; + use super::MultivariateStudentError; + fn try_create(location: Vec, scale: Vec, freedom: f64) -> MultivariateStudent { let mvs = MultivariateStudent::new(location, scale, freedom); @@ -614,5 +616,9 @@ mod tests { assert_eq!(mvs.scale_chol_decomp(), &OMatrix::::identity(2, 2)); } - + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::(); + } } From 1ece9341de45d840692ad5f98ced027f4be22b02 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Tue, 27 Aug 2024 11:43:42 +0200 Subject: [PATCH 175/185] build: Add MSRV in Cargo.toml --- Cargo.lock.MSRV | 848 ++++++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 4 +- README.md | 4 + 3 files changed, 855 insertions(+), 1 deletion(-) create mode 100644 Cargo.lock.MSRV diff --git a/Cargo.lock.MSRV b/Cargo.lock.MSRV new file mode 100644 index 00000000..3dcbc271 --- /dev/null +++ b/Cargo.lock.MSRV @@ -0,0 +1,848 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" + +[[package]] +name = "anyhow" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + +[[package]] +name = "bytemuck" +version = "1.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773d90827bc3feecfb67fab12e24de0749aad83c74b9504ecde46237b5cd24e2" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.5.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + +[[package]] +name = "hermit-abi" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" + +[[package]] +name = "is-terminal" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "js-sys" +version = "0.3.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.158" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "log" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" + +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "nalgebra" +version = "0.32.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4" +dependencies = [ + "approx", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "rand", + "rand_distr", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "plotters" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" + +[[package]] +name = "plotters-svg" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "regex" +version = "1.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "safe_arch" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3460605018fdc9612bce72735cba0d27efbcd9904780d44c7e3a9948f96148a" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "serde" +version = "1.0.209" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.209" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.128" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "simba" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + +[[package]] +name = "statrs" +version = "0.17.1" +dependencies = [ + "anyhow", + "approx", + "criterion", + "nalgebra", + "num-traits", + "rand", +] + +[[package]] +name = "syn" +version = "2.0.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +dependencies = [ + "cfg-if", + "once_cell", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" + +[[package]] +name = "web-sys" +version = "0.3.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "wide" +version = "0.7.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b828f995bf1e9622031f8009f8481a85406ce1f4d4588ff746d872043e855690" +dependencies = [ + "bytemuck", + "safe_arch", +] + +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index b295a5fe..f9f531df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,4 @@ [package] - name = "statrs" version = "0.17.1" authors = ["Michael Ma"] @@ -13,6 +12,9 @@ edition = "2021" include = ["CHANGELOG.md", "LICENSE.md", "src/", "tests/"] +# When changing MSRV: Also update the README +rust-version = "1.66.0" + [lib] name = "statrs" path = "src/lib.rs" diff --git a/README.md b/README.md index 8d4ff9db..973f0046 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,10 @@ cargo test If you'd like to modify where the data is downloaded, you can use the environment variable, `STATRS_NIST_DATA_DIR` for running the script and the tests. +## Minimum supported Rust version (MSRV) + +This crate requires a Rust version of 1.66.0 or higher. Increases in MSRV will be considered a semver non-breaking API change and require a version increase (PATCH until 1.0.0, MINOR after 1.0.0). + ## Contributing Thanks for your help to improve the project! From 04ce395f115b8a56178747e1bb96a6c09d099eb7 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Tue, 27 Aug 2024 12:47:39 +0200 Subject: [PATCH 176/185] ci: Check MSRV in CI job --- .github/workflows/test.yml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1b0a2eca..88b02134 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,8 +38,20 @@ jobs: - name: Run rustfmt --check run: cargo fmt -- --check + msrv: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install cargo-hack + uses: taiki-e/install-action@cargo-hack + - uses: Swatinem/rust-cache@v2 + - name: Use predefined lockfile + run: mv -v Cargo.lock.MSRV Cargo.lock + - name: Build (lib only) + run: cargo hack check --rust-version --locked + test: - needs: [clippy, fmt] + needs: [clippy, fmt, msrv] runs-on: ${{ matrix.os }} strategy: matrix: From be45210006d53f2943c77f17be2dfb02955a9249 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Mon, 9 Sep 2024 13:16:41 -0500 Subject: [PATCH 177/185] ci: coverage report running on stable rust --- .github/workflows/coverage.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 1485889c..c0840d8a 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -14,7 +14,7 @@ jobs: CARGO_TERM_COLOR: always steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@nightly + - uses: dtolnay/rust-toolchain@stable with: components: llvm-tools-preview From 7cf8fac288b40efad7449e7f4406576b63855fdc Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Mon, 9 Sep 2024 13:53:12 -0500 Subject: [PATCH 178/185] ci: test coverage on nightly with changes from 275 --- .github/workflows/coverage.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index c0840d8a..1485889c 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -14,7 +14,7 @@ jobs: CARGO_TERM_COLOR: always steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: dtolnay/rust-toolchain@nightly with: components: llvm-tools-preview From 53e8c97c6652f23032865bf22dacb91687cdccb7 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Mon, 9 Sep 2024 14:31:11 -0500 Subject: [PATCH 179/185] ci: specify which nightly compiler --- .github/workflows/coverage.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 1485889c..b5af5ec8 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -14,8 +14,9 @@ jobs: CARGO_TERM_COLOR: always steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@nightly + - uses: dtolnay/rust-toolchain@master with: + toolchain: nightly-2024-08-29 components: llvm-tools-preview - uses: taiki-e/install-action@v2 From 30a155834b906048217bc19de13184472fbcb64e Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+YeungOnion@users.noreply.github.com> Date: Mon, 9 Sep 2024 20:25:36 -0500 Subject: [PATCH 180/185] Revert "ci: specify which nightly compiler" This reverts commit 53e8c97c6652f23032865bf22dacb91687cdccb7. --- .github/workflows/coverage.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index b5af5ec8..1485889c 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -14,9 +14,8 @@ jobs: CARGO_TERM_COLOR: always steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@master + - uses: dtolnay/rust-toolchain@nightly with: - toolchain: nightly-2024-08-29 components: llvm-tools-preview - uses: taiki-e/install-action@v2 From cd4900533d6c3e272e262d47f953ec805e8b7827 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+YeungOnion@users.noreply.github.com> Date: Mon, 9 Sep 2024 20:25:36 -0500 Subject: [PATCH 181/185] Revert "ci: test coverage on nightly with changes from 275" This reverts commit 7cf8fac288b40efad7449e7f4406576b63855fdc. --- .github/workflows/coverage.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 1485889c..c0840d8a 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -14,7 +14,7 @@ jobs: CARGO_TERM_COLOR: always steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@nightly + - uses: dtolnay/rust-toolchain@stable with: components: llvm-tools-preview From f58cd4c49f64aabb69a92953823cf4967def6756 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+YeungOnion@users.noreply.github.com> Date: Mon, 9 Sep 2024 20:25:36 -0500 Subject: [PATCH 182/185] Revert "ci: coverage report running on stable rust" This reverts commit be45210006d53f2943c77f17be2dfb02955a9249. --- .github/workflows/coverage.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index c0840d8a..1485889c 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -14,7 +14,7 @@ jobs: CARGO_TERM_COLOR: always steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: dtolnay/rust-toolchain@nightly with: components: llvm-tools-preview From bee36d9b7dab7499da73f31508e76c2fa1a58151 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Thu, 5 Sep 2024 13:47:30 +0200 Subject: [PATCH 183/185] refactor!: decouple distribution traits from rand This changes the behavior of Cauchy::mean() and Cauchy::variance() from the old default implementation (based on RNG sampling) to simply returning `None`. Both mean and variance are undefined for the cauchy distribution. This also changes the trait signature of both Distribution and DiscreteDistribution to not be subtraits of rand::distribution::Distribution. Our distributions do still implement this trait iff the `rand` feature is enabled. --- src/statistics/traits.rs | 42 ++++------------------------------------ 1 file changed, 4 insertions(+), 38 deletions(-) diff --git a/src/statistics/traits.rs b/src/statistics/traits.rs index e5211075..9140eab4 100644 --- a/src/statistics/traits.rs +++ b/src/statistics/traits.rs @@ -1,7 +1,5 @@ use ::num_traits::float::Float; -const STEPS: usize = 1_000; - /// The `Min` trait specifies than an object has a minimum value pub trait Min { /// Returns the minimum value in the domain of a given distribution @@ -35,7 +33,7 @@ pub trait Max { /// ``` fn max(&self) -> T; } -pub trait DiscreteDistribution: ::rand::distributions::Distribution { +pub trait DiscreteDistribution { /// Returns the mean, if it exists. fn mean(&self) -> Option { None @@ -58,14 +56,8 @@ pub trait DiscreteDistribution: ::rand::distributions::Distribution: ::rand::distributions::Distribution { +pub trait Distribution { /// Returns the mean, if it exists. - /// The default implementation returns an estimation - /// based on random samples. This is a crude estimate - /// for when no further information is known about the - /// distribution. More accurate statements about the - /// mean can and should be given by overriding the - /// default implementation. /// /// # Examples /// @@ -77,23 +69,9 @@ pub trait Distribution: ::rand::distributions::Distribution { /// assert_eq!(0.5, n.mean().unwrap()); /// ``` fn mean(&self) -> Option { - // TODO: Does not need cryptographic rng - let mut rng = ::rand::rngs::OsRng; - let mut mean = T::zero(); - let mut steps = T::zero(); - for _ in 0..STEPS { - steps = steps + T::one(); - mean = mean + Self::sample(self, &mut rng); - } - Some(mean / steps) + None } /// Returns the variance, if it exists. - /// The default implementation returns an estimation - /// based on random samples. This is a crude estimate - /// for when no further information is known about the - /// distribution. More accurate statements about the - /// variance can and should be given by overriding the - /// default implementation. /// /// # Examples /// @@ -105,19 +83,7 @@ pub trait Distribution: ::rand::distributions::Distribution { /// assert_eq!(1.0 / 12.0, n.variance().unwrap()); /// ``` fn variance(&self) -> Option { - // TODO: Does not need cryptographic rng - let mut rng = ::rand::rngs::OsRng; - let mut mean = T::zero(); - let mut variance = T::zero(); - let mut steps = T::zero(); - for _ in 0..STEPS { - steps = steps + T::one(); - let sample = Self::sample(self, &mut rng); - variance = variance + (steps - T::one()) * (sample - mean) * (sample - mean) / steps; - mean = mean + (sample - mean) / steps; - } - steps = steps - T::one(); - Some(variance / steps) + None } /// Returns the standard deviation, if it exists. /// From d18485c2f10dbb15283e14da8bbd8b416be50c35 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Thu, 5 Sep 2024 14:17:49 +0200 Subject: [PATCH 184/185] build: make `rand` an optional dependency --- Cargo.toml | 11 ++++++++--- benches/order_statistics.rs | 1 - src/distribution/bernoulli.rs | 4 ++-- src/distribution/beta.rs | 4 ++-- src/distribution/binomial.rs | 4 ++-- src/distribution/categorical.rs | 7 ++++--- src/distribution/cauchy.rs | 4 ++-- src/distribution/chi.rs | 4 ++-- src/distribution/chi_squared.rs | 4 ++-- src/distribution/dirac.rs | 4 ++-- src/distribution/dirichlet.rs | 8 ++++---- src/distribution/discrete_uniform.rs | 4 ++-- src/distribution/empirical.rs | 8 +++++--- src/distribution/erlang.rs | 4 ++-- src/distribution/exponential.rs | 8 +++++--- src/distribution/fisher_snedecor.rs | 4 ++-- src/distribution/gamma.rs | 7 ++++--- src/distribution/geometric.rs | 7 ++++--- src/distribution/hypergeometric.rs | 4 ++-- src/distribution/inverse_gamma.rs | 4 ++-- src/distribution/laplace.rs | 11 +++++++---- src/distribution/log_normal.rs | 4 ++-- src/distribution/mod.rs | 2 ++ src/distribution/multinomial.rs | 8 +++++--- src/distribution/multivariate_normal.rs | 7 +++---- src/distribution/multivariate_students_t.rs | 7 ++++--- src/distribution/negative_binomial.rs | 10 ++++++---- src/distribution/normal.rs | 11 +++++++---- src/distribution/pareto.rs | 7 ++++--- src/distribution/poisson.rs | 7 ++++--- src/distribution/students_t.rs | 4 ++-- src/distribution/triangular.rs | 7 ++++--- src/distribution/uniform.rs | 8 ++++---- src/distribution/weibull.rs | 4 ++-- src/statistics/slice_statistics.rs | 4 +++- 35 files changed, 117 insertions(+), 89 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f9f531df..6c80cc70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,21 +22,26 @@ path = "src/lib.rs" [[bench]] name = "order_statistics" harness = false +required-features = ["rand"] [features] -default = ["nalgebra"] +default = ["nalgebra", "rand"] nalgebra = ["dep:nalgebra"] +rand = ["dep:rand", "nalgebra?/rand"] [dependencies] -rand = "0.8" approx = "0.5.0" num-traits = "0.2.14" +[dependencies.rand] +version = "0.8" +optional = true + [dependencies.nalgebra] version = "0.32" optional = true default-features = false -features = ["rand", "std"] +features = ["std"] [dev-dependencies] criterion = "0.5" diff --git a/benches/order_statistics.rs b/benches/order_statistics.rs index fa6fdd26..d94902c9 100644 --- a/benches/order_statistics.rs +++ b/benches/order_statistics.rs @@ -1,4 +1,3 @@ -extern crate rand; extern crate statrs; use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; use rand::prelude::*; diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index d5de981a..28f9d104 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -1,6 +1,5 @@ use crate::distribution::{Binomial, BinomialError, Discrete, DiscreteCDF}; use crate::statistics::*; -use rand::Rng; /// Implements the /// [Bernoulli](https://en.wikipedia.org/wiki/Bernoulli_distribution) @@ -85,8 +84,9 @@ impl std::fmt::Display for Bernoulli { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Bernoulli { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { rng.gen_bool(self.p()) as u8 as f64 } } diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index e20ea302..2741889e 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::{beta, gamma}; use crate::statistics::*; -use rand::Rng; /// Implements the [Beta](https://en.wikipedia.org/wiki/Beta_distribution) /// distribution @@ -121,8 +120,9 @@ impl std::fmt::Display for Beta { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Beta { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { // Generated by sampling two gamma distributions and normalizing. let x = super::gamma::sample_unchecked(rng, self.shape_a, 1.0); let y = super::gamma::sample_unchecked(rng, self.shape_b, 1.0); diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 1d86283d..c24bf7b5 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -1,7 +1,6 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{beta, factorial}; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the @@ -110,8 +109,9 @@ impl std::fmt::Display for Binomial { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Binomial { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { (0..self.n).fold(0.0, |acc, _| { let n: f64 = rng.gen(); if n < self.p { diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index cb3c7ea8..7d3a7c1c 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -1,6 +1,5 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the @@ -124,8 +123,9 @@ impl std::fmt::Display for Categorical { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Categorical { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, &self.cdf) } } @@ -322,7 +322,8 @@ impl Discrete for Categorical { /// Draws a sample from the categorical distribution described by `cdf` /// without doing any bounds checking -pub fn sample_unchecked(rng: &mut R, cdf: &[f64]) -> f64 { +#[cfg(feature = "rand")] +pub fn sample_unchecked(rng: &mut R, cdf: &[f64]) -> f64 { let draw = rng.gen::() * cdf.last().unwrap(); cdf.iter() .enumerate() diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index 5ba7f69f..c9fc4ae2 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the [Cauchy](https://en.wikipedia.org/wiki/Cauchy_distribution) @@ -111,8 +110,9 @@ impl std::fmt::Display for Cauchy { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Cauchy { - fn sample(&self, r: &mut R) -> f64 { + fn sample(&self, r: &mut R) -> f64 { self.location + self.scale * (f64::consts::PI * (r.gen::() - 0.5)).tan() } } diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index 796fcd23..bcb98481 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the [Chi](https://en.wikipedia.org/wiki/Chi_distribution) @@ -94,8 +93,9 @@ impl std::fmt::Display for Chi { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Chi { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { (0..self.freedom as i64) .fold(0.0, |acc, _| { acc + super::normal::sample_unchecked(rng, 0.0, 1.0).powf(2.0) diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index a847ac94..f61d6e19 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF, Gamma, GammaError}; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the @@ -101,8 +100,9 @@ impl std::fmt::Display for ChiSquared { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for ChiSquared { - fn sample(&self, r: &mut R) -> f64 { + fn sample(&self, r: &mut R) -> f64 { ::rand::distributions::Distribution::sample(&self.g, r) } } diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index 18e70f9b..ec833d93 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -1,6 +1,5 @@ use crate::distribution::ContinuousCDF; use crate::statistics::*; -use rand::Rng; /// Implements the [Dirac Delta](https://en.wikipedia.org/wiki/Dirac_delta_function#As_a_distribution) /// distribution @@ -69,8 +68,9 @@ impl std::fmt::Display for Dirac { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Dirac { - fn sample(&self, _: &mut R) -> f64 { + fn sample(&self, _: &mut R) -> f64 { self.0 } } diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index 355476db..7c7c9913 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -2,8 +2,7 @@ use crate::distribution::Continuous; use crate::function::gamma; use crate::prec; use crate::statistics::*; -use nalgebra::{Const, Dim, Dyn, OMatrix, OVector}; -use rand::Rng; +use nalgebra::{Dim, Dyn, OMatrix, OVector}; use std::f64; /// Implements the @@ -192,16 +191,17 @@ where } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution> for Dirichlet where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { - fn sample(&self, rng: &mut R) -> OVector { + fn sample(&self, rng: &mut R) -> OVector { let mut sum = 0.0; OVector::from_iterator_generic( self.alpha.shape_generic().0, - Const::<1>, + nalgebra::Const::<1>, self.alpha.iter().map(|&a| { let sample = super::gamma::sample_unchecked(rng, a, 1.0); sum += sample; diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index 524bc2a3..85b26090 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -1,6 +1,5 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; -use rand::Rng; /// Implements the [Discrete /// Uniform](https://en.wikipedia.org/wiki/Discrete_uniform_distribution) @@ -75,8 +74,9 @@ impl std::fmt::Display for DiscreteUniform { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for DiscreteUniform { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { rng.gen_range(self.min..=self.max) as f64 } } diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index 6dc7ec71..965d8c7f 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -1,7 +1,6 @@ -use crate::distribution::{ContinuousCDF, Uniform}; +use crate::distribution::ContinuousCDF; use crate::statistics::*; use core::cmp::Ordering; -use rand::Rng; use std::collections::BTreeMap; #[derive(Clone, PartialEq, Debug)] @@ -176,8 +175,11 @@ impl std::fmt::Display for Empirical { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Empirical { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { + use crate::distribution::Uniform; + let uniform = Uniform::new(0.0, 1.0).unwrap(); self.__inverse_cdf(uniform.sample(rng)) } diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index 9b7a332c..2ad017f3 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF, Gamma, GammaError}; use crate::statistics::*; -use rand::Rng; /// Implements the [Erlang](https://en.wikipedia.org/wiki/Erlang_distribution) /// distribution @@ -83,8 +82,9 @@ impl std::fmt::Display for Erlang { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Erlang { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { ::rand::distributions::Distribution::sample(&self.g, rng) } } diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index ec30d1f7..9c6c21fc 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -1,6 +1,5 @@ -use crate::distribution::{ziggurat, Continuous, ContinuousCDF}; +use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the @@ -91,8 +90,11 @@ impl std::fmt::Display for Exp { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Exp { - fn sample(&self, r: &mut R) -> f64 { + fn sample(&self, r: &mut R) -> f64 { + use crate::distribution::ziggurat; + ziggurat::sample_exp_1(r) / self.rate } } diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index 610da130..3208f98f 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::beta; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the @@ -124,8 +123,9 @@ impl std::fmt::Display for FisherSnedecor { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for FisherSnedecor { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { (super::gamma::sample_unchecked(rng, self.freedom_1 / 2.0, 0.5) * self.freedom_2) / (super::gamma::sample_unchecked(rng, self.freedom_2 / 2.0, 0.5) * self.freedom_1) } diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index b6055c77..89341439 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -2,7 +2,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::prec; use crate::statistics::*; -use rand::Rng; /// Implements the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) /// distribution @@ -122,8 +121,9 @@ impl std::fmt::Display for Gamma { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Gamma { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.shape, self.rate) } } @@ -400,7 +400,8 @@ impl Continuous for Gamma { /// /// ACM Transactions on Mathematical Software, Vol. 26, No. 3, September 2000, /// Pages 363-372 -pub fn sample_unchecked(rng: &mut R, shape: f64, rate: f64) -> f64 { +#[cfg(feature = "rand")] +pub fn sample_unchecked(rng: &mut R, shape: f64, rate: f64) -> f64 { let mut a = shape; let mut afix = 1.0; if shape < 1.0 { diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index 41e35f52..82af5eef 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -1,7 +1,5 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; -use rand::distributions::OpenClosed01; -use rand::Rng; use std::f64; /// Implements the @@ -92,8 +90,11 @@ impl std::fmt::Display for Geometric { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Geometric { - fn sample(&self, r: &mut R) -> f64 { + fn sample(&self, r: &mut R) -> f64 { + use ::rand::distributions::OpenClosed01; + if ulps_eq!(self.p, 1.0) { 1.0 } else { diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 7da6f45a..ac39917d 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -1,7 +1,6 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::factorial; use crate::statistics::*; -use rand::Rng; use std::cmp; use std::f64; @@ -146,8 +145,9 @@ impl std::fmt::Display for Hypergeometric { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Hypergeometric { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { let mut population = self.population as f64; let mut successes = self.successes as f64; let mut draws = self.draws; diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index a36c7e17..db101fd0 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the [Inverse @@ -119,8 +118,9 @@ impl std::fmt::Display for InverseGamma { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for InverseGamma { - fn sample(&self, r: &mut R) -> f64 { + fn sample(&self, r: &mut R) -> f64 { 1.0 / super::gamma::sample_unchecked(r, self.shape, self.rate) } } diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index 13f03a55..b54bbd9f 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::{Distribution, Max, Median, Min, Mode}; -use rand::Rng; use std::f64; /// Implements the [Laplace](https://en.wikipedia.org/wiki/Laplace_distribution) @@ -111,8 +110,9 @@ impl std::fmt::Display for Laplace { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Laplace { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { let x: f64 = rng.gen_range(-0.5..0.5); self.location - self.scale * x.signum() * (1. - 2. * x.abs()).ln() } @@ -326,7 +326,6 @@ impl Continuous for Laplace { #[cfg(test)] mod tests { use super::*; - use rand::thread_rng; use crate::testing_boiler; @@ -551,18 +550,22 @@ mod tests { test_rel_close(loc, scale, expected, reltol, inverse_cdf(0.95)); } + #[cfg(feature = "rand")] #[test] fn test_sample() { use ::rand::distributions::Distribution; + use ::rand::thread_rng; + let l = create_ok(0.1, 0.5); l.sample(&mut thread_rng()); } + #[cfg(feature = "rand")] #[test] fn test_sample_distribution() { + use ::rand::distributions::Distribution; use ::rand::rngs::StdRng; use ::rand::SeedableRng; - use rand::distributions::Distribution; // sanity check sampling let location = 0.0; diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 49380d2b..2cf9d7cb 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -2,7 +2,6 @@ use crate::consts; use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::erf; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the @@ -88,8 +87,9 @@ impl std::fmt::Display for LogNormal { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for LogNormal { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { super::normal::sample_unchecked(rng, self.location, self.scale).exp() } } diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 8955ed63..e8c9574c 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -78,7 +78,9 @@ mod students_t; mod triangular; mod uniform; mod weibull; +#[cfg(feature = "rand")] mod ziggurat; +#[cfg(feature = "rand")] mod ziggurat_tables; /// The `ContinuousCDF` trait is used to specify an interface for univariate diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index 7d1b408c..d6304214 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -1,8 +1,7 @@ use crate::distribution::Discrete; use crate::function::factorial; use crate::statistics::*; -use nalgebra::{Const, DVector, Dim, Dyn, OMatrix, OVector}; -use rand::Rng; +use nalgebra::{DVector, Dim, Dyn, OMatrix, OVector}; /// Implements the /// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution) @@ -160,12 +159,15 @@ where } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution> for Multinomial where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { - fn sample(&self, rng: &mut R) -> OVector { + fn sample(&self, rng: &mut R) -> OVector { + use nalgebra::Const; + let p_cdf = super::categorical::prob_mass_to_cdf(self.p().as_slice()); let mut res = OVector::zeros_generic(self.p.shape_generic().0, Const::<1>); for _ in 0..self.n { diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index eb86edd3..b336d9db 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -1,9 +1,7 @@ use crate::distribution::Continuous; -use crate::distribution::Normal; use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; use crate::StatsError; use nalgebra::{Cholesky, Const, DMatrix, DVector, Dim, DimMin, Dyn, OMatrix, OVector}; -use rand::Rng; use std::f64; use std::f64::consts::{E, PI}; @@ -247,6 +245,7 @@ where } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution> for MultivariateNormal where D: Dim, @@ -264,8 +263,8 @@ where /// `Z` is a vector of normally distributed random variables, and /// `μ` is the mean vector - fn sample(&self, rng: &mut R) -> OVector { - let d = Normal::new(0., 1.).unwrap(); + fn sample(&self, rng: &mut R) -> OVector { + let d = crate::distribution::Normal::new(0., 1.).unwrap(); let z = OVector::from_distribution_generic(self.mu.shape_generic().0, Const::<1>, &d, rng); (&self.cov_chol_decomp * z) + &self.mu } diff --git a/src/distribution/multivariate_students_t.rs b/src/distribution/multivariate_students_t.rs index b75e0a88..73e8f8f2 100644 --- a/src/distribution/multivariate_students_t.rs +++ b/src/distribution/multivariate_students_t.rs @@ -1,9 +1,7 @@ use crate::distribution::Continuous; -use crate::distribution::{ChiSquared, Normal}; use crate::function::gamma; use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; use nalgebra::{Cholesky, Const, DMatrix, Dim, DimMin, Dyn, OMatrix, OVector}; -use rand::Rng; use std::f64::consts::PI; /// Implements the [Multivariate Student's t-distribution](https://en.wikipedia.org/wiki/Multivariate_t-distribution) @@ -198,6 +196,7 @@ where } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution> for MultivariateStudent where D: Dim, @@ -217,7 +216,9 @@ where /// `L` is the Cholesky decomposition of the scale matrix, /// `Z` is a vector of normally distributed random variables, and /// `μ` is the location vector - fn sample(&self, rng: &mut R) -> OVector { + fn sample(&self, rng: &mut R) -> OVector { + use crate::distribution::{ChiSquared, Normal}; + let d = Normal::new(0., 1.).unwrap(); let s = ChiSquared::new(self.freedom).unwrap(); let w = (self.freedom / s.sample(rng)).sqrt(); diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index 6ed557be..29e22eee 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -1,7 +1,6 @@ -use crate::distribution::{self, poisson, Discrete, DiscreteCDF}; +use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{beta, gamma}; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the @@ -136,9 +135,12 @@ impl std::fmt::Display for NegativeBinomial { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for NegativeBinomial { - fn sample(&self, r: &mut R) -> u64 { - let lambda = distribution::gamma::sample_unchecked(r, self.r, (1.0 - self.p) / self.p); + fn sample(&self, r: &mut R) -> u64 { + use crate::distribution::{gamma, poisson}; + + let lambda = gamma::sample_unchecked(r, self.r, (1.0 - self.p) / self.p); poisson::sample_unchecked(r, lambda).floor() as u64 } } diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index b536c101..a264af50 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -1,8 +1,7 @@ use crate::consts; -use crate::distribution::{ziggurat, Continuous, ContinuousCDF}; +use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::erf; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the [Normal](https://en.wikipedia.org/wiki/Normal_distribution) @@ -106,8 +105,9 @@ impl std::fmt::Display for Normal { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Normal { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.mean, self.std_dev) } } @@ -347,8 +347,11 @@ pub fn ln_pdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 { (-0.5 * d * d) - consts::LN_SQRT_2PI - std_dev.ln() } +#[cfg(feature = "rand")] /// draws a sample from a normal distribution using the Box-Muller algorithm -pub fn sample_unchecked(rng: &mut R, mean: f64, std_dev: f64) -> f64 { +pub fn sample_unchecked(rng: &mut R, mean: f64, std_dev: f64) -> f64 { + use crate::distribution::ziggurat; + mean + std_dev * ziggurat::sample_std_normal(rng) } diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index 886db43b..1de73d84 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -1,7 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use rand::distributions::OpenClosed01; -use rand::Rng; use std::f64; /// Implements the [Pareto](https://en.wikipedia.org/wiki/Pareto_distribution) @@ -114,8 +112,11 @@ impl std::fmt::Display for Pareto { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Pareto { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { + use rand::distributions::OpenClosed01; + // Inverse transform sampling let u: f64 = rng.sample(OpenClosed01); self.scale * u.powf(-1.0 / self.shape) diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 33e8f8a2..b3f1ebab 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -1,7 +1,6 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{factorial, gamma}; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the [Poisson](https://en.wikipedia.org/wiki/Poisson_distribution) @@ -90,13 +89,14 @@ impl std::fmt::Display for Poisson { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Poisson { /// Generates one sample from the Poisson distribution either by /// Knuth's method if lambda < 30.0 or Rejection method PA by /// A. C. Atkinson from the Journal of the Royal Statistical Society /// Series C (Applied Statistics) Vol. 28 No. 1. (1979) pp. 29 - 35 /// otherwise - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.lambda) } } @@ -283,7 +283,8 @@ impl Discrete for Poisson { /// A. C. Atkinson from the Journal of the Royal Statistical Society /// Series C (Applied Statistics) Vol. 28 No. 1. (1979) pp. 29 - 35 /// otherwise -pub fn sample_unchecked(rng: &mut R, lambda: f64) -> f64 { +#[cfg(feature = "rand")] +pub fn sample_unchecked(rng: &mut R, lambda: f64) -> f64 { if lambda < 30.0 { let limit = (-lambda).exp(); let mut count = 0.0; diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index cc88707f..117fbc87 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -1,7 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::{beta, gamma}; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the [Student's @@ -143,8 +142,9 @@ impl std::fmt::Display for StudentsT { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for StudentsT { - fn sample(&self, r: &mut R) -> f64 { + fn sample(&self, r: &mut R) -> f64 { // based on method 2, section 5 in chapter 9 of L. Devroye's // "Non-Uniform Random Variate Generation" let gamma = super::gamma::sample_unchecked(r, 0.5 * self.freedom, 0.5); diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index eb3cb93d..1a83be2c 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -1,6 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the @@ -112,8 +111,9 @@ impl std::fmt::Display for Triangular { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Triangular { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.min, self.max, self.mode) } } @@ -382,7 +382,8 @@ impl Continuous for Triangular { } } -fn sample_unchecked(rng: &mut R, min: f64, max: f64, mode: f64) -> f64 { +#[cfg(feature = "rand")] +fn sample_unchecked(rng: &mut R, min: f64, max: f64, mode: f64) -> f64 { let f: f64 = rng.gen(); if f < (mode - min) / (max - min) { min + (f * (max - min) * (mode - min)).sqrt() diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index 55bd7884..3d637734 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -1,7 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use rand::distributions::Uniform as RandUniform; -use rand::Rng; use std::f64; use std::fmt::Debug; @@ -121,9 +119,10 @@ impl std::fmt::Display for Uniform { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Uniform { - fn sample(&self, rng: &mut R) -> f64 { - let d = RandUniform::new_inclusive(self.min, self.max); + fn sample(&self, rng: &mut R) -> f64 { + let d = rand::distributions::Uniform::new_inclusive(self.min, self.max); rng.sample(d) } } @@ -495,6 +494,7 @@ mod tests { test::check_continuous_distribution(&create_ok(-2.0, 15.0), -2.0, 15.0); } + #[cfg(feature = "rand")] #[test] fn test_samples_in_range() { use rand::rngs::StdRng; diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index 71aa30ef..aacf7bb6 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -2,7 +2,6 @@ use crate::consts; use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; -use rand::Rng; use std::f64; /// Implements the [Weibull](https://en.wikipedia.org/wiki/Weibull_distribution) @@ -121,8 +120,9 @@ impl std::fmt::Display for Weibull { } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Weibull { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { let x: f64 = rng.gen(); self.scale * (-x.ln()).powf(1.0 / self.shape) } diff --git a/src/statistics/slice_statistics.rs b/src/statistics/slice_statistics.rs index ea2f3096..b0b1d2a7 100644 --- a/src/statistics/slice_statistics.rs +++ b/src/statistics/slice_statistics.rs @@ -1,6 +1,5 @@ use crate::statistics::*; use core::ops::{Index, IndexMut}; -use rand::prelude::SliceRandom; #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] pub struct Data(D); @@ -133,8 +132,11 @@ impl + AsRef<[f64]>> Data { } } +#[cfg(feature = "rand")] impl> ::rand::distributions::Distribution for Data { fn sample(&self, rng: &mut R) -> f64 { + use rand::prelude::SliceRandom; + *self.0.as_ref().choose(rng).unwrap() } } From 207e27db352354f4d25f5be7ddc51bdeb8db6554 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Thu, 5 Sep 2024 14:30:05 +0200 Subject: [PATCH 185/185] ci: check all feature combinations --- .github/workflows/test.yml | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 88b02134..65d901eb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,12 +20,9 @@ jobs: with: components: clippy - - name: Run cargo clippy + - name: Run cargo clippy (default features) run: cargo clippy --all-targets - - name: Run cargo clippy without default features - run: cargo clippy --no-default-features --all-targets - fmt: runs-on: ubuntu-latest steps: @@ -65,3 +62,15 @@ jobs: - name: Test default features run: cargo test + features: + needs: [clippy, fmt] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install Rust stable + uses: dtolnay/rust-toolchain@stable + - name: Install cargo-hack + uses: taiki-e/install-action@cargo-hack + - uses: Swatinem/rust-cache@v2 + - name: Check all possible feature sets + run: cargo hack check --feature-powerset --no-dev-deps