diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index c5f7d776..c4abc985 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -102,6 +102,19 @@ impl ContinuousCDF for Uniform { (self.max - x) / (self.max - self.min) } } + + /// Finds the value of `x` where `F(p) = x` + fn inverse_cdf(&self, p: f64) -> f64 { + if !(0.0..=1.0).contains(&p) { + panic!("p must be in [0, 1], was {}", p); + } else if p == 0.0 { + self.min + } else if p == 1.0 { + self.max + } else { + (self.max - self.min) * p + self.min + } + } } impl Min for Uniform { @@ -417,6 +430,21 @@ mod tests { test_case(0.0, f64::INFINITY, 1.0, cdf(f64::INFINITY)); } + #[test] + fn test_inverse_cdf() { + let inverse_cdf = |arg: f64| move |x: Uniform| x.inverse_cdf(arg); + test_case(0.0, 0.0, 0.0, inverse_cdf(0.0)); + test_case(0.0, 0.0, 0.0, inverse_cdf(1.0)); + test_case(0.0, 0.1, 0.05, inverse_cdf(0.5)); + test_case(0.0, 10.0, 5.0, inverse_cdf(0.5)); + test_case(1.0, 10.0, 1.0, inverse_cdf(0.0)); + test_case(1.0, 10.0, 4.0, inverse_cdf(1.0 / 3.0)); + test_case(1.0, 10.0, 10.0, inverse_cdf(1.0)); + test_case(f64::NEG_INFINITY, f64::INFINITY, f64::NEG_INFINITY, inverse_cdf(0.0)); + test_case(0.0, f64::INFINITY, 0.0, inverse_cdf(0.0)); + test_case(0.0, f64::INFINITY, f64::INFINITY, inverse_cdf(1.0)); + } + #[test] fn test_cdf_lower_bound() { let cdf = |arg: f64| move |x: Uniform| x.cdf(arg);