Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Empirical distribution #308

Merged
merged 8 commits into from
Dec 3, 2024
254 changes: 182 additions & 72 deletions src/distribution/empirical.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,45 @@
use crate::distribution::ContinuousCDF;
use crate::statistics::*;
use core::cmp::Ordering;
use std::collections::BTreeMap;
use non_nan::NonNan;
use std::collections::btree_map::{BTreeMap, Entry};
use std::convert::Infallible;
use std::ops::Bound;

#[derive(Clone, PartialEq, Debug)]
struct NonNan<T>(T);
mod non_nan {
use core::cmp::Ordering;

impl<T: PartialEq> Eq for NonNan<T> {}
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct NonNan<T>(T);

impl<T: PartialOrd> PartialOrd for NonNan<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
impl<T: Copy> NonNan<T> {
pub fn get(self) -> T {
self.0
}
}
}

impl<T: PartialOrd> Ord for NonNan<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.0.partial_cmp(&other.0).unwrap()
impl NonNan<f64> {
#[inline]
pub fn new(x: f64) -> Option<Self> {
if x.is_nan() {
None
} else {
Some(Self(x))
}
}
}

impl<T: PartialEq> Eq for NonNan<T> {}

impl<T: PartialOrd> PartialOrd for NonNan<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}

Check warning on line 36 in src/distribution/empirical.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/empirical.rs#L34-L36

Added lines #L34 - L36 were not covered by tests
}

impl<T: PartialOrd> Ord for NonNan<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.0.partial_cmp(&other.0).unwrap()
}
}
}

Expand All @@ -36,10 +59,15 @@
/// ```
#[derive(Clone, PartialEq, Debug)]
pub struct Empirical {
sum: f64,
mean_and_var: Option<(f64, f64)>,
// keys are data points, values are number of data points with equal value
data: BTreeMap<NonNan<f64>, u64>,

// The following fields are only logically valid if !data.is_empty():
/// Total amount of data points (== sum of all _values_ inside self.data).
/// Must be 0 iff data.is_empty()
sum: u64,
mean: f64,
var: f64,
}

impl Empirical {
Expand All @@ -56,54 +84,62 @@
/// let mut result = Empirical::new();
/// assert!(result.is_ok());
/// ```
#[allow(clippy::result_unit_err)]
pub fn new() -> Result<Empirical, ()> {
pub fn new() -> Result<Empirical, Infallible> {
Ok(Empirical {
sum: 0.,
mean_and_var: None,
data: BTreeMap::new(),
sum: 0,
mean: 0.0,
var: 0.0,
})
}

pub fn add(&mut self, data_point: f64) {
if !data_point.is_nan() {
self.sum += 1.;
match self.mean_and_var {
Some((mean, var)) => {
let sum = self.sum;
let var = var + (sum - 1.) * (data_point - mean) * (data_point - mean) / sum;
let mean = mean + (data_point - mean) / sum;
self.mean_and_var = Some((mean, var));
}
None => {
self.mean_and_var = Some((data_point, 0.));
}
}
*self.data.entry(NonNan(data_point)).or_insert(0) += 1;
}
let map_key = match NonNan::new(data_point) {
Some(valid) => valid,
None => return,
};

self.sum += 1;
let sum = self.sum as f64;
self.var += (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum;
self.mean += (data_point - self.mean) / sum;

self.data
.entry(map_key)
.and_modify(|c| *c += 1)
.or_insert(1);
}

pub fn remove(&mut self, data_point: f64) {
if !data_point.is_nan() {
if let (Some(val), Some((mean, var))) =
(self.data.remove(&NonNan(data_point)), self.mean_and_var)
{
if val == 1 && self.data.is_empty() {
self.mean_and_var = None;
self.sum = 0.;
return;
};
// reset mean and var
let mean = (self.sum * mean - data_point) / (self.sum - 1.);
let var =
var - (self.sum - 1.) * (data_point - mean) * (data_point - mean) / self.sum;
self.sum -= 1.;
if val != 1 {
self.data.insert(NonNan(data_point), val - 1);
};
self.mean_and_var = Some((mean, var));
let map_key = match NonNan::new(data_point) {
Some(valid) => valid,
None => return,
};

let mut entry = match self.data.entry(map_key) {
Entry::Occupied(entry) => entry,
Entry::Vacant(_) => return, // no entry found
};

if *entry.get() == 1 {
entry.remove();
if self.data.is_empty() {
// logically, this should not need special handling.
// FP math can result in mean or var being != 0.0 though.
self.sum = 0;
self.mean = 0.0;
self.var = 0.0;
return;
}
} else {
*entry.get_mut() -= 1;
}

// reset mean and var
let sum = self.sum as f64;
self.mean = (sum * self.mean - data_point) / (sum - 1.);
self.var -= (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum;
self.sum -= 1;
}

// Due to issues with rounding and floating-point accuracy the default
Expand Down Expand Up @@ -148,7 +184,7 @@
let mut enumerated_values = self
.data
.iter()
.flat_map(|(&NonNan(x), &count)| std::iter::repeat(x).take(count as usize));
.flat_map(|(x, &count)| std::iter::repeat(x.get()).take(count as usize));

if let Some(x) = enumerated_values.next() {
write!(f, "Empirical([{x:.3e}")?;
Expand Down Expand Up @@ -190,48 +226,50 @@
/// Panics if number of samples is zero
impl Max<f64> for Empirical {
fn max(&self) -> f64 {
self.data.keys().rev().map(|key| key.0).next().unwrap()
self.data.keys().rev().map(|key| key.get()).next().unwrap()
}
}

/// Panics if number of samples is zero
impl Min<f64> for Empirical {
fn min(&self) -> f64 {
self.data.keys().map(|key| key.0).next().unwrap()
self.data.keys().map(|key| key.get()).next().unwrap()
}
}

impl Distribution<f64> for Empirical {
fn mean(&self) -> Option<f64> {
self.mean_and_var.map(|(mean, _)| mean)
if self.data.is_empty() {
None
} else {
Some(self.mean)
}
}

fn variance(&self) -> Option<f64> {
self.mean_and_var.map(|(_, var)| var / (self.sum - 1.))
if self.data.is_empty() {
None
} else {
Some(self.var / (self.sum as f64 - 1.))
}
}
}

impl ContinuousCDF<f64, f64> for Empirical {
fn cdf(&self, x: f64) -> f64 {
let mut sum = 0;
for (keys, values) in &self.data {
if keys.0 > x {
return sum as f64 / self.sum;
}
sum += values;
}
sum as f64 / self.sum
let start = Bound::Unbounded;
let end = Bound::Included(NonNan::new(x).expect("x must not be NaN"));

let sum: u64 = self.data.range((start, end)).map(|(_, v)| v).sum();
sum as f64 / self.sum as f64
}

fn sf(&self, x: f64) -> f64 {
let mut sum = 0;
for (keys, values) in self.data.iter().rev() {
if keys.0 <= x {
return sum as f64 / self.sum;
}
sum += values;
}
sum as f64 / self.sum
let start = Bound::Excluded(NonNan::new(x).expect("x must not be NaN"));
let end = Bound::Unbounded;

let sum: u64 = self.data.range((start, end)).map(|(_, v)| v).sum();
sum as f64 / self.sum as f64
}

fn inverse_cdf(&self, p: f64) -> f64 {
Expand All @@ -242,6 +280,78 @@
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_add_nan() {
let mut empirical = Empirical::new().unwrap();

// should not panic
empirical.add(f64::NAN);
}

#[test]
fn test_remove_nan() {
let mut empirical = Empirical::new().unwrap();

empirical.add(5.2);
// should not panic
empirical.remove(f64::NAN);
}

#[test]
fn test_remove_nonexisting() {
let mut empirical = Empirical::new().unwrap();

empirical.add(5.2);
// should not panic
empirical.remove(10.0);
}

#[test]
fn test_remove_all() {
let mut empirical = Empirical::new().unwrap();

empirical.add(17.123);
empirical.add(-10.0);
empirical.add(0.0);
empirical.remove(-10.0);
empirical.remove(17.123);
empirical.remove(0.0);

assert!(empirical.mean().is_none());
assert!(empirical.variance().is_none());
}

#[test]
fn test_mean() {
fn test_mean_for_samples(expected_mean: f64, samples: Vec<f64>) {
let dist = Empirical::from_iter(samples);
assert_relative_eq!(dist.mean().unwrap(), expected_mean);
}

let dist = Empirical::from_iter(vec![]);
assert!(dist.mean().is_none());

test_mean_for_samples(4.0, vec![4.0; 100]);
test_mean_for_samples(-0.2, vec![-0.2; 100]);
test_mean_for_samples(28.5, vec![21.3, 38.4, 12.7, 41.6]);
}

#[test]
fn test_var() {
fn test_var_for_samples(expected_var: f64, samples: Vec<f64>) {
let dist = Empirical::from_iter(samples);
assert_relative_eq!(dist.variance().unwrap(), expected_var);
}

let dist = Empirical::from_iter(vec![]);
assert!(dist.variance().is_none());

test_var_for_samples(0.0, vec![4.0; 100]);
test_var_for_samples(0.0, vec![-0.2; 100]);
test_var_for_samples(190.36666666666667, vec![21.3, 38.4, 12.7, 41.6]);
}

#[test]
fn test_cdf() {
let samples = vec![5.0, 10.0];
Expand Down