From 230c68c02bf0c3d5b7d50d24145eb50604420d4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Mon, 13 May 2024 12:53:56 +0100 Subject: [PATCH 01/11] Add `simplify` method to aggregate function (#10354) * add simplify method for aggregate function * simplify returns closure --- .../examples/simplify_udaf_expression.rs | 180 ++++++++++++++++++ datafusion/expr/src/function.rs | 13 ++ datafusion/expr/src/udaf.rs | 33 +++- .../simplify_expressions/expr_simplifier.rs | 105 +++++++++- 4 files changed, 328 insertions(+), 3 deletions(-) create mode 100644 datafusion-examples/examples/simplify_udaf_expression.rs diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs new file mode 100644 index 0000000000000..92deb20272e41 --- /dev/null +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::{Field, Schema}; +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use datafusion_expr::function::AggregateFunctionSimplification; +use datafusion_expr::simplify::SimplifyInfo; + +use std::{any::Any, sync::Arc}; + +use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch}; +use datafusion::error::Result; +use datafusion::{assert_batches_eq, prelude::*}; +use datafusion_common::cast::as_float64_array; +use datafusion_expr::{ + expr::{AggregateFunction, AggregateFunctionDefinition}, + function::AccumulatorArgs, + Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, +}; + +/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user +/// defined aggregate function with a different expression which is defined in the `simplify` method. + +#[derive(Debug, Clone)] +struct BetterAvgUdaf { + signature: Signature, +} + +impl BetterAvgUdaf { + /// Create a new instance of the GeoMeanUdaf struct + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for BetterAvgUdaf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "better_avg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + unimplemented!("should not be invoked") + } + + fn state_fields( + &self, + _name: &str, + _value_type: DataType, + _ordering_fields: Vec, + ) -> Result> { + unimplemented!("should not be invoked") + } + + fn groups_accumulator_supported(&self) -> bool { + true + } + + fn create_groups_accumulator(&self) -> Result> { + unimplemented!("should not get here"); + } + // we override method, to return new expression which would substitute + // user defined function call + fn simplify(&self) -> Option { + // as an example for this functionality we replace UDF function + // with build-in aggregate function to illustrate the use + let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction, + _: &dyn SimplifyInfo| { + Ok(Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::BuiltIn( + // yes it is the same Avg, `BetterAvgUdaf` was just a + // marketing pitch :) + datafusion_expr::aggregate_function::AggregateFunction::Avg, + ), + args: aggregate_function.args, + distinct: aggregate_function.distinct, + filter: aggregate_function.filter, + order_by: aggregate_function.order_by, + null_treatment: aggregate_function.null_treatment, + })) + }; + + Some(Box::new(simplify)) + } +} + +// create local session context with an in-memory table +fn create_context() -> Result { + use datafusion::datasource::MemTable; + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + ])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])), + Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])), + ], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![16.0])), + Arc::new(Float32Array::from(vec![2.0])), + ], + )?; + + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + let better_avg = AggregateUDF::from(BetterAvgUdaf::new()); + ctx.register_udaf(better_avg.clone()); + + let result = ctx + .sql("SELECT better_avg(a) FROM t group by b") + .await? + .collect() + .await?; + + let expected = [ + "+-----------------+", + "| better_avg(t.a) |", + "+-----------------+", + "| 7.5 |", + "+-----------------+", + ]; + + assert_batches_eq!(expected, &result); + + let df = ctx.table("t").await?; + let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?; + + let results = df.collect().await?; + let result = as_float64_array(results[0].column(0))?; + + assert!((result.value(0) - 7.5).abs() < f64::EPSILON); + println!("The average of [2,4,8,16] is {}", result.value(0)); + + Ok(()) +} diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 7a92a50ae15df..4e4d77924a9d9 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -97,3 +97,16 @@ pub type PartitionEvaluatorFactory = /// its state, given its return datatype. pub type StateTypeFunction = Arc Result>> + Send + Sync>; + +/// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure +/// A closure with two arguments: +/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked +/// * 'info': [crate::simplify::SimplifyInfo] +/// +/// closure returns simplified [Expr] or an error. +pub type AggregateFunctionSimplification = Box< + dyn Fn( + crate::expr::AggregateFunction, + &dyn crate::simplify::SimplifyInfo, + ) -> Result, +>; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index e5a47ddcd8b6a..95121d78e7aa6 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,7 +17,7 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions -use crate::function::AccumulatorArgs; +use crate::function::{AccumulatorArgs, AggregateFunctionSimplification}; use crate::groups_accumulator::GroupsAccumulator; use crate::utils::format_state_name; use crate::{Accumulator, Expr}; @@ -199,6 +199,12 @@ impl AggregateUDF { pub fn coerce_types(&self, _args: &[DataType]) -> Result> { not_impl_err!("coerce_types not implemented for {:?} yet", self.name()) } + /// Do the function rewrite + /// + /// See [`AggregateUDFImpl::simplify`] for more details. + pub fn simplify(&self) -> Option { + self.inner.simplify() + } } impl From for AggregateUDF @@ -358,6 +364,31 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn aliases(&self) -> &[String] { &[] } + + /// Optionally apply per-UDaF simplification / rewrite rules. + /// + /// This can be used to apply function specific simplification rules during + /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default + /// implementation does nothing. + /// + /// Note that DataFusion handles simplifying arguments and "constant + /// folding" (replacing a function call with constant arguments such as + /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such + /// optimizations manually for specific UDFs. + /// + /// # Returns + /// + /// [None] if simplify is not defined or, + /// + /// Or, a closure with two arguments: + /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked + /// * 'info': [crate::simplify::SimplifyInfo] + /// + /// closure returns simplified [Expr] or an error. + /// + fn simplify(&self) -> Option { + None + } } /// AggregateUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5122de4f09a7a..55052542a8bf9 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -32,7 +32,7 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{InList, InSubquery}; +use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, @@ -1382,6 +1382,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction { + func_def: AggregateFunctionDefinition::UDF(ref udaf), + .. + }) => match (udaf.simplify(), expr) { + (Some(simplify_function), Expr::AggregateFunction(af)) => { + Transformed::yes(simplify_function(af, info)?) + } + (_, expr) => Transformed::no(expr), + }, + // // Rules for Between // @@ -1748,7 +1758,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result { #[cfg(test)] mod tests { use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; - use datafusion_expr::{interval_arithmetic::Interval, *}; + use datafusion_expr::{ + function::AggregateFunctionSimplification, interval_arithmetic::Interval, *, + }; use std::{ collections::HashMap, ops::{BitAnd, BitOr, BitXor}, @@ -3698,4 +3710,93 @@ mod tests { assert_eq!(expr, expected); assert_eq!(num_iter, 2); } + #[test] + fn test_simplify_udaf() { + let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify()); + let aggregate_function_expr = + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + udaf.into(), + vec![], + false, + None, + None, + None, + )); + + let expected = col("result_column"); + assert_eq!(simplify(aggregate_function_expr), expected); + + let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify()); + let aggregate_function_expr = + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + udaf.into(), + vec![], + false, + None, + None, + None, + )); + + let expected = aggregate_function_expr.clone(); + assert_eq!(simplify(aggregate_function_expr), expected); + } + + /// A Mock UDAF which defines `simplify` to be used in tests + /// related to UDAF simplification + #[derive(Debug, Clone)] + struct SimplifyMockUdaf { + simplify: bool, + } + + impl SimplifyMockUdaf { + /// make simplify method return new expression + fn new_with_simplify() -> Self { + Self { simplify: true } + } + /// make simplify method return no change + fn new_without_simplify() -> Self { + Self { simplify: false } + } + } + + impl AggregateUDFImpl for SimplifyMockUdaf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "mock_simplify" + } + + fn signature(&self) -> &Signature { + unimplemented!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!("not needed for tests") + } + + fn accumulator( + &self, + _acc_args: function::AccumulatorArgs, + ) -> Result> { + unimplemented!("not needed for tests") + } + + fn groups_accumulator_supported(&self) -> bool { + unimplemented!("not needed for testing") + } + + fn create_groups_accumulator(&self) -> Result> { + unimplemented!("not needed for testing") + } + + fn simplify(&self) -> Option { + if self.simplify { + Some(Box::new(|_, _| Ok(col("result_column")))) + } else { + None + } + } + } } From 5fac581efbaffd0e6a9edf931182517524526afd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 13 May 2024 05:14:19 -0700 Subject: [PATCH 02/11] Add cast array test to sqllogictest (#10474) --- datafusion/sqllogictest/test_files/cast.slt | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/datafusion/sqllogictest/test_files/cast.slt b/datafusion/sqllogictest/test_files/cast.slt index 73862be60d9b9..4554c9292b6e6 100644 --- a/datafusion/sqllogictest/test_files/cast.slt +++ b/datafusion/sqllogictest/test_files/cast.slt @@ -56,3 +56,16 @@ query I SELECT 10::bigint unsigned ---- 10 + +# cast array +query ? +SELECT CAST(MAKE_ARRAY(1, 2, 3) AS VARCHAR[]) +---- +[1, 2, 3] + + +# cast empty array +query ? +SELECT CAST(MAKE_ARRAY() AS VARCHAR[]) +---- +[] From 53de994423fc85f655da232db7e807c2a38276ea Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 13 May 2024 09:10:13 -0400 Subject: [PATCH 03/11] Add Expr::try_as_col, deprecate Expr::try_into_col (#10448) --- datafusion/expr/src/expr.rs | 23 +++++++++++++++++++ datafusion/expr/src/expr_rewriter/mod.rs | 2 +- datafusion/expr/src/logical_plan/builder.rs | 10 +++++--- datafusion/expr/src/logical_plan/plan.rs | 14 +++++++++-- datafusion/optimizer/src/push_down_filter.rs | 4 ++-- .../simplify_expressions/inlist_simplifier.rs | 2 +- datafusion/proto/src/logical_plan/mod.rs | 12 +++++++--- 7 files changed, 55 insertions(+), 12 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index f0f41a4c55c5d..660a45c27a296 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1264,6 +1264,7 @@ impl Expr { }) } + #[deprecated(since = "39.0.0", note = "use try_as_col instead")] pub fn try_into_col(&self) -> Result { match self { Expr::Column(it) => Ok(it.clone()), @@ -1271,6 +1272,28 @@ impl Expr { } } + /// Return a reference to the inner `Column` if any + /// + /// returns `None` if the expression is not a `Column` + /// + /// Example + /// ``` + /// # use datafusion_common::Column; + /// use datafusion_expr::{col, Expr}; + /// let expr = col("foo"); + /// assert_eq!(expr.try_as_col(), Some(&Column::from("foo"))); + /// + /// let expr = col("foo").alias("bar"); + /// assert_eq!(expr.try_as_col(), None); + /// ``` + pub fn try_as_col(&self) -> Option<&Column> { + if let Expr::Column(it) = self { + Some(it) + } else { + None + } + } + /// Return all referenced columns of this expression. pub fn to_columns(&self) -> Result> { let mut using_columns = HashSet::new(); diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 700dd560ec0b4..1441374bdba3d 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -221,7 +221,7 @@ pub fn coerce_plan_expr_for_schema( let exprs: Vec = plan.schema().iter().map(Expr::from).collect(); let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?; - let add_project = new_exprs.iter().any(|expr| expr.try_into_col().is_err()); + let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none()); if add_project { let projection = Projection::try_new(new_exprs, Arc::new(plan.clone()))?; Ok(LogicalPlan::Projection(projection)) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 3f15b84784f16..2c6cfd8f9d204 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1489,7 +1489,7 @@ pub fn wrap_projection_for_join_if_necessary( let mut projection = expand_wildcard(input_schema, &input, None)?; let join_key_items = alias_join_keys .iter() - .flat_map(|expr| expr.try_into_col().is_err().then_some(expr)) + .flat_map(|expr| expr.try_as_col().is_none().then_some(expr)) .cloned() .collect::>(); projection.extend(join_key_items); @@ -1504,8 +1504,12 @@ pub fn wrap_projection_for_join_if_necessary( let join_on = alias_join_keys .into_iter() .map(|key| { - key.try_into_col() - .or_else(|_| Ok(Column::from_name(key.display_name()?))) + if let Some(col) = key.try_as_col() { + Ok(col.clone()) + } else { + let name = key.display_name()?; + Ok(Column::from_name(name)) + } }) .collect::>>()?; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9832b69f841a9..266e7abc341a1 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -369,8 +369,18 @@ impl LogicalPlan { // The join keys in using-join must be columns. let columns = on.iter().try_fold(HashSet::new(), |mut accumu, (l, r)| { - accumu.insert(l.try_into_col()?); - accumu.insert(r.try_into_col()?); + let Some(l) = l.try_as_col().cloned() else { + return internal_err!( + "Invalid join key. Expected column, found {l:?}" + ); + }; + let Some(r) = r.try_as_col().cloned() else { + return internal_err!( + "Invalid join key. Expected column, found {r:?}" + ); + }; + accumu.insert(l); + accumu.insert(r); Result::<_, DataFusionError>::Ok(accumu) })?; using_columns.push(columns); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 9ce135b0d6464..57b38bd0d0fd0 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -535,8 +535,8 @@ fn push_down_join( .on .iter() .filter_map(|(l, r)| { - let left_col = l.try_into_col().ok()?; - let right_col = r.try_into_col().ok()?; + let left_col = l.try_as_col().cloned()?; + let right_col = r.try_as_col().cloned()?; Some((left_col, right_col)) }) .collect::>(); diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index 9dcb8ed15563a..c8638eb723955 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -52,7 +52,7 @@ impl TreeNodeRewriter for ShortenInListSimplifier { // expressions list.len() == 1 || list.len() <= THRESHOLD_INLINE_INLIST - && expr.try_into_col().is_ok() + && expr.try_as_col().is_some() ) { let first_val = list[0].clone(); diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index a6352bcefc3eb..83e58c3a22ccc 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -45,8 +45,9 @@ use datafusion::{ prelude::SessionContext, }; use datafusion_common::{ - context, internal_err, not_impl_err, parsers::CompressionTypeVariant, - plan_datafusion_err, DataFusionError, Result, TableReference, + context, internal_datafusion_err, internal_err, not_impl_err, + parsers::CompressionTypeVariant, plan_datafusion_err, DataFusionError, Result, + TableReference, }; use datafusion_expr::{ dml, @@ -695,7 +696,12 @@ impl AsLogicalPlan for LogicalPlanNode { // The equijoin keys in using-join must be column. let using_keys = left_keys .into_iter() - .map(|key| key.try_into_col()) + .map(|key| { + key.try_as_col().cloned() + .ok_or_else(|| internal_datafusion_err!( + "Using join keys must be column references, got: {key:?}" + )) + }) .collect::, _>>()?; builder.join_using( into_logical_plan!(join.right, ctx, extension_codec)?, From 3491f6bd5003624dc064db410eeaa41ef3f86acf Mon Sep 17 00:00:00 2001 From: Abrar Khan Date: Mon, 13 May 2024 18:53:59 +0530 Subject: [PATCH 04/11] Implement `From>` for `LogicalPlanBuilder` (#10466) * implement From> for LogicalPlanBuilder * make fmt happy * added test case and doc comment --- datafusion/expr/src/logical_plan/builder.rs | 47 ++++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2c6cfd8f9d204..6055537ac5118 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -42,8 +42,9 @@ use crate::utils::{ expand_wildcard, find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, - TableProviderFilterPushDown, TableSource, WriteOp, + and, binary_expr, logical_plan::tree_node::unwrap_arc, DmlStatement, Expr, + ExprSchemable, Operator, RecursiveQuery, TableProviderFilterPushDown, TableSource, + WriteOp, }; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; @@ -1138,6 +1139,31 @@ impl LogicalPlanBuilder { )?)) } } + +/// Converts a `Arc` into `LogicalPlanBuilder` +/// fn employee_schema() -> Schema { +/// Schema::new(vec![ +/// Field::new("id", DataType::Int32, false), +/// Field::new("first_name", DataType::Utf8, false), +/// Field::new("last_name", DataType::Utf8, false), +/// Field::new("state", DataType::Utf8, false), +/// Field::new("salary", DataType::Int32, false), +/// ]) +/// } +/// let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? +/// .sort(vec![ +/// Expr::Sort(expr::Sort::new(Box::new(col("state")), true, true)), +/// Expr::Sort(expr::Sort::new(Box::new(col("salary")), false, false)), +/// ])? +/// .build()?; +/// let plan_builder: LogicalPlanBuilder = Arc::new(plan).into(); + +impl From> for LogicalPlanBuilder { + fn from(plan: Arc) -> Self { + LogicalPlanBuilder::from(unwrap_arc(plan)) + } +} + pub fn change_redundant_column(fields: &Fields) -> Vec { let mut name_map = HashMap::new(); fields @@ -2144,4 +2170,21 @@ mod tests { ); Ok(()) } + + #[test] + fn plan_builder_from_logical_plan() -> Result<()> { + let plan = + table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? + .sort(vec![ + Expr::Sort(expr::Sort::new(Box::new(col("state")), true, true)), + Expr::Sort(expr::Sort::new(Box::new(col("salary")), false, false)), + ])? + .build()?; + + let plan_expected = format!("{plan:?}"); + let plan_builder: LogicalPlanBuilder = Arc::new(plan).into(); + assert_eq!(plan_expected, format!("{:?}", plan_builder.plan)); + + Ok(()) + } } From c7dbfeb79a0f41b6098184de33499546697ef631 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 13 May 2024 10:49:27 -0400 Subject: [PATCH 05/11] Minor: Improve documentation for `catalog.has_header` config option (#10452) * Minor: document catalog.has_header better * update docs * update test --- datafusion/common/src/config.rs | 3 ++- datafusion/sqllogictest/test_files/information_schema.slt | 2 +- docs/source/user-guide/configs.md | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index c60f843393f89..0f1d9b8f02644 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -181,7 +181,8 @@ config_namespace! { /// Type of `TableProvider` to use when loading `default` schema pub format: Option, default = None - /// If the file has a header + /// Default value for `format.has_header` for `CREATE EXTERNAL TABLE` + /// if not specified explicitly in the statement. pub has_header: bool, default = false } } diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index de00cf9d05473..6f31973fdb6fb 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -246,7 +246,7 @@ datafusion.catalog.create_default_catalog_and_schema true Whether the default ca datafusion.catalog.default_catalog datafusion The default catalog name - this impacts what SQL queries use if not specified datafusion.catalog.default_schema public The default schema name - this impacts what SQL queries use if not specified datafusion.catalog.format NULL Type of `TableProvider` to use when loading `default` schema -datafusion.catalog.has_header false If the file has a header +datafusion.catalog.has_header false Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. datafusion.catalog.information_schema true Should DataFusion provide access to `information_schema` virtual tables for displaying schema information datafusion.catalog.location NULL Location scanned to load tables for `default` schema datafusion.execution.aggregate.scalar_update_factor 10 Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index ef2a2a4119e33..0cfd81eff75a8 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -43,7 +43,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | | datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | | datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | -| datafusion.catalog.has_header | false | If the file has a header | +| datafusion.catalog.has_header | false | Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. | | datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | | datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | | datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | From 9cc981b06115ee40b53384c287689ce0e07950bc Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 13 May 2024 10:49:51 -0400 Subject: [PATCH 06/11] Minor: Simplify conjunction and disjunction, improve docs (#10446) --- datafusion/expr/src/utils.rs | 37 ++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 0c1084674d8e0..43e8ff7b23d64 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1107,7 +1107,7 @@ fn split_binary_impl<'a>( /// assert_eq!(conjunction(split), Some(expr)); /// ``` pub fn conjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.and(expr)) + filters.into_iter().reduce(Expr::and) } /// Combines an array of filter expressions into a single filter @@ -1115,12 +1115,41 @@ pub fn conjunction(filters: impl IntoIterator) -> Option { /// logical OR. /// /// Returns None if the filters array is empty. +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::utils::disjunction; +/// // a=1 OR b=2 +/// let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use disjuncton to join them together with `OR` +/// assert_eq!(disjunction(split), Some(expr)); +/// ``` pub fn disjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.or(expr)) + filters.into_iter().reduce(Expr::or) } -/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with -/// its predicate be all `predicates` ANDed. +/// Returns a new [LogicalPlan] that filters the output of `plan` with a +/// [LogicalPlan::Filter] with all `predicates` ANDed. +/// +/// # Example +/// Before: +/// ```text +/// plan +/// ``` +/// +/// After: +/// ```text +/// Filter(predicate) +/// plan +/// ``` pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { // reduce filters to a single filter with an AND let predicate = predicates From a2eca291ad9d586222f042ab4c068feeb055526b Mon Sep 17 00:00:00 2001 From: ClSlaid Date: Tue, 14 May 2024 00:03:25 +0800 Subject: [PATCH 07/11] Stop copying LogicalPlan and Exprs in `ReplaceDistinctWithAggregate` (#10460) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * patch: implement rewrite for RDWA Signed-off-by: cailue * refactor: rewrite replace_distinct_aggregate Signed-off-by: 蔡略 * patch: recorrect aggr_expr Signed-off-by: 蔡略 * Update datafusion/optimizer/src/replace_distinct_aggregate.rs --------- Signed-off-by: cailue Signed-off-by: 蔡略 Co-authored-by: Andrew Lamb --- .../src/replace_distinct_aggregate.rs | 73 +++++++++++++------ 1 file changed, 49 insertions(+), 24 deletions(-) diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 4f68e2623f403..404f054cb9fa9 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -19,7 +19,9 @@ use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{Column, Result}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{internal_err, Column, Result}; +use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; use datafusion_expr::{ aggregate_function::AggregateFunction as AggregateFunctionFunc, col, @@ -66,20 +68,24 @@ impl ReplaceDistinctWithAggregate { } impl OptimizerRule for ReplaceDistinctWithAggregate { - fn try_optimize( + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, + plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result> { match plan { LogicalPlan::Distinct(Distinct::All(input)) => { - let group_expr = expand_wildcard(input.schema(), input, None)?; - let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( - input.clone(), + let group_expr = expand_wildcard(input.schema(), &input, None)?; + let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, vec![], )?); - Ok(Some(aggregate)) + Ok(Transformed::yes(aggr_plan)) } LogicalPlan::Distinct(Distinct::On(DistinctOn { select_expr, @@ -88,13 +94,15 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { input, schema, })) => { + let expr_cnt = on_expr.len(); + // Construct the aggregation expression to be used to fetch the selected expressions. let aggr_expr = select_expr - .iter() + .into_iter() .map(|e| { Expr::AggregateFunction(AggregateFunction::new( AggregateFunctionFunc::FirstValue, - vec![e.clone()], + vec![e], false, None, sort_expr.clone(), @@ -103,45 +111,62 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { }) .collect::>(); + let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?; + let group_expr = normalize_cols(on_expr, input.as_ref())?; + // Build the aggregation plan - let plan = LogicalPlanBuilder::from(input.as_ref().clone()) - .aggregate(on_expr.clone(), aggr_expr.to_vec())? - .build()?; + let plan = LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, aggr_expr, + )?); + // TODO use LogicalPlanBuilder directly rather than recreating the Aggregate + // when https://github.com/apache/datafusion/issues/10485 is available + let lpb = LogicalPlanBuilder::from(plan); - let plan = if let Some(sort_expr) = sort_expr { + let plan = if let Some(mut sort_expr) = sort_expr { // While sort expressions were used in the `FIRST_VALUE` aggregation itself above, // this on it's own isn't enough to guarantee the proper output order of the grouping // (`ON`) expression, so we need to sort those as well. - LogicalPlanBuilder::from(plan) - .sort(sort_expr[..on_expr.len()].to_vec())? - .build()? + + // truncate the sort_expr to the length of on_expr + sort_expr.truncate(expr_cnt); + + lpb.sort(sort_expr)?.build()? } else { - plan + lpb.build()? }; // Whereas the aggregation plan by default outputs both the grouping and the aggregation // expressions, for `DISTINCT ON` we only need to emit the original selection expressions. + let project_exprs = plan .schema() .iter() - .skip(on_expr.len()) + .skip(expr_cnt) .zip(schema.iter()) .map(|((new_qualifier, new_field), (old_qualifier, old_field))| { - Ok(col(Column::from((new_qualifier, new_field))) - .alias_qualified(old_qualifier.cloned(), old_field.name())) + col(Column::from((new_qualifier, new_field))) + .alias_qualified(old_qualifier.cloned(), old_field.name()) }) - .collect::>>()?; + .collect::>(); let plan = LogicalPlanBuilder::from(plan) .project(project_exprs)? .build()?; - Ok(Some(plan)) + Ok(Transformed::yes(plan)) } - _ => Ok(None), + _ => Ok(Transformed::no(plan)), } } + fn try_optimize( + &self, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + internal_err!("Should have called ReplaceDistinctWithAggregate::rewrite") + } + fn name(&self) -> &str { "replace_distinct_aggregate" } From adf0bfc757d2f9ba48c45d368578d07806858b89 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 13 May 2024 14:00:35 -0400 Subject: [PATCH 08/11] Stop copying LogicalPlan and Exprs in `EliminateCrossJoin` (4% faster planning) (#10431) * Stop copying LogicalPlan and Exprs in `EliminateCrossJoin` * Clarify when can_flatten_join_inputs runs * Use a single `map` --- .../optimizer/src/eliminate_cross_join.rs | 298 +++++++++++------- datafusion/optimizer/src/join_key_set.rs | 73 ++++- 2 files changed, 254 insertions(+), 117 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 923be75748037..9d871c50ad996 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -18,11 +18,13 @@ //! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available. use std::sync::Arc; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; use crate::join_key_set::JoinKeySet; -use datafusion_common::{plan_err, Result}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{internal_err, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::{ CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; @@ -39,65 +41,109 @@ impl EliminateCrossJoin { } } -/// Attempt to reorder join to eliminate cross joins to inner joins. -/// for queries: -/// 'select ... from a, b where a.x = b.y and b.xx = 100;' -/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' -/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) -/// or (a.x = b.y and b.xx = 200 and a.z=c.z);' -/// 'select ... from a, b where a.x > b.y' +/// Eliminate cross joins by rewriting them to inner joins when possible. +/// +/// # Example +/// The initial plan for this query: +/// ```sql +/// select ... from a, b where a.x = b.y and b.xx = 100; +/// ``` +/// +/// Looks like this: +/// ```text +/// Filter(a.x = b.y AND b.xx = 100) +/// CrossJoin +/// TableScan a +/// TableScan b +/// ``` +/// +/// After the rule is applied, the plan will look like this: +/// ```text +/// Filter(b.xx = 100) +/// InnerJoin(a.x = b.y) +/// TableScan a +/// TableScan b +/// ``` +/// +/// # Other Examples +/// * 'select ... from a, b where a.x = b.y and b.xx = 100;' +/// * 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' +/// * 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) +/// * or (a.x = b.y and b.xx = 200 and a.z=c.z);' +/// * 'select ... from a, b where a.x > b.y' +/// /// For above queries, the join predicate is available in filters and they are moved to /// join nodes appropriately +/// /// This fix helps to improve the performance of TPCH Q19. issue#78 impl OptimizerRule for EliminateCrossJoin { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { + internal_err!("Should have called EliminateCrossJoin::rewrite") + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + let plan_schema = plan.schema().clone(); let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec = vec![]; - let parent_predicate = match plan { - LogicalPlan::Filter(filter) => { - let input = filter.input.as_ref(); - match input { - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) - | LogicalPlan::CrossJoin(_) => { - if !try_flatten_join_inputs( - input, - &mut possible_join_keys, - &mut all_inputs, - )? { - return Ok(None); - } - extract_possible_join_keys( - &filter.predicate, - &mut possible_join_keys, - ); - Some(&filter.predicate) - } - _ => { - return utils::optimize_children(self, plan, config); - } - } + + let parent_predicate = if let LogicalPlan::Filter(filter) = plan { + // if input isn't a join that can potentially be rewritten + // avoid unwrapping the input + let rewriteable = matches!( + filter.input.as_ref(), + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) | LogicalPlan::CrossJoin(_) + ); + + if !rewriteable { + // recursively try to rewrite children + return rewrite_children(self, LogicalPlan::Filter(filter), config); } + + if !can_flatten_join_inputs(&filter.input) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + + let Filter { + input, predicate, .. + } = filter; + flatten_join_inputs( + unwrap_arc(input), + &mut possible_join_keys, + &mut all_inputs, + )?; + + extract_possible_join_keys(&predicate, &mut possible_join_keys); + Some(predicate) + } else if matches!( + plan, LogicalPlan::Join(Join { join_type: JoinType::Inner, .. - }) => { - if !try_flatten_join_inputs( - plan, - &mut possible_join_keys, - &mut all_inputs, - )? { - return Ok(None); - } - None + }) + ) { + if !can_flatten_join_inputs(&plan) { + return Ok(Transformed::no(plan)); } - _ => return utils::optimize_children(self, plan, config), + flatten_join_inputs(plan, &mut possible_join_keys, &mut all_inputs)?; + None + } else { + // recursively try to rewrite children + return rewrite_children(self, plan, config); }; // Join keys are handled locally: @@ -105,36 +151,36 @@ impl OptimizerRule for EliminateCrossJoin { let mut left = all_inputs.remove(0); while !all_inputs.is_empty() { left = find_inner_join( - &left, + left, &mut all_inputs, &possible_join_keys, &mut all_join_keys, )?; } - left = utils::optimize_children(self, &left, config)?.unwrap_or(left); + left = rewrite_children(self, left, config)?.data; - if plan.schema() != left.schema() { + if &plan_schema != left.schema() { left = LogicalPlan::Projection(Projection::new_from_schema( Arc::new(left), - plan.schema().clone(), + plan_schema.clone(), )); } let Some(predicate) = parent_predicate else { - return Ok(Some(left)); + return Ok(Transformed::yes(left)); }; // If there are no join keys then do nothing: if all_join_keys.is_empty() { - Filter::try_new(predicate.clone(), Arc::new(left)) - .map(|f| Some(LogicalPlan::Filter(f))) + Filter::try_new(predicate, Arc::new(left)) + .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))) } else { // Remove join expressions from filter: - match remove_join_expressions(predicate.clone(), &all_join_keys) { + match remove_join_expressions(predicate, &all_join_keys) { Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left)) - .map(|f| Some(LogicalPlan::Filter(f))), - _ => Ok(Some(left)), + .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))), + _ => Ok(Transformed::yes(left)), } } } @@ -144,49 +190,89 @@ impl OptimizerRule for EliminateCrossJoin { } } +fn rewrite_children( + optimizer: &impl OptimizerRule, + plan: LogicalPlan, + config: &dyn OptimizerConfig, +) -> Result> { + let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?; + + // recompute schema if the plan was transformed + if transformed_plan.transformed { + transformed_plan.map_data(|plan| plan.recompute_schema()) + } else { + Ok(transformed_plan) + } +} + /// Recursively accumulate possible_join_keys and inputs from inner joins /// (including cross joins). /// -/// Returns a boolean indicating whether the flattening was successful. -fn try_flatten_join_inputs( - plan: &LogicalPlan, +/// Assumes can_flatten_join_inputs has returned true and thus the plan can be +/// flattened. Adds all leaf inputs to `all_inputs` and join_keys to +/// possible_join_keys +fn flatten_join_inputs( + plan: LogicalPlan, possible_join_keys: &mut JoinKeySet, all_inputs: &mut Vec, -) -> Result { - let children = match plan { +) -> Result<()> { + match plan { LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { + // checked in can_flatten_join_inputs if join.filter.is_some() { - // The filter of inner join will lost, skip this rule. - // issue: https://github.com/apache/datafusion/issues/4844 - return Ok(false); + return internal_err!( + "should not have filter in inner join in flatten_join_inputs" + ); } - possible_join_keys.insert_all(join.on.iter()); - vec![&join.left, &join.right] + possible_join_keys.insert_all_owned(join.on); + flatten_join_inputs(unwrap_arc(join.left), possible_join_keys, all_inputs)?; + flatten_join_inputs(unwrap_arc(join.right), possible_join_keys, all_inputs)?; } LogicalPlan::CrossJoin(join) => { - vec![&join.left, &join.right] + flatten_join_inputs(unwrap_arc(join.left), possible_join_keys, all_inputs)?; + flatten_join_inputs(unwrap_arc(join.right), possible_join_keys, all_inputs)?; } _ => { - return plan_err!("flatten_join_inputs just can call join/cross_join"); + all_inputs.push(plan); } }; + Ok(()) +} - for child in children.iter() { - let child = child.as_ref(); +/// Returns true if the plan is a Join or Cross join could be flattened with +/// `flatten_join_inputs` +/// +/// Must stay in sync with `flatten_join_inputs` +fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { + // can only flatten inner / cross joins + match plan { + LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { + // The filter of inner join will lost, skip this rule. + // issue: https://github.com/apache/datafusion/issues/4844 + if join.filter.is_some() { + return false; + } + } + LogicalPlan::CrossJoin(_) => {} + _ => return false, + }; + + for child in plan.inputs() { match child { LogicalPlan::Join(Join { join_type: JoinType::Inner, .. }) | LogicalPlan::CrossJoin(_) => { - if !try_flatten_join_inputs(child, possible_join_keys, all_inputs)? { - return Ok(false); + if !can_flatten_join_inputs(child) { + return false; } } - _ => all_inputs.push(child.clone()), + // the child is not a join/cross join + _ => (), } } - Ok(true) + true } /// Finds the next to join with the left input plan, @@ -202,7 +288,7 @@ fn try_flatten_join_inputs( /// 1. Removes the first plan from `rights` /// 2. Returns `left_input CROSS JOIN right`. fn find_inner_join( - left_input: &LogicalPlan, + left_input: LogicalPlan, rights: &mut Vec, possible_join_keys: &JoinKeySet, all_join_keys: &mut JoinKeySet, @@ -237,7 +323,7 @@ fn find_inner_join( )?); return Ok(LogicalPlan::Join(Join { - left: Arc::new(left_input.clone()), + left: Arc::new(left_input), right: Arc::new(right_input), join_type: JoinType::Inner, join_constraint: JoinConstraint::On, @@ -259,7 +345,7 @@ fn find_inner_join( )?); Ok(LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(left_input.clone()), + left: Arc::new(left_input), right: Arc::new(right), schema: join_schema, })) @@ -341,12 +427,12 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: Vec<&str>) { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: Vec<&str>) { + let starting_schema = plan.schema().clone(); let rule = EliminateCrossJoin::new(); - let optimized_plan = rule - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); + let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); + assert!(transformed_plan.transformed, "failed to optimize plan"); + let optimized_plan = transformed_plan.data; let formatted = optimized_plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -355,13 +441,13 @@ mod tests { "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - assert_eq!(plan.schema(), optimized_plan.schema()) + assert_eq!(&starting_schema, optimized_plan.schema()) } - fn assert_optimization_rule_fails(plan: &LogicalPlan) { + fn assert_optimization_rule_fails(plan: LogicalPlan) { let rule = EliminateCrossJoin::new(); - let optimized_plan = rule.try_optimize(plan, &OptimizerContext::new()).unwrap(); - assert!(optimized_plan.is_none()); + let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); + assert!(!transformed_plan.transformed) } #[test] @@ -386,7 +472,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -414,7 +500,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -441,7 +527,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -471,7 +557,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -501,7 +587,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -527,7 +613,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -551,7 +637,7 @@ mod tests { .filter(col("t1.a").gt(lit(15u32)))? .build()?; - assert_optimization_rule_fails(&plan); + assert_optimization_rule_fails(plan); Ok(()) } @@ -598,7 +684,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -675,7 +761,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -750,7 +836,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -825,7 +911,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -904,7 +990,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -987,7 +1073,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1074,7 +1160,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1100,7 +1186,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1128,7 +1214,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1156,7 +1242,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1184,7 +1270,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1224,7 +1310,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } diff --git a/datafusion/optimizer/src/join_key_set.rs b/datafusion/optimizer/src/join_key_set.rs index c47afa012c174..cd8ed382f0690 100644 --- a/datafusion/optimizer/src/join_key_set.rs +++ b/datafusion/optimizer/src/join_key_set.rs @@ -66,20 +66,46 @@ impl JoinKeySet { } } + /// Same as [`Self::insert`] but avoids cloning expression if they + /// are owned + pub fn insert_owned(&mut self, left: Expr, right: Expr) -> bool { + if self.contains(&left, &right) { + false + } else { + self.inner.insert((left, right)); + true + } + } + /// Inserts potentially many join keys into the set, copying only when necessary /// /// returns true if any of the pairs were inserted pub fn insert_all<'a>( &mut self, - iter: impl Iterator, + iter: impl IntoIterator, ) -> bool { let mut inserted = false; - for (left, right) in iter { + for (left, right) in iter.into_iter() { inserted |= self.insert(left, right); } inserted } + /// Same as [`Self::insert_all`] but avoids cloning expressions if they are + /// already owned + /// + /// returns true if any of the pairs were inserted + pub fn insert_all_owned( + &mut self, + iter: impl IntoIterator, + ) -> bool { + let mut inserted = false; + for (left, right) in iter.into_iter() { + inserted |= self.insert_owned(left, right); + } + inserted + } + /// Inserts any join keys that are common to both `s1` and `s2` into self pub fn insert_intersection(&mut self, s1: JoinKeySet, s2: JoinKeySet) { // note can't use inner.intersection as we need to consider both (l, r) @@ -156,6 +182,15 @@ mod test { assert_eq!(set.len(), 2); } + #[test] + fn test_insert_owned() { + let mut set = JoinKeySet::new(); + assert!(set.insert_owned(col("a"), col("b"))); + assert!(set.contains(&col("a"), &col("b"))); + assert!(set.contains(&col("b"), &col("a"))); + assert!(!set.contains(&col("a"), &col("c"))); + } + #[test] fn test_contains() { let mut set = JoinKeySet::new(); @@ -217,18 +252,34 @@ mod test { } #[test] - fn test_insert_many() { + fn test_insert_all() { let mut set = JoinKeySet::new(); // insert (a=b), (b=c), (b=a) - set.insert_all( - vec![ - &(col("a"), col("b")), - &(col("b"), col("c")), - &(col("b"), col("a")), - ] - .into_iter(), - ); + set.insert_all(vec![ + &(col("a"), col("b")), + &(col("b"), col("c")), + &(col("b"), col("a")), + ]); + assert_eq!(set.len(), 2); + assert!(set.contains(&col("a"), &col("b"))); + assert!(set.contains(&col("b"), &col("c"))); + assert!(set.contains(&col("b"), &col("a"))); + + // should not contain (a=c) + assert!(!set.contains(&col("a"), &col("c"))); + } + + #[test] + fn test_insert_all_owned() { + let mut set = JoinKeySet::new(); + + // insert (a=b), (b=c), (b=a) + set.insert_all_owned(vec![ + (col("a"), col("b")), + (col("b"), col("c")), + (col("b"), col("a")), + ]); assert_eq!(set.len(), 2); assert!(set.contains(&col("a"), &col("b"))); assert!(set.contains(&col("b"), &col("c"))); From 5b74c2d1f8923b8f4f7cf7a660459a80bd947790 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Mon, 13 May 2024 21:18:29 +0300 Subject: [PATCH 09/11] Improved ergonomy for `CREATE EXTERNAL TABLE OPTIONS`: Don't require quotations for simple namespaced keys like `foo.bar` (#10483) * Don't require quotations for simple namespaced keys like foo.bar * Add comments clarifying parse error cases for unquoted namespaced keys --- datafusion/common/src/config.rs | 65 ++++++++----------- datafusion/core/src/execution/context/mod.rs | 24 ++++--- .../tests/cases/roundtrip_logical_plan.rs | 18 ++--- datafusion/sql/src/parser.rs | 24 +++++-- .../test_files/create_external_table.slt | 21 ++++-- .../test_files/tpch/create_tables.slt.part | 2 +- 6 files changed, 84 insertions(+), 70 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 0f1d9b8f02644..a4f937b6e2a3b 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -130,9 +130,9 @@ macro_rules! config_namespace { $( stringify!($field_name) => self.$field_name.set(rem, value), )* - _ => return Err(DataFusionError::Configuration(format!( + _ => return _config_err!( "Config value \"{}\" not found on {}", key, stringify!($struct_name) - ))) + ) } } @@ -676,22 +676,17 @@ impl ConfigOptions { /// Set a configuration option pub fn set(&mut self, key: &str, value: &str) -> Result<()> { - let (prefix, key) = key.split_once('.').ok_or_else(|| { - DataFusionError::Configuration(format!( - "could not find config namespace for key \"{key}\"", - )) - })?; + let Some((prefix, key)) = key.split_once('.') else { + return _config_err!("could not find config namespace for key \"{key}\""); + }; if prefix == "datafusion" { return ConfigField::set(self, key, value); } - let e = self.extensions.0.get_mut(prefix); - let e = e.ok_or_else(|| { - DataFusionError::Configuration(format!( - "Could not find config namespace \"{prefix}\"" - )) - })?; + let Some(e) = self.extensions.0.get_mut(prefix) else { + return _config_err!("Could not find config namespace \"{prefix}\""); + }; e.0.set(key, value) } @@ -1279,22 +1274,17 @@ impl TableOptions { /// /// A result indicating success or failure in setting the configuration option. pub fn set(&mut self, key: &str, value: &str) -> Result<()> { - let (prefix, _) = key.split_once('.').ok_or_else(|| { - DataFusionError::Configuration(format!( - "could not find config namespace for key \"{key}\"" - )) - })?; + let Some((prefix, _)) = key.split_once('.') else { + return _config_err!("could not find config namespace for key \"{key}\""); + }; if prefix == "format" { return ConfigField::set(self, key, value); } - let e = self.extensions.0.get_mut(prefix); - let e = e.ok_or_else(|| { - DataFusionError::Configuration(format!( - "Could not find config namespace \"{prefix}\"" - )) - })?; + let Some(e) = self.extensions.0.get_mut(prefix) else { + return _config_err!("Could not find config namespace \"{prefix}\""); + }; e.0.set(key, value) } @@ -1413,19 +1403,19 @@ impl ConfigField for TableParquetOptions { fn set(&mut self, key: &str, value: &str) -> Result<()> { // Determine if the key is a global, metadata, or column-specific setting if key.starts_with("metadata::") { - let k = - match key.split("::").collect::>()[..] { - [_meta] | [_meta, ""] => return Err(DataFusionError::Configuration( + let k = match key.split("::").collect::>()[..] { + [_meta] | [_meta, ""] => { + return _config_err!( "Invalid metadata key provided, missing key in metadata::" - .to_string(), - )), - [_meta, k] => k.into(), - _ => { - return Err(DataFusionError::Configuration(format!( + ) + } + [_meta, k] => k.into(), + _ => { + return _config_err!( "Invalid metadata key provided, found too many '::' in \"{key}\"" - ))) - } - }; + ) + } + }; self.key_value_metadata.insert(k, Some(value.into())); Ok(()) } else if key.contains("::") { @@ -1498,10 +1488,7 @@ macro_rules! config_namespace_with_hashmap { inner_value.set(inner_key, value) } - _ => Err(DataFusionError::Configuration(format!( - "Unrecognized key '{}'.", - key - ))), + _ => _config_err!("Unrecognized key '{key}'."), } } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index e69a249410b1f..2fc1a19c3386f 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -23,6 +23,8 @@ use std::ops::ControlFlow; use std::sync::{Arc, Weak}; use super::options::ReadOptions; +#[cfg(feature = "array_expressions")] +use crate::functions_array; use crate::{ catalog::information_schema::{InformationSchemaProvider, INFORMATION_SCHEMA}, catalog::listing_schema::ListingSchemaProvider, @@ -53,53 +55,49 @@ use crate::{ }, optimizer::analyzer::{Analyzer, AnalyzerRule}, optimizer::optimizer::{Optimizer, OptimizerConfig, OptimizerRule}, + physical_expr::{create_physical_expr, PhysicalExpr}, physical_optimizer::optimizer::{PhysicalOptimizer, PhysicalOptimizerRule}, physical_plan::ExecutionPlan, physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}, variable::{VarProvider, VarType}, }; - -#[cfg(feature = "array_expressions")] -use crate::functions_array; use crate::{functions, functions_aggregate}; use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_schema::Schema; -use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use datafusion_common::tree_node::TreeNode; use datafusion_common::{ alias::AliasGenerator, config::{ConfigExtension, TableOptions}, exec_err, not_impl_err, plan_datafusion_err, plan_err, - tree_node::{TreeNodeRecursion, TreeNodeVisitor}, + tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, DFSchema, SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ + expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, + simplify::SimplifyInfo, var_provider::is_system_variables, Expr, ExprSchemable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; +use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_sql::{ parser::{CopyToSource, CopyToStatement, DFParser}, planner::{object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel}, ResolvedTableReference, }; -use parking_lot::RwLock; use sqlparser::dialect::dialect_from_str; + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use parking_lot::RwLock; use url::Url; use uuid::Uuid; -use crate::physical_expr::PhysicalExpr; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::simplify::SimplifyInfo; -use datafusion_optimizer::simplify_expressions::ExprSimplifier; -use datafusion_physical_expr::create_physical_expr; mod avro; mod csv; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index e5e57c0bc8938..2927fd01d1b3a 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -15,6 +15,12 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::collections::HashMap; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; +use std::vec; + use arrow::array::{ArrayRef, FixedSizeListArray}; use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, @@ -24,6 +30,7 @@ use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp}; use datafusion::functions_aggregate::expr_fn::first_value; use datafusion::prelude::*; @@ -51,16 +58,11 @@ use datafusion_proto::bytes::{ logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, }; use datafusion_proto::logical_plan::to_proto::serialize_expr; -use datafusion_proto::logical_plan::LogicalExtensionCodec; -use datafusion_proto::logical_plan::{from_proto, DefaultLogicalExtensionCodec}; +use datafusion_proto::logical_plan::{ + from_proto, DefaultLogicalExtensionCodec, LogicalExtensionCodec, +}; use datafusion_proto::protobuf; -use std::any::Any; -use std::collections::HashMap; -use std::fmt::{self, Debug, Formatter}; -use std::sync::Arc; -use std::vec; -use datafusion::execution::FunctionRegistry; use prost::Message; #[cfg(feature = "json")] diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index f61c9cda63453..d09317271d23f 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -462,7 +462,21 @@ impl<'a> DFParser<'a> { pub fn parse_option_key(&mut self) -> Result { let next_token = self.parser.next_token(); match next_token.token { - Token::Word(Word { value, .. }) => Ok(value), + Token::Word(Word { value, .. }) => { + let mut parts = vec![value]; + while self.parser.consume_token(&Token::Period) { + let next_token = self.parser.next_token(); + if let Token::Word(Word { value, .. }) = next_token.token { + parts.push(value); + } else { + // Unquoted namespaced keys have to conform to the syntax + // "[\.]*". If we have a key that breaks this + // pattern, error out: + return self.parser.expected("key name", next_token); + } + } + Ok(parts.join(".")) + } Token::SingleQuotedString(s) => Ok(s), Token::DoubleQuotedString(s) => Ok(s), Token::EscapedStringLiteral(s) => Ok(s), @@ -712,15 +726,15 @@ impl<'a> DFParser<'a> { } else { self.parser.expect_keyword(Keyword::HEADER)?; self.parser.expect_keyword(Keyword::ROW)?; - return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS ('format.has_header' 'true')"); + return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS (format.has_header true)"); } } Keyword::DELIMITER => { - return parser_err!("DELIMITER clause is no longer in use. Please use the OPTIONS clause with 'format.delimiter' set appropriately, e.g., OPTIONS ('format.delimiter' ',')"); + return parser_err!("DELIMITER clause is no longer in use. Please use the OPTIONS clause with 'format.delimiter' set appropriately, e.g., OPTIONS (format.delimiter ',')"); } Keyword::COMPRESSION => { self.parser.expect_keyword(Keyword::TYPE)?; - return parser_err!("COMPRESSION TYPE clause is no longer in use. Please use the OPTIONS clause with 'format.compression' set appropriately, e.g., OPTIONS ('format.compression' 'gzip')"); + return parser_err!("COMPRESSION TYPE clause is no longer in use. Please use the OPTIONS clause with 'format.compression' set appropriately, e.g., OPTIONS (format.compression gzip)"); } Keyword::PARTITIONED => { self.parser.expect_keyword(Keyword::BY)?; @@ -933,7 +947,7 @@ mod tests { expect_parse_ok(sql, expected)?; // positive case with delimiter - let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS ('format.delimiter' '|')"; + let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS (format.delimiter '|')"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { name: "t".into(), diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index fca177bb61f0e..607c909fd63d5 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -190,8 +190,8 @@ LOCATION 'test_files/scratch/create_external_table/manual_partitioning/'; statement error DataFusion error: Error during planning: Option format.delimiter is specified multiple times CREATE EXTERNAL TABLE t STORED AS CSV OPTIONS ( 'format.delimiter' '*', - 'format.has_header' 'true', - 'format.delimiter' '|') + 'format.has_header' 'true', + 'format.delimiter' '|') LOCATION 'foo.csv'; # If a config does not belong to any namespace, we assume it is a 'format' option and apply the 'format' prefix for backwards compatibility. @@ -201,7 +201,20 @@ CREATE EXTERNAL TABLE IF NOT EXISTS region ( r_name VARCHAR, r_comment VARCHAR, r_rev VARCHAR, -) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' +) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' OPTIONS ( 'format.delimiter' '|', - 'has_header' 'false'); \ No newline at end of file + 'has_header' 'false'); + +# Verify that we do not need quotations for simple namespaced keys. +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS region ( + r_regionkey BIGINT, + r_name VARCHAR, + r_comment VARCHAR, + r_rev VARCHAR, +) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' +OPTIONS ( + format.delimiter '|', + has_header false, + compression gzip); diff --git a/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part b/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part index 111d24773055f..75bcbc198bef8 100644 --- a/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part @@ -121,4 +121,4 @@ CREATE EXTERNAL TABLE IF NOT EXISTS region ( r_name VARCHAR, r_comment VARCHAR, r_rev VARCHAR, -) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' OPTIONS ('format.delimiter' '|'); \ No newline at end of file +) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' OPTIONS ('format.delimiter' '|'); From 18fc37629250d22faa6ead109725ebf94a4fa532 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Tue, 14 May 2024 07:32:29 +0800 Subject: [PATCH 10/11] feat: allow `array_slice` to take an optional stride parameter (#10469) * feat: allow array_slice to take an optional stride parameter * Use ScalarUDF::call * Use create_function and add test * format * fix cargo doc --- datafusion/functions-array/src/array_has.rs | 6 +-- datafusion/functions-array/src/cardinality.rs | 2 +- datafusion/functions-array/src/concat.rs | 6 +-- datafusion/functions-array/src/dimension.rs | 4 +- datafusion/functions-array/src/empty.rs | 2 +- datafusion/functions-array/src/except.rs | 2 +- datafusion/functions-array/src/extract.rs | 23 +++++----- datafusion/functions-array/src/flatten.rs | 2 +- datafusion/functions-array/src/length.rs | 2 +- datafusion/functions-array/src/macros.rs | 44 +++++++++---------- datafusion/functions-array/src/make_array.rs | 2 +- datafusion/functions-array/src/position.rs | 4 +- datafusion/functions-array/src/range.rs | 4 +- datafusion/functions-array/src/remove.rs | 6 +-- datafusion/functions-array/src/repeat.rs | 2 +- datafusion/functions-array/src/replace.rs | 6 +-- datafusion/functions-array/src/resize.rs | 2 +- datafusion/functions-array/src/reverse.rs | 2 +- datafusion/functions-array/src/rewrite.rs | 2 +- datafusion/functions-array/src/set_ops.rs | 6 +-- datafusion/functions-array/src/sort.rs | 2 +- datafusion/functions-array/src/string.rs | 4 +- .../tests/cases/roundtrip_logical_plan.rs | 6 +++ 23 files changed, 74 insertions(+), 67 deletions(-) diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-array/src/array_has.rs index e5e8add95fbed..43d6046f4f828 100644 --- a/datafusion/functions-array/src/array_has.rs +++ b/datafusion/functions-array/src/array_has.rs @@ -34,19 +34,19 @@ use std::any::Any; use std::sync::Arc; // Create static instances of ScalarUDFs for each function -make_udf_function!(ArrayHas, +make_udf_expr_and_func!(ArrayHas, array_has, first_array second_array, // arg name "returns true, if the element appears in the first array, otherwise false.", // doc array_has_udf // internal function name ); -make_udf_function!(ArrayHasAll, +make_udf_expr_and_func!(ArrayHasAll, array_has_all, first_array second_array, // arg name "returns true if each element of the second array appears in the first array; otherwise, it returns false.", // doc array_has_all_udf // internal function name ); -make_udf_function!(ArrayHasAny, +make_udf_expr_and_func!(ArrayHasAny, array_has_any, first_array second_array, // arg name "returns true if at least one element of the second array appears in the first array; otherwise, it returns false.", // doc diff --git a/datafusion/functions-array/src/cardinality.rs b/datafusion/functions-array/src/cardinality.rs index ed9f8d01f9732..d6f2456313bc5 100644 --- a/datafusion/functions-array/src/cardinality.rs +++ b/datafusion/functions-array/src/cardinality.rs @@ -29,7 +29,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( Cardinality, cardinality, array, diff --git a/datafusion/functions-array/src/concat.rs b/datafusion/functions-array/src/concat.rs index f9d9bf4356ff1..a6fed84fa765f 100644 --- a/datafusion/functions-array/src/concat.rs +++ b/datafusion/functions-array/src/concat.rs @@ -36,7 +36,7 @@ use datafusion_expr::{ use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; -make_udf_function!( +make_udf_expr_and_func!( ArrayAppend, array_append, array element, // arg name @@ -96,7 +96,7 @@ impl ScalarUDFImpl for ArrayAppend { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayPrepend, array_prepend, element array, @@ -156,7 +156,7 @@ impl ScalarUDFImpl for ArrayPrepend { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayConcat, array_concat, "Concatenates arrays.", diff --git a/datafusion/functions-array/src/dimension.rs b/datafusion/functions-array/src/dimension.rs index 569eff66f7f45..1dc6520f1bc74 100644 --- a/datafusion/functions-array/src/dimension.rs +++ b/datafusion/functions-array/src/dimension.rs @@ -33,7 +33,7 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayDims, array_dims, array, @@ -88,7 +88,7 @@ impl ScalarUDFImpl for ArrayDims { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayNdims, array_ndims, array, diff --git a/datafusion/functions-array/src/empty.rs b/datafusion/functions-array/src/empty.rs index d5fa174eee5ff..9fe2c870496bc 100644 --- a/datafusion/functions-array/src/empty.rs +++ b/datafusion/functions-array/src/empty.rs @@ -28,7 +28,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayEmpty, array_empty, array, diff --git a/datafusion/functions-array/src/except.rs b/datafusion/functions-array/src/except.rs index 444c7c7587714..a56bab1e06116 100644 --- a/datafusion/functions-array/src/except.rs +++ b/datafusion/functions-array/src/except.rs @@ -31,7 +31,7 @@ use std::any::Any; use std::collections::HashSet; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayExcept, array_except, first_array second_array, diff --git a/datafusion/functions-array/src/extract.rs b/datafusion/functions-array/src/extract.rs index 0dbd106b6f189..842f4ec1b8398 100644 --- a/datafusion/functions-array/src/extract.rs +++ b/datafusion/functions-array/src/extract.rs @@ -44,7 +44,7 @@ use std::sync::Arc; use crate::utils::make_scalar_function; // Create static instances of ScalarUDFs for each function -make_udf_function!( +make_udf_expr_and_func!( ArrayElement, array_element, array element, @@ -52,15 +52,9 @@ make_udf_function!( array_element_udf ); -make_udf_function!( - ArraySlice, - array_slice, - array begin end stride, - "returns a slice of the array.", - array_slice_udf -); +create_func!(ArraySlice, array_slice_udf); -make_udf_function!( +make_udf_expr_and_func!( ArrayPopFront, array_pop_front, array, @@ -68,7 +62,7 @@ make_udf_function!( array_pop_front_udf ); -make_udf_function!( +make_udf_expr_and_func!( ArrayPopBack, array_pop_back, array, @@ -224,6 +218,15 @@ where Ok(arrow::array::make_array(data)) } +#[doc = "returns a slice of the array."] +pub fn array_slice(array: Expr, begin: Expr, end: Expr, stride: Option) -> Expr { + let args = match stride { + Some(stride) => vec![array, begin, end, stride], + None => vec![array, begin, end], + }; + array_slice_udf().call(args) +} + #[derive(Debug)] pub(super) struct ArraySlice { signature: Signature, diff --git a/datafusion/functions-array/src/flatten.rs b/datafusion/functions-array/src/flatten.rs index e2b50c6c02cc2..294d41ada7c34 100644 --- a/datafusion/functions-array/src/flatten.rs +++ b/datafusion/functions-array/src/flatten.rs @@ -31,7 +31,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( Flatten, flatten, array, diff --git a/datafusion/functions-array/src/length.rs b/datafusion/functions-array/src/length.rs index 9bbd11950d217..9cdcaddf8dff2 100644 --- a/datafusion/functions-array/src/length.rs +++ b/datafusion/functions-array/src/length.rs @@ -32,7 +32,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayLength, array_length, array, diff --git a/datafusion/functions-array/src/macros.rs b/datafusion/functions-array/src/macros.rs index c49f5830b8d5f..4e00aa39bd84f 100644 --- a/datafusion/functions-array/src/macros.rs +++ b/datafusion/functions-array/src/macros.rs @@ -19,8 +19,8 @@ /// /// 1. Single `ScalarUDF` instance /// -/// Creates a singleton `ScalarUDF` of the `$UDF` function named `$GNAME` and a -/// function named `$NAME` which returns that function named $NAME. +/// Creates a singleton `ScalarUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$SCALAR_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. /// /// This is used to ensure creating the list of `ScalarUDF` only happens once. /// @@ -41,10 +41,9 @@ /// * `arg`: 0 or more named arguments for the function /// * `DOC`: documentation string for the function /// * `SCALAR_UDF_FUNC`: name of the function to create (just) the `ScalarUDF` -/// * `GNAME`: name for the single static instance of the `ScalarUDF` /// /// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl -macro_rules! make_udf_function { +macro_rules! make_udf_expr_and_func { ($UDF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr , $SCALAR_UDF_FN:ident) => { paste::paste! { // "fluent expr_fn" style function @@ -55,25 +54,7 @@ macro_rules! make_udf_function { vec![$($arg),*], )) } - - /// Singleton instance of [`$UDF`], ensures the UDF is only created once - /// named STATIC_$(UDF). For example `STATIC_ArrayToString` - #[allow(non_upper_case_globals)] - static [< STATIC_ $UDF >]: std::sync::OnceLock> = - std::sync::OnceLock::new(); - - /// ScalarFunction that returns a [`ScalarUDF`] for [`$UDF`] - /// - /// [`ScalarUDF`]: datafusion_expr::ScalarUDF - pub fn $SCALAR_UDF_FN() -> std::sync::Arc { - [< STATIC_ $UDF >] - .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( - <$UDF>::new(), - )) - }) - .clone() - } + create_func!($UDF, $SCALAR_UDF_FN); } }; ($UDF:ty, $EXPR_FN:ident, $DOC:expr , $SCALAR_UDF_FN:ident) => { @@ -86,7 +67,24 @@ macro_rules! make_udf_function { arg, )) } + create_func!($UDF, $SCALAR_UDF_FN); + } + }; +} +/// Creates a singleton `ScalarUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$SCALAR_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. +/// +/// This is used to ensure creating the list of `ScalarUDF` only happens once. +/// +/// # Arguments +/// * `UDF`: name of the [`ScalarUDFImpl`] +/// * `SCALAR_UDF_FUNC`: name of the function to create (just) the `ScalarUDF` +/// +/// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl +macro_rules! create_func { + ($UDF:ty, $SCALAR_UDF_FN:ident) => { + paste::paste! { /// Singleton instance of [`$UDF`], ensures the UDF is only created once /// named STATIC_$(UDF). For example `STATIC_ArrayToString` #[allow(non_upper_case_globals)] diff --git a/datafusion/functions-array/src/make_array.rs b/datafusion/functions-array/src/make_array.rs index 4f7dda933f427..4723464dfaf29 100644 --- a/datafusion/functions-array/src/make_array.rs +++ b/datafusion/functions-array/src/make_array.rs @@ -35,7 +35,7 @@ use datafusion_expr::{Expr, TypeSignature}; use crate::utils::make_scalar_function; -make_udf_function!( +make_udf_expr_and_func!( MakeArray, make_array, "Returns an Arrow array using the specified input expressions.", diff --git a/datafusion/functions-array/src/position.rs b/datafusion/functions-array/src/position.rs index a5a7a7405aa98..efdb7dff0ce6e 100644 --- a/datafusion/functions-array/src/position.rs +++ b/datafusion/functions-array/src/position.rs @@ -37,7 +37,7 @@ use itertools::Itertools; use crate::utils::{compare_element_to_list, make_scalar_function}; -make_udf_function!( +make_udf_expr_and_func!( ArrayPosition, array_position, array element index, @@ -168,7 +168,7 @@ fn generic_position( Ok(Arc::new(UInt64Array::from(data))) } -make_udf_function!( +make_udf_expr_and_func!( ArrayPositions, array_positions, array element, // arg name diff --git a/datafusion/functions-array/src/range.rs b/datafusion/functions-array/src/range.rs index 150fe59602660..9a9829f961001 100644 --- a/datafusion/functions-array/src/range.rs +++ b/datafusion/functions-array/src/range.rs @@ -35,7 +35,7 @@ use datafusion_expr::{ use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( Range, range, start stop step, @@ -106,7 +106,7 @@ impl ScalarUDFImpl for Range { } } -make_udf_function!( +make_udf_expr_and_func!( GenSeries, gen_series, start stop step, diff --git a/datafusion/functions-array/src/remove.rs b/datafusion/functions-array/src/remove.rs index 21e373081054b..7645c1a57573a 100644 --- a/datafusion/functions-array/src/remove.rs +++ b/datafusion/functions-array/src/remove.rs @@ -32,7 +32,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayRemove, array_remove, array element, @@ -81,7 +81,7 @@ impl ScalarUDFImpl for ArrayRemove { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayRemoveN, array_remove_n, array element max, @@ -130,7 +130,7 @@ impl ScalarUDFImpl for ArrayRemoveN { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayRemoveAll, array_remove_all, array element, diff --git a/datafusion/functions-array/src/repeat.rs b/datafusion/functions-array/src/repeat.rs index 89b766bdcdfc1..df623c114818c 100644 --- a/datafusion/functions-array/src/repeat.rs +++ b/datafusion/functions-array/src/repeat.rs @@ -34,7 +34,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayRepeat, array_repeat, element count, // arg name diff --git a/datafusion/functions-array/src/replace.rs b/datafusion/functions-array/src/replace.rs index c32305bb454b8..7cea4945836eb 100644 --- a/datafusion/functions-array/src/replace.rs +++ b/datafusion/functions-array/src/replace.rs @@ -38,19 +38,19 @@ use std::any::Any; use std::sync::Arc; // Create static instances of ScalarUDFs for each function -make_udf_function!(ArrayReplace, +make_udf_expr_and_func!(ArrayReplace, array_replace, array from to, "replaces the first occurrence of the specified element with another specified element.", array_replace_udf ); -make_udf_function!(ArrayReplaceN, +make_udf_expr_and_func!(ArrayReplaceN, array_replace_n, array from to max, "replaces the first `max` occurrences of the specified element with another specified element.", array_replace_n_udf ); -make_udf_function!(ArrayReplaceAll, +make_udf_expr_and_func!(ArrayReplaceAll, array_replace_all, array from to, "replaces all occurrences of the specified element with another specified element.", diff --git a/datafusion/functions-array/src/resize.rs b/datafusion/functions-array/src/resize.rs index 561e98e8b76f2..63f28c9afa77c 100644 --- a/datafusion/functions-array/src/resize.rs +++ b/datafusion/functions-array/src/resize.rs @@ -30,7 +30,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayResize, array_resize, array size value, diff --git a/datafusion/functions-array/src/reverse.rs b/datafusion/functions-array/src/reverse.rs index 9be6405657033..3076013899ef5 100644 --- a/datafusion/functions-array/src/reverse.rs +++ b/datafusion/functions-array/src/reverse.rs @@ -30,7 +30,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayReverse, array_reverse, array, diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs index 416e79cbc0792..5280355a8224d 100644 --- a/datafusion/functions-array/src/rewrite.rs +++ b/datafusion/functions-array/src/rewrite.rs @@ -171,7 +171,7 @@ impl FunctionRewrite for ArrayFunctionRewriter { stop, stride, }, - }) => Transformed::yes(array_slice(*expr, *start, *stop, *stride)), + }) => Transformed::yes(array_slice(*expr, *start, *stop, Some(*stride))), _ => Transformed::no(expr), }; diff --git a/datafusion/functions-array/src/set_ops.rs b/datafusion/functions-array/src/set_ops.rs index 5f3087fafd6f2..40676b7cdcb88 100644 --- a/datafusion/functions-array/src/set_ops.rs +++ b/datafusion/functions-array/src/set_ops.rs @@ -37,7 +37,7 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; // Create static instances of ScalarUDFs for each function -make_udf_function!( +make_udf_expr_and_func!( ArrayUnion, array_union, array1 array2, @@ -45,7 +45,7 @@ make_udf_function!( array_union_udf ); -make_udf_function!( +make_udf_expr_and_func!( ArrayIntersect, array_intersect, first_array second_array, @@ -53,7 +53,7 @@ make_udf_function!( array_intersect_udf ); -make_udf_function!( +make_udf_expr_and_func!( ArrayDistinct, array_distinct, array, diff --git a/datafusion/functions-array/src/sort.rs b/datafusion/functions-array/src/sort.rs index af78712065fc8..16f271ef10ff5 100644 --- a/datafusion/functions-array/src/sort.rs +++ b/datafusion/functions-array/src/sort.rs @@ -30,7 +30,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArraySort, array_sort, array desc null_first, diff --git a/datafusion/functions-array/src/string.rs b/datafusion/functions-array/src/string.rs index 38059035005bd..4122ddbd45eb9 100644 --- a/datafusion/functions-array/src/string.rs +++ b/datafusion/functions-array/src/string.rs @@ -102,7 +102,7 @@ macro_rules! call_array_function { } // Create static instances of ScalarUDFs for each function -make_udf_function!( +make_udf_expr_and_func!( ArrayToString, array_to_string, array delimiter, // arg name @@ -160,7 +160,7 @@ impl ScalarUDFImpl for ArrayToString { } } -make_udf_function!( +make_udf_expr_and_func!( StringToArray, string_to_array, string delimiter null_string, // arg name diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 2927fd01d1b3a..ec215937dca82 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -582,7 +582,13 @@ async fn roundtrip_expr_api() -> Result<()> { make_array(vec![lit(1), lit(2), lit(3)]), lit(1), lit(2), + Some(lit(1)), + ), + array_slice( + make_array(vec![lit(1), lit(2), lit(3)]), lit(1), + lit(2), + None, ), array_pop_front(make_array(vec![lit(1), lit(2), lit(3)])), array_pop_back(make_array(vec![lit(1), lit(2), lit(3)])), From b8fab5cdf418e1fba5e6012b815a5bc40c7771cc Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 14 May 2024 10:22:28 +0800 Subject: [PATCH 11/11] Replace `GetFieldAccess` with indexing function in `SqlToRel ` (#10375) * use func in parser Signed-off-by: jayzhan211 * add tests Signed-off-by: jayzhan211 * add test Signed-off-by: jayzhan211 * rm test1 Signed-off-by: jayzhan211 * parser done Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix exprapi test Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * fix conflicts Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/core/tests/expr_api/mod.rs | 14 +-- datafusion/functions-array/src/rewrite.rs | 29 +---- datafusion/sql/src/expr/identifier.rs | 17 ++- datafusion/sql/src/expr/mod.rs | 48 ++++++++- datafusion/sqllogictest/test_files/expr.slt | 114 +++++++++++++++++++- 5 files changed, 172 insertions(+), 50 deletions(-) diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 0dde7604cce24..d7e839824b3be 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -60,9 +60,8 @@ fn test_eq_with_coercion() { #[test] fn test_get_field() { - // field access Expr::field() requires a rewrite to work evaluate_expr_test( - col("props").field("a"), + get_field(col("props"), lit("a")), vec![ "+------------+", "| expr |", @@ -77,11 +76,8 @@ fn test_get_field() { #[test] fn test_nested_get_field() { - // field access Expr::field() requires a rewrite to work, test when it is - // not the root expression evaluate_expr_test( - col("props") - .field("a") + get_field(col("props"), lit("a")) .eq(lit("2021-02-02")) .or(col("id").eq(lit(1))), vec![ @@ -98,9 +94,8 @@ fn test_nested_get_field() { #[test] fn test_list() { - // list access also requires a rewrite to work evaluate_expr_test( - col("list").index(lit(1i64)), + array_element(col("list"), lit(1i64)), vec![ "+------+", "| expr |", "+------+", "| one |", "| two |", "| five |", "+------+", @@ -110,9 +105,8 @@ fn test_list() { #[test] fn test_list_range() { - // range access also requires a rewrite to work evaluate_expr_test( - col("list").range(lit(1i64), lit(2i64)), + array_slice(col("list"), lit(1i64), lit(2i64), None), vec![ "+--------------+", "| expr |", diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs index 5280355a8224d..a7aba78c1dbe1 100644 --- a/datafusion/functions-array/src/rewrite.rs +++ b/datafusion/functions-array/src/rewrite.rs @@ -19,7 +19,6 @@ use crate::array_has::array_has_all; use crate::concat::{array_append, array_concat, array_prepend}; -use crate::extract::{array_element, array_slice}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; use datafusion_common::utils::list_ndims; @@ -27,8 +26,7 @@ use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::{BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Operator}; -use datafusion_functions::expr_fn::get_field; +use datafusion_expr::{BinaryExpr, Expr, Operator}; /// Rewrites expressions into function calls to array functions pub(crate) struct ArrayFunctionRewriter {} @@ -148,31 +146,6 @@ impl FunctionRewrite for ArrayFunctionRewriter { Transformed::yes(array_prepend(*left, *right)) } - Expr::GetIndexedField(GetIndexedField { - expr, - field: GetFieldAccess::NamedStructField { name }, - }) => { - let name = Expr::Literal(name); - Transformed::yes(get_field(*expr, name)) - } - - // expr[idx] ==> array_element(expr, idx) - Expr::GetIndexedField(GetIndexedField { - expr, - field: GetFieldAccess::ListIndex { key }, - }) => Transformed::yes(array_element(*expr, *key)), - - // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) - Expr::GetIndexedField(GetIndexedField { - expr, - field: - GetFieldAccess::ListRange { - start, - stop, - stride, - }, - }) => Transformed::yes(array_slice(*expr, *start, *stop, Some(*stride))), - _ => Transformed::no(expr), }; Ok(transformed) diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 713ad6f72c24a..d297b2e4df5b3 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -19,9 +19,9 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::Field; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, - TableReference, + ScalarValue, TableReference, }; -use datafusion_expr::{Case, Expr}; +use datafusion_expr::{expr::ScalarFunction, lit, Case, Expr}; use sqlparser::ast::{Expr as SQLExpr, Ident}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -133,7 +133,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ); } let nested_name = nested_names[0].to_string(); - Ok(Expr::Column(Column::from((qualifier, field))).field(nested_name)) + + let col = Expr::Column(Column::from((qualifier, field))); + if let Some(udf) = + self.context_provider.get_function_meta("get_field") + { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![col, lit(ScalarValue::from(nested_name))], + ))) + } else { + internal_err!("get_field not found") + } } // found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index ed5421edfbb01..6445c3f7a885d 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -29,7 +29,7 @@ use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable, - GetFieldAccess, GetIndexedField, Like, Literal, Operator, TryCast, + GetFieldAccess, Like, Literal, Operator, TryCast, }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; @@ -1019,10 +1019,48 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { expr }; - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(expr), - self.plan_indices(indices, schema, planner_context)?, - ))) + let field = self.plan_indices(indices, schema, planner_context)?; + match field { + GetFieldAccess::NamedStructField { name } => { + if let Some(udf) = self.context_provider.get_function_meta("get_field") { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![expr, lit(name)], + ))) + } else { + internal_err!("get_field not found") + } + } + // expr[idx] ==> array_element(expr, idx) + GetFieldAccess::ListIndex { key } => { + if let Some(udf) = + self.context_provider.get_function_meta("array_element") + { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![expr, *key], + ))) + } else { + internal_err!("get_field not found") + } + } + // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) + GetFieldAccess::ListRange { + start, + stop, + stride, + } => { + if let Some(udf) = self.context_provider.get_function_meta("array_slice") + { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![expr, *start, *stop, *stride], + ))) + } else { + internal_err!("array_slice not found") + } + } + } } } diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 4b5f4d770a036..2dc00cbc50017 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -2324,28 +2324,134 @@ host3 3.3 # can have an aggregate function with an inner CASE WHEN query TR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select + t2.server_host as host, + sum(( + case when t2.server_host is not null + then t2.server_load2 + end + )) + from ( + select + struct(time,load1,load2,host)['c2'] as server_load2, + struct(time,load1,load2,host)['c3'] as server_host + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; ---- host1 101 host2 202 host3 303 +# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 +query error +select + t2.server['c3'] as host, + sum(( + case when t2.server['c3'] is not null + then t2.server['c2'] + end + )) + from ( + select + struct(time,load1,load2,host) as server + from t1 + ) t2 + where t2.server['c3'] IS NOT NULL + group by t2.server['c3'] order by host; + # can have 2 projections with aggr(short_circuited), with different short-circuited expr query TRR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select + t2.server_host as host, + sum(coalesce(server_load1)), + sum(( + case when t2.server_host is not null + then t2.server_load2 + end + )) + from ( + select + struct(time,load1,load2,host)['c1'] as server_load1, + struct(time,load1,load2,host)['c2'] as server_load2, + struct(time,load1,load2,host)['c3'] as server_host + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; ---- host1 1.1 101 host2 2.2 202 host3 3.3 303 -# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN) +# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 +query error +select + t2.server['c3'] as host, + sum(coalesce(server['c1'])), + sum(( + case when t2.server['c3'] is not null + then t2.server['c2'] + end + )) + from ( + select + struct(time,load1,load2,host) as server, + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; + query TRR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select + t2.server_host as host, + sum(( + case when t2.server_host is not null + then server_load1 + end + )), + sum(( + case when server_host is not null + then server_load2 + end + )) + from ( + select + struct(time,load1,load2,host)['c1'] as server_load1, + struct(time,load1,load2,host)['c2'] as server_load2, + struct(time,load1,load2,host)['c3'] as server_host + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; ---- host1 1.1 101 host2 2.2 202 host3 3.3 303 +# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 +query error +select + t2.server['c3'] as host, + sum(( + case when t2.server['c3'] is not null + then t2.server['c1'] + end + )), + sum(( + case when t2.server['c3'] is not null + then t2.server['c2'] + end + )) + from ( + select + struct(time,load1,load2,host) as server + from t1 + ) t2 + where t2.server['c3'] IS NOT NULL + group by t2.server['c3'] order by host; + # can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce) query TRR select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;