From 8db30e25d6fe65a9779d237cf48aea9aee297502 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 3 Sep 2024 07:57:40 +0800 Subject: [PATCH 1/6] Introduce `Signature::Coercible` (#12275) * introduce signature float Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * change float to coercible Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * rm test Signed-off-by: jayzhan211 * add comment Signed-off-by: jayzhan211 * typo Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/core/src/dataframe/mod.rs | 6 ++-- datafusion/expr-common/src/signature.rs | 15 ++++++++- .../src/type_coercion/aggregates.rs | 6 ++-- .../expr/src/type_coercion/functions.rs | 33 ++++++++++++++++++- datafusion/functions-aggregate/src/stddev.rs | 11 +++---- .../functions-aggregate/src/variance.rs | 11 +++---- 6 files changed, 64 insertions(+), 18 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index b8c0bd9d744c..2138bd1294b4 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2428,7 +2428,8 @@ mod tests { let df: Vec = df.select(aggr_expr)?.collect().await?; assert_batches_sorted_eq!( - ["+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", + [ + "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", "| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |", "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", "| | | | | | | | 1 | -85 |", @@ -2452,7 +2453,8 @@ mod tests { "| -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 |", "| -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 |", "| -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 |", - "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+"], + "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", + ], &df ); diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 4dcfa423e371..2043757a49fb 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -105,6 +105,11 @@ pub enum TypeSignature { Uniform(usize, Vec), /// Exact number of arguments of an exact type Exact(Vec), + /// The number of arguments that can be coerced to in order + /// For example, `Coercible(vec![DataType::Float64])` accepts + /// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]` + /// since i32 and f32 can be casted to f64 + Coercible(Vec), /// Fixed number of arguments of arbitrary types /// If a function takes 0 argument, its `TypeSignature` should be `Any(0)` Any(usize), @@ -188,7 +193,7 @@ impl TypeSignature { TypeSignature::Numeric(num) => { vec![format!("Numeric({})", num)] } - TypeSignature::Exact(types) => { + TypeSignature::Exact(types) | TypeSignature::Coercible(types) => { vec![Self::join_types(types, ", ")] } TypeSignature::Any(arg_count) => { @@ -300,6 +305,14 @@ impl Signature { volatility, } } + /// Target coerce types in order + pub fn coercible(target_types: Vec, volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::Coercible(target_types), + volatility, + } + } + /// A specified number of arguments of any type pub fn any(arg_count: usize, volatility: Volatility) -> Self { Signature { diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index 40ee596eee05..2add9e7c1867 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -128,9 +128,11 @@ pub fn check_arg_count( ); } } - TypeSignature::UserDefined | TypeSignature::Numeric(_) => { + TypeSignature::UserDefined + | TypeSignature::Numeric(_) + | TypeSignature::Coercible(_) => { // User-defined signature is validated in `coerce_types` - // Numreic signature is validated in `get_valid_types` + // Numeric and Coercible signature is validated in `get_valid_types` } _ => { return internal_err!( diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index b0b14a1a4e6e..d30d202df050 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -175,7 +175,14 @@ fn try_coerce_types( let mut valid_types = valid_types; // Well-supported signature that returns exact valid types. - if !valid_types.is_empty() && matches!(type_signature, TypeSignature::UserDefined) { + if !valid_types.is_empty() + && matches!( + type_signature, + TypeSignature::UserDefined + | TypeSignature::Numeric(_) + | TypeSignature::Coercible(_) + ) + { // exact valid types assert_eq!(valid_types.len(), 1); let valid_types = valid_types.swap_remove(0); @@ -397,6 +404,30 @@ fn get_valid_types( vec![vec![valid_type; *number]] } + TypeSignature::Coercible(target_types) => { + if target_types.is_empty() { + return plan_err!( + "The signature expected at least one argument but received {}", + current_types.len() + ); + } + if target_types.len() != current_types.len() { + return plan_err!( + "The signature expected {} arguments but received {}", + target_types.len(), + current_types.len() + ); + } + + for (data_type, target_type) in current_types.iter().zip(target_types.iter()) + { + if !can_cast_types(data_type, target_type) { + return plan_err!("{data_type} is not coercible to {target_type}"); + } + } + + vec![target_types.to_owned()] + } TypeSignature::Uniform(number, valid_types) => valid_types .iter() .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 3534fb5b4d26..a25ab5e31991 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -68,7 +68,10 @@ impl Stddev { /// Create a new STDDEV aggregate function pub fn new() -> Self { Self { - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::coercible( + vec![DataType::Float64], + Volatility::Immutable, + ), alias: vec!["stddev_samp".to_string()], } } @@ -88,11 +91,7 @@ impl AggregateUDFImpl for Stddev { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Stddev requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index f5f2d06e3837..367a8669ab7d 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -79,7 +79,10 @@ impl VarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("var_sample"), String::from("var_samp")], - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::coercible( + vec![DataType::Float64], + Volatility::Immutable, + ), } } } @@ -97,11 +100,7 @@ impl AggregateUDFImpl for VarianceSample { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Variance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } From 4e1b6de5549cc7ed399766dd22cafd450701db16 Mon Sep 17 00:00:00 2001 From: iamthinh Date: Tue, 3 Sep 2024 00:35:56 -0700 Subject: [PATCH 2/6] fix partition-by panic (#12297) --- datafusion/physical-plan/src/coalesce/mod.rs | 15 ++++++++++++-- .../physical-plan/src/repartition/mod.rs | 12 ++++++++--- .../sqllogictest/test_files/repartition.slt | 20 +++++++++++++++++++ 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-plan/src/coalesce/mod.rs b/datafusion/physical-plan/src/coalesce/mod.rs index 5befa5ecda99..ce5a1e53abfd 100644 --- a/datafusion/physical-plan/src/coalesce/mod.rs +++ b/datafusion/physical-plan/src/coalesce/mod.rs @@ -18,7 +18,7 @@ use arrow::compute::concat_batches; use arrow_array::builder::StringViewBuilder; use arrow_array::cast::AsArray; -use arrow_array::{Array, ArrayRef, RecordBatch}; +use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions}; use arrow_schema::SchemaRef; use std::sync::Arc; @@ -265,7 +265,9 @@ fn gc_string_view_batch(batch: &RecordBatch) -> RecordBatch { } }) .collect(); - RecordBatch::try_new(batch.schema(), new_columns) + let mut options = RecordBatchOptions::new(); + options = options.with_row_count(Some(batch.num_rows())); + RecordBatch::try_new_with_options(batch.schema(), new_columns, &options) .expect("Failed to re-create the gc'ed record batch") } @@ -501,6 +503,15 @@ mod tests { assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len()); // no compaction } + #[test] + fn test_gc_string_view_test_batch_empty() { + let schema = Schema::empty(); + let batch = RecordBatch::new_empty(schema.into()); + let output_batch = gc_string_view_batch(&batch); + assert_eq!(batch.num_columns(), output_batch.num_columns()); + assert_eq!(batch.num_rows(), output_batch.num_rows()); + } + #[test] fn test_gc_string_view_batch_large_no_compact() { // view with large strings (has buffers) but full --> no need to compact diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 47e5192c237e..093803e3c8b3 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -41,7 +41,7 @@ use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Stat use arrow::array::ArrayRef; use arrow::datatypes::{SchemaRef, UInt64Type}; use arrow::record_batch::RecordBatch; -use arrow_array::PrimitiveArray; +use arrow_array::{PrimitiveArray, RecordBatchOptions}; use datafusion_common::utils::transpose; use datafusion_common::{arrow_datafusion_err, not_impl_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; @@ -309,8 +309,14 @@ impl BatchPartitioner { }) .collect::>>()?; - let batch = - RecordBatch::try_new(batch.schema(), columns).unwrap(); + let mut options = RecordBatchOptions::new(); + options = options.with_row_count(Some(indices.len())); + let batch = RecordBatch::try_new_with_options( + batch.schema(), + columns, + &options, + ) + .unwrap(); Ok((partition, batch)) }); diff --git a/datafusion/sqllogictest/test_files/repartition.slt b/datafusion/sqllogictest/test_files/repartition.slt index e3c204a4f4ce..2d59ad2b5b0e 100644 --- a/datafusion/sqllogictest/test_files/repartition.slt +++ b/datafusion/sqllogictest/test_files/repartition.slt @@ -127,3 +127,23 @@ physical_plan 04)------FilterExec: c3@2 > 0 05)--------RepartitionExec: partitioning=RoundRobinBatch(3), input_partitions=1 06)----------StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + +# Start repratition on empty column test. +# See https://github.com/apache/datafusion/issues/12057 + +statement ok +CREATE TABLE t1(v1 int); + +statement ok +INSERT INTO t1 values(42); + +query I +SELECT sum(1) OVER (PARTITION BY false=false) +FROM t1 WHERE ((false > (v1 = v1)) IS DISTINCT FROM true); +---- +1 + +statement ok +DROP TABLE t1; + +# End repartition on empty columns test \ No newline at end of file From 4a227c524f9130eb9b911af17a25c23b25aa41fa Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 3 Sep 2024 09:51:19 +0200 Subject: [PATCH 3/6] Remove unsafe Send impl from PriorityMap (#12289) It's not necessary to use unsafe Send impl. It's enough to require the referenced trait objects as Send. --- .../physical-plan/src/aggregates/topk/hash_table.rs | 5 ++++- datafusion/physical-plan/src/aggregates/topk/heap.rs | 6 +++++- .../physical-plan/src/aggregates/topk/priority_map.rs | 9 ++------- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index 2b02fff1f573..232b87de3231 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -367,7 +367,10 @@ has_integer!(u8, u16, u32, u64); has_integer!(IntervalDayTime, IntervalMonthDayNano); hash_float!(f16, f32, f64); -pub fn new_hash_table(limit: usize, kt: DataType) -> Result> { +pub fn new_hash_table( + limit: usize, + kt: DataType, +) -> Result> { macro_rules! downcast_helper { ($kt:ty, $d:ident) => { return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit))) diff --git a/datafusion/physical-plan/src/aggregates/topk/heap.rs b/datafusion/physical-plan/src/aggregates/topk/heap.rs index 81eadbc018b3..e694422e443d 100644 --- a/datafusion/physical-plan/src/aggregates/topk/heap.rs +++ b/datafusion/physical-plan/src/aggregates/topk/heap.rs @@ -459,7 +459,11 @@ compare_integer!(u8, u16, u32, u64); compare_integer!(IntervalDayTime, IntervalMonthDayNano); compare_float!(f16, f32, f64); -pub fn new_heap(limit: usize, desc: bool, vt: DataType) -> Result> { +pub fn new_heap( + limit: usize, + desc: bool, + vt: DataType, +) -> Result> { macro_rules! downcast_helper { ($vt:ty, $d:ident) => { return Ok(Box::new(PrimitiveHeap::<$vt>::new(limit, desc, vt))) diff --git a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs index 668018b9c24c..ed41d22e935b 100644 --- a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs +++ b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs @@ -25,17 +25,12 @@ use datafusion_common::Result; /// A `Map` / `PriorityQueue` combo that evicts the worst values after reaching `capacity` pub struct PriorityMap { - map: Box, - heap: Box, + map: Box, + heap: Box, capacity: usize, mapper: Vec<(usize, usize)>, } -// JUSTIFICATION -// Benefit: ~15% speedup + required to index into RawTable from binary heap -// Soundness: it is only accessed by one thread at a time, and indexes are kept up to date -unsafe impl Send for PriorityMap {} - impl PriorityMap { pub fn new( key_type: DataType, From 6bceeae6c622485c8cfaae34eb60152f5c7cbe01 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Tue, 3 Sep 2024 20:03:24 +0300 Subject: [PATCH 4/6] test: check record count and types in parquet window test (#12277) * test: check record count and types in paqruet window test * Review suggestions Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- .../sqllogictest/test_files/parquet.slt | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index 34d4ed6ff284..9a7b085312bb 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -251,25 +251,21 @@ SELECT COUNT(*) FROM timestamp_with_tz; ---- 131072 -# Perform the query: -query IPT -SELECT - count, - LAG(timestamp, 1) OVER (ORDER BY timestamp), - arrow_typeof(LAG(timestamp, 1) OVER (ORDER BY timestamp)) -FROM timestamp_with_tz -LIMIT 10; +# Ensure that output timestamp columns preserve the timezone from the input +# and output record count will match input file record count +query TPI +SELECT arrow_typeof(lag_timestamp), + MIN(lag_timestamp), + COUNT(1) +FROM ( + SELECT + count, + LAG(timestamp, 1) OVER (ORDER BY timestamp) AS lag_timestamp + FROM timestamp_with_tz +) t +GROUP BY 1 ---- -0 NULL Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -4 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -14 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +Timestamp(Millisecond, Some("UTC")) 2014-08-27T14:00:00Z 131072 # Test config listing_table_ignore_subdirectory: From bf6c82fc18a8243380b8302f9b266344b6955656 Mon Sep 17 00:00:00 2001 From: Alexander Rafferty Date: Wed, 4 Sep 2024 06:12:04 +1000 Subject: [PATCH 5/6] Optimize `struct` and `named_struct` functions (#11688) * Remove unnecessary heap allocations in implementation of `named_struct_expr` caused by zipping then unzipping fields and values. * Change implementation of `array_struct` to reduce number of allocations * Remove tests already covered by `struct.slt` --- datafusion/functions/src/core/named_struct.rs | 21 +++--- datafusion/functions/src/core/struct.rs | 68 ++++--------------- 2 files changed, 21 insertions(+), 68 deletions(-) diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index f71b1b00f0fe..85c332745355 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -70,20 +70,17 @@ fn named_struct_expr(args: &[ColumnarValue]) -> Result { } } - let arrays = ColumnarValue::values_to_arrays(&values)?; - - let fields = names + let fields: Fields = names .into_iter() - .zip(arrays) - .map(|(name, value)| { - ( - Arc::new(Field::new(name, value.data_type().clone(), true)), - value, - ) - }) - .collect::>(); + .zip(&values) + .map(|(name, value)| Arc::new(Field::new(name, value.data_type().clone(), true))) + .collect::>() + .into(); + + let arrays = ColumnarValue::values_to_arrays(&values)?; - Ok(ColumnarValue::Array(Arc::new(StructArray::from(fields)))) + let struct_array = StructArray::new(fields, arrays, None); + Ok(ColumnarValue::Array(Arc::new(struct_array))) } #[derive(Debug)] diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index c3dee8b1ccb4..bdddbb81beab 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -29,23 +29,23 @@ fn array_struct(args: &[ArrayRef]) -> Result { return exec_err!("struct requires at least one argument"); } - let vec: Vec<_> = args + let fields = args .iter() .enumerate() .map(|(i, arg)| { let field_name = format!("c{i}"); - Ok(( - Arc::new(Field::new( - field_name.as_str(), - arg.data_type().clone(), - true, - )), - Arc::clone(arg), - )) + Ok(Arc::new(Field::new( + field_name.as_str(), + arg.data_type().clone(), + true, + ))) }) - .collect::>>()?; + .collect::>>()? + .into(); - Ok(Arc::new(StructArray::from(vec))) + let arrays = args.to_vec(); + + Ok(Arc::new(StructArray::new(fields, arrays, None))) } /// put values in a struct array. @@ -53,6 +53,7 @@ fn struct_expr(args: &[ColumnarValue]) -> Result { let arrays = ColumnarValue::values_to_arrays(args)?; Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?)) } + #[derive(Debug)] pub struct StructFunc { signature: Signature, @@ -97,48 +98,3 @@ impl ScalarUDFImpl for StructFunc { struct_expr(args) } } - -#[cfg(test)] -mod tests { - use super::*; - use arrow::array::Int64Array; - use datafusion_common::cast::as_struct_array; - use datafusion_common::ScalarValue; - - #[test] - fn test_struct() { - // struct(1, 2, 3) = {"c0": 1, "c1": 2, "c2": 3} - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ]; - let struc = struct_expr(&args) - .expect("failed to initialize function struct") - .into_array(1) - .expect("Failed to convert to array"); - let result = - as_struct_array(&struc).expect("failed to initialize function struct"); - assert_eq!( - &Int64Array::from(vec![1]), - Arc::clone(result.column_by_name("c0").unwrap()) - .as_any() - .downcast_ref::() - .unwrap() - ); - assert_eq!( - &Int64Array::from(vec![2]), - Arc::clone(result.column_by_name("c1").unwrap()) - .as_any() - .downcast_ref::() - .unwrap() - ); - assert_eq!( - &Int64Array::from(vec![3]), - Arc::clone(result.column_by_name("c2").unwrap()) - .as_any() - .downcast_ref::() - .unwrap() - ); - } -} From e4a94243b502da2ad07a358b4401052651952eea Mon Sep 17 00:00:00 2001 From: WeblWabl Date: Tue, 3 Sep 2024 15:12:47 -0500 Subject: [PATCH 6/6] Update the CONCAT scalar function to support Utf8View (#12224) * wip * feat: Update the CONCAT scalar function to support Utf8View * fmt * fmt and add default return type for concat * fix clippy lint Signed-off-by: Devan * fmt Signed-off-by: Devan * add more tests for sqllogic Signed-off-by: Devan * make sure no casting with LargeUtf8 * fixing utf8large * fix large utf8 Signed-off-by: Devan * fix large utf8 Signed-off-by: Devan * add test Signed-off-by: Devan * fmt Signed-off-by: Devan * make it so Utf8View just returns Utf8 Signed-off-by: Devan * wip -- trying to build a stringview with columnar refs Signed-off-by: Devan * built stringview builder but it does allocate a new String each iter :( Signed-off-by: Devan * add some testing Signed-off-by: Devan * clippy Signed-off-by: Devan --------- Signed-off-by: Devan --- datafusion/functions/src/string/common.rs | 195 +++++++++++++++++- datafusion/functions/src/string/concat.rs | 184 ++++++++++++++--- .../sqllogictest/test_files/string_view.slt | 71 ++++++- 3 files changed, 416 insertions(+), 34 deletions(-) diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 9738cb812f90..6ebcc4ee6cd3 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -22,12 +22,11 @@ use std::sync::Arc; use arrow::array::{ new_null_array, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, ArrayRef, - GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, - StringBuilder, StringViewArray, + GenericStringArray, GenericStringBuilder, LargeStringArray, OffsetSizeTrait, + StringArray, StringBuilder, StringViewArray, StringViewBuilder, }; use arrow::buffer::{Buffer, MutableBuffer, NullBuffer}; use arrow::datatypes::DataType; - use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::Result; use datafusion_common::{exec_err, ScalarValue}; @@ -249,26 +248,41 @@ where } } +#[derive(Debug)] pub(crate) enum ColumnarValueRef<'a> { Scalar(&'a [u8]), NullableArray(&'a StringArray), NonNullableArray(&'a StringArray), + NullableLargeStringArray(&'a LargeStringArray), + NonNullableLargeStringArray(&'a LargeStringArray), + NullableStringViewArray(&'a StringViewArray), + NonNullableStringViewArray(&'a StringViewArray), } impl<'a> ColumnarValueRef<'a> { #[inline] pub fn is_valid(&self, i: usize) -> bool { match &self { - Self::Scalar(_) | Self::NonNullableArray(_) => true, + Self::Scalar(_) + | Self::NonNullableArray(_) + | Self::NonNullableLargeStringArray(_) + | Self::NonNullableStringViewArray(_) => true, Self::NullableArray(array) => array.is_valid(i), + Self::NullableStringViewArray(array) => array.is_valid(i), + Self::NullableLargeStringArray(array) => array.is_valid(i), } } #[inline] pub fn nulls(&self) -> Option { match &self { - Self::Scalar(_) | Self::NonNullableArray(_) => None, + Self::Scalar(_) + | Self::NonNullableArray(_) + | Self::NonNullableStringViewArray(_) + | Self::NonNullableLargeStringArray(_) => None, Self::NullableArray(array) => array.nulls().cloned(), + Self::NullableStringViewArray(array) => array.nulls().cloned(), + Self::NullableLargeStringArray(array) => array.nulls().cloned(), } } } @@ -387,10 +401,30 @@ impl StringArrayBuilder { .extend_from_slice(array.value(i).as_bytes()); } } + ColumnarValueRef::NullableLargeStringArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableStringViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } ColumnarValueRef::NonNullableArray(array) => { self.value_buffer .extend_from_slice(array.value(i).as_bytes()); } + ColumnarValueRef::NonNullableLargeStringArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableStringViewArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } } } @@ -416,6 +450,157 @@ impl StringArrayBuilder { } } +pub(crate) struct StringViewArrayBuilder { + builder: StringViewBuilder, + block: String, +} + +impl StringViewArrayBuilder { + pub fn with_capacity(_item_capacity: usize, data_capacity: usize) -> Self { + let builder = StringViewBuilder::with_capacity(data_capacity); + Self { + builder, + block: String::new(), + } + } + + pub fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) { + match column { + ColumnarValueRef::Scalar(s) => { + self.block.push_str(std::str::from_utf8(s).unwrap()); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.push_str( + std::str::from_utf8(array.value(i).as_bytes()).unwrap(), + ); + } + } + ColumnarValueRef::NullableLargeStringArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.push_str( + std::str::from_utf8(array.value(i).as_bytes()).unwrap(), + ); + } + } + ColumnarValueRef::NullableStringViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.push_str( + std::str::from_utf8(array.value(i).as_bytes()).unwrap(), + ); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.block + .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + } + ColumnarValueRef::NonNullableLargeStringArray(array) => { + self.block + .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + } + ColumnarValueRef::NonNullableStringViewArray(array) => { + self.block + .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + } + } + } + + pub fn append_offset(&mut self) { + self.builder.append_value(&self.block); + self.block = String::new(); + } + + pub fn finish(mut self) -> StringViewArray { + self.builder.finish() + } +} + +pub(crate) struct LargeStringArrayBuilder { + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, +} + +impl LargeStringArrayBuilder { + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let mut offsets_buffer = MutableBuffer::with_capacity( + (item_capacity + 1) * std::mem::size_of::(), + ); + // SAFETY: the first offset value is definitely not going to exceed the bounds. + unsafe { offsets_buffer.push_unchecked(0_i64) }; + Self { + offsets_buffer, + value_buffer: MutableBuffer::with_capacity(data_capacity), + } + } + + pub fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) { + match column { + ColumnarValueRef::Scalar(s) => { + self.value_buffer.extend_from_slice(s); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableLargeStringArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableStringViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableLargeStringArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableStringViewArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + } + + pub fn append_offset(&mut self) { + let next_offset: i64 = self + .value_buffer + .len() + .try_into() + .expect("byte array offset overflow"); + unsafe { self.offsets_buffer.push_unchecked(next_offset) }; + } + + pub fn finish(self, null_buffer: Option) -> LargeStringArray { + let array_builder = ArrayDataBuilder::new(DataType::LargeUtf8) + .len(self.offsets_buffer.len() / std::mem::size_of::() - 1) + .add_buffer(self.offsets_buffer.into()) + .add_buffer(self.value_buffer.into()) + .nulls(null_buffer); + // SAFETY: all data that was appended was valid Large UTF8 and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + LargeStringArray::from(array_data) + } +} + fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result where O: OffsetSizeTrait, diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 6d15e2206721..00fe69b0bd33 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::{as_largestring_array, Array}; +use arrow::datatypes::DataType; use std::any::Any; use std::sync::Arc; -use arrow::datatypes::DataType; -use arrow::datatypes::DataType::Utf8; - -use datafusion_common::cast::as_string_array; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::cast::{as_string_array, as_string_view_array}; +use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{lit, ColumnarValue, Expr, Volatility}; @@ -46,7 +45,10 @@ impl ConcatFunc { pub fn new() -> Self { use DataType::*; Self { - signature: Signature::variadic(vec![Utf8], Volatility::Immutable), + signature: Signature::variadic( + vec![Utf8, Utf8View, LargeUtf8], + Volatility::Immutable, + ), } } } @@ -64,13 +66,36 @@ impl ScalarUDFImpl for ConcatFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Utf8) + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + let mut dt = &Utf8; + arg_types.iter().for_each(|data_type| { + if data_type == &Utf8View { + dt = data_type; + } + if data_type == &LargeUtf8 && dt != &Utf8View { + dt = data_type; + } + }); + + Ok(dt.to_owned()) } /// Concatenates the text representations of all the arguments. NULL arguments are ignored. /// concat('abcde', 2, NULL, 22) = 'abcde222' fn invoke(&self, args: &[ColumnarValue]) -> Result { + let mut return_datatype = DataType::Utf8; + args.iter().for_each(|col| { + if col.data_type() == DataType::Utf8View { + return_datatype = col.data_type(); + } + if col.data_type() == DataType::LargeUtf8 + && return_datatype != DataType::Utf8View + { + return_datatype = col.data_type(); + } + }); + let array_len = args .iter() .filter_map(|x| match x { @@ -87,7 +112,21 @@ impl ScalarUDFImpl for ConcatFunc { result.push_str(v); } } - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); + + return match return_datatype { + DataType::Utf8View => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) + } + DataType::Utf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) + } + other => { + plan_err!("Concat function does not support datatype of {other}") + } + }; } // Array @@ -103,28 +142,95 @@ impl ScalarUDFImpl for ConcatFunc { columns.push(ColumnarValueRef::Scalar(s.as_bytes())); } } + ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => { + if let Some(s) = maybe_value { + data_size += s.len() * len; + columns.push(ColumnarValueRef::Scalar(s.as_bytes())); + } + } ColumnarValue::Array(array) => { - let string_array = as_string_array(array)?; - data_size += string_array.values().len(); - let column = if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) + match array.data_type() { + DataType::Utf8 => { + let string_array = as_string_array(array)?; + + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + }; + columns.push(column); + }, + DataType::LargeUtf8 => { + let string_array = as_largestring_array(array); + + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableLargeStringArray(string_array) + } else { + ColumnarValueRef::NonNullableLargeStringArray(string_array) + }; + columns.push(column); + }, + DataType::Utf8View => { + let string_array = as_string_view_array(array)?; + + data_size += string_array.len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableStringViewArray(string_array) + } else { + ColumnarValueRef::NonNullableStringViewArray(string_array) + }; + columns.push(column); + }, + other => { + return plan_err!("Input was {other} which is not a supported datatype for concat function") + } }; - columns.push(column); } _ => unreachable!(), } } - let mut builder = StringArrayBuilder::with_capacity(len, data_size); - for i in 0..len { - columns - .iter() - .for_each(|column| builder.write::(column, i)); - builder.append_offset(); + match return_datatype { + DataType::Utf8 => { + let mut builder = StringArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + columns + .iter() + .for_each(|column| builder.write::(column, i)); + builder.append_offset(); + } + + let string_array = builder.finish(None); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + DataType::Utf8View => { + let mut builder = StringViewArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + columns + .iter() + .for_each(|column| builder.write::(column, i)); + builder.append_offset(); + } + + let string_array = builder.finish(); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + DataType::LargeUtf8 => { + let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + columns + .iter() + .for_each(|column| builder.write::(column, i)); + builder.append_offset(); + } + + let string_array = builder.finish(None); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + _ => unreachable!(), } - Ok(ColumnarValue::Array(Arc::new(builder.finish(None)))) } /// Simplify the `concat` function by @@ -151,11 +257,11 @@ pub fn simplify_concat(args: Vec) -> Result { for arg in args.clone() { match arg { // filter out `null` args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {} // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. Expr::Literal( - ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)), + ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), ) => contiguous_scalar += &v, Expr::Literal(x) => { return internal_err!( @@ -195,8 +301,9 @@ pub fn simplify_concat(args: Vec) -> Result { mod tests { use super::*; use crate::utils::test::test_function; - use arrow::array::Array; + use arrow::array::{Array, LargeStringArray, StringViewArray}; use arrow::array::{ArrayRef, StringArray}; + use DataType::*; #[test] fn test_functions() -> Result<()> { @@ -232,6 +339,31 @@ mod tests { Utf8, StringArray ); + test_function!( + ConcatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("aa")), + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), + ColumnarValue::Scalar(ScalarValue::from("cc")), + ], + Ok(Some("aacc")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + ConcatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("aa")), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), + ColumnarValue::Scalar(ScalarValue::from("cc")), + ], + Ok(Some("aacc")), + &str, + LargeUtf8, + LargeStringArray + ); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 83c75b8df38c..eb625e530b66 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -768,17 +768,26 @@ logical_plan 01)Projection: character_length(test.column1_utf8view) AS l 02)--TableScan: test projection=[column1_utf8view] -## Ensure no casts for CONCAT -## TODO https://github.com/apache/datafusion/issues/11836 +## Ensure no casts for CONCAT Utf8View query TT EXPLAIN SELECT concat(column1_utf8view, column2_utf8view) as c FROM test; ---- logical_plan -01)Projection: concat(CAST(test.column1_utf8view AS Utf8), CAST(test.column2_utf8view AS Utf8)) AS c +01)Projection: concat(test.column1_utf8view, test.column2_utf8view) AS c 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] +## Ensure no casts for CONCAT LargeUtf8 +query TT +EXPLAIN SELECT + concat(column1_large_utf8, column2_large_utf8) as c +FROM test; +---- +logical_plan +01)Projection: concat(test.column1_large_utf8, test.column2_large_utf8) AS c +02)--TableScan: test projection=[column1_large_utf8, column2_large_utf8] + ## Ensure no casts for CONCAT_WS ## TODO https://github.com/apache/datafusion/issues/11837 query TT @@ -863,6 +872,61 @@ XIANGPENG RAPHAEL NULL +## Should run CONCAT successfully with utf8view +query T +SELECT + concat(column1_utf8view, column2_utf8view) as c +FROM test; +---- +AndrewX +XiangpengXiangpeng +RaphaelR +R + +## Should run CONCAT successfully with utf8 +query T +SELECT + concat(column1_utf8, column2_utf8) as c +FROM test; +---- +AndrewX +XiangpengXiangpeng +RaphaelR +R + +## Should run CONCAT successfully with utf8 and utf8view +query T +SELECT + concat(column1_utf8view, column2_utf8) as c +FROM test; +---- +AndrewX +XiangpengXiangpeng +RaphaelR +R + +## Should run CONCAT successfully with utf8 utf8view and largeutf8 +query T +SELECT + concat(column1_utf8view, column2_utf8, column2_large_utf8) as c +FROM test; +---- +AndrewXX +XiangpengXiangpengXiangpeng +RaphaelRR +RR + +## Should run CONCAT successfully with utf8large +query T +SELECT + concat(column1_large_utf8, column2_large_utf8) as c +FROM test; +---- +AndrewX +XiangpengXiangpeng +RaphaelR +R + ## Ensure no casts for LPAD query TT EXPLAIN SELECT @@ -1307,3 +1371,4 @@ select column2|| ' ' ||column3 from temp; ---- rust fast datafusion cool +