Skip to content

Commit

Permalink
use batches instead of accumulator
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Oct 22, 2024
1 parent 7922441 commit 959c158
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 38 deletions.
9 changes: 6 additions & 3 deletions datafusion/optimizer/src/join_filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! [`JoinFilterPushdown`] pushdown join filter to scan dynamically
use datafusion_common::{tree_node::Transformed, DataFusionError};
use datafusion_common::{tree_node::Transformed, DataFusionError, ExprSchema};
use datafusion_expr::{utils::DynamicFilterColumn, Expr, JoinType, LogicalPlan};

use crate::{optimizer::ApplyOrder, OptimizerConfig, OptimizerRule};
Expand Down Expand Up @@ -62,8 +62,11 @@ impl OptimizerRule for JoinFilterPushdown {
for (left, right) in join.on.iter() {
// Only support left to be a column
if let (Expr::Column(l), Expr::Column(r)) = (left, right) {
columns.push(r.clone());
build_side_names.push(l.name().to_owned());
// Todo: currently only support numeric data type
if join.schema.data_type(l)?.is_numeric() {
columns.push(r.clone());
build_side_names.push(l.name().to_owned());
}
}
}

Expand Down
117 changes: 82 additions & 35 deletions datafusion/physical-plan/src/joins/dynamic_filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,25 @@
// under the License.

use arrow::array::AsArray;
use arrow::array::PrimitiveArray;
use arrow::array::{
Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array,
Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array,
UInt8Array,
};
use arrow::compute::filter_record_batch;
use arrow::compute::kernels::aggregate::{max, min};
use arrow::datatypes::DataType;
use arrow::record_batch::RecordBatch;
use datafusion_common::{exec_err, DataFusionError};
use datafusion_expr::Accumulator;
use arrow_array::ArrowNativeTypeOp;
use arrow_array::{make_array, Array};
use datafusion_common::{exec_err, DataFusionError, ScalarValue};
use datafusion_expr::Operator;
use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator};
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
use datafusion_physical_expr::PhysicalExpr;
use hashbrown::HashSet;
use parking_lot::Mutex;
use std::cmp::Ordering;
use std::fmt;
use std::sync::Arc;

Expand All @@ -37,8 +45,7 @@ pub struct DynamicFilterInfo {
}

struct DynamicFilterInfoInner {
max_accumulators: Vec<MaxAccumulator>,
min_accumulators: Vec<MinAccumulator>,
batches: Vec<Vec<Arc<dyn Array>>>,
final_expr: Option<Arc<dyn PhysicalExpr>>,
batch_count: usize,
processed_partitions: HashSet<usize>,
Expand All @@ -61,23 +68,13 @@ impl DynamicFilterInfo {
data_types: Vec<&DataType>,
total_batches: usize,
) -> Result<Self, DataFusionError> {
let (max_accumulators, min_accumulators) = data_types
.into_iter()
.try_fold::<_, _, Result<_, DataFusionError>>(
(Vec::new(), Vec::new()),
|(mut max_acc, mut min_acc), data_type| {
max_acc.push(MaxAccumulator::try_new(data_type)?);
min_acc.push(MinAccumulator::try_new(data_type)?);
Ok((max_acc, min_acc))
},
)?;
let batches = vec![Vec::new(); columns.len()];

Ok(Self {
columns,
build_side_names,
inner: Mutex::new(DynamicFilterInfoInner {
max_accumulators,
min_accumulators,
batches,
final_expr: None,
batch_count: total_batches,
processed_partitions: HashSet::with_capacity(total_batches),
Expand Down Expand Up @@ -119,10 +116,7 @@ impl DynamicFilterInfo {
for (i, _) in self.columns.iter().enumerate() {
let index = schema.index_of(&self.build_side_names[i])?;
let column_data = &columns[index];
inner.max_accumulators[i]
.update_batch(&[Arc::<dyn arrow_array::Array>::clone(column_data)])?;
inner.min_accumulators[i]
.update_batch(&[Arc::<dyn arrow_array::Array>::clone(column_data)])?;
inner.batches[i].push(Arc::<dyn Array>::clone(column_data));
}

if finalize {
Expand All @@ -142,8 +136,9 @@ impl DynamicFilterInfo {
Option<Arc<dyn PhysicalExpr>>,
DataFusionError,
>>(None, |acc, (i, column)| {
let max_value = inner.max_accumulators[i].evaluate()?;
let min_value = inner.min_accumulators[i].evaluate()?;
// Compute min and max from batches[i]
let (min_value, max_value) =
compute_min_max_from_batches(&inner.batches[i])?;

let max_scalar = max_value.clone();
let min_scalar = min_value.clone();
Expand All @@ -153,23 +148,15 @@ impl DynamicFilterInfo {

let range_condition: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::<dyn datafusion_physical_expr::PhysicalExpr>::clone(
&min_expr,
),
Arc::<dyn PhysicalExpr>::clone(&min_expr),
Operator::LtEq,
Arc::<datafusion_physical_expr::expressions::Column>::clone(
column,
),
Arc::<Column>::clone(column),
)),
Operator::And,
Arc::new(BinaryExpr::new(
Arc::<datafusion_physical_expr::expressions::Column>::clone(
column,
),
Arc::<Column>::clone(column),
Operator::LtEq,
Arc::<dyn datafusion_physical_expr::PhysicalExpr>::clone(
&max_expr,
),
Arc::<dyn PhysicalExpr>::clone(&max_expr),
)),
));

Expand All @@ -185,6 +172,7 @@ impl DynamicFilterInfo {
})?;

let filter_expr = filter_expr.expect("Filter expression should be built");
println!("final expr is {:?}", filter_expr);
inner.final_expr = Some(filter_expr);
Ok(())
}
Expand Down Expand Up @@ -243,6 +231,65 @@ impl DynamicFilterInfo {
}
}

macro_rules! process_min_max {
($ARRAYS:expr, $ARRAY_TYPE:ty, $SCALAR_TY:ident, $NATIVE_TYPE:ty) => {{
let mut min_val: Option<$NATIVE_TYPE> = None;
let mut max_val: Option<$NATIVE_TYPE> = None;

for array in $ARRAYS {
if let Some(primitive_array) = array.as_any().downcast_ref::<$ARRAY_TYPE>() {
let batch_min = min(primitive_array);
let batch_max = max(primitive_array);

min_val = match (min_val, batch_min) {
(Some(a), Some(b)) => Some(if a.is_lt(b) { a } else { b }),
(None, Some(b)) => Some(b),
(Some(a), None) => Some(a),
(None, None) => None,
};

max_val = match (max_val, batch_max) {
(Some(a), Some(b)) => Some(if a.is_gt(b) { a } else { b }),
(None, Some(b)) => Some(b),
(Some(a), None) => Some(a),
(None, None) => None,
};
}
}
Ok((
ScalarValue::$SCALAR_TY(min_val),
ScalarValue::$SCALAR_TY(max_val),
))
}};
}

/// Currently only support numeric data types so generate a range filter
fn compute_min_max_from_batches(
arrays: &[Arc<dyn Array>],
) -> Result<(ScalarValue, ScalarValue), DataFusionError> {
if arrays.is_empty() {
return exec_err!("should not be an empty array");
}

let data_type = arrays[0].data_type();
match data_type {
DataType::Int8 => process_min_max!(arrays, Int8Array, Int8, i8),
DataType::Int16 => process_min_max!(arrays, Int16Array, Int16, i16),
DataType::Int32 => process_min_max!(arrays, Int32Array, Int32, i32),
DataType::Int64 => process_min_max!(arrays, Int64Array, Int64, i64),
DataType::UInt8 => process_min_max!(arrays, UInt8Array, UInt8, u8),
DataType::UInt16 => process_min_max!(arrays, UInt16Array, UInt16, u16),
DataType::UInt32 => process_min_max!(arrays, UInt32Array, UInt32, u32),
DataType::UInt64 => process_min_max!(arrays, UInt64Array, UInt64, u64),
DataType::Float32 => process_min_max!(arrays, Float32Array, Float32, f32),
DataType::Float64 => process_min_max!(arrays, Float64Array, Float64, f64),
_ => Err(DataFusionError::NotImplemented(format!(
"Min/Max not implemented for type {}",
data_type
))),
}
}

/// used in partition mode
pub struct PartitionedDynamicFilterInfo {
partition: usize,
Expand Down

0 comments on commit 959c158

Please sign in to comment.