Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve ordering equivalencies on with_reorder #13770

Merged
211 changes: 210 additions & 1 deletion datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,59 @@ impl EquivalenceGroup {
JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(),
}
}

/// Checks if two expressions are equal either directly or through equivalence classes.
/// For complex expressions (e.g. a + b), checks that the expression trees are structurally
/// identical and their leaf nodes are equivalent either directly or through equivalence classes.
pub fn exprs_equal(
gokselk marked this conversation as resolved.
Show resolved Hide resolved
&self,
left: &Arc<dyn PhysicalExpr>,
right: &Arc<dyn PhysicalExpr>,
) -> bool {
// Direct equality check
if left.eq(right) {
return true;
}

// Check if expressions are equivalent through equivalence classes
// We need to check both directions since expressions might be in different classes
if let Some(left_class) = self.get_equivalence_class(left) {
if left_class.contains(right) {
return true;
}
}
if let Some(right_class) = self.get_equivalence_class(right) {
if right_class.contains(left) {
return true;
}
}

// For non-leaf nodes, check structural equality
let left_children = left.children();
let right_children = right.children();

// If either expression is a leaf node and we haven't found equality yet,
// they must be different
if left_children.is_empty() || right_children.is_empty() {
return false;
}

// Type equality check through reflection
if left.as_any().type_id() != right.as_any().type_id() {
return false;
}

// Check if the number of children is the same
if left_children.len() != right_children.len() {
return false;
}

// Check if all children are equal
left_children
.into_iter()
.zip(right_children)
.all(|(left_child, right_child)| self.exprs_equal(left_child, right_child))
}
}

impl Display for EquivalenceGroup {
Expand All @@ -647,9 +700,10 @@ mod tests {

use super::*;
use crate::equivalence::tests::create_test_params;
use crate::expressions::{lit, Literal};
use crate::expressions::{lit, BinaryExpr, Literal};

use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Operator;

#[test]
fn test_bridge_groups() -> Result<()> {
Expand Down Expand Up @@ -777,4 +831,159 @@ mod tests {
assert!(!cls1.contains_any(&cls3));
assert!(!cls2.contains_any(&cls3));
}

#[test]
fn test_exprs_equal() -> Result<()> {
struct TestCase {
left: Arc<dyn PhysicalExpr>,
right: Arc<dyn PhysicalExpr>,
expected: bool,
description: &'static str,
}

// Create test columns
let col_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
let col_b = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
let col_x = Arc::new(Column::new("x", 2)) as Arc<dyn PhysicalExpr>;
let col_y = Arc::new(Column::new("y", 3)) as Arc<dyn PhysicalExpr>;

// Create test literals
let lit_1 =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let lit_2 =
Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;

// Create equivalence group with classes (a = x) and (b = y)
let eq_group = EquivalenceGroup::new(vec![
EquivalenceClass::new(vec![Arc::clone(&col_a), Arc::clone(&col_x)]),
EquivalenceClass::new(vec![Arc::clone(&col_b), Arc::clone(&col_y)]),
]);

let test_cases = vec![
// Basic equality tests
TestCase {
left: Arc::clone(&col_a),
right: Arc::clone(&col_a),
expected: true,
description: "Same column should be equal",
},
// Equivalence class tests
TestCase {
left: Arc::clone(&col_a),
right: Arc::clone(&col_x),
expected: true,
description: "Columns in same equivalence class should be equal",
},
TestCase {
left: Arc::clone(&col_b),
right: Arc::clone(&col_y),
expected: true,
description: "Columns in same equivalence class should be equal",
},
TestCase {
left: Arc::clone(&col_a),
right: Arc::clone(&col_b),
expected: false,
description:
"Columns in different equivalence classes should not be equal",
},
// Literal tests
TestCase {
left: Arc::clone(&lit_1),
right: Arc::clone(&lit_1),
expected: true,
description: "Same literal should be equal",
},
TestCase {
left: Arc::clone(&lit_1),
right: Arc::clone(&lit_2),
expected: false,
description: "Different literals should not be equal",
},
// Complex expression tests
TestCase {
left: Arc::new(BinaryExpr::new(
Arc::clone(&col_a),
Operator::Plus,
Arc::clone(&col_b),
)) as Arc<dyn PhysicalExpr>,
right: Arc::new(BinaryExpr::new(
Arc::clone(&col_x),
Operator::Plus,
Arc::clone(&col_y),
)) as Arc<dyn PhysicalExpr>,
expected: true,
description:
"Binary expressions with equivalent operands should be equal",
},
TestCase {
left: Arc::new(BinaryExpr::new(
Arc::clone(&col_a),
Operator::Plus,
Arc::clone(&col_b),
)) as Arc<dyn PhysicalExpr>,
right: Arc::new(BinaryExpr::new(
Arc::clone(&col_x),
Operator::Plus,
Arc::clone(&col_a),
)) as Arc<dyn PhysicalExpr>,
expected: false,
description:
"Binary expressions with non-equivalent operands should not be equal",
},
TestCase {
left: Arc::new(BinaryExpr::new(
Arc::clone(&col_a),
Operator::Plus,
Arc::clone(&lit_1),
)) as Arc<dyn PhysicalExpr>,
right: Arc::new(BinaryExpr::new(
Arc::clone(&col_x),
Operator::Plus,
Arc::clone(&lit_1),
)) as Arc<dyn PhysicalExpr>,
expected: true,
description: "Binary expressions with equivalent column and same literal should be equal",
},
TestCase {
left: Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::clone(&col_a),
Operator::Plus,
Arc::clone(&col_b),
)),
Operator::Multiply,
Arc::clone(&lit_1),
)) as Arc<dyn PhysicalExpr>,
right: Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::clone(&col_x),
Operator::Plus,
Arc::clone(&col_y),
)),
Operator::Multiply,
Arc::clone(&lit_1),
)) as Arc<dyn PhysicalExpr>,
expected: true,
description: "Nested binary expressions with equivalent operands should be equal",
},
];

for TestCase {
left,
right,
expected,
description,
} in test_cases
{
let actual = eq_group.exprs_equal(&left, &right);
assert_eq!(
actual, expected,
"{}: Failed comparing {:?} and {:?}, expected {}, got {}",
description, left, right, expected, actual
);
}

Ok(())
}
}
Loading
Loading