Skip to content

Commit

Permalink
[Minor] Refactor approx_percentile (#11769)
Browse files Browse the repository at this point in the history
* Refactor approx_percentile

* Refactor approx_percentile

* Types

* Types

* Types
  • Loading branch information
Dandandan authored Aug 2, 2024
1 parent f044bc8 commit a0ad376
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 31 deletions.
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/approx_median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl AggregateUDFImpl for ApproxMedian {
Ok(vec![
Field::new(format_state_name(args.name, "max_size"), UInt64, false),
Field::new(format_state_name(args.name, "sum"), Float64, false),
Field::new(format_state_name(args.name, "count"), Float64, false),
Field::new(format_state_name(args.name, "count"), UInt64, false),
Field::new(format_state_name(args.name, "max"), Float64, false),
Field::new(format_state_name(args.name, "min"), Float64, false),
Field::new_list(
Expand Down
8 changes: 4 additions & 4 deletions datafusion/functions-aggregate/src/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ impl AggregateUDFImpl for ApproxPercentileCont {
),
Field::new(
format_state_name(args.name, "count"),
DataType::Float64,
DataType::UInt64,
false,
),
Field::new(
Expand Down Expand Up @@ -406,7 +406,7 @@ impl Accumulator for ApproxPercentileAccumulator {
}

fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
if self.digest.count() == 0.0 {
if self.digest.count() == 0 {
return ScalarValue::try_from(self.return_type.clone());
}
let q = self.digest.estimate_quantile(self.percentile);
Expand Down Expand Up @@ -487,8 +487,8 @@ mod tests {
ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);

accumulator.merge_digests(&[t1]);
assert_eq!(accumulator.digest.count(), 50_000.0);
assert_eq!(accumulator.digest.count(), 50_000);
accumulator.merge_digests(&[t2]);
assert_eq!(accumulator.digest.count(), 100_000.0);
assert_eq!(accumulator.digest.count(), 100_000);
}
}
62 changes: 36 additions & 26 deletions datafusion/physical-expr-common/src/aggregate/tdigest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ macro_rules! cast_scalar_f64 {
};
}

// Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or
// panic.
macro_rules! cast_scalar_u64 {
($value:expr ) => {
match &$value {
ScalarValue::UInt64(Some(v)) => *v,
v => panic!("invalid type {:?}", v),
}
};
}

/// This trait is implemented for each type a [`TDigest`] can operate on,
/// allowing it to support both numerical rust types (obtained from
/// `PrimitiveArray` instances), and [`ScalarValue`] instances.
Expand Down Expand Up @@ -142,7 +153,7 @@ pub struct TDigest {
centroids: Vec<Centroid>,
max_size: usize,
sum: f64,
count: f64,
count: u64,
max: f64,
min: f64,
}
Expand All @@ -153,7 +164,7 @@ impl TDigest {
centroids: Vec::new(),
max_size,
sum: 0_f64,
count: 0_f64,
count: 0,
max: f64::NAN,
min: f64::NAN,
}
Expand All @@ -164,14 +175,14 @@ impl TDigest {
centroids: vec![centroid.clone()],
max_size,
sum: centroid.mean * centroid.weight,
count: 1_f64,
count: 1,
max: centroid.mean,
min: centroid.mean,
}
}

#[inline]
pub fn count(&self) -> f64 {
pub fn count(&self) -> u64 {
self.count
}

Expand Down Expand Up @@ -203,16 +214,16 @@ impl Default for TDigest {
centroids: Vec::new(),
max_size: 100,
sum: 0_f64,
count: 0_f64,
count: 0,
max: f64::NAN,
min: f64::NAN,
}
}
}

impl TDigest {
fn k_to_q(k: f64, d: f64) -> f64 {
let k_div_d = k / d;
fn k_to_q(k: u64, d: usize) -> f64 {
let k_div_d = k as f64 / d as f64;
if k_div_d >= 0.5 {
let base = 1.0 - k_div_d;
1.0 - 2.0 * base * base
Expand Down Expand Up @@ -244,12 +255,12 @@ impl TDigest {
}

let mut result = TDigest::new(self.max_size());
result.count = self.count() + (sorted_values.len() as f64);
result.count = self.count() + sorted_values.len() as u64;

let maybe_min = *sorted_values.first().unwrap();
let maybe_max = *sorted_values.last().unwrap();

if self.count() > 0.0 {
if self.count() > 0 {
result.min = self.min.min(maybe_min);
result.max = self.max.max(maybe_max);
} else {
Expand All @@ -259,10 +270,10 @@ impl TDigest {

let mut compressed: Vec<Centroid> = Vec::with_capacity(self.max_size);

let mut k_limit: f64 = 1.0;
let mut k_limit: u64 = 1;
let mut q_limit_times_count =
Self::k_to_q(k_limit, self.max_size as f64) * result.count();
k_limit += 1.0;
Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
k_limit += 1;

let mut iter_centroids = self.centroids.iter().peekable();
let mut iter_sorted_values = sorted_values.iter().peekable();
Expand Down Expand Up @@ -309,8 +320,8 @@ impl TDigest {

compressed.push(curr.clone());
q_limit_times_count =
Self::k_to_q(k_limit, self.max_size as f64) * result.count();
k_limit += 1.0;
Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
k_limit += 1;
curr = next;
}
}
Expand Down Expand Up @@ -381,16 +392,16 @@ impl TDigest {
let mut centroids: Vec<Centroid> = Vec::with_capacity(n_centroids);
let mut starts: Vec<usize> = Vec::with_capacity(digests.len());

let mut count: f64 = 0.0;
let mut count = 0;
let mut min = f64::INFINITY;
let mut max = f64::NEG_INFINITY;

let mut start: usize = 0;
for digest in digests.iter() {
starts.push(start);

let curr_count: f64 = digest.count();
if curr_count > 0.0 {
let curr_count = digest.count();
if curr_count > 0 {
min = min.min(digest.min);
max = max.max(digest.max);
count += curr_count;
Expand Down Expand Up @@ -424,8 +435,8 @@ impl TDigest {
let mut result = TDigest::new(max_size);
let mut compressed: Vec<Centroid> = Vec::with_capacity(max_size);

let mut k_limit: f64 = 1.0;
let mut q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) * (count);
let mut k_limit = 1;
let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;

let mut iter_centroids = centroids.iter_mut();
let mut curr = iter_centroids.next().unwrap();
Expand All @@ -444,8 +455,8 @@ impl TDigest {
sums_to_merge = 0_f64;
weights_to_merge = 0_f64;
compressed.push(curr.clone());
q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) * (count);
k_limit += 1.0;
q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
k_limit += 1;
curr = centroid;
}
}
Expand All @@ -468,8 +479,7 @@ impl TDigest {
return 0.0;
}

let count_ = self.count;
let rank = q * count_;
let rank = q * self.count as f64;

let mut pos: usize;
let mut t;
Expand All @@ -479,7 +489,7 @@ impl TDigest {
}

pos = 0;
t = count_;
t = self.count as f64;

for (k, centroid) in self.centroids.iter().enumerate().rev() {
t -= centroid.weight();
Expand Down Expand Up @@ -581,7 +591,7 @@ impl TDigest {
vec![
ScalarValue::UInt64(Some(self.max_size as u64)),
ScalarValue::Float64(Some(self.sum)),
ScalarValue::Float64(Some(self.count)),
ScalarValue::UInt64(Some(self.count)),
ScalarValue::Float64(Some(self.max)),
ScalarValue::Float64(Some(self.min)),
ScalarValue::List(arr),
Expand Down Expand Up @@ -627,7 +637,7 @@ impl TDigest {
Self {
max_size,
sum: cast_scalar_f64!(state[1]),
count: cast_scalar_f64!(&state[2]),
count: cast_scalar_u64!(&state[2]),
max,
min,
centroids,
Expand Down

0 comments on commit a0ad376

Please sign in to comment.