Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure Uniform constructed on bounded interval #218

Merged
merged 3 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 17 additions & 40 deletions src/distribution/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl Uniform {
///
/// # Errors
///
/// Returns an error if `min` or `max` are `NaN`
/// Returns an error if `min` or `max` are `NaN` or unbounded
///
/// # Examples
///
Expand All @@ -44,12 +44,21 @@ impl Uniform {
///
/// result = Uniform::new(f64::NAN, f64::NAN);
/// assert!(result.is_err());
///
/// result = Uniform::new(f64::NEG_INFINITY, 1.0);
/// assert!(result.is_err());
/// ```
pub fn new(min: f64, max: f64) -> Result<Uniform> {
if min > max || min.is_nan() || max.is_nan() {
Err(StatsError::BadParams)
} else {
Ok(Uniform { min, max })
if min.is_nan() || max.is_nan() {
return Err(StatsError::BadParams);
}

match (min.is_finite(), max.is_finite(), min < max) {
(false, false, _) => Err(StatsError::ArgFinite("min and max")),
(false, true, _) => Err(StatsError::ArgFinite("min")),
(true, false, _) => Err(StatsError::ArgFinite("max")),
(true, true, false) => Err(StatsError::ArgLteArg("min", "max")),
(true, true, true) => Ok(Uniform { min, max }),
}
}
}
Expand Down Expand Up @@ -94,10 +103,6 @@ impl ContinuousCDF<f64, f64> for Uniform {
1.0
} else if x >= self.max {
0.0
} else if x.is_infinite() && self.max.is_infinite() {
0.0
} else if self.max.is_infinite() {
1.0
} else {
(self.max - x) / (self.max - self.min)
}
Expand Down Expand Up @@ -256,7 +261,7 @@ mod tests {

fn try_create(min: f64, max: f64) -> Uniform {
let n = Uniform::new(min, max);
assert!(n.is_ok());
assert!(n.is_ok(), "failed create over interval [{}, {}]", min, max);
n.unwrap()
}

Expand Down Expand Up @@ -296,19 +301,19 @@ mod tests {

#[test]
fn test_create() {
create_case(0.0, 0.0);
create_case(0.0, 0.1);
create_case(0.0, 1.0);
create_case(10.0, 10.0);
create_case(-5.0, 11.0);
create_case(-5.0, 100.0);
}

#[test]
fn test_bad_create() {
bad_create_case(0.0, 0.0);
bad_create_case(f64::NAN, 1.0);
bad_create_case(1.0, f64::NAN);
bad_create_case(f64::NAN, f64::NAN);
bad_create_case(0.0, f64::INFINITY);
bad_create_case(1.0, 0.0);
}

Expand All @@ -319,7 +324,6 @@ mod tests {
test_case(0.0, 2.0, 1.0 / 3.0, variance);
test_almost(0.1, 4.0, 1.2675, 1e-15, variance);
test_case(10.0, 11.0, 1.0 / 12.0, variance);
test_case(0.0, f64::INFINITY, f64::INFINITY, variance);
}

#[test]
Expand All @@ -330,7 +334,6 @@ mod tests {
test_almost(0.1, 4.0, 1.360976553135600743431, 1e-15, entropy);
test_case(1.0, 10.0, 2.19722457733621938279, entropy);
test_case(10.0, 11.0, 0.0, entropy);
test_case(0.0, f64::INFINITY, f64::INFINITY, entropy);
}

#[test]
Expand All @@ -341,7 +344,6 @@ mod tests {
test_case(0.1, 4.0, 0.0, skewness);
test_case(1.0, 10.0, 0.0, skewness);
test_case(10.0, 11.0, 0.0, skewness);
test_case(0.0, f64::INFINITY, 0.0, skewness);
}

#[test]
Expand All @@ -352,7 +354,6 @@ mod tests {
test_case(0.1, 4.0, 2.05, mode);
test_case(1.0, 10.0, 5.5, mode);
test_case(10.0, 11.0, 10.5, mode);
test_case(0.0, f64::INFINITY, f64::INFINITY, mode);
}

#[test]
Expand All @@ -363,15 +364,11 @@ mod tests {
test_case(0.1, 4.0, 2.05, median);
test_case(1.0, 10.0, 5.5, median);
test_case(10.0, 11.0, 10.5, median);
test_case(0.0, f64::INFINITY, f64::INFINITY, median);
}

#[test]
fn test_pdf() {
let pdf = |arg: f64| move |x: Uniform| x.pdf(arg);
test_case(0.0, 0.0, 0.0, pdf(-5.0));
test_case(0.0, 0.0, f64::INFINITY, pdf(0.0));
test_case(0.0, 0.0, 0.0, pdf(5.0));
test_case(0.0, 0.1, 0.0, pdf(-5.0));
test_case(0.0, 0.1, 10.0, pdf(0.05));
test_case(0.0, 0.1, 0.0, pdf(5.0));
Expand All @@ -386,17 +383,11 @@ mod tests {
test_case(-5.0, 100.0, 0.009523809523809523809524, pdf(-5.0));
test_case(-5.0, 100.0, 0.009523809523809523809524, pdf(0.0));
test_case(-5.0, 100.0, 0.0, pdf(101.0));
test_case(0.0, f64::INFINITY, 0.0, pdf(-5.0));
test_case(0.0, f64::INFINITY, 0.0, pdf(10.0));
test_case(0.0, f64::INFINITY, 0.0, pdf(f64::INFINITY));
}

#[test]
fn test_ln_pdf() {
let ln_pdf = |arg: f64| move |x: Uniform| x.ln_pdf(arg);
test_case(0.0, 0.0, f64::NEG_INFINITY, ln_pdf(-5.0));
test_case(0.0, 0.0, f64::INFINITY, ln_pdf(0.0));
test_case(0.0, 0.0, f64::NEG_INFINITY, ln_pdf(5.0));
test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(-5.0));
test_almost(0.0, 0.1, 2.302585092994045684018, 1e-15, ln_pdf(0.05));
test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0));
Expand All @@ -411,38 +402,27 @@ mod tests {
test_case(-5.0, 100.0, -4.653960350157523371101, ln_pdf(-5.0));
test_case(-5.0, 100.0, -4.653960350157523371101, ln_pdf(0.0));
test_case(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(101.0));
test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0));
test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(10.0));
test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(f64::INFINITY));
}

#[test]
fn test_cdf() {
let cdf = |arg: f64| move |x: Uniform| x.cdf(arg);
test_case(0.0, 0.0, 0.0, cdf(0.0));
test_case(0.0, 0.1, 0.5, cdf(0.05));
test_case(0.0, 1.0, 0.5, cdf(0.5));
test_case(0.0, 10.0, 0.1, cdf(1.0));
test_case(0.0, 10.0, 0.5, cdf(5.0));
test_case(-5.0, 100.0, 0.0, cdf(-5.0));
test_case(-5.0, 100.0, 0.04761904761904761904762, cdf(0.0));
test_case(0.0, f64::INFINITY, 0.0, cdf(10.0));
test_case(0.0, f64::INFINITY, 1.0, cdf(f64::INFINITY));
}

#[test]
fn test_inverse_cdf() {
let inverse_cdf = |arg: f64| move |x: Uniform| x.inverse_cdf(arg);
test_case(0.0, 0.0, 0.0, inverse_cdf(0.0));
test_case(0.0, 0.0, 0.0, inverse_cdf(1.0));
test_case(0.0, 0.1, 0.05, inverse_cdf(0.5));
test_case(0.0, 10.0, 5.0, inverse_cdf(0.5));
test_case(1.0, 10.0, 1.0, inverse_cdf(0.0));
test_case(1.0, 10.0, 4.0, inverse_cdf(1.0 / 3.0));
test_case(1.0, 10.0, 10.0, inverse_cdf(1.0));
test_case(f64::NEG_INFINITY, f64::INFINITY, f64::NEG_INFINITY, inverse_cdf(0.0));
test_case(0.0, f64::INFINITY, 0.0, inverse_cdf(0.0));
test_case(0.0, f64::INFINITY, f64::INFINITY, inverse_cdf(1.0));
}

#[test]
Expand All @@ -461,15 +441,12 @@ mod tests {
#[test]
fn test_sf() {
let sf = |arg: f64| move |x: Uniform| x.sf(arg);
test_case(0.0, 0.0, 1.0, sf(0.0));
test_case(0.0, 0.1, 0.5, sf(0.05));
test_case(0.0, 1.0, 0.5, sf(0.5));
test_case(0.0, 10.0, 0.9, sf(1.0));
test_case(0.0, 10.0, 0.5, sf(5.0));
test_case(-5.0, 100.0, 1.0, sf(-5.0));
test_case(-5.0, 100.0, 0.9523809523809523, sf(0.0));
test_case(0.0, f64::INFINITY, 1.0, sf(10.0));
test_case(0.0, f64::INFINITY, 0.0, sf(f64::INFINITY));
}

#[test]
Expand Down
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::fmt;
pub enum StatsError {
/// Generic bad input parameter error
BadParams,
/// An argument must be finite
ArgFinite(&'static str),
/// An argument should have been positive and was not
ArgMustBePositive(&'static str),
/// An argument should have been non-negative and was not
Expand Down Expand Up @@ -58,6 +60,7 @@ impl fmt::Display for StatsError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
StatsError::BadParams => write!(f, "Bad distribution parameters"),
StatsError::ArgFinite(s) => write!(f, "Argument {} must be finite", s),
StatsError::ArgMustBePositive(s) => write!(f, "Argument {} must be positive", s),
StatsError::ArgNotNegative(s) => write!(f, "Argument {} must be non-negative", s),
StatsError::ArgIntervalIncl(s, min, max) => {
Expand Down
Loading