diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 5b8792298533..bcffb19b7559 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::cmp::Ordering; use std::fmt::{Debug, Formatter}; use std::mem::{size_of, size_of_val}; use std::sync::{Arc, OnceLock}; @@ -30,7 +31,7 @@ use arrow::{ use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; -use arrow::datatypes::ArrowNativeType; +use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue}; use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; @@ -310,6 +311,21 @@ impl Accumulator for DistinctMedianAccumulator { } } +/// Get maximum entry in the slice, +fn slice_max(array: &[T::Native]) -> T::Native +where + T: ArrowPrimitiveType, + T::Native: PartialOrd, // Ensure the type supports PartialOrd for comparison +{ + // Make sure that, array is not empty. + debug_assert!(!array.is_empty()); + // `.unwrap()` is safe here as the array is supposed to be non-empty + *array + .iter() + .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less)) + .unwrap() +} + fn calculate_median( mut values: Vec, ) -> Option { @@ -320,8 +336,11 @@ fn calculate_median( None } else if len % 2 == 0 { let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp); - let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp); - let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2)); + // Get the maximum of the low (left side after bi-partitioning) + let left_max = slice_max::(low); + let median = left_max + .add_wrapping(*high) + .div_wrapping(T::Native::usize_as(2)); Some(median) } else { let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);