Skip to content

Commit

Permalink
refactor: Separate Empirical::mean_and_var
Browse files Browse the repository at this point in the history
  • Loading branch information
FreezyLemon committed Nov 10, 2024
1 parent dd03d33 commit e2c54cd
Showing 1 changed file with 69 additions and 32 deletions.
101 changes: 69 additions & 32 deletions src/distribution/empirical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ mod non_nan {
/// ```
#[derive(Clone, PartialEq, Debug)]
pub struct Empirical {
mean_and_var: Option<(f64, f64)>,
// keys are data points, values are number of data points with equal value
data: BTreeMap<NonNan<f64>, u64>,

// The following fields are only logically valid if !data.is_empty():
/// Total amount of data points (== sum of all _values_ inside self.data).
/// Must be 0 iff data.is_empty()
sum: u64,
mean: f64,
var: f64,
}

impl Empirical {
Expand All @@ -84,9 +85,10 @@ impl Empirical {
#[allow(clippy::result_unit_err)]
pub fn new() -> Result<Empirical, ()> {
Ok(Empirical {
sum: 0,
mean_and_var: None,
data: BTreeMap::new(),
sum: 0,
mean: 0.0,
var: 0.0,
})
}

Expand All @@ -97,17 +99,10 @@ impl Empirical {
};

self.sum += 1;
match self.mean_and_var {
Some((mean, var)) => {
let sum = self.sum as f64;
let var = var + (sum - 1.) * (data_point - mean) * (data_point - mean) / sum;
let mean = mean + (data_point - mean) / sum;
self.mean_and_var = Some((mean, var));
}
None => {
self.mean_and_var = Some((data_point, 0.));
}
}
let sum = self.sum as f64;
self.var += (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum;
self.mean += (data_point - self.mean) / sum;

*self.data.entry(map_key).or_insert(0) += 1;
}

Expand All @@ -117,21 +112,25 @@ impl Empirical {
None => return,
};

if let (Some(val), Some((mean, var))) = (self.data.remove(&map_key), self.mean_and_var) {
if val == 1 && self.data.is_empty() {
self.mean_and_var = None;
self.sum = 0;
return;
};
// reset mean and var
let sum = self.sum as f64;
let mean = (sum * mean - data_point) / (sum - 1.);
let var = var - (sum - 1.) * (data_point - mean) * (data_point - mean) / sum;
self.sum -= 1;
if val != 1 {
self.data.insert(map_key, val - 1);
};
self.mean_and_var = Some((mean, var));
let val = match self.data.remove(&map_key) {
Some(v) => v,
None => return,
};

if val == 1 && self.data.is_empty() {
self.sum = 0;
self.mean = 0.0;
self.var = 0.0;
return;
};

// reset mean and var
let sum = self.sum as f64;
self.mean = (sum * self.mean - data_point) / (sum - 1.);
self.var -= (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum;
self.sum -= 1;
if val != 1 {
self.data.insert(map_key, val - 1);
}
}

Expand Down Expand Up @@ -232,12 +231,19 @@ impl Min<f64> for Empirical {

impl Distribution<f64> for Empirical {
fn mean(&self) -> Option<f64> {
self.mean_and_var.map(|(mean, _)| mean)
if self.data.is_empty() {
None
} else {
Some(self.mean)
}
}

fn variance(&self) -> Option<f64> {
self.mean_and_var
.map(|(_, var)| var / (self.sum as f64 - 1.))
if self.data.is_empty() {
None
} else {
Some(self.var / (self.sum as f64 - 1.))
}
}
}

Expand Down Expand Up @@ -272,6 +278,37 @@ impl ContinuousCDF<f64, f64> for Empirical {
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_mean() {
fn test_mean_for_samples(expected_mean: f64, samples: Vec<f64>) {
let dist = Empirical::from_iter(samples);
assert_relative_eq!(dist.mean().unwrap(), expected_mean);
}

let dist = Empirical::from_iter(vec![]);
assert!(dist.mean().is_none());

test_mean_for_samples(4.0, vec![4.0; 100]);
test_mean_for_samples(-0.2, vec![-0.2; 100]);
test_mean_for_samples(28.5, vec![21.3, 38.4, 12.7, 41.6]);
}

#[test]
fn test_var() {
fn test_var_for_samples(expected_var: f64, samples: Vec<f64>) {
let dist = Empirical::from_iter(samples);
assert_relative_eq!(dist.variance().unwrap(), expected_var);
}

let dist = Empirical::from_iter(vec![]);
assert!(dist.variance().is_none());

test_var_for_samples(0.0, vec![4.0; 100]);
test_var_for_samples(0.0, vec![-0.2; 100]);
test_var_for_samples(190.36666666666667, vec![21.3, 38.4, 12.7, 41.6]);
}

#[test]
fn test_cdf() {
let samples = vec![5.0, 10.0];
Expand Down

0 comments on commit e2c54cd

Please sign in to comment.