Skip to content

Commit

Permalink
Fix: Sort Merge Join Left Semi crashes. Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
comphead committed May 13, 2024
1 parent 8e9b490 commit c31206e
Showing 1 changed file with 37 additions and 27 deletions.
64 changes: 37 additions & 27 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,12 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::expressions::PhysicalSortExpr;
use crate::joins::utils::{
build_join_schema, check_join_is_valid, estimate_join_statistics,
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
};
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use crate::{
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution,
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};

use arrow::array::*;
use arrow::compute::{self, concat_batches, take, SortOptions};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use arrow::error::ArrowError;
use futures::{Stream, StreamExt};

use datafusion_common::{
internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result,
};
Expand All @@ -54,7 +44,17 @@ use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};

use futures::{Stream, StreamExt};
use crate::expressions::PhysicalSortExpr;
use crate::joins::utils::{
build_join_schema, check_join_is_valid, estimate_join_statistics,
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
};
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use crate::{
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution,
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};

/// join execution plan executes partitions in parallel and combines them into a set of
/// partitions.
Expand Down Expand Up @@ -1396,36 +1396,45 @@ fn get_buffered_columns(
}

// Calculate join filter bit mask considering join type specifics
// `streamed_indices` - array of streamed datasource JOINED row indices
// `mask` - array booleans representing computed join filter expression eval result:
// true = the row index matches the join filter
// false = the row index doesn't match the join filter
// `streaned_indices` have the same length as `mask`
fn get_filtered_join_mask(
join_type: JoinType,
streamed_indices: UInt64Array,
mask: &BooleanArray,
) -> Option<BooleanArray> {
// for LeftSemi Join the filter mask should be calculated in its own way:
// if we find at least one matching row for specific streaming index
// we dont need to check any others for the same index
// we don't need to check any others for the same index
if matches!(join_type, JoinType::LeftSemi) {
// have we seen a filter match for a streaming index before
let mut seen_as_true: bool = false;
let streamed_indices_length = streamed_indices.len();
let mut corrected_mask: Vec<bool> = vec![false; streamed_indices_length];
let mut corrected_mask: BooleanBuilder =
BooleanBuilder::with_capacity(streamed_indices_length);

#[allow(clippy::needless_range_loop)]
for i in 0..streamed_indices_length {
// if for a streaming index its a match first time, set it as true
// LeftSemi respects only first true values for specific streaming index,
// others true values for the same index must be false
if mask.value(i) && !seen_as_true {
seen_as_true = true;
corrected_mask[i] = true;
corrected_mask.append_value(true);
} else {
corrected_mask.append_value(false);
}

// if switched to next streaming index(e.g from 0 to 1, or from 1 to 2), we reset seen_as_true flag
// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
if i < streamed_indices_length - 1
&& streamed_indices.value(i) != streamed_indices.value(i + 1)
{
seen_as_true = false;
}
}
Some(BooleanArray::from(corrected_mask))
Some(corrected_mask.finish())
} else {
None
}
Expand Down Expand Up @@ -1670,19 +1679,12 @@ fn is_join_arrays_equal(
mod tests {
use std::sync::Arc;

use crate::expressions::Column;
use crate::joins::utils::JoinOn;
use crate::joins::SortMergeJoinExec;
use crate::memory::MemoryExec;
use crate::test::build_table_i32;
use crate::{common, ExecutionPlan};

use crate::joins::sort_merge_join::get_filtered_join_mask;
use arrow::array::{Date32Array, Date64Array, Int32Array};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow_array::{BooleanArray, UInt64Array};

use datafusion_common::JoinType::LeftSemi;
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result,
Expand All @@ -1691,6 +1693,14 @@ mod tests {
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_execution::TaskContext;

use crate::expressions::Column;
use crate::joins::sort_merge_join::get_filtered_join_mask;
use crate::joins::utils::JoinOn;
use crate::joins::SortMergeJoinExec;
use crate::memory::MemoryExec;
use crate::test::build_table_i32;
use crate::{common, ExecutionPlan};

fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
Expand Down

0 comments on commit c31206e

Please sign in to comment.