From c6be00d5012e89ea29611f02b9edf15c806db6c5 Mon Sep 17 00:00:00 2001 From: Dmitry Bugakov Date: Wed, 21 Aug 2024 21:05:03 +0200 Subject: [PATCH] Add Utf8View support to STRPOS function (#12087) * Add Utf8View support to STRPOS function * fix type inconsistency * fix type inconsistency * refactor tests --- datafusion/functions/src/unicode/strpos.rs | 175 ++++++++++++------ .../sqllogictest/test_files/string_view.slt | 5 +- 2 files changed, 121 insertions(+), 59 deletions(-) diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 702baf6e8fa7..cf10b18ae338 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -19,11 +19,10 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ - ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, + ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; -use datafusion_common::cast::as_generic_string_array; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -52,6 +51,9 @@ impl StrposFunc { Exact(vec![Utf8, LargeUtf8]), Exact(vec![LargeUtf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8]), + Exact(vec![Utf8View, Utf8View]), + Exact(vec![Utf8View, Utf8]), + Exact(vec![Utf8View, LargeUtf8]), ], Volatility::Immutable, ), @@ -78,21 +80,7 @@ impl ScalarUDFImpl for StrposFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match (args[0].data_type(), args[1].data_type()) { - (DataType::Utf8, DataType::Utf8) => { - make_scalar_function(strpos::, vec![])(args) - } - (DataType::Utf8, DataType::LargeUtf8) => { - make_scalar_function(strpos::, vec![])(args) - } - (DataType::LargeUtf8, DataType::Utf8) => { - make_scalar_function(strpos::, vec![])(args) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => { - make_scalar_function(strpos::, vec![])(args) - } - other => exec_err!("Unsupported data type {other:?} for function strpos"), - } + make_scalar_function(strpos, vec![])(args) } fn aliases(&self) -> &[String] { @@ -100,30 +88,71 @@ impl ScalarUDFImpl for StrposFunc { } } +fn strpos(args: &[ArrayRef]) -> Result { + match (args[0].data_type(), args[1].data_type()) { + (DataType::Utf8, DataType::Utf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::Utf8, DataType::LargeUtf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::LargeUtf8, DataType::Utf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int64Type>(string_array, substring_array) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int64Type>(string_array, substring_array) + } + (DataType::Utf8View, DataType::Utf8View) => { + let string_array = args[0].as_string_view(); + let substring_array = args[1].as_string_view(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::Utf8View, DataType::Utf8) => { + let string_array = args[0].as_string_view(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::Utf8View, DataType::LargeUtf8) => { + let string_array = args[0].as_string_view(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + + other => { + exec_err!("Unsupported data type combination {other:?} for function strpos") + } + } +} + /// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) /// strpos('high', 'ig') = 2 /// The implementation uses UTF-8 code points as characters -fn strpos( - args: &[ArrayRef], +fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>( + string_array: V1, + substring_array: V2, ) -> Result where - T0::Native: OffsetSizeTrait, - T1::Native: OffsetSizeTrait, + V1: ArrayAccessor, + V2: ArrayAccessor, { - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - - let substring_array: &GenericStringArray = - as_generic_string_array::(&args[1])?; + let string_iter = ArrayIter::new(string_array); + let substring_iter = ArrayIter::new(substring_array); - let result = string_array - .iter() - .zip(substring_array.iter()) + let result = string_iter + .zip(substring_iter) .map(|(string, substring)| match (string, substring) { (Some(string), Some(substring)) => { - // the find method returns the byte index of the substring - // Next, we count the number of the chars until that byte - T0::Native::from_usize( + // The `find` method returns the byte index of the substring. + // We count the number of chars up to that byte index. + T::Native::from_usize( string .find(substring) .map(|x| string[..x].chars().count() + 1) @@ -132,20 +161,21 @@ where } _ => None, }) - .collect::>(); + .collect::>(); Ok(Arc::new(result) as ArrayRef) } #[cfg(test)] -mod test { - use super::*; +mod tests { + use arrow::array::{Array, Int32Array, Int64Array}; + use arrow::datatypes::DataType::{Int32, Int64}; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::strpos::StrposFunc; use crate::utils::test::test_function; - use arrow::{ - array::{Array as _, Int32Array, Int64Array}, - datatypes::DataType::{Int32, Int64}, - }; - use datafusion_common::ScalarValue; macro_rules! test_strpos { ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => { @@ -164,21 +194,54 @@ mod test { } #[test] - fn strpos() { - test_strpos!("foo", "bar" -> 0; Utf8 Utf8 i32 Int32 Int32Array); - test_strpos!("foobar", "foo" -> 1; Utf8 Utf8 i32 Int32 Int32Array); - test_strpos!("foobar", "bar" -> 4; Utf8 Utf8 i32 Int32 Int32Array); - - test_strpos!("foo", "bar" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - test_strpos!("foobar", "foo" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - test_strpos!("foobar", "bar" -> 4; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - - test_strpos!("foo", "bar" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); - test_strpos!("foobar", "foo" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); - test_strpos!("foobar", "bar" -> 4; Utf8 LargeUtf8 i32 Int32 Int32Array); - - test_strpos!("foo", "bar" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); - test_strpos!("foobar", "foo" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); - test_strpos!("foobar", "bar" -> 4; LargeUtf8 Utf8 i64 Int64 Int64Array); + fn test_strpos_functions() { + // Utf8 and Utf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array); + + // LargeUtf8 and LargeUtf8 combinations + test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + + // Utf8 and LargeUtf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); + + // LargeUtf8 and Utf8 combinations + test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); + + // Utf8View and Utf8View combinations + test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array); + + // Utf8View and Utf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array); + + // Utf8View and LargeUtf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array); } } diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 0b441bcbeb8f..4b4eba0522e4 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -1066,9 +1066,8 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: strpos(__common_expr_1, Utf8("f")) AS c, strpos(__common_expr_1, CAST(test.column2_utf8view AS Utf8)) AS c2 -02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1, test.column2_utf8view -03)----TableScan: test projection=[column1_utf8view, column2_utf8view] +01)Projection: strpos(test.column1_utf8view, Utf8("f")) AS c, strpos(test.column1_utf8view, test.column2_utf8view) AS c2 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for SUBSTR ## TODO file ticket