diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index bcd01b81..11240599 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -51,6 +51,24 @@ impl Normal { Ok(Normal { mean, std_dev }) } } + + /// Constructs a new standard normal distribution with a mean of 0 + /// and a standard deviation of 1. + /// + /// + /// # Examples + /// + /// ``` + /// use statrs::distribution::Normal; + /// + /// let mut result = Normal::standard(); + /// ``` + pub fn standard() -> Normal { + Normal { + mean: 0.0, + std_dev: 1.0, + } + } } impl ::rand::distributions::Distribution for Normal { @@ -288,6 +306,15 @@ pub fn sample_unchecked(rng: &mut R, mean: f64, std_dev: f64) - mean + std_dev * ziggurat::sample_std_normal(rng) } + +impl std::default::Default for Normal { + /// Returns the standard normal distribution with a mean of 0 + /// and a standard deviation of 1. + fn default() -> Self { + Self::standard() + } +} + #[rustfmt::skip] #[cfg(all(test, feature = "nightly"))] mod tests { @@ -508,4 +535,17 @@ mod tests { test_almost(5.0, 2.0, 10.0, 1e-14, inverse_cdf(0.9937903346742238648330218954258077788721022530769078)); test_case(5.0, 2.0, f64::INFINITY, inverse_cdf(1.0)); } + + #[test] + fn test_default() { + let n = Normal::default(); + + let n_mean = n.mean().unwrap(); + let n_std = n.std_dev().unwrap(); + + // Check that the mean of the distribution is close to 0 + assert_almost_eq!(n_mean, 0.0, 1e-15); + // Check that the standard deviation of the distribution is close to 1 + assert_almost_eq!(n_std, 1.0, 1e-15); + } }