diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 26b30029f63d9..079e7d42e93e2 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -56,6 +56,7 @@ itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "^1.0" petgraph = "0.6.2" +regex = { workspace = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs index badb00659576c..cc446f3328d54 100644 --- a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs +++ b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs @@ -22,11 +22,16 @@ use arrow_array::{ }; use arrow_buffer::BooleanBufferBuilder; use arrow_schema::{DataType, Schema}; -use datafusion_common::ScalarValue; +use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::ColumnarValue; -use datafusion_physical_expr_common::physical_expr::{down_cast_any_ref, PhysicalExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use regex::Regex; -use std::{any::Any, hash::Hash, sync::Arc}; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter, Result as FmtResult}, + hash::Hash, + sync::Arc, +}; /// ScalarRegexMatchExpr /// Only used when evaluating regexp matching with literal pattern. @@ -133,9 +138,7 @@ impl ScalarRegexMatchExpr { (true, true) => "NOT IMATCH", } } -} -impl ScalarRegexMatchExpr { /// Evaluate the scalar regex match expression match array value fn evaluate_array( &self, @@ -200,16 +203,9 @@ impl ScalarRegexMatchExpr { } } -impl std::hash::Hash for ScalarRegexMatchExpr { - fn hash(&self, state: &mut H) { - self.negated.hash(state); - self.case_insensitive.hash(state); - self.expr.hash(state); - self.pattern.hash(state); - } -} +impl Eq for ScalarRegexMatchExpr {} -impl std::cmp::PartialEq for ScalarRegexMatchExpr { +impl PartialEq for ScalarRegexMatchExpr { fn eq(&self, other: &Self) -> bool { self.negated.eq(&other.negated) && self.case_insensitive.eq(&self.case_insensitive) @@ -218,8 +214,17 @@ impl std::cmp::PartialEq for ScalarRegexMatchExpr { } } -impl std::fmt::Debug for ScalarRegexMatchExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Hash for ScalarRegexMatchExpr { + fn hash(&self, state: &mut H) { + self.negated.hash(state); + self.case_insensitive.hash(state); + self.expr.hash(state); + self.pattern.hash(state); + } +} + +impl Debug for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { f.debug_struct("ScalarRegexMatchExpr") .field("negated", &self.negated) .field("case_insensitive", &self.case_insensitive) @@ -229,35 +234,26 @@ impl std::fmt::Debug for ScalarRegexMatchExpr { } } -impl std::fmt::Display for ScalarRegexMatchExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl Display for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut Formatter) -> FmtResult { write!(f, "{} {} {}", self.expr, self.op_name(), self.pattern) } } impl PhysicalExpr for ScalarRegexMatchExpr { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } - fn data_type( - &self, - _: &arrow_schema::Schema, - ) -> datafusion_common::Result { + fn data_type(&self, _: &Schema) -> DFResult { Ok(DataType::Boolean) } - fn nullable( - &self, - input_schema: &arrow_schema::Schema, - ) -> datafusion_common::Result { + fn nullable(&self, input_schema: &Schema) -> DFResult { Ok(self.expr.nullable(input_schema)? || self.pattern.nullable(input_schema)?) } - fn evaluate( - &self, - batch: &arrow_array::RecordBatch, - ) -> datafusion_common::Result { + fn evaluate(&self, batch: &arrow_array::RecordBatch) -> DFResult { self.expr .evaluate(batch) .and_then(|lhs| { @@ -274,14 +270,14 @@ impl PhysicalExpr for ScalarRegexMatchExpr { .map(ColumnarValue::Array) } - fn children(&self) -> Vec<&std::sync::Arc> { + fn children(&self) -> Vec<&Arc> { vec![&self.expr, &self.pattern] } fn with_new_children( - self: std::sync::Arc, - children: Vec>, - ) -> datafusion_common::Result> { + self: Arc, + children: Vec>, + ) -> DFResult> { Ok(Arc::new(ScalarRegexMatchExpr::new( self.negated, self.case_insensitive, @@ -290,18 +286,24 @@ impl PhysicalExpr for ScalarRegexMatchExpr { ))) } - fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) { - let mut s = state; - self.hash(&mut s); - } -} - -impl PartialEq for ScalarRegexMatchExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self == x) - .unwrap_or(false) + fn evaluate_selection( + &self, + batch: &arrow_array::RecordBatch, + selection: &BooleanArray, + ) -> DFResult { + let tmp_batch = arrow::compute::filter_record_batch(batch, selection)?; + + let tmp_result = self.evaluate(&tmp_batch)?; + + if batch.num_rows() == tmp_batch.num_rows() { + // All values from the `selection` filter are true. + Ok(tmp_result) + } else if let ColumnarValue::Array(a) = tmp_result { + datafusion_physical_expr_common::utils::scatter(selection, a.as_ref()) + .map(ColumnarValue::Array) + } else { + Ok(tmp_result) + } } } @@ -310,7 +312,7 @@ fn array_regexp_match( array: &dyn ArrayAccessor, regex: &Regex, negated: bool, -) -> datafusion_common::Result { +) -> DFResult { let null_bit_buffer = array.nulls().map(|x| x.inner().sliced()); let mut buffer_builder = BooleanBufferBuilder::new(array.len()); @@ -359,7 +361,7 @@ pub fn scalar_regex_match( expr: Arc, pattern: Arc, input_schema: &Schema, -) -> datafusion_common::Result> { +) -> DFResult> { let valid_data_type = |data_type: &DataType| { if !matches!( data_type, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d9c13d837491b..1fefce16c789c 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -841,7 +841,7 @@ message PhysicalExprNode { PhysicalLikeExprNode like_expr = 18; PhysicalExtensionExprNode extension = 19; - + PhysicalScalarRegexMatchExprNode scalar_regex_match_expr = 20; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index fe871b3873672..894ce5c5a525c 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1404,7 +1404,6 @@ pub struct PhysicalScalarRegexMatchExprNode { #[prost(message, optional, boxed, tag = "4")] pub pattern: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FilterExecNode { #[prost(message, optional, boxed, tag = "1")]