Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[minor]: Update median implementation #13554

Merged
merged 7 commits into from
Nov 29, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions datafusion/functions-aggregate/src/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use arrow::{

use arrow::array::Array;
use arrow::array::ArrowNativeTypeOp;
use arrow::datatypes::ArrowNativeType;
use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType};

use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue};
use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
Expand Down Expand Up @@ -310,6 +310,18 @@ impl<T: ArrowNumericType> Accumulator for DistinctMedianAccumulator<T> {
}
}

/// Get maximum entry in the slice,
fn slice_max<T>(array: &[T::Native]) -> Option<T::Native>
where
T: ArrowPrimitiveType,
T::Native: PartialOrd, // Ensure the type supports PartialOrd for comparison
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As unwrap is used, the code could be updated / simplified to work on arrays of size >= 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me, updated the implementation such that function doesn't return option in 0303c9

array.iter().fold(None, |acc, val| match acc {
Some(max) if max > *val => Some(max),
_ => Some(*val),
}) // Fold to find the maximum
}

fn calculate_median<T: ArrowNumericType>(
mut values: Vec<T::Native>,
) -> Option<T::Native> {
Expand All @@ -320,8 +332,12 @@ fn calculate_median<T: ArrowNumericType>(
None
} else if len % 2 == 0 {
let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp);
let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2));
// Get the maximum of the low (left side after bi-partitioning)
// Since contains at least 1 entry, shouldn't return Some.
let left_max = slice_max::<T>(low).unwrap();
let median = left_max
.add_wrapping(*high)
.div_wrapping(T::Native::usize_as(2));
Some(median)
} else {
let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
Expand Down