From 85b26d8c631082cc8ed71d285111805e20c409ab Mon Sep 17 00:00:00 2001 From: Lordworms Date: Wed, 6 Nov 2024 17:53:34 -0800 Subject: [PATCH] add filter --- datafusion/core/src/physical_planner.rs | 14 ++--- .../src/joins/dynamic_filters.rs | 46 +++++++++++++-- .../physical-plan/src/joins/hash_join.rs | 25 ++++++--- datafusion/physical-plan/src/joins/utils.rs | 56 ++++++++++++++++++- 4 files changed, 118 insertions(+), 23 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ec14d48a093f9..79531b9c5fc5f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1143,14 +1143,12 @@ impl DefaultPhysicalPlanner { .iter() .map(|dynamic_column| { let column = dynamic_column.column(); - let index = join.schema().index_of(column.name())?; - let physical_column = Arc::new( - datafusion_physical_expr::expressions::Column::new( - &column.name, - index, - ), - ); - let build_side_name = dynamic_column.build_name().to_owned(); + let index = + join.schema().index_of(column.name())?; + let physical_column = + Arc::new(Column::new(&column.name, index)); + let build_side_name = + dynamic_column.build_name().to_owned(); Ok((physical_column, build_side_name)) }) .collect::>()?; diff --git a/datafusion/physical-plan/src/joins/dynamic_filters.rs b/datafusion/physical-plan/src/joins/dynamic_filters.rs index 3e45c14308c77..8d360fe3bfd88 100644 --- a/datafusion/physical-plan/src/joins/dynamic_filters.rs +++ b/datafusion/physical-plan/src/joins/dynamic_filters.rs @@ -14,8 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - -use arrow::array::AsArray; +use arrow::array::{AsArray, BooleanBuilder}; use arrow::array::{ Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, @@ -24,10 +23,10 @@ use arrow::compute::filter_record_batch; use arrow::compute::kernels::aggregate::{max, max_string, min, min_string}; use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; -use arrow_array::Array; use arrow_array::ArrowNativeTypeOp; use arrow_array::StringArray; -use datafusion_common::{exec_err, DataFusionError, ScalarValue}; +use arrow_array::{Array, ArrayRef}; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, ScalarValue}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; use datafusion_physical_expr::PhysicalExpr; @@ -36,6 +35,8 @@ use parking_lot::Mutex; use std::fmt; use std::sync::Arc; +use super::utils::JoinHashMap; + pub struct DynamicFilterInfo { columns: Vec>, build_side_names: Vec, @@ -184,7 +185,7 @@ impl DynamicFilterInfo { records: &RecordBatch, ) -> Result { let filter_expr = match self.inner.lock().final_expr.as_ref() { - Some(expr) => Arc::::clone(expr), + Some(expr) => Arc::::clone(expr), None => { return exec_err!( "Filter expression should have been created before calling filter_batch" @@ -354,4 +355,39 @@ impl PartitionedDynamicFilterInfo { self.dynamic_filter_info .merge_batch_and_check_finalized(records, self.partition) } + + pub fn filter_probe_batch( + &self, + batch: &RecordBatch, + hashes: &[u64], + hash_map: &JoinHashMap, + ) -> Result<(RecordBatch, Vec), DataFusionError> { + let left_hash_set = hash_map.extract_unique_keys(); + + let mut mask_builder = BooleanBuilder::new(); + for hash in hashes.iter() { + mask_builder.append_value(left_hash_set.contains(hash)); + } + let mask = mask_builder.finish(); + + let filtered_columns = batch + .columns() + .iter() + .map(|col| { + arrow::compute::filter(col, &mask).map_err(|e| arrow_datafusion_err!(e)) + }) + .collect::, DataFusionError>>()?; + + let filtered_batch = RecordBatch::try_new(batch.schema(), filtered_columns)?; + + let filtered_hashes = hashes + .iter() + .zip(mask.iter()) + .filter_map(|(hash, keep)| { + keep.and_then(|k| if k { Some(*hash) } else { None }) + }) + .collect(); + + Ok((filtered_batch, filtered_hashes)) + } } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 58f22be9b6b07..bf82cf57006cb 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -49,11 +49,6 @@ use crate::{ Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use std::fmt; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::task::Poll; -use std::{any::Any, vec}; use arrow::array::{ Array, ArrayRef, BooleanArray, BooleanBufferBuilder, UInt32Array, UInt64Array, @@ -1395,7 +1390,9 @@ impl HashJoinStream { self.state = HashJoinStreamState::ExhaustedProbeSide; } Some(Ok(batch)) => { - // Precalculate hash values for fetched batch + let left_data = Arc::::clone( + &self.build_side.try_as_ready()?.left_data, + ); let keys_values = self .on_right .iter() @@ -1406,12 +1403,24 @@ impl HashJoinStream { self.hashes_buffer.resize(batch.num_rows(), 0); create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?; + let (filtered_batch, filtered_hashes) = + if let Some(dynamic_filter) = &self.dynamic_filter_info { + dynamic_filter.filter_probe_batch( + &batch, + &self.hashes_buffer, + &left_data.hash_map, + )? + } else { + (batch, std::mem::take(&mut self.hashes_buffer)) + }; + self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); + self.join_metrics.input_rows.add(filtered_batch.num_rows()); + self.hashes_buffer = filtered_hashes; self.state = HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { - batch, + batch: filtered_batch, offset: (0, None), joined_probe_idx: None, }); diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 0366c9fa5e462..0b777949b554c 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -140,6 +140,18 @@ impl JoinHashMap { next: vec![0; capacity], } } + + /// extract all unique keys of this join hash map + pub fn extract_unique_keys(&self) -> HashSet { + let mut unique_keys = HashSet::new(); + unsafe { + self.map.iter().for_each(|entry| { + let (hash, _) = entry.as_ref(); + unique_keys.insert(hash.to_owned()); + }) + }; + unique_keys + } } // Type of offsets for obtaining indices from JoinHashMap. @@ -371,8 +383,48 @@ impl JoinHashMapType for JoinHashMap { } impl Debug for JoinHashMap { - fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { - Ok(()) + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "JoinHashMap {{")?; + writeln!(f, " map:")?; + writeln!(f, " ----------")?; + + let mut entries: Vec<_> = unsafe { self.map.iter().collect() }; + entries.sort_by_key(|bucket| unsafe { bucket.as_ref().0 }); + + for bucket in entries { + let mut indices = Vec::new(); + let mut curr_idx = unsafe { bucket.as_ref().1 }; + + while curr_idx > 0 { + indices.push(curr_idx - 1); + curr_idx = self.next[(curr_idx - 1) as usize]; + } + + indices.reverse(); + + writeln!( + f, + " | {:3} | {} | -> {:?}", + unsafe { bucket.as_ref().0 }, + unsafe { bucket.as_ref().1 }, + indices + )?; + } + + writeln!(f, " ----------")?; + writeln!(f, "\n next:")?; + writeln!(f, " ---------------------")?; + write!(f, " |")?; + for &next_idx in self.next.iter() { + write!(f, " {:2} |", next_idx)?; + } + writeln!(f)?; + write!(f, " |")?; + for i in 0..self.next.len() { + write!(f, " {:2} |", i)?; + } + writeln!(f, "\n ---------------------")?; + writeln!(f, "}}") } }