From ad5219e3ca0953e381a74257c89fbb7529588ceb Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 21 Apr 2024 10:08:46 -0500 Subject: [PATCH 1/3] feat: extend StatsError for finiteness --- src/error.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/error.rs b/src/error.rs index c76d8b32..ce0bb1a4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,6 +6,8 @@ use std::fmt; pub enum StatsError { /// Generic bad input parameter error BadParams, + /// An argument must be finite + ArgFinite(&'static str), /// An argument should have been positive and was not ArgMustBePositive(&'static str), /// An argument should have been non-negative and was not @@ -58,6 +60,7 @@ impl fmt::Display for StatsError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { StatsError::BadParams => write!(f, "Bad distribution parameters"), + StatsError::ArgFinite(s) => write!(f, "Argument {} must be finite", s), StatsError::ArgMustBePositive(s) => write!(f, "Argument {} must be positive", s), StatsError::ArgNotNegative(s) => write!(f, "Argument {} must be non-negative", s), StatsError::ArgIntervalIncl(s, min, max) => { From 726316c6fb25418b32bf65d57720361e5d1d57e0 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 21 Apr 2024 10:09:05 -0500 Subject: [PATCH 2/3] feat: reject constructing Uniform of infinite support additionally removes logic handling infinite support --- src/distribution/uniform.rs | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index c4abc985..339ebaec 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -31,7 +31,7 @@ impl Uniform { /// /// # Errors /// - /// Returns an error if `min` or `max` are `NaN` + /// Returns an error if `min` or `max` are `NaN` or unbounded /// /// # Examples /// @@ -44,12 +44,21 @@ impl Uniform { /// /// result = Uniform::new(f64::NAN, f64::NAN); /// assert!(result.is_err()); + /// + /// result = Uniform::new(f64::NEG_INFINITY, 1.0); + /// assert!(result.is_err()); /// ``` pub fn new(min: f64, max: f64) -> Result { - if min > max || min.is_nan() || max.is_nan() { - Err(StatsError::BadParams) - } else { - Ok(Uniform { min, max }) + if min.is_nan() || max.is_nan() { + return Err(StatsError::BadParams); + } + + match (min.is_finite(), max.is_finite(), min < max) { + (false, false, _) => Err(StatsError::ArgFinite("min and max")), + (false, true, _) => Err(StatsError::ArgFinite("min")), + (true, false, _) => Err(StatsError::ArgFinite("max")), + (true, true, false) => Err(StatsError::ArgLteArg("min", "max")), + (true, true, true) => Ok(Uniform { min, max }), } } } @@ -94,10 +103,6 @@ impl ContinuousCDF for Uniform { 1.0 } else if x >= self.max { 0.0 - } else if x.is_infinite() && self.max.is_infinite() { - 0.0 - } else if self.max.is_infinite() { - 1.0 } else { (self.max - x) / (self.max - self.min) } From 07c1f8e854b1ef3dbe5dcf20bacef6ceb34c8902 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Sun, 21 Apr 2024 17:46:17 -0500 Subject: [PATCH 3/3] test: ensure test suite matches behavior for restricting bounded Uniform --- src/distribution/uniform.rs | 34 +++------------------------------- 1 file changed, 3 insertions(+), 31 deletions(-) diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index 339ebaec..774db333 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -261,7 +261,7 @@ mod tests { fn try_create(min: f64, max: f64) -> Uniform { let n = Uniform::new(min, max); - assert!(n.is_ok()); + assert!(n.is_ok(), "failed create over interval [{}, {}]", min, max); n.unwrap() } @@ -301,19 +301,19 @@ mod tests { #[test] fn test_create() { - create_case(0.0, 0.0); create_case(0.0, 0.1); create_case(0.0, 1.0); - create_case(10.0, 10.0); create_case(-5.0, 11.0); create_case(-5.0, 100.0); } #[test] fn test_bad_create() { + bad_create_case(0.0, 0.0); bad_create_case(f64::NAN, 1.0); bad_create_case(1.0, f64::NAN); bad_create_case(f64::NAN, f64::NAN); + bad_create_case(0.0, f64::INFINITY); bad_create_case(1.0, 0.0); } @@ -324,7 +324,6 @@ mod tests { test_case(0.0, 2.0, 1.0 / 3.0, variance); test_almost(0.1, 4.0, 1.2675, 1e-15, variance); test_case(10.0, 11.0, 1.0 / 12.0, variance); - test_case(0.0, f64::INFINITY, f64::INFINITY, variance); } #[test] @@ -335,7 +334,6 @@ mod tests { test_almost(0.1, 4.0, 1.360976553135600743431, 1e-15, entropy); test_case(1.0, 10.0, 2.19722457733621938279, entropy); test_case(10.0, 11.0, 0.0, entropy); - test_case(0.0, f64::INFINITY, f64::INFINITY, entropy); } #[test] @@ -346,7 +344,6 @@ mod tests { test_case(0.1, 4.0, 0.0, skewness); test_case(1.0, 10.0, 0.0, skewness); test_case(10.0, 11.0, 0.0, skewness); - test_case(0.0, f64::INFINITY, 0.0, skewness); } #[test] @@ -357,7 +354,6 @@ mod tests { test_case(0.1, 4.0, 2.05, mode); test_case(1.0, 10.0, 5.5, mode); test_case(10.0, 11.0, 10.5, mode); - test_case(0.0, f64::INFINITY, f64::INFINITY, mode); } #[test] @@ -368,15 +364,11 @@ mod tests { test_case(0.1, 4.0, 2.05, median); test_case(1.0, 10.0, 5.5, median); test_case(10.0, 11.0, 10.5, median); - test_case(0.0, f64::INFINITY, f64::INFINITY, median); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Uniform| x.pdf(arg); - test_case(0.0, 0.0, 0.0, pdf(-5.0)); - test_case(0.0, 0.0, f64::INFINITY, pdf(0.0)); - test_case(0.0, 0.0, 0.0, pdf(5.0)); test_case(0.0, 0.1, 0.0, pdf(-5.0)); test_case(0.0, 0.1, 10.0, pdf(0.05)); test_case(0.0, 0.1, 0.0, pdf(5.0)); @@ -391,17 +383,11 @@ mod tests { test_case(-5.0, 100.0, 0.009523809523809523809524, pdf(-5.0)); test_case(-5.0, 100.0, 0.009523809523809523809524, pdf(0.0)); test_case(-5.0, 100.0, 0.0, pdf(101.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(-5.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(10.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(f64::INFINITY)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Uniform| x.ln_pdf(arg); - test_case(0.0, 0.0, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(0.0, 0.0, f64::INFINITY, ln_pdf(0.0)); - test_case(0.0, 0.0, f64::NEG_INFINITY, ln_pdf(5.0)); test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(-5.0)); test_almost(0.0, 0.1, 2.302585092994045684018, 1e-15, ln_pdf(0.05)); test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0)); @@ -416,38 +402,27 @@ mod tests { test_case(-5.0, 100.0, -4.653960350157523371101, ln_pdf(-5.0)); test_case(-5.0, 100.0, -4.653960350157523371101, ln_pdf(0.0)); test_case(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(101.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(10.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Uniform| x.cdf(arg); - test_case(0.0, 0.0, 0.0, cdf(0.0)); test_case(0.0, 0.1, 0.5, cdf(0.05)); test_case(0.0, 1.0, 0.5, cdf(0.5)); test_case(0.0, 10.0, 0.1, cdf(1.0)); test_case(0.0, 10.0, 0.5, cdf(5.0)); test_case(-5.0, 100.0, 0.0, cdf(-5.0)); test_case(-5.0, 100.0, 0.04761904761904761904762, cdf(0.0)); - test_case(0.0, f64::INFINITY, 0.0, cdf(10.0)); - test_case(0.0, f64::INFINITY, 1.0, cdf(f64::INFINITY)); } #[test] fn test_inverse_cdf() { let inverse_cdf = |arg: f64| move |x: Uniform| x.inverse_cdf(arg); - test_case(0.0, 0.0, 0.0, inverse_cdf(0.0)); - test_case(0.0, 0.0, 0.0, inverse_cdf(1.0)); test_case(0.0, 0.1, 0.05, inverse_cdf(0.5)); test_case(0.0, 10.0, 5.0, inverse_cdf(0.5)); test_case(1.0, 10.0, 1.0, inverse_cdf(0.0)); test_case(1.0, 10.0, 4.0, inverse_cdf(1.0 / 3.0)); test_case(1.0, 10.0, 10.0, inverse_cdf(1.0)); - test_case(f64::NEG_INFINITY, f64::INFINITY, f64::NEG_INFINITY, inverse_cdf(0.0)); - test_case(0.0, f64::INFINITY, 0.0, inverse_cdf(0.0)); - test_case(0.0, f64::INFINITY, f64::INFINITY, inverse_cdf(1.0)); } #[test] @@ -466,15 +441,12 @@ mod tests { #[test] fn test_sf() { let sf = |arg: f64| move |x: Uniform| x.sf(arg); - test_case(0.0, 0.0, 1.0, sf(0.0)); test_case(0.0, 0.1, 0.5, sf(0.05)); test_case(0.0, 1.0, 0.5, sf(0.5)); test_case(0.0, 10.0, 0.9, sf(1.0)); test_case(0.0, 10.0, 0.5, sf(5.0)); test_case(-5.0, 100.0, 1.0, sf(-5.0)); test_case(-5.0, 100.0, 0.9523809523809523, sf(0.0)); - test_case(0.0, f64::INFINITY, 1.0, sf(10.0)); - test_case(0.0, f64::INFINITY, 0.0, sf(f64::INFINITY)); } #[test]