From 7fe350c9bac6b11a84687a7e8e33e6fd8e0a8c01 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Wed, 13 Nov 2024 12:42:45 +0100 Subject: [PATCH] Hypergeo fix (#1510) --- rand_distr/CHANGELOG.md | 1 + rand_distr/src/hypergeometric.rs | 21 +++++++++++++++++++-- rand_distr/tests/cdf.rs | 2 +- rand_distr/tests/value_stability.rs | 2 +- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index fc597d4776..81b62a1f28 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Mark `WeightError`, `PoissonError`, `BinomialError` as `#[non_exhaustive]` (#1480). - Remove support for generating `isize` and `usize` values with `Standard`, `Uniform` and `Fill` and usage as a `WeightedAliasIndex` weight (#1487) - Limit the maximal acceptable lambda for `Poisson` to solve (#1312) (#1498) +- Fix bug in `Hypergeometric`, this is a Value-breaking change (#1510) - Change parameter type of `Zipf::new`: `n` is now floating-point (#1518) ### Added diff --git a/rand_distr/src/hypergeometric.rs b/rand_distr/src/hypergeometric.rs index c1f1d4ef23..f446357530 100644 --- a/rand_distr/src/hypergeometric.rs +++ b/rand_distr/src/hypergeometric.rs @@ -131,10 +131,17 @@ fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64, result } +const LOGSQRT2PI: f64 = 0.91893853320467274178; // log(sqrt(2*pi)) + fn ln_of_factorial(v: f64) -> f64 { // the paper calls for ln(v!), but also wants to pass in fractions, // so we need to use Stirling's approximation to fill in the gaps: - v * v.ln() - v + + // shift v by 3, because Stirling is bad for small values + let v_3 = v + 3.0; + let ln_fac = (v_3 + 0.5) * v_3.ln() - v_3 + LOGSQRT2PI + 1.0 / (12.0 * v_3); + // make the correction for the shift + ln_fac - ((v + 3.0) * (v + 2.0) * (v + 1.0)).ln() } impl Hypergeometric { @@ -359,7 +366,7 @@ impl Distribution for Hypergeometric { } else { for i in (y as u64 + 1)..=(m as u64) { f *= i as f64 * (n2 - k + i) as f64; - f /= (n1 - i) as f64 * (k - i) as f64; + f /= (n1 - i + 1) as f64 * (k - i + 1) as f64; } } @@ -441,6 +448,7 @@ impl Distribution for Hypergeometric { #[cfg(test)] mod test { + use super::*; #[test] @@ -494,4 +502,13 @@ mod test { fn hypergeometric_distributions_can_be_compared() { assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3)); } + + #[test] + fn stirling() { + let test = [0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + for &v in test.iter() { + let ln_fac = ln_of_factorial(v); + assert!((special::Gamma::ln_gamma(v + 1.0).0 - ln_fac).abs() < 1e-4); + } + } } diff --git a/rand_distr/tests/cdf.rs b/rand_distr/tests/cdf.rs index 86e926af00..62286860db 100644 --- a/rand_distr/tests/cdf.rs +++ b/rand_distr/tests/cdf.rs @@ -598,7 +598,7 @@ fn hypergeometric() { (60, 10, 7), (70, 20, 50), (100, 50, 10), - // (100, 50, 49), // Fail case + (100, 50, 49), ]; for (seed, (n, k, n_)) in parameters.into_iter().enumerate() { diff --git a/rand_distr/tests/value_stability.rs b/rand_distr/tests/value_stability.rs index b142741e77..330119b68f 100644 --- a/rand_distr/tests/value_stability.rs +++ b/rand_distr/tests/value_stability.rs @@ -105,7 +105,7 @@ fn hypergeometric_stability() { test_samples( 7221, Hypergeometric::new(100, 50, 50).unwrap(), - &[23, 27, 26, 27, 22, 24, 31, 22], + &[23, 27, 26, 27, 22, 25, 31, 25], ); // Algorithm H2PE }