diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index 640b8626cb9..25def45ef7f 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased - All error types now implement `std::error::Error` (#919) - Re-exported `rand::distributions::BernoulliError` (#919) +- Add case `lambda = 0` in the parametrixation of `Exp` (#972) ## [0.2.2] - 2019-09-10 - Fix version requirement on rand lib (#847) diff --git a/rand_distr/src/exponential.rs b/rand_distr/src/exponential.rs index 87fb29f5a8e..3fe8e22fd09 100644 --- a/rand_distr/src/exponential.rs +++ b/rand_distr/src/exponential.rs @@ -76,7 +76,7 @@ impl Distribution for Exp1 { /// The exponential distribution `Exp(lambda)`. /// /// This distribution has density function: `f(x) = lambda * exp(-lambda * x)` -/// for `x > 0`. +/// for `x > 0`, when `lambda > 0`. For `lambda = 0`, all samples yield infinity. /// /// Note that [`Exp1`](crate::Exp1) is an optimised implementation for `lambda = 1`. /// @@ -98,14 +98,14 @@ pub struct Exp { /// Error type returned from `Exp::new`. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { - /// `lambda <= 0` or `nan`. + /// `lambda < 0` or `nan`. LambdaTooSmall, } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { - Error::LambdaTooSmall => "lambda is not positive in exponential distribution", + Error::LambdaTooSmall => "lambda is negative or NaN in exponential distribution", }) } } @@ -117,9 +117,16 @@ where Exp1: Distribution { /// Construct a new `Exp` with the given shape parameter /// `lambda`. + /// + /// # Remarks + /// + /// For custom types `N` implementing the [`Float`](crate::Float) trait, + /// the case `lambda = 0` is handled as follows: each sample corresponds + /// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types + /// yield infinity, since `1 / 0 = infinity`. #[inline] pub fn new(lambda: N) -> Result, Error> { - if !(lambda > N::from(0.0)) { + if !(lambda >= N::from(0.0)) { return Err(Error::LambdaTooSmall); } Ok(Exp { @@ -149,15 +156,20 @@ mod test { } } #[test] - #[should_panic] - fn test_exp_invalid_lambda_zero() { - Exp::new(0.0).unwrap(); + fn test_zero() { + let d = Exp::new(0.0).unwrap(); + assert_eq!(d.sample(&mut crate::test::rng(21)), std::f64::INFINITY); } #[test] #[should_panic] fn test_exp_invalid_lambda_neg() { Exp::new(-10.0).unwrap(); } + #[test] + #[should_panic] + fn test_exp_invalid_lambda_nan() { + Exp::new(std::f64::NAN).unwrap(); + } #[test] fn value_stability() {