Skip to content

Commit

Permalink
add filter
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Nov 7, 2024
1 parent 092d2bf commit 85b26d8
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 23 deletions.
14 changes: 6 additions & 8 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1143,14 +1143,12 @@ impl DefaultPhysicalPlanner {
.iter()
.map(|dynamic_column| {
let column = dynamic_column.column();
let index = join.schema().index_of(column.name())?;
let physical_column = Arc::new(
datafusion_physical_expr::expressions::Column::new(
&column.name,
index,
),
);
let build_side_name = dynamic_column.build_name().to_owned();
let index =
join.schema().index_of(column.name())?;
let physical_column =
Arc::new(Column::new(&column.name, index));
let build_side_name =
dynamic_column.build_name().to_owned();
Ok((physical_column, build_side_name))
})
.collect::<Result<_, DataFusionError>>()?;
Expand Down
46 changes: 41 additions & 5 deletions datafusion/physical-plan/src/joins/dynamic_filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::AsArray;
use arrow::array::{AsArray, BooleanBuilder};
use arrow::array::{
Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
UInt16Array, UInt32Array, UInt64Array, UInt8Array,
Expand All @@ -24,10 +23,10 @@ use arrow::compute::filter_record_batch;
use arrow::compute::kernels::aggregate::{max, max_string, min, min_string};
use arrow::datatypes::DataType;
use arrow::record_batch::RecordBatch;
use arrow_array::Array;
use arrow_array::ArrowNativeTypeOp;
use arrow_array::StringArray;
use datafusion_common::{exec_err, DataFusionError, ScalarValue};
use arrow_array::{Array, ArrayRef};
use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, ScalarValue};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
use datafusion_physical_expr::PhysicalExpr;
Expand All @@ -36,6 +35,8 @@ use parking_lot::Mutex;
use std::fmt;
use std::sync::Arc;

use super::utils::JoinHashMap;

pub struct DynamicFilterInfo {
columns: Vec<Arc<Column>>,
build_side_names: Vec<String>,
Expand Down Expand Up @@ -184,7 +185,7 @@ impl DynamicFilterInfo {
records: &RecordBatch,
) -> Result<RecordBatch, DataFusionError> {
let filter_expr = match self.inner.lock().final_expr.as_ref() {
Some(expr) => Arc::<dyn datafusion_physical_expr::PhysicalExpr>::clone(expr),
Some(expr) => Arc::<dyn PhysicalExpr>::clone(expr),
None => {
return exec_err!(
"Filter expression should have been created before calling filter_batch"
Expand Down Expand Up @@ -354,4 +355,39 @@ impl PartitionedDynamicFilterInfo {
self.dynamic_filter_info
.merge_batch_and_check_finalized(records, self.partition)
}

pub fn filter_probe_batch(
&self,
batch: &RecordBatch,
hashes: &[u64],
hash_map: &JoinHashMap,
) -> Result<(RecordBatch, Vec<u64>), DataFusionError> {
let left_hash_set = hash_map.extract_unique_keys();

let mut mask_builder = BooleanBuilder::new();
for hash in hashes.iter() {
mask_builder.append_value(left_hash_set.contains(hash));
}
let mask = mask_builder.finish();

let filtered_columns = batch
.columns()
.iter()
.map(|col| {
arrow::compute::filter(col, &mask).map_err(|e| arrow_datafusion_err!(e))
})
.collect::<Result<Vec<ArrayRef>, DataFusionError>>()?;

let filtered_batch = RecordBatch::try_new(batch.schema(), filtered_columns)?;

let filtered_hashes = hashes
.iter()
.zip(mask.iter())
.filter_map(|(hash, keep)| {
keep.and_then(|k| if k { Some(*hash) } else { None })
})
.collect();

Ok((filtered_batch, filtered_hashes))
}
}
25 changes: 17 additions & 8 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,6 @@ use crate::{
Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream,
Statistics,
};
use std::fmt;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::Poll;
use std::{any::Any, vec};

use arrow::array::{
Array, ArrayRef, BooleanArray, BooleanBufferBuilder, UInt32Array, UInt64Array,
Expand Down Expand Up @@ -1395,7 +1390,9 @@ impl HashJoinStream {
self.state = HashJoinStreamState::ExhaustedProbeSide;
}
Some(Ok(batch)) => {
// Precalculate hash values for fetched batch
let left_data = Arc::<JoinLeftData>::clone(
&self.build_side.try_as_ready()?.left_data,
);
let keys_values = self
.on_right
.iter()
Expand All @@ -1406,12 +1403,24 @@ impl HashJoinStream {
self.hashes_buffer.resize(batch.num_rows(), 0);
create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?;

let (filtered_batch, filtered_hashes) =
if let Some(dynamic_filter) = &self.dynamic_filter_info {
dynamic_filter.filter_probe_batch(
&batch,
&self.hashes_buffer,
&left_data.hash_map,
)?
} else {
(batch, std::mem::take(&mut self.hashes_buffer))
};

self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
self.join_metrics.input_rows.add(filtered_batch.num_rows());

self.hashes_buffer = filtered_hashes;
self.state =
HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState {
batch,
batch: filtered_batch,
offset: (0, None),
joined_probe_idx: None,
});
Expand Down
56 changes: 54 additions & 2 deletions datafusion/physical-plan/src/joins/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,18 @@ impl JoinHashMap {
next: vec![0; capacity],
}
}

/// extract all unique keys of this join hash map
pub fn extract_unique_keys(&self) -> HashSet<u64> {
let mut unique_keys = HashSet::new();
unsafe {
self.map.iter().for_each(|entry| {
let (hash, _) = entry.as_ref();
unique_keys.insert(hash.to_owned());
})
};
unique_keys
}
}

// Type of offsets for obtaining indices from JoinHashMap.
Expand Down Expand Up @@ -371,8 +383,48 @@ impl JoinHashMapType for JoinHashMap {
}

impl Debug for JoinHashMap {
fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
Ok(())
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "JoinHashMap {{")?;
writeln!(f, " map:")?;
writeln!(f, " ----------")?;

let mut entries: Vec<_> = unsafe { self.map.iter().collect() };
entries.sort_by_key(|bucket| unsafe { bucket.as_ref().0 });

for bucket in entries {
let mut indices = Vec::new();
let mut curr_idx = unsafe { bucket.as_ref().1 };

while curr_idx > 0 {
indices.push(curr_idx - 1);
curr_idx = self.next[(curr_idx - 1) as usize];
}

indices.reverse();

writeln!(
f,
" | {:3} | {} | -> {:?}",
unsafe { bucket.as_ref().0 },
unsafe { bucket.as_ref().1 },
indices
)?;
}

writeln!(f, " ----------")?;
writeln!(f, "\n next:")?;
writeln!(f, " ---------------------")?;
write!(f, " |")?;
for &next_idx in self.next.iter() {
write!(f, " {:2} |", next_idx)?;
}
writeln!(f)?;
write!(f, " |")?;
for i in 0..self.next.len() {
write!(f, " {:2} |", i)?;
}
writeln!(f, "\n ---------------------")?;
writeln!(f, "}}")
}
}

Expand Down

0 comments on commit 85b26d8

Please sign in to comment.