Skip to content

Commit

Permalink
Implement median colmputation for Gaussian-Gaussian pair
Browse files Browse the repository at this point in the history
  • Loading branch information
Fly-Style committed Jan 31, 2025
1 parent b8068b8 commit 140fb5e
Showing 1 changed file with 47 additions and 10 deletions.
57 changes: 47 additions & 10 deletions datafusion/physical-expr/src/utils/stats_v2_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use petgraph::prelude::Bfs;
use petgraph::stable_graph::{NodeIndex, StableGraph};
use petgraph::visit::DfsPostOrder;
use petgraph::Outgoing;
use StatisticsV2::Gaussian;

#[derive(Clone, Debug)]
pub struct ExprStatisticGraphNode {
Expand Down Expand Up @@ -289,21 +290,24 @@ pub fn compute_median(
right_stat: &StatisticsV2,
) -> Result<Option<ScalarValue>> {
match (left_stat, right_stat) {
// TODO[sasha]: handle Gaussian!
(Uniform { .. }, Uniform { .. }) => {
if let (Some(l_median), Some(r_median)) =
(left_stat.median()?, right_stat.median()?)
{
match op {
Operator::Plus => Ok(Some(l_median.add_checked(r_median)?)),
Operator::Minus => Ok(Some(l_median.sub_checked(r_median)?)),
Operator::Multiply | Operator::Divide | Operator::Modulo => Ok(None),
_ => Ok(None),
}
} else {
Ok(None)
}
}
(Gaussian { mean: l_mean, .. }, Gaussian { mean: r_mean, .. }) => match op {
Operator::Plus => Ok(Some(l_mean.add_checked(r_mean)?)),
Operator::Minus => Ok(Some(l_mean.sub_checked(r_mean)?)),
_ => Ok(None),
},
// Any
_ => Ok(None),
}
Expand Down Expand Up @@ -472,14 +476,9 @@ mod tests {
// Expected test results were calculated in Wolfram Mathematica, by using
// *METHOD_NAME*[TransformedDistribution[x op y, {x ~ *DISTRIBUTION_X*[..], y ~ *DISTRIBUTION_Y*[..]}]]
#[test]
fn test_unknown_properties_uniform_uniform() -> Result<()> {
let stat_a = Uniform {
interval: Interval::make(Some(0.), Some(12.0))?,
};

let stat_b = Uniform {
interval: Interval::make(Some(12.0), Some(36.0))?,
};
fn test_calculate_unknown_properties_uniform_uniform() -> Result<()> {
let stat_a = StatisticsV2::new_uniform(Interval::make(Some(0.), Some(12.))?)?;
let stat_b = StatisticsV2::new_uniform(Interval::make(Some(12.), Some(36.))?)?;

let test_data: Vec<(Actual, Expected)> = vec![
// mean
Expand Down Expand Up @@ -524,6 +523,44 @@ mod tests {
Ok(())
}

#[test]
fn test_calculate_unknown_properties_gauss_gauss() -> Result<()> {
let stat_a = StatisticsV2::new_gaussian(
ScalarValue::from(Some(10.)),
ScalarValue::from(Some(0.0)),
)?;
let stat_b = StatisticsV2::new_gaussian(
ScalarValue::from(Some(20.)),
ScalarValue::from(Some(0.0)),
)?;

let test_data: Vec<(Actual, Expected)> = vec![
// mean
(
compute_mean(&Plus, &stat_a, &stat_b)?,
Some(Float64(Some(30.))),
),
(
compute_mean(&Minus, &stat_a, &stat_b)?,
Some(Float64(Some(-10.))),
),
// median
(
compute_median(&Plus, &stat_a, &stat_b)?,
Some(Float64(Some(30.))),
),
(
compute_median(&Minus, &stat_a, &stat_b)?,
Some(Float64(Some(-10.))),
),
];
for (actual, expected) in test_data {
assert_eq!(actual, expected);
}

Ok(())
}

/// Test for Uniform-Uniform, Uniform-Unknown, Unknown-Uniform, Unknown-Unknown pairs,
/// where range is always present.
#[test]
Expand Down

0 comments on commit 140fb5e

Please sign in to comment.