Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support specific GroupsAccumulator for median #13681

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 242 additions & 2 deletions datafusion/functions-aggregate/src/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ use std::fmt::{Debug, Formatter};
use std::mem::{size_of, size_of_val};
use std::sync::Arc;

use arrow::array::{downcast_integer, ArrowNumericType};
use arrow::array::{
downcast_integer, ArrowNumericType, BooleanArray, ListArray, PrimitiveArray,
PrimitiveBuilder,
};
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
use arrow::{
array::{ArrayRef, AsArray},
datatypes::{
Expand All @@ -33,12 +37,15 @@ use arrow::array::Array;
use arrow::array::ArrowNativeTypeOp;
use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType};

use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue};
use datafusion_common::{internal_err, DataFusionError, HashSet, Result, ScalarValue};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
Documentation, Signature, Volatility,
};
use datafusion_expr::{EmitTo, GroupsAccumulator};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
use datafusion_functions_aggregate_common::utils::Hashable;
use datafusion_macros::user_doc;

Expand Down Expand Up @@ -165,6 +172,45 @@ impl AggregateUDFImpl for Median {
}
}

fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
!args.is_distinct
}

fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
let num_args = args.exprs.len();
if num_args != 1 {
return internal_err!(
"median should only have 1 arg, but found num args:{}",
args.exprs.len()
);
}

let dt = args.exprs[0].data_type(args.schema)?;

macro_rules! helper {
($t:ty, $dt:expr) => {
Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
};
}

downcast_integer! {
dt => (helper, dt),
DataType::Float16 => helper!(Float16Type, dt),
DataType::Float32 => helper!(Float32Type, dt),
DataType::Float64 => helper!(Float64Type, dt),
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
_ => Err(DataFusionError::NotImplemented(format!(
"MedianGroupsAccumulator not supported for {} with {}",
args.name,
dt,
))),
}
}

fn aliases(&self) -> &[String] {
&[]
}
Expand Down Expand Up @@ -230,6 +276,200 @@ impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
}
}

/// The median groups accumulator accumulates the raw input values
///
/// For calculating the accurate medians of groups, we need to store all values
/// of groups before final evaluation.
/// So values in each group will be stored in a `Vec<T>`, so the total group values
/// will be actually organized as a `Vec<Vec<T>>`.
///
#[derive(Debug)]
struct MedianGroupsAccumulator<T: ArrowNumericType + Send> {
data_type: DataType,
group_values: Vec<Vec<T::Native>>,
}

impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> {
pub fn new(data_type: DataType) -> Self {
Self {
data_type,
group_values: Vec::new(),
}
}
}

impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_primitive::<T>();

// Push the `not nulls + not filtered` row into its group
self.group_values.resize(total_num_groups, Vec::new());
accumulate(
group_indices,
values,
opt_filter,
|group_index, new_value| {
self.group_values[group_index].push(new_value);
},
);

Ok(())
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
// Since aggregate filter should be applied in partial stage, in final stage there should be no filter
_opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "one argument to merge_batch");

// The merged values should be organized like as a `ListArray` which is nullable,
// but `values` in it is `non-nullable`(`values` with nulls usually generated
// from `convert_to_state`).
//
// Following is the possible and impossible input `values`:
//
// # Possible values
// ```text
// group 0: [1, 2, 3]
// group 1: null (list array is nullable)
// group 2: [6, 7, 8]
// ...
// group n: [...]
// ```
//
// # Impossible values
// ```text
// group x: [1, 2, null] (values in list array is non-nullable)
// ```
//
let input_group_values = values[0].as_list::<i32>();

// Ensure group values big enough
self.group_values.resize(total_num_groups, Vec::new());

// Extend values to related groups
// TODO: avoid using iterator of the `ListArray`, this will lead to
// many calls of `slice` of its `values` array, and `slice` is not
// so efficient.
group_indices
.iter()
.zip(input_group_values.iter())
.for_each(|(&group_index, values_opt)| {
if let Some(values) = values_opt {
let values = values.as_primitive::<T>();
self.group_values[group_index].extend(values.values().iter());
}
});

Ok(())
}

fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
// Emit values
let emit_group_values = emit_to.take_needed(&mut self.group_values);

// Build offsets
let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
offsets.push(0);
let mut cur_len = 0;
for group_value in &emit_group_values {
cur_len += group_value.len() as i32;
offsets.push(cur_len);
}
let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));

// Build inner array
let flatten_group_values =
emit_group_values.into_iter().flatten().collect::<Vec<_>>();
let group_values_array =
PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None);

// Build the result list array
let result_list_array = ListArray::new(
Arc::new(Field::new_list_field(self.data_type.clone(), false)),
offsets,
Arc::new(group_values_array),
None,
);

Ok(vec![Arc::new(result_list_array)])
}

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
// Emit values
let emit_group_values = emit_to.take_needed(&mut self.group_values);

// Calculate median for each group
let mut evaluate_result_builder = PrimitiveBuilder::<T>::new();
for values in emit_group_values {
let median = calculate_median::<T>(values);
evaluate_result_builder.append_option(median);
}

Ok(Arc::new(evaluate_result_builder.finish()))
}

fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
assert_eq!(values.len(), 1, "one argument to merge_batch");

let input_array = values[0].as_primitive::<T>();

// Directly convert the input array to states, each row will be
// seen as a respective group.
// For detail, the `input_array` will be converted to a `ListArray`.
// And if row is `not null + not filtered`, it will be converted to a list
// with only one element; otherwise, this row in `ListArray` will be set
// to null.

// Reuse values buffer in `input_array` to build `values` in `ListArray`
let values = PrimitiveArray::<T>::new(input_array.values().clone(), None);

// `offsets` in `ListArray`, each row as a list element
let offsets = (0..=input_array.len() as i32)
.into_iter()
.collect::<Vec<_>>();
let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));

// `nulls` for converted `ListArray`
let nulls = filtered_null_mask(opt_filter, input_array);

let converted_list_array = Arc::new(ListArray::new(
Arc::new(Field::new_list_field(self.data_type.clone(), false)),
offsets,
Arc::new(values),
nulls,
));

Ok(vec![converted_list_array])
}

fn supports_convert_to_state(&self) -> bool {
true
}

fn size(&self) -> usize {
self.group_values
.iter()
.map(|values| values.capacity() * size_of::<T>())
.sum::<usize>()
}
}

/// The distinct median accumulator accumulates the raw input values
/// as `ScalarValue`s
///
Expand Down
Loading