diff --git a/arrow-string/src/like.rs b/arrow-string/src/like.rs index 1223280e3769..9131f7c2f35a 100644 --- a/arrow-string/src/like.rs +++ b/arrow-string/src/like.rs @@ -235,6 +235,103 @@ dict_function!("STARTSWITH(left, right)", starts_with_dict, starts_with); dict_function!("ENDSWITH(left, right)", ends_with_dict, ends_with); dict_function!("CONTAINS(left, right)", contains_dict, contains); +macro_rules! datum_function { + ($fn_name:ident, $fn_array:ident, $fn_scalar:ident) => { + pub fn $fn_name( + left: &dyn Datum, + right: &dyn Datum, + ) -> Result { + let (left_array, left_scalar) = left.get(); + let (right_array, right_scalar) = right.get(); + let left_type = left_array.data_type(); + let right_type = right_array.data_type(); + + // TODO(alexandreyc): check if PartialEq for DataType is deep or shallow + // i.e. does it check nested subtypes for equality? + if left_type != right_type { + return Err(ArrowError::ComputeError( + "Arrays must have the same data type".to_string(), + )); + } + + if left_scalar && !right_scalar { + return Err(ArrowError::ComputeError( + "Left cannot be scalar when right is not".to_string(), + )); + } + + match left_type { + DataType::Utf8 => { + let left_array = left_array.as_string::(); + let right_array = right_array.as_string::(); + if right_scalar { + $fn_scalar(left_array, right_array.value(0)) + } else { + $fn_array(left_array, right_array) + } + } + DataType::LargeUtf8 => { + let left_array = left_array.as_string::(); + let right_array = right_array.as_string::(); + if right_scalar { + $fn_scalar(left_array, right_array.value(0)) + } else { + $fn_array(left_array, right_array) + } + } + DataType::Dictionary(_, value_type) => match **value_type { + DataType::Utf8 => { + downcast_dictionary_array!( + left_array => { + let right_array = as_dictionary_array(right_array); + let right_array = right_array.downcast_dict::>().unwrap(); + let left_array = left_array.downcast_dict::>().unwrap(); + + if right_scalar { + $fn_scalar(left_array, right_array.value(0)) + } else { + $fn_array(left_array, right_array) + } + } + t => Err(ArrowError::ComputeError(format!( + "Should be DictionaryArray but got: {}", t + ))) + ) + } + DataType::LargeUtf8 => { + downcast_dictionary_array!( + left_array => { + let left_array = left_array.downcast_dict::>().unwrap(); + let right_array = as_dictionary_array(right_array); + let right_array = right_array.downcast_dict::>().unwrap(); + + if right_scalar { + $fn_scalar(left_array, right_array.value(0)) + } else { + $fn_array(left_array, right_array) + } + } + t => Err(ArrowError::ComputeError(format!( + "Should be DictionaryArray but got: {}", t + ))) + ) + } + _ => Err(ArrowError::ComputeError(format!( + "Unsupported dictionnary value type: {}", + value_type + ))), + }, + _ => Err(ArrowError::ComputeError(format!( + "Unsupported data type: {}", + left_type + ))), + } + } + }; +} + +datum_function!(like_datum, like, like_scalar); + /// Perform SQL `left LIKE right` operation on [`StringArray`] / [`LargeStringArray`]. /// /// There are two wildcards supported with the LIKE operator: @@ -807,6 +904,7 @@ mod tests { }; } + // OK test_utf8!( test_utf8_array_like, vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow", "arrow"], @@ -815,6 +913,7 @@ mod tests { vec![true, true, true, false, false, true, false, false] ); + // TODO test_dict_utf8!( test_utf8_array_like_dict, vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow", "arrow"], @@ -823,6 +922,7 @@ mod tests { vec![true, true, true, false, false, true, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_escape_testing, test_utf8_array_like_scalar_dyn_escape_testing, @@ -833,6 +933,7 @@ mod tests { vec![true, true, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_escape_regex, test_utf8_array_like_scalar_dyn_escape_regex, @@ -843,6 +944,7 @@ mod tests { vec![true, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_escape_regex_dot, test_utf8_array_like_scalar_dyn_escape_regex_dot, @@ -853,6 +955,7 @@ mod tests { vec![true, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar, test_utf8_array_like_scalar_dyn, @@ -863,6 +966,7 @@ mod tests { vec![true, true, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_start, test_utf8_array_like_scalar_dyn_start, @@ -885,6 +989,7 @@ mod tests { vec![true, false, true, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_end, test_utf8_array_like_scalar_dyn_end, @@ -907,6 +1012,7 @@ mod tests { vec![true, true, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_equals, test_utf8_array_like_scalar_dyn_equals, @@ -917,6 +1023,7 @@ mod tests { vec![true, false, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_one, test_utf8_array_like_scalar_dyn_one, @@ -927,6 +1034,7 @@ mod tests { vec![false, true, false, false] ); + // OK test_utf8_scalar!( test_utf8_scalar_like_escape, test_utf8_scalar_like_dyn_escape, @@ -937,6 +1045,7 @@ mod tests { vec![true, false] ); + // OK test_utf8_scalar!( test_utf8_scalar_like_escape_contains, test_utf8_scalar_like_dyn_escape_contains, @@ -1939,3 +2048,270 @@ mod tests { ); } } + +#[cfg(test)] +mod tests_datum { + use super::*; + use arrow_array::types::Int8Type; + use std::sync::Arc; + + macro_rules! test_array_array { + ($test_name:ident, $op:ident, $left:expr, $right:expr, $expected:expr) => { + #[test] + fn $test_name() { + let expected = $expected; + + // StringArray + let left = StringArray::from($left); + let right = StringArray::from($right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), expected.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + + // LargeStringArray + let left = LargeStringArray::from($left); + let right = LargeStringArray::from($right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), expected.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + + // DictionnaryArray + let left: DictionaryArray = $left.into_iter().collect(); + let right: DictionaryArray = $right.into_iter().collect(); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), expected.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + } + }; + } + + macro_rules! test_array_scalar { + ($test_name:ident, $op:ident, $left:expr, $right:expr, $expected:expr) => { + #[test] + fn $test_name() { + let expected = $expected; + + // StringArray + let left = StringArray::from($left); + let right = StringArray::from(vec![$right]); + let right = Scalar::new(&right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), expected.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + + // LargeStringArray + let left = LargeStringArray::from($left); + let right = LargeStringArray::from(vec![$right]); + let right = Scalar::new(&right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), expected.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + + // DictionnaryArray + let left: DictionaryArray = $left.into_iter().collect(); + let right: DictionaryArray = [$right].into_iter().collect(); + let right = Scalar::new(&right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), expected.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + } + }; + } + + macro_rules! test_scalar_scalar { + ($test_name:ident, $op:ident, $left:expr, $right:expr, $expected:expr) => { + #[test] + fn $test_name() { + let expected = $expected; + + // StringArray + let left = StringArray::from(vec![$left]); + let left = Scalar::new(&left); + let right = StringArray::from(vec![$right]); + let right = Scalar::new(&right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), 1); + assert_eq!(res.value(0), expected); + + // LargeStringArray + let left = LargeStringArray::from(vec![$left]); + let right = LargeStringArray::from(vec![$right]); + let right = Scalar::new(&right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), 1); + assert_eq!(res.value(0), expected); + + // DictionnaryArray + let left: DictionaryArray = [$left].into_iter().collect(); + let left = Scalar::new(&left); + let right: DictionaryArray = [$right].into_iter().collect(); + let right = Scalar::new(&right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), 1); + assert_eq!(res.value(0), expected); + } + }; + } + + macro_rules! test_errors { + ($test_name:ident, $op:ident) => { + #[test] + fn $test_name() { + let left = StringArray::from(vec!["a"]); + let left = Scalar::new(&left); + let right = StringArray::from(vec!["a", "b", "c", "d"]); + let res = $op(&left, &right).unwrap_err(); + assert_eq!( + res.to_string(), + "Compute error: Left cannot be scalar when right is not" + ); + + let left = StringArray::from(vec!["a", "b", "c", "d"]); + let right = StringArray::from(vec!["a"]); + let res = $op(&left, &right).unwrap_err(); + assert_eq!( + res.to_string(), + "Compute error: Cannot perform comparison operation on arrays of different length" + ); + + let left = StringArray::from(vec!["a", "b", "c", "d"]); + let right = LargeStringArray::from(vec!["a", "b", "c", "d"]); + let res = $op(&left, &right).unwrap_err(); + assert_eq!( + res.to_string(), + "Compute error: Arrays must have the same data type" + ); + + let left = Int32Array::from(vec![1, 2, 3, 4]); + let right = Int32Array::from(vec![1, 2, 3, 4]); + let res = $op(&left, &right).unwrap_err(); + assert_eq!( + res.to_string(), + format!("Compute error: Unsupported data type: {}", left.data_type()) + ); + + let values = Arc::new(BinaryArray::from_iter_values(["a", "b", "c"])); + let keys = Int8Array::from_iter_values([0, 0, 1, 2]); + let left = DictionaryArray::::try_new(keys.clone(), values.clone()).unwrap(); + let right = DictionaryArray::::try_new(keys.clone(), values.clone()).unwrap(); + let res = $op(&left, &right).unwrap_err(); + assert_eq!( + res.to_string(), + format!("Compute error: Unsupported dictionnary value type: {}", left.value_type()) + ); + } + } + } + + // LIKE + + test_array_array!( + test_like_array_array, + like_datum, + vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow", "arrow"], + vec!["arrow", "ar%", "%ro%", "foo", "arr", "arrow_", "arrow_", ".*"], + vec![true, true, true, false, false, true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_start, + like_datum, + vec!["arrow", "parrow", "arrows", "arr"], + "arrow%", + vec![true, false, true, false] + ); + + test_array_scalar!( + test_like_array_scalar_end, + like_datum, + vec!["arrow", "parrow", "arrows", "arr"], + "%arrow", + vec![true, true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_start_end, + like_datum, + vec!["arrow", "parquet", "datafusion", "flight"], + "%ar%", + vec![true, true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_escape_testing, + like_datum, + vec!["varchar(255)", "int(255)", "varchar", "int"], + "%(%)%", + vec![true, true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_escape_regex, + like_datum, + vec![".*", "a", "*"], + ".*", + vec![true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_escape_regex_dot, + like_datum, + vec![".", "a", "*"], + ".", + vec![true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_equals, + like_datum, + vec!["arrow", "parrow", "arrows", "arr"], + "arrow", + vec![true, false, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_one, + like_datum, + vec!["arrow", "arrows", "parrow", "arr"], + "arrow_", + vec![false, true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_escape, + like_datum, + vec!["a%", "a\\x"], + "a\\%", + vec![true, false] + ); + + test_array_scalar!( + test_like_array_scalar_escape_contains, + like_datum, + vec!["ba%", "ba\\x"], + "%a\\%", + vec![true, false] + ); + + test_scalar_scalar!(test_like_scalar_scalar, like_datum, "arrow", "%rr%", true); + + test_errors!(test_like_errors, like_datum); +}