From 89def2c6e5f32f15dd80dd19d11b3087e1647310 Mon Sep 17 00:00:00 2001 From: jcsherin Date: Thu, 20 Jun 2024 16:29:59 +0530 Subject: [PATCH 01/11] Convert `bool_and` & `bool_or` to UDAF (#11009) * Port `bool_and` and `bool_or` to `AggregateUDFImpl` * Remove trait methods with default implementation * Add `bool_or_udaf` * Register `bool_and` and `bool_or` * Remove from `physical-expr` * Add expressions to logical plan roundtrip test * minor: remove methods with default implementation * Removes redundant tests * Removes hard-coded function names --- datafusion/expr/src/aggregate_function.rs | 15 - .../expr/src/type_coercion/aggregates.rs | 16 - .../functions-aggregate/src/bool_and_or.rs | 343 +++++++++++++++ datafusion/functions-aggregate/src/lib.rs | 6 +- .../src/aggregate/bool_and_or.rs | 394 ------------------ .../physical-expr/src/aggregate/build_in.rs | 56 +-- .../src/aggregate/groups_accumulator/mod.rs | 3 - datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/expressions/mod.rs | 1 - datafusion/proto/proto/datafusion.proto | 4 +- datafusion/proto/src/generated/pbjson.rs | 6 - datafusion/proto/src/generated/prost.rs | 8 +- .../proto/src/logical_plan/from_proto.rs | 2 - datafusion/proto/src/logical_plan/to_proto.rs | 4 - .../proto/src/physical_plan/to_proto.rs | 12 +- .../tests/cases/roundtrip_logical_plan.rs | 6 +- 16 files changed, 362 insertions(+), 515 deletions(-) create mode 100644 datafusion/functions-aggregate/src/bool_and_or.rs delete mode 100644 datafusion/physical-expr/src/aggregate/bool_and_or.rs diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 1cde1c5050a8..967ccc0b0866 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -47,10 +47,6 @@ pub enum AggregateFunction { Correlation, /// Grouping Grouping, - /// Bool And - BoolAnd, - /// Bool Or - BoolOr, } impl AggregateFunction { @@ -64,8 +60,6 @@ impl AggregateFunction { NthValue => "NTH_VALUE", Correlation => "CORR", Grouping => "GROUPING", - BoolAnd => "BOOL_AND", - BoolOr => "BOOL_OR", } } } @@ -82,8 +76,6 @@ impl FromStr for AggregateFunction { Ok(match name { // general "avg" => AggregateFunction::Avg, - "bool_and" => AggregateFunction::BoolAnd, - "bool_or" => AggregateFunction::BoolOr, "max" => AggregateFunction::Max, "mean" => AggregateFunction::Avg, "min" => AggregateFunction::Min, @@ -128,9 +120,6 @@ impl AggregateFunction { // The coerced_data_types is same with input_types. Ok(coerced_data_types[0].clone()) } - AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { - Ok(DataType::Boolean) - } AggregateFunction::Correlation => { correlation_return_type(&coerced_data_types[0]) } @@ -179,10 +168,6 @@ impl AggregateFunction { .collect::>(); Signature::uniform(1, valid, Volatility::Immutable) } - AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { - Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable) - } - AggregateFunction::Avg => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index abe6d8b1823d..428fc99070d2 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -121,18 +121,6 @@ pub fn coerce_types( }; Ok(vec![v]) } - AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - // smallint, int, bigint, real, double precision, decimal, or interval. - if !is_bool_and_or_support_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(input_types.to_vec()) - } AggregateFunction::Correlation => { if !is_correlation_support_arg_type(&input_types[0]) { return plan_err!( @@ -319,10 +307,6 @@ pub fn avg_sum_type(arg_type: &DataType) -> Result { } } -pub fn is_bool_and_or_support_arg_type(arg_type: &DataType) -> bool { - matches!(arg_type, DataType::Boolean) -} - pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { match arg_type { DataType::Dictionary(_, dict_value_type) => { diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs new file mode 100644 index 000000000000..d0028672743e --- /dev/null +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -0,0 +1,343 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; + +use arrow::array::ArrayRef; +use arrow::array::BooleanArray; +use arrow::compute::bool_and as compute_bool_and; +use arrow::compute::bool_or as compute_bool_or; +use arrow::datatypes::DataType; +use arrow::datatypes::Field; + +use datafusion_common::internal_err; +use datafusion_common::{downcast_value, not_impl_err}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature, Volatility, +}; + +use datafusion_physical_expr_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; + +// returns the new value after bool_and/bool_or with the new values, taking nullability into account +macro_rules! typed_bool_and_or_batch { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let delta = $OP(array); + Ok(ScalarValue::$SCALAR(delta)) + }}; +} + +// bool_and/bool_or the array and returns a ScalarValue of its corresponding type. +macro_rules! bool_and_or_batch { + ($VALUES:expr, $OP:ident) => {{ + match $VALUES.data_type() { + DataType::Boolean => { + typed_bool_and_or_batch!($VALUES, BooleanArray, Boolean, $OP) + } + e => { + return internal_err!( + "Bool and/Bool or is not expected to receive the type {e:?}" + ); + } + } + }}; +} + +/// dynamically-typed bool_and(array) -> ScalarValue +fn bool_and_batch(values: &ArrayRef) -> Result { + bool_and_or_batch!(values, compute_bool_and) +} + +/// dynamically-typed bool_or(array) -> ScalarValue +fn bool_or_batch(values: &ArrayRef) -> Result { + bool_and_or_batch!(values, compute_bool_or) +} + +make_udaf_expr_and_func!( + BoolAnd, + bool_and, + expression, + "The values to combine with `AND`", + bool_and_udaf +); + +make_udaf_expr_and_func!( + BoolOr, + bool_or, + expression, + "The values to combine with `OR`", + bool_or_udaf +); + +/// BOOL_AND aggregate expression +#[derive(Debug)] +pub struct BoolAnd { + signature: Signature, +} + +impl BoolAnd { + fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Boolean], + Volatility::Immutable, + ), + } + } +} + +impl Default for BoolAnd { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for BoolAnd { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bool_and" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn accumulator(&self, _: AccumulatorArgs) -> Result> { + Ok(Box::::default()) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![Field::new( + format_state_name(args.name, self.name()), + DataType::Boolean, + true, + )]) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + match args.data_type { + DataType::Boolean => { + Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x && y))) + } + _ => not_impl_err!( + "GroupsAccumulator not supported for {} with {}", + args.name, + args.data_type + ), + } + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn create_sliding_accumulator( + &self, + _: AccumulatorArgs, + ) -> Result> { + Ok(Box::::default()) + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Insensitive + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + +#[derive(Debug, Default)] +struct BoolAndAccumulator { + acc: Option, +} + +impl Accumulator for BoolAndAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + self.acc = match (self.acc, bool_and_batch(values)?) { + (None, ScalarValue::Boolean(v)) => v, + (Some(v), ScalarValue::Boolean(None)) => Some(v), + (Some(a), ScalarValue::Boolean(Some(b))) => Some(a && b), + _ => unreachable!(), + }; + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Boolean(self.acc)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Boolean(self.acc)]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} + +/// BOOL_OR aggregate expression +#[derive(Debug, Clone)] +pub struct BoolOr { + signature: Signature, +} + +impl BoolOr { + fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Boolean], + Volatility::Immutable, + ), + } + } +} + +impl Default for BoolOr { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for BoolOr { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bool_or" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn accumulator(&self, _: AccumulatorArgs) -> Result> { + Ok(Box::::default()) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![Field::new( + format_state_name(args.name, self.name()), + DataType::Boolean, + true, + )]) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + match args.data_type { + DataType::Boolean => { + Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x || y))) + } + _ => not_impl_err!( + "GroupsAccumulator not supported for {} with {}", + args.name, + args.data_type + ), + } + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn create_sliding_accumulator( + &self, + _: AccumulatorArgs, + ) -> Result> { + Ok(Box::::default()) + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Insensitive + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + +#[derive(Debug, Default)] +struct BoolOrAccumulator { + acc: Option, +} + +impl Accumulator for BoolOrAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + self.acc = match (self.acc, bool_or_batch(values)?) { + (None, ScalarValue::Boolean(v)) => v, + (Some(v), ScalarValue::Boolean(None)) => Some(v), + (Some(a), ScalarValue::Boolean(Some(b))) => Some(a || b), + _ => unreachable!(), + }; + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Boolean(self.acc)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Boolean(self.acc)]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 20a8d2c15926..260d6dab31b9 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -70,8 +70,8 @@ pub mod approx_median; pub mod approx_percentile_cont; pub mod approx_percentile_cont_with_weight; pub mod bit_and_or_xor; +pub mod bool_and_or; pub mod string_agg; - use crate::approx_percentile_cont::approx_percentile_cont_udaf; use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf; use datafusion_common::Result; @@ -89,6 +89,8 @@ pub mod expr_fn { pub use super::bit_and_or_xor::bit_and; pub use super::bit_and_or_xor::bit_or; pub use super::bit_and_or_xor::bit_xor; + pub use super::bool_and_or::bool_and; + pub use super::bool_and_or::bool_or; pub use super::count::count; pub use super::count::count_distinct; pub use super::covariance::covar_pop; @@ -143,6 +145,8 @@ pub fn all_default_aggregate_functions() -> Vec> { bit_and_or_xor::bit_and_udaf(), bit_and_or_xor::bit_or_udaf(), bit_and_or_xor::bit_xor_udaf(), + bool_and_or::bool_and_udaf(), + bool_and_or::bool_or_udaf(), ] } diff --git a/datafusion/physical-expr/src/aggregate/bool_and_or.rs b/datafusion/physical-expr/src/aggregate/bool_and_or.rs deleted file mode 100644 index 341932bd77a4..000000000000 --- a/datafusion/physical-expr/src/aggregate/bool_and_or.rs +++ /dev/null @@ -1,394 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::datatypes::DataType; -use arrow::{ - array::{ArrayRef, BooleanArray}, - datatypes::Field, -}; -use datafusion_common::{ - downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, -}; -use datafusion_expr::{Accumulator, GroupsAccumulator}; -use std::any::Any; -use std::sync::Arc; - -use crate::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use arrow::array::Array; -use arrow::compute::{bool_and, bool_or}; - -// returns the new value after bool_and/bool_or with the new values, taking nullability into account -macro_rules! typed_bool_and_or_batch { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let delta = $OP(array); - Ok(ScalarValue::$SCALAR(delta)) - }}; -} - -// bool_and/bool_or the array and returns a ScalarValue of its corresponding type. -macro_rules! bool_and_or_batch { - ($VALUES:expr, $OP:ident) => {{ - match $VALUES.data_type() { - DataType::Boolean => { - typed_bool_and_or_batch!($VALUES, BooleanArray, Boolean, $OP) - } - e => { - return internal_err!( - "Bool and/Bool or is not expected to receive the type {e:?}" - ); - } - } - }}; -} - -/// dynamically-typed bool_and(array) -> ScalarValue -fn bool_and_batch(values: &ArrayRef) -> Result { - bool_and_or_batch!(values, bool_and) -} - -/// dynamically-typed bool_or(array) -> ScalarValue -fn bool_or_batch(values: &ArrayRef) -> Result { - bool_and_or_batch!(values, bool_or) -} - -/// BOOL_AND aggregate expression -#[derive(Debug, Clone)] -pub struct BoolAnd { - name: String, - pub data_type: DataType, - expr: Arc, - nullable: bool, -} - -impl BoolAnd { - /// Create a new BOOL_AND aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for BoolAnd { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::::default()) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "bool_and"), - self.data_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - true - } - - fn create_groups_accumulator(&self) -> Result> { - match self.data_type { - DataType::Boolean => { - Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x && y))) - } - _ => not_impl_err!( - "GroupsAccumulator not supported for {} with {}", - self.name(), - self.data_type - ), - } - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::::default()) - } -} - -impl PartialEq for BoolAnd { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug, Default)] -struct BoolAndAccumulator { - acc: Option, -} - -impl Accumulator for BoolAndAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - self.acc = match (self.acc, bool_and_batch(values)?) { - (None, ScalarValue::Boolean(v)) => v, - (Some(v), ScalarValue::Boolean(None)) => Some(v), - (Some(a), ScalarValue::Boolean(Some(b))) => Some(a && b), - _ => unreachable!(), - }; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn state(&mut self) -> Result> { - Ok(vec![ScalarValue::Boolean(self.acc)]) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Boolean(self.acc)) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -/// BOOL_OR aggregate expression -#[derive(Debug, Clone)] -pub struct BoolOr { - name: String, - pub data_type: DataType, - expr: Arc, - nullable: bool, -} - -impl BoolOr { - /// Create a new BOOL_OR aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for BoolOr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::::default()) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "bool_or"), - self.data_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - true - } - - fn create_groups_accumulator(&self) -> Result> { - match self.data_type { - DataType::Boolean => { - Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x || y))) - } - _ => not_impl_err!( - "GroupsAccumulator not supported for {} with {}", - self.name(), - self.data_type - ), - } - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::::default()) - } -} - -impl PartialEq for BoolOr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug, Default)] -struct BoolOrAccumulator { - acc: Option, -} - -impl Accumulator for BoolOrAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ScalarValue::Boolean(self.acc)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - self.acc = match (self.acc, bool_or_batch(values)?) { - (None, ScalarValue::Boolean(v)) => v, - (Some(v), ScalarValue::Boolean(None)) => Some(v), - (Some(a), ScalarValue::Boolean(Some(b))) => Some(a || b), - _ => unreachable!(), - }; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Boolean(self.acc)) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - - #[test] - fn test_bool_and() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, true, false])); - generic_test_op!(a, DataType::Boolean, BoolAnd, ScalarValue::from(false)) - } - - #[test] - fn bool_and_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![ - Some(true), - None, - Some(true), - Some(true), - ])); - generic_test_op!(a, DataType::Boolean, BoolAnd, ScalarValue::from(true)) - } - - #[test] - fn bool_and_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![None, None])); - generic_test_op!(a, DataType::Boolean, BoolAnd, ScalarValue::Boolean(None)) - } - - #[test] - fn test_bool_or() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, true, false])); - generic_test_op!(a, DataType::Boolean, BoolOr, ScalarValue::from(true)) - } - - #[test] - fn bool_or_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![ - Some(false), - None, - Some(false), - Some(false), - ])); - generic_test_op!(a, DataType::Boolean, BoolOr, ScalarValue::from(false)) - } - - #[test] - fn bool_or_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![None, None])); - generic_test_op!(a, DataType::Boolean, BoolOr, ScalarValue::Boolean(None)) - } -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 1dfe9ffd6905..53cfcfb033a1 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -66,16 +66,6 @@ pub fn create_aggregate_expr( name, data_type, )), - (AggregateFunction::BoolAnd, _) => Arc::new(expressions::BoolAnd::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::BoolOr, _) => Arc::new(expressions::BoolOr::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), (AggregateFunction::ArrayAgg, false) => { let expr = input_phy_exprs[0].clone(); let nullable = expr.nullable(input_schema)?; @@ -165,9 +155,7 @@ mod tests { use datafusion_common::plan_err; use datafusion_expr::{type_coercion, Signature}; - use crate::expressions::{ - try_cast, ArrayAgg, Avg, BoolAnd, BoolOr, DistinctArrayAgg, Max, Min, - }; + use crate::expressions::{try_cast, ArrayAgg, Avg, DistinctArrayAgg, Max, Min}; use super::*; #[test] @@ -281,48 +269,6 @@ mod tests { Ok(()) } - #[test] - fn test_bool_and_or_expr() -> Result<()> { - let funcs = vec![AggregateFunction::BoolAnd, AggregateFunction::BoolOr]; - let data_types = vec![DataType::Boolean]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - match fun { - AggregateFunction::BoolAnd => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::BoolOr => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; - } - } - Ok(()) - } - #[test] fn test_sum_avg_expr() -> Result<()> { let funcs = vec![AggregateFunction::Avg]; diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs index a6946e739c97..73d810ec056d 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs @@ -25,9 +25,6 @@ pub(crate) mod accumulate { pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; -pub(crate) mod bool_op { - pub use datafusion_physical_expr_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; -} pub(crate) mod prim_op { pub use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 87c7deccc2cd..f64c5b1fb260 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -21,7 +21,6 @@ pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; pub(crate) mod average; -pub(crate) mod bool_and_or; pub(crate) mod correlation; pub(crate) mod covariance; pub(crate) mod grouping; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 322610404074..0020aa5f55b2 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -40,7 +40,6 @@ pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::average::Avg; pub use crate::aggregate::average::AvgAccumulator; -pub use crate::aggregate::bool_and_or::{BoolAnd, BoolOr}; pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::grouping::Grouping; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 6375df721ae6..50356d5b6052 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -494,8 +494,8 @@ enum AggregateFunction { // BIT_AND = 19; // BIT_OR = 20; // BIT_XOR = 21; - BOOL_AND = 22; - BOOL_OR = 23; +// BOOL_AND = 22; +// BOOL_OR = 23; // REGR_SLOPE = 26; // REGR_INTERCEPT = 27; // REGR_COUNT = 28; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 5c483f70d150..8cca0fe4a876 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -538,8 +538,6 @@ impl serde::Serialize for AggregateFunction { Self::ArrayAgg => "ARRAY_AGG", Self::Correlation => "CORRELATION", Self::Grouping => "GROUPING", - Self::BoolAnd => "BOOL_AND", - Self::BoolOr => "BOOL_OR", Self::NthValueAgg => "NTH_VALUE_AGG", }; serializer.serialize_str(variant) @@ -558,8 +556,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "ARRAY_AGG", "CORRELATION", "GROUPING", - "BOOL_AND", - "BOOL_OR", "NTH_VALUE_AGG", ]; @@ -607,8 +603,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "CORRELATION" => Ok(AggregateFunction::Correlation), "GROUPING" => Ok(AggregateFunction::Grouping), - "BOOL_AND" => Ok(AggregateFunction::BoolAnd), - "BOOL_OR" => Ok(AggregateFunction::BoolOr), "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index bc5b6be2ad87..56f14982923d 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1948,8 +1948,8 @@ pub enum AggregateFunction { /// BIT_AND = 19; /// BIT_OR = 20; /// BIT_XOR = 21; - BoolAnd = 22, - BoolOr = 23, + /// BOOL_AND = 22; + /// BOOL_OR = 23; /// REGR_SLOPE = 26; /// REGR_INTERCEPT = 27; /// REGR_COUNT = 28; @@ -1975,8 +1975,6 @@ impl AggregateFunction { AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Correlation => "CORRELATION", AggregateFunction::Grouping => "GROUPING", - AggregateFunction::BoolAnd => "BOOL_AND", - AggregateFunction::BoolOr => "BOOL_OR", AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } } @@ -1989,8 +1987,6 @@ impl AggregateFunction { "ARRAY_AGG" => Some(Self::ArrayAgg), "CORRELATION" => Some(Self::Correlation), "GROUPING" => Some(Self::Grouping), - "BOOL_AND" => Some(Self::BoolAnd), - "BOOL_OR" => Some(Self::BoolOr), "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 5bec655bb1ff..ba0e708218cf 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -140,8 +140,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Min => Self::Min, protobuf::AggregateFunction::Max => Self::Max, protobuf::AggregateFunction::Avg => Self::Avg, - protobuf::AggregateFunction::BoolAnd => Self::BoolAnd, - protobuf::AggregateFunction::BoolOr => Self::BoolOr, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::Correlation => Self::Correlation, protobuf::AggregateFunction::Grouping => Self::Grouping, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 66b7c77799ea..08999effa4b1 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -111,8 +111,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Min => Self::Min, AggregateFunction::Max => Self::Max, AggregateFunction::Avg => Self::Avg, - AggregateFunction::BoolAnd => Self::BoolAnd, - AggregateFunction::BoolOr => Self::BoolOr, AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Correlation => Self::Correlation, AggregateFunction::Grouping => Self::Grouping, @@ -376,8 +374,6 @@ pub fn serialize_expr( AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, - AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index ed966509b842..a9d3736dee08 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,10 +23,10 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ArrayAgg, Avg, BinaryExpr, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, - CumeDist, DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, - Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, - OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, WindowShift, + ArrayAgg, Avg, BinaryExpr, CaseExpr, CastExpr, Column, Correlation, CumeDist, + DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, + NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, + RankType, RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -240,10 +240,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Grouping - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::BoolAnd - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::BoolOr } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ArrayAgg } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 61764394ee74..b3966c3f0204 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -59,7 +59,9 @@ use datafusion_expr::{ TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; -use datafusion_functions_aggregate::expr_fn::{bit_and, bit_or, bit_xor}; +use datafusion_functions_aggregate::expr_fn::{ + bit_and, bit_or, bit_xor, bool_and, bool_or, +}; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -671,6 +673,8 @@ async fn roundtrip_expr_api() -> Result<()> { bit_or(lit(2)), bit_xor(lit(2)), string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")), + bool_and(lit(true)), + bool_or(lit(true)), ]; // ensure expressions created with the expr api can be round tripped From 58d23c5c050f43aa7b867d4f0be7298d8d6cad83 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Thu, 20 Jun 2024 05:59:56 -0700 Subject: [PATCH 02/11] feat: support uint data page extraction (#11018) --- .../physical_plan/parquet/statistics.rs | 42 +++++++++++++++++++ .../core/tests/parquet/arrow_statistics.rs | 8 ++-- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index 2ca47de990ec..3be060ce6180 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -613,6 +613,48 @@ macro_rules! get_data_page_statistics { ($stat_type_prefix: ident, $data_type: ident, $iterator: ident) => { paste! { match $data_type { + Some(DataType::UInt8) => Ok(Arc::new( + UInt8Array::from_iter( + [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) + .map(|x| { + x.into_iter().filter_map(|x| { + x.and_then(|x| u8::try_from(x).ok()) + }) + }) + .flatten() + ) + )), + Some(DataType::UInt16) => Ok(Arc::new( + UInt16Array::from_iter( + [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) + .map(|x| { + x.into_iter().filter_map(|x| { + x.and_then(|x| u16::try_from(x).ok()) + }) + }) + .flatten() + ) + )), + Some(DataType::UInt32) => Ok(Arc::new( + UInt32Array::from_iter( + [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) + .map(|x| { + x.into_iter().filter_map(|x| { + x.and_then(|x| u32::try_from(x).ok()) + }) + }) + .flatten() + ))), + Some(DataType::UInt64) => Ok(Arc::new( + UInt64Array::from_iter( + [<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator) + .map(|x| { + x.into_iter().filter_map(|x| { + x.and_then(|x| u64::try_from(x).ok()) + }) + }) + .flatten() + ))), Some(DataType::Int8) => Ok(Arc::new( Int8Array::from_iter( [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index bdae9f47867e..e48c11fcbaab 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -1352,7 +1352,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4, 4])), column_name: "u8", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1363,7 +1363,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4, 4])), column_name: "u16", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1374,7 +1374,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4, 4])), column_name: "u32", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1385,7 +1385,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4, 4])), column_name: "u64", - check: Check::RowGroup, + check: Check::Both, } .run(); } From 5316278ceac6743fde8dfc94880eb63d0c4c3d2e Mon Sep 17 00:00:00 2001 From: Lorrens Pantelis <100197010+LorrensP-2158466@users.noreply.github.com> Date: Thu, 20 Jun 2024 17:39:49 +0200 Subject: [PATCH 03/11] propagate error instead of panicking on out of bounds in physical-expr/src/analysis.rs (#10992) * propogate error instead of panicking * use macro for creating internal df error --- datafusion/physical-expr/src/analysis.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index ca25bfd647b6..e7b199af3743 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -27,7 +27,9 @@ use crate::PhysicalExpr; use arrow::datatypes::Schema; use datafusion_common::stats::Precision; -use datafusion_common::{internal_err, ColumnStatistics, Result, ScalarValue}; +use datafusion_common::{ + internal_datafusion_err, internal_err, ColumnStatistics, Result, ScalarValue, +}; use datafusion_expr::interval_arithmetic::{cardinality_ratio, Interval}; /// The shared context used during the analysis of an expression. Includes @@ -92,7 +94,13 @@ impl ExprBoundaries { col_stats: &ColumnStatistics, col_index: usize, ) -> Result { - let field = &schema.fields()[col_index]; + let field = schema.fields().get(col_index).ok_or_else(|| { + internal_datafusion_err!( + "Could not create `ExprBoundaries`: in `try_from_column` `col_index` + has gone out of bounds with a value of {col_index}, the schema has {} columns.", + schema.fields.len() + ) + })?; let empty_field = ScalarValue::try_from(field.data_type()).unwrap_or(ScalarValue::Null); let interval = Interval::try_new( From 1155b0b15e6ce3a8d5d28e5ecaebf4706448c548 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 20 Jun 2024 11:44:56 -0400 Subject: [PATCH 04/11] Minor: Add more docs and examples for `Transformed` and `TransformedResult` (#11003) --- datafusion/common/src/tree_node.rs | 83 +++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 276a1cc4c59c..bb54f4e13af9 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -582,7 +582,11 @@ impl TreeNodeRecursion { /// Result of tree walk / transformation APIs /// -/// API users control the transformation by returning: +/// `Transformed` is a wrapper around the tree node data (e.g. `Expr` or +/// `LogicalPlan`). It is used to indicate whether the node was transformed +/// and how the recursion should proceed. +/// +/// [`TreeNode`] API users control the transformation by returning: /// - The resulting (possibly transformed) node, /// - `transformed`: flag indicating whether any change was made to the node /// - `tnr`: [`TreeNodeRecursion`] specifying how to proceed with the recursion. @@ -592,7 +596,66 @@ impl TreeNodeRecursion { /// - `transformed`: flag indicating whether any change was made to the node /// - `tnr`: [`TreeNodeRecursion`] specifying how the recursion ended. /// -/// Example APIs: +/// See also +/// * [`Transformed::update_data`] to modify the node without changing the `transformed` flag +/// * [`Transformed::map_data`] for fallable operation that return the same type +/// * [`Transformed::transform_data`] to chain fallable transformations +/// * [`TransformedResult`] for working with `Result>` +/// +/// # Examples +/// +/// Use [`Transformed::yes`] and [`Transformed::no`] to signal that a node was +/// rewritten and the recursion should continue: +/// +/// ``` +/// # use datafusion_common::tree_node::Transformed; +/// # // note use i64 instead of Expr as Expr is not in datafusion-common +/// # fn orig_expr() -> i64 { 1 } +/// # fn make_new_expr(i: i64) -> i64 { 2 } +/// let expr = orig_expr(); +/// +/// // Create a new `Transformed` object signaling the node was not rewritten +/// let ret = Transformed::no(expr.clone()); +/// assert!(!ret.transformed); +/// +/// // Create a new `Transformed` object signaling the node was rewritten +/// let ret = Transformed::yes(expr); +/// assert!(ret.transformed) +/// ``` +/// +/// Access the node within the `Transformed` object: +/// ``` +/// # use datafusion_common::tree_node::Transformed; +/// # // note use i64 instead of Expr as Expr is not in datafusion-common +/// # fn orig_expr() -> i64 { 1 } +/// # fn make_new_expr(i: i64) -> i64 { 2 } +/// let expr = orig_expr(); +/// +/// // `Transformed` object signaling the node was not rewritten +/// let ret = Transformed::no(expr.clone()); +/// // Access the inner object using .data +/// assert_eq!(expr, ret.data); +/// ``` +/// +/// Transform the node within the `Transformed` object. +/// +/// ``` +/// # use datafusion_common::tree_node::Transformed; +/// # // note use i64 instead of Expr as Expr is not in datafusion-common +/// # fn orig_expr() -> i64 { 1 } +/// # fn make_new_expr(i: i64) -> i64 { 2 } +/// let expr = orig_expr(); +/// let ret = Transformed::no(expr.clone()) +/// .transform_data(|expr| { +/// // closure returns a result and potentially transforms the node +/// // in this example, it does transform the node +/// let new_expr = make_new_expr(expr); +/// Ok(Transformed::yes(new_expr)) +/// }).unwrap(); +/// // transformed flag is the union of the original ans closure's transformed flag +/// assert!(ret.transformed); +/// ``` +/// # Example APIs that use `TreeNode` /// - [`TreeNode`], /// - [`TreeNode::rewrite`], /// - [`TreeNode::transform_down`], @@ -833,6 +896,22 @@ macro_rules! map_until_stop_and_collect { } /// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. +/// +/// # Example +/// Access the internal data of a `Result>` +/// as a `Result` using the `data` method: +/// ``` +/// # use datafusion_common::Result; +/// # use datafusion_common::tree_node::{Transformed, TransformedResult}; +/// # // note use i64 instead of Expr as Expr is not in datafusion-common +/// # fn update_expr() -> i64 { 1 } +/// # fn main() -> Result<()> { +/// let transformed: Result> = Ok(Transformed::yes(update_expr())); +/// // access the internal data of the transformed result, or return the error +/// let transformed_expr = transformed.data()?; +/// # Ok(()) +/// # } +/// ``` pub trait TransformedResult { fn data(self) -> Result; From 18042fd69138e19613844580408a71a200ea6caa Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Thu, 20 Jun 2024 09:58:01 -0700 Subject: [PATCH 05/11] feat: propagate EmptyRelation for more join types (#10963) * feat: propagate empty for more join types * feat: update subquery de-correlation test * tests: simplify tests * refactor: better name * style: clippy * refactor: update tests * refactor: rename * refactor: fix spellings * add slt tests --- .../core/tests/parquet/arrow_statistics.rs | 8 +- .../optimizer/src/eliminate_one_union.rs | 3 +- .../optimizer/src/propagate_empty_relation.rs | 177 +++++++++++++++--- datafusion/optimizer/src/test/mod.rs | 39 +++- datafusion/sqllogictest/test_files/joins.slt | 96 ++++++++++ .../sqllogictest/test_files/subquery.slt | 5 +- 6 files changed, 290 insertions(+), 38 deletions(-) diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index e48c11fcbaab..ddb39fce4076 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! This file contains an end to end test of extracting statitics from parquet files. +//! This file contains an end to end test of extracting statistics from parquet files. //! It writes data into a parquet file, reads statistics and verifies they are correct use std::default::Default; @@ -716,8 +716,8 @@ async fn test_timestamp() { // "seconds_timezoned" --> TimestampSecondArray // "names" --> StringArray // - // The file is created by 4 record batches, each has 5 rowws. - // Since the row group isze is set to 5, those 4 batches will go into 4 row groups + // The file is created by 4 record batches, each has 5 rows. + // Since the row group size is set to 5, those 4 batches will go into 4 row groups // This creates a parquet files of 4 columns named "nanos", "nanos_timezoned", "micros", "micros_timezoned", "millis", "millis_timezoned", "seconds", "seconds_timezoned" let reader = TestReader { scenario: Scenario::Timestamps, @@ -2039,7 +2039,7 @@ async fn test_missing_statistics() { expected_min: Arc::new(Int64Array::from(vec![None])), expected_max: Arc::new(Int64Array::from(vec![None])), expected_null_counts: UInt64Array::from(vec![None]), - expected_row_counts: Some(UInt64Array::from(vec![3])), // stil has row count statistics + expected_row_counts: Some(UInt64Array::from(vec![3])), // still has row count statistics column_name: "i64", check: Check::RowGroup, } diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index 68d0ddba8b20..7763e7d3b796 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -88,10 +88,11 @@ mod tests { } fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq_with_rules( + assert_optimized_plan_with_rules( vec![Arc::new(EliminateOneUnion::new())], plan, expected, + true, ) } diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index d08820c58a05..dfcfc313efcc 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use datafusion_common::tree_node::Transformed; -use datafusion_common::JoinType::Inner; +use datafusion_common::JoinType; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::LogicalPlan; @@ -94,13 +94,9 @@ impl OptimizerRule for PropagateEmptyRelation { Ok(Transformed::no(LogicalPlan::CrossJoin(join.clone()))) } - LogicalPlan::Join(ref join) if join.join_type == Inner => { + LogicalPlan::Join(ref join) => { // TODO: For Join, more join type need to be careful: - // For LeftOuter/LeftSemi/LeftAnti Join, only the left side is empty, the Join result is empty. - // For LeftSemi Join, if the right side is empty, the Join result is empty. // For LeftAnti Join, if the right side is empty, the Join result is left side(should exclude null ??). - // For RightOuter/RightSemi/RightAnti Join, only the right side is empty, the Join result is empty. - // For RightSemi Join, if the left side is empty, the Join result is empty. // For RightAnti Join, if the left side is empty, the Join result is right side(should exclude null ??). // For Full Join, only both sides are empty, the Join result is empty. // For LeftOut/Full Join, if the right side is empty, the Join can be eliminated with a Projection with left side @@ -108,15 +104,52 @@ impl OptimizerRule for PropagateEmptyRelation { // For RightOut/Full Join, if the left side is empty, the Join can be eliminated with a Projection with right side // columns + left side columns replaced with null values. let (left_empty, right_empty) = binary_plan_children_is_empty(&plan)?; - if left_empty || right_empty { - return Ok(Transformed::yes(LogicalPlan::EmptyRelation( - EmptyRelation { + + match join.join_type { + JoinType::Inner if left_empty || right_empty => Ok(Transformed::yes( + LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: join.schema.clone(), - }, - ))); + }), + )), + JoinType::Left if left_empty => Ok(Transformed::yes( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: join.schema.clone(), + }), + )), + JoinType::Right if right_empty => Ok(Transformed::yes( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: join.schema.clone(), + }), + )), + JoinType::LeftSemi if left_empty || right_empty => Ok( + Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: join.schema.clone(), + })), + ), + JoinType::RightSemi if left_empty || right_empty => Ok( + Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: join.schema.clone(), + })), + ), + JoinType::LeftAnti if left_empty => Ok(Transformed::yes( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: join.schema.clone(), + }), + )), + JoinType::RightAnti if right_empty => Ok(Transformed::yes( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: join.schema.clone(), + }), + )), + _ => Ok(Transformed::no(LogicalPlan::Join(join.clone()))), } - Ok(Transformed::no(LogicalPlan::Join(join.clone()))) } LogicalPlan::Aggregate(ref agg) => { if !agg.group_expr.is_empty() { @@ -222,7 +255,7 @@ mod tests { use crate::eliminate_filter::EliminateFilter; use crate::eliminate_nested_union::EliminateNestedUnion; use crate::test::{ - assert_optimized_plan_eq, assert_optimized_plan_eq_with_rules, test_table_scan, + assert_optimized_plan_eq, assert_optimized_plan_with_rules, test_table_scan, test_table_scan_fields, test_table_scan_with_name, }; @@ -232,11 +265,12 @@ mod tests { assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()), plan, expected) } - fn assert_together_optimized_plan_eq( + fn assert_together_optimized_plan( plan: LogicalPlan, expected: &str, + eq: bool, ) -> Result<()> { - assert_optimized_plan_eq_with_rules( + assert_optimized_plan_with_rules( vec![ Arc::new(EliminateFilter::new()), Arc::new(EliminateNestedUnion::new()), @@ -244,6 +278,7 @@ mod tests { ], plan, expected, + eq, ) } @@ -279,7 +314,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } #[test] @@ -292,7 +327,7 @@ mod tests { let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } #[test] @@ -317,7 +352,7 @@ mod tests { let expected = "Union\ \n TableScan: test1\ \n TableScan: test4"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } #[test] @@ -342,7 +377,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } #[test] @@ -369,7 +404,7 @@ mod tests { let expected = "Union\ \n TableScan: test2\ \n TableScan: test3"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } #[test] @@ -382,7 +417,7 @@ mod tests { let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } #[test] @@ -397,7 +432,103 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) + } + + fn assert_empty_left_empty_right_lp( + left_empty: bool, + right_empty: bool, + join_type: JoinType, + eq: bool, + ) -> Result<()> { + let left_lp = if left_empty { + let left_table_scan = test_table_scan()?; + + LogicalPlanBuilder::from(left_table_scan) + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build() + } else { + let scan = test_table_scan_with_name("left").unwrap(); + LogicalPlanBuilder::from(scan).build() + }?; + + let right_lp = if right_empty { + let right_table_scan = test_table_scan_with_name("right")?; + + LogicalPlanBuilder::from(right_table_scan) + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build() + } else { + let scan = test_table_scan_with_name("right").unwrap(); + LogicalPlanBuilder::from(scan).build() + }?; + + let plan = LogicalPlanBuilder::from(left_lp) + .join_using( + right_lp, + join_type, + vec![Column::from_name("a".to_string())], + )? + .build()?; + + let expected = "EmptyRelation"; + assert_together_optimized_plan(plan, expected, eq) + } + + #[test] + fn test_join_empty_propagation_rules() -> Result<()> { + // test left join with empty left + assert_empty_left_empty_right_lp(true, false, JoinType::Left, true)?; + + // test right join with empty right + assert_empty_left_empty_right_lp(false, true, JoinType::Right, true)?; + + // test left semi join with empty left + assert_empty_left_empty_right_lp(true, false, JoinType::LeftSemi, true)?; + + // test left semi join with empty right + assert_empty_left_empty_right_lp(false, true, JoinType::LeftSemi, true)?; + + // test right semi join with empty left + assert_empty_left_empty_right_lp(true, false, JoinType::RightSemi, true)?; + + // test right semi join with empty right + assert_empty_left_empty_right_lp(false, true, JoinType::RightSemi, true)?; + + // test left anti join empty left + assert_empty_left_empty_right_lp(true, false, JoinType::LeftAnti, true)?; + + // test right anti join empty right + assert_empty_left_empty_right_lp(false, true, JoinType::RightAnti, true) + } + + #[test] + fn test_join_empty_propagation_rules_noop() -> Result<()> { + // these cases should not result in an empty relation + + // test left join with empty right + assert_empty_left_empty_right_lp(false, true, JoinType::Left, false)?; + + // test right join with empty left + assert_empty_left_empty_right_lp(true, false, JoinType::Right, false)?; + + // test left semi with non-empty left and right + assert_empty_left_empty_right_lp(false, false, JoinType::LeftSemi, false)?; + + // test right semi with non-empty left and right + assert_empty_left_empty_right_lp(false, false, JoinType::RightSemi, false)?; + + // test left anti join with non-empty left and right + assert_empty_left_empty_right_lp(false, false, JoinType::LeftAnti, false)?; + + // test left anti with non-empty left and empty right + assert_empty_left_empty_right_lp(false, true, JoinType::LeftAnti, false)?; + + // test right anti join with non-empty left and right + assert_empty_left_empty_right_lp(false, false, JoinType::RightAnti, false)?; + + // test right anti with empty left and non-empty right + assert_empty_left_empty_right_lp(true, false, JoinType::RightAnti, false) } #[test] @@ -430,6 +561,6 @@ mod tests { let expected = "Projection: a, b, c\ \n TableScan: test"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 98d19956df3c..2c7e8644026e 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -121,6 +121,21 @@ pub fn assert_analyzed_plan_eq( Ok(()) } + +pub fn assert_analyzed_plan_ne( + rule: Arc, + plan: LogicalPlan, + expected: &str, +) -> Result<()> { + let options = ConfigOptions::default(); + let analyzed_plan = + Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; + let formatted_plan = format!("{analyzed_plan:?}"); + assert_ne!(formatted_plan, expected); + + Ok(()) +} + pub fn assert_analyzed_plan_eq_display_indent( rule: Arc, plan: LogicalPlan, @@ -169,21 +184,33 @@ pub fn assert_optimized_plan_eq( Ok(()) } -pub fn assert_optimized_plan_eq_with_rules( +fn generate_optimized_plan_with_rules( rules: Vec>, plan: LogicalPlan, - expected: &str, -) -> Result<()> { +) -> LogicalPlan { fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} let config = &mut OptimizerContext::new() .with_max_passes(1) .with_skip_failing_rules(false); let optimizer = Optimizer::with_rules(rules); - let optimized_plan = optimizer + optimizer .optimize(plan, config, observe) - .expect("failed to optimize plan"); + .expect("failed to optimize plan") +} + +pub fn assert_optimized_plan_with_rules( + rules: Vec>, + plan: LogicalPlan, + expected: &str, + eq: bool, +) -> Result<()> { + let optimized_plan = generate_optimized_plan_with_rules(rules, plan); let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(formatted_plan, expected); + if eq { + assert_eq!(formatted_plan, expected); + } else { + assert_ne!(formatted_plan, expected); + } Ok(()) } diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index ea7103255697..0cecd7935a74 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3686,3 +3686,99 @@ set datafusion.explain.logical_plan_only = true; statement ok set datafusion.execution.target_partitions = 2; + +# Inner join with empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a WHERE 1=0 +) AS a INNER JOIN (SELECT 1 as a) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Inner join with empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a +) AS a INNER JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Left join with empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a WHERE 1=0 +) AS a LEFT JOIN (SELECT 1 as a) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Left join with empty left and empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a WHERE 1=0 +) AS a LEFT JOIN (SELECT 1 as a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Right join with empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a +) AS a RIGHT JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Right join with empty right and empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a WHERE 1=0 +) AS a RIGHT JOIN (SELECT 1 as a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Left SEMI join with empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a +) AS a LEFT SEMI JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Left SEMI join with empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a WHERE 1=0 +) AS a LEFT SEMI JOIN (SELECT 1 AS a) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Right SEMI join with empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a WHERE 1=0 +) AS a RIGHT SEMI JOIN (SELECT 1 AS a) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Right SEMI join with empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a +) AS a RIGHT SEMI JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Left ANTI join with empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a WHERE 1=0 +) AS a LEFT ANTI JOIN (SELECT 1 AS a) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Right ANTI join with empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a +) AS a RIGHT ANTI JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index f325d5567603..dbdb7fc76b8b 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -626,10 +626,7 @@ SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id query TT explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) ---- -logical_plan -01)LeftSemi Join: t1.t1_id = __correlated_sq_1.t2_id -02)--TableScan: t1 projection=[t1_id, t1_name] -03)--EmptyRelation +logical_plan EmptyRelation query IT rowsort SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) From 1f3ba116a428c932fc34b7cbfbecd1b8fa1d60ab Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sat, 22 Jun 2024 00:13:53 +0800 Subject: [PATCH 06/11] doc: Update links in the documantation (#11044) --- datafusion/optimizer/src/analyzer/mod.rs | 2 +- datafusion/optimizer/src/optimizer.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 121e46cc950f..32bb2bc70452 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -57,7 +57,7 @@ pub mod type_coercion; /// Use [`SessionState::add_analyzer_rule`] to register additional /// `AnalyzerRule`s. /// -/// [`SessionState::add_analyzer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionState.html#method.add_analyzer_rule +/// [`SessionState::add_analyzer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html#method.add_analyzer_rule pub trait AnalyzerRule { /// Rewrite `plan` fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 998eeb7167ee..75ebc92ece33 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -68,7 +68,7 @@ use crate::utils::log_plan; /// `OptimizerRule`s. /// /// [`AnalyzerRule`]: crate::analyzer::AnalyzerRule -/// [`SessionState::add_optimizer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionState.html#method.add_optimizer_rule +/// [`SessionState::add_optimizer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html#method.add_optimizer_rule pub trait OptimizerRule { /// Try and rewrite `plan` to an optimized form, returning None if the plan From 5498a02853021fc07a57743990d207e25ada27f4 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 21 Jun 2024 12:14:57 -0400 Subject: [PATCH 07/11] Add drop_columns to dataframe api (#11010) * Add drop_columns to dataframe api #11007 * Prettier cleanup * Added additional drop_columns tests and fixed issue with nonexistent columns. --- datafusion/core/src/dataframe/mod.rs | 169 +++++++++++++++++++++++++++ docs/source/user-guide/dataframe.md | 1 + 2 files changed, 170 insertions(+) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index b5c58eff577c..43ac3992be78 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -244,6 +244,42 @@ impl DataFrame { }) } + /// Returns a new DataFrame containing all columns except the specified columns. + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = df.drop_columns(&["a"])?; + /// # Ok(()) + /// # } + /// ``` + pub fn drop_columns(self, columns: &[&str]) -> Result { + let fields_to_drop = columns + .iter() + .map(|name| { + self.plan + .schema() + .qualified_field_with_unqualified_name(name) + }) + .filter(|r| r.is_ok()) + .collect::>>()?; + let expr: Vec = self + .plan + .schema() + .fields() + .into_iter() + .enumerate() + .map(|(idx, _)| self.plan.schema().qualified_field(idx)) + .filter(|(qualifier, f)| !fields_to_drop.contains(&(*qualifier, f))) + .map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field)))) + .collect(); + self.select(expr) + } + /// Expand each list element of a column to multiple rows. #[deprecated(since = "37.0.0", note = "use unnest_columns instead")] pub fn unnest_column(self, column: &str) -> Result { @@ -1799,6 +1835,139 @@ mod tests { Ok(()) } + #[tokio::test] + async fn drop_columns() -> Result<()> { + // build plan using Table API + let t = test_table().await?; + let t2 = t.drop_columns(&["c2", "c11"])?; + let plan = t2.plan.clone(); + + // build query using SQL + let sql_plan = create_plan( + "SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c12,c13 FROM aggregate_test_100", + ) + .await?; + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) + } + + #[tokio::test] + async fn drop_columns_with_duplicates() -> Result<()> { + // build plan using Table API + let t = test_table().await?; + let t2 = t.drop_columns(&["c2", "c11", "c2", "c2"])?; + let plan = t2.plan.clone(); + + // build query using SQL + let sql_plan = create_plan( + "SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c12,c13 FROM aggregate_test_100", + ) + .await?; + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) + } + + #[tokio::test] + async fn drop_columns_with_nonexistent_columns() -> Result<()> { + // build plan using Table API + let t = test_table().await?; + let t2 = t.drop_columns(&["canada", "c2", "rocks"])?; + let plan = t2.plan.clone(); + + // build query using SQL + let sql_plan = create_plan( + "SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13 FROM aggregate_test_100", + ) + .await?; + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) + } + + #[tokio::test] + async fn drop_columns_with_empty_array() -> Result<()> { + // build plan using Table API + let t = test_table().await?; + let t2 = t.drop_columns(&[])?; + let plan = t2.plan.clone(); + + // build query using SQL + let sql_plan = create_plan( + "SELECT c1,c2,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13 FROM aggregate_test_100", + ) + .await?; + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) + } + + #[tokio::test] + async fn drop_with_quotes() -> Result<()> { + // define data with a column name that has a "." in it: + let array1: Int32Array = [1, 10].into_iter().collect(); + let array2: Int32Array = [2, 11].into_iter().collect(); + let batch = RecordBatch::try_from_iter(vec![ + ("f\"c1", Arc::new(array1) as _), + ("f\"c2", Arc::new(array2) as _), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + + let df = ctx.table("t").await?.drop_columns(&["f\"c1"])?; + + let df_results = df.collect().await?; + + assert_batches_sorted_eq!( + [ + "+------+", + "| f\"c2 |", + "+------+", + "| 2 |", + "| 11 |", + "+------+" + ], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn drop_with_periods() -> Result<()> { + // define data with a column name that has a "." in it: + let array1: Int32Array = [1, 10].into_iter().collect(); + let array2: Int32Array = [2, 11].into_iter().collect(); + let batch = RecordBatch::try_from_iter(vec![ + ("f.c1", Arc::new(array1) as _), + ("f.c2", Arc::new(array2) as _), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + + let df = ctx.table("t").await?.drop_columns(&["f.c1"])?; + + let df_results = df.collect().await?; + + assert_batches_sorted_eq!( + ["+------+", "| f.c2 |", "+------+", "| 2 |", "| 11 |", "+------+"], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn aggregate() -> Result<()> { // build plan using DataFrame API diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index c0210200a246..744a719e77be 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -64,6 +64,7 @@ execution. The plan is evaluated (executed) when an action method is invoked, su | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | | aggregate | Perform an aggregate query with optional grouping expressions. | | distinct | Filter out duplicate rows. | +| drop_columns | Create a projection with all but the provided column names. | | except | Calculate the exception of two DataFrames. The two DataFrames must have exactly the same schema | | filter | Filter a DataFrame to only include rows that match the specified filter expression. | | intersect | Calculate the intersection of two DataFrames. The two DataFrames must have exactly the same schema | From fd5a68f8021b878d1f43e1451feb78154d9378d9 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 22 Jun 2024 00:17:48 +0800 Subject: [PATCH 08/11] Push down filter plan for non-unnest column (#11019) * push down non-unnest only Signed-off-by: jayzhan211 * add doc Signed-off-by: jayzhan211 * add doc Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * rewrite unnest push donw filter Signed-off-by: jayzhan211 * remove comment Signed-off-by: jayzhan211 * avoid double recurisve Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/expr/src/logical_plan/plan.rs | 9 ++ datafusion/optimizer/src/push_down_filter.rs | 84 +++++++++++-------- .../test_files/push_down_filter.slt | 23 ++++- 3 files changed, 80 insertions(+), 36 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 85958223ac97..6e7efaf39e3e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -281,6 +281,15 @@ pub enum LogicalPlan { RecursiveQuery(RecursiveQuery), } +impl Default for LogicalPlan { + fn default() -> Self { + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + }) + } +} + impl LogicalPlan { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 295039af2f19..68339a84649d 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -693,8 +693,9 @@ impl OptimizerRule for PushDownFilter { insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter) } LogicalPlan::Projection(projection) => { + let predicates = split_conjunction_owned(filter.predicate.clone()); let (new_projection, keep_predicate) = - rewrite_projection(filter.predicate.clone(), projection)?; + rewrite_projection(predicates, projection)?; if new_projection.transformed { match keep_predicate { None => Ok(new_projection), @@ -709,41 +710,54 @@ impl OptimizerRule for PushDownFilter { } } LogicalPlan::Unnest(mut unnest) => { - // collect all the Expr::Column in predicate recursively - let mut accum: HashSet = HashSet::new(); - expr_to_columns(&filter.predicate, &mut accum)?; + let predicates = split_conjunction_owned(filter.predicate.clone()); + let mut non_unnest_predicates = vec![]; + let mut unnest_predicates = vec![]; + for predicate in predicates { + // collect all the Expr::Column in predicate recursively + let mut accum: HashSet = HashSet::new(); + expr_to_columns(&predicate, &mut accum)?; + + if unnest.exec_columns.iter().any(|c| accum.contains(c)) { + unnest_predicates.push(predicate); + } else { + non_unnest_predicates.push(predicate); + } + } - if unnest.exec_columns.iter().any(|c| accum.contains(c)) { + // Unnest predicates should not be pushed down. + // If no non-unnest predicates exist, early return + if non_unnest_predicates.is_empty() { filter.input = Arc::new(LogicalPlan::Unnest(unnest)); return Ok(Transformed::no(LogicalPlan::Filter(filter))); } - // Unnest is built above Projection, so we only take Projection into consideration - match unwrap_arc(unnest.input) { - LogicalPlan::Projection(projection) => { - let (new_projection, keep_predicate) = - rewrite_projection(filter.predicate.clone(), projection)?; - unnest.input = Arc::new(new_projection.data); - - if new_projection.transformed { - match keep_predicate { - None => Ok(Transformed::yes(LogicalPlan::Unnest(unnest))), - Some(keep_predicate) => Ok(Transformed::yes( - LogicalPlan::Filter(Filter::try_new( - keep_predicate, - Arc::new(LogicalPlan::Unnest(unnest)), - )?), - )), - } - } else { - filter.input = Arc::new(LogicalPlan::Unnest(unnest)); - Ok(Transformed::no(LogicalPlan::Filter(filter))) - } - } - child => { - filter.input = Arc::new(child); - Ok(Transformed::no(LogicalPlan::Filter(filter))) - } + // Push down non-unnest filter predicate + // Unnest + // Unenst Input (Projection) + // -> rewritten to + // Unnest + // Filter + // Unenst Input (Projection) + + let unnest_input = std::mem::take(&mut unnest.input); + + let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new( + conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty. + unnest_input, + )?); + + // Directly assign new filter plan as the new unnest's input. + // The new filter plan will go through another rewrite pass since the rule itself + // is applied recursively to all the child from top to down + let unnest_plan = + insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?; + + match conjunction(unnest_predicates) { + None => Ok(unnest_plan), + Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter( + Filter::try_new(predicate, Arc::new(unnest_plan.data))?, + ))), } } LogicalPlan::Union(ref union) => { @@ -958,6 +972,10 @@ impl OptimizerRule for PushDownFilter { /// `plan` is a LogicalPlan for `projection` with possibly a new FilterExec below it. /// `remaining_predicate` is any part of the predicate that could not be pushed down /// +/// # Args +/// - predicates: Split predicates like `[foo=5, bar=6]` +/// - projection: The target projection plan to push down the predicates +/// /// # Example /// /// Pushing a predicate like `foo=5 AND bar=6` with an input plan like this: @@ -974,7 +992,7 @@ impl OptimizerRule for PushDownFilter { /// ... /// ``` fn rewrite_projection( - predicate: Expr, + predicates: Vec, projection: Projection, ) -> Result<(Transformed, Option)> { // A projection is filter-commutable if it do not contain volatile predicates or contain volatile @@ -994,7 +1012,7 @@ fn rewrite_projection( let mut push_predicates = vec![]; let mut keep_predicates = vec![]; - for expr in split_conjunction_owned(predicate) { + for expr in predicates { if contain(&expr, &volatile_map) { keep_predicates.push(expr); } else { diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 5029ab170a18..3ca187ddee84 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -67,17 +67,34 @@ select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where 5 2 # Could push the filter (column1 = 2) down below unnest -# https://github.com/apache/datafusion/issues/11016 query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; ---- logical_plan 01)Projection: unnest(v.column2) AS uc2, v.column1 -02)--Filter: unnest(v.column2) > Int64(3) AND v.column1 = Int64(2) +02)--Filter: unnest(v.column2) > Int64(3) 03)----Unnest: lists[unnest(v.column2)] structs[] 04)------Projection: v.column2 AS unnest(v.column2), v.column1 -05)--------TableScan: v projection=[column1, column2] +05)--------Filter: v.column1 = Int64(2) +06)----------TableScan: v projection=[column1, column2] + +query II +select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; +---- +3 2 +4 2 +5 2 +# only non-unnest filter in AND clause could be pushed down +query TT +explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; +---- +logical_plan +01)Projection: unnest(v.column2) AS uc2, v.column1 +02)--Filter: unnest(v.column2) > Int64(3) OR v.column1 = Int64(2) +03)----Unnest: lists[unnest(v.column2)] structs[] +04)------Projection: v.column2 AS unnest(v.column2), v.column1 +05)--------TableScan: v projection=[column1, column2] statement ok drop table v; From 4a0c7f35a02ef71fcea518a3566edc9b56866cae Mon Sep 17 00:00:00 2001 From: Marvin Lanhenke <62298609+marvinlanhenke@users.noreply.github.com> Date: Fri, 21 Jun 2024 18:18:14 +0200 Subject: [PATCH 09/11] Consider timezones with `UTC` and `+00:00` to be the same (#10960) * feat: add temporal_coercion check * fix: add return stmt * chore: add slts * fix: remove println * Update datafusion/expr/src/type_coercion/binary.rs --------- Co-authored-by: Andrew Lamb --- datafusion/expr/src/type_coercion/binary.rs | 14 +++++++---- .../sqllogictest/test_files/timestamps.slt | 23 +++++++++++++++++++ 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 615bb3ac568c..ea9d0c2fe72e 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -1050,12 +1050,16 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { let tz = match (lhs_tz, rhs_tz) { - // can't cast across timezones (Some(lhs_tz), Some(rhs_tz)) => { - if lhs_tz != rhs_tz { - return None; - } else { - Some(lhs_tz.clone()) + match (lhs_tz.as_ref(), rhs_tz.as_ref()) { + // UTC and "+00:00" are the same by definition. Most other timezones + // do not have a 1-1 mapping between timezone and an offset from UTC + ("UTC", "+00:00") | ("+00:00", "UTC") => Some(lhs_tz.clone()), + (lhs, rhs) if lhs == rhs => Some(lhs_tz.clone()), + // can't cast across timezones + _ => { + return None; + } } } (Some(lhs_tz), None) => Some(lhs_tz.clone()), diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 7d5d601bbfdd..96d846d449e1 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -2801,3 +2801,26 @@ query B select current_time = current_time; ---- true + +# Test temporal coercion for UTC +query ? +select arrow_cast('2024-06-17T11:00:00', 'Timestamp(Nanosecond, Some("UTC"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("UTC"))'); +---- +0 days -1 hours 0 mins 0.000000 secs + +query ? +select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("+00:00"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("UTC"))'); +---- +0 days 1 hours 0 mins 0.000000 secs + +query ? +select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("UTC"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("+00:00"))'); +---- +0 days 1 hours 0 mins 0.000000 secs + +# not supported: coercion across timezones +query error +select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("UTC"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("+01:00"))'); + +query error +select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("+00:00"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("+01:00"))'); From 6dffc53e76fb9a785dafd55106dd4033e744bf8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Sat, 22 Jun 2024 00:20:21 +0800 Subject: [PATCH 10/11] Deprecate `OptimizerRule::try_optimize` (#11022) * Deprecate OptimizerRule::try_optimize * optimize_children * Apply review suggestions * Fix clippy lint --- datafusion-examples/examples/rewrite_expr.rs | 48 ++++++++++--------- .../tests/user_defined/user_defined_plan.rs | 44 +++++++++++------ datafusion/optimizer/src/lib.rs | 1 + datafusion/optimizer/src/optimizer.rs | 21 +++++++- datafusion/optimizer/src/utils.rs | 17 +++++-- 5 files changed, 89 insertions(+), 42 deletions(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index d8965888eab6..556687a46ab4 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -18,13 +18,13 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{plan_err, Result, ScalarValue}; +use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, }; use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule}; -use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::{utils, OptimizerConfig, OptimizerContext, OptimizerRule}; +use datafusion_optimizer::optimizer::{ApplyOrder, Optimizer}; +use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; use datafusion_sql::sqlparser::parser::Parser; @@ -133,30 +133,34 @@ impl OptimizerRule for MyOptimizerRule { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { - // recurse down and optimize children first - let optimized_plan = utils::optimize_children(self, plan, config)?; - match optimized_plan { - Some(LogicalPlan::Filter(filter)) => { + unreachable!() + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result, DataFusionError> { + match plan { + LogicalPlan::Filter(filter) => { let predicate = my_rewrite(filter.predicate.clone())?; - Ok(Some(LogicalPlan::Filter(Filter::try_new( + Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( predicate, - filter.input, + filter.input.clone(), )?))) } - Some(optimized_plan) => Ok(Some(optimized_plan)), - None => match plan { - LogicalPlan::Filter(filter) => { - let predicate = my_rewrite(filter.predicate.clone())?; - Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input.clone(), - )?))) - } - _ => Ok(None), - }, + _ => Ok(Transformed::no(plan)), } } } diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 07622e48afaf..ebf907c5e2c0 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -80,7 +80,7 @@ use datafusion::{ Expr, Extension, Limit, LogicalPlan, Sort, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, }, - optimizer::{optimize_children, OptimizerConfig, OptimizerRule}, + optimizer::{OptimizerConfig, OptimizerRule}, physical_expr::EquivalenceProperties, physical_plan::{ DisplayAs, DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan, @@ -92,6 +92,8 @@ use datafusion::{ }; use async_trait::async_trait; +use datafusion_common::tree_node::Transformed; +use datafusion_optimizer::optimizer::ApplyOrder; use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches @@ -282,9 +284,29 @@ impl OptimizerRule for TopKOptimizerRule { // Example rewrite pass to insert a user defined LogicalPlanNode fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { + unreachable!() + } + + fn name(&self) -> &str { + "topk" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result, DataFusionError> { // Note: this code simply looks for the pattern of a Limit followed by a // Sort and replaces it by a TopK node. It does not handle many // edge cases (e.g multiple sort columns, sort ASC / DESC), etc. @@ -292,7 +314,7 @@ impl OptimizerRule for TopKOptimizerRule { fetch: Some(fetch), input, .. - }) = plan + }) = &plan { if let LogicalPlan::Sort(Sort { ref expr, @@ -302,12 +324,10 @@ impl OptimizerRule for TopKOptimizerRule { { if expr.len() == 1 { // we found a sort with a single sort expr, replace with a a TopK - return Ok(Some(LogicalPlan::Extension(Extension { + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { node: Arc::new(TopKPlanNode { k: *fetch, - input: self - .try_optimize(input.as_ref(), config)? - .unwrap_or_else(|| input.as_ref().clone()), + input: input.as_ref().clone(), expr: expr[0].clone(), }), }))); @@ -315,13 +335,7 @@ impl OptimizerRule for TopKOptimizerRule { } } - // If we didn't find the Limit/Sort combination, recurse as - // normal and build the result. - optimize_children(self, plan, config) - } - - fn name(&self) -> &str { - "topk" + Ok(Transformed::no(plan)) } } diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index c172d5979756..a6a9e5cf26ea 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -61,6 +61,7 @@ pub mod test; pub use analyzer::{Analyzer, AnalyzerRule}; pub use optimizer::{Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule}; +#[allow(deprecated)] pub use utils::optimize_children; pub(crate) mod join_key_set; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 75ebc92ece33..7a3ea6ed4cc1 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -77,6 +77,10 @@ pub trait OptimizerRule { /// Note this API will be deprecated in the future as it requires `clone`ing /// the input plan, which can be expensive. OptimizerRules should implement /// [`Self::rewrite`] instead. + #[deprecated( + since = "40.0.0", + note = "please implement supports_rewrite and rewrite instead" + )] fn try_optimize( &self, plan: &LogicalPlan, @@ -332,6 +336,7 @@ fn optimize_plan_node( return rule.rewrite(plan, config); } + #[allow(deprecated)] rule.try_optimize(&plan, config).map(|maybe_plan| { match maybe_plan { Some(new_plan) => { @@ -483,7 +488,7 @@ mod tests { use std::sync::{Arc, Mutex}; use datafusion_common::tree_node::Transformed; - use datafusion_common::{plan_err, DFSchema, DFSchemaRef, Result}; + use datafusion_common::{plan_err, DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::logical_plan::EmptyRelation; use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; @@ -667,12 +672,24 @@ mod tests { _: &LogicalPlan, _: &dyn OptimizerConfig, ) -> Result> { - plan_err!("rule failed") + unreachable!() } fn name(&self) -> &str { "bad rule" } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + _plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result, DataFusionError> { + plan_err!("rule failed") + } } /// Replaces whatever plan with a single table scan diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index fd47cb23b108..6218140409b5 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -35,6 +35,10 @@ use log::{debug, trace}; /// This also handles the case when the `plan` is a [`LogicalPlan::Explain`]. /// /// Returning `Ok(None)` indicates that the plan can't be optimized by the `optimizer`. +#[deprecated( + since = "40.0.0", + note = "please use OptimizerRule::apply_order with ApplyOrder::BottomUp instead" +)] pub fn optimize_children( optimizer: &impl OptimizerRule, plan: &LogicalPlan, @@ -43,9 +47,16 @@ pub fn optimize_children( let mut new_inputs = Vec::with_capacity(plan.inputs().len()); let mut plan_is_changed = false; for input in plan.inputs() { - let new_input = optimizer.try_optimize(input, config)?; - plan_is_changed = plan_is_changed || new_input.is_some(); - new_inputs.push(new_input.unwrap_or_else(|| input.clone())) + if optimizer.supports_rewrite() { + let new_input = optimizer.rewrite(input.clone(), config)?; + plan_is_changed = plan_is_changed || new_input.transformed; + new_inputs.push(new_input.data); + } else { + #[allow(deprecated)] + let new_input = optimizer.try_optimize(input, config)?; + plan_is_changed = plan_is_changed || new_input.is_some(); + new_inputs.push(new_input.unwrap_or_else(|| input.clone())) + } } if plan_is_changed { let exprs = plan.expressions(); From 098ba30ce5a950f51f56d43a4cdcc192af2bf9e1 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Fri, 21 Jun 2024 19:25:40 +0300 Subject: [PATCH 11/11] Relax combine partial final rule (#10913) * Minor changes * Minor changes * Re-introduce group by expression check --- .../combine_partial_final_agg.rs | 66 +++++------------ .../sqllogictest/test_files/group_by.slt | 71 +++++++++++++++++++ datafusion/sqllogictest/test_files/joins.slt | 23 +++--- 3 files changed, 99 insertions(+), 61 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index b57f36f728d7..940b09131a77 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -27,8 +27,7 @@ use crate::physical_plan::ExecutionPlan; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; +use datafusion_physical_expr::{physical_exprs_equal, AggregateExpr, PhysicalExpr}; /// CombinePartialFinalAggregate optimizer rule combines the adjacent Partial and Final AggregateExecs /// into a Single AggregateExec if their grouping exprs and aggregate exprs equal. @@ -132,19 +131,23 @@ type GroupExprsRef<'a> = ( &'a [Option>], ); -type GroupExprs = ( - PhysicalGroupBy, - Vec>, - Vec>>, -); - fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool { - let (final_group_by, final_aggr_expr, final_filter_expr) = - normalize_group_exprs(final_agg); - let (input_group_by, input_aggr_expr, input_filter_expr) = - normalize_group_exprs(partial_agg); - - final_group_by.eq(&input_group_by) + let (final_group_by, final_aggr_expr, final_filter_expr) = final_agg; + let (input_group_by, input_aggr_expr, input_filter_expr) = partial_agg; + + // Compare output expressions of the partial, and input expressions of the final operator. + physical_exprs_equal( + &input_group_by.output_exprs(), + &final_group_by.input_exprs(), + ) && input_group_by.groups() == final_group_by.groups() + && input_group_by.null_expr().len() == final_group_by.null_expr().len() + && input_group_by + .null_expr() + .iter() + .zip(final_group_by.null_expr().iter()) + .all(|((lhs_expr, lhs_str), (rhs_expr, rhs_str))| { + lhs_expr.eq(rhs_expr) && lhs_str == rhs_str + }) && final_aggr_expr.len() == input_aggr_expr.len() && final_aggr_expr .iter() @@ -160,41 +163,6 @@ fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool { ) } -// To compare the group expressions between the final and partial aggregations, need to discard all the column indexes and compare -fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs { - let (group, agg, filter) = group_exprs; - let new_group_expr = group - .expr() - .iter() - .map(|(expr, name)| (discard_column_index(expr.clone()), name.clone())) - .collect::>(); - let new_group = PhysicalGroupBy::new( - new_group_expr, - group.null_expr().to_vec(), - group.groups().to_vec(), - ); - (new_group, agg.to_vec(), filter.to_vec()) -} - -fn discard_column_index(group_expr: Arc) -> Arc { - group_expr - .clone() - .transform(|expr| { - let normalized_form: Option> = - match expr.as_any().downcast_ref::() { - Some(column) => Some(Arc::new(Column::new(column.name(), 0))), - None => None, - }; - Ok(if let Some(normalized_form) = normalized_form { - Transformed::yes(normalized_form) - } else { - Transformed::no(expr) - }) - }) - .data() - .unwrap_or(group_expr) -} - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 8ccf3ae85345..fff3977fe1e6 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5064,3 +5064,74 @@ statement error DataFusion error: Error during planning: Cannot find column with SELECT a, b, COUNT(1) FROM multiple_ordered_table GROUP BY 1, 2, 4, 5, 6; + +statement ok +set datafusion.execution.target_partitions = 1; + +# Create a table that contains various keywords, with their corresponding timestamps +statement ok +CREATE TABLE keywords_stream ( + ts TIMESTAMP, + sn INTEGER PRIMARY KEY, + keyword VARCHAR NOT NULL +); + +statement ok +INSERT INTO keywords_stream(ts, sn, keyword) VALUES +('2024-01-01T00:00:00Z', '0', 'Drug'), +('2024-01-01T00:00:05Z', '1', 'Bomb'), +('2024-01-01T00:00:10Z', '2', 'Theft'), +('2024-01-01T00:00:15Z', '3', 'Gun'), +('2024-01-01T00:00:20Z', '4', 'Calm'); + +# Create a table that contains alert keywords +statement ok +CREATE TABLE ALERT_KEYWORDS(keyword VARCHAR NOT NULL); + +statement ok +INSERT INTO ALERT_KEYWORDS VALUES +('Drug'), +('Bomb'), +('Theft'), +('Gun'), +('Knife'), +('Fire'); + +query TT +explain SELECT + DATE_BIN(INTERVAL '2' MINUTE, ts, '2000-01-01') AS ts_chunk, + COUNT(keyword) AS alert_keyword_count +FROM + keywords_stream +WHERE + keywords_stream.keyword IN (SELECT keyword FROM ALERT_KEYWORDS) +GROUP BY + ts_chunk; +---- +logical_plan +01)Projection: date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01")) AS ts_chunk, COUNT(keywords_stream.keyword) AS alert_keyword_count +02)--Aggregate: groupBy=[[date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"), keywords_stream.ts, TimestampNanosecond(946684800000000000, None)) AS date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))]], aggr=[[COUNT(keywords_stream.keyword)]] +03)----LeftSemi Join: keywords_stream.keyword = __correlated_sq_1.keyword +04)------TableScan: keywords_stream projection=[ts, keyword] +05)------SubqueryAlias: __correlated_sq_1 +06)--------TableScan: alert_keywords projection=[keyword] +physical_plan +01)ProjectionExec: expr=[date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))@0 as ts_chunk, COUNT(keywords_stream.keyword)@1 as alert_keyword_count] +02)--AggregateExec: mode=Single, gby=[date_bin(IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }, ts@0, 946684800000000000) as date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))], aggr=[COUNT(keywords_stream.keyword)] +03)----CoalesceBatchesExec: target_batch_size=2 +04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(keyword@0, keyword@1)] +05)--------MemoryExec: partitions=1, partition_sizes=[1] +06)--------MemoryExec: partitions=1, partition_sizes=[1] + +query PI +SELECT + DATE_BIN(INTERVAL '2' MINUTE, ts, '2000-01-01') AS ts_chunk, + COUNT(keyword) AS alert_keyword_count +FROM + keywords_stream +WHERE + keywords_stream.keyword IN (SELECT keyword FROM ALERT_KEYWORDS) +GROUP BY + ts_chunk; +---- +2024-01-01T00:00:00 4 diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 0cecd7935a74..4b62f2561260 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -1382,18 +1382,17 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[COUNT(alias1)] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(alias1)] -05)--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[] -06)----------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[] -07)------------CoalesceBatchesExec: target_batch_size=2 -08)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0] -09)----------------CoalesceBatchesExec: target_batch_size=2 -10)------------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 -11)--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -12)----------------------MemoryExec: partitions=1, partition_sizes=[1] -13)----------------CoalesceBatchesExec: target_batch_size=2 -14)------------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -15)--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -16)----------------------MemoryExec: partitions=1, partition_sizes=[1] +05)--------AggregateExec: mode=SinglePartitioned, gby=[t1_id@0 as alias1], aggr=[] +06)----------CoalesceBatchesExec: target_batch_size=2 +07)------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0] +08)--------------CoalesceBatchesExec: target_batch_size=2 +09)----------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +10)------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +11)--------------------MemoryExec: partitions=1, partition_sizes=[1] +12)--------------CoalesceBatchesExec: target_batch_size=2 +13)----------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +14)------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +15)--------------------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.explain.logical_plan_only = true;