diff --git a/datafusion/optimizer/src/join_filter_pushdown.rs b/datafusion/optimizer/src/join_filter_pushdown.rs index 6fa3daf456336..7f6bcd0e0521d 100644 --- a/datafusion/optimizer/src/join_filter_pushdown.rs +++ b/datafusion/optimizer/src/join_filter_pushdown.rs @@ -17,7 +17,7 @@ //! [`JoinFilterPushdown`] pushdown join filter to scan dynamically -use datafusion_common::{tree_node::Transformed, DataFusionError}; +use datafusion_common::{tree_node::Transformed, DataFusionError, ExprSchema}; use datafusion_expr::{utils::DynamicFilterColumn, Expr, JoinType, LogicalPlan}; use crate::{optimizer::ApplyOrder, OptimizerConfig, OptimizerRule}; @@ -62,8 +62,11 @@ impl OptimizerRule for JoinFilterPushdown { for (left, right) in join.on.iter() { // Only support left to be a column if let (Expr::Column(l), Expr::Column(r)) = (left, right) { - columns.push(r.clone()); - build_side_names.push(l.name().to_owned()); + // Todo: currently only support numeric data type + if join.schema.data_type(l)?.is_numeric() { + columns.push(r.clone()); + build_side_names.push(l.name().to_owned()); + } } } diff --git a/datafusion/physical-plan/src/joins/dynamic_filters.rs b/datafusion/physical-plan/src/joins/dynamic_filters.rs index eaa1936b221e4..667c7fc7f60d1 100644 --- a/datafusion/physical-plan/src/joins/dynamic_filters.rs +++ b/datafusion/physical-plan/src/joins/dynamic_filters.rs @@ -16,17 +16,25 @@ // under the License. use arrow::array::AsArray; +use arrow::array::PrimitiveArray; +use arrow::array::{ + Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, +}; use arrow::compute::filter_record_batch; +use arrow::compute::kernels::aggregate::{max, min}; use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; -use datafusion_common::{exec_err, DataFusionError}; -use datafusion_expr::Accumulator; +use arrow_array::ArrowNativeTypeOp; +use arrow_array::{make_array, Array}; +use datafusion_common::{exec_err, DataFusionError, ScalarValue}; use datafusion_expr::Operator; -use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; use datafusion_physical_expr::PhysicalExpr; use hashbrown::HashSet; use parking_lot::Mutex; +use std::cmp::Ordering; use std::fmt; use std::sync::Arc; @@ -37,8 +45,7 @@ pub struct DynamicFilterInfo { } struct DynamicFilterInfoInner { - max_accumulators: Vec, - min_accumulators: Vec, + batches: Vec>>, final_expr: Option>, batch_count: usize, processed_partitions: HashSet, @@ -61,23 +68,13 @@ impl DynamicFilterInfo { data_types: Vec<&DataType>, total_batches: usize, ) -> Result { - let (max_accumulators, min_accumulators) = data_types - .into_iter() - .try_fold::<_, _, Result<_, DataFusionError>>( - (Vec::new(), Vec::new()), - |(mut max_acc, mut min_acc), data_type| { - max_acc.push(MaxAccumulator::try_new(data_type)?); - min_acc.push(MinAccumulator::try_new(data_type)?); - Ok((max_acc, min_acc)) - }, - )?; + let batches = vec![Vec::new(); columns.len()]; Ok(Self { columns, build_side_names, inner: Mutex::new(DynamicFilterInfoInner { - max_accumulators, - min_accumulators, + batches, final_expr: None, batch_count: total_batches, processed_partitions: HashSet::with_capacity(total_batches), @@ -119,10 +116,7 @@ impl DynamicFilterInfo { for (i, _) in self.columns.iter().enumerate() { let index = schema.index_of(&self.build_side_names[i])?; let column_data = &columns[index]; - inner.max_accumulators[i] - .update_batch(&[Arc::::clone(column_data)])?; - inner.min_accumulators[i] - .update_batch(&[Arc::::clone(column_data)])?; + inner.batches[i].push(Arc::::clone(column_data)); } if finalize { @@ -142,8 +136,9 @@ impl DynamicFilterInfo { Option>, DataFusionError, >>(None, |acc, (i, column)| { - let max_value = inner.max_accumulators[i].evaluate()?; - let min_value = inner.min_accumulators[i].evaluate()?; + // Compute min and max from batches[i] + let (min_value, max_value) = + compute_min_max_from_batches(&inner.batches[i])?; let max_scalar = max_value.clone(); let min_scalar = min_value.clone(); @@ -153,23 +148,15 @@ impl DynamicFilterInfo { let range_condition: Arc = Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( - Arc::::clone( - &min_expr, - ), + Arc::::clone(&min_expr), Operator::LtEq, - Arc::::clone( - column, - ), + Arc::::clone(column), )), Operator::And, Arc::new(BinaryExpr::new( - Arc::::clone( - column, - ), + Arc::::clone(column), Operator::LtEq, - Arc::::clone( - &max_expr, - ), + Arc::::clone(&max_expr), )), )); @@ -185,6 +172,7 @@ impl DynamicFilterInfo { })?; let filter_expr = filter_expr.expect("Filter expression should be built"); + println!("final expr is {:?}", filter_expr); inner.final_expr = Some(filter_expr); Ok(()) } @@ -243,6 +231,65 @@ impl DynamicFilterInfo { } } +macro_rules! process_min_max { + ($ARRAYS:expr, $ARRAY_TYPE:ty, $SCALAR_TY:ident, $NATIVE_TYPE:ty) => {{ + let mut min_val: Option<$NATIVE_TYPE> = None; + let mut max_val: Option<$NATIVE_TYPE> = None; + + for array in $ARRAYS { + if let Some(primitive_array) = array.as_any().downcast_ref::<$ARRAY_TYPE>() { + let batch_min = min(primitive_array); + let batch_max = max(primitive_array); + + min_val = match (min_val, batch_min) { + (Some(a), Some(b)) => Some(if a.is_lt(b) { a } else { b }), + (None, Some(b)) => Some(b), + (Some(a), None) => Some(a), + (None, None) => None, + }; + + max_val = match (max_val, batch_max) { + (Some(a), Some(b)) => Some(if a.is_gt(b) { a } else { b }), + (None, Some(b)) => Some(b), + (Some(a), None) => Some(a), + (None, None) => None, + }; + } + } + Ok(( + ScalarValue::$SCALAR_TY(min_val), + ScalarValue::$SCALAR_TY(max_val), + )) + }}; +} + +/// Currently only support numeric data types so generate a range filter +fn compute_min_max_from_batches( + arrays: &[Arc], +) -> Result<(ScalarValue, ScalarValue), DataFusionError> { + if arrays.is_empty() { + return exec_err!("should not be an empty array"); + } + + let data_type = arrays[0].data_type(); + match data_type { + DataType::Int8 => process_min_max!(arrays, Int8Array, Int8, i8), + DataType::Int16 => process_min_max!(arrays, Int16Array, Int16, i16), + DataType::Int32 => process_min_max!(arrays, Int32Array, Int32, i32), + DataType::Int64 => process_min_max!(arrays, Int64Array, Int64, i64), + DataType::UInt8 => process_min_max!(arrays, UInt8Array, UInt8, u8), + DataType::UInt16 => process_min_max!(arrays, UInt16Array, UInt16, u16), + DataType::UInt32 => process_min_max!(arrays, UInt32Array, UInt32, u32), + DataType::UInt64 => process_min_max!(arrays, UInt64Array, UInt64, u64), + DataType::Float32 => process_min_max!(arrays, Float32Array, Float32, f32), + DataType::Float64 => process_min_max!(arrays, Float64Array, Float64, f64), + _ => Err(DataFusionError::NotImplemented(format!( + "Min/Max not implemented for type {}", + data_type + ))), + } +} + /// used in partition mode pub struct PartitionedDynamicFilterInfo { partition: usize,