Skip to content

Commit

Permalink
Hypergeo fix (#1510)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamin-lieser authored Nov 13, 2024
1 parent ad67294 commit 7fe350c
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
1 change: 1 addition & 0 deletions rand_distr/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Mark `WeightError`, `PoissonError`, `BinomialError` as `#[non_exhaustive]` (#1480).
- Remove support for generating `isize` and `usize` values with `Standard`, `Uniform` and `Fill` and usage as a `WeightedAliasIndex` weight (#1487)
- Limit the maximal acceptable lambda for `Poisson` to solve (#1312) (#1498)
- Fix bug in `Hypergeometric`, this is a Value-breaking change (#1510)
- Change parameter type of `Zipf::new`: `n` is now floating-point (#1518)

### Added
Expand Down
21 changes: 19 additions & 2 deletions rand_distr/src/hypergeometric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,17 @@ fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64,
result
}

const LOGSQRT2PI: f64 = 0.91893853320467274178; // log(sqrt(2*pi))

fn ln_of_factorial(v: f64) -> f64 {
// the paper calls for ln(v!), but also wants to pass in fractions,
// so we need to use Stirling's approximation to fill in the gaps:
v * v.ln() - v

// shift v by 3, because Stirling is bad for small values
let v_3 = v + 3.0;
let ln_fac = (v_3 + 0.5) * v_3.ln() - v_3 + LOGSQRT2PI + 1.0 / (12.0 * v_3);
// make the correction for the shift
ln_fac - ((v + 3.0) * (v + 2.0) * (v + 1.0)).ln()
}

impl Hypergeometric {
Expand Down Expand Up @@ -359,7 +366,7 @@ impl Distribution<u64> for Hypergeometric {
} else {
for i in (y as u64 + 1)..=(m as u64) {
f *= i as f64 * (n2 - k + i) as f64;
f /= (n1 - i) as f64 * (k - i) as f64;
f /= (n1 - i + 1) as f64 * (k - i + 1) as f64;
}
}

Expand Down Expand Up @@ -441,6 +448,7 @@ impl Distribution<u64> for Hypergeometric {

#[cfg(test)]
mod test {

use super::*;

#[test]
Expand Down Expand Up @@ -494,4 +502,13 @@ mod test {
fn hypergeometric_distributions_can_be_compared() {
assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3));
}

#[test]
fn stirling() {
let test = [0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
for &v in test.iter() {
let ln_fac = ln_of_factorial(v);
assert!((special::Gamma::ln_gamma(v + 1.0).0 - ln_fac).abs() < 1e-4);
}
}
}
2 changes: 1 addition & 1 deletion rand_distr/tests/cdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ fn hypergeometric() {
(60, 10, 7),
(70, 20, 50),
(100, 50, 10),
// (100, 50, 49), // Fail case
(100, 50, 49),
];

for (seed, (n, k, n_)) in parameters.into_iter().enumerate() {
Expand Down
2 changes: 1 addition & 1 deletion rand_distr/tests/value_stability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ fn hypergeometric_stability() {
test_samples(
7221,
Hypergeometric::new(100, 50, 50).unwrap(),
&[23, 27, 26, 27, 22, 24, 31, 22],
&[23, 27, 26, 27, 22, 25, 31, 25],
); // Algorithm H2PE
}

Expand Down

0 comments on commit 7fe350c

Please sign in to comment.