From 3887087eb66a90b29869f2e2b57f28a36756afe9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 8 Jan 2025 06:32:04 -0500 Subject: [PATCH 01/19] Encapsulate fields of `EquivalenceProperties` (#14040) --- .../core/src/physical_optimizer/sanity_checker.rs | 2 +- .../core/tests/fuzz_cases/equivalence/ordering.rs | 8 ++++---- .../tests/fuzz_cases/equivalence/projection.rs | 8 ++++---- .../tests/fuzz_cases/equivalence/properties.rs | 4 ++-- .../core/tests/fuzz_cases/equivalence/utils.rs | 6 +++--- .../physical-expr/src/equivalence/ordering.rs | 14 ++++++++------ .../physical-expr/src/equivalence/properties.rs | 14 +++++++------- datafusion/physical-plan/src/memory.rs | 4 +++- datafusion/physical-plan/src/union.rs | 4 ++-- 9 files changed, 34 insertions(+), 30 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/sanity_checker.rs b/datafusion/core/src/physical_optimizer/sanity_checker.rs index f4b0f7c6069b..8e8787aec96b 100644 --- a/datafusion/core/src/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/src/physical_optimizer/sanity_checker.rs @@ -144,7 +144,7 @@ pub fn check_plan_sanity( plan_str, format_physical_sort_requirement_list(&sort_req), idx, - child_eq_props.oeq_class + child_eq_props.oeq_class() ); } } diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs index 525baadd14a5..ecf267185bae 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -68,8 +68,8 @@ fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { table_data_with_properties.clone(), )?; let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + "Error in test case requirement:{:?}, expected: {:?}, eq_properties {}", + requirement, expected, eq_properties ); // Check whether ordering_satisfy API result and // experimental result matches. @@ -141,8 +141,8 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { table_data_with_properties.clone(), )?; let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + "Error in test case requirement:{:?}, expected: {:?}, eq_properties: {}", + requirement, expected, eq_properties, ); // Check whether ordering_satisfy API result and // experimental result matches. diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs index 3df3e0348e42..f71df50fce2f 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -82,8 +82,8 @@ fn project_orderings_random() -> Result<()> { // Make sure each ordering after projection is valid. for ordering in projected_eq.oeq_class().iter() { let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs + "Error in test case ordering:{:?}, eq_properties {}, proj_exprs: {:?}", + ordering, eq_properties, proj_exprs, ); // Since ordered section satisfies schema, we expect // that result will be same after sort (e.g sort was unnecessary). @@ -179,8 +179,8 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { projected_batch.clone(), )?; let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping + "Error in test case requirement:{:?}, expected: {:?}, eq_properties: {}, projected_eq: {}, projection_mapping: {:?}", + requirement, expected, eq_properties, projected_eq, projection_mapping ); // Check whether ordering_satisfy API result and // experimental result matches. diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs index 82586bd79eda..fc21c620a711 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -83,8 +83,8 @@ fn test_find_longest_permutation_random() -> Result<()> { ); let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + "Error in test case ordering:{:?}, eq_properties: {}", + ordering, eq_properties ); assert_eq!(ordering.len(), indices.len(), "{}", err_msg); // Since ordered section satisfies schema, we expect diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index e18dab35fc91..f8d0ea8200c3 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -373,7 +373,7 @@ pub fn generate_table_for_eq_properties( }; // Fill constant columns - for constant in &eq_properties.constants { + for constant in eq_properties.constants() { let col = constant.expr().as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = @@ -382,7 +382,7 @@ pub fn generate_table_for_eq_properties( } // Fill columns based on ordering equivalences - for ordering in eq_properties.oeq_class.iter() { + for ordering in eq_properties.oeq_class().iter() { let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering .iter() .map(|PhysicalSortExpr { expr, options }| { @@ -406,7 +406,7 @@ pub fn generate_table_for_eq_properties( } // Fill columns based on equivalence groups - for eq_group in eq_properties.eq_group.iter() { + for eq_group in eq_properties.eq_group().iter() { let representative_array = get_representative_arr(eq_group, &schema_vec, Arc::clone(schema)) .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 24e2fc7dbaf5..0ae5f4af8f08 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -291,15 +291,17 @@ mod tests { }, ]); // finer ordering satisfies, crude ordering should return true - let mut eq_properties_finer = - EquivalenceProperties::new(Arc::clone(&input_schema)); - eq_properties_finer.oeq_class.push(finer.clone()); + let eq_properties_finer = EquivalenceProperties::new_with_orderings( + Arc::clone(&input_schema), + &[finer.clone()], + ); assert!(eq_properties_finer.ordering_satisfy(crude.as_ref())); // Crude ordering doesn't satisfy finer ordering. should return false - let mut eq_properties_crude = - EquivalenceProperties::new(Arc::clone(&input_schema)); - eq_properties_crude.oeq_class.push(crude); + let eq_properties_crude = EquivalenceProperties::new_with_orderings( + Arc::clone(&input_schema), + &[crude.clone()], + ); assert!(!eq_properties_crude.ordering_satisfy(finer.as_ref())); Ok(()) } diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index c3d458103285..d2eeccda2cae 100755 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -124,15 +124,15 @@ use itertools::Itertools; /// ``` #[derive(Debug, Clone)] pub struct EquivalenceProperties { - /// Collection of equivalence classes that store expressions with the same - /// value. - pub eq_group: EquivalenceGroup, - /// Equivalent sort expressions for this table. - pub oeq_class: OrderingEquivalenceClass, - /// Expressions whose values are constant throughout the table. + /// Distinct equivalence classes (exprs known to have the same expressions) + eq_group: EquivalenceGroup, + /// Equivalent sort expressions + oeq_class: OrderingEquivalenceClass, + /// Expressions whose values are constant + /// /// TODO: We do not need to track constants separately, they can be tracked /// inside `eq_groups` as `Literal` expressions. - pub constants: Vec, + constants: Vec, /// Schema associated with this object. schema: SchemaRef, } diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 521008ce9b02..c61a1f0ae5da 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -260,7 +260,9 @@ impl MemoryExec { ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?; sort_information = base_eqp .project(&projection_mapping, self.schema()) - .oeq_class + .oeq_class() + // TODO add a take / into to avoid the clone + .clone() .orderings; } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 6e768a3d87bc..cfa919425c54 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -843,9 +843,9 @@ mod tests { ) { // Check whether orderings are same. let lhs_orderings = lhs.oeq_class(); - let rhs_orderings = &rhs.oeq_class.orderings; + let rhs_orderings = rhs.oeq_class(); assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{}", err_msg); - for rhs_ordering in rhs_orderings { + for rhs_ordering in rhs_orderings.iter() { assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg); } } From b0bd899297c9d2c3a0f77bf8cc83f3e2c45995fd Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 8 Jan 2025 09:05:30 -0500 Subject: [PATCH 02/19] Encapsulate fields of `EquivalenceGroup` (#14039) --- .../tests/fuzz_cases/equivalence/utils.rs | 6 ++--- .../physical-expr/src/equivalence/class.rs | 25 ++++++++++++++----- .../physical-expr/src/equivalence/mod.rs | 6 ++--- .../src/equivalence/properties.rs | 5 ++-- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index f8d0ea8200c3..4d599879df67 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -179,7 +179,7 @@ fn add_equal_conditions_test() -> Result<()> { // This new entry is redundant, size shouldn't increase eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 2); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); @@ -188,7 +188,7 @@ fn add_equal_conditions_test() -> Result<()> { // however there shouldn't be any new equivalence class eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 3); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); @@ -202,7 +202,7 @@ fn add_equal_conditions_test() -> Result<()> { // Hence equivalent class count should decrease from 2 to 1. eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 5); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 9e535a94eb6e..cb11409479a8 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -20,12 +20,12 @@ use crate::{ expressions::Column, LexOrdering, LexRequirement, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use std::fmt::Display; -use std::sync::Arc; - use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{JoinType, ScalarValue}; use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; +use std::fmt::Display; +use std::sync::Arc; +use std::vec::IntoIter; use indexmap::{IndexMap, IndexSet}; @@ -323,11 +323,10 @@ impl Display for EquivalenceClass { } } -/// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each -/// class represents a distinct equivalence class in a relation. +/// A collection of distinct `EquivalenceClass`es #[derive(Debug, Clone)] pub struct EquivalenceGroup { - pub classes: Vec, + classes: Vec, } impl EquivalenceGroup { @@ -717,6 +716,20 @@ impl EquivalenceGroup { .zip(right_children) .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child)) } + + /// Return the inner classes of this equivalence group. + pub fn into_inner(self) -> Vec { + self.classes + } +} + +impl IntoIterator for EquivalenceGroup { + type Item = EquivalenceClass; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.classes.into_iter() + } } impl Display for EquivalenceGroup { diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index d4c14f7bc8ff..b50633d777f7 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -254,7 +254,7 @@ mod tests { // This new entry is redundant, size shouldn't increase eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 2); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); @@ -263,7 +263,7 @@ mod tests { // however there shouldn't be any new equivalence class eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 3); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); @@ -277,7 +277,7 @@ mod tests { // Hence equivalent class count should decrease from 2 to 1. eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 5); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index d2eeccda2cae..8c4d77a44c2d 100755 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -338,7 +338,6 @@ impl EquivalenceProperties { let normalized_expr = self.eq_group().normalize_expr(Arc::clone(expr)); let eq_class = self .eq_group - .classes .iter() .find_map(|class| { class @@ -1234,7 +1233,7 @@ impl EquivalenceProperties { // Rewrite equivalence classes according to the new schema: let mut eq_classes = vec![]; - for eq_class in self.eq_group.classes { + for eq_class in self.eq_group { let new_eq_exprs = eq_class .into_vec() .into_iter() @@ -2315,7 +2314,7 @@ mod tests { // At the output a1=a2=a3=a4 assert_eq!(out_properties.eq_group().len(), 1); - let eq_class = &out_properties.eq_group().classes[0]; + let eq_class = out_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_class.len(), 4); assert!(eq_class.contains(col_a1)); assert!(eq_class.contains(col_a2)); From 05b6f93d89f831008202c16c5c5b83c37cc88d8a Mon Sep 17 00:00:00 2001 From: cht42 <42912042+cht42@users.noreply.github.com> Date: Wed, 8 Jan 2025 20:11:49 +0400 Subject: [PATCH 03/19] Fix error on `array_distinct` when input is empty #13810 (#14034) * fix * add test * oops --------- Co-authored-by: Cyprien Huet --- datafusion/functions-nested/src/set_ops.rs | 5 ++++- datafusion/sqllogictest/test_files/array.slt | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index 70d69170fc1e..079e0e3ed214 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -513,7 +513,7 @@ fn general_array_distinct( array: &GenericListArray, field: &FieldRef, ) -> Result { - if array.len() == 0 { + if array.is_empty() { return Ok(Arc::new(array.clone()) as ArrayRef); } let dt = array.value_type(); @@ -542,6 +542,9 @@ fn general_array_distinct( }; new_arrays.push(array); } + if new_arrays.is_empty() { + return Ok(Arc::new(array.clone()) as ArrayRef); + } let offsets = OffsetBuffer::new(offsets.into()); let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); let values = compute::concat(&new_arrays_ref)?; diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 90003b28572a..83f228b90115 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5681,6 +5681,11 @@ select array_distinct(a) from values ([1, 2, 3]), (null), ([1, 3, 1]) as X(a); NULL [1, 3] +query ? +select array_distinct(arrow_cast(null, 'LargeList(Int64)')); +---- +NULL + query ? select array_distinct([]); ---- From 7af6aa9e51fde2bbf7a12c4096e999041ef6ffc6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 8 Jan 2025 11:44:22 -0500 Subject: [PATCH 04/19] Update petgraph requirement from 0.6.2 to 0.7.1 (#14045) * Update petgraph requirement from 0.6.2 to 0.7.1 Updates the requirements on [petgraph](https://github.com/petgraph/petgraph) to permit the latest version. - [Changelog](https://github.com/petgraph/petgraph/blob/master/RELEASES.rst) - [Commits](https://github.com/petgraph/petgraph/compare/petgraph@v0.6.2...petgraph@v0.7.1) --- updated-dependencies: - dependency-name: petgraph dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Update datafusion-cli/Cargo.lock --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Andrew Lamb --- datafusion-cli/Cargo.lock | 8 ++++---- datafusion/physical-expr/Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index e8b6e99d2dc4..dfc75f15b03b 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1789,9 +1789,9 @@ dependencies = [ [[package]] name = "fixedbitset" -version = "0.4.2" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "flatbuffers" @@ -2954,9 +2954,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" -version = "0.6.5" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", "indexmap", diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index d1de63a1e8fc..5e0832673697 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -52,7 +52,7 @@ indexmap = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "^1.0" -petgraph = "0.6.2" +petgraph = "0.7.1" [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } From c7e8858ac619890bff6f57ccca5fb4248a61cbdc Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 8 Jan 2025 13:43:08 -0500 Subject: [PATCH 05/19] Encapsulate fields of `OrderingEquivalenceClass` (make field non pub) (#14037) * Complete encapsulatug `OrderingEquivalenceClass` (make fields non pub) * fix doc --- .../physical-expr/src/equivalence/ordering.rs | 32 +++++-- .../src/equivalence/projection.rs | 6 +- .../src/equivalence/properties.rs | 92 +++++++++---------- datafusion/physical-plan/src/memory.rs | 6 +- 4 files changed, 72 insertions(+), 64 deletions(-) diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 0ae5f4af8f08..5dfa1b08f366 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -39,7 +39,7 @@ use arrow_schema::SortOptions; /// ordering. In this case, we say that these orderings are equivalent. #[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] pub struct OrderingEquivalenceClass { - pub orderings: Vec, + orderings: Vec, } impl OrderingEquivalenceClass { @@ -53,13 +53,20 @@ impl OrderingEquivalenceClass { self.orderings.clear(); } - /// Creates new ordering equivalence class from the given orderings. + /// Creates new ordering equivalence class from the given orderings + /// + /// Any redundant entries are removed pub fn new(orderings: Vec) -> Self { let mut result = Self { orderings }; result.remove_redundant_entries(); result } + /// Converts this OrderingEquivalenceClass to a vector of orderings. + pub fn into_inner(self) -> Vec { + self.orderings + } + /// Checks whether `ordering` is a member of this equivalence class. pub fn contains(&self, ordering: &LexOrdering) -> bool { self.orderings.contains(ordering) @@ -67,10 +74,12 @@ impl OrderingEquivalenceClass { /// Adds `ordering` to this equivalence class. #[allow(dead_code)] + #[deprecated( + since = "45.0.0", + note = "use OrderingEquivalenceClass::add_new_ordering instead" + )] fn push(&mut self, ordering: LexOrdering) { - self.orderings.push(ordering); - // Make sure that there are no redundant orderings: - self.remove_redundant_entries(); + self.add_new_ordering(ordering) } /// Checks whether this ordering equivalence class is empty. @@ -79,6 +88,9 @@ impl OrderingEquivalenceClass { } /// Returns an iterator over the equivalent orderings in this class. + /// + /// Note this class also implements [`IntoIterator`] to return an iterator + /// over owned [`LexOrdering`]s. pub fn iter(&self) -> impl Iterator { self.orderings.iter() } @@ -95,7 +107,7 @@ impl OrderingEquivalenceClass { self.remove_redundant_entries(); } - /// Adds new orderings into this ordering equivalence class. + /// Adds new orderings into this ordering equivalence class pub fn add_new_orderings( &mut self, orderings: impl IntoIterator, @@ -110,9 +122,10 @@ impl OrderingEquivalenceClass { self.add_new_orderings([ordering]); } - /// Removes redundant orderings from this equivalence class. For instance, - /// if we already have the ordering `[a ASC, b ASC, c DESC]`, then there is - /// no need to keep ordering `[a ASC, b ASC]` in the state. + /// Removes redundant orderings from this equivalence class. + /// + /// For instance, if we already have the ordering `[a ASC, b ASC, c DESC]`, + /// then there is no need to keep ordering `[a ASC, b ASC]` in the state. fn remove_redundant_entries(&mut self) { let mut work = true; while work { @@ -198,6 +211,7 @@ impl OrderingEquivalenceClass { } } +/// Convert the `OrderingEquivalenceClass` into an iterator of LexOrderings impl IntoIterator for OrderingEquivalenceClass { type Item = LexOrdering; type IntoIter = IntoIter; diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 25a05a2a5918..681484fd6bff 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -636,7 +636,7 @@ mod tests { let err_msg = format!( "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings.orderings, expected, projection_mapping + idx, orderings, expected, projection_mapping ); assert_eq!(orderings.len(), expected.len(), "{}", err_msg); @@ -825,7 +825,7 @@ mod tests { let err_msg = format!( "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings.orderings, expected, projection_mapping + idx, orderings, expected, projection_mapping ); assert_eq!(orderings.len(), expected.len(), "{}", err_msg); @@ -971,7 +971,7 @@ mod tests { let err_msg = format!( "actual: {:?}, expected: {:?}, projection_mapping: {:?}", - orderings.orderings, expected, projection_mapping + orderings, expected, projection_mapping ); assert_eq!(orderings.len(), expected.len(), "{}", err_msg); diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 8c4d77a44c2d..4f440416c457 100755 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -131,7 +131,7 @@ pub struct EquivalenceProperties { /// Expressions whose values are constant /// /// TODO: We do not need to track constants separately, they can be tracked - /// inside `eq_groups` as `Literal` expressions. + /// inside `eq_group` as `Literal` expressions. constants: Vec, /// Schema associated with this object. schema: SchemaRef, @@ -168,6 +168,11 @@ impl EquivalenceProperties { &self.oeq_class } + /// Return the inner OrderingEquivalenceClass, consuming self + pub fn into_oeq_class(self) -> OrderingEquivalenceClass { + self.oeq_class + } + /// Returns a reference to the equivalence group within. pub fn eq_group(&self) -> &EquivalenceGroup { &self.eq_group @@ -430,8 +435,8 @@ impl EquivalenceProperties { let mut new_orderings = vec![filtered_exprs.clone()]; // Preserve valid suffixes from existing orderings - let orderings = mem::take(&mut self.oeq_class.orderings); - for existing in orderings { + let oeq_class = mem::take(&mut self.oeq_class); + for existing in oeq_class { if self.is_prefix_of(&filtered_exprs, &existing) { let mut extended = filtered_exprs.clone(); extended.extend(existing.into_iter().skip(filtered_exprs.len())); @@ -710,8 +715,8 @@ impl EquivalenceProperties { /// Since it would cause bug in dependency constructions, we should substitute the input order in order to get correct /// dependency map, happen in issue 8838: pub fn substitute_oeq_class(&mut self, mapping: &ProjectionMapping) -> Result<()> { - let orderings = &self.oeq_class.orderings; - let new_order = orderings + let new_order = self + .oeq_class .iter() .map(|order| self.substitute_ordering_component(mapping, order)) .collect::>>()?; @@ -1219,7 +1224,7 @@ impl EquivalenceProperties { // Rewrite orderings according to new schema: let mut new_orderings = vec![]; - for ordering in self.oeq_class.orderings { + for ordering in self.oeq_class { let new_ordering = ordering .inner .into_iter() @@ -2008,16 +2013,8 @@ fn calculate_union_binary( // Next, calculate valid orderings for the union by searching for prefixes // in both sides. let mut orderings = UnionEquivalentOrderingBuilder::new(); - orderings.add_satisfied_orderings( - lhs.normalized_oeq_class().orderings, - lhs.constants(), - &rhs, - ); - orderings.add_satisfied_orderings( - rhs.normalized_oeq_class().orderings, - rhs.constants(), - &lhs, - ); + orderings.add_satisfied_orderings(lhs.normalized_oeq_class(), lhs.constants(), &rhs); + orderings.add_satisfied_orderings(rhs.normalized_oeq_class(), rhs.constants(), &lhs); let orderings = orderings.build(); let mut eq_properties = @@ -2156,7 +2153,7 @@ impl UnionEquivalentOrderingBuilder { // for each equivalent ordering in properties, try and augment // `ordering` it with the constants to match - for existing_ordering in &properties.oeq_class.orderings { + for existing_ordering in properties.oeq_class.iter() { if let Some(augmented_ordering) = self.augment_ordering( ordering, constants, @@ -2437,17 +2434,12 @@ mod tests { Some(JoinSide::Left), &[], ); - let orderings = &join_eq.oeq_class.orderings; - let err_msg = format!("expected: {:?}, actual:{:?}", expected, orderings); - assert_eq!( - join_eq.oeq_class.orderings.len(), - expected.len(), - "{}", - err_msg - ); - for ordering in orderings { + let err_msg = + format!("expected: {:?}, actual:{:?}", expected, &join_eq.oeq_class); + assert_eq!(join_eq.oeq_class.len(), expected.len(), "{}", err_msg); + for ordering in join_eq.oeq_class { assert!( - expected.contains(ordering), + expected.contains(&ordering), "{}, ordering: {:?}", err_msg, ordering @@ -3766,8 +3758,8 @@ mod tests { // Check whether orderings are same. let lhs_orderings = lhs.oeq_class(); - let rhs_orderings = &rhs.oeq_class.orderings; - for rhs_ordering in rhs_orderings { + let rhs_orderings = rhs.oeq_class(); + for rhs_ordering in rhs_orderings.iter() { assert!( lhs_orderings.contains(rhs_ordering), "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" @@ -3843,7 +3835,7 @@ mod tests { // Add equality condition c = concat(a, b) eq_properties.add_equal_conditions(&col_c, &a_concat_b)?; - let orderings = eq_properties.oeq_class().orderings.clone(); + let orderings = eq_properties.oeq_class(); let expected_ordering1 = LexOrdering::from(vec![ @@ -3894,7 +3886,7 @@ mod tests { // Add equality condition c = a * b eq_properties.add_equal_conditions(&col_c, &a_times_b)?; - let orderings = eq_properties.oeq_class().orderings.clone(); + let orderings = eq_properties.oeq_class(); // The ordering should remain unchanged since multiplication is not lex-monotonic assert_eq!(orderings.len(), 1); @@ -3934,7 +3926,7 @@ mod tests { // Add equality condition c = concat(a, b) eq_properties.add_equal_conditions(&col_c, &a_concat_b)?; - let orderings = eq_properties.oeq_class().orderings.clone(); + let orderings = eq_properties.oeq_class(); let expected_ordering1 = LexOrdering::from(vec![PhysicalSortExpr::new_default( Arc::clone(&a_concat_b), @@ -3978,8 +3970,9 @@ mod tests { // Should only contain b since a is constant assert_eq!(result.oeq_class().len(), 1); - assert_eq!(result.oeq_class().orderings[0].len(), 1); - assert!(result.oeq_class().orderings[0][0].expr.eq(&col_b)); + let ordering = result.oeq_class().iter().next().unwrap(); + assert_eq!(ordering.len(), 1); + assert!(ordering[0].expr.eq(&col_b)); Ok(()) } @@ -4025,13 +4018,14 @@ mod tests { // Should only contain [a ASC, b DESC, c ASC] assert_eq!(result.oeq_class().len(), 1); - assert_eq!(result.oeq_class().orderings[0].len(), 3); - assert!(result.oeq_class().orderings[0][0].expr.eq(&col_a)); - assert!(result.oeq_class().orderings[0][0].options.eq(&asc)); - assert!(result.oeq_class().orderings[0][1].expr.eq(&col_b)); - assert!(result.oeq_class().orderings[0][1].options.eq(&desc)); - assert!(result.oeq_class().orderings[0][2].expr.eq(&col_c)); - assert!(result.oeq_class().orderings[0][2].options.eq(&asc)); + let ordering = result.oeq_class().iter().next().unwrap(); + assert_eq!(ordering.len(), 3); + assert!(ordering[0].expr.eq(&col_a)); + assert!(ordering[0].options.eq(&asc)); + assert!(ordering[1].expr.eq(&col_b)); + assert!(ordering[1].options.eq(&desc)); + assert!(ordering[2].expr.eq(&col_c)); + assert!(ordering[2].options.eq(&asc)); Ok(()) } @@ -4074,11 +4068,12 @@ mod tests { assert_eq!(result.oeq_class().len(), 1); // Verify orderings - assert_eq!(result.oeq_class().orderings[0].len(), 2); - assert!(result.oeq_class().orderings[0][0].expr.eq(&col_b)); - assert!(result.oeq_class().orderings[0][0].options.eq(&asc)); - assert!(result.oeq_class().orderings[0][1].expr.eq(&col_c)); - assert!(result.oeq_class().orderings[0][1].options.eq(&asc)); + let ordering = result.oeq_class().iter().next().unwrap(); + assert_eq!(ordering.len(), 2); + assert!(ordering[0].expr.eq(&col_b)); + assert!(ordering[0].options.eq(&asc)); + assert!(ordering[1].expr.eq(&col_c)); + assert!(ordering[1].options.eq(&asc)); Ok(()) } @@ -4119,7 +4114,8 @@ mod tests { // Should only contain the new ordering since options don't match assert_eq!(result.oeq_class().len(), 1); - assert_eq!(result.oeq_class().orderings[0], new_order); + let ordering = result.oeq_class().iter().next().unwrap(); + assert_eq!(ordering, &new_order); Ok(()) } @@ -4177,7 +4173,7 @@ mod tests { // Should preserve the original [d ASC, a ASC] ordering assert_eq!(result.oeq_class().len(), 1); - let ordering = &result.oeq_class().orderings[0]; + let ordering = result.oeq_class().iter().next().unwrap(); assert_eq!(ordering.len(), 2); // First expression should be either b or d (they're equivalent) diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index c61a1f0ae5da..67ecd44ff317 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -260,10 +260,8 @@ impl MemoryExec { ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?; sort_information = base_eqp .project(&projection_mapping, self.schema()) - .oeq_class() - // TODO add a take / into to avoid the clone - .clone() - .orderings; + .into_oeq_class() + .into_inner(); } self.sort_information = sort_information; From da4208b28842a3482bc340cb79dae79c5b94eb1d Mon Sep 17 00:00:00 2001 From: Tim Van Wassenhove Date: Wed, 8 Jan 2025 20:32:29 +0100 Subject: [PATCH 06/19] Fix: ensure that compression type is also taken into consideration during ListingTableConfig infer_options (#14021) * chore: add test to verify that schema is inferred as expected * chore: add comment to method as suggested * chore: restructure to avoid need to clone * chore: fix flaw in rewrite --- .../core/src/datasource/listing/table.rs | 59 ++++++++++++++++--- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 791b15704d09..7470597ef72c 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -114,19 +114,26 @@ impl ListingTableConfig { } } - fn infer_file_extension(path: &str) -> Result { + ///Returns a tupe of (file_extension, optional compression_extension) + /// + /// For example a path ending with blah.test.csv.gz returns `("csv", Some("gz"))` + /// For example a path ending with blah.test.csv returns `("csv", None)` + fn infer_file_extension_and_compression_type( + path: &str, + ) -> Result<(String, Option)> { let mut exts = path.rsplit('.'); - let mut splitted = exts.next().unwrap_or(""); + let splitted = exts.next().unwrap_or(""); let file_compression_type = FileCompressionType::from_str(splitted) .unwrap_or(FileCompressionType::UNCOMPRESSED); if file_compression_type.is_compressed() { - splitted = exts.next().unwrap_or(""); + let splitted2 = exts.next().unwrap_or(""); + Ok((splitted2.to_string(), Some(splitted.to_string()))) + } else { + Ok((splitted.to_string(), None)) } - - Ok(splitted.to_string()) } /// Infer `ListingOptions` based on `table_path` suffix. @@ -147,18 +154,33 @@ impl ListingTableConfig { .await .ok_or_else(|| DataFusionError::Internal("No files for table".into()))??; - let file_extension = - ListingTableConfig::infer_file_extension(file.location.as_ref())?; + let (file_extension, maybe_compression_type) = + ListingTableConfig::infer_file_extension_and_compression_type( + file.location.as_ref(), + )?; + + let mut format_options = HashMap::new(); + if let Some(ref compression_type) = maybe_compression_type { + format_options + .insert("format.compression".to_string(), compression_type.clone()); + } let file_format = state .get_file_format_factory(&file_extension) .ok_or(config_datafusion_err!( "No file_format found with extension {file_extension}" ))? - .create(state, &HashMap::new())?; + .create(state, &format_options)?; + + let listing_file_extension = + if let Some(compression_type) = maybe_compression_type { + format!("{}.{}", &file_extension, &compression_type) + } else { + file_extension + }; let listing_options = ListingOptions::new(file_format) - .with_file_extension(file_extension) + .with_file_extension(listing_file_extension) .with_target_partitions(state.config().target_partitions()); Ok(Self { @@ -2194,4 +2216,23 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_infer_options_compressed_csv() -> Result<()> { + let testdata = crate::test_util::arrow_test_data(); + let filename = format!("{}/csv/aggregate_test_100.csv.gz", testdata); + let table_path = ListingTableUrl::parse(filename).unwrap(); + + let ctx = SessionContext::new(); + + let config = ListingTableConfig::new(table_path); + let config_with_opts = config.infer_options(&ctx.state()).await?; + let config_with_schema = config_with_opts.infer_schema(&ctx.state()).await?; + + let schema = config_with_schema.file_schema.unwrap(); + + assert_eq!(schema.fields.len(), 13); + + Ok(()) + } } From ad5a04f47f67090b6acb2ddd46b5f15200e85912 Mon Sep 17 00:00:00 2001 From: nuno-faria Date: Wed, 8 Jan 2025 19:34:52 +0000 Subject: [PATCH 07/19] feat(optimizer): Enable filter pushdown on window functions (#14026) * feat(optimizer): Enable filter pushdown on window functions Ensures selections can be pushed past window functions similarly to what is already done with aggregations, when possible. * fix: Add missing dependency * minor(optimizer): Use 'datafusion-functions-window' as a dev dependency * docs(optimizer): Add example to filter pushdown on LogicalPlan::Window --- datafusion/optimizer/Cargo.toml | 1 + datafusion/optimizer/src/push_down_filter.rs | 310 +++++++++++++++++- datafusion/sqllogictest/test_files/window.slt | 293 +++++++++++++++++ 3 files changed, 600 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index ba0dedc57675..3f5ec9b0da03 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -55,6 +55,7 @@ regex-syntax = "0.8.0" async-trait = { workspace = true } ctor = { workspace = true } datafusion-functions-aggregate = { workspace = true } +datafusion-functions-window = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-sql = { workspace = true } env_logger = { workspace = true } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 0b328ad39f55..534209c4e11b 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -985,6 +985,87 @@ impl OptimizerRule for PushDownFilter { } }) } + // Tries to push filters based on the partition key(s) of the window function(s) used. + // Example: + // Before: + // Filter: (a > 1) and (b > 1) and (c > 1) + // Window: func() PARTITION BY [a] ... + // --- + // After: + // Filter: (b > 1) and (c > 1) + // Window: func() PARTITION BY [a] ... + // Filter: (a > 1) + LogicalPlan::Window(window) => { + // Retrieve the set of potential partition keys where we can push filters by. + // Unlike aggregations, where there is only one statement per SELECT, there can be + // multiple window functions, each with potentially different partition keys. + // Therefore, we need to ensure that any potential partition key returned is used in + // ALL window functions. Otherwise, filters cannot be pushed by through that column. + let potential_partition_keys = window + .window_expr + .iter() + .map(|e| { + if let Expr::WindowFunction(window_expression) = e { + window_expression + .partition_by + .iter() + .map(|c| { + Column::from_qualified_name( + c.schema_name().to_string(), + ) + }) + .collect::>() + } else { + // window functions expressions are only Expr::WindowFunction + unreachable!() + } + }) + // performs the set intersection of the partition keys of all window functions, + // returning only the common ones + .reduce(|a, b| &a & &b) + .unwrap_or_default(); + + let predicates = split_conjunction_owned(filter.predicate); + let mut keep_predicates = vec![]; + let mut push_predicates = vec![]; + for expr in predicates { + let cols = expr.column_refs(); + if cols.iter().all(|c| potential_partition_keys.contains(c)) { + push_predicates.push(expr); + } else { + keep_predicates.push(expr); + } + } + + // Unlike with aggregations, there are no cases where we have to replace, e.g., + // `a+b` with Column(a)+Column(b). This is because partition expressions are not + // available as standalone columns to the user. For example, while an aggregation on + // `a+b` becomes Column(a + b), in a window partition it becomes + // `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in + // place, so we can use `push_predicates` directly. This is consistent with other + // optimizers, such as the one used by Postgres. + + let window_input = Arc::clone(&window.input); + Transformed::yes(LogicalPlan::Window(window)) + .transform_data(|new_plan| { + // If we have a filter to push, we push it down to the input of the window + if let Some(predicate) = conjunction(push_predicates) { + let new_filter = make_filter(predicate, window_input)?; + insert_below(new_plan, new_filter) + } else { + Ok(Transformed::no(new_plan)) + } + })? + .map_data(|child_plan| { + // if there are any remaining predicates we can't push, add them + // back as a filter + if let Some(predicate) = conjunction(keep_predicates) { + make_filter(predicate, Arc::new(child_plan)) + } else { + Ok(child_plan) + } + }) + } LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), LogicalPlan::TableScan(scan) => { let filter_predicates = split_conjunction(&filter.predicate); @@ -1289,12 +1370,12 @@ mod tests { use async_trait::async_trait; use datafusion_common::{DFSchemaRef, ScalarValue}; - use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::expr::{ScalarFunction, WindowFunction}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ - col, in_list, in_subquery, lit, ColumnarValue, Extension, LogicalPlanBuilder, - ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType, - UserDefinedLogicalNodeCore, Volatility, + col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension, + LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType, + UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, }; use crate::optimizer::Optimizer; @@ -1442,6 +1523,227 @@ mod tests { assert_optimized_plan_eq(plan, expected) } + /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed + #[test] + fn filter_move_window() -> Result<()> { + let table_scan = test_table_scan()?; + + let window = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("a"), col("b")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window])? + .filter(col("b").gt(lit(10i64)))? + .build()?; + + let expected = "\ + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test, full_filters=[test.b > Int64(10)]"; + assert_optimized_plan_eq(plan, expected) + } + + /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and + /// 'b' are pushed + #[test] + fn filter_move_complex_window() -> Result<()> { + let table_scan = test_table_scan()?; + + let window = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("a"), col("b")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window])? + .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))? + .build()?; + + let expected = "\ + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]"; + assert_optimized_plan_eq(plan, expected) + } + + /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed + #[test] + fn filter_move_partial_window() -> Result<()> { + let table_scan = test_table_scan()?; + + let window = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("a")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window])? + .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))? + .build()?; + + let expected = "\ + Filter: test.b = Int64(1)\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test, full_filters=[test.a > Int64(10)]"; + assert_optimized_plan_eq(plan, expected) + } + + /// verifies that filters on partition expressions are not pushed, as the single expression + /// column is not available to the user, unlike with aggregations + #[test] + fn filter_expression_keep_window() -> Result<()> { + let table_scan = test_table_scan()?; + + let window = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![add(col("a"), col("b"))]) // PARTITION BY a + b + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window])? + // unlike with aggregations, single partition column "test.a + test.b" is not available + // to the plan, so we use multiple columns when filtering + .filter(add(col("a"), col("b")).gt(lit(10i64)))? + .build()?; + + let expected = "\ + Filter: test.a + test.b > Int64(10)\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test"; + assert_optimized_plan_eq(plan, expected) + } + + /// verifies that filters are not pushed on order by columns (that are not used in partitioning) + #[test] + fn filter_order_keep_window() -> Result<()> { + let table_scan = test_table_scan()?; + + let window = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("a")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window])? + .filter(col("c").gt(lit(10i64)))? + .build()?; + + let expected = "\ + Filter: test.c > Int64(10)\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test"; + assert_optimized_plan_eq(plan, expected) + } + + /// verifies that when we use multiple window functions with a common partition key, the filter + /// on that key is pushed + #[test] + fn filter_multiple_windows_common_partitions() -> Result<()> { + let table_scan = test_table_scan()?; + + let window1 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("a")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let window2 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("b"), col("a")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window1, window2])? + .filter(col("a").gt(lit(10i64)))? // a appears in both window functions + .build()?; + + let expected = "\ + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test, full_filters=[test.a > Int64(10)]"; + assert_optimized_plan_eq(plan, expected) + } + + /// verifies that when we use multiple window functions with different partitions keys, the + /// filter cannot be pushed + #[test] + fn filter_multiple_windows_disjoint_partitions() -> Result<()> { + let table_scan = test_table_scan()?; + + let window1 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("a")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let window2 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("b"), col("a")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window1, window2])? + .filter(col("b").gt(lit(10i64)))? // b only appears in one window function + .build()?; + + let expected = "\ + Filter: test.b > Int64(10)\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test"; + assert_optimized_plan_eq(plan, expected) + } + /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written #[test] fn alias() -> Result<()> { diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 56f088dfd10f..3c6f0f6deba1 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -5127,3 +5127,296 @@ order by id; statement ok drop table t1; + +# test filters that will be pushed into windows +statement count 0 +create table t1 (c1 int, c2 int) as values + (1, 1), + (1, 10), + (2, 1), + (2, 10), + (3, 1), + (3, 10); + +# base query +query III +select c1, c2, rank() over (partition by c1 order by c2) as rank +from t1 +order by c1; +---- +1 1 1 +1 10 2 +2 1 1 +2 10 2 +3 1 1 +3 10 2 + +# push filter since it uses a partition column +query TT +explain select c1, c2, rank +from ( + select c1, c2, rank() over (partition by c1 order by c2) as rank + from t1 +) +where c1 = 2 or c1 = 3 +order by c1, c2, rank; +---- +logical_plan +01)Sort: t1.c1 ASC NULLS LAST, t1.c2 ASC NULLS LAST, rank ASC NULLS LAST +02)--Projection: t1.c1, t1.c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rank +03)----WindowAggr: windowExpr=[[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------Filter: t1.c1 = Int32(2) OR t1.c1 = Int32(3) +05)--------TableScan: t1 projection=[c1, c2] +physical_plan +01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, rank@2 ASC NULLS LAST] +02)--ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rank] +03)----BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] +05)--------CoalesceBatchesExec: target_batch_size=1 +06)----------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +07)------------CoalesceBatchesExec: target_batch_size=1 +08)--------------FilterExec: c1@0 = 2 OR c1@0 = 3 +09)----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +10)------------------MemoryExec: partitions=1, partition_sizes=[1] + +query III +select c1, c2, rank +from ( + select c1, c2, rank() over (partition by c1 order by c2) as rank + from t1 +) +where c1 = 2 or c1 = 3 +order by c1, c2, rank; +---- +2 1 1 +2 10 2 +3 1 1 +3 10 2 + +# do not push filter since it does not use a partition column +query TT +explain select c1, c2, rank +from ( + select c1, c2, rank() over (partition by c1 order by c2) as rank + from t1 +) +where c2 >= 10 +order by c1, c2, rank; +---- +logical_plan +01)Sort: t1.c1 ASC NULLS LAST, t1.c2 ASC NULLS LAST, rank ASC NULLS LAST +02)--Projection: t1.c1, t1.c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rank +03)----Filter: t1.c2 >= Int32(10) +04)------WindowAggr: windowExpr=[[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +05)--------TableScan: t1 projection=[c1, c2] +physical_plan +01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, rank@2 ASC NULLS LAST] +02)--ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rank] +03)----CoalesceBatchesExec: target_batch_size=1 +04)------FilterExec: c2@1 >= 10 +05)--------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +06)----------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] +07)------------CoalesceBatchesExec: target_batch_size=1 +08)--------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +09)----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +10)------------------MemoryExec: partitions=1, partition_sizes=[1] + +query III +select c1, c2, rank +from ( + select c1, c2, rank() over (partition by c1 order by c2) as rank + from t1 +) +where c2 >= 10 +order by c1, c2, rank; +---- +1 10 2 +2 10 2 +3 10 2 + +# push partial multi column filter (conjunction) +query TT +explain select c1, c2, rank +from ( + select c1, c2, rank() over (partition by c1 order by c2) as rank + from t1 +) +where c1 = 1 and c2 = 10 +order by c1, c2, rank; +---- +logical_plan +01)Sort: t1.c1 ASC NULLS LAST, t1.c2 ASC NULLS LAST, rank ASC NULLS LAST +02)--Projection: t1.c1, t1.c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rank +03)----Filter: t1.c2 = Int32(10) +04)------WindowAggr: windowExpr=[[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +05)--------Filter: t1.c1 = Int32(1) +06)----------TableScan: t1 projection=[c1, c2] +physical_plan +01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, rank@2 ASC NULLS LAST] +02)--ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rank] +03)----CoalesceBatchesExec: target_batch_size=1 +04)------FilterExec: c2@1 = 10 +05)--------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +06)----------SortExec: expr=[c2@1 ASC NULLS LAST], preserve_partitioning=[true] +07)------------CoalesceBatchesExec: target_batch_size=1 +08)--------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +09)----------------CoalesceBatchesExec: target_batch_size=1 +10)------------------FilterExec: c1@0 = 1 +11)--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +12)----------------------MemoryExec: partitions=1, partition_sizes=[1] + +query III +select c1, c2, rank +from ( + select c1, c2, rank() over (partition by c1 order by c2) as rank + from t1 +) +where c1 = 1 and c2 = 10 +order by c1, c2, rank; +---- +1 10 2 + +# do not push multi column filter (disjunction) +query TT +explain select c1, c2, rank +from ( + select c1, c2, rank() over (partition by c1 order by c2) as rank + from t1 +) +where c1 = 1 or c2 = 10 +order by c1, c2, rank; +---- +logical_plan +01)Sort: t1.c1 ASC NULLS LAST, t1.c2 ASC NULLS LAST, rank ASC NULLS LAST +02)--Projection: t1.c1, t1.c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rank +03)----Filter: t1.c1 = Int32(1) OR t1.c2 = Int32(10) +04)------WindowAggr: windowExpr=[[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +05)--------TableScan: t1 projection=[c1, c2] +physical_plan +01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, rank@2 ASC NULLS LAST] +02)--ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rank] +03)----CoalesceBatchesExec: target_batch_size=1 +04)------FilterExec: c1@0 = 1 OR c2@1 = 10 +05)--------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +06)----------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] +07)------------CoalesceBatchesExec: target_batch_size=1 +08)--------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +09)----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +10)------------------MemoryExec: partitions=1, partition_sizes=[1] + +query III +select c1, c2, rank +from ( + select c1, c2, rank() over (partition by c1 order by c2) as rank + from t1 +) +where c1 = 1 or c2 = 10 +order by c1, c2, rank; +---- +1 1 1 +1 10 2 +2 10 2 +3 10 2 + +# multiple window functions with a filter in a common key -> filter will be pushed past both windows +query TT +explain select c1, c2, rank1, rank2 +from ( + select c1, c2, rank() over (partition by c1 order by c2) as rank1, + rank() over (partition by c2, c1 order by c1) as rank2 + from t1 +) +where c1 > 1 +order by c1, c2, rank1, rank2; +---- +logical_plan +01)Sort: t1.c1 ASC NULLS LAST, t1.c2 ASC NULLS LAST, rank1 ASC NULLS LAST, rank2 ASC NULLS LAST +02)--Projection: t1.c1, t1.c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rank1, rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rank2 +03)----WindowAggr: windowExpr=[[rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------WindowAggr: windowExpr=[[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +05)--------Filter: t1.c1 > Int32(1) +06)----------TableScan: t1 projection=[c1, c2] +physical_plan +01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, rank1@2 ASC NULLS LAST, rank2@3 ASC NULLS LAST] +02)--SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, rank1@2 ASC NULLS LAST, rank2@3 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rank1, rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as rank2] +04)------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------SortExec: expr=[c2@1 ASC NULLS LAST, c1@0 ASC NULLS LAST], preserve_partitioning=[true] +06)----------CoalesceBatchesExec: target_batch_size=1 +07)------------RepartitionExec: partitioning=Hash([c2@1, c1@0], 2), input_partitions=2 +08)--------------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +09)----------------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] +10)------------------CoalesceBatchesExec: target_batch_size=1 +11)--------------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +12)----------------------CoalesceBatchesExec: target_batch_size=1 +13)------------------------FilterExec: c1@0 > 1 +14)--------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +15)----------------------------MemoryExec: partitions=1, partition_sizes=[1] + +query IIII +select c1, c2, rank1, rank2 +from ( + select c1, c2, rank() over (partition by c1 order by c2) as rank1, + rank() over (partition by c2, c1 order by c1) as rank2 + from t1 +) +where c1 > 1 +order by c1, c2, rank1, rank2; +---- +2 1 1 1 +2 10 2 1 +3 1 1 1 +3 10 2 1 + +# multiple window functions with a filter whose key only appears in one window -> filter will be +# pushed past only one function (in this case, since that window appears above in the plan) +query TT +explain select c1, c2, rank1, rank2 +from ( + select c1, c2, rank() over (partition by c1 order by c2) as rank1, + rank() over (partition by c2, c1 order by c1) as rank2 + from t1 +) +where c2 > 1 +order by c1, c2, rank1, rank2; +---- +logical_plan +01)Sort: t1.c1 ASC NULLS LAST, t1.c2 ASC NULLS LAST, rank1 ASC NULLS LAST, rank2 ASC NULLS LAST +02)--Projection: t1.c1, t1.c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rank1, rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rank2 +03)----WindowAggr: windowExpr=[[rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------Filter: t1.c2 > Int32(1) +05)--------WindowAggr: windowExpr=[[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +06)----------TableScan: t1 projection=[c1, c2] +physical_plan +01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, rank1@2 ASC NULLS LAST, rank2@3 ASC NULLS LAST] +02)--SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, rank1@2 ASC NULLS LAST, rank2@3 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rank1, rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as rank2] +04)------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------SortExec: expr=[c2@1 ASC NULLS LAST, c1@0 ASC NULLS LAST], preserve_partitioning=[true] +06)----------CoalesceBatchesExec: target_batch_size=1 +07)------------RepartitionExec: partitioning=Hash([c2@1, c1@0], 2), input_partitions=2 +08)--------------CoalesceBatchesExec: target_batch_size=1 +09)----------------FilterExec: c2@1 > 1 +10)------------------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +11)--------------------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] +12)----------------------CoalesceBatchesExec: target_batch_size=1 +13)------------------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +14)--------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +15)----------------------------MemoryExec: partitions=1, partition_sizes=[1] + +query IIII +select c1, c2, rank1, rank2 +from ( + select c1, c2, rank() over (partition by c1 order by c2) as rank1, + rank() over (partition by c2, c1 order by c1) as rank2 + from t1 +) +where c2 > 1 +order by c1, c2, rank1, rank2; +---- +1 10 2 1 +2 10 2 1 +3 10 2 1 + +statement ok +drop table t1; From 5955860e1b12b5fc0042b88341d9154f2eec8cfc Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen <83442793+MohamedAbdeen21@users.noreply.github.com> Date: Thu, 9 Jan 2025 01:01:00 +0000 Subject: [PATCH 08/19] Unparsing optimized (> 2 inputs) unions (#14031) * tests and optimizer in testing queries * unparse optimized unions * format Cargo.toml * format Cargo.toml * revert test * rewrite test to avoid cyclic dep * remove old test * cleanup * comments and error handling * handle union with lt 2 inputs --- datafusion/sql/src/unparser/plan.rs | 29 ++++++------ datafusion/sql/tests/cases/plan_to_sql.rs | 56 +++++++++++++++++++++-- 2 files changed, 68 insertions(+), 17 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 6f30845eb810..2bad683dc1bc 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -706,13 +706,6 @@ impl Unparser<'_> { Ok(()) } LogicalPlan::Union(union) => { - if union.inputs.len() != 2 { - return not_impl_err!( - "UNION ALL expected 2 inputs, but found {}", - union.inputs.len() - ); - } - // Covers cases where the UNION is a subquery and the projection is at the top level if select.already_projected() { return self.derive_with_dialect_alias( @@ -729,12 +722,22 @@ impl Unparser<'_> { .map(|input| self.select_to_sql_expr(input, query)) .collect::>>()?; - let union_expr = SetExpr::SetOperation { - op: ast::SetOperator::Union, - set_quantifier: ast::SetQuantifier::All, - left: Box::new(input_exprs[0].clone()), - right: Box::new(input_exprs[1].clone()), - }; + if input_exprs.len() < 2 { + return internal_err!("UNION operator requires at least 2 inputs"); + } + + // Build the union expression tree bottom-up by reversing the order + // note that we are also swapping left and right inputs because of the rev + let union_expr = input_exprs + .into_iter() + .rev() + .reduce(|a, b| SetExpr::SetOperation { + op: ast::SetOperator::Union, + set_quantifier: ast::SetQuantifier::All, + left: Box::new(b), + right: Box::new(a), + }) + .unwrap(); let Some(query) = query.as_mut() else { return internal_err!( diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 24ec7f03deb0..94b4df59ef00 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::*; +use arrow_schema::{DataType, Field, Schema}; use datafusion_common::{assert_contains, DFSchema, DFSchemaRef, Result, TableReference}; use datafusion_expr::test::function_stub::{ count_udaf, max_udaf, min_udaf, sum, sum_udaf, }; use datafusion_expr::{ - col, lit, table_scan, wildcard, Expr, Extension, LogicalPlan, LogicalPlanBuilder, - UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + col, lit, table_scan, wildcard, EmptyRelation, Expr, Extension, LogicalPlan, + LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, }; use datafusion_functions::unicode; use datafusion_functions_aggregate::grouping::grouping_udaf; @@ -42,7 +42,7 @@ use std::{fmt, vec}; use crate::common::{MockContextProvider, MockSessionState}; use datafusion_expr::builder::{ - table_scan_with_filter_and_fetch, table_scan_with_filters, + project, table_scan_with_filter_and_fetch, table_scan_with_filters, }; use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_nested::extract::array_element_udf; @@ -1615,3 +1615,51 @@ fn test_unparse_extension_to_sql() -> Result<()> { } Ok(()) } + +#[test] +fn test_unparse_optimized_multi_union() -> Result<()> { + let unparser = Unparser::default(); + + let schema = Schema::new(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + + let dfschema = Arc::new(DFSchema::try_from(schema)?); + + let empty = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: dfschema.clone(), + }); + + let plan = LogicalPlan::Union(Union { + inputs: vec![ + project(empty.clone(), vec![lit(1).alias("x"), lit("a").alias("y")])?.into(), + project(empty.clone(), vec![lit(1).alias("x"), lit("b").alias("y")])?.into(), + project(empty.clone(), vec![lit(2).alias("x"), lit("a").alias("y")])?.into(), + project(empty.clone(), vec![lit(2).alias("x"), lit("c").alias("y")])?.into(), + ], + schema: dfschema.clone(), + }); + + let sql = "SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y"; + + assert_eq!(unparser.plan_to_sql(&plan)?.to_string(), sql); + + let plan = LogicalPlan::Union(Union { + inputs: vec![project( + empty.clone(), + vec![lit(1).alias("x"), lit("a").alias("y")], + )? + .into()], + schema: dfschema.clone(), + }); + + if let Some(err) = plan_to_sql(&plan).err() { + assert_contains!(err.to_string(), "UNION operator requires at least 2 inputs"); + } else { + panic!("Expected error") + } + + Ok(()) +} From 61afb0d486b4617ddc79cb01e8765fedd3ae247b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 8 Jan 2025 21:42:52 -0500 Subject: [PATCH 09/19] Minor: Document output schema of LogicalPlan::Aggregate and LogicalPlan::Window (#14047) --- datafusion/expr/src/logical_plan/plan.rs | 27 ++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 005bf4b98451..24fb0609b0fe 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -215,10 +215,14 @@ pub enum LogicalPlan { /// Windows input based on a set of window spec and window /// function (e.g. SUM or RANK). This is used to implement SQL /// window functions, and the `OVER` clause. + /// + /// See [`Window`] for more details Window(Window), /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). This is used to implement SQL aggregates /// and `GROUP BY`. + /// + /// See [`Aggregate`] for more details Aggregate(Aggregate), /// Sorts its input according to a list of sort expressions. This /// is used to implement SQL `ORDER BY` @@ -2365,6 +2369,19 @@ impl Filter { } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) +/// +/// # Output Schema +/// +/// The output schema is the input schema followed by the window function +/// expressions, in order. +/// +/// For example, given the input schema `"A", "B", "C"` and the window function +/// `SUM(A) OVER (PARTITION BY B+1 ORDER BY C)`, the output schema will be `"A", +/// "B", "C", "SUM(A) OVER ..."` where `"SUM(A) OVER ..."` is the name of the +/// output column. +/// +/// Note that the `PARTITION BY` expression "B+1" is not produced in the output +/// schema. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Window { /// The incoming logical plan @@ -2968,6 +2985,16 @@ impl PartialOrd for DistinctOn { /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). +/// +/// # Output Schema +/// +/// The output schema is the group expressions followed by the aggregate +/// expressions in order. +/// +/// For example, given the input schema `"A", "B", "C"` and the aggregate +/// `SUM(A) GROUP BY C+B`, the output schema will be `"C+B", "SUM(A)"` where +/// "C+B" and "SUM(A)" are the names of the output columns. Note that "C+B" is a +/// single new column #[derive(Debug, Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() #[non_exhaustive] From ac9584e8a65ce3f9ad4fbca59c0e12399a622f9c Mon Sep 17 00:00:00 2001 From: Sergey Zhukov <62326549+cj-zhukov@users.noreply.github.com> Date: Fri, 10 Jan 2025 00:29:36 +0300 Subject: [PATCH 10/19] Simplify error handling in case.rs (#13990) (#14033) * Simplify error handling in case.rs (#13990) * Fix issues causing GitHub checks to fail * Update datafusion/physical-expr/src/expressions/case.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Sergey Zhukov Co-authored-by: Andrew Lamb --- .../physical-expr/src/expressions/case.rs | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index ee19a8c9ddb1..16b97c203c30 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -27,7 +27,9 @@ use arrow::compute::kernels::zip::zip; use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::as_boolean_array; -use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::ColumnarValue; use super::{Column, Literal}; @@ -249,10 +251,9 @@ impl CaseExpr { remainder = and_not(&remainder, &when_match)?; } - if let Some(e) = &self.else_expr { + if let Some(e) = self.else_expr() { // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| Arc::clone(e)); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; // null and unmatched tuples should be assigned else value remainder = or(&base_nulls, &remainder)?; let else_ = expr @@ -282,11 +283,8 @@ impl CaseExpr { .0 .evaluate_selection(batch, &remainder)?; let when_value = when_value.into_array(batch.num_rows())?; - let when_value = as_boolean_array(&when_value).map_err(|e| { - DataFusionError::Context( - "WHEN expression did not return a BooleanArray".to_string(), - Box::new(e), - ) + let when_value = as_boolean_array(&when_value).map_err(|_| { + internal_datafusion_err!("WHEN expression did not return a BooleanArray") })?; // Treat 'NULL' as false value let when_value = match when_value.null_count() { @@ -322,10 +320,9 @@ impl CaseExpr { remainder = and_not(&remainder, &when_value)?; } - if let Some(e) = &self.else_expr { + if let Some(e) = self.else_expr() { // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| Arc::clone(e)); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_ = expr .evaluate_selection(batch, &remainder)? .into_array(batch.num_rows())?; @@ -376,11 +373,8 @@ impl CaseExpr { // evaluate when expression let when_value = self.when_then_expr[0].0.evaluate(batch)?; let when_value = when_value.into_array(batch.num_rows())?; - let when_value = as_boolean_array(&when_value).map_err(|e| { - DataFusionError::Context( - "WHEN expression did not return a BooleanArray".to_string(), - Box::new(e), - ) + let when_value = as_boolean_array(&when_value).map_err(|_| { + internal_datafusion_err!("WHEN expression did not return a BooleanArray") })?; // Treat 'NULL' as false value @@ -393,12 +387,12 @@ impl CaseExpr { let then_value = self.when_then_expr[0].1.evaluate(batch)?; let then_value = Scalar::new(then_value.into_array(1)?); + let Some(e) = self.else_expr() else { + return internal_err!("expression did not evaluate to an array"); + }; // keep `else_expr`'s data type and return type consistent - let e = self.else_expr.as_ref().unwrap(); - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type) - .unwrap_or_else(|_| Arc::clone(e)); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?; let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?); - Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?)) } From f379719e58100d1af7f93c138516f4684985c06a Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Thu, 9 Jan 2025 13:39:56 -0800 Subject: [PATCH 11/19] feat: add `AsyncCatalogProvider` helpers for asynchronous catalogs (#13800) * Add asynchronous catalog traits to help users that have asynchronous catalogs * Apply clippy suggestions * Address PR reviews * Remove allow_unused exceptions * Update remote catalog example to demonstrate new helper structs * Move schema_name / catalog_name parameters into resolve function and out of trait --- .../examples/remote_catalog.rs | 194 ++--- datafusion/catalog/Cargo.toml | 3 + datafusion/catalog/src/async.rs | 747 ++++++++++++++++++ datafusion/catalog/src/lib.rs | 2 + 4 files changed, 796 insertions(+), 150 deletions(-) create mode 100644 datafusion/catalog/src/async.rs diff --git a/datafusion-examples/examples/remote_catalog.rs b/datafusion-examples/examples/remote_catalog.rs index 206b7ba9c4be..38629328d71c 100644 --- a/datafusion-examples/examples/remote_catalog.rs +++ b/datafusion-examples/examples/remote_catalog.rs @@ -32,22 +32,19 @@ use arrow::array::record_batch; use arrow_schema::{Field, Fields, Schema, SchemaRef}; use async_trait::async_trait; -use datafusion::catalog::{SchemaProvider, TableProvider}; -use datafusion::common::DataFusionError; +use datafusion::catalog::TableProvider; use datafusion::common::Result; use datafusion::execution::SendableRecordBatchStream; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::{DataFrame, SessionContext}; -use datafusion_catalog::Session; -use datafusion_common::{ - assert_batches_eq, internal_datafusion_err, plan_err, HashMap, TableReference, -}; +use datafusion_catalog::{AsyncSchemaProvider, Session}; +use datafusion_common::{assert_batches_eq, internal_datafusion_err, plan_err}; use datafusion_expr::{Expr, TableType}; use futures::TryStreamExt; use std::any::Any; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; #[tokio::main] async fn main() -> Result<()> { @@ -55,23 +52,18 @@ async fn main() -> Result<()> { let ctx = SessionContext::new(); // Make a connection to the remote catalog, asynchronously, and configure it - let remote_catalog_interface = RemoteCatalogInterface::connect().await?; + let remote_catalog_interface = Arc::new(RemoteCatalogInterface::connect().await?); - // Register a SchemaProvider for tables in a schema named "remote_schema". - // - // This will let DataFusion query tables such as - // `datafusion.remote_schema.remote_table` - let remote_schema: Arc = - Arc::new(RemoteSchema::new(Arc::new(remote_catalog_interface))); - ctx.catalog("datafusion") - .ok_or_else(|| internal_datafusion_err!("default catalog was not installed"))? - .register_schema("remote_schema", Arc::clone(&remote_schema))?; + // Create an adapter to provide the AsyncSchemaProvider interface to DataFusion + // based on our remote catalog interface + let remote_catalog_adapter = RemoteCatalogDatafusionAdapter(remote_catalog_interface); // Here is a query that selects data from a table in the remote catalog. let sql = "SELECT * from remote_schema.remote_table"; // The `SessionContext::sql` interface is async, but it does not - // support asynchronous access to catalogs, so the following query errors. + // support asynchronous access to catalogs, so we cannot register our schema provider + // directly and the following query fails to find our table. let results = ctx.sql(sql).await; assert_eq!( results.unwrap_err().to_string(), @@ -91,27 +83,26 @@ async fn main() -> Result<()> { // `remote_schema.remote_table`) let references = state.resolve_table_references(&statement)?; - // Call `load_tables` to load information from the remote catalog for each - // of the referenced tables. Best practice is to fetch the the information - // for all tables required by the query once (rather than one per table) to - // minimize network overhead - let table_names = references.iter().filter_map(|r| { - if refers_to_schema("datafusion", "remote_schema", r) { - Some(r.table()) - } else { - None - } - }); - remote_schema - .as_any() - .downcast_ref::() - .expect("correct types") - .load_tables(table_names) + // Now we can asynchronously resolve the table references to get a cached catalog + // that we can use for our query + let resolved_catalog = remote_catalog_adapter + .resolve(&references, state.config(), "datafusion", "remote_schema") .await?; - // Now continue planing the query after having fetched the remote table and - // it can run as normal - let plan = state.statement_to_plan(statement).await?; + // This resolved catalog only makes sense for this query and so we create a clone + // of the session context with the resolved catalog + let query_ctx = ctx.clone(); + + query_ctx + .catalog("datafusion") + .ok_or_else(|| internal_datafusion_err!("default catalog was not installed"))? + .register_schema("remote_schema", resolved_catalog)?; + + // We can now continue planning the query with this new query-specific context that + // contains our cached catalog + let query_state = query_ctx.state(); + + let plan = query_state.statement_to_plan(statement).await?; let results = DataFrame::new(state, plan).collect().await?; assert_batches_eq!( [ @@ -145,9 +136,9 @@ impl RemoteCatalogInterface { } /// Fetches information for a specific table - pub async fn table_info(&self, name: &str) -> Result { + pub async fn table_info(&self, name: &str) -> Result> { if name != "remote_table" { - return plan_err!("Remote table not found: {}", name); + return Ok(None); } // In this example, we'll model a remote table with columns "id" and @@ -159,7 +150,7 @@ impl RemoteCatalogInterface { Field::new("id", arrow::datatypes::DataType::Int32, false), Field::new("name", arrow::datatypes::DataType::Utf8, false), ])); - Ok(Arc::new(schema)) + Ok(Some(Arc::new(schema))) } /// Fetches data for a table from a remote data source @@ -186,95 +177,22 @@ impl RemoteCatalogInterface { } } -/// Implements the DataFusion Catalog API interface for tables +/// Implements an async version of the DataFusion SchemaProvider API for tables /// stored in a remote catalog. -#[derive(Debug)] -struct RemoteSchema { - /// Connection with the remote catalog - remote_catalog_interface: Arc, - /// Local cache of tables that have been preloaded from the remote - /// catalog - tables: Mutex>>, -} - -impl RemoteSchema { - /// Create a new RemoteSchema - pub fn new(remote_catalog_interface: Arc) -> Self { - Self { - remote_catalog_interface, - tables: Mutex::new(HashMap::new()), - } - } - - /// Load information for the specified tables from the remote source into - /// the local cached copy. - pub async fn load_tables( - &self, - references: impl IntoIterator, - ) -> Result<()> { - for table_name in references { - if !self.table_exist(table_name) { - // Fetch information about the table from the remote catalog - // - // Note that a real remote catalog interface could return more - // information, but at the minimum, DataFusion requires the - // table's schema for planing. - let schema = self.remote_catalog_interface.table_info(table_name).await?; - let remote_table = RemoteTable::new( - Arc::clone(&self.remote_catalog_interface), - table_name, - schema, - ); - - // Add the table to our local cached list - self.tables - .lock() - .expect("mutex invalid") - .insert(table_name.to_string(), Arc::new(remote_table)); - }; - } - Ok(()) - } -} +struct RemoteCatalogDatafusionAdapter(Arc); -/// Implement the DataFusion Catalog API for [`RemoteSchema`] #[async_trait] -impl SchemaProvider for RemoteSchema { - fn as_any(&self) -> &dyn Any { - self - } - - fn table_names(&self) -> Vec { - // Note this API is not async so we can't directly call the RemoteCatalogInterface - // instead we use the cached list of loaded tables - self.tables - .lock() - .expect("mutex valid") - .keys() - .cloned() - .collect() - } - - // While this API is actually `async` and thus could consult a remote - // catalog directly it is more efficient to use a local cached copy instead, - // which is what we model in this example - async fn table( - &self, - name: &str, - ) -> Result>, DataFusionError> { - // Look for any pre-loaded tables - let table = self - .tables - .lock() - .expect("mutex valid") - .get(name) - .map(Arc::clone); - Ok(table) - } - - fn table_exist(&self, name: &str) -> bool { - // Look for any pre-loaded tables, note this function is also `async` - self.tables.lock().expect("mutex valid").contains_key(name) +impl AsyncSchemaProvider for RemoteCatalogDatafusionAdapter { + async fn table(&self, name: &str) -> Result>> { + // Fetch information about the table from the remote catalog + // + // Note that a real remote catalog interface could return more + // information, but at the minimum, DataFusion requires the + // table's schema for planing. + Ok(self.0.table_info(name).await?.map(|schema| { + Arc::new(RemoteTable::new(Arc::clone(&self.0), name, schema)) + as Arc + })) } } @@ -343,27 +261,3 @@ impl TableProvider for RemoteTable { )?)) } } - -/// Return true if this `table_reference` might be for a table in the specified -/// catalog and schema. -fn refers_to_schema( - catalog_name: &str, - schema_name: &str, - table_reference: &TableReference, -) -> bool { - // Check the references are in the correct catalog and schema - // references like foo.bar.baz - if let Some(catalog) = table_reference.catalog() { - if catalog != catalog_name { - return false; - } - } - // references like bar.baz - if let Some(schema) = table_reference.schema() { - if schema != schema_name { - return false; - } - } - - true -} diff --git a/datafusion/catalog/Cargo.toml b/datafusion/catalog/Cargo.toml index f9801352087d..32a87cc7611c 100644 --- a/datafusion/catalog/Cargo.toml +++ b/datafusion/catalog/Cargo.toml @@ -36,5 +36,8 @@ datafusion-expr = { workspace = true } datafusion-physical-plan = { workspace = true } parking_lot = { workspace = true } +[dev-dependencies] +tokio = { workspace = true } + [lints] workspace = true diff --git a/datafusion/catalog/src/async.rs b/datafusion/catalog/src/async.rs new file mode 100644 index 000000000000..504f20ff9543 --- /dev/null +++ b/datafusion/catalog/src/async.rs @@ -0,0 +1,747 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion_common::{error::Result, not_impl_err, HashMap, TableReference}; +use datafusion_execution::config::SessionConfig; + +use crate::{CatalogProvider, CatalogProviderList, SchemaProvider, TableProvider}; + +/// A schema provider that looks up tables in a cache +/// +/// Instances are created by the [`AsyncSchemaProvider::resolve`] method +#[derive(Debug)] +struct ResolvedSchemaProvider { + owner_name: Option, + cached_tables: HashMap>, +} +#[async_trait] +impl SchemaProvider for ResolvedSchemaProvider { + fn owner_name(&self) -> Option<&str> { + self.owner_name.as_deref() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn table_names(&self) -> Vec { + self.cached_tables.keys().cloned().collect() + } + + async fn table(&self, name: &str) -> Result>> { + Ok(self.cached_tables.get(name).cloned()) + } + + fn register_table( + &self, + name: String, + _table: Arc, + ) -> Result>> { + not_impl_err!( + "Attempt to register table '{name}' with ResolvedSchemaProvider which is not supported" + ) + } + + fn deregister_table(&self, name: &str) -> Result>> { + not_impl_err!("Attempt to deregister table '{name}' with ResolvedSchemaProvider which is not supported") + } + + fn table_exist(&self, name: &str) -> bool { + self.cached_tables.contains_key(name) + } +} + +/// Helper class for building a [`ResolvedSchemaProvider`] +struct ResolvedSchemaProviderBuilder { + owner_name: String, + async_provider: Arc, + cached_tables: HashMap>>, +} +impl ResolvedSchemaProviderBuilder { + fn new(owner_name: String, async_provider: Arc) -> Self { + Self { + owner_name, + async_provider, + cached_tables: HashMap::new(), + } + } + + async fn resolve_table(&mut self, table_name: &str) -> Result<()> { + if !self.cached_tables.contains_key(table_name) { + let resolved_table = self.async_provider.table(table_name).await?; + self.cached_tables + .insert(table_name.to_string(), resolved_table); + } + Ok(()) + } + + fn finish(self) -> Arc { + let cached_tables = self + .cached_tables + .into_iter() + .filter_map(|(key, maybe_value)| maybe_value.map(|value| (key, value))) + .collect(); + Arc::new(ResolvedSchemaProvider { + owner_name: Some(self.owner_name), + cached_tables, + }) + } +} + +/// A catalog provider that looks up schemas in a cache +/// +/// Instances are created by the [`AsyncCatalogProvider::resolve`] method +#[derive(Debug)] +struct ResolvedCatalogProvider { + cached_schemas: HashMap>, +} +impl CatalogProvider for ResolvedCatalogProvider { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema_names(&self) -> Vec { + self.cached_schemas.keys().cloned().collect() + } + + fn schema(&self, name: &str) -> Option> { + self.cached_schemas.get(name).cloned() + } +} + +/// Helper class for building a [`ResolvedCatalogProvider`] +struct ResolvedCatalogProviderBuilder { + cached_schemas: HashMap>, + async_provider: Arc, +} +impl ResolvedCatalogProviderBuilder { + fn new(async_provider: Arc) -> Self { + Self { + cached_schemas: HashMap::new(), + async_provider, + } + } + fn finish(self) -> Arc { + let cached_schemas = self + .cached_schemas + .into_iter() + .filter_map(|(key, maybe_value)| { + maybe_value.map(|value| (key, value.finish())) + }) + .collect(); + Arc::new(ResolvedCatalogProvider { cached_schemas }) + } +} + +/// A catalog provider list that looks up catalogs in a cache +/// +/// Instances are created by the [`AsyncCatalogProviderList::resolve`] method +#[derive(Debug)] +struct ResolvedCatalogProviderList { + cached_catalogs: HashMap>, +} +impl CatalogProviderList for ResolvedCatalogProviderList { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn register_catalog( + &self, + _name: String, + _catalog: Arc, + ) -> Option> { + unimplemented!("resolved providers cannot handle registration APIs") + } + + fn catalog_names(&self) -> Vec { + self.cached_catalogs.keys().cloned().collect() + } + + fn catalog(&self, name: &str) -> Option> { + self.cached_catalogs.get(name).cloned() + } +} + +/// A trait for schema providers that must resolve tables asynchronously +/// +/// The [`SchemaProvider::table`] method _is_ asynchronous. However, this is primarily for convenience and +/// it is not a good idea for this method to be slow as this will cause poor planning performance. +/// +/// It is a better idea to resolve the tables once and cache them in memory for the duration of +/// planning. This trait helps implement that pattern. +/// +/// After implementing this trait you can call the [`AsyncSchemaProvider::resolve`] method to get an +/// `Arc` that contains a cached copy of the referenced tables. The `resolve` +/// method can be slow and asynchronous as it is only called once, before planning. +#[async_trait] +pub trait AsyncSchemaProvider: Send + Sync { + /// Lookup a table in the schema provider + async fn table(&self, name: &str) -> Result>>; + /// Creates a cached provider that can be used to execute a query containing given references + /// + /// This method will walk through the references and look them up once, creating a cache of table + /// providers. This cache will be returned as a synchronous TableProvider that can be used to plan + /// and execute a query containing the given references. + /// + /// This cache is intended to be short-lived for the execution of a single query. There is no mechanism + /// for refresh or eviction of stale entries. + async fn resolve( + &self, + references: &[TableReference], + config: &SessionConfig, + catalog_name: &str, + schema_name: &str, + ) -> Result> { + let mut cached_tables = HashMap::>>::new(); + + for reference in references { + let ref_catalog_name = reference + .catalog() + .unwrap_or(&config.options().catalog.default_catalog); + + // Maybe this is a reference to some other catalog provided in another way + if ref_catalog_name != catalog_name { + continue; + } + + let ref_schema_name = reference + .schema() + .unwrap_or(&config.options().catalog.default_schema); + + if ref_schema_name != schema_name { + continue; + } + + if !cached_tables.contains_key(reference.table()) { + let resolved_table = self.table(reference.table()).await?; + cached_tables.insert(reference.table().to_string(), resolved_table); + } + } + + let cached_tables = cached_tables + .into_iter() + .filter_map(|(key, maybe_value)| maybe_value.map(|value| (key, value))) + .collect(); + + Ok(Arc::new(ResolvedSchemaProvider { + cached_tables, + owner_name: Some(catalog_name.to_string()), + })) + } +} + +/// A trait for catalog providers that must resolve schemas asynchronously +/// +/// The [`CatalogProvider::schema`] method is synchronous because asynchronous operations should +/// not be used during planning. This trait makes it easy to lookup schema references once and cache +/// them for future planning use. See [`AsyncSchemaProvider`] for more details on motivation. + +#[async_trait] +pub trait AsyncCatalogProvider: Send + Sync { + /// Lookup a schema in the provider + async fn schema(&self, name: &str) -> Result>>; + + /// Creates a cached provider that can be used to execute a query containing given references + /// + /// This method will walk through the references and look them up once, creating a cache of schema + /// providers (each with their own cache of table providers). This cache will be returned as a + /// synchronous CatalogProvider that can be used to plan and execute a query containing the given + /// references. + /// + /// This cache is intended to be short-lived for the execution of a single query. There is no mechanism + /// for refresh or eviction of stale entries. + async fn resolve( + &self, + references: &[TableReference], + config: &SessionConfig, + catalog_name: &str, + ) -> Result> { + let mut cached_schemas = + HashMap::>::new(); + + for reference in references { + let ref_catalog_name = reference + .catalog() + .unwrap_or(&config.options().catalog.default_catalog); + + // Maybe this is a reference to some other catalog provided in another way + if ref_catalog_name != catalog_name { + continue; + } + + let schema_name = reference + .schema() + .unwrap_or(&config.options().catalog.default_schema); + + let schema = if let Some(schema) = cached_schemas.get_mut(schema_name) { + schema + } else { + let resolved_schema = self.schema(schema_name).await?; + let resolved_schema = resolved_schema.map(|resolved_schema| { + ResolvedSchemaProviderBuilder::new( + catalog_name.to_string(), + resolved_schema, + ) + }); + cached_schemas.insert(schema_name.to_string(), resolved_schema); + cached_schemas.get_mut(schema_name).unwrap() + }; + + // If we can't find the catalog don't bother checking the table + let Some(schema) = schema else { continue }; + + schema.resolve_table(reference.table()).await?; + } + + let cached_schemas = cached_schemas + .into_iter() + .filter_map(|(key, maybe_builder)| { + maybe_builder.map(|schema_builder| (key, schema_builder.finish())) + }) + .collect::>(); + + Ok(Arc::new(ResolvedCatalogProvider { cached_schemas })) + } +} + +/// A trait for catalog provider lists that must resolve catalogs asynchronously +/// +/// The [`CatalogProviderList::catalog`] method is synchronous because asynchronous operations should +/// not be used during planning. This trait makes it easy to lookup catalog references once and cache +/// them for future planning use. See [`AsyncSchemaProvider`] for more details on motivation. +#[async_trait] +pub trait AsyncCatalogProviderList: Send + Sync { + /// Lookup a catalog in the provider + async fn catalog(&self, name: &str) -> Result>>; + + /// Creates a cached provider that can be used to execute a query containing given references + /// + /// This method will walk through the references and look them up once, creating a cache of catalog + /// providers, schema providers, and table providers. This cache will be returned as a + /// synchronous CatalogProvider that can be used to plan and execute a query containing the given + /// references. + /// + /// This cache is intended to be short-lived for the execution of a single query. There is no mechanism + /// for refresh or eviction of stale entries. + async fn resolve( + &self, + references: &[TableReference], + config: &SessionConfig, + ) -> Result> { + let mut cached_catalogs = + HashMap::>::new(); + + for reference in references { + let catalog_name = reference + .catalog() + .unwrap_or(&config.options().catalog.default_catalog); + + // We will do three lookups here, one for the catalog, one for the schema, and one for the table + // We cache the result (both found results and not-found results) to speed up future lookups + // + // Note that a cache-miss is not an error at this point. We allow for the possibility that + // other providers may supply the reference. + // + // If this is the only provider then a not-found error will be raised during planning when it can't + // find the reference in the cache. + + let catalog = if let Some(catalog) = cached_catalogs.get_mut(catalog_name) { + catalog + } else { + let resolved_catalog = self.catalog(catalog_name).await?; + let resolved_catalog = + resolved_catalog.map(ResolvedCatalogProviderBuilder::new); + cached_catalogs.insert(catalog_name.to_string(), resolved_catalog); + cached_catalogs.get_mut(catalog_name).unwrap() + }; + + // If we can't find the catalog don't bother checking the schema / table + let Some(catalog) = catalog else { continue }; + + let schema_name = reference + .schema() + .unwrap_or(&config.options().catalog.default_schema); + + let schema = if let Some(schema) = catalog.cached_schemas.get_mut(schema_name) + { + schema + } else { + let resolved_schema = catalog.async_provider.schema(schema_name).await?; + let resolved_schema = resolved_schema.map(|async_schema| { + ResolvedSchemaProviderBuilder::new( + catalog_name.to_string(), + async_schema, + ) + }); + catalog + .cached_schemas + .insert(schema_name.to_string(), resolved_schema); + catalog.cached_schemas.get_mut(schema_name).unwrap() + }; + + // If we can't find the catalog don't bother checking the table + let Some(schema) = schema else { continue }; + + schema.resolve_table(reference.table()).await?; + } + + // Build the cached catalog provider list + let cached_catalogs = cached_catalogs + .into_iter() + .filter_map(|(key, maybe_builder)| { + maybe_builder.map(|catalog_builder| (key, catalog_builder.finish())) + }) + .collect::>(); + + Ok(Arc::new(ResolvedCatalogProviderList { cached_catalogs })) + } +} + +#[cfg(test)] +mod tests { + use std::{ + any::Any, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + }; + + use arrow_schema::SchemaRef; + use async_trait::async_trait; + use datafusion_common::{error::Result, Statistics, TableReference}; + use datafusion_execution::config::SessionConfig; + use datafusion_expr::{Expr, TableType}; + use datafusion_physical_plan::ExecutionPlan; + + use crate::{Session, TableProvider}; + + use super::{AsyncCatalogProvider, AsyncCatalogProviderList, AsyncSchemaProvider}; + + #[derive(Debug)] + struct MockTableProvider {} + #[async_trait] + impl TableProvider for MockTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + /// Get a reference to the schema for this table + fn schema(&self) -> SchemaRef { + unimplemented!() + } + + fn table_type(&self) -> TableType { + unimplemented!() + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + unimplemented!() + } + + fn statistics(&self) -> Option { + unimplemented!() + } + } + + #[derive(Default)] + struct MockAsyncSchemaProvider { + lookup_count: AtomicU32, + } + + const MOCK_CATALOG: &str = "mock_catalog"; + const MOCK_SCHEMA: &str = "mock_schema"; + const MOCK_TABLE: &str = "mock_table"; + + #[async_trait] + impl AsyncSchemaProvider for MockAsyncSchemaProvider { + async fn table(&self, name: &str) -> Result>> { + self.lookup_count.fetch_add(1, Ordering::Release); + if name == MOCK_TABLE { + Ok(Some(Arc::new(MockTableProvider {}))) + } else { + Ok(None) + } + } + } + + fn test_config() -> SessionConfig { + let mut config = SessionConfig::default(); + config.options_mut().catalog.default_catalog = MOCK_CATALOG.to_string(); + config.options_mut().catalog.default_schema = MOCK_SCHEMA.to_string(); + config + } + + #[tokio::test] + async fn test_async_schema_provider_resolve() { + async fn check( + refs: Vec, + expected_lookup_count: u32, + found_tables: &[&str], + not_found_tables: &[&str], + ) { + let async_provider = MockAsyncSchemaProvider::default(); + let cached_provider = async_provider + .resolve(&refs, &test_config(), MOCK_CATALOG, MOCK_SCHEMA) + .await + .unwrap(); + + assert_eq!( + async_provider.lookup_count.load(Ordering::Acquire), + expected_lookup_count + ); + + for table_ref in found_tables { + let table = cached_provider.table(table_ref).await.unwrap(); + assert!(table.is_some()); + } + + for table_ref in not_found_tables { + assert!(cached_provider.table(table_ref).await.unwrap().is_none()); + } + } + + // Basic full lookups + check( + vec![ + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE), + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "not_exists"), + ], + 2, + &[MOCK_TABLE], + &["not_exists"], + ) + .await; + + // Catalog / schema mismatch doesn't even search + check( + vec![ + TableReference::full(MOCK_CATALOG, "foo", MOCK_TABLE), + TableReference::full("foo", MOCK_SCHEMA, MOCK_TABLE), + ], + 0, + &[], + &[MOCK_TABLE], + ) + .await; + + // Both hits and misses cached + check( + vec![ + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE), + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE), + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "not_exists"), + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "not_exists"), + ], + 2, + &[MOCK_TABLE], + &["not_exists"], + ) + .await; + } + + #[derive(Default)] + struct MockAsyncCatalogProvider { + lookup_count: AtomicU32, + } + + #[async_trait] + impl AsyncCatalogProvider for MockAsyncCatalogProvider { + async fn schema( + &self, + name: &str, + ) -> Result>> { + self.lookup_count.fetch_add(1, Ordering::Release); + if name == MOCK_SCHEMA { + Ok(Some(Arc::new(MockAsyncSchemaProvider::default()))) + } else { + Ok(None) + } + } + } + + #[tokio::test] + async fn test_async_catalog_provider_resolve() { + async fn check( + refs: Vec, + expected_lookup_count: u32, + found_schemas: &[&str], + not_found_schemas: &[&str], + ) { + let async_provider = MockAsyncCatalogProvider::default(); + let cached_provider = async_provider + .resolve(&refs, &test_config(), MOCK_CATALOG) + .await + .unwrap(); + + assert_eq!( + async_provider.lookup_count.load(Ordering::Acquire), + expected_lookup_count + ); + + for schema_ref in found_schemas { + let schema = cached_provider.schema(schema_ref); + assert!(schema.is_some()); + } + + for schema_ref in not_found_schemas { + assert!(cached_provider.schema(schema_ref).is_none()); + } + } + + // Basic full lookups + check( + vec![ + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "x"), + TableReference::full(MOCK_CATALOG, "not_exists", "x"), + ], + 2, + &[MOCK_SCHEMA], + &["not_exists"], + ) + .await; + + // Catalog mismatch doesn't even search + check( + vec![TableReference::full("foo", MOCK_SCHEMA, "x")], + 0, + &[], + &[MOCK_SCHEMA], + ) + .await; + + // Both hits and misses cached + check( + vec![ + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "x"), + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "x"), + TableReference::full(MOCK_CATALOG, "not_exists", "x"), + TableReference::full(MOCK_CATALOG, "not_exists", "x"), + ], + 2, + &[MOCK_SCHEMA], + &["not_exists"], + ) + .await; + } + + #[derive(Default)] + struct MockAsyncCatalogProviderList { + lookup_count: AtomicU32, + } + + #[async_trait] + impl AsyncCatalogProviderList for MockAsyncCatalogProviderList { + async fn catalog( + &self, + name: &str, + ) -> Result>> { + self.lookup_count.fetch_add(1, Ordering::Release); + if name == MOCK_CATALOG { + Ok(Some(Arc::new(MockAsyncCatalogProvider::default()))) + } else { + Ok(None) + } + } + } + + #[tokio::test] + async fn test_async_catalog_provider_list_resolve() { + async fn check( + refs: Vec, + expected_lookup_count: u32, + found_catalogs: &[&str], + not_found_catalogs: &[&str], + ) { + let async_provider = MockAsyncCatalogProviderList::default(); + let cached_provider = + async_provider.resolve(&refs, &test_config()).await.unwrap(); + + assert_eq!( + async_provider.lookup_count.load(Ordering::Acquire), + expected_lookup_count + ); + + for catalog_ref in found_catalogs { + let catalog = cached_provider.catalog(catalog_ref); + assert!(catalog.is_some()); + } + + for catalog_ref in not_found_catalogs { + assert!(cached_provider.catalog(catalog_ref).is_none()); + } + } + + // Basic full lookups + check( + vec![ + TableReference::full(MOCK_CATALOG, "x", "x"), + TableReference::full("not_exists", "x", "x"), + ], + 2, + &[MOCK_CATALOG], + &["not_exists"], + ) + .await; + + // Both hits and misses cached + check( + vec![ + TableReference::full(MOCK_CATALOG, "x", "x"), + TableReference::full(MOCK_CATALOG, "x", "x"), + TableReference::full("not_exists", "x", "x"), + TableReference::full("not_exists", "x", "x"), + ], + 2, + &[MOCK_CATALOG], + &["not_exists"], + ) + .await; + } + + #[tokio::test] + async fn test_defaults() { + for table_ref in &[ + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE), + TableReference::partial(MOCK_SCHEMA, MOCK_TABLE), + TableReference::bare(MOCK_TABLE), + ] { + let async_provider = MockAsyncCatalogProviderList::default(); + let cached_provider = async_provider + .resolve(&[table_ref.clone()], &test_config()) + .await + .unwrap(); + + let catalog = cached_provider + .catalog(table_ref.catalog().unwrap_or(MOCK_CATALOG)) + .unwrap(); + let schema = catalog + .schema(table_ref.schema().unwrap_or(MOCK_SCHEMA)) + .unwrap(); + assert!(schema.table(table_ref.table()).await.unwrap().is_some()); + } + } +} diff --git a/datafusion/catalog/src/lib.rs b/datafusion/catalog/src/lib.rs index 21630f267d2c..3cf2a3b3cd33 100644 --- a/datafusion/catalog/src/lib.rs +++ b/datafusion/catalog/src/lib.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +mod r#async; mod catalog; mod dynamic_file; mod schema; @@ -23,6 +24,7 @@ mod table; pub use catalog::*; pub use dynamic_file::catalog::*; +pub use r#async::*; pub use schema::*; pub use session::*; pub use table::*; From 80c828b631fb40a882542f1e07a2e7dad9f41eca Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Fri, 10 Jan 2025 00:44:22 +0300 Subject: [PATCH 12/19] Custom scalar to sql overrides support for DuckDB Unparser dialect (#13915) * Allow adding custom scalar to sql overrides for DuckDB (#68) * Add unit test: custom_scalar_overrides_duckdb * Move `with_custom_scalar_overrides` definition on `Dialect` trait level --- datafusion/sql/src/unparser/dialect.rs | 46 ++++++++++++++++++++++++-- datafusion/sql/src/unparser/expr.rs | 26 ++++++++++++++- 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 3a44d7f0ec48..5c318a96ef6c 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use arrow_schema::TimeUnit; use datafusion_common::Result; @@ -29,6 +29,9 @@ use sqlparser::{ use super::{utils::character_length_to_sql, utils::date_part_to_sql, Unparser}; +pub type ScalarFnToSqlHandler = + Box Result> + Send + Sync>; + /// `Dialect` to use for Unparsing /// /// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`) @@ -150,6 +153,18 @@ pub trait Dialect: Send + Sync { Ok(None) } + /// Extends the dialect's default rules for unparsing scalar functions. + /// This is useful for supporting application-specific UDFs or custom engine extensions. + fn with_custom_scalar_overrides( + self, + _handlers: Vec<(&str, ScalarFnToSqlHandler)>, + ) -> Self + where + Self: Sized, + { + unimplemented!("Custom scalar overrides are not supported by this dialect yet"); + } + /// Allow to unparse a qualified column with a full qualified name /// (e.g. catalog_name.schema_name.table_name.column_name) /// Otherwise, the column will be unparsed with only the table name and column name @@ -305,7 +320,19 @@ impl PostgreSqlDialect { } } -pub struct DuckDBDialect {} +#[derive(Default)] +pub struct DuckDBDialect { + custom_scalar_fn_overrides: HashMap, +} + +impl DuckDBDialect { + #[must_use] + pub fn new() -> Self { + Self { + custom_scalar_fn_overrides: HashMap::new(), + } + } +} impl Dialect for DuckDBDialect { fn identifier_quote_style(&self, _: &str) -> Option { @@ -320,12 +347,27 @@ impl Dialect for DuckDBDialect { BinaryOperator::DuckIntegerDivide } + fn with_custom_scalar_overrides( + mut self, + handlers: Vec<(&str, ScalarFnToSqlHandler)>, + ) -> Self { + for (func_name, handler) in handlers { + self.custom_scalar_fn_overrides + .insert(func_name.to_string(), handler); + } + self + } + fn scalar_function_to_sql_overrides( &self, unparser: &Unparser, func_name: &str, args: &[Expr], ) -> Result> { + if let Some(handler) = self.custom_scalar_fn_overrides.get(func_name) { + return handler(unparser, args); + } + if func_name == "character_length" { return character_length_to_sql( unparser, diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index d012d3437720..7a110fd0785c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1636,7 +1636,7 @@ mod tests { use crate::unparser::dialect::{ CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, - Dialect, PostgreSqlDialect, + Dialect, DuckDBDialect, PostgreSqlDialect, ScalarFnToSqlHandler, }; use super::*; @@ -2722,4 +2722,28 @@ mod tests { Ok(()) } + + #[test] + fn test_custom_scalar_overrides_duckdb() -> Result<()> { + let duckdb_default = DuckDBDialect::new(); + let duckdb_extended = DuckDBDialect::new().with_custom_scalar_overrides(vec![( + "dummy_udf", + Box::new(|unparser: &Unparser, args: &[Expr]| { + unparser.scalar_function_to_sql("smart_udf", args).map(Some) + }) as ScalarFnToSqlHandler, + )]); + + for (dialect, expected) in [ + (duckdb_default, r#"dummy_udf("a", "b")"#), + (duckdb_extended, r#"smart_udf("a", "b")"#), + ] { + let unparser = Unparser::new(&dialect); + let expr = + ScalarUDF::new_from_impl(DummyUDF::new()).call(vec![col("a"), col("b")]); + let actual = format!("{}", unparser.expr_to_sql(&expr)?); + assert_eq!(actual, expected); + } + + Ok(()) + } } From f9d3133c4fd39e8a14329bbb98bb726283bccffd Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Fri, 10 Jan 2025 04:45:57 +0700 Subject: [PATCH 13/19] Improve perfomance of `reverse` function (#14025) * Improve perfomance of 'reverse' function Signed-off-by: Tai Le Manh * Apply sugestion change * Fix typo --------- Signed-off-by: Tai Le Manh --- datafusion/functions/Cargo.toml | 5 ++ datafusion/functions/benches/reverse.rs | 90 +++++++++++++++++++++ datafusion/functions/src/unicode/reverse.rs | 25 +++--- 3 files changed, 111 insertions(+), 9 deletions(-) create mode 100644 datafusion/functions/benches/reverse.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index fd986c4be41c..c8025fb2d895 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -204,6 +204,11 @@ harness = false name = "strpos" required-features = ["unicode_expressions"] +[[bench]] +harness = false +name = "reverse" +required-features = ["unicode_expressions"] + [[bench]] harness = false name = "trunc" diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs new file mode 100644 index 000000000000..c7c1ef8a8220 --- /dev/null +++ b/datafusion/functions/benches/reverse.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::OffsetSizeTrait; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::unicode; +use std::sync::Arc; + +fn create_args( + size: usize, + str_len: usize, + force_view_types: bool, +) -> Vec { + if force_view_types { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + + vec![ColumnarValue::Array(string_array)] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ColumnarValue::Array(string_array)] + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let reverse = unicode::reverse(); + for size in [1024, 4096] { + let str_len = 8; + + let args = create_args::(size, str_len, true); + c.bench_function( + format!("reverse_string_view [size={}, str_len={}]", size, str_len).as_str(), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(reverse.invoke_batch(&args, str_len)) + }) + }, + ); + + let str_len = 32; + + let args = create_args::(size, str_len, true); + c.bench_function( + format!("reverse_string_view [size={}, str_len={}]", size, str_len).as_str(), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(reverse.invoke_batch(&args, str_len)) + }) + }, + ); + + let args = create_args::(size, str_len, false); + c.bench_function( + format!("reverse_string [size={}, str_len={}]", size, str_len).as_str(), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(reverse.invoke_batch(&args, str_len)) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index 5ad347ed96c0..f07deda70e52 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -20,8 +20,7 @@ use std::sync::Arc; use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ - Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, - OffsetSizeTrait, + Array, ArrayRef, AsArray, GenericStringBuilder, OffsetSizeTrait, StringArrayType, }; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; @@ -105,8 +104,7 @@ impl ScalarUDFImpl for ReverseFunc { } } -/// Reverses the order of the characters in the string. -/// reverse('abcde') = 'edcba' +/// Reverses the order of the characters in the string `reverse('abcde') = 'edcba'`. /// The implementation uses UTF-8 code points as characters pub fn reverse(args: &[ArrayRef]) -> Result { if args[0].data_type() == &Utf8View { @@ -116,14 +114,23 @@ pub fn reverse(args: &[ArrayRef]) -> Result { } } -fn reverse_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor>( +fn reverse_impl<'a, T: OffsetSizeTrait, V: StringArrayType<'a>>( string_array: V, ) -> Result { - let result = ArrayIter::new(string_array) - .map(|string| string.map(|string: &str| string.chars().rev().collect::())) - .collect::>(); + let mut builder = GenericStringBuilder::::with_capacity(string_array.len(), 1024); + + let mut reversed = String::new(); + for string in string_array.iter() { + if let Some(s) = string { + reversed.extend(s.chars().rev()); + builder.append_value(&reversed); + reversed.clear(); + } else { + builder.append_null(); + } + } - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)] From 40f1ddd8dcf47be42d8331d862c923388471bee3 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Thu, 9 Jan 2025 16:46:54 -0500 Subject: [PATCH 14/19] docs(ci): use up-to-date protoc with docs.rs (#14048) --- datafusion/substrait/Cargo.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 60e35cc966d2..a60b70a963e9 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -54,3 +54,10 @@ tokio = { workspace = true } default = ["physical"] physical = ["datafusion/parquet"] protoc = ["substrait/protoc"] + +[package.metadata.docs.rs] +# Use default features ("physical") for docs, plus "protoc". "protoc" is needed +# to get a consistent version of the protobuf compiler in the docs build; +# without that, an outdated protobuf compiler may fail to compile the protobuf +# files as it did in versions 42.0.0 through 44.0.0. +all-features = true From e113ec13bcaa3084763307c37509df62733b923c Mon Sep 17 00:00:00 2001 From: cht42 <42912042+cht42@users.noreply.github.com> Date: Fri, 10 Jan 2025 01:49:58 +0400 Subject: [PATCH 15/19] fix (#14042) Co-authored-by: Cyprien Huet --- datafusion/functions-window/src/nth_value.rs | 5 ++- datafusion/sqllogictest/test_files/window.slt | 38 +++++++++++++++++-- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index d15e76718b02..e5d866940c05 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -360,9 +360,10 @@ impl PartitionEvaluator for NthValueEvaluator { }) .unwrap_or_default(); if valid_indices.is_empty() { - return ScalarValue::try_from(arr.data_type()); + None + } else { + Some(valid_indices) } - Some(valid_indices) } else { None }; diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 3c6f0f6deba1..2c82df969f1f 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -1767,7 +1767,7 @@ logical_plan 01)Projection: count(*) AS global_count 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 03)----SubqueryAlias: a -04)------Projection: +04)------Projection: 05)--------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] 06)----------Projection: aggregate_test_100.c1 07)------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") @@ -4625,6 +4625,38 @@ NULL 1 statement ok DROP TABLE t; + +# Test for ignore nulls in nth_VALUE without null values +statement ok +CREATE TABLE t AS VALUES (3, 3), (4, 4), (5, 5), (6, 6); + +query I +SELECT column1 FROM t ORDER BY column2; +---- +3 +4 +5 +6 + +query I +SELECT nth_VALUE(column1, 1) OVER(ORDER BY column2) FROM t; +---- +3 +3 +3 +3 + +query I +SELECT nth_VALUE(column1, 1) IGNORE NULLS OVER(ORDER BY column2) FROM t; +---- +3 +3 +3 +3 + +statement ok +DROP TABLE t; + # Test for ignore nulls with ORDER BY in nth_VALUE statement ok CREATE TABLE t AS VALUES (3, 3), (4, 4), (null::bigint, 1), (null::bigint, 2), (5, 5), (6, 6); @@ -5055,7 +5087,7 @@ select b, row_number() over (order by a) from (select TRUE as a, 1 as b); # test window functions on boolean columns statement count 0 -create table t1 (id int, bool_col boolean) as values +create table t1 (id int, bool_col boolean) as values (1, true), (2, false), (3, true), @@ -5110,7 +5142,7 @@ select ntile(2) over (order by bool_col) from t1; 2 query IIIRRI -select +select row_number() over (order by bool_col) as row_num, rank() over (order by bool_col) as rank, dense_rank() over (order by bool_col) as dense_rank, From bf28e9bed2d872abcfb2f54092f342e3d950bf41 Mon Sep 17 00:00:00 2001 From: niebayes Date: Fri, 10 Jan 2025 05:58:29 +0800 Subject: [PATCH 16/19] Re-export TypeSignatureClass from the datafusion-expr package (#14051) --- datafusion/expr/src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index c5016a18d443..a57fd80c48e1 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -71,7 +71,8 @@ pub use datafusion_expr_common::columnar_value::ColumnarValue; pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; pub use datafusion_expr_common::operator::Operator; pub use datafusion_expr_common::signature::{ - ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, + ArrayFunctionSignature, Signature, TypeSignature, TypeSignatureClass, Volatility, + TIMEZONE_WILDCARD, }; pub use datafusion_expr_common::type_coercion::binary; pub use expr::{ From 300c2afb076caa118c734f0d998595958f360ac3 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Fri, 10 Jan 2025 16:33:20 +0800 Subject: [PATCH 17/19] Fix clippy for Rust 1.84 (#14065) --- datafusion/common/src/pyarrow.rs | 3 +++ datafusion/core/src/datasource/listing/table.rs | 4 ++-- datafusion/core/src/execution/context/mod.rs | 9 +++------ datafusion/physical-expr-common/src/physical_expr.rs | 4 +--- datafusion/physical-expr-common/src/sort_expr.rs | 6 +++--- 5 files changed, 12 insertions(+), 14 deletions(-) diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs index bdcf831c7884..29869c8da561 100644 --- a/datafusion/common/src/pyarrow.rs +++ b/datafusion/common/src/pyarrow.rs @@ -138,6 +138,9 @@ mod tests { fn test_py_scalar() { init_python(); + // TODO: remove this attribute when bumping pyo3 to v0.23.0 + // See: + #[allow(unexpected_cfgs)] Python::with_gil(|py| { let scalar_float = ScalarValue::Float64(Some(12.34)); let py_float = scalar_float.into_py(py).call_method0(py, "as_py").unwrap(); diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 7470597ef72c..06b94f804268 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -1148,8 +1148,8 @@ impl ListingTable { /// This method first checks if the statistics for the given file are already cached. /// If they are, it returns the cached statistics. /// If they are not, it infers the statistics from the file and stores them in the cache. - async fn do_collect_statistics<'a>( - &'a self, + async fn do_collect_statistics( + &self, ctx: &SessionState, store: &Arc, part_file: &PartitionedFile, diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index f20475df150b..e5da49ad7b8b 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1047,7 +1047,7 @@ impl SessionContext { Ok(table) } - async fn find_and_deregister<'a>( + async fn find_and_deregister( &self, table_ref: impl Into, table_type: TableType, @@ -1481,10 +1481,7 @@ impl SessionContext { /// provided reference. /// /// [`register_table`]: SessionContext::register_table - pub async fn table<'a>( - &self, - table_ref: impl Into, - ) -> Result { + pub async fn table(&self, table_ref: impl Into) -> Result { let table_ref: TableReference = table_ref.into(); let provider = self.table_provider(table_ref.clone()).await?; let plan = LogicalPlanBuilder::scan( @@ -1511,7 +1508,7 @@ impl SessionContext { } /// Return a [`TableProvider`] for the specified table. - pub async fn table_provider<'a>( + pub async fn table_provider( &self, table_ref: impl Into, ) -> Result> { diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index c2e892d63da0..e90f9c32ee87 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -159,9 +159,7 @@ pub trait DynEq { impl DynEq for T { fn dyn_eq(&self, other: &dyn Any) -> bool { - other - .downcast_ref::() - .map_or(false, |other| other == self) + other.downcast_ref::() == Some(self) } } diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 0d7501610662..8395d3e5263d 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -291,9 +291,9 @@ impl PhysicalSortRequirement { /// Returns whether this requirement is equal or more specific than `other`. pub fn compatible(&self, other: &PhysicalSortRequirement) -> bool { self.expr.eq(&other.expr) - && other.options.map_or(true, |other_opts| { - self.options.map_or(false, |opts| opts == other_opts) - }) + && other + .options + .map_or(true, |other_opts| self.options == Some(other_opts)) } #[deprecated(since = "43.0.0", note = "use LexRequirement::from_lex_ordering")] From 268df42a600e85ab80c6934d4ffb54062649dac2 Mon Sep 17 00:00:00 2001 From: niebayes Date: Fri, 10 Jan 2025 17:14:16 +0800 Subject: [PATCH 18/19] fix: incorrect error message of function_length_check (#14056) * minor fix * add ut * remove check for 0 arg --- .../expr/src/type_coercion/functions.rs | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 96bb5c4b2d8f..5294cc526d38 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -438,18 +438,11 @@ fn get_valid_types( } fn function_length_check(length: usize, expected_length: usize) -> Result<()> { - if length < 1 { - return plan_err!( - "The signature expected at least one argument but received {expected_length}" - ); - } - if length != expected_length { return plan_err!( - "The signature expected {length} arguments but received {expected_length}" + "The signature expected {expected_length} arguments but received {length}" ); } - Ok(()) } @@ -939,6 +932,7 @@ mod tests { use super::*; use arrow::datatypes::Field; + use datafusion_common::assert_contains; #[test] fn test_string_conversion() { @@ -1027,6 +1021,29 @@ mod tests { Ok(()) } + #[test] + fn test_get_valid_types_length_check() -> Result<()> { + let signature = TypeSignature::Numeric(1); + + let err = get_valid_types(&signature, &[]).unwrap_err(); + assert_contains!( + err.to_string(), + "The signature expected 1 arguments but received 0" + ); + + let err = get_valid_types( + &signature, + &[DataType::Int32, DataType::Int32, DataType::Int32], + ) + .unwrap_err(); + assert_contains!( + err.to_string(), + "The signature expected 1 arguments but received 3" + ); + + Ok(()) + } + #[test] fn test_fixed_list_wildcard_coerce() -> Result<()> { let inner = Arc::new(Field::new_list_field(DataType::Int32, false)); From 295ffb41f562f2c58b22ce36928df3f85f7fe09a Mon Sep 17 00:00:00 2001 From: Aleksey Kirilishin <54231417+avkirilishin@users.noreply.github.com> Date: Fri, 10 Jan 2025 13:19:23 +0300 Subject: [PATCH 19/19] test: Add plan execution during tests for bounded source (#14013) --- .../replace_with_order_preserving_variants.rs | 251 +++++++++++------- 1 file changed, 151 insertions(+), 100 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 96b2454fa330..9f5afc7abc2e 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -285,9 +285,7 @@ pub(crate) fn replace_with_order_preserving_variants( mod tests { use super::*; - use crate::datasource::file_format::file_compression_type::FileCompressionType; - use crate::datasource::listing::PartitionedFile; - use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; + use crate::execution::TaskContext; use crate::physical_optimizer::test_utils::check_integrity; use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; use crate::physical_plan::filter::FilterExec; @@ -296,18 +294,24 @@ mod tests { use crate::physical_plan::{ displayable, get_plan_string, ExecutionPlan, Partitioning, }; - use crate::prelude::SessionConfig; + use crate::prelude::{SessionConfig, SessionContext}; use crate::test::TestStreamPartition; + use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::Result; - use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{self, col, Column}; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_plan::collect; + use datafusion_physical_plan::memory::MemoryExec; use datafusion_physical_plan::streaming::StreamingTableExec; + use object_store::memory::InMemory; + use object_store::ObjectStore; + use url::Url; use rstest::rstest; @@ -328,20 +332,24 @@ mod tests { /// * `$PLAN`: The plan to optimize. /// * `$SOURCE_UNBOUNDED`: Whether the given plan contains an unbounded source. macro_rules! assert_optimized_in_all_boundedness_situations { - ($EXPECTED_UNBOUNDED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PLAN_LINES: expr, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $SOURCE_UNBOUNDED: expr) => { + ($EXPECTED_UNBOUNDED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PLAN_LINES: expr, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $SOURCE_UNBOUNDED: expr, $PREFER_EXISTING_SORT: expr) => { if $SOURCE_UNBOUNDED { assert_optimized_prefer_sort_on_off!( $EXPECTED_UNBOUNDED_PLAN_LINES, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, - $PLAN + $PLAN, + $PREFER_EXISTING_SORT, + $SOURCE_UNBOUNDED ); } else { assert_optimized_prefer_sort_on_off!( $EXPECTED_BOUNDED_PLAN_LINES, $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES, $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, - $PLAN + $PLAN, + $PREFER_EXISTING_SORT, + $SOURCE_UNBOUNDED ); } }; @@ -359,19 +367,24 @@ mod tests { /// the flag `prefer_existing_sort` is `true`. /// * `$PLAN`: The plan to optimize. macro_rules! assert_optimized_prefer_sort_on_off { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { - assert_optimized!( - $EXPECTED_PLAN_LINES, - $EXPECTED_OPTIMIZED_PLAN_LINES, - $PLAN.clone(), - false - ); - assert_optimized!( - $EXPECTED_PLAN_LINES, - $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, - $PLAN, - true - ); + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr, $SOURCE_UNBOUNDED: expr) => { + if $PREFER_EXISTING_SORT { + assert_optimized!( + $EXPECTED_PLAN_LINES, + $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, + $PLAN, + $PREFER_EXISTING_SORT, + $SOURCE_UNBOUNDED + ); + } else { + assert_optimized!( + $EXPECTED_PLAN_LINES, + $EXPECTED_OPTIMIZED_PLAN_LINES, + $PLAN, + $PREFER_EXISTING_SORT, + $SOURCE_UNBOUNDED + ); + } }; } @@ -385,7 +398,7 @@ mod tests { /// * `$PLAN`: The plan to optimize. /// * `$PREFER_EXISTING_SORT`: Value of the `prefer_existing_sort` flag. macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr) => { + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr, $SOURCE_UNBOUNDED: expr) => { let physical_plan = $PLAN; let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -412,6 +425,19 @@ mod tests { expected_optimized_lines, actual, "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected_optimized_lines:#?}\nactual:\n\n{actual:#?}\n\n" ); + + if !$SOURCE_UNBOUNDED { + let ctx = SessionContext::new(); + let object_store = InMemory::new(); + object_store.put(&object_store::path::Path::from("file_path"), bytes::Bytes::from("").into()).await?; + ctx.register_object_store(&Url::parse("test://").unwrap(), Arc::new(object_store)); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let res = collect(optimized_physical_plan, task_ctx).await; + assert!( + res.is_ok(), + "Some errors occurred while executing the optimized physical plan: {:?}", res.unwrap_err() + ); + } }; } @@ -420,13 +446,14 @@ mod tests { // Searches for a simple sort and a repartition just after it, the second repartition with 1 input partition should not be affected async fn test_replace_multiple_input_repartition_1( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); @@ -447,7 +474,7 @@ mod tests { " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -464,13 +491,13 @@ mod tests { " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -479,7 +506,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -488,13 +516,14 @@ mod tests { #[tokio::test] async fn test_with_inter_children_change_only( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr_default("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -538,7 +567,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC", ]; // Expected unbounded result (same for with and without flag) @@ -564,7 +593,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC]", @@ -574,7 +603,7 @@ mod tests { " SortPreservingMergeExec: [a@0 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -583,7 +612,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -592,13 +622,14 @@ mod tests { #[tokio::test] async fn test_replace_multiple_input_repartition_2( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let filter = filter_exec(repartition_rr); @@ -623,7 +654,7 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -642,14 +673,14 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -658,7 +689,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -667,13 +699,14 @@ mod tests { #[tokio::test] async fn test_replace_multiple_input_repartition_with_extra_steps( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -701,7 +734,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -722,7 +755,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -730,7 +763,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -739,7 +772,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -748,13 +782,14 @@ mod tests { #[tokio::test] async fn test_replace_multiple_input_repartition_with_extra_steps_2( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); @@ -786,7 +821,7 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -809,7 +844,7 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -818,7 +853,7 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -827,7 +862,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -836,13 +872,14 @@ mod tests { #[tokio::test] async fn test_not_replacing_when_no_need_to_preserve_sorting( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -867,7 +904,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -887,7 +924,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; @@ -898,7 +935,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -907,13 +945,14 @@ mod tests { #[tokio::test] async fn test_with_multiple_replacable_repartitions( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -944,7 +983,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -967,7 +1006,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -976,7 +1015,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -985,7 +1024,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -994,13 +1034,14 @@ mod tests { #[tokio::test] async fn test_not_replace_with_different_orderings( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -1028,7 +1069,7 @@ mod tests { " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -1046,7 +1087,7 @@ mod tests { " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; @@ -1057,7 +1098,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -1066,13 +1108,14 @@ mod tests { #[tokio::test] async fn test_with_lost_ordering( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -1093,7 +1136,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -1110,13 +1153,13 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -1125,7 +1168,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -1134,13 +1178,14 @@ mod tests { #[tokio::test] async fn test_with_lost_and_kept_ordering( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -1184,7 +1229,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -1211,7 +1256,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [c@1 ASC]", @@ -1222,7 +1267,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -1231,7 +1276,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -1240,6 +1286,7 @@ mod tests { #[tokio::test] async fn test_with_multiple_child_trees( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; @@ -1247,7 +1294,7 @@ mod tests { let left_source = if source_unbounded { stream_exec_ordered(&schema, left_sort_exprs) } else { - csv_exec_sorted(&schema, left_sort_exprs) + memory_exec_sorted(&schema, left_sort_exprs) }; let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); @@ -1258,7 +1305,7 @@ mod tests { let right_source = if source_unbounded { stream_exec_ordered(&schema, right_sort_exprs) } else { - csv_exec_sorted(&schema, right_sort_exprs) + memory_exec_sorted(&schema, right_sort_exprs) }; let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); @@ -1299,11 +1346,11 @@ mod tests { " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -1330,11 +1377,11 @@ mod tests { " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; @@ -1345,7 +1392,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -1492,33 +1540,36 @@ mod tests { ) } - // creates a csv exec source for the test purposes - // projection and has_header parameters are given static due to testing needs - fn csv_exec_sorted( + // creates a memory exec source for the test purposes + // projection parameter is given static due to testing needs + fn memory_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - let projection: Vec = vec![0, 2, 3]; + pub fn make_partition(schema: &SchemaRef, sz: i32) -> RecordBatch { + let values = (0..sz).collect::>(); + let arr = Arc::new(Int32Array::from(values)); + let arr = arr as ArrayRef; - Arc::new( - CsvExec::builder( - FileScanConfig::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - ) - .with_file(PartitionedFile::new("file_path".to_string(), 100)) - .with_projection(Some(projection)) - .with_output_ordering(vec![sort_exprs]), + RecordBatch::try_new( + schema.clone(), + vec![arr.clone(), arr.clone(), arr.clone(), arr], ) - .with_has_header(true) - .with_delimeter(0) - .with_quote(b'"') - .with_escape(None) - .with_comment(None) - .with_newlines_in_values(false) - .with_file_compression_type(FileCompressionType::UNCOMPRESSED) - .build(), - ) + .unwrap() + } + + let rows = 5; + let partitions = 1; + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new({ + let data: Vec> = (0..partitions) + .map(|_| vec![make_partition(schema, rows)]) + .collect(); + let projection: Vec = vec![0, 2, 3]; + MemoryExec::try_new(&data, schema.clone(), Some(projection)) + .unwrap() + .try_with_sort_information(vec![sort_exprs]) + .unwrap() + }) } }