From 8e80e7802b874d044ab1da2c84068d77527e38b8 Mon Sep 17 00:00:00 2001 From: kamille Date: Wed, 4 Sep 2024 17:05:15 +0800 Subject: [PATCH] tmp --- .../src/aggregate/groups_accumulator.rs | 59 +++++++++++++++---- .../aggregate/groups_accumulator/bool_op.rs | 2 +- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index a74f0a33ac76e..be8f71178ad41 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -25,9 +25,9 @@ pub mod prim_op; use std::mem; use arrow::{ - array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}, + array::{ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, PrimitiveArray, TimestampSecondArray}, compute, - datatypes::UInt32Type, + datatypes::{DataType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type}, }; use datafusion_common::{ arrow_datafusion_err, utils::get_arrayref_at_indices, DataFusionError, Result, @@ -36,6 +36,8 @@ use datafusion_common::{ use datafusion_expr_common::accumulator::Accumulator; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; +use crate::aggregate::groups_accumulator::nulls::{filtered_null_mask, set_nulls}; + /// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`] /// /// While [`Accumulator`] are simpler to implement and can support @@ -236,8 +238,11 @@ impl GroupsAccumulatorAdapter { let state = &mut self.states[group_idx]; sizes_pre += state.size(); - let values_to_accumulate = - slice_and_maybe_filter(&values, opt_filter.as_ref().map(|f| f.as_boolean()), offsets)?; + let values_to_accumulate = slice_and_maybe_filter( + &values, + opt_filter.as_ref().map(|f| f.as_boolean()), + offsets, + )?; (f)(state.accumulator.as_mut(), &values_to_accumulate)?; // clear out the state so they are empty for next @@ -318,7 +323,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { result } - + // filtered_null_mask(opt_filter, &values); fn state(&mut self, emit_to: EmitTo) -> Result> { let vec_size_pre = self.states.allocated_size(); let states = emit_to.take_needed(&mut self.states); @@ -397,11 +402,8 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { mem::replace(&mut self.convert_state_buffer[row_idx], new_accumulator); // Convert row to state by applying it to the empty accumulator. - let values_to_accumulate = slice_and_maybe_filter( - &values, - opt_filter, - &[row_idx, row_idx + 1], - )?; + let values_to_accumulate = + slice_and_maybe_filter(&values, opt_filter, &[row_idx, row_idx + 1])?; convert_state.update_batch(&values_to_accumulate)?; let row_states = convert_state.state()?; @@ -472,10 +474,45 @@ pub(crate) fn slice_and_maybe_filter( sliced_arrays .iter() .map(|array| { - compute::filter(array, &filter_array).map_err(|e| arrow_datafusion_err!(e)) + compute::filter(array, &filter_array) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect() } else { Ok(sliced_arrays) } } + + +macro_rules! filter_primitive_array { + ($Values:ident, $PType:ident, $Filter:expr) => {{ + let array = $Values.as_primitive::<$PType>().clone(); + let nulls = filtered_null_mask($Filter, &array); + Ok(std::sync::Arc::new(set_nulls(array, nulls))) + }} +} + +fn filter_array( + values: &ArrayRef, + filter: &BooleanArray, +) -> Result { + match values.data_type() { + DataType::Boolean => { + let values_null_buffer_filtered = filtered_null_mask(Some(filter), &values); + let (values_buf, _) = values.into_parts(); + Ok(Arc::new(BooleanArray::new(values_buf, values_null_buffer_filtered))) + }, + DataType::Int8 => filter_primitive_array!(values, Int8Type, Some(filter)), + DataType::Int16 => filter_primitive_array!(values, Int16Type, Some(filter)), + DataType::Int32 => filter_primitive_array!(values, Int32Type, Some(filter)), + DataType::Int64 => filter_primitive_array!(values, Int64Type, Some(filter)), + DataType::UInt8 => filter_primitive_array!(values, UInt8Type, Some(filter)), + DataType::UInt16 => filter_primitive_array!(values, UInt16Type, Some(filter)), + DataType::UInt32 => filter_primitive_array!(values, UInt32Type, Some(filter)), + DataType::UInt64 => filter_primitive_array!(values, UInt64Type, Some(filter)), + DataType::Float16 => filter_primitive_array!(values, Float16Type, Some(filter)), + DataType::Float32 => filter_primitive_array!(values, Float32Type, Some(filter)), + DataType::Float64 => filter_primitive_array!(values, Float64Type, Some(filter)), + other_type => compute::filter(values, predicate), + } +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs index 149312e5a9c0f..05a16c70b65c3 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs @@ -145,7 +145,7 @@ where } fn convert_to_state( - &self, + &mut self, values: &[ArrayRef], opt_filter: Option<&BooleanArray>, ) -> Result> {