Skip to content

Commit

Permalink
Implement Eq, PartialEq, Hash for dyn PhysicalExpr (#13005)
Browse files Browse the repository at this point in the history
* Implement Eq, PartialEq, Hash for PhysicalExpr

* Manually implement PartialEq and Hash for BinaryExpr

* Port more

* Complete manual derivations

* fmt

* add and fix docs

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
peter-toth and alamb authored Nov 5, 2024
1 parent 7c6f891 commit cc43766
Show file tree
Hide file tree
Showing 22 changed files with 231 additions and 368 deletions.
3 changes: 1 addition & 2 deletions datafusion/core/tests/sql/path_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ use bytes::Bytes;
use chrono::{TimeZone, Utc};
use datafusion_expr::{col, lit, Expr, Operator};
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
use datafusion_physical_expr::PhysicalExpr;
use futures::stream::{self, BoxStream};
use object_store::{
path::Path, GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta,
Expand Down Expand Up @@ -97,7 +96,7 @@ async fn parquet_partition_pruning_filter() -> Result<()> {
assert!(pred.as_any().is::<BinaryExpr>());
let pred = pred.as_any().downcast_ref::<BinaryExpr>().unwrap();

assert_eq!(pred, expected.as_any());
assert_eq!(pred, expected.as_ref());

Ok(())
}
Expand Down
84 changes: 37 additions & 47 deletions datafusion/physical-expr-common/src/physical_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use datafusion_expr_common::sort_properties::ExprProperties;
/// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html
/// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html
/// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html
pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash {
/// Returns the physical expression as [`Any`] so that it can be
/// downcast to a specific implementation.
fn as_any(&self) -> &dyn Any;
Expand Down Expand Up @@ -141,38 +141,6 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
Ok(Some(vec![]))
}

/// Update the hash `state` with this expression requirements from
/// [`Hash`].
///
/// This method is required to support hashing [`PhysicalExpr`]s. To
/// implement it, typically the type implementing
/// [`PhysicalExpr`] implements [`Hash`] and
/// then the following boiler plate is used:
///
/// # Example:
/// ```
/// // User defined expression that derives Hash
/// #[derive(Hash, Debug, PartialEq, Eq)]
/// struct MyExpr {
/// val: u64
/// }
///
/// // impl PhysicalExpr {
/// // ...
/// # impl MyExpr {
/// // Boiler plate to call the derived Hash impl
/// fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
/// use std::hash::Hash;
/// let mut s = state;
/// self.hash(&mut s);
/// }
/// // }
/// # }
/// ```
/// Note: [`PhysicalExpr`] is not constrained by [`Hash`]
/// directly because it must remain object safe.
fn dyn_hash(&self, _state: &mut dyn Hasher);

/// Calculates the properties of this [`PhysicalExpr`] based on its
/// children's properties (i.e. order and range), recursively aggregating
/// the information from its children. In cases where the [`PhysicalExpr`]
Expand All @@ -183,6 +151,42 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
}
}

/// [`PhysicalExpr`] can't be constrained by [`Eq`] directly because it must remain object
/// safe. To ease implementation blanket implementation is provided for [`Eq`] types.
pub trait DynEq {
fn dyn_eq(&self, other: &dyn Any) -> bool;
}

impl<T: Eq + Any> DynEq for T {
fn dyn_eq(&self, other: &dyn Any) -> bool {
other
.downcast_ref::<Self>()
.map_or(false, |other| other == self)
}
}

impl PartialEq for dyn PhysicalExpr {
fn eq(&self, other: &Self) -> bool {
self.dyn_eq(other.as_any())
}
}

impl Eq for dyn PhysicalExpr {}

/// [`PhysicalExpr`] can't be constrained by [`Hash`] directly because it must remain
/// object safe. To ease implementation blanket implementation is provided for [`Hash`]
/// types.
pub trait DynHash {
fn dyn_hash(&self, _state: &mut dyn Hasher);
}

impl<T: Hash + Any> DynHash for T {
fn dyn_hash(&self, mut state: &mut dyn Hasher) {
self.type_id().hash(&mut state);
self.hash(&mut state)
}
}

impl Hash for dyn PhysicalExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.dyn_hash(state);
Expand Down Expand Up @@ -210,20 +214,6 @@ pub fn with_new_children_if_necessary(
}
}

pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
if any.is::<Arc<dyn PhysicalExpr>>() {
any.downcast_ref::<Arc<dyn PhysicalExpr>>()
.unwrap()
.as_any()
} else if any.is::<Box<dyn PhysicalExpr>>() {
any.downcast_ref::<Box<dyn PhysicalExpr>>()
.unwrap()
.as_any()
} else {
any
}
}

/// Returns [`Display`] able a list of [`PhysicalExpr`]
///
/// Example output: `[a + 1, b]`
Expand Down
4 changes: 0 additions & 4 deletions datafusion/physical-expr-common/src/sort_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,10 @@ use itertools::Itertools;
/// # fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {todo!() }
/// # fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {todo!()}
/// # fn with_new_children(self: Arc<Self>, children: Vec<Arc<dyn PhysicalExpr>>) -> Result<Arc<dyn PhysicalExpr>> {todo!()}
/// # fn dyn_hash(&self, _state: &mut dyn Hasher) {todo!()}
/// # }
/// # impl Display for MyPhysicalExpr {
/// # fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "a") }
/// # }
/// # impl PartialEq<dyn Any> for MyPhysicalExpr {
/// # fn eq(&self, _other: &dyn Any) -> bool { true }
/// # }
/// # fn col(name: &str) -> Arc<dyn PhysicalExpr> { Arc::new(MyPhysicalExpr) }
/// // Sort by a ASC
/// let options = SortOptions::default();
Expand Down
7 changes: 3 additions & 4 deletions datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ pub struct ConstExpr {

impl PartialEq for ConstExpr {
fn eq(&self, other: &Self) -> bool {
self.across_partitions == other.across_partitions
&& self.expr.eq(other.expr.as_any())
self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
}
}

Expand Down Expand Up @@ -120,7 +119,7 @@ impl ConstExpr {

/// Returns true if this constant expression is equal to the given expression
pub fn eq_expr(&self, other: impl AsRef<dyn PhysicalExpr>) -> bool {
self.expr.eq(other.as_ref().as_any())
self.expr.as_ref() == other.as_ref()
}

/// Returns a [`Display`]able list of `ConstExpr`.
Expand Down Expand Up @@ -556,7 +555,7 @@ impl EquivalenceGroup {
new_classes.push((source, vec![Arc::clone(target)]));
}
if let Some((_, values)) =
new_classes.iter_mut().find(|(key, _)| key.eq(source))
new_classes.iter_mut().find(|(key, _)| *key == source)
{
if !physical_exprs_contains(values, target) {
values.push(Arc::clone(target));
Expand Down
42 changes: 20 additions & 22 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

mod kernels;

use std::hash::{Hash, Hasher};
use std::hash::Hash;
use std::{any::Any, sync::Arc};

use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
use crate::physical_expr::down_cast_any_ref;
use crate::PhysicalExpr;

use arrow::array::*;
Expand All @@ -48,7 +47,7 @@ use kernels::{
};

/// Binary expression
#[derive(Debug, Hash, Clone)]
#[derive(Debug, Clone, Eq)]
pub struct BinaryExpr {
left: Arc<dyn PhysicalExpr>,
op: Operator,
Expand All @@ -57,6 +56,24 @@ pub struct BinaryExpr {
fail_on_overflow: bool,
}

// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
impl PartialEq for BinaryExpr {
fn eq(&self, other: &Self) -> bool {
self.left.eq(&other.left)
&& self.op.eq(&other.op)
&& self.right.eq(&other.right)
&& self.fail_on_overflow.eq(&other.fail_on_overflow)
}
}
impl Hash for BinaryExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.left.hash(state);
self.op.hash(state);
self.right.hash(state);
self.fail_on_overflow.hash(state);
}
}

impl BinaryExpr {
/// Create new binary expression
pub fn new(
Expand Down Expand Up @@ -477,11 +494,6 @@ impl PhysicalExpr for BinaryExpr {
}
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
}

/// For each operator, [`BinaryExpr`] has distinct rules.
/// TODO: There may be rules specific to some data types and expression ranges.
fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
Expand Down Expand Up @@ -525,20 +537,6 @@ impl PhysicalExpr for BinaryExpr {
}
}

impl PartialEq<dyn Any> for BinaryExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.left.eq(&x.left)
&& self.op == x.op
&& self.right.eq(&x.right)
&& self.fail_on_overflow.eq(&x.fail_on_overflow)
})
.unwrap_or(false)
}
}

/// Casts dictionary array to result type for binary numerical operators. Such operators
/// between array and scalar produce a dictionary array other than primitive array of the
/// same operators between array and array. This leads to inconsistent result types causing
Expand Down
40 changes: 3 additions & 37 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
// under the License.

use std::borrow::Cow;
use std::hash::{Hash, Hasher};
use std::hash::Hash;
use std::{any::Any, sync::Arc};

use crate::expressions::try_cast;
use crate::physical_expr::down_cast_any_ref;
use crate::PhysicalExpr;

use arrow::array::*;
Expand All @@ -37,7 +36,7 @@ use itertools::Itertools;

type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);

#[derive(Debug, Hash)]
#[derive(Debug, Hash, PartialEq, Eq)]
enum EvalMethod {
/// CASE WHEN condition THEN result
/// [WHEN ...]
Expand Down Expand Up @@ -80,7 +79,7 @@ enum EvalMethod {
/// [WHEN ...]
/// [ELSE result]
/// END
#[derive(Debug, Hash)]
#[derive(Debug, Hash, PartialEq, Eq)]
pub struct CaseExpr {
/// Optional base expression that can be compared to literal values in the "when" expressions
expr: Option<Arc<dyn PhysicalExpr>>,
Expand Down Expand Up @@ -506,39 +505,6 @@ impl PhysicalExpr for CaseExpr {
)?))
}
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
}
}

impl PartialEq<dyn Any> for CaseExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
let expr_eq = match (&self.expr, &x.expr) {
(Some(expr1), Some(expr2)) => expr1.eq(expr2),
(None, None) => true,
_ => false,
};
let else_expr_eq = match (&self.else_expr, &x.else_expr) {
(Some(expr1), Some(expr2)) => expr1.eq(expr2),
(None, None) => true,
_ => false,
};
expr_eq
&& else_expr_eq
&& self.when_then_expr.len() == x.when_then_expr.len()
&& self.when_then_expr.iter().zip(x.when_then_expr.iter()).all(
|((when1, then1), (when2, then2))| {
when1.eq(when2) && then1.eq(then2)
},
)
})
.unwrap_or(false)
}
}

/// Create a CASE expression
Expand Down
Loading

0 comments on commit cc43766

Please sign in to comment.