From 0396fc41846c6fd1f8aa7be2f329d54dbfc96c2c Mon Sep 17 00:00:00 2001 From: kamille Date: Wed, 4 Sep 2024 23:39:00 +0800 Subject: [PATCH] use filter nulls to impl quick filter for some arrays. --- .../expr-common/src/groups_accumulator.rs | 2 +- .../src/aggregate/groups_accumulator.rs | 104 ++++++++++++------ .../aggregate/groups_accumulator/prim_op.rs | 2 +- datafusion/functions-aggregate/src/count.rs | 2 +- .../physical-plan/src/aggregates/row_hash.rs | 2 +- 5 files changed, 77 insertions(+), 35 deletions(-) diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 834a8f8a91a3..e66b27d073d1 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -199,7 +199,7 @@ pub trait GroupsAccumulator: Send { /// /// [`Accumulator::state`]: crate::accumulator::Accumulator::state fn convert_to_state( - &mut self, + &self, _values: &[ArrayRef], _opt_filter: Option<&BooleanArray>, ) -> Result> { diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index a74f0a33ac76..1b92a0b87376 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -22,12 +22,16 @@ pub mod accumulate; pub mod bool_op; pub mod nulls; pub mod prim_op; -use std::mem; use arrow::{ - array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}, + array::{ + ArrayRef, AsArray, BooleanArray, PrimitiveArray, + }, compute, - datatypes::UInt32Type, + datatypes::{ + DataType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + }, }; use datafusion_common::{ arrow_datafusion_err, utils::get_arrayref_at_indices, DataFusionError, Result, @@ -36,6 +40,8 @@ use datafusion_common::{ use datafusion_expr_common::accumulator::Accumulator; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; +use crate::aggregate::groups_accumulator::nulls::{filtered_null_mask, set_nulls}; + /// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`] /// /// While [`Accumulator`] are simpler to implement and can support @@ -236,8 +242,11 @@ impl GroupsAccumulatorAdapter { let state = &mut self.states[group_idx]; sizes_pre += state.size(); - let values_to_accumulate = - slice_and_maybe_filter(&values, opt_filter.as_ref().map(|f| f.as_boolean()), offsets)?; + let values_to_accumulate = slice_and_maybe_filter( + &values, + opt_filter.as_ref().map(|f| f.as_boolean()), + offsets, + )?; (f)(state.accumulator.as_mut(), &values_to_accumulate)?; // clear out the state so they are empty for next @@ -318,7 +327,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { result } - + // filtered_null_mask(opt_filter, &values); fn state(&mut self, emit_to: EmitTo) -> Result> { let vec_size_pre = self.states.allocated_size(); let states = emit_to.take_needed(&mut self.states); @@ -379,34 +388,29 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { } fn convert_to_state( - &mut self, + &self, values: &[ArrayRef], opt_filter: Option<&BooleanArray>, ) -> Result> { let num_rows = values[0].len(); - // Make the buffer large enough. - self.ensure_convert_buffer_large_enough(num_rows)?; - - // Each row has its respective group. - let mut results = vec![Vec::with_capacity(num_rows); values.len()]; + // Each row has its respective group + let mut results = vec![]; for row_idx in 0..num_rows { - // Take the empty to update, and replace with the new empty. - let new_accumulator = (self.factory)()?; - let mut convert_state = - mem::replace(&mut self.convert_state_buffer[row_idx], new_accumulator); - - // Convert row to state by applying it to the empty accumulator. - let values_to_accumulate = slice_and_maybe_filter( - &values, - opt_filter, - &[row_idx, row_idx + 1], - )?; - convert_state.update_batch(&values_to_accumulate)?; - let row_states = convert_state.state()?; + // Take the empty to update, and replace with the new empty + let mut converted_accumulator = (self.factory)()?; - for (col_idx, col_state) in row_states.into_iter().enumerate() { - results[col_idx].push(col_state); + // Convert row to state by applying it to the empty accumulator.\ + let values_to_accumulate = + slice_and_maybe_filter(&values, opt_filter, &[row_idx, row_idx + 1])?; + converted_accumulator.update_batch(&values_to_accumulate)?; + let states = converted_accumulator.state()?; + + // Resize results to have enough columns according to the converted states + results.resize_with(states.len(), || Vec::with_capacity(num_rows)); + // Add the states to results + for (idx, state_val) in states.into_iter().enumerate() { + results[idx].push(state_val); } } @@ -467,15 +471,53 @@ pub(crate) fn slice_and_maybe_filter( .collect(); if let Some(f) = filter_opt { - let filter_array = f.slice(offset, length); + let filter = f.slice(offset, length); sliced_arrays .iter() - .map(|array| { - compute::filter(array, &filter_array).map_err(|e| arrow_datafusion_err!(e)) - }) + .map(|array| filter_array(array, &filter)) .collect() } else { Ok(sliced_arrays) } } + +macro_rules! filter_primitive_array { + ($Values:ident, $PType:ident, $Filter:expr) => {{ + let array = $Values.as_primitive::<$PType>().clone(); + let nulls = filtered_null_mask($Filter, &array); + Ok(std::sync::Arc::new(set_nulls(array, nulls))) + }}; +} + +// TODO: +fn filter_array( + values: &ArrayRef, + filter: &BooleanArray, +) -> Result { + match values.data_type() { + DataType::Boolean => { + let array = values.as_boolean().clone(); + let array_null_buffer_filtered = filtered_null_mask(Some(filter), &array); + let (array_values_buf, _) = array.into_parts(); + Ok(std::sync::Arc::new(BooleanArray::new( + array_values_buf, + array_null_buffer_filtered, + ))) + } + DataType::Int8 => filter_primitive_array!(values, Int8Type, Some(filter)), + DataType::Int16 => filter_primitive_array!(values, Int16Type, Some(filter)), + DataType::Int32 => filter_primitive_array!(values, Int32Type, Some(filter)), + DataType::Int64 => filter_primitive_array!(values, Int64Type, Some(filter)), + DataType::UInt8 => filter_primitive_array!(values, UInt8Type, Some(filter)), + DataType::UInt16 => filter_primitive_array!(values, UInt16Type, Some(filter)), + DataType::UInt32 => filter_primitive_array!(values, UInt32Type, Some(filter)), + DataType::UInt64 => filter_primitive_array!(values, UInt64Type, Some(filter)), + DataType::Float16 => filter_primitive_array!(values, Float16Type, Some(filter)), + DataType::Float32 => filter_primitive_array!(values, Float32Type, Some(filter)), + DataType::Float64 => filter_primitive_array!(values, Float64Type, Some(filter)), + _ => { + compute::filter(values, filter).map_err(|e| arrow_datafusion_err!(e)) + } + } +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs index f7143d02dff4..8bbcf756c37c 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs @@ -143,7 +143,7 @@ where /// - null otherwise /// fn convert_to_state( - &mut self, + &self, values: &[ArrayRef], opt_filter: Option<&BooleanArray>, ) -> Result> { diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 53d83a209421..417e28e72a71 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -453,7 +453,7 @@ impl GroupsAccumulator for CountGroupsAccumulator { /// * `1` (for non-null, non filtered values) /// * `0` (for null values) fn convert_to_state( - &mut self, + &self, values: &[ArrayRef], opt_filter: Option<&BooleanArray>, ) -> Result> { diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1fc33acf7ff5..d022bb007d9b 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -1065,7 +1065,7 @@ impl GroupedHashAggregateStream { let iter = self .accumulators - .iter_mut() + .iter() .zip(input_values.iter()) .zip(filter_values.iter());