Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace [Discrete]Distribution trait for traits for moments and entropy separately #304

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,8 @@ pub const TWO_SQRT_E_OVER_PI: f64 = 1.860382734205265717336249247266663112059421
pub const EULER_MASCHERONI: f64 =
0.5772156649015328606065120900824024310421593359399235988057672348849;

/// Constant value for `zeta(3)`
pub const ZETA_3: f64 = 1.2020569031595942853997381615114499907649862923404988817922715553;

/// Targeted accuracy instantiated over `f64`
pub const ACC: f64 = 10e-11;
130 changes: 70 additions & 60 deletions src/distribution/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ use crate::statistics::*;
/// # Examples
///
/// ```
/// use statrs::distribution::{Bernoulli, Discrete};
/// use statrs::statistics::Distribution;
/// use statrs::distribution::{Bernoulli, BinomialError, Discrete};
/// use statrs::statistics::*;
///
/// let n = Bernoulli::new(0.5).unwrap();
/// assert_eq!(n.mean().unwrap(), 0.5);
/// let n = Bernoulli::new(0.5)?;
/// assert_eq!(n.mean(), 0.5);
/// assert_eq!(n.pmf(0), 0.5);
/// assert_eq!(n.pmf(1), 0.5);
/// # Ok::<(), BinomialError>(())
/// ```
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct Bernoulli {
Expand Down Expand Up @@ -53,10 +54,11 @@ impl Bernoulli {
/// # Examples
///
/// ```
/// use statrs::distribution::Bernoulli;
/// use statrs::distribution::{Bernoulli, BinomialError};
///
/// let n = Bernoulli::new(0.5).unwrap();
/// let n = Bernoulli::new(0.5)?;
/// assert_eq!(n.p(), 0.5);
/// # Ok::<(), BinomialError>(())
/// ```
pub fn p(&self) -> f64 {
self.b.p()
Expand All @@ -68,10 +70,11 @@ impl Bernoulli {
/// # Examples
///
/// ```
/// use statrs::distribution::Bernoulli;
/// use statrs::distribution::{Bernoulli, BinomialError};
///
/// let n = Bernoulli::new(0.5).unwrap();
/// let n = Bernoulli::new(0.5)?;
/// assert_eq!(n.n(), 1);
/// # Ok::<(), BinomialError>(())
/// ```
pub fn n(&self) -> u64 {
1
Expand Down Expand Up @@ -160,58 +163,6 @@ impl Max<u64> for Bernoulli {
}
}

impl Distribution<f64> for Bernoulli {
/// Returns the mean of the bernoulli
/// distribution
///
/// # Formula
///
/// ```text
/// p
/// ```
fn mean(&self) -> Option<f64> {
self.b.mean()
}

/// Returns the variance of the bernoulli
/// distribution
///
/// # Formula
///
/// ```text
/// p * (1 - p)
/// ```
fn variance(&self) -> Option<f64> {
self.b.variance()
}

/// Returns the entropy of the bernoulli
/// distribution
///
/// # Formula
///
/// ```text
/// q = (1 - p)
/// -q * ln(q) - p * ln(p)
/// ```
fn entropy(&self) -> Option<f64> {
self.b.entropy()
}

/// Returns the skewness of the bernoulli
/// distribution
///
/// # Formula
///
/// ```text
/// q = (1 - p)
/// (1 - 2p) / sqrt(p * q)
/// ```
fn skewness(&self) -> Option<f64> {
self.b.skewness()
}
}

impl Median<f64> for Bernoulli {
/// Returns the median of the bernoulli
/// distribution
Expand Down Expand Up @@ -270,6 +221,65 @@ impl Discrete<u64, f64> for Bernoulli {
}
}

/// returns the mean of a bernoulli variable
///
/// this is also it's probability parameter

impl Mean for Bernoulli {
type Mu = f64;
fn mean(&self) -> Self::Mu {
self.p()
}
}

/// returns the variance of a bernoulli variable
///
/// # Formula
/// ```text
/// p (1 - p)
/// ```

impl Variance for Bernoulli {
type Var = f64;
fn variance(&self) -> Self::Var {
let p = self.p();
p.mul_add(-p, p)
}
}

/// Returns the skewness of a bernoulli variable
///
/// # Formula
///
/// ```text
/// (1 - 2p) / sqrt(p * (1 - p)))
/// ```

impl Skewness for Bernoulli {
type Skew = f64;
fn skewness(&self) -> Self::Skew {
let d = 0.5 - self.p();
2.0 * d / (0.25 - d * d).sqrt()
}
}

/// Returns the excess kurtosis of a bernoulli variable
///
/// # Formula
///
/// ```text
/// pq^-1 - 6; pq = p (1-p)
/// ```

impl ExcessKurtosis for Bernoulli {
type Kurt = f64;
fn excess_kurtosis(&self) -> Self::Kurt {
let p = self.p();
let pq = p.mul_add(-p, p);
pq.recip() - 6.0
}
}

#[rustfmt::skip]
#[cfg(test)]
mod testing {
Expand Down
142 changes: 79 additions & 63 deletions src/distribution/beta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ use crate::statistics::*;
/// # Examples
///
/// ```
/// use statrs::distribution::{Beta, Continuous};
/// use statrs::distribution::{Beta, Continuous, BetaError};
/// use statrs::statistics::*;
/// use statrs::prec;
///
/// let n = Beta::new(2.0, 2.0).unwrap();
/// assert_eq!(n.mean().unwrap(), 0.5);
/// let n = Beta::new(2.0, 2.0)?;
/// assert_eq!(n.mean(), 0.5);
/// assert!(prec::almost_eq(n.pdf(0.5), 1.5, 1e-14));
/// # Ok::<(), BetaError>(())
/// ```
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct Beta {
Expand Down Expand Up @@ -82,10 +83,11 @@ impl Beta {
/// # Examples
///
/// ```
/// use statrs::distribution::Beta;
/// use statrs::distribution::{Beta, BetaError};
///
/// let n = Beta::new(1.0, 2.0).unwrap();
/// let n = Beta::new(1.0, 2.0)?;
/// assert_eq!(n.shape_a(), 1.0);
/// # Ok::<(), BetaError>(())
/// ```
pub fn shape_a(&self) -> f64 {
self.shape_a
Expand All @@ -96,10 +98,11 @@ impl Beta {
/// # Examples
///
/// ```
/// use statrs::distribution::Beta;
/// use statrs::distribution::{Beta, BetaError};
///
/// let n = Beta::new(1.0, 2.0).unwrap();
/// let n = Beta::new(1.0, 2.0)?;
/// assert_eq!(n.shape_b(), 2.0);
/// # Ok::<(), BetaError>(())
/// ```
pub fn shape_b(&self) -> f64 {
self.shape_b
Expand Down Expand Up @@ -221,38 +224,70 @@ impl Max<f64> for Beta {
}
}

impl Distribution<f64> for Beta {
/// Returns the mean of the beta distribution.
///
/// # Formula
///
/// ```text
/// α / (α + β)
/// ```
///
/// where `α` is shapeA and `β` is shapeB.
fn mean(&self) -> Option<f64> {
Some(self.shape_a / (self.shape_a + self.shape_b))
/// Returns the mean of the beta distribution.
///
/// # Formula
///
/// ```text
/// α / (α + β)
/// ```
///
/// where `α` is shapeA and `β` is shapeB.
impl Mean for Beta {
type Mu = f64;
fn mean(&self) -> Self::Mu {
self.shape_a / (self.shape_a + self.shape_b)
}
}

/// Returns the variance of the beta distribution.
///
/// # Formula
///
/// ```text
/// (α * β) / ((α + β)^2 * (α + β + 1))
/// ```
///
/// where `α` is shapeA and `β` is shapeB.
fn variance(&self) -> Option<f64> {
Some(
self.shape_a * self.shape_b
/ ((self.shape_a + self.shape_b)
* (self.shape_a + self.shape_b)
* (self.shape_a + self.shape_b + 1.0)),
)
/// Returns the variance of the beta distribution.
///
/// # Formula
///
/// ```text
/// (α * β) / ((α + β)^2 * (α + β + 1))
/// ```
///
/// where `α` is shapeA and `β` is shapeB.
impl Variance for Beta {
type Var = f64;
fn variance(&self) -> Self::Var {
self.shape_a * self.shape_b
/ ((self.shape_a + self.shape_b)
* (self.shape_a + self.shape_b)
* (self.shape_a + self.shape_b + 1.0))
}
}

/// Returns the skewness of the Beta distribution.
///
/// # Formula
///
/// ```text
/// 2(β - α) * sqrt(α + β + 1) / ((α + β + 2) * sqrt(αβ))
/// ```
///
/// where `α` is shapeA and `β` is shapeB.
impl Skewness for Beta {
type Skew = f64;
fn skewness(&self) -> Self::Skew {
2.0 * (self.shape_b - self.shape_a) * (self.shape_a + self.shape_b + 1.0).sqrt()
/ ((self.shape_a + self.shape_b + 2.0) * (self.shape_a * self.shape_b).sqrt())
}
}

/// Returns the excess kurtosis of the Beta distribution
impl ExcessKurtosis for Beta {
type Kurt = f64;
fn excess_kurtosis(&self) -> Self::Kurt {
let a = self.shape_a;
let b = self.shape_b;
let numer = 6. * ((a - b).powi(2) * (a + b + 1.) - a * b * (a + b + 2.));
let denom = a * b * (a + b + 2.) * (a + b + 3.);
numer / denom
}
}
impl Entropy<f64> for Beta {
/// Returns the entropy of the beta distribution.
///
/// # Formula
Expand All @@ -262,32 +297,13 @@ impl Distribution<f64> for Beta {
/// ```
///
/// where `α` is shapeA, `β` is shapeB and `ψ` is the digamma function.
fn entropy(&self) -> Option<f64> {
Some(
beta::ln_beta(self.shape_a, self.shape_b)
- (self.shape_a - 1.0) * gamma::digamma(self.shape_a)
- (self.shape_b - 1.0) * gamma::digamma(self.shape_b)
+ (self.shape_a + self.shape_b - 2.0) * gamma::digamma(self.shape_a + self.shape_b),
)
}

/// Returns the skewness of the Beta distribution.
///
/// # Formula
///
/// ```text
/// 2(β - α) * sqrt(α + β + 1) / ((α + β + 2) * sqrt(αβ))
/// ```
///
/// where `α` is shapeA and `β` is shapeB.
fn skewness(&self) -> Option<f64> {
Some(
2.0 * (self.shape_b - self.shape_a) * (self.shape_a + self.shape_b + 1.0).sqrt()
/ ((self.shape_a + self.shape_b + 2.0) * (self.shape_a * self.shape_b).sqrt()),
)
fn entropy(&self) -> f64 {
beta::ln_beta(self.shape_a, self.shape_b)
- (self.shape_a - 1.0) * gamma::digamma(self.shape_a)
- (self.shape_b - 1.0) * gamma::digamma(self.shape_b)
+ (self.shape_a + self.shape_b - 2.0) * gamma::digamma(self.shape_a + self.shape_b)
}
}

impl Mode<Option<f64>> for Beta {
/// Returns the mode of the Beta distribution. Returns `None` if `α <= 1`
/// or `β <= 1`.
Expand Down Expand Up @@ -423,7 +439,7 @@ mod tests {

#[test]
fn test_mean() {
let f = |x: Beta| x.mean().unwrap();
let f = |x: Beta| x.mean();
let test = [
((1.0, 1.0), 0.5),
((9.0, 1.0), 0.9),
Expand All @@ -436,7 +452,7 @@ mod tests {

#[test]
fn test_variance() {
let f = |x: Beta| x.variance().unwrap();
let f = |x: Beta| x.variance();
let test = [
((1.0, 1.0), 1.0 / 12.0),
((9.0, 1.0), 9.0 / 1100.0),
Expand All @@ -449,7 +465,7 @@ mod tests {

#[test]
fn test_entropy() {
let f = |x: Beta| x.entropy().unwrap();
let f = |x: Beta| x.entropy();
let test = [
((9.0, 1.0), -1.3083356884473304939016015),
((5.0, 100.0), -2.52016231876027436794592),
Expand All @@ -462,7 +478,7 @@ mod tests {

#[test]
fn test_skewness() {
let skewness = |x: Beta| x.skewness().unwrap();
let skewness = |x: Beta| x.skewness();
test_relative(1.0, 1.0, 0.0, skewness);
test_relative(9.0, 1.0, -1.4740554623801777107177478829, skewness);
test_relative(5.0, 100.0, 0.817594109275534303545831591, skewness);
Expand Down
Loading
Loading