diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index d7059e882e55..839f890bf077 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -20,6 +20,7 @@ pub mod expr; pub mod memory; pub mod proxy; +pub mod string_utils; use crate::error::{_internal_datafusion_err, _internal_err}; use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; diff --git a/datafusion/common/src/utils/string_utils.rs b/datafusion/common/src/utils/string_utils.rs new file mode 100644 index 000000000000..a2231e6786a7 --- /dev/null +++ b/datafusion/common/src/utils/string_utils.rs @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for working with strings + +use arrow::{array::AsArray, datatypes::DataType}; +use arrow_array::Array; + +/// Convenient function to convert an Arrow string array to a vector of strings +pub fn string_array_to_vec(array: &dyn Array) -> Vec> { + match array.data_type() { + DataType::Utf8 => array.as_string::().iter().collect(), + DataType::LargeUtf8 => array.as_string::().iter().collect(), + DataType::Utf8View => array.as_string_view().iter().collect(), + _ => unreachable!(), + } +} diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index fe1df2579932..d7c194304887 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -19,14 +19,16 @@ use arrow::array::{Array, ArrayRef, BooleanArray, OffsetSizeTrait}; use arrow::datatypes::DataType; -use arrow::row::{RowConverter, SortField}; +use arrow::row::{RowConverter, Rows, SortField}; +use arrow_array::GenericListArray; use datafusion_common::cast::as_generic_list_array; +use datafusion_common::utils::string_utils::string_array_to_vec; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use itertools::Itertools; -use crate::utils::check_datatypes; +use crate::utils::make_scalar_function; use std::any::Any; use std::sync::Arc; @@ -93,33 +95,52 @@ impl ScalarUDFImpl for ArrayHas { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + make_scalar_function(array_has_inner)(args) + } - if args.len() != 2 { - return exec_err!("array_has needs two arguments"); - } + fn aliases(&self) -> &[String] { + &self.aliases + } +} - let array_type = args[0].data_type(); - - match array_type { - DataType::List(_) => general_array_has_dispatch::( - &args[0], - &args[1], - ComparisonType::Single, - ) - .map(ColumnarValue::Array), - DataType::LargeList(_) => general_array_has_dispatch::( - &args[0], - &args[1], - ComparisonType::Single, - ) - .map(ColumnarValue::Array), - _ => exec_err!("array_has does not support type '{array_type:?}'."), +fn array_has_inner(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::List(_) => array_has_dispatch::(&args[0], &args[1]), + DataType::LargeList(_) => array_has_dispatch::(&args[0], &args[1]), + _ => exec_err!( + "array_has does not support type '{:?}'.", + args[0].data_type() + ), + } +} + +fn array_has_all_inner(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::List(_) => { + array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::All) + } + DataType::LargeList(_) => { + array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::All) } + _ => exec_err!( + "array_has does not support type '{:?}'.", + args[0].data_type() + ), } +} - fn aliases(&self) -> &[String] { - &self.aliases +fn array_has_any_inner(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::List(_) => { + array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + DataType::LargeList(_) => { + array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + _ => exec_err!( + "array_has does not support type '{:?}'.", + args[0].data_type() + ), } } @@ -161,24 +182,7 @@ impl ScalarUDFImpl for ArrayHasAll { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - if args.len() != 2 { - return exec_err!("array_has_all needs two arguments"); - } - - let array_type = args[0].data_type(); - - match array_type { - DataType::List(_) => { - general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) - .map(ColumnarValue::Array) - } - DataType::LargeList(_) => { - general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) - .map(ColumnarValue::Array) - } - _ => exec_err!("array_has_all does not support type '{array_type:?}'."), - } + make_scalar_function(array_has_all_inner)(args) } fn aliases(&self) -> &[String] { @@ -224,25 +228,7 @@ impl ScalarUDFImpl for ArrayHasAny { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - - if args.len() != 2 { - return exec_err!("array_has_any needs two arguments"); - } - - let array_type = args[0].data_type(); - - match array_type { - DataType::List(_) => { - general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) - .map(ColumnarValue::Array) - } - DataType::LargeList(_) => { - general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) - .map(ColumnarValue::Array) - } - _ => exec_err!("array_has_any does not support type '{array_type:?}'."), - } + make_scalar_function(array_has_any_inner)(args) } fn aliases(&self) -> &[String] { @@ -251,75 +237,176 @@ impl ScalarUDFImpl for ArrayHasAny { } /// Represents the type of comparison for array_has. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone, Copy)] enum ComparisonType { // array_has_all All, // array_has_any Any, - // array_has - Single, } -fn general_array_has_dispatch( +fn array_has_dispatch( + haystack: &ArrayRef, + needle: &ArrayRef, +) -> Result { + let haystack = as_generic_list_array::(haystack)?; + match needle.data_type() { + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + array_has_string_internal::(haystack, needle) + } + _ => general_array_has::(haystack, needle), + } +} + +fn array_has_all_and_any_dispatch( haystack: &ArrayRef, needle: &ArrayRef, comparison_type: ComparisonType, ) -> Result { - let array = if comparison_type == ComparisonType::Single { - let arr = as_generic_list_array::(haystack)?; - check_datatypes("array_has", &[arr.values(), needle])?; - arr - } else { - check_datatypes("array_has", &[haystack, needle])?; - as_generic_list_array::(haystack)? - }; + let haystack = as_generic_list_array::(haystack)?; + let needle = as_generic_list_array::(needle)?; + match needle.data_type() { + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + array_has_all_and_any_string_internal::(haystack, needle, comparison_type) + } + _ => general_array_has_for_all_and_any::(haystack, needle, comparison_type), + } +} - let mut boolean_builder = BooleanArray::builder(array.len()); +fn array_has_string_internal( + haystack: &GenericListArray, + needle: &ArrayRef, +) -> Result { + let mut boolean_builder = BooleanArray::builder(haystack.len()); + for (arr, element) in haystack.iter().zip(string_array_to_vec(needle).into_iter()) { + match (arr, element) { + (Some(arr), Some(element)) => { + boolean_builder.append_value( + string_array_to_vec(&arr) + .into_iter() + .flatten() + .any(|x| x == element), + ); + } + (_, _) => { + boolean_builder.append_null(); + } + } + } + + Ok(Arc::new(boolean_builder.finish())) +} +fn general_array_has( + array: &GenericListArray, + needle: &ArrayRef, +) -> Result { + let mut boolean_builder = BooleanArray::builder(array.len()); let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; + let sub_arr_values = converter.convert_columns(&[Arc::clone(needle)])?; + + for (row_idx, arr) in array.iter().enumerate() { + if let Some(arr) = arr { + let arr_values = converter.convert_columns(&[arr])?; + boolean_builder.append_value( + arr_values + .iter() + .dedup() + .any(|x| x == sub_arr_values.row(row_idx)), + ); + } else { + boolean_builder.append_null(); + } + } - let element = Arc::clone(needle); - let sub_array = if comparison_type != ComparisonType::Single { - as_generic_list_array::(needle)? - } else { - array - }; - for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() { + Ok(Arc::new(boolean_builder.finish())) +} + +// String comparison for array_has_all and array_has_any +fn array_has_all_and_any_string_internal( + array: &GenericListArray, + needle: &GenericListArray, + comparison_type: ComparisonType, +) -> Result { + let mut boolean_builder = BooleanArray::builder(array.len()); + for (arr, sub_arr) in array.iter().zip(needle.iter()) { match (arr, sub_arr) { (Some(arr), Some(sub_arr)) => { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = if comparison_type != ComparisonType::Single { - converter.convert_columns(&[sub_arr])? - } else { - converter.convert_columns(&[Arc::clone(&element)])? - }; - - let mut res = match comparison_type { - ComparisonType::All => sub_arr_values - .iter() - .dedup() - .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), - ComparisonType::Any => sub_arr_values - .iter() - .dedup() - .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), - ComparisonType::Single => arr_values - .iter() - .dedup() - .any(|x| x == sub_arr_values.row(row_idx)), - }; - - if comparison_type == ComparisonType::Any { - res |= res; - } - boolean_builder.append_value(res); + let haystack_array = string_array_to_vec(&arr); + let needle_array = string_array_to_vec(&sub_arr); + boolean_builder.append_value(array_has_string_kernel( + haystack_array, + needle_array, + comparison_type, + )); } - // respect null input (_, _) => { boolean_builder.append_null(); } } } + Ok(Arc::new(boolean_builder.finish())) } + +fn array_has_string_kernel( + haystack: Vec>, + needle: Vec>, + comparison_type: ComparisonType, +) -> bool { + match comparison_type { + ComparisonType::All => needle + .iter() + .dedup() + .all(|x| haystack.iter().dedup().any(|y| y == x)), + ComparisonType::Any => needle + .iter() + .dedup() + .any(|x| haystack.iter().dedup().any(|y| y == x)), + } +} + +// General row comparison for array_has_all and array_has_any +fn general_array_has_for_all_and_any( + haystack: &GenericListArray, + needle: &GenericListArray, + comparison_type: ComparisonType, +) -> Result { + let mut boolean_builder = BooleanArray::builder(haystack.len()); + let converter = RowConverter::new(vec![SortField::new(haystack.value_type())])?; + + for (arr, sub_arr) in haystack.iter().zip(needle.iter()) { + if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = converter.convert_columns(&[sub_arr])?; + boolean_builder.append_value(general_array_has_all_and_any_kernel( + arr_values, + sub_arr_values, + comparison_type, + )); + } else { + boolean_builder.append_null(); + } + } + + Ok(Arc::new(boolean_builder.finish())) +} + +fn general_array_has_all_and_any_kernel( + haystack_rows: Rows, + needle_rows: Rows, + comparison_type: ComparisonType, +) -> bool { + match comparison_type { + ComparisonType::All => needle_rows.iter().all(|needle_row| { + haystack_rows + .iter() + .any(|haystack_row| haystack_row == needle_row) + }), + ComparisonType::Any => needle_rows.iter().any(|needle_row| { + haystack_rows + .iter() + .any(|haystack_row| haystack_row == needle_row) + }), + } +}