diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index cc23e99e8cba..19af21ec910b 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -30,7 +30,7 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Audit licenses diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 14b2038e8794..ab6a615ab60b 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -24,7 +24,7 @@ jobs: path: asf-site - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 485d179571e3..099aab061435 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -348,7 +348,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: "3.8" - name: Install PyArrow diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index f88c907b052f..76be04d5ef67 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1161,6 +1161,7 @@ dependencies = [ "mimalloc", "object_store", "parking_lot", + "parquet", "predicates", "regex", "rstest", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index fd2dfd76c20e..5ce318aea3ac 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -40,6 +40,7 @@ env_logger = "0.9" mimalloc = { version = "0.1", default-features = false } object_store = { version = "0.8.0", features = ["aws", "gcp"] } parking_lot = { version = "0.12" } +parquet = { version = "49.0.0", default-features = false } regex = "1.8" rustyline = "11.0" tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index eeebe713d716..24f3399ee2be 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -16,12 +16,26 @@ // under the License. //! Functions that are query-able and searchable via the `\h` command -use arrow::array::StringArray; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::array::{Int64Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; +use async_trait::async_trait; +use datafusion::common::DataFusionError; +use datafusion::common::{plan_err, Column}; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; use datafusion::error::Result; +use datafusion::execution::context::SessionState; +use datafusion::logical_expr::Expr; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::scalar::ScalarValue; +use parquet::file::reader::FileReader; +use parquet::file::serialized_reader::SerializedFileReader; +use parquet::file::statistics::Statistics; use std::fmt; +use std::fs::File; use std::str::FromStr; use std::sync::Arc; @@ -196,3 +210,208 @@ pub fn display_all_functions() -> Result<()> { println!("{}", pretty_format_batches(&[batch]).unwrap()); Ok(()) } + +/// PARQUET_META table function +struct ParquetMetadataTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for ParquetMetadataTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(MemoryExec::try_new( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} + +pub struct ParquetMetadataFunc {} + +impl TableFunctionImpl for ParquetMetadataFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let filename = match exprs.get(0) { + Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') + Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") + _ => { + return plan_err!( + "parquet_metadata requires string argument as its input" + ); + } + }; + + let file = File::open(filename.clone())?; + let reader = SerializedFileReader::new(file)?; + let metadata = reader.metadata(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("filename", DataType::Utf8, true), + Field::new("row_group_id", DataType::Int64, true), + Field::new("row_group_num_rows", DataType::Int64, true), + Field::new("row_group_num_columns", DataType::Int64, true), + Field::new("row_group_bytes", DataType::Int64, true), + Field::new("column_id", DataType::Int64, true), + Field::new("file_offset", DataType::Int64, true), + Field::new("num_values", DataType::Int64, true), + Field::new("path_in_schema", DataType::Utf8, true), + Field::new("type", DataType::Utf8, true), + Field::new("stats_min", DataType::Utf8, true), + Field::new("stats_max", DataType::Utf8, true), + Field::new("stats_null_count", DataType::Int64, true), + Field::new("stats_distinct_count", DataType::Int64, true), + Field::new("stats_min_value", DataType::Utf8, true), + Field::new("stats_max_value", DataType::Utf8, true), + Field::new("compression", DataType::Utf8, true), + Field::new("encodings", DataType::Utf8, true), + Field::new("index_page_offset", DataType::Int64, true), + Field::new("dictionary_page_offset", DataType::Int64, true), + Field::new("data_page_offset", DataType::Int64, true), + Field::new("total_compressed_size", DataType::Int64, true), + Field::new("total_uncompressed_size", DataType::Int64, true), + ])); + + // construct recordbatch from metadata + let mut filename_arr = vec![]; + let mut row_group_id_arr = vec![]; + let mut row_group_num_rows_arr = vec![]; + let mut row_group_num_columns_arr = vec![]; + let mut row_group_bytes_arr = vec![]; + let mut column_id_arr = vec![]; + let mut file_offset_arr = vec![]; + let mut num_values_arr = vec![]; + let mut path_in_schema_arr = vec![]; + let mut type_arr = vec![]; + let mut stats_min_arr = vec![]; + let mut stats_max_arr = vec![]; + let mut stats_null_count_arr = vec![]; + let mut stats_distinct_count_arr = vec![]; + let mut stats_min_value_arr = vec![]; + let mut stats_max_value_arr = vec![]; + let mut compression_arr = vec![]; + let mut encodings_arr = vec![]; + let mut index_page_offset_arr = vec![]; + let mut dictionary_page_offset_arr = vec![]; + let mut data_page_offset_arr = vec![]; + let mut total_compressed_size_arr = vec![]; + let mut total_uncompressed_size_arr = vec![]; + for (rg_idx, row_group) in metadata.row_groups().iter().enumerate() { + for (col_idx, column) in row_group.columns().iter().enumerate() { + filename_arr.push(filename.clone()); + row_group_id_arr.push(rg_idx as i64); + row_group_num_rows_arr.push(row_group.num_rows()); + row_group_num_columns_arr.push(row_group.num_columns() as i64); + row_group_bytes_arr.push(row_group.total_byte_size()); + column_id_arr.push(col_idx as i64); + file_offset_arr.push(column.file_offset()); + num_values_arr.push(column.num_values()); + path_in_schema_arr.push(column.column_path().to_string()); + type_arr.push(column.column_type().to_string()); + if let Some(s) = column.statistics() { + let (min_val, max_val) = if s.has_min_max_set() { + let (min_val, max_val) = match s { + Statistics::Boolean(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Int32(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Int64(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Int96(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Float(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Double(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::ByteArray(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::FixedLenByteArray(val) => { + (val.min().to_string(), val.max().to_string()) + } + }; + (Some(min_val), Some(max_val)) + } else { + (None, None) + }; + stats_min_arr.push(min_val.clone()); + stats_max_arr.push(max_val.clone()); + stats_null_count_arr.push(Some(s.null_count() as i64)); + stats_distinct_count_arr.push(s.distinct_count().map(|c| c as i64)); + stats_min_value_arr.push(min_val); + stats_max_value_arr.push(max_val); + } else { + stats_min_arr.push(None); + stats_max_arr.push(None); + stats_null_count_arr.push(None); + stats_distinct_count_arr.push(None); + stats_min_value_arr.push(None); + stats_max_value_arr.push(None); + }; + compression_arr.push(format!("{:?}", column.compression())); + encodings_arr.push(format!("{:?}", column.encodings())); + index_page_offset_arr.push(column.index_page_offset()); + dictionary_page_offset_arr.push(column.dictionary_page_offset()); + data_page_offset_arr.push(column.data_page_offset()); + total_compressed_size_arr.push(column.compressed_size()); + total_uncompressed_size_arr.push(column.uncompressed_size()); + } + } + + let rb = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(filename_arr)), + Arc::new(Int64Array::from(row_group_id_arr)), + Arc::new(Int64Array::from(row_group_num_rows_arr)), + Arc::new(Int64Array::from(row_group_num_columns_arr)), + Arc::new(Int64Array::from(row_group_bytes_arr)), + Arc::new(Int64Array::from(column_id_arr)), + Arc::new(Int64Array::from(file_offset_arr)), + Arc::new(Int64Array::from(num_values_arr)), + Arc::new(StringArray::from(path_in_schema_arr)), + Arc::new(StringArray::from(type_arr)), + Arc::new(StringArray::from(stats_min_arr)), + Arc::new(StringArray::from(stats_max_arr)), + Arc::new(Int64Array::from(stats_null_count_arr)), + Arc::new(Int64Array::from(stats_distinct_count_arr)), + Arc::new(StringArray::from(stats_min_value_arr)), + Arc::new(StringArray::from(stats_max_value_arr)), + Arc::new(StringArray::from(compression_arr)), + Arc::new(StringArray::from(encodings_arr)), + Arc::new(Int64Array::from(index_page_offset_arr)), + Arc::new(Int64Array::from(dictionary_page_offset_arr)), + Arc::new(Int64Array::from(data_page_offset_arr)), + Arc::new(Int64Array::from(total_compressed_size_arr)), + Arc::new(Int64Array::from(total_uncompressed_size_arr)), + ], + )?; + + let parquet_metadata = ParquetMetadataTable { schema, batch: rb }; + Ok(Arc::new(parquet_metadata)) + } +} diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index c069f458f196..8b1a9816afc0 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -22,6 +22,7 @@ use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool}; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicFileCatalog; +use datafusion_cli::functions::ParquetMetadataFunc; use datafusion_cli::{ exec, print_format::PrintFormat, @@ -185,6 +186,8 @@ pub async fn main() -> Result<()> { ctx.state().catalog_list(), ctx.state_weak_ref(), ))); + // register `parquet_metadata` table function to get metadata from parquet files + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); let mut print_options = PrintOptions { format: args.format, @@ -328,6 +331,8 @@ fn extract_memory_pool_size(size: &str) -> Result { #[cfg(test)] mod tests { + use datafusion::assert_batches_eq; + use super::*; fn assert_conversion(input: &str, expected: Result) { @@ -385,4 +390,34 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_parquet_metadata_works() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); + + // input with single quote + let sql = + "SELECT * FROM parquet_metadata('../datafusion/core/tests/data/fixed_size_list_array.parquet')"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + + let excepted = [ + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + "| filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size |", + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + "| ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | \"f0.list.item\" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [RLE_DICTIONARY, PLAIN, RLE] | | 4 | 46 | 121 | 123 |", + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + ]; + assert_batches_eq!(excepted, &rbs); + + // input with double quote + let sql = + "SELECT * FROM parquet_metadata(\"../datafusion/core/tests/data/fixed_size_list_array.parquet\")"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + assert_batches_eq!(excepted, &rbs); + + Ok(()) + } } diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index 9f25a0b2fa47..69f9c9530e87 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -80,7 +80,7 @@ async fn search_accounts( timeout(Duration::from_secs(10), async move { let result = dataframe.collect().await.unwrap(); - let record_batch = result.get(0).unwrap(); + let record_batch = result.first().unwrap(); assert_eq!(expected_result_length, record_batch.column(1).len()); dbg!(record_batch.columns()); diff --git a/datafusion-examples/examples/memtable.rs b/datafusion-examples/examples/memtable.rs index bef8f3e5bb8f..5cce578039e7 100644 --- a/datafusion-examples/examples/memtable.rs +++ b/datafusion-examples/examples/memtable.rs @@ -40,7 +40,7 @@ async fn main() -> Result<()> { timeout(Duration::from_secs(10), async move { let result = dataframe.collect().await.unwrap(); - let record_batch = result.get(0).unwrap(); + let record_batch = result.first().unwrap(); assert_eq!(1, record_batch.column(0).len()); dbg!(record_batch.columns()); diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index bce633765281..e120c5e7bf8e 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -129,7 +129,7 @@ struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.get(0) else { + let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { return plan_err!("read_csv requires at least one string argument"); }; diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 52cd85675824..e06f947ad5e7 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -199,9 +199,16 @@ impl DFSchema { pub fn with_functional_dependencies( mut self, functional_dependencies: FunctionalDependencies, - ) -> Self { - self.functional_dependencies = functional_dependencies; - self + ) -> Result { + if functional_dependencies.is_valid(self.fields.len()) { + self.functional_dependencies = functional_dependencies; + Ok(self) + } else { + _plan_err!( + "Invalid functional dependency: {:?}", + functional_dependencies + ) + } } /// Create a new schema that contains the fields from this schema followed by the fields @@ -1476,8 +1483,8 @@ mod tests { DFSchema::new_with_metadata([a, b].to_vec(), HashMap::new()).unwrap(), ); let schema: Schema = df_schema.as_ref().clone().into(); - let a_df = df_schema.fields.get(0).unwrap().field(); - let a_arrow = schema.fields.get(0).unwrap(); + let a_df = df_schema.fields.first().unwrap().field(); + let a_arrow = schema.fields.first().unwrap(); assert_eq!(a_df.metadata(), a_arrow.metadata()) } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 9114c669ab8b..4ae30ae86cdd 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -564,18 +564,16 @@ mod test { assert_eq!( err.split(DataFusionError::BACK_TRACE_SEP) .collect::>() - .get(0) + .first() .unwrap(), &"Error during planning: Err" ); - assert!( - err.split(DataFusionError::BACK_TRACE_SEP) - .collect::>() - .get(1) - .unwrap() - .len() - > 0 - ); + assert!(!err + .split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .get(1) + .unwrap() + .is_empty()); } #[cfg(not(feature = "backtrace"))] diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 4587677e7726..1cb1751d713e 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -24,6 +24,7 @@ use std::ops::Deref; use std::vec::IntoIter; use crate::error::_plan_err; +use crate::utils::{merge_and_order_indices, set_difference}; use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result}; use sqlparser::ast::TableConstraint; @@ -271,6 +272,29 @@ impl FunctionalDependencies { self.deps.extend(other.deps); } + /// Sanity checks if functional dependencies are valid. For example, if + /// there are 10 fields, we cannot receive any index further than 9. + pub fn is_valid(&self, n_field: usize) -> bool { + self.deps.iter().all( + |FunctionalDependence { + source_indices, + target_indices, + .. + }| { + source_indices + .iter() + .max() + .map(|&max_index| max_index < n_field) + .unwrap_or(true) + && target_indices + .iter() + .max() + .map(|&max_index| max_index < n_field) + .unwrap_or(true) + }, + ) + } + /// Adds the `offset` value to `source_indices` and `target_indices` for /// each functional dependency. pub fn add_offset(&mut self, offset: usize) { @@ -442,44 +466,56 @@ pub fn aggregate_functional_dependencies( } in &func_dependencies.deps { // Keep source indices in a `HashSet` to prevent duplicate entries: - let mut new_source_indices = HashSet::new(); + let mut new_source_indices = vec![]; + let mut new_source_field_names = vec![]; let source_field_names = source_indices .iter() .map(|&idx| aggr_input_fields[idx].qualified_name()) .collect::>(); + for (idx, group_by_expr_name) in group_by_expr_names.iter().enumerate() { // When one of the input determinant expressions matches with // the GROUP BY expression, add the index of the GROUP BY // expression as a new determinant key: if source_field_names.contains(group_by_expr_name) { - new_source_indices.insert(idx); + new_source_indices.push(idx); + new_source_field_names.push(group_by_expr_name.clone()); } } + let existing_target_indices = + get_target_functional_dependencies(aggr_input_schema, group_by_expr_names); + let new_target_indices = get_target_functional_dependencies( + aggr_input_schema, + &new_source_field_names, + ); + let mode = if existing_target_indices == new_target_indices + && new_target_indices.is_some() + { + // If dependency covers all GROUP BY expressions, mode will be `Single`: + Dependency::Single + } else { + // Otherwise, existing mode is preserved: + *mode + }; // All of the composite indices occur in the GROUP BY expression: if new_source_indices.len() == source_indices.len() { aggregate_func_dependencies.push( FunctionalDependence::new( - new_source_indices.into_iter().collect(), + new_source_indices, target_indices.clone(), *nullable, ) - // input uniqueness stays the same when GROUP BY matches with input functional dependence determinants - .with_mode(*mode), + .with_mode(mode), ); } } + // If we have a single GROUP BY key, we can guarantee uniqueness after // aggregation: if group_by_expr_names.len() == 1 { // If `source_indices` contain 0, delete this functional dependency // as it will be added anyway with mode `Dependency::Single`: - if let Some(idx) = aggregate_func_dependencies - .iter() - .position(|item| item.source_indices.contains(&0)) - { - // Delete the functional dependency that contains zeroth idx: - aggregate_func_dependencies.remove(idx); - } + aggregate_func_dependencies.retain(|item| !item.source_indices.contains(&0)); // Add a new functional dependency associated with the whole table: aggregate_func_dependencies.push( // Use nullable property of the group by expression @@ -527,8 +563,61 @@ pub fn get_target_functional_dependencies( combined_target_indices.extend(target_indices.iter()); } } - (!combined_target_indices.is_empty()) - .then_some(combined_target_indices.iter().cloned().collect::>()) + (!combined_target_indices.is_empty()).then_some({ + let mut result = combined_target_indices.into_iter().collect::>(); + result.sort(); + result + }) +} + +/// Returns indices for the minimal subset of GROUP BY expressions that are +/// functionally equivalent to the original set of GROUP BY expressions. +pub fn get_required_group_by_exprs_indices( + schema: &DFSchema, + group_by_expr_names: &[String], +) -> Option> { + let dependencies = schema.functional_dependencies(); + let field_names = schema + .fields() + .iter() + .map(|item| item.qualified_name()) + .collect::>(); + let mut groupby_expr_indices = group_by_expr_names + .iter() + .map(|group_by_expr_name| { + field_names + .iter() + .position(|field_name| field_name == group_by_expr_name) + }) + .collect::>>()?; + + groupby_expr_indices.sort(); + for FunctionalDependence { + source_indices, + target_indices, + .. + } in &dependencies.deps + { + if source_indices + .iter() + .all(|source_idx| groupby_expr_indices.contains(source_idx)) + { + // If all source indices are among GROUP BY expression indices, we + // can remove target indices from GROUP BY expression indices and + // use source indices instead. + groupby_expr_indices = set_difference(&groupby_expr_indices, target_indices); + groupby_expr_indices = + merge_and_order_indices(groupby_expr_indices, source_indices); + } + } + groupby_expr_indices + .iter() + .map(|idx| { + group_by_expr_names + .iter() + .position(|name| &field_names[*idx] == name) + }) + .collect() } /// Updates entries inside the `entries` vector with their corresponding diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 6df89624fc51..ed547782e4a5 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -56,8 +56,9 @@ pub use file_options::file_type::{ }; pub use file_options::FileTypeWriterOptions; pub use functional_dependencies::{ - aggregate_functional_dependencies, get_target_functional_dependencies, Constraint, - Constraints, Dependency, FunctionalDependence, FunctionalDependencies, + aggregate_functional_dependencies, get_required_group_by_exprs_indices, + get_target_functional_dependencies, Constraint, Constraints, Dependency, + FunctionalDependence, FunctionalDependencies, }; pub use join_type::{JoinConstraint, JoinSide, JoinType}; pub use param_value::ParamValues; diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 1f302c750916..d730fbf89b72 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -46,6 +46,7 @@ use arrow::{ }, }; use arrow_array::cast::as_list_array; +use arrow_array::types::ArrowTimestampType; use arrow_array::{ArrowNativeTypeOp, Scalar}; /// A dynamically typed, nullable single value, (the single-valued counter-part @@ -358,69 +359,47 @@ impl PartialOrd for ScalarValue { (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), (LargeBinary(_), _) => None, - (List(arr1), List(arr2)) | (FixedSizeList(arr1), FixedSizeList(arr2)) => { - if arr1.data_type() == arr2.data_type() { - let list_arr1 = as_list_array(arr1); - let list_arr2 = as_list_array(arr2); - if list_arr1.len() != list_arr2.len() { - return None; - } - for i in 0..list_arr1.len() { - let arr1 = list_arr1.value(i); - let arr2 = list_arr2.value(i); - - let lt_res = - arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = - arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; - - for j in 0..lt_res.len() { - if lt_res.is_valid(j) && lt_res.value(j) { - return Some(Ordering::Less); - } - if eq_res.is_valid(j) && !eq_res.value(j) { - return Some(Ordering::Greater); - } - } + (List(arr1), List(arr2)) + | (FixedSizeList(arr1), FixedSizeList(arr2)) + | (LargeList(arr1), LargeList(arr2)) => { + // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 + assert_eq!(arr1.len(), 1); + assert_eq!(arr2.len(), 1); + + if arr1.data_type() != arr2.data_type() { + return None; + } + + fn first_array_for_list(arr: &ArrayRef) -> ArrayRef { + if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_fixed_size_list_opt() { + arr.value(0) + } else { + unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") } - Some(Ordering::Equal) - } else { - None } - } - (LargeList(arr1), LargeList(arr2)) => { - if arr1.data_type() == arr2.data_type() { - let list_arr1 = as_large_list_array(arr1); - let list_arr2 = as_large_list_array(arr2); - if list_arr1.len() != list_arr2.len() { - return None; + + let arr1 = first_array_for_list(arr1); + let arr2 = first_array_for_list(arr2); + + let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); } - for i in 0..list_arr1.len() { - let arr1 = list_arr1.value(i); - let arr2 = list_arr2.value(i); - - let lt_res = - arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = - arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; - - for j in 0..lt_res.len() { - if lt_res.is_valid(j) && lt_res.value(j) { - return Some(Ordering::Less); - } - if eq_res.is_valid(j) && !eq_res.value(j) { - return Some(Ordering::Greater); - } - } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); } - Some(Ordering::Equal) - } else { - None } + + Some(Ordering::Equal) } - (List(_), _) => None, - (LargeList(_), _) => None, - (FixedSizeList(_), _) => None, + (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), @@ -796,6 +775,20 @@ impl ScalarValue { ScalarValue::IntervalMonthDayNano(Some(val)) } + /// Returns a [`ScalarValue`] representing + /// `value` and `tz_opt` timezone + pub fn new_timestamp( + value: Option, + tz_opt: Option>, + ) -> Self { + match T::UNIT { + TimeUnit::Second => ScalarValue::TimestampSecond(value, tz_opt), + TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(value, tz_opt), + TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(value, tz_opt), + TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(value, tz_opt), + } + } + /// Create a zero value in the given type. pub fn new_zero(datatype: &DataType) -> Result { assert!(datatype.is_primitive()); @@ -3644,24 +3637,6 @@ mod tests { ])]), )); assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); - - let a = - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(10), Some(2), Some(3)]), - None, - Some(vec![Some(10), Some(2), Some(3)]), - ]), - )); - let b = - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(10), Some(2), Some(3)]), - None, - Some(vec![Some(10), Some(2), Some(3)]), - ]), - )); - assert_eq!(a.partial_cmp(&b), Some(Ordering::Equal)); } #[test] diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 7f2dc61c07bf..fecab8835e50 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -25,7 +25,7 @@ use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; -use arrow_array::{Array, LargeListArray, ListArray}; +use arrow_array::{Array, LargeListArray, ListArray, RecordBatchOptions}; use arrow_schema::DataType; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; @@ -90,8 +90,12 @@ pub fn get_record_batch_at_indices( indices: &PrimitiveArray, ) -> Result { let new_columns = get_arrayref_at_indices(record_batch.columns(), indices)?; - RecordBatch::try_new(record_batch.schema(), new_columns) - .map_err(DataFusionError::ArrowError) + RecordBatch::try_new_with_options( + record_batch.schema(), + new_columns, + &RecordBatchOptions::new().with_row_count(Some(indices.len())), + ) + .map_err(DataFusionError::ArrowError) } /// This function compares two tuples depending on the given sort options. @@ -135,7 +139,7 @@ pub fn bisect( ) -> Result { let low: usize = 0; let high: usize = item_columns - .get(0) + .first() .ok_or_else(|| { DataFusionError::Internal("Column array shouldn't be empty".to_string()) })? @@ -186,7 +190,7 @@ pub fn linear_search( ) -> Result { let low: usize = 0; let high: usize = item_columns - .get(0) + .first() .ok_or_else(|| { DataFusionError::Internal("Column array shouldn't be empty".to_string()) })? diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs index efed5a04e7a5..cfd4b8bc4bba 100644 --- a/datafusion/core/benches/sort_limit_query_sql.rs +++ b/datafusion/core/benches/sort_limit_query_sql.rs @@ -99,7 +99,7 @@ fn create_context() -> Arc> { ctx_holder.lock().push(Arc::new(Mutex::new(ctx))) }); - let ctx = ctx_holder.lock().get(0).unwrap().clone(); + let ctx = ctx_holder.lock().first().unwrap().clone(); ctx } diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 1f9b4dc6ccf7..c7a838385bd6 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -93,10 +93,9 @@ async fn setup_files(store: Arc) { for partition in 0..TABLE_PARTITIONS { for file in 0..PARTITION_FILES { let data = create_parquet_file(&mut rng, file * FILE_ROWS); - let location = Path::try_from(format!( + let location = Path::from(format!( "{table_name}/partition={partition}/{file}.parquet" - )) - .unwrap(); + )); store.put(&location, data).await.unwrap(); } } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 52b5157b7313..c40dd522a457 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -23,44 +23,43 @@ mod parquet; use std::any::Any; use std::sync::Arc; +use crate::arrow::datatypes::{Schema, SchemaRef}; +use crate::arrow::record_batch::RecordBatch; +use crate::arrow::util::pretty; +use crate::datasource::{provider_as_source, MemTable, TableProvider}; +use crate::error::Result; +use crate::execution::{ + context::{SessionState, TaskContext}, + FunctionRegistry, +}; +use crate::logical_expr::utils::find_window_exprs; +use crate::logical_expr::{ + col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, TableType, +}; +use crate::physical_plan::{ + collect, collect_partitioned, execute_stream, execute_stream_partitioned, + ExecutionPlan, SendableRecordBatchStream, +}; +use crate::prelude::SessionContext; + use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; use arrow::csv::WriterBuilder; use arrow::datatypes::{DataType, Field}; -use async_trait::async_trait; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - DataFusionError, FileType, FileTypeWriterOptions, ParamValues, SchemaError, - UnnestOptions, + Column, DFSchema, DataFusionError, FileType, FileTypeWriterOptions, ParamValues, + SchemaError, UnnestOptions, }; use datafusion_expr::dml::CopyOptions; - -use datafusion_common::{Column, DFSchema}; use datafusion_expr::{ avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; -use crate::arrow::datatypes::Schema; -use crate::arrow::datatypes::SchemaRef; -use crate::arrow::record_batch::RecordBatch; -use crate::arrow::util::pretty; -use crate::datasource::{provider_as_source, MemTable, TableProvider}; -use crate::error::Result; -use crate::execution::{ - context::{SessionState, TaskContext}, - FunctionRegistry, -}; -use crate::logical_expr::{ - col, utils::find_window_exprs, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, - Partitioning, TableType, -}; -use crate::physical_plan::SendableRecordBatchStream; -use crate::physical_plan::{collect, collect_partitioned}; -use crate::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan}; -use crate::prelude::SessionContext; +use async_trait::async_trait; /// Contains options that control how data is /// written out from a DataFrame @@ -1343,24 +1342,43 @@ impl TableProvider for DataFrameTableProvider { mod tests { use std::vec; - use arrow::array::Int32Array; - use arrow::datatypes::DataType; + use super::*; + use crate::execution::context::SessionConfig; + use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr}; + use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; + use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; + use arrow::array::{self, Int32Array}; + use arrow::datatypes::DataType; + use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, - BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, - WindowFunction, + BinaryExpr, BuiltInWindowFunction, Operator, ScalarFunctionImplementation, + Volatility, WindowFrame, WindowFunction, }; use datafusion_physical_expr::expressions::Column; - - use crate::execution::context::SessionConfig; - use crate::physical_plan::ColumnarValue; - use crate::physical_plan::Partitioning; - use crate::physical_plan::PhysicalExpr; - use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; - use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; - - use super::*; + use datafusion_physical_plan::get_plan_string; + + pub fn table_with_constraints() -> Arc { + let dual_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + dual_schema.clone(), + vec![ + Arc::new(array::Int32Array::from(vec![1])), + Arc::new(array::StringArray::from(vec!["a"])), + ], + ) + .unwrap(); + let provider = MemTable::try_new(dual_schema, vec![vec![batch]]) + .unwrap() + .with_constraints(Constraints::new_unverified(vec![Constraint::PrimaryKey( + vec![0], + )])); + Arc::new(provider) + } async fn assert_logical_expr_schema_eq_physical_expr_schema( df: DataFrame, @@ -1557,6 +1575,262 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_with_pk() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + let col_name = Expr::Column(datafusion_common::Column { + relation: None, + name: "name".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + // expr list contains id, name + let expr_list = vec![col_id, col_name]; + let df = df.select(expr_list)?; + let physical_plan = df.clone().create_physical_plan().await?; + let expected = vec![ + "AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk2() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + let col_name = Expr::Column(datafusion_common::Column { + relation: None, + name: "name".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + let condition1 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_id.clone()), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + )); + let condition2 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_name), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))), + )); + // Predicate refers to id, and name fields + let predicate = Expr::BinaryExpr(BinaryExpr::new( + Box::new(condition1), + Operator::And, + Box::new(condition2), + )); + let df = df.filter(predicate)?; + let physical_plan = df.clone().create_physical_plan().await?; + + let expected = vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1 AND name@1 = a", + " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk3() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + let col_name = Expr::Column(datafusion_common::Column { + relation: None, + name: "name".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + // group by id, + let df = df.aggregate(group_expr, aggr_expr)?; + + let condition1 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_id.clone()), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + )); + // Predicate refers to id field + let predicate = condition1; + // id=0 + let df = df.filter(predicate)?; + // Select expression refers to id, and name columns. + // id, name + let df = df.select(vec![col_id.clone(), col_name.clone()])?; + let physical_plan = df.clone().create_physical_plan().await?; + + let expected = vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk4() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + // group by id, + let df = df.aggregate(group_expr, aggr_expr)?; + + let condition1 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_id.clone()), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + )); + // Predicate refers to id field + let predicate = condition1; + // id=1 + let df = df.filter(predicate)?; + // Select expression refers to id column. + // id + let df = df.select(vec![col_id.clone()])?; + let physical_plan = df.clone().create_physical_plan().await?; + + // In this case aggregate shouldn't be expanded, since these + // columns are not used. + let expected = vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + [ "+----+", + "| id |", + "+----+", + "| 1 |", + "+----+",], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn test_distinct() -> Result<()> { let t = test_table().await?; diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index cf6b87408107..09e54558f12e 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -1803,8 +1803,8 @@ mod tests { // there is only one row group in one file. assert_eq!(page_index.len(), 1); assert_eq!(offset_index.len(), 1); - let page_index = page_index.get(0).unwrap(); - let offset_index = offset_index.get(0).unwrap(); + let page_index = page_index.first().unwrap(); + let offset_index = offset_index.first().unwrap(); // 13 col in one row group assert_eq!(page_index.len(), 13); diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index 27c65dd459ec..fa4ed8437015 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -264,12 +264,9 @@ async fn hive_style_partitions_demuxer( // TODO: upstream RecordBatch::take to arrow-rs let take_indices = builder.finish(); let struct_array: StructArray = rb.clone().into(); - let parted_batch = RecordBatch::try_from( + let parted_batch = RecordBatch::from( arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(), - ) - .map_err(|_| { - DataFusionError::Internal("Unexpected error partitioning batch!".into()) - })?; + ); // Get or create channel for this batch let part_tx = match value_map.get_mut(&part_key) { diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index effeacc4804f..a7f69a1d3cc8 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -157,7 +157,7 @@ impl ListingTableConfig { /// Infer `ListingOptions` based on `table_path` suffix. pub async fn infer_options(self, state: &SessionState) -> Result { - let store = if let Some(url) = self.table_paths.get(0) { + let store = if let Some(url) = self.table_paths.first() { state.runtime_env().object_store(url)? } else { return Ok(self); @@ -165,7 +165,7 @@ impl ListingTableConfig { let file = self .table_paths - .get(0) + .first() .unwrap() .list_all_files(state, store.as_ref(), "") .await? @@ -191,7 +191,7 @@ impl ListingTableConfig { pub async fn infer_schema(self, state: &SessionState) -> Result { match self.options { Some(options) => { - let schema = if let Some(url) = self.table_paths.get(0) { + let schema = if let Some(url) = self.table_paths.first() { options.infer_schema(state, url).await? } else { Arc::new(Schema::empty()) @@ -710,7 +710,7 @@ impl TableProvider for ListingTable { None }; - let object_store_url = if let Some(url) = self.table_paths.get(0) { + let object_store_url = if let Some(url) = self.table_paths.first() { url.object_store() } else { return Ok(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))); @@ -835,7 +835,7 @@ impl TableProvider for ListingTable { // Multiple sort orders in outer vec are equivalent, so we pass only the first one let ordering = self .try_create_output_ordering()? - .get(0) + .first() .ok_or(DataFusionError::Internal( "Expected ListingTable to have a sort order, but none found!".into(), ))? @@ -872,7 +872,7 @@ impl ListingTable { filters: &'a [Expr], limit: Option, ) -> Result<(Vec>, Statistics)> { - let store = if let Some(url) = self.table_paths.get(0) { + let store = if let Some(url) = self.table_paths.first() { ctx.runtime_env().object_store(url)? } else { return Ok((vec![], Statistics::new_unknown(&self.file_schema))); diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 73dcb32ac81f..9c3b523a652c 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -357,9 +357,9 @@ mod tests { ) .unwrap(); let meta = file_groups - .get(0) + .first() .unwrap() - .get(0) + .first() .unwrap() .clone() .object_meta; @@ -391,9 +391,9 @@ mod tests { ) .unwrap(); let path = file_groups - .get(0) + .first() .unwrap() - .get(0) + .first() .unwrap() .object_meta .location diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 4cf115d03a9b..14e550eab1d5 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -135,7 +135,7 @@ impl DisplayAs for FileScanConfig { write!(f, ", infinite_source=true")?; } - if let Some(ordering) = orderings.get(0) { + if let Some(ordering) = orderings.first() { if !ordering.is_empty() { let start = if orderings.len() == 1 { ", output_ordering=" diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 0ab2046097c4..65414f5619a5 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -350,6 +350,7 @@ mod tests { use arrow::datatypes::Schema; use arrow::datatypes::{DataType, Field}; use datafusion_common::{config::ConfigOptions, TableReference, ToDFSchema}; + use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ builder::LogicalTableSource, cast, col, lit, AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF, @@ -994,6 +995,26 @@ mod tests { create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() } + // Note the values in the `String` column are: + // ❯ select * from './parquet-testing/data/data_index_bloom_encoding_stats.parquet'; + // +-----------+ + // | String | + // +-----------+ + // | Hello | + // | This is | + // | a | + // | test | + // | How | + // | are you | + // | doing | + // | today | + // | the quick | + // | brown fox | + // | jumps | + // | over | + // | the lazy | + // | dog | + // +-----------+ #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_simple_expr() { // load parquet file @@ -1002,7 +1023,7 @@ mod tests { let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate + // generate pruning predicate `(String = "Hello_Not_exists")` let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); let expr = col(r#""String""#).eq(lit("Hello_Not_Exists")); let expr = logical2physical(&expr, &schema); @@ -1029,7 +1050,7 @@ mod tests { let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate + // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); let expr = lit("1").eq(lit("1")).and( col(r#""String""#) @@ -1091,7 +1112,7 @@ mod tests { let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate + // generate pruning predicate `(String = "Hello")` let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); let expr = col(r#""String""#).eq(lit("Hello")); let expr = logical2physical(&expr, &schema); @@ -1110,6 +1131,94 @@ mod tests { assert_eq!(pruned_row_groups, row_groups); } + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_2_values() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "Hello") OR (String = "the quick")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_3_values() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))) + .or(col(r#""String""#).eq(lit("are you"))); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_or_not_eq() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "foo") OR (String != "bar")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#) + .not_eq(lit("foo")) + .or(col(r#""String""#).not_eq(lit("bar"))); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_without_bloom_filter() { // load parquet file @@ -1118,7 +1227,7 @@ mod tests { let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate + // generate pruning predicate on a column without a bloom filter let schema = Schema::new(vec![Field::new("string_col", DataType::Utf8, false)]); let expr = col(r#""string_col""#).eq(lit("0")); let expr = logical2physical(&expr, &schema); diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index f6c94edd8ca3..67a2eaf0d9b3 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -20,6 +20,7 @@ //! projections one by one if the operator below is amenable to this. If a //! projection reaches a source, it can even dissappear from the plan entirely. +use std::collections::HashMap; use std::sync::Arc; use super::output_requirements::OutputRequirementExec; @@ -42,9 +43,9 @@ use crate::physical_plan::{Distribution, ExecutionPlan}; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::JoinSide; -use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; @@ -245,12 +246,36 @@ fn try_swapping_with_streaming_table( } /// Unifies `projection` with its input (which is also a [`ProjectionExec`]). -/// Two consecutive projections can always merge into a single projection. fn try_unifying_projections( projection: &ProjectionExec, child: &ProjectionExec, ) -> Result>> { let mut projected_exprs = vec![]; + let mut column_ref_map: HashMap = HashMap::new(); + + // Collect the column references usage in the outer projection. + projection.expr().iter().for_each(|(expr, _)| { + expr.apply(&mut |expr| { + Ok({ + if let Some(column) = expr.as_any().downcast_ref::() { + *column_ref_map.entry(column.clone()).or_default() += 1; + } + VisitRecursion::Continue + }) + }) + .unwrap(); + }); + + // Merging these projections is not beneficial, e.g + // If an expression is not trivial and it is referred more than 1, unifies projections will be + // beneficial as caching mechanism for non-trivial computations. + // See discussion in: https://github.com/apache/arrow-datafusion/issues/8296 + if column_ref_map.iter().any(|(column, count)| { + *count > 1 && !is_expr_trivial(&child.expr()[column.index()].0.clone()) + }) { + return Ok(None); + } + for (expr, alias) in projection.expr() { // If there is no match in the input projection, we cannot unify these // projections. This case will arise if the projection expression contains @@ -265,6 +290,13 @@ fn try_unifying_projections( .map(|e| Some(Arc::new(e) as _)) } +/// Checks if the given expression is trivial. +/// An expression is considered trivial if it is either a `Column` or a `Literal`. +fn is_expr_trivial(expr: &Arc) -> bool { + expr.as_any().downcast_ref::().is_some() + || expr.as_any().downcast_ref::().is_some() +} + /// Tries to swap `projection` with its input (`output_req`). If possible, /// performs the swap and returns [`OutputRequirementExec`] as the top plan. /// Otherwise, returns `None`. diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 65a2e4e0a4f3..38532002a634 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2012,7 +2012,7 @@ impl DefaultPhysicalPlanner { let mut column_names = StringBuilder::new(); let mut data_types = StringBuilder::new(); let mut is_nullables = StringBuilder::new(); - for (_, field) in table_schema.fields().iter().enumerate() { + for field in table_schema.fields() { column_names.append_value(field.name()); // "System supplied type" --> Use debug format of the datatype diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 10f4574020bf..c6b8e0e01b4f 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1323,6 +1323,91 @@ async fn unnest_array_agg() -> Result<()> { Ok(()) } +#[tokio::test] +async fn unnest_with_redundant_columns() -> Result<()> { + let mut shape_id_builder = UInt32Builder::new(); + let mut tag_id_builder = UInt32Builder::new(); + + for shape_id in 1..=3 { + for tag_id in 1..=3 { + shape_id_builder.append_value(shape_id as u32); + tag_id_builder.append_value((shape_id * 10 + tag_id) as u32); + } + } + + let batch = RecordBatch::try_from_iter(vec![ + ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef), + ("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("shapes", batch)?; + let df = ctx.table("shapes").await?; + + let results = df.clone().collect().await?; + let expected = vec![ + "+----------+--------+", + "| shape_id | tag_id |", + "+----------+--------+", + "| 1 | 11 |", + "| 1 | 12 |", + "| 1 | 13 |", + "| 2 | 21 |", + "| 2 | 22 |", + "| 2 | 23 |", + "| 3 | 31 |", + "| 3 | 32 |", + "| 3 | 33 |", + "+----------+--------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + // Doing an `array_agg` by `shape_id` produces: + let df = df + .clone() + .aggregate( + vec![col("shape_id")], + vec![array_agg(col("shape_id")).alias("shape_id2")], + )? + .unnest_column("shape_id2")? + .select(vec![col("shape_id")])?; + + let optimized_plan = df.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: shapes.shape_id [shape_id:UInt32]", + " Unnest: shape_id2 [shape_id:UInt32, shape_id2:UInt32;N]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", + " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", + ]; + + let formatted = optimized_plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let results = df.collect().await?; + let expected = [ + "+----------+", + "| shape_id |", + "+----------+", + "| 1 |", + "| 1 |", + "| 1 |", + "| 2 |", + "| 2 |", + "| 2 |", + "| 3 |", + "| 3 |", + "| 3 |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + async fn create_test_table(name: &str) -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 1ea154303d69..9f94a59a3e59 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -133,7 +133,7 @@ async fn list_files_with_session_level_cache() { assert_eq!(get_list_file_cache_size(&state1), 1); let fg = &parquet1.base_config().file_groups; assert_eq!(fg.len(), 1); - assert_eq!(fg.get(0).unwrap().len(), 1); + assert_eq!(fg.first().unwrap().len(), 1); //Session 2 first time list files //check session 1 cache result not show in session 2 @@ -144,7 +144,7 @@ async fn list_files_with_session_level_cache() { assert_eq!(get_list_file_cache_size(&state2), 1); let fg2 = &parquet2.base_config().file_groups; assert_eq!(fg2.len(), 1); - assert_eq!(fg2.get(0).unwrap().len(), 1); + assert_eq!(fg2.first().unwrap().len(), 1); //Session 1 second time list files //check session 1 cache result not show in session 2 @@ -155,7 +155,7 @@ async fn list_files_with_session_level_cache() { assert_eq!(get_list_file_cache_size(&state1), 1); let fg = &parquet3.base_config().file_groups; assert_eq!(fg.len(), 1); - assert_eq!(fg.get(0).unwrap().len(), 1); + assert_eq!(fg.first().unwrap().len(), 1); // List same file no increase assert_eq!(get_list_file_cache_size(&state1), 1); } diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 47de6ec857da..94fc8015a78a 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::convert::TryFrom; use std::sync::Arc; use arrow::{ diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs index c2844a2b762a..8f810a929df3 100644 --- a/datafusion/core/tests/sql/parquet.rs +++ b/datafusion/core/tests/sql/parquet.rs @@ -263,7 +263,7 @@ async fn parquet_list_columns() { assert_eq!( as_string_array(&utf8_list_array.value(0)).unwrap(), - &StringArray::try_from(vec![Some("abc"), Some("efg"), Some("hij"),]).unwrap() + &StringArray::from(vec![Some("abc"), Some("efg"), Some("hij"),]) ); assert_eq!( diff --git a/datafusion/execution/src/cache/cache_unit.rs b/datafusion/execution/src/cache/cache_unit.rs index c54839061c8a..25f9b9fa4d68 100644 --- a/datafusion/execution/src/cache/cache_unit.rs +++ b/datafusion/execution/src/cache/cache_unit.rs @@ -228,7 +228,7 @@ mod tests { cache.put(&meta.location, vec![meta.clone()].into()); assert_eq!( - cache.get(&meta.location).unwrap().get(0).unwrap().clone(), + cache.get(&meta.location).unwrap().first().unwrap().clone(), meta.clone() ); } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index c4ff9fe95435..be2c45b901fa 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -50,9 +50,9 @@ use crate::{ use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::{ - plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, - DataFusionError, FileType, OwnedTableReference, Result, ScalarValue, TableReference, - ToDFSchema, UnnestOptions, + get_target_functional_dependencies, plan_datafusion_err, plan_err, Column, DFField, + DFSchema, DFSchemaRef, DataFusionError, FileType, OwnedTableReference, Result, + ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; /// Default table name for unnamed table @@ -904,8 +904,27 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator>, aggr_expr: impl IntoIterator>, ) -> Result { - let group_expr = normalize_cols(group_expr, &self.plan)?; + let mut group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; + + // Rewrite groupby exprs according to functional dependencies + let group_by_expr_names = group_expr + .iter() + .map(|group_by_expr| group_by_expr.display_name()) + .collect::>>()?; + let schema = self.plan.schema(); + if let Some(target_indices) = + get_target_functional_dependencies(schema, &group_by_expr_names) + { + for idx in target_indices { + let field = schema.field(idx); + let expr = + Expr::Column(Column::new(field.qualifier().cloned(), field.name())); + if !group_expr.contains(&expr) { + group_expr.push(expr); + } + } + } Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr) .map(LogicalPlan::Aggregate) .map(Self::from) @@ -1166,8 +1185,8 @@ pub fn build_join_schema( ); let mut metadata = left.metadata().clone(); metadata.extend(right.metadata().clone()); - DFSchema::new_with_metadata(fields, metadata) - .map(|schema| schema.with_functional_dependencies(func_dependencies)) + let schema = DFSchema::new_with_metadata(fields, metadata)?; + schema.with_functional_dependencies(func_dependencies) } /// Errors if one or more expressions have equal names. @@ -1491,7 +1510,7 @@ pub fn unnest_with_options( let df_schema = DFSchema::new_with_metadata(fields, metadata)?; // We can use the existing functional dependencies: let deps = input_schema.functional_dependencies().clone(); - let schema = Arc::new(df_schema.with_functional_dependencies(deps)); + let schema = Arc::new(df_schema.with_functional_dependencies(deps)?); Ok(LogicalPlan::Unnest(Unnest { input: Arc::new(input), diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d85e0b5b0a40..dfd4fbf65d8e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -946,7 +946,7 @@ impl LogicalPlan { // We can use the existing functional dependencies as is: .with_functional_dependencies( input.schema().functional_dependencies().clone(), - ), + )?, ); Ok(LogicalPlan::Unnest(Unnest { @@ -1834,8 +1834,9 @@ pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result Schema { Schema::new(vec![ @@ -3164,15 +3165,20 @@ digraph { ) .unwrap(); assert!(!filter.is_scalar()); - let unique_schema = - Arc::new(schema.as_ref().clone().with_functional_dependencies( - FunctionalDependencies::new_from_constraints( - Some(&Constraints::new_unverified(vec![Constraint::Unique( - vec![0], - )])), - 1, - ), - )); + let unique_schema = Arc::new( + schema + .as_ref() + .clone() + .with_functional_dependencies( + FunctionalDependencies::new_from_constraints( + Some(&Constraints::new_unverified(vec![Constraint::Unique( + vec![0], + )])), + 1, + ), + ) + .unwrap(), + ); let scan = Arc::new(LogicalPlan::TableScan(TableScan { table_name: TableReference::bare("tab"), source, diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 1027e97d061a..dd9449198796 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -116,7 +116,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result }) } AtArrow | ArrowAt => { - // ArrowAt and AtArrow check for whether one array ic contained in another. + // ArrowAt and AtArrow check for whether one array is contained in another. // The result type is boolean. Signature::comparison defines this signature. // Operation has nothing to do with comparison array_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 7d126a0f3373..abdd7f5f57f6 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -17,6 +17,10 @@ //! Expression utilities +use std::cmp::Ordering; +use std::collections::HashSet; +use std::sync::Arc; + use crate::expr::{Alias, Sort, WindowFunction}; use crate::expr_rewriter::strip_outer_reference; use crate::logical_plan::Aggregate; @@ -25,16 +29,15 @@ use crate::{ and, BinaryExpr, Cast, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, TryCast, }; + use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, }; + use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions}; -use std::cmp::Ordering; -use std::collections::HashSet; -use std::sync::Arc; /// The value to which `COUNT(*)` is expanded to in /// `COUNT()` expressions @@ -433,7 +436,7 @@ pub fn expand_qualified_wildcard( let qualified_schema = DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())? // We can use the functional dependencies as is, since it only stores indices: - .with_functional_dependencies(schema.functional_dependencies().clone()); + .with_functional_dependencies(schema.functional_dependencies().clone())?; let excluded_columns = if let Some(WildcardAdditionalOptions { opt_exclude, opt_except, @@ -501,7 +504,6 @@ pub fn generate_sort_key( let res = final_sort_keys .into_iter() .zip(is_partition_flag) - .map(|(lhs, rhs)| (lhs, rhs)) .collect::>(); Ok(res) } @@ -731,11 +733,7 @@ fn agg_cols(agg: &Aggregate) -> Vec { .collect() } -fn exprlist_to_fields_aggregate( - exprs: &[Expr], - plan: &LogicalPlan, - agg: &Aggregate, -) -> Result> { +fn exprlist_to_fields_aggregate(exprs: &[Expr], agg: &Aggregate) -> Result> { let agg_cols = agg_cols(agg); let mut fields = vec![]; for expr in exprs { @@ -744,7 +742,7 @@ fn exprlist_to_fields_aggregate( // resolve against schema of input to aggregate fields.push(expr.to_field(agg.input.schema())?); } - _ => fields.push(expr.to_field(plan.schema())?), + _ => fields.push(expr.to_field(&agg.schema)?), } } Ok(fields) @@ -761,15 +759,7 @@ pub fn exprlist_to_fields<'a>( // `GROUPING(person.state)` so in order to resolve `person.state` in this case we need to // look at the input to the aggregate instead. let fields = match plan { - LogicalPlan::Aggregate(agg) => { - Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) - } - LogicalPlan::Window(window) => match window.input.as_ref() { - LogicalPlan::Aggregate(agg) => { - Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) - } - _ => None, - }, + LogicalPlan::Aggregate(agg) => Some(exprlist_to_fields_aggregate(&exprs, agg)), _ => None, }; if let Some(fields) = fields { @@ -1241,10 +1231,9 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { #[cfg(test)] mod tests { use super::*; - use crate::expr_vec_fmt; use crate::{ - col, cube, expr, grouping_set, lit, rollup, AggregateFunction, WindowFrame, - WindowFunction, + col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, AggregateFunction, + WindowFrame, WindowFunction, }; #[test] diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 2a64f21b856b..2701ca1ecf3b 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,6 +23,8 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. +use crate::expr::Sort; +use crate::Expr; use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; use sqlparser::parser::ParserError::ParserError; @@ -142,41 +144,57 @@ impl WindowFrame { } } -/// Construct equivalent explicit window frames for implicit corner cases. -/// With this processing, we may assume in downstream code that RANGE/GROUPS -/// frames contain an appropriate ORDER BY clause. -pub fn regularize(mut frame: WindowFrame, order_bys: usize) -> Result { - if frame.units == WindowFrameUnits::Range && order_bys != 1 { +/// Regularizes ORDER BY clause for window definition for implicit corner cases. +pub fn regularize_window_order_by( + frame: &WindowFrame, + order_by: &mut Vec, +) -> Result<()> { + if frame.units == WindowFrameUnits::Range && order_by.len() != 1 { // Normally, RANGE frames require an ORDER BY clause with exactly one // column. However, an ORDER BY clause may be absent or present but with // more than one column in two edge cases: // 1. start bound is UNBOUNDED or CURRENT ROW // 2. end bound is CURRENT ROW or UNBOUNDED. - // In these cases, we regularize the RANGE frame to be equivalent to a ROWS - // frame with the UNBOUNDED bounds. - // Note that this follows Postgres behavior. + // In these cases, we regularize the ORDER BY clause if the ORDER BY clause + // is absent. If an ORDER BY clause is present but has more than one column, + // the ORDER BY clause is unchanged. Note that this follows Postgres behavior. if (frame.start_bound.is_unbounded() || frame.start_bound == WindowFrameBound::CurrentRow) && (frame.end_bound == WindowFrameBound::CurrentRow || frame.end_bound.is_unbounded()) { - // If an ORDER BY clause is absent, the frame is equivalent to a ROWS - // frame with the UNBOUNDED bounds. - // If an ORDER BY clause is present but has more than one column, the - // frame is unchanged. - if order_bys == 0 { - frame.units = WindowFrameUnits::Rows; - frame.start_bound = - WindowFrameBound::Preceding(ScalarValue::UInt64(None)); - frame.end_bound = WindowFrameBound::Following(ScalarValue::UInt64(None)); + // If an ORDER BY clause is absent, it is equivalent to a ORDER BY clause + // with constant value as sort key. + // If an ORDER BY clause is present but has more than one column, it is + // unchanged. + if order_by.is_empty() { + order_by.push(Expr::Sort(Sort::new( + Box::new(Expr::Literal(ScalarValue::UInt64(Some(1)))), + true, + false, + ))); } - } else { + } + } + Ok(()) +} + +/// Checks if given window frame is valid. In particular, if the frame is RANGE +/// with offset PRECEDING/FOLLOWING, it must have exactly one ORDER BY column. +pub fn check_window_frame(frame: &WindowFrame, order_bys: usize) -> Result<()> { + if frame.units == WindowFrameUnits::Range && order_bys != 1 { + // See `regularize_window_order_by`. + if !(frame.start_bound.is_unbounded() + || frame.start_bound == WindowFrameBound::CurrentRow) + || !(frame.end_bound == WindowFrameBound::CurrentRow + || frame.end_bound.is_unbounded()) + { plan_err!("RANGE requires exactly one ORDER BY column")? } } else if frame.units == WindowFrameUnits::Groups && order_bys == 0 { plan_err!("GROUPS requires an ORDER BY clause")? }; - Ok(frame) + Ok(()) } /// There are five ways to describe starting and ending frame boundaries: diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index e3b86f5db78f..91611251d9dd 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -503,7 +503,10 @@ fn coerce_window_frame( let target_type = match window_frame.units { WindowFrameUnits::Range => { if let Some(col_type) = current_types.first() { - if col_type.is_numeric() || is_utf8_or_large_utf8(col_type) { + if col_type.is_numeric() + || is_utf8_or_large_utf8(col_type) + || matches!(col_type, DataType::Null) + { col_type } else if is_datetime(col_type) { &DataType::Interval(IntervalUnit::MonthDayNano) diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 7844ca7909fc..4386253740aa 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -97,7 +97,7 @@ mod tests { let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index bbf704a83c55..7ae9f7edf5e5 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -15,33 +15,42 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to prune unnecessary Columns from the intermediate schemas inside the [LogicalPlan]. -//! This rule -//! - Removes unnecessary columns that are not showed at the output, and that are not used during computation. -//! - Adds projection to decrease table column size before operators that benefits from less memory at its input. -//! - Removes unnecessary [LogicalPlan::Projection] from the [LogicalPlan]. +//! Optimizer rule to prune unnecessary columns from intermediate schemas +//! inside the [`LogicalPlan`]. This rule: +//! - Removes unnecessary columns that do not appear at the output and/or are +//! not used during any computation step. +//! - Adds projections to decrease table column size before operators that +//! benefit from a smaller memory footprint at its input. +//! - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`]. + +use std::collections::HashSet; +use std::sync::Arc; + use crate::optimizer::ApplyOrder; -use datafusion_common::{Column, DFSchema, DFSchemaRef, JoinType, Result}; -use datafusion_expr::expr::{Alias, ScalarFunction}; +use crate::{OptimizerConfig, OptimizerRule}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::{ + get_required_group_by_exprs_indices, Column, DFSchema, DFSchemaRef, JoinType, Result, +}; +use datafusion_expr::expr::{Alias, ScalarFunction, ScalarFunctionDefinition}; use datafusion_expr::{ logical_plan::LogicalPlan, projection_schema, Aggregate, BinaryExpr, Cast, Distinct, - Expr, Projection, ScalarFunctionDefinition, TableScan, Window, + Expr, GroupingSet, Projection, TableScan, Window, }; + use hashbrown::HashMap; use itertools::{izip, Itertools}; -use std::collections::HashSet; -use std::sync::Arc; - -use crate::{OptimizerConfig, OptimizerRule}; -/// A rule for optimizing logical plans by removing unused Columns/Fields. +/// A rule for optimizing logical plans by removing unused columns/fields. /// -/// `OptimizeProjections` is an optimizer rule that identifies and eliminates columns from a logical plan -/// that are not used in any downstream operations. This can improve query performance and reduce unnecessary -/// data processing. +/// `OptimizeProjections` is an optimizer rule that identifies and eliminates +/// columns from a logical plan that are not used by downstream operations. +/// This can improve query performance and reduce unnecessary data processing. /// -/// The rule analyzes the input logical plan, determines the necessary column indices, and then removes any -/// unnecessary columns. Additionally, it eliminates any unnecessary projections in the plan. +/// The rule analyzes the input logical plan, determines the necessary column +/// indices, and then removes any unnecessary columns. It also removes any +/// unnecessary projections from the plan tree. #[derive(Default)] pub struct OptimizeProjections {} @@ -58,8 +67,8 @@ impl OptimizerRule for OptimizeProjections { plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - // All of the fields at the output are necessary. - let indices = require_all_indices(plan); + // All output fields are necessary: + let indices = (0..plan.schema().fields().len()).collect::>(); optimize_projections(plan, config, &indices) } @@ -72,30 +81,35 @@ impl OptimizerRule for OptimizeProjections { } } -/// Removes unnecessary columns (e.g Columns that are not referred at the output schema and -/// Columns that are not used during any computation, expression evaluation) from the logical plan and its inputs. +/// Removes unnecessary columns (e.g. columns that do not appear in the output +/// schema and/or are not used during any computation step such as expression +/// evaluation) from the logical plan and its inputs. /// -/// # Arguments +/// # Parameters /// -/// - `plan`: A reference to the input `LogicalPlan` to be optimized. -/// - `_config`: A reference to the optimizer configuration (not currently used). -/// - `indices`: A slice of column indices that represent the necessary column indices for downstream operations. +/// - `plan`: A reference to the input `LogicalPlan` to optimize. +/// - `config`: A reference to the optimizer configuration. +/// - `indices`: A slice of column indices that represent the necessary column +/// indices for downstream operations. /// /// # Returns /// -/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` with unnecessary columns removed. -/// - `Ok(None)`: If the optimization process results in a logical plan that doesn't require further propagation. -/// - `Err(error)`: If an error occurs during the optimization process. +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary +/// columns. +/// - `Ok(None)`: Signal that the given logical plan did not require any change. +/// - `Err(error)`: An error occured during the optimization process. fn optimize_projections( plan: &LogicalPlan, - _config: &dyn OptimizerConfig, + config: &dyn OptimizerConfig, indices: &[usize], ) -> Result> { // `child_required_indices` stores // - indices of the columns required for each child // - a flag indicating whether putting a projection above children is beneficial for the parent. // As an example LogicalPlan::Filter benefits from small tables. Hence for filter child this flag would be `true`. - let child_required_indices: Option, bool)>> = match plan { + let child_required_indices: Vec<(Vec, bool)> = match plan { LogicalPlan::Sort(_) | LogicalPlan::Filter(_) | LogicalPlan::Repartition(_) @@ -103,36 +117,32 @@ fn optimize_projections( | LogicalPlan::Union(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Distinct(Distinct::On(_)) => { - // Re-route required indices from the parent + column indices referred by expressions in the plan - // to the child. - // All of these operators benefits from small tables at their inputs. Hence projection_beneficial flag is `true`. + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. All these + // operators benefit from "small" inputs, so the projection_beneficial + // flag is `true`. let exprs = plan.expressions(); - let child_req_indices = plan - .inputs() + plan.inputs() .into_iter() .map(|input| { - let required_indices = - get_all_required_indices(indices, input, exprs.iter())?; - Ok((required_indices, true)) + get_all_required_indices(indices, input, exprs.iter()) + .map(|idxs| (idxs, true)) }) - .collect::>>()?; - Some(child_req_indices) + .collect::>()? } LogicalPlan::Limit(_) | LogicalPlan::Prepare(_) => { - // Re-route required indices from the parent + column indices referred by expressions in the plan - // to the child. - // Limit, Prepare doesn't benefit from small column numbers. Hence projection_beneficial flag is `false`. + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. These operators + // do not benefit from "small" inputs, so the projection_beneficial + // flag is `false`. let exprs = plan.expressions(); - let child_req_indices = plan - .inputs() + plan.inputs() .into_iter() .map(|input| { - let required_indices = - get_all_required_indices(indices, input, exprs.iter())?; - Ok((required_indices, false)) + get_all_required_indices(indices, input, exprs.iter()) + .map(|idxs| (idxs, false)) }) - .collect::>>()?; - Some(child_req_indices) + .collect::>()? } LogicalPlan::Copy(_) | LogicalPlan::Ddl(_) @@ -141,79 +151,108 @@ fn optimize_projections( | LogicalPlan::Analyze(_) | LogicalPlan::Subquery(_) | LogicalPlan::Distinct(Distinct::All(_)) => { - // Require all of the fields of the Dml, Ddl, Copy, Explain, Analyze, Subquery, Distinct::All input(s). - // Their child plan can be treated as final plan. Otherwise expected schema may not match. - // TODO: For some subquery variants we may not need to require all indices for its input. - // such as Exists. - let child_requirements = plan - .inputs() + // These plans require all their fields, and their children should + // be treated as final plans -- otherwise, we may have schema a + // mismatch. + // TODO: For some subquery variants (e.g. a subquery arising from an + // EXISTS expression), we may not need to require all indices. + plan.inputs() .iter() - .map(|input| { - // Require all of the fields for each input. - // No projection since all of the fields at the child is required - (require_all_indices(input), false) - }) - .collect::>(); - Some(child_requirements) + .map(|input| ((0..input.schema().fields().len()).collect_vec(), false)) + .collect::>() } LogicalPlan::EmptyRelation(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) | LogicalPlan::Extension(_) | LogicalPlan::DescribeTable(_) => { - // EmptyRelation, Values, DescribeTable, Statement has no inputs stop iteration - - // TODO: Add support for extension - // It is not known how to direct requirements to children for LogicalPlan::Extension. - // Safest behaviour is to stop propagation. - None + // These operators have no inputs, so stop the optimization process. + // TODO: Add support for `LogicalPlan::Extension`. + return Ok(None); } LogicalPlan::Projection(proj) => { return if let Some(proj) = merge_consecutive_projections(proj)? { - rewrite_projection_given_requirements(&proj, _config, indices)? - .map(|res| Ok(Some(res))) - // Even if projection cannot be optimized, return merged version - .unwrap_or_else(|| Ok(Some(LogicalPlan::Projection(proj)))) + Ok(Some( + rewrite_projection_given_requirements(&proj, config, indices)? + // Even if we cannot optimize the projection, merge if possible: + .unwrap_or_else(|| LogicalPlan::Projection(proj)), + )) } else { - rewrite_projection_given_requirements(proj, _config, indices) + rewrite_projection_given_requirements(proj, config, indices) }; } LogicalPlan::Aggregate(aggregate) => { - // Split parent requirements to group by and aggregate sections - let group_expr_len = aggregate.group_expr_len()?; - let (_group_by_reqs, mut aggregate_reqs): (Vec, Vec) = - indices.iter().partition(|&&idx| idx < group_expr_len); - // Offset aggregate indices so that they point to valid indices at the `aggregate.aggr_expr` - aggregate_reqs - .iter_mut() - .for_each(|idx| *idx -= group_expr_len); - - // Group by expressions are same - let new_group_bys = aggregate.group_expr.clone(); - - // Only use absolutely necessary aggregate expressions required by parent. - let new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs); + // Split parent requirements to GROUP BY and aggregate sections: + let n_group_exprs = aggregate.group_expr_len()?; + let (group_by_reqs, mut aggregate_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < n_group_exprs); + // Offset aggregate indices so that they point to valid indices at + // `aggregate.aggr_expr`: + for idx in aggregate_reqs.iter_mut() { + *idx -= n_group_exprs; + } + + // Get absolutely necessary GROUP BY fields: + let group_by_expr_existing = aggregate + .group_expr + .iter() + .map(|group_by_expr| group_by_expr.display_name()) + .collect::>>()?; + let new_group_bys = if let Some(simplest_groupby_indices) = + get_required_group_by_exprs_indices( + aggregate.input.schema(), + &group_by_expr_existing, + ) { + // Some of the fields in the GROUP BY may be required by the + // parent even if these fields are unnecessary in terms of + // functional dependency. + let required_indices = + merge_slices(&simplest_groupby_indices, &group_by_reqs); + get_at_indices(&aggregate.group_expr, &required_indices) + } else { + aggregate.group_expr.clone() + }; + + // Only use the absolutely necessary aggregate expressions required + // by the parent: + let mut new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs); let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); - let necessary_indices = - indices_referred_by_exprs(&aggregate.input, all_exprs_iter)?; + let schema = aggregate.input.schema(); + let necessary_indices = indices_referred_by_exprs(schema, all_exprs_iter)?; let aggregate_input = if let Some(input) = - optimize_projections(&aggregate.input, _config, &necessary_indices)? + optimize_projections(&aggregate.input, config, &necessary_indices)? { input } else { aggregate.input.as_ref().clone() }; - // Simplify input of the aggregation by adding a projection so that its input only contains - // absolutely necessary columns for the aggregate expressions. Please no that we use aggregate.input.schema() - // because necessary_indices refers to fields in this schema. - let necessary_exprs = - get_required_exprs(aggregate.input.schema(), &necessary_indices); - let (aggregate_input, _is_added) = - add_projection_on_top_if_helpful(aggregate_input, necessary_exprs, true)?; + // Simplify the input of the aggregation by adding a projection so + // that its input only contains absolutely necessary columns for + // the aggregate expressions. Note that necessary_indices refer to + // fields in `aggregate.input.schema()`. + let necessary_exprs = get_required_exprs(schema, &necessary_indices); + let (aggregate_input, _) = + add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)?; + + // Aggregations always need at least one aggregate expression. + // With a nested count, we don't require any column as input, but + // still need to create a correct aggregate, which may be optimized + // out later. As an example, consider the following query: + // + // SELECT COUNT(*) FROM (SELECT COUNT(*) FROM [...]) + // + // which always returns 1. + if new_aggr_expr.is_empty() + && new_group_bys.is_empty() + && !aggregate.aggr_expr.is_empty() + { + new_aggr_expr = vec![aggregate.aggr_expr[0].clone()]; + } - // Create new aggregate plan with updated input, and absolutely necessary fields. + // Create a new aggregate plan with the updated input and only the + // absolutely necessary fields: return Aggregate::try_new( Arc::new(aggregate_input), new_group_bys, @@ -222,43 +261,48 @@ fn optimize_projections( .map(|aggregate| Some(LogicalPlan::Aggregate(aggregate))); } LogicalPlan::Window(window) => { - // Split parent requirements to child and window expression sections. + // Split parent requirements to child and window expression sections: let n_input_fields = window.input.schema().fields().len(); let (child_reqs, mut window_reqs): (Vec, Vec) = indices.iter().partition(|&&idx| idx < n_input_fields); - // Offset window expr indices so that they point to valid indices at the `window.window_expr` - window_reqs - .iter_mut() - .for_each(|idx| *idx -= n_input_fields); + // Offset window expression indices so that they point to valid + // indices at `window.window_expr`: + for idx in window_reqs.iter_mut() { + *idx -= n_input_fields; + } - // Only use window expressions that are absolutely necessary by parent requirements. + // Only use window expressions that are absolutely necessary according + // to parent requirements: let new_window_expr = get_at_indices(&window.window_expr, &window_reqs); - // All of the required column indices at the input of the window by parent, and window expression requirements. + // Get all the required column indices at the input, either by the + // parent or window expression requirements. let required_indices = get_all_required_indices( &child_reqs, &window.input, new_window_expr.iter(), )?; let window_child = if let Some(new_window_child) = - optimize_projections(&window.input, _config, &required_indices)? + optimize_projections(&window.input, config, &required_indices)? { new_window_child } else { window.input.as_ref().clone() }; - // When no window expression is necessary, just use window input. (Remove window operator) + return if new_window_expr.is_empty() { + // When no window expression is necessary, use the input directly: Ok(Some(window_child)) } else { // Calculate required expressions at the input of the window. - // Please note that we use `old_child`, because `required_indices` refers to `old_child`. + // Please note that we use `old_child`, because `required_indices` + // refers to `old_child`. let required_exprs = get_required_exprs(window.input.schema(), &required_indices); - let (window_child, _is_added) = - add_projection_on_top_if_helpful(window_child, required_exprs, true)?; - let window = Window::try_new(new_window_expr, Arc::new(window_child))?; - Ok(Some(LogicalPlan::Window(window))) + let (window_child, _) = + add_projection_on_top_if_helpful(window_child, required_exprs)?; + Window::try_new(new_window_expr, Arc::new(window_child)) + .map(|window| Some(LogicalPlan::Window(window))) }; } LogicalPlan::Join(join) => { @@ -270,323 +314,402 @@ fn optimize_projections( get_all_required_indices(&left_req_indices, &join.left, exprs.iter())?; let right_indices = get_all_required_indices(&right_req_indices, &join.right, exprs.iter())?; - // Join benefits from small columns numbers at its input (decreases memory usage) - // Hence each child benefits from projection. - Some(vec![(left_indices, true), (right_indices, true)]) + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![(left_indices, true), (right_indices, true)] } LogicalPlan::CrossJoin(cross_join) => { let left_len = cross_join.left.schema().fields().len(); let (left_child_indices, right_child_indices) = split_join_requirements(left_len, indices, &JoinType::Inner); - // Join benefits from small columns numbers at its input (decreases memory usage) - // Hence each child benefits from projection. - Some(vec![ - (left_child_indices, true), - (right_child_indices, true), - ]) + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![(left_child_indices, true), (right_child_indices, true)] } LogicalPlan::TableScan(table_scan) => { - let projection_fields = table_scan.projected_schema.fields(); let schema = table_scan.source.schema(); - // We expect to find all of the required indices of the projected schema fields. - // among original schema. If at least one of them cannot be found. Use all of the fields in the file. - // (No projection at the source) - let projection = indices - .iter() - .map(|&idx| { - schema.fields().iter().position(|field_source| { - projection_fields[idx].field() == field_source - }) - }) - .collect::>>(); + // Get indices referred to in the original (schema with all fields) + // given projected indices. + let projection = with_indices(&table_scan.projection, schema, |map| { + indices.iter().map(|&idx| map[idx]).collect() + }); - return Ok(Some(LogicalPlan::TableScan(TableScan::try_new( + return TableScan::try_new( table_scan.table_name.clone(), table_scan.source.clone(), - projection, + Some(projection), table_scan.filters.clone(), table_scan.fetch, - )?))); + ) + .map(|table| Some(LogicalPlan::TableScan(table))); } }; - let child_required_indices = - if let Some(child_required_indices) = child_required_indices { - child_required_indices - } else { - // Stop iteration, cannot propagate requirement down below this operator. - return Ok(None); - }; - let new_inputs = izip!(child_required_indices, plan.inputs().into_iter()) .map(|((required_indices, projection_beneficial), child)| { - let (input, mut is_changed) = if let Some(new_input) = - optimize_projections(child, _config, &required_indices)? + let (input, is_changed) = if let Some(new_input) = + optimize_projections(child, config, &required_indices)? { (new_input, true) } else { (child.clone(), false) }; let project_exprs = get_required_exprs(child.schema(), &required_indices); - let (input, is_projection_added) = add_projection_on_top_if_helpful( - input, - project_exprs, - projection_beneficial, - )?; - is_changed |= is_projection_added; - Ok(is_changed.then_some(input)) + let (input, proj_added) = if projection_beneficial { + add_projection_on_top_if_helpful(input, project_exprs)? + } else { + (input, false) + }; + Ok((is_changed || proj_added).then_some(input)) }) - .collect::>>>()?; - // All of the children are same in this case, no need to change plan + .collect::>>()?; if new_inputs.iter().all(|child| child.is_none()) { + // All children are the same in this case, no need to change the plan: Ok(None) } else { - // At least one of the children is changed. + // At least one of the children is changed: let new_inputs = izip!(new_inputs, plan.inputs()) - // If new_input is `None`, this means child is not changed. Hence use `old_child` during construction. + // If new_input is `None`, this means child is not changed, so use + // `old_child` during construction: .map(|(new_input, old_child)| new_input.unwrap_or_else(|| old_child.clone())) .collect::>(); - let res = plan.with_new_inputs(&new_inputs)?; - Ok(Some(res)) + plan.with_new_inputs(&new_inputs).map(Some) } } -/// Merge Consecutive Projections +/// This function applies the given function `f` to the projection indices +/// `proj_indices` if they exist. Otherwise, applies `f` to a default set +/// of indices according to `schema`. +fn with_indices( + proj_indices: &Option>, + schema: SchemaRef, + mut f: F, +) -> Vec +where + F: FnMut(&[usize]) -> Vec, +{ + match proj_indices { + Some(indices) => f(indices.as_slice()), + None => { + let range: Vec = (0..schema.fields.len()).collect(); + f(range.as_slice()) + } + } +} + +/// Merges consecutive projections. /// /// Given a projection `proj`, this function attempts to merge it with a previous -/// projection if it exists and if the merging is beneficial. Merging is considered -/// beneficial when expressions in the current projection are non-trivial and referred to -/// more than once in its input fields. This can act as a caching mechanism for non-trivial -/// computations. +/// projection if it exists and if merging is beneficial. Merging is considered +/// beneficial when expressions in the current projection are non-trivial and +/// appear more than once in its input fields. This can act as a caching mechanism +/// for non-trivial computations. /// -/// # Arguments +/// # Parameters /// /// * `proj` - A reference to the `Projection` to be merged. /// /// # Returns /// -/// A `Result` containing an `Option` of the merged `Projection`. If merging is not beneficial -/// it returns `Ok(None)`. +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Projection))`: Merge was beneficial and successful. Contains the +/// merged projection. +/// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place). +/// - `Err(error)`: An error occured during the function call. fn merge_consecutive_projections(proj: &Projection) -> Result> { - let prev_projection = if let LogicalPlan::Projection(prev) = proj.input.as_ref() { - prev - } else { + let LogicalPlan::Projection(prev_projection) = proj.input.as_ref() else { return Ok(None); }; - // Count usages (referral counts) of each projection expression in its input fields - let column_referral_map: HashMap = proj - .expr - .iter() - .flat_map(|expr| expr.to_columns()) - .fold(HashMap::new(), |mut map, cols| { - cols.into_iter() - .for_each(|col| *map.entry(col).or_default() += 1); - map - }); - - // Merging these projections is not beneficial, e.g - // If an expression is not trivial and it is referred more than 1, consecutive projections will be - // beneficial as caching mechanism for non-trivial computations. - // See discussion in: https://github.com/apache/arrow-datafusion/issues/8296 - if column_referral_map.iter().any(|(col, usage)| { - *usage > 1 + // Count usages (referrals) of each projection expression in its input fields: + let mut column_referral_map = HashMap::::new(); + for columns in proj.expr.iter().flat_map(|expr| expr.to_columns()) { + for col in columns.into_iter() { + *column_referral_map.entry(col.clone()).or_default() += 1; + } + } + + // If an expression is non-trivial and appears more than once, consecutive + // projections will benefit from a compute-once approach. For details, see: + // https://github.com/apache/arrow-datafusion/issues/8296 + if column_referral_map.into_iter().any(|(col, usage)| { + usage > 1 && !is_expr_trivial( &prev_projection.expr - [prev_projection.schema.index_of_column(col).unwrap()], + [prev_projection.schema.index_of_column(&col).unwrap()], ) }) { return Ok(None); } - // If all of the expression of the top projection can be rewritten. Rewrite expressions and create a new projection + // If all the expression of the top projection can be rewritten, do so and + // create a new projection: let new_exprs = proj .expr .iter() .map(|expr| rewrite_expr(expr, prev_projection)) .collect::>>>()?; - new_exprs - .map(|exprs| Projection::try_new(exprs, prev_projection.input.clone())) - .transpose() + if let Some(new_exprs) = new_exprs { + let new_exprs = new_exprs + .into_iter() + .zip(proj.expr.iter()) + .map(|(new_expr, old_expr)| { + new_expr.alias_if_changed(old_expr.name_for_alias()?) + }) + .collect::>>()?; + Projection::try_new(new_exprs, prev_projection.input.clone()).map(Some) + } else { + Ok(None) + } } -/// Trim Expression -/// -/// Trim the given expression by removing any unnecessary layers of abstraction. +/// Trim the given expression by removing any unnecessary layers of aliasing. /// If the expression is an alias, the function returns the underlying expression. -/// Otherwise, it returns the original expression unchanged. -/// -/// # Arguments +/// Otherwise, it returns the given expression as is. /// -/// * `expr` - The input expression to be trimmed. +/// Without trimming, we can end up with unnecessary indirections inside expressions +/// during projection merges. /// -/// # Returns -/// -/// The trimmed expression. If the input is an alias, the underlying expression is returned. -/// -/// Without trimming, during projection merge we can end up unnecessary indirections inside the expressions. /// Consider: /// -/// Projection (a1 + b1 as sum1) -/// --Projection (a as a1, b as b1) -/// ----Source (a, b) +/// ```text +/// Projection(a1 + b1 as sum1) +/// --Projection(a as a1, b as b1) +/// ----Source(a, b) +/// ``` /// -/// After merge we want to produce +/// After merge, we want to produce: /// -/// Projection (a + b as sum1) +/// ```text +/// Projection(a + b as sum1) /// --Source(a, b) +/// ``` /// -/// Without trimming we would end up +/// Without trimming, we would end up with: /// -/// Projection (a as a1 + b as b1 as sum1) +/// ```text +/// Projection((a as a1 + b as b1) as sum1) /// --Source(a, b) +/// ``` fn trim_expr(expr: Expr) -> Expr { match expr { - Expr::Alias(alias) => *alias.expr, + Expr::Alias(alias) => trim_expr(*alias.expr), _ => expr, } } -// Check whether expression is trivial (e.g it doesn't include computation.) +// Check whether `expr` is trivial; i.e. it doesn't imply any computation. fn is_expr_trivial(expr: &Expr) -> bool { matches!(expr, Expr::Column(_) | Expr::Literal(_)) } -// Exit early when None is seen. +// Exit early when there is no rewrite to do. macro_rules! rewrite_expr_with_check { ($expr:expr, $input:expr) => { - if let Some(val) = rewrite_expr($expr, $input)? { - val + if let Some(value) = rewrite_expr($expr, $input)? { + value } else { return Ok(None); } }; } -// Rewrites expression using its input projection (Merges consecutive projection expressions). -/// Rewrites an projections expression using its input projection -/// (Helper during merging consecutive projection expressions). +/// Rewrites a projection expression using the projection before it (i.e. its input) +/// This is a subroutine to the `merge_consecutive_projections` function. /// -/// # Arguments +/// # Parameters /// -/// * `expr` - A reference to the expression to be rewritten. -/// * `input` - A reference to the input (itself a projection) of the projection expression. +/// * `expr` - A reference to the expression to rewrite. +/// * `input` - A reference to the input of the projection expression (itself +/// a projection). /// /// # Returns /// -/// A `Result` containing an `Option` of the rewritten expression. If the rewrite is successful, -/// it returns `Ok(Some)` with the modified expression. If the expression cannot be rewritten -/// it returns `Ok(None)`. +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result. +/// - `Ok(None)`: Signals that `expr` can not be rewritten. +/// - `Err(error)`: An error occured during the function call. fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { - Ok(match expr { + let result = match expr { Expr::Column(col) => { - // Find index of column + // Find index of column: let idx = input.schema.index_of_column(col)?; - Some(input.expr[idx].clone()) - } - Expr::BinaryExpr(binary) => { - let lhs = trim_expr(rewrite_expr_with_check!(&binary.left, input)); - let rhs = trim_expr(rewrite_expr_with_check!(&binary.right, input)); - Some(Expr::BinaryExpr(BinaryExpr::new( - Box::new(lhs), - binary.op, - Box::new(rhs), - ))) + input.expr[idx].clone() } - Expr::Alias(alias) => { - let new_expr = trim_expr(rewrite_expr_with_check!(&alias.expr, input)); - Some(Expr::Alias(Alias::new( - new_expr, - alias.relation.clone(), - alias.name.clone(), - ))) - } - Expr::Literal(_val) => Some(expr.clone()), + Expr::BinaryExpr(binary) => Expr::BinaryExpr(BinaryExpr::new( + Box::new(trim_expr(rewrite_expr_with_check!(&binary.left, input))), + binary.op, + Box::new(trim_expr(rewrite_expr_with_check!(&binary.right, input))), + )), + Expr::Alias(alias) => Expr::Alias(Alias::new( + trim_expr(rewrite_expr_with_check!(&alias.expr, input)), + alias.relation.clone(), + alias.name.clone(), + )), + Expr::Literal(_) => expr.clone(), Expr::Cast(cast) => { let new_expr = rewrite_expr_with_check!(&cast.expr, input); - Some(Expr::Cast(Cast::new( - Box::new(new_expr), - cast.data_type.clone(), - ))) + Expr::Cast(Cast::new(Box::new(new_expr), cast.data_type.clone())) } Expr::ScalarFunction(scalar_fn) => { - let fun = if let ScalarFunctionDefinition::BuiltIn(fun) = scalar_fn.func_def { - fun - } else { + // TODO: Support UDFs. + let ScalarFunctionDefinition::BuiltIn(fun) = scalar_fn.func_def else { return Ok(None); }; - scalar_fn + return Ok(scalar_fn .args .iter() .map(|expr| rewrite_expr(expr, input)) - .collect::>>>()? - .map(|new_args| Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + .collect::>>()? + .map(|new_args| { + Expr::ScalarFunction(ScalarFunction::new(fun, new_args)) + })); } - _ => { - // Unsupported type to merge in consecutive projections - None - } - }) + // Unsupported type for consecutive projection merge analysis. + _ => return Ok(None), + }; + Ok(Some(result)) } -/// Retrieves a set of outer-referenced columns from an expression. -/// Please note that `expr.to_columns()` API doesn't return these columns. +/// Retrieves a set of outer-referenced columns by the given expression, `expr`. +/// Note that the `Expr::to_columns()` function doesn't return these columns. /// -/// # Arguments +/// # Parameters /// -/// * `expr` - The expression to be analyzed for outer-referenced columns. +/// * `expr` - The expression to analyze for outer-referenced columns. /// /// # Returns /// -/// A `HashSet` containing columns that are referenced by the expression. -fn outer_columns(expr: &Expr) -> HashSet { +/// If the function can safely infer all outer-referenced columns, returns a +/// `Some(HashSet)` containing these columns. Otherwise, returns `None`. +fn outer_columns(expr: &Expr) -> Option> { let mut columns = HashSet::new(); - outer_columns_helper(expr, &mut columns); - columns + outer_columns_helper(expr, &mut columns).then_some(columns) } -/// Helper function to accumulate outer-referenced columns referred by the `expr`. +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expression, `expr`. /// -/// # Arguments +/// # Parameters +/// +/// * `expr` - The expression to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. /// -/// * `expr` - The expression to be analyzed for outer-referenced columns. -/// * `columns` - A mutable reference to a `HashSet` where the detected columns are collected. -fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) { +/// Returns `true` if it can safely collect all outer-referenced columns. +/// Otherwise, returns `false`. +fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) -> bool { match expr { Expr::OuterReferenceColumn(_, col) => { columns.insert(col.clone()); + true } Expr::BinaryExpr(binary_expr) => { - outer_columns_helper(&binary_expr.left, columns); - outer_columns_helper(&binary_expr.right, columns); + outer_columns_helper(&binary_expr.left, columns) + && outer_columns_helper(&binary_expr.right, columns) } Expr::ScalarSubquery(subquery) => { - for expr in &subquery.outer_ref_columns { - outer_columns_helper(expr, columns); - } + let exprs = subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) } Expr::Exists(exists) => { - for expr in &exists.subquery.outer_ref_columns { - outer_columns_helper(expr, columns); + let exprs = exists.subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) + } + Expr::Alias(alias) => outer_columns_helper(&alias.expr, columns), + Expr::InSubquery(insubquery) => { + let exprs = insubquery.subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) + } + Expr::IsNotNull(expr) | Expr::IsNull(expr) => outer_columns_helper(expr, columns), + Expr::Cast(cast) => outer_columns_helper(&cast.expr, columns), + Expr::Sort(sort) => outer_columns_helper(&sort.expr, columns), + Expr::AggregateFunction(aggregate_fn) => { + outer_columns_helper_multi(aggregate_fn.args.iter(), columns) + && aggregate_fn + .order_by + .as_ref() + .map_or(true, |obs| outer_columns_helper_multi(obs.iter(), columns)) + && aggregate_fn + .filter + .as_ref() + .map_or(true, |filter| outer_columns_helper(filter, columns)) + } + Expr::WindowFunction(window_fn) => { + outer_columns_helper_multi(window_fn.args.iter(), columns) + && outer_columns_helper_multi(window_fn.order_by.iter(), columns) + && outer_columns_helper_multi(window_fn.partition_by.iter(), columns) + } + Expr::GroupingSet(groupingset) => match groupingset { + GroupingSet::GroupingSets(multi_exprs) => multi_exprs + .iter() + .all(|e| outer_columns_helper_multi(e.iter(), columns)), + GroupingSet::Cube(exprs) | GroupingSet::Rollup(exprs) => { + outer_columns_helper_multi(exprs.iter(), columns) } + }, + Expr::ScalarFunction(scalar_fn) => { + outer_columns_helper_multi(scalar_fn.args.iter(), columns) + } + Expr::Like(like) => { + outer_columns_helper(&like.expr, columns) + && outer_columns_helper(&like.pattern, columns) } - Expr::Alias(alias) => { - outer_columns_helper(&alias.expr, columns); + Expr::InList(in_list) => { + outer_columns_helper(&in_list.expr, columns) + && outer_columns_helper_multi(in_list.list.iter(), columns) } - _ => {} + Expr::Case(case) => { + let when_then_exprs = case + .when_then_expr + .iter() + .flat_map(|(first, second)| [first.as_ref(), second.as_ref()]); + outer_columns_helper_multi(when_then_exprs, columns) + && case + .expr + .as_ref() + .map_or(true, |expr| outer_columns_helper(expr, columns)) + && case + .else_expr + .as_ref() + .map_or(true, |expr| outer_columns_helper(expr, columns)) + } + Expr::Column(_) | Expr::Literal(_) | Expr::Wildcard { .. } => true, + _ => false, } } -/// Generates the required expressions(Column) that resides at `indices` of the `input_schema`. +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expressions (`exprs`). +/// +/// # Parameters +/// +/// * `exprs` - The expressions to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +/// +/// Returns `true` if it can safely collect all outer-referenced columns. +/// Otherwise, returns `false`. +fn outer_columns_helper_multi<'a>( + mut exprs: impl Iterator, + columns: &mut HashSet, +) -> bool { + exprs.all(|e| outer_columns_helper(e, columns)) +} + +/// Generates the required expressions (columns) that reside at `indices` of +/// the given `input_schema`. /// /// # Arguments /// /// * `input_schema` - A reference to the input schema. -/// * `indices` - A slice of `usize` indices specifying which columns are required. +/// * `indices` - A slice of `usize` indices specifying required columns. /// /// # Returns /// -/// A vector of `Expr::Column` expressions, that sits at `indices` of the `input_schema`. +/// A vector of `Expr::Column` expressions residing at `indices` of the `input_schema`. fn get_required_exprs(input_schema: &Arc, indices: &[usize]) -> Vec { let fields = input_schema.fields(); indices @@ -595,58 +718,70 @@ fn get_required_exprs(input_schema: &Arc, indices: &[usize]) -> Vec>( - input: &LogicalPlan, - exprs: I, +/// A [`Result`] object containing the indices of all required fields in +/// `input_schema` to calculate all `exprs` successfully. +fn indices_referred_by_exprs<'a>( + input_schema: &DFSchemaRef, + exprs: impl Iterator, ) -> Result> { - let new_indices = exprs - .flat_map(|expr| indices_referred_by_expr(input.schema(), expr)) + let indices = exprs + .map(|expr| indices_referred_by_expr(input_schema, expr)) + .collect::>>()?; + Ok(indices + .into_iter() .flatten() - // Make sure no duplicate entries exists and indices are ordered. + // Make sure no duplicate entries exist and indices are ordered: .sorted() .dedup() - .collect::>(); - Ok(new_indices) + .collect()) } -/// Get indices of the necessary fields referred by the `expr` among input schema. +/// Get indices of the fields referred to by the given expression `expr` within +/// the given schema (`input_schema`). /// -/// # Arguments +/// # Parameters /// -/// * `input_schema`: The input schema to search for indices referred by expr. -/// * `expr`: An expression for which we want to find necessary field indices at the input schema. +/// * `input_schema`: The input schema to analyze for index requirements. +/// * `expr`: An expression for which we want to find necessary field indices. /// /// # Returns /// -/// A [Result] object that contains the required field indices of the `input_schema`, to be able to calculate -/// the `expr` successfully. +/// A [`Result`] object containing the indices of all required fields in +/// `input_schema` to calculate `expr` successfully. fn indices_referred_by_expr( input_schema: &DFSchemaRef, expr: &Expr, ) -> Result> { let mut cols = expr.to_columns()?; - // Get outer referenced columns (expr.to_columns() doesn't return these columns). - cols.extend(outer_columns(expr)); - cols.iter() - .filter(|&col| input_schema.has_column(col)) - .map(|col| input_schema.index_of_column(col)) - .collect::>>() + // Get outer-referenced columns: + if let Some(outer_cols) = outer_columns(expr) { + cols.extend(outer_cols); + } else { + // Expression is not known to contain outer columns or not. Hence, do + // not assume anything and require all the schema indices at the input: + return Ok((0..input_schema.fields().len()).collect()); + } + Ok(cols + .iter() + .flat_map(|col| input_schema.index_of_column(col)) + .collect()) } -/// Get all required indices for the input (indices required by parent + indices referred by `exprs`) +/// Gets all required indices for the input; i.e. those required by the parent +/// and those referred to by `exprs`. /// -/// # Arguments +/// # Parameters /// /// * `parent_required_indices` - A slice of indices required by the parent plan. /// * `input` - The input logical plan to analyze for index requirements. @@ -654,30 +789,28 @@ fn indices_referred_by_expr( /// /// # Returns /// -/// A `Result` containing a vector of `usize` indices containing all required indices. -fn get_all_required_indices<'a, I: Iterator>( +/// A `Result` containing a vector of `usize` indices containing all the required +/// indices. +fn get_all_required_indices<'a>( parent_required_indices: &[usize], input: &LogicalPlan, - exprs: I, + exprs: impl Iterator, ) -> Result> { - let referred_indices = indices_referred_by_exprs(input, exprs)?; - Ok(merge_vectors(parent_required_indices, &referred_indices)) + indices_referred_by_exprs(input.schema(), exprs) + .map(|indices| merge_slices(parent_required_indices, &indices)) } -/// Retrieves a list of expressions at specified indices from a slice of expressions. +/// Retrieves the expressions at specified indices within the given slice. Ignores +/// any invalid indices. /// -/// This function takes a slice of expressions `exprs` and a slice of `usize` indices `indices`. -/// It returns a new vector containing the expressions from `exprs` that correspond to the provided indices (with bound check). -/// -/// # Arguments +/// # Parameters /// -/// * `exprs` - A slice of expressions from which expressions are to be retrieved. -/// * `indices` - A slice of `usize` indices specifying the positions of the expressions to be retrieved. +/// * `exprs` - A slice of expressions to index into. +/// * `indices` - A slice of indices specifying the positions of expressions sought. /// /// # Returns /// -/// A vector of expressions that correspond to the specified indices. If any index is out of bounds, -/// the associated expression is skipped in the result. +/// A vector of expressions corresponding to specified indices. fn get_at_indices(exprs: &[Expr], indices: &[usize]) -> Vec { indices .iter() @@ -686,158 +819,148 @@ fn get_at_indices(exprs: &[Expr], indices: &[usize]) -> Vec { .collect() } -/// Merges two slices of `usize` values into a single vector with sorted (ascending) and deduplicated elements. -/// -/// # Arguments -/// -/// * `lhs` - The first slice of `usize` values to be merged. -/// * `rhs` - The second slice of `usize` values to be merged. -/// -/// # Returns -/// -/// A vector of `usize` values containing the merged, sorted, and deduplicated elements from `lhs` and `rhs`. -/// As an example merge of [3, 2, 4] and [3, 6, 1] will produce [1, 2, 3, 6] -fn merge_vectors(lhs: &[usize], rhs: &[usize]) -> Vec { - let mut merged = lhs.to_vec(); - merged.extend(rhs); - // Make sure to run sort before dedup. - // Dedup removes consecutive same entries - // If sort is run before it, all duplicates are removed. - merged.sort(); - merged.dedup(); - merged +/// Merges two slices into a single vector with sorted (ascending) and +/// deduplicated elements. For example, merging `[3, 2, 4]` and `[3, 6, 1]` +/// will produce `[1, 2, 3, 6]`. +fn merge_slices(left: &[T], right: &[T]) -> Vec { + // Make sure to sort before deduping, which removes the duplicates: + left.iter() + .cloned() + .chain(right.iter().cloned()) + .sorted() + .dedup() + .collect() } -/// Splits requirement indices for a join into left and right children based on the join type. +/// Splits requirement indices for a join into left and right children based on +/// the join type. /// -/// This function takes the length of the left child, a slice of requirement indices, and the type -/// of join (e.g., INNER, LEFT, RIGHT, etc.) as arguments. Depending on the join type, it divides -/// the requirement indices into those that apply to the left child and those that apply to the right child. +/// This function takes the length of the left child, a slice of requirement +/// indices, and the type of join (e.g. `INNER`, `LEFT`, `RIGHT`) as arguments. +/// Depending on the join type, it divides the requirement indices into those +/// that apply to the left child and those that apply to the right child. /// -/// - For INNER, LEFT, RIGHT, and FULL joins, the requirements are split between left and right children. -/// The right child indices are adjusted to point to valid positions in the right child by subtracting -/// the length of the left child. +/// - For `INNER`, `LEFT`, `RIGHT` and `FULL` joins, the requirements are split +/// between left and right children. The right child indices are adjusted to +/// point to valid positions within the right child by subtracting the length +/// of the left child. /// -/// - For LEFT ANTI, LEFT SEMI, RIGHT SEMI, and RIGHT ANTI joins, all requirements are re-routed to either -/// the left child or the right child directly, depending on the join type. +/// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all +/// requirements are re-routed to either the left child or the right child +/// directly, depending on the join type. /// -/// # Arguments +/// # Parameters /// /// * `left_len` - The length of the left child. /// * `indices` - A slice of requirement indices. -/// * `join_type` - The type of join (e.g., INNER, LEFT, RIGHT, etc.). +/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`). /// /// # Returns /// -/// A tuple containing two vectors of `usize` indices: the first vector represents the requirements for -/// the left child, and the second vector represents the requirements for the right child. The indices -/// are appropriately split and adjusted based on the join type. +/// A tuple containing two vectors of `usize` indices: The first vector represents +/// the requirements for the left child, and the second vector represents the +/// requirements for the right child. The indices are appropriately split and +/// adjusted based on the join type. fn split_join_requirements( left_len: usize, indices: &[usize], join_type: &JoinType, ) -> (Vec, Vec) { match join_type { - // In these cases requirements split to left and right child. + // In these cases requirements are split between left/right children: JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - let (left_child_reqs, mut right_child_reqs): (Vec, Vec) = + let (left_reqs, mut right_reqs): (Vec, Vec) = indices.iter().partition(|&&idx| idx < left_len); - // Decrease right side index by `left_len` so that they point to valid positions in the right child. - right_child_reqs.iter_mut().for_each(|idx| *idx -= left_len); - (left_child_reqs, right_child_reqs) + // Decrease right side indices by `left_len` so that they point to valid + // positions within the right child: + for idx in right_reqs.iter_mut() { + *idx -= left_len; + } + (left_reqs, right_reqs) } // All requirements can be re-routed to left child directly. JoinType::LeftAnti | JoinType::LeftSemi => (indices.to_vec(), vec![]), - // All requirements can be re-routed to right side directly. (No need to change index, join schema is right child schema.) + // All requirements can be re-routed to right side directly. + // No need to change index, join schema is right child schema. JoinType::RightSemi | JoinType::RightAnti => (vec![], indices.to_vec()), } } -/// Adds a projection on top of a logical plan if it is beneficial and reduces the number of columns for the parent operator. +/// Adds a projection on top of a logical plan if doing so reduces the number +/// of columns for the parent operator. /// -/// This function takes a `LogicalPlan`, a list of projection expressions, and a flag indicating whether -/// the projection is beneficial. If the projection is beneficial and reduces the number of columns in -/// the plan, a new `LogicalPlan` with the projection is created and returned, along with a `true` flag. -/// If the projection is unnecessary or doesn't reduce the number of columns, the original plan is returned -/// with a `false` flag. +/// This function takes a `LogicalPlan` and a list of projection expressions. +/// If the projection is beneficial (it reduces the number of columns in the +/// plan) a new `LogicalPlan` with the projection is created and returned, along +/// with a `true` flag. If the projection doesn't reduce the number of columns, +/// the original plan is returned with a `false` flag. /// -/// # Arguments +/// # Parameters /// /// * `plan` - The input `LogicalPlan` to potentially add a projection to. /// * `project_exprs` - A list of expressions for the projection. -/// * `projection_beneficial` - A flag indicating whether the projection is beneficial. /// /// # Returns /// -/// A `Result` containing a tuple with two values: the resulting `LogicalPlan` (with or without -/// the added projection) and a `bool` flag indicating whether the projection was added (`true`) or not (`false`). +/// A `Result` containing a tuple with two values: The resulting `LogicalPlan` +/// (with or without the added projection) and a `bool` flag indicating if a +/// projection was added (`true`) or not (`false`). fn add_projection_on_top_if_helpful( plan: LogicalPlan, project_exprs: Vec, - projection_beneficial: bool, ) -> Result<(LogicalPlan, bool)> { - // Make sure projection decreases table column size, otherwise it is unnecessary. - if !projection_beneficial || project_exprs.len() >= plan.schema().fields().len() { + // Make sure projection decreases the number of columns, otherwise it is unnecessary. + if project_exprs.len() >= plan.schema().fields().len() { Ok((plan, false)) } else { - let new_plan = Projection::try_new(project_exprs, Arc::new(plan)) - .map(LogicalPlan::Projection)?; - Ok((new_plan, true)) + Projection::try_new(project_exprs, Arc::new(plan)) + .map(|proj| (LogicalPlan::Projection(proj), true)) } } -/// Collects and returns a vector of all indices of the fields in the schema of a logical plan. +/// Rewrite the given projection according to the fields required by its +/// ancestors. /// -/// # Arguments +/// # Parameters /// -/// * `plan` - A reference to the `LogicalPlan` for which indices are required. +/// * `proj` - A reference to the original projection to rewrite. +/// * `config` - A reference to the optimizer configuration. +/// * `indices` - A slice of indices representing the columns required by the +/// ancestors of the given projection. /// /// # Returns /// -/// A vector of `usize` indices representing all fields in the schema of the provided logical plan. -fn require_all_indices(plan: &LogicalPlan) -> Vec { - (0..plan.schema().fields().len()).collect() -} - -/// Rewrite Projection Given Required fields by its parent(s). -/// -/// # Arguments +/// A `Result` object with the following semantics: /// -/// * `proj` - A reference to the original projection to be rewritten. -/// * `_config` - A reference to the optimizer configuration (unused in the function). -/// * `indices` - A slice of indices representing the required columns by the parent(s) of projection. -/// -/// # Returns -/// -/// A `Result` containing an `Option` of the rewritten logical plan. If the -/// rewrite is successful, it returns `Some` with the optimized logical plan. -/// If the logical plan remains unchanged it returns `Ok(None)`. +/// - `Ok(Some(LogicalPlan))`: Contains the rewritten projection +/// - `Ok(None)`: No rewrite necessary. +/// - `Err(error)`: An error occured during the function call. fn rewrite_projection_given_requirements( proj: &Projection, - _config: &dyn OptimizerConfig, + config: &dyn OptimizerConfig, indices: &[usize], ) -> Result> { let exprs_used = get_at_indices(&proj.expr, indices); - let required_indices = indices_referred_by_exprs(&proj.input, exprs_used.iter())?; + let required_indices = + indices_referred_by_exprs(proj.input.schema(), exprs_used.iter())?; return if let Some(input) = - optimize_projections(&proj.input, _config, &required_indices)? + optimize_projections(&proj.input, config, &required_indices)? { if &projection_schema(&input, &exprs_used)? == input.schema() { Ok(Some(input)) } else { - let new_proj = Projection::try_new(exprs_used, Arc::new(input))?; - let new_proj = LogicalPlan::Projection(new_proj); - Ok(Some(new_proj)) + Projection::try_new(exprs_used, Arc::new(input)) + .map(|proj| Some(LogicalPlan::Projection(proj))) } } else if exprs_used.len() < proj.expr.len() { - // Projection expression used is different than the existing projection - // In this case, even if child doesn't change we should update projection to use less columns. + // Projection expression used is different than the existing projection. + // In this case, even if the child doesn't change, we should update the + // projection to use fewer columns: if &projection_schema(&proj.input, &exprs_used)? == proj.input.schema() { Ok(Some(proj.input.as_ref().clone())) } else { - let new_proj = Projection::try_new(exprs_used, proj.input.clone())?; - let new_proj = LogicalPlan::Projection(new_proj); - Ok(Some(new_proj)) + Projection::try_new(exprs_used, proj.input.clone()) + .map(|proj| Some(LogicalPlan::Projection(proj))) } } else { // Projection doesn't change. @@ -847,15 +970,16 @@ fn rewrite_projection_given_requirements( #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::optimize_projections::OptimizeProjections; - use datafusion_common::Result; + use crate::test::{assert_optimized_plan_eq, test_table_scan}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{Result, TableReference}; use datafusion_expr::{ - binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan, - Operator, + binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, + table_scan, Expr, LogicalPlan, Operator, }; - use std::sync::Arc; - - use crate::test::*; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) @@ -900,4 +1024,39 @@ mod tests { \n TableScan: test projection=[a]"; assert_optimized_plan_equal(&plan, expected) } + + #[test] + fn merge_nested_alias() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").alias("alias1").alias("alias2")])? + .project(vec![col("alias2").alias("alias")])? + .build()?; + + let expected = "Projection: test.a AS alias\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_nested_count() -> Result<()> { + let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]); + + let groups: Vec = vec![]; + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .aggregate(groups.clone(), vec![count(lit(1))]) + .unwrap() + .aggregate(groups, vec![count(lit(1))]) + .unwrap() + .build() + .unwrap(); + + let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n Projection: \ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n TableScan: ?table? projection=[]"; + assert_optimized_plan_equal(&plan, expected) + } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 7af46ed70adf..0dc34cb809eb 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -17,6 +17,10 @@ //! Query optimizer traits +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Instant; + use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; @@ -41,15 +45,14 @@ use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::unwrap_cast_in_comparison::UnwrapCastInComparison; use crate::utils::log_plan; -use chrono::{DateTime, Utc}; + use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::LogicalPlan; +use datafusion_expr::logical_plan::LogicalPlan; + +use chrono::{DateTime, Utc}; use log::{debug, warn}; -use std::collections::HashSet; -use std::sync::Arc; -use std::time::Instant; /// `OptimizerRule` transforms one [`LogicalPlan`] into another which /// computes the same results, but in a potentially more efficient @@ -447,17 +450,18 @@ pub(crate) fn assert_schema_is_the_same( #[cfg(test)] mod tests { + use std::sync::{Arc, Mutex}; + + use super::ApplyOrder; use crate::optimizer::Optimizer; use crate::test::test_table_scan; use crate::{OptimizerConfig, OptimizerContext, OptimizerRule}; + use datafusion_common::{ plan_err, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::logical_plan::EmptyRelation; use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; - use std::sync::{Arc, Mutex}; - - use super::ApplyOrder; #[test] fn skip_failing_rule() { diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index e8f116d89466..c090fb849a82 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1062,7 +1062,7 @@ mod tests { ]); let mut optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index bdd66347631c..10cc1879aeeb 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -625,7 +625,7 @@ mod tests { let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 917ddc565c9e..e691fe9a5351 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -158,7 +158,7 @@ pub fn assert_optimized_plan_eq( let optimizer = Optimizer::with_rules(vec![rule.clone()]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? @@ -199,7 +199,7 @@ pub fn assert_optimized_plan_eq_display_indent( let optimizer = Optimizer::with_rules(vec![rule]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), ) @@ -233,7 +233,7 @@ pub fn assert_optimizer_err( ) { let optimizer = Optimizer::with_rules(vec![rule]); let res = optimizer.optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), ); @@ -255,7 +255,7 @@ pub fn assert_optimization_skipped( let optimizer = Optimizer::with_rules(vec![rule]); let new_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 4172881c0aad..d857c6154ea9 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; -use chrono::{DateTime, NaiveDateTime, Utc}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; @@ -28,9 +31,8 @@ use datafusion_sql::sqlparser::ast::Statement; use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; use datafusion_sql::TableReference; -use std::any::Any; -use std::collections::HashMap; -use std::sync::Arc; + +use chrono::{DateTime, NaiveDateTime, Utc}; #[cfg(test)] #[ctor::ctor] diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs index dcc8c37e7484..cf980f4c3f16 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs @@ -309,7 +309,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { // double check each array has the same length (aka the // accumulator was implemented correctly - if let Some(first_col) = arrays.get(0) { + if let Some(first_col) = arrays.first() { for arr in &arrays { assert_eq!(arr.len(), first_col.len()) } diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 8cea300f0e6f..ae048694583b 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -537,7 +537,7 @@ fn general_except( dedup.clear(); } - if let Some(values) = converter.convert_rows(rows)?.get(0) { + if let Some(values) = converter.convert_rows(rows)?.first() { Ok(GenericListArray::::new( field.to_owned(), OffsetBuffer::new(offsets.into()), @@ -2088,7 +2088,7 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result { }; offsets.push(last_offset + rows.len() as i32); let arrays = converter.convert_rows(rows)?; - let array = match arrays.get(0) { + let array = match arrays.first() { Some(array) => array.clone(), None => { return internal_err!( diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index 04cfec29ea8a..d634b4d01918 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -36,6 +36,7 @@ use arrow::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }, }; +use arrow_array::types::ArrowTimestampType; use arrow_array::{ timezone::Tz, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray, @@ -43,7 +44,7 @@ use arrow_array::{ use chrono::prelude::*; use chrono::{Duration, Months, NaiveDate}; use datafusion_common::cast::{ - as_date32_array, as_date64_array, as_generic_string_array, + as_date32_array, as_date64_array, as_generic_string_array, as_primitive_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, }; @@ -335,7 +336,7 @@ fn date_trunc_coarse(granularity: &str, value: i64, tz: Option) -> Result, tz: Option, @@ -403,123 +404,61 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result { return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8"); }; + fn process_array( + array: &dyn Array, + granularity: String, + tz_opt: &Option>, + ) -> Result { + let parsed_tz = parse_tz(tz_opt)?; + let array = as_primitive_array::(array)?; + let array = array + .iter() + .map(|x| general_date_trunc(T::UNIT, &x, parsed_tz, granularity.as_str())) + .collect::>>()? + .with_timezone_opt(tz_opt.clone()); + Ok(ColumnarValue::Array(Arc::new(array))) + } + + fn process_scalr( + v: &Option, + granularity: String, + tz_opt: &Option>, + ) -> Result { + let parsed_tz = parse_tz(tz_opt)?; + let value = general_date_trunc(T::UNIT, v, parsed_tz, granularity.as_str())?; + let value = ScalarValue::new_timestamp::(value, tz_opt.clone()); + Ok(ColumnarValue::Scalar(value)) + } + Ok(match array { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Nanosecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampNanosecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalr::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Microsecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampMicrosecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalr::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Millisecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampMillisecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalr::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Second, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampSecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalr::(v, granularity, tz_opt)? } ColumnarValue::Array(array) => { let array_type = array.data_type(); match array_type { DataType::Timestamp(TimeUnit::Second, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_second_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Second, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_millisecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Millisecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_microsecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Microsecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_nanosecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Nanosecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) - } - _ => { - let parsed_tz = None; - let array = as_timestamp_nanosecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Nanosecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()?; - - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } + _ => process_array::(array, granularity, &None)?, } } _ => { diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index b3ca95292a37..0c4ed3c12549 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -140,8 +140,7 @@ impl PhysicalExpr for CastExpr { let mut s = state; self.expr.hash(&mut s); self.cast_type.hash(&mut s); - // Add `self.cast_options` when hash is available - // https://github.com/apache/arrow-rs/pull/4395 + self.cast_options.hash(&mut s); } /// A [`CastExpr`] preserves the ordering of its child. @@ -157,8 +156,7 @@ impl PartialEq for CastExpr { .map(|x| { self.expr.eq(&x.expr) && self.cast_type == x.cast_type - // TODO: Use https://github.com/apache/arrow-rs/issues/2966 when available - && self.cast_options.safe == x.cast_options.safe + && self.cast_options == x.cast_options }) .unwrap_or(false) } diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 39cd47452eff..7de474fda11c 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -55,7 +55,7 @@ impl fmt::Debug for MemoryExec { write!(f, "partitions: [...]")?; write!(f, "schema: {:?}", self.projected_schema)?; write!(f, "projection: {:?}", self.projection)?; - if let Some(sort_info) = &self.sort_information.get(0) { + if let Some(sort_info) = &self.sort_information.first() { write!(f, ", output_ordering: {:?}", sort_info)?; } Ok(()) diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index 71e6cba6741e..1f6ee1f117aa 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -790,7 +790,7 @@ impl Stream for PanicStream { } else { self.ready = true; // get called again - cx.waker().clone().wake(); + cx.waker().wake_by_ref(); return Poll::Pending; } } diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index f988b28cce0d..431a43bc6055 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -40,7 +40,7 @@ use crate::{ }; use arrow::{ - array::{Array, ArrayRef, UInt32Builder}, + array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, compute::{concat, concat_batches, sort_to_indices}, datatypes::{Schema, SchemaBuilder, SchemaRef}, record_batch::RecordBatch, @@ -1026,8 +1026,11 @@ impl BoundedWindowAggStream { .iter() .map(|elem| elem.slice(n_out, n_to_keep)) .collect::>(); - self.input_buffer = - RecordBatch::try_new(self.input_buffer.schema(), batch_to_keep)?; + self.input_buffer = RecordBatch::try_new_with_options( + self.input_buffer.schema(), + batch_to_keep, + &RecordBatchOptions::new().with_row_count(Some(n_to_keep)), + )?; Ok(()) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 697c94959fbb..13a54f2a5659 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -841,6 +841,8 @@ message Field { // for complex data types like structs, unions repeated Field children = 4; map metadata = 5; + int64 dict_id = 6; + bool dict_ordered = 7; } message FixedSizeBinary{ @@ -1162,6 +1164,7 @@ message PhysicalPlanNode { AnalyzeExecNode analyze = 23; JsonSinkExecNode json_sink = 24; SymmetricHashJoinExecNode symmetric_hash_join = 25; + InterleaveExecNode interleave = 26; } } @@ -1455,6 +1458,10 @@ message SymmetricHashJoinExecNode { JoinFilter filter = 8; } +message InterleaveExecNode { + repeated PhysicalPlanNode inputs = 1; +} + message UnionExecNode { repeated PhysicalPlanNode inputs = 1; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 35bded71d619..0d013c72d37f 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -6910,6 +6910,12 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { len += 1; } + if self.dict_id != 0 { + len += 1; + } + if self.dict_ordered { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.Field", len)?; if !self.name.is_empty() { struct_ser.serialize_field("name", &self.name)?; @@ -6926,6 +6932,13 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { struct_ser.serialize_field("metadata", &self.metadata)?; } + if self.dict_id != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dictId", ToString::to_string(&self.dict_id).as_str())?; + } + if self.dict_ordered { + struct_ser.serialize_field("dictOrdered", &self.dict_ordered)?; + } struct_ser.end() } } @@ -6942,6 +6955,10 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable", "children", "metadata", + "dict_id", + "dictId", + "dict_ordered", + "dictOrdered", ]; #[allow(clippy::enum_variant_names)] @@ -6951,6 +6968,8 @@ impl<'de> serde::Deserialize<'de> for Field { Nullable, Children, Metadata, + DictId, + DictOrdered, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6977,6 +6996,8 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable" => Ok(GeneratedField::Nullable), "children" => Ok(GeneratedField::Children), "metadata" => Ok(GeneratedField::Metadata), + "dictId" | "dict_id" => Ok(GeneratedField::DictId), + "dictOrdered" | "dict_ordered" => Ok(GeneratedField::DictOrdered), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7001,6 +7022,8 @@ impl<'de> serde::Deserialize<'de> for Field { let mut nullable__ = None; let mut children__ = None; let mut metadata__ = None; + let mut dict_id__ = None; + let mut dict_ordered__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -7035,6 +7058,20 @@ impl<'de> serde::Deserialize<'de> for Field { map_.next_value::>()? ); } + GeneratedField::DictId => { + if dict_id__.is_some() { + return Err(serde::de::Error::duplicate_field("dictId")); + } + dict_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DictOrdered => { + if dict_ordered__.is_some() { + return Err(serde::de::Error::duplicate_field("dictOrdered")); + } + dict_ordered__ = Some(map_.next_value()?); + } } } Ok(Field { @@ -7043,6 +7080,8 @@ impl<'de> serde::Deserialize<'de> for Field { nullable: nullable__.unwrap_or_default(), children: children__.unwrap_or_default(), metadata: metadata__.unwrap_or_default(), + dict_id: dict_id__.unwrap_or_default(), + dict_ordered: dict_ordered__.unwrap_or_default(), }) } } @@ -9205,6 +9244,97 @@ impl<'de> serde::Deserialize<'de> for InListNode { deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for InterleaveExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.inputs.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.InterleaveExecNode", len)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for InterleaveExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "inputs", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Inputs, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "inputs" => Ok(GeneratedField::Inputs), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = InterleaveExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.InterleaveExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut inputs__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); + } + inputs__ = Some(map_.next_value()?); + } + } + } + Ok(InterleaveExecNode { + inputs: inputs__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.InterleaveExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for IntervalMonthDayNanoValue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17887,6 +18017,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::SymmetricHashJoin(v) => { struct_ser.serialize_field("symmetricHashJoin", v)?; } + physical_plan_node::PhysicalPlanType::Interleave(v) => { + struct_ser.serialize_field("interleave", v)?; + } } } struct_ser.end() @@ -17935,6 +18068,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "jsonSink", "symmetric_hash_join", "symmetricHashJoin", + "interleave", ]; #[allow(clippy::enum_variant_names)] @@ -17963,6 +18097,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { Analyze, JsonSink, SymmetricHashJoin, + Interleave, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18008,6 +18143,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "analyze" => Ok(GeneratedField::Analyze), "jsonSink" | "json_sink" => Ok(GeneratedField::JsonSink), "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), + "interleave" => Ok(GeneratedField::Interleave), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18196,6 +18332,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("symmetricHashJoin")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SymmetricHashJoin) +; + } + GeneratedField::Interleave => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("interleave")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Interleave) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 97843319b116..d4b62d4b3fd8 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1027,6 +1027,10 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, + #[prost(int64, tag = "6")] + pub dict_id: i64, + #[prost(bool, tag = "7")] + pub dict_ordered: bool, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1521,7 +1525,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26" )] pub physical_plan_type: ::core::option::Option, } @@ -1580,6 +1584,8 @@ pub mod physical_plan_node { JsonSink(::prost::alloc::boxed::Box), #[prost(message, tag = "25")] SymmetricHashJoin(::prost::alloc::boxed::Box), + #[prost(message, tag = "26")] + Interleave(super::InterleaveExecNode), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -2038,6 +2044,12 @@ pub struct SymmetricHashJoinExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct InterleaveExecNode { + #[prost(message, repeated, tag = "1")] + pub inputs: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionExecNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 1b7e3d483818..193e0947d6d9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -39,6 +39,7 @@ use datafusion_common::{ internal_err, plan_datafusion_err, Column, Constraint, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, ScalarValue, }; +use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ abs, acos, acosh, array, array_append, array_concat, array_dims, array_distinct, array_element, array_except, array_has, array_has_all, array_has_any, @@ -59,7 +60,6 @@ use datafusion_expr::{ sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index, substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, - window_frame::regularize, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, @@ -377,8 +377,20 @@ impl TryFrom<&protobuf::Field> for Field { type Error = Error; fn try_from(field: &protobuf::Field) -> Result { let datatype = field.arrow_type.as_deref().required("arrow_type")?; - Ok(Self::new(field.name.as_str(), datatype, field.nullable) - .with_metadata(field.metadata.clone())) + let field = if field.dict_id != 0 { + Self::new_dict( + field.name.as_str(), + datatype, + field.nullable, + field.dict_id, + field.dict_ordered, + ) + .with_metadata(field.metadata.clone()) + } else { + Self::new(field.name.as_str(), datatype, field.nullable) + .with_metadata(field.metadata.clone()) + }; + Ok(field) } } @@ -1073,7 +1085,7 @@ pub fn parse_expr( .iter() .map(|e| parse_expr(e, registry)) .collect::, _>>()?; - let order_by = expr + let mut order_by = expr .order_by .iter() .map(|e| parse_expr(e, registry)) @@ -1083,7 +1095,8 @@ pub fn parse_expr( .as_ref() .map::, _>(|window_frame| { let window_frame = window_frame.clone().try_into()?; - regularize(window_frame, order_by.len()) + check_window_frame(&window_frame, order_by.len()) + .map(|_| window_frame) }) .transpose()? .ok_or_else(|| { @@ -1091,6 +1104,7 @@ pub fn parse_expr( "missing window frame during deserialization".to_string(), ) })?; + regularize_window_order_by(&window_frame, &mut order_by)?; match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 646cde274c0d..2997d147424d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -108,6 +108,8 @@ impl TryFrom<&Field> for protobuf::Field { nullable: field.is_nullable(), children: Vec::new(), metadata: field.metadata().clone(), + dict_id: field.dict_id().unwrap_or(0), + dict_ordered: field.dict_is_ordered().unwrap_or(false), }) } } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 74c8ec894ff2..878a5bcb7f69 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -48,7 +48,7 @@ use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion::physical_plan::union::UnionExec; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ udaf, AggregateExpr, ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, @@ -545,7 +545,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -556,7 +556,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -634,7 +634,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -645,7 +645,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -693,6 +693,17 @@ impl AsExecutionPlan for PhysicalPlanNode { } Ok(Arc::new(UnionExec::new(inputs))) } + PhysicalPlanType::Interleave(interleave) => { + let mut inputs: Vec> = vec![]; + for input in &interleave.inputs { + inputs.push(input.try_into_physical_plan( + registry, + runtime, + extension_codec, + )?); + } + Ok(Arc::new(InterleaveExec::try_new(inputs)?)) + } PhysicalPlanType::CrossJoin(crossjoin) => { let left: Arc = into_physical_plan( &crossjoin.left, @@ -735,7 +746,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -782,7 +793,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -845,7 +856,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -856,7 +867,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -1463,6 +1474,21 @@ impl AsExecutionPlan for PhysicalPlanNode { }); } + if let Some(interleave) = plan.downcast_ref::() { + let mut inputs: Vec = vec![]; + for input in interleave.inputs() { + inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( + input.to_owned(), + extension_codec, + )?); + } + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Interleave( + protobuf::InterleaveExecNode { inputs }, + )), + }); + } + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 5e36a838f311..8e15b5d0d480 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -972,6 +972,45 @@ fn round_trip_datatype() { } } +#[test] +fn roundtrip_dict_id() -> Result<()> { + let dict_id = 42; + let field = Field::new( + "keys", + DataType::List(Arc::new(Field::new_dict( + "item", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + true, + dict_id, + false, + ))), + false, + ); + let schema = Arc::new(Schema::new(vec![field])); + + // encode + let mut buf: Vec = vec![]; + let schema_proto: datafusion_proto::generated::datafusion::Schema = + schema.try_into().unwrap(); + schema_proto.encode(&mut buf).unwrap(); + + // decode + let schema_proto = + datafusion_proto::generated::datafusion::Schema::decode(buf.as_slice()).unwrap(); + let decoded: Schema = (&schema_proto).try_into()?; + + // assert + let keys = decoded.fields().iter().last().unwrap(); + match keys.data_type() { + DataType::List(field) => { + assert_eq!(field.dict_id(), Some(dict_id), "dict_id should be retained"); + } + _ => panic!("Invalid type"), + } + + Ok(()) +} + #[test] fn roundtrip_null_scalar_values() { let test_types = vec![ diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 287207bae5f6..f46a29447dd6 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -50,12 +50,14 @@ use datafusion::physical_plan::joins::{ }; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, }; use datafusion::physical_plan::{ - functions, udaf, AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics, + functions, udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, }; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; @@ -231,21 +233,21 @@ fn roundtrip_window() -> Result<()> { }; let builtin_window_expr = Arc::new(BuiltInWindowExpr::new( - Arc::new(NthValue::first( - "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", - col("a", &schema)?, - DataType::Int64, - )), - &[col("b", &schema)?], - &[PhysicalSortExpr { - expr: col("a", &schema)?, - options: SortOptions { - descending: false, - nulls_first: false, - }, - }], - Arc::new(window_frame), - )); + Arc::new(NthValue::first( + "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", + col("a", &schema)?, + DataType::Int64, + )), + &[col("b", &schema)?], + &[PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + Arc::new(window_frame), + )); let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( Arc::new(Avg::new( @@ -798,3 +800,34 @@ fn roundtrip_sym_hash_join() -> Result<()> { } Ok(()) } + +#[test] +fn roundtrip_union() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let left = EmptyExec::new(false, Arc::new(schema_left)); + let right = EmptyExec::new(false, Arc::new(schema_right)); + let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; + let union = UnionExec::new(inputs); + roundtrip_test(Arc::new(union)) +} + +#[test] +fn roundtrip_interleave() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let partition = Partitioning::Hash(vec![], 3); + let left = RepartitionExec::try_new( + Arc::new(EmptyExec::new(false, Arc::new(schema_left))), + partition.clone(), + )?; + let right = RepartitionExec::try_new( + Arc::new(EmptyExec::new(false, Arc::new(schema_right))), + partition.clone(), + )?; + let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; + let interleave = InterleaveExec::try_new(inputs)?; + roundtrip_test(Arc::new(interleave)) +} diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 14ea20c3fa5f..73de4fa43907 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -21,7 +21,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::function::suggest_valid_function; -use datafusion_expr::window_frame::regularize; +use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, WindowFunction, @@ -92,7 +92,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .into_iter() .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; - let order_by = self.order_by_to_sort_expr( + let mut order_by = self.order_by_to_sort_expr( &window.order_by, schema, planner_context, @@ -104,14 +104,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .as_ref() .map(|window_frame| { let window_frame = window_frame.clone().try_into()?; - regularize(window_frame, order_by.len()) + check_window_frame(&window_frame, order_by.len()) + .map(|_| window_frame) }) .transpose()?; + let window_frame = if let Some(window_frame) = window_frame { + regularize_window_order_by(&window_frame, &mut order_by)?; window_frame } else { WindowFrame::new(!order_by.is_empty()) }; + if let Ok(fun) = self.find_window_func(&name) { let expr = match fun { WindowFunction::AggregateFunction(aggregate_fun) => { diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 6fc7e9601243..b233f47a058f 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -36,7 +36,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { name, alias, args, .. } => { if let Some(func_args) = args { - let tbl_func_name = name.0.get(0).unwrap().value.to_string(); + let tbl_func_name = name.0.first().unwrap().value.to_string(); let args = func_args .into_iter() .flat_map(|arg| { diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 15f720d75652..a0819e4aaf8e 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -25,10 +25,7 @@ use crate::utils::{ }; use datafusion_common::Column; -use datafusion_common::{ - get_target_functional_dependencies, not_impl_err, plan_err, DFSchemaRef, - DataFusionError, Result, -}; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, @@ -534,14 +531,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { group_by_exprs: &[Expr], aggr_exprs: &[Expr], ) -> Result<(LogicalPlan, Vec, Option)> { - let group_by_exprs = - get_updated_group_by_exprs(group_by_exprs, select_exprs, input.schema())?; - // create the aggregate plan let plan = LogicalPlanBuilder::from(input.clone()) - .aggregate(group_by_exprs.clone(), aggr_exprs.to_vec())? + .aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())? .build()?; + let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan { + &agg.group_expr + } else { + unreachable!(); + }; + // in this next section of code we are re-writing the projection to refer to columns // output by the aggregate plan. For example, if the projection contains the expression // `SUM(a)` then we replace that with a reference to a column `SUM(a)` produced by @@ -550,7 +550,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // combine the original grouping and aggregate expressions into one list (note that // we do not add the "having" expression since that is not part of the projection) let mut aggr_projection_exprs = vec![]; - for expr in &group_by_exprs { + for expr in group_by_exprs { match expr { Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { aggr_projection_exprs.extend_from_slice(exprs) @@ -659,61 +659,3 @@ fn match_window_definitions( } Ok(()) } - -/// Update group by exprs, according to functional dependencies -/// The query below -/// -/// SELECT sn, amount -/// FROM sales_global -/// GROUP BY sn -/// -/// cannot be calculated, because it has a column(`amount`) which is not -/// part of group by expression. -/// However, if we know that, `sn` is determinant of `amount`. We can -/// safely, determine value of `amount` for each distinct `sn`. For these cases -/// we rewrite the query above as -/// -/// SELECT sn, amount -/// FROM sales_global -/// GROUP BY sn, amount -/// -/// Both queries, are functionally same. \[Because, (`sn`, `amount`) and (`sn`) -/// defines the identical groups. \] -/// This function updates group by expressions such that select expressions that are -/// not in group by expression, are added to the group by expressions if they are dependent -/// of the sub-set of group by expressions. -fn get_updated_group_by_exprs( - group_by_exprs: &[Expr], - select_exprs: &[Expr], - schema: &DFSchemaRef, -) -> Result> { - let mut new_group_by_exprs = group_by_exprs.to_vec(); - let fields = schema.fields(); - let group_by_expr_names = group_by_exprs - .iter() - .map(|group_by_expr| group_by_expr.display_name()) - .collect::>>()?; - // Get targets that can be used in a select, even if they do not occur in aggregation: - if let Some(target_indices) = - get_target_functional_dependencies(schema, &group_by_expr_names) - { - // Calculate dependent fields names with determinant GROUP BY expression: - let associated_field_names = target_indices - .iter() - .map(|idx| fields[*idx].qualified_name()) - .collect::>(); - // Expand GROUP BY expressions with select expressions: If a GROUP - // BY expression is a determinant key, we can use its dependent - // columns in select statements also. - for expr in select_exprs { - let expr_name = format!("{}", expr); - if !new_group_by_exprs.contains(expr) - && associated_field_names.contains(&expr_name) - { - new_group_by_exprs.push(expr.clone()); - } - } - } - - Ok(new_group_by_exprs) -} diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index e4718035a58d..7cfc9c707d43 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3199,3 +3199,16 @@ FROM my_data GROUP BY dummy ---- text1, text1, text1 + + +# Queries with nested count(*) + +query I +select count(*) from (select count(*) from (select 1)); +---- +1 + +query I +select count(*) from (select count(*) a, count(*) b from (select 1)); +---- +1 \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 1d6d7dc671fa..5248ac8c8531 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -3211,6 +3211,21 @@ SELECT s.sn, s.amount, 2*s.sn 3 200 6 4 100 8 +# we should be able to re-write group by expression +# using functional dependencies for complex expressions also. +# In this case, we use 2*s.amount instead of s.amount. +query IRI +SELECT s.sn, 2*s.amount, 2*s.sn + FROM sales_global_with_pk AS s + GROUP BY sn + ORDER BY sn +---- +0 60 0 +1 100 2 +2 150 4 +3 400 6 +4 200 8 + query IRI SELECT s.sn, s.amount, 2*s.sn FROM sales_global_with_pk_alternate AS s @@ -3364,7 +3379,7 @@ SELECT column1, COUNT(*) as column2 FROM (VALUES (['a', 'b'], 1), (['c', 'd', 'e # primary key should be aware from which columns it is associated -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.sn could not be resolved from available columns: l.sn, SUM\(l.amount\) +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.sn could not be resolved from available columns: l.sn, l.zip_code, l.country, l.ts, l.currency, l.amount, SUM\(l.amount\) SELECT l.sn, r.sn, SUM(l.amount), r.amount FROM sales_global_with_pk AS l JOIN sales_global_with_pk AS r @@ -3456,7 +3471,7 @@ ORDER BY r.sn 4 100 2022-01-03T10:00:00 # after join, new window expressions shouldn't be associated with primary keys -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression rn1 could not be resolved from available columns: r.sn, SUM\(r.amount\) +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression rn1 could not be resolved from available columns: r.sn, r.ts, r.amount, SUM\(r.amount\) SELECT r.sn, SUM(r.amount), rn1 FROM (SELECT r.ts, r.sn, r.amount, @@ -3784,6 +3799,192 @@ AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[SUM(multip ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true +statement ok +set datafusion.execution.target_partitions = 1; + +query TT +EXPLAIN SELECT c, sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) +GROUP BY c; +---- +logical_plan +Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, sum1]], aggr=[[]] +--Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[c, d] +physical_plan +AggregateExec: mode=Single, gby=[c@0 as c, sum1@1 as sum1], aggr=[], ordering_mode=PartiallySorted([0]) +--ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +----AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT c, sum1, SUM(b) OVER() as sumb + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c); +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS sumb +--WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table_with_pk.b AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +----Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +--------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, sum1@2 as sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as sumb] +--WindowAggExec: wdw=[SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] +----ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as lhs + JOIN + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as rhs + ON lhs.b=rhs.b; +---- +logical_plan +Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 +--Inner Join: lhs.b = rhs.b +----SubqueryAlias: lhs +------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +----SubqueryAlias: rhs +------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, c@3 as c, sum1@2 as sum1, sum1@5 as sum1] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(b@1, b@1)] +------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true +------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as lhs + CROSS JOIN + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as rhs; +---- +logical_plan +Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 +--CrossJoin: +----SubqueryAlias: lhs +------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[c, d] +----SubqueryAlias: rhs +------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, c@2 as c, sum1@1 as sum1, sum1@3 as sum1] +--CrossJoinExec +----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true +----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +# we do not generate physical plan for Repartition yet (e.g Distribute By queries). +query TT +EXPLAIN SELECT a, b, sum1 +FROM (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) +DISTRIBUTE BY a +---- +logical_plan +Repartition: DistributeBy(a) +--Projection: multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, b, c, d] + +# union with aggregate +query TT +EXPLAIN SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +UNION ALL + SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +---- +logical_plan +Union +--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +UnionExec +--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true +--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +# table scan should be simplified. +query TT +EXPLAIN SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +# limit should be simplified +query TT +EXPLAIN SELECT * + FROM (SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c + LIMIT 5) +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--Limit: skip=0, fetch=5 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--GlobalLimitExec: skip=0, fetch=5 +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +statement ok +set datafusion.execution.target_partitions = 8; + # Tests for single distinct to group by optimization rule statement ok CREATE TABLE t(x int) AS VALUES (1), (2), (1); diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 3f3befd85a59..ea570b99d4dd 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1056,3 +1056,51 @@ drop table annotated_data_finite2; statement ok drop table t; + +statement ok +create table t(x bigint, y bigint) as values (1,2), (1,3); + +query II +select z+1, y from (select x+1 as z, y from t) where y > 1; +---- +3 2 +3 3 + +query TT +EXPLAIN SELECT x/2, x/2+1 FROM t; +---- +logical_plan +Projection: t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2), t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2) + Int64(1) +--Projection: t.x / Int64(2) AS t.x / Int64(2)Int64(2)t.x +----TableScan: t projection=[x] +physical_plan +ProjectionExec: expr=[t.x / Int64(2)Int64(2)t.x@0 as t.x / Int64(2), t.x / Int64(2)Int64(2)t.x@0 + 1 as t.x / Int64(2) + Int64(1)] +--ProjectionExec: expr=[x@0 / 2 as t.x / Int64(2)Int64(2)t.x] +----MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT x/2, x/2+1 FROM t; +---- +0 1 +0 1 + +query TT +EXPLAIN SELECT abs(x), abs(x) + abs(y) FROM t; +---- +logical_plan +Projection: abs(t.x)t.x AS abs(t.x), abs(t.x)t.x AS abs(t.x) + abs(t.y) +--Projection: abs(t.x) AS abs(t.x)t.x, t.y +----TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[abs(t.x)t.x@0 as abs(t.x), abs(t.x)t.x@0 + abs(y@1) as abs(t.x) + abs(t.y)] +--ProjectionExec: expr=[abs(x@0) as abs(t.x)t.x, y@1 as y] +----MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT abs(x), abs(x) + abs(y) FROM t; +---- +1 3 +1 4 + +statement ok +DROP TABLE t; diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 0179431ac8ad..7846bb001a91 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3763,20 +3763,20 @@ select a, 1 1 2 2 -# TODO: this is different to Postgres which returns [1, 1] for `rnk`. -# Comment it because it is flaky now as it depends on the order of the `a` column. -# query II -# select a, -# rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk -# from (select 1 a union select 2 a) q ORDER BY rnk -# ---- -# 1 1 -# 2 2 - -# TODO: this works in Postgres which returns [1, 1]. -query error DataFusion error: Arrow error: Invalid argument error: must either specify a row count or at least one column +query II +select a, + rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 + +query I select rank() over (RANGE between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk from (select 1 a union select 2 a) q; +---- +1 +1 query II select a, @@ -3785,3 +3785,11 @@ select a, ---- 1 1 2 1 + +query II +select a, + rank() over (order by null RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1