From f529f83cd49056fc607beb788316235937723f00 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Tue, 12 Dec 2023 22:39:07 +0100 Subject: [PATCH] fix ci --- .../physical-expr/src/array_expressions.rs | 117 ++++++++---------- 1 file changed, 54 insertions(+), 63 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index db4cea228d652..679dcabe530e9 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -19,6 +19,7 @@ use std::any::type_name; use std::collections::HashSet; +use std::fmt::{Display, Formatter}; use std::sync::Arc; use arrow::array::*; @@ -1535,11 +1536,25 @@ macro_rules! to_string { }}; } -/// general function for array_union and array_intersect -fn general_set_lists( +#[derive(Debug, PartialEq)] +enum SetOp { + Union, + Intersect, +} + +impl Display for SetOp { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + SetOp::Union => write!(f, "array_union"), + SetOp::Intersect => write!(f, "array_intersect"), + } + } +} + +fn generic_set_lists( l: &GenericListArray, r: &GenericListArray, - is_union: bool, + set_op: SetOp, ) -> Result { if matches!(l.value_type(), DataType::Null) { let field = Arc::new(Field::new("item", r.value_type(), true)); @@ -1550,12 +1565,7 @@ fn general_set_lists( } if l.value_type() != r.value_type() { - let operation = if is_union { - "array_union" - } else { - "array_intersect" - }; - return internal_err!("{operation} is not implemented for '{l:?}' and '{r:?}'"); + return internal_err!("{set_op} is not implemented for '{l:?}' and '{r:?}'"); } let dt = l.value_type(); @@ -1571,14 +1581,23 @@ fn general_set_lists( let l_iter = l_values.iter().sorted().dedup(); let values_set: HashSet<_> = l_iter.clone().collect(); - let mut rows = if is_union { + let mut rows = if set_op == SetOp::Union { l_iter.collect::>() } else { vec![] }; for r_val in r_values.iter().sorted().dedup() { - if !values_set.contains(&r_val) == is_union { - rows.push(r_val); + match set_op { + SetOp::Union => { + if !values_set.contains(&r_val) { + rows.push(r_val); + } + } + SetOp::Intersect => { + if values_set.contains(&r_val) { + rows.push(r_val); + } + } } } @@ -1591,12 +1610,7 @@ fn general_set_lists( let array = match arrays.first() { Some(array) => array.clone(), None => { - let operation = if is_union { - "array_union" - } else { - "array_intersect" - }; - return internal_err!("{operation}: failed to get array from rows"); + return internal_err!("{set_op}: failed to get array from rows"); } }; new_arrays.push(array); @@ -1611,15 +1625,13 @@ fn general_set_lists( Ok(Arc::new(arr)) } -/// Array_union SQL function -pub fn array_union(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_union needs two arguments"); - } - let array1 = &args[0]; - let array2 = &args[1]; - +fn general_set_op( + array1: &ArrayRef, + array2: &ArrayRef, + set_op: SetOp, +) -> Result { match (array1.data_type(), array2.data_type()) { + // Null type (DataType::Null, DataType::List(field)) | (DataType::List(field), DataType::Null) => { let array = match array1.data_type() { @@ -1637,24 +1649,36 @@ pub fn array_union(args: &[ArrayRef]) -> Result { general_array_distinct::(array, field) } (DataType::Null, DataType::Null) => Ok(array1.clone()), + (DataType::List(_), DataType::List(_)) => { let array1 = as_list_array(&array1)?; let array2 = as_list_array(&array2)?; - general_set_lists::(array1, array2, true) + generic_set_lists::(array1, array2, set_op) } (DataType::LargeList(_), DataType::LargeList(_)) => { let array1 = as_large_list_array(&array1)?; let array2 = as_large_list_array(&array2)?; - general_set_lists::(array1, array2, true) + generic_set_lists::(array1, array2, set_op) } (data_type1, data_type2) => { internal_err!( - "array_union does not support types '{data_type1:?}' and '{data_type2:?}'" + "{set_op} does not support types '{data_type1:?}' and '{data_type2:?}'" ) } } } +/// Array_union SQL function +pub fn array_union(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_union needs two arguments"); + } + let array1 = &args[0]; + let array2 = &args[1]; + + general_set_op(array1, array2, SetOp::Union) +} + /// array_intersect SQL function pub fn array_intersect(args: &[ArrayRef]) -> Result { if args.len() != 2 { @@ -1664,40 +1688,7 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result { let array1 = &args[0]; let array2 = &args[1]; - match (array1.data_type(), array2.data_type()) { - (DataType::Null, DataType::List(field)) - | (DataType::List(field), DataType::Null) => { - let array = match array1.data_type() { - DataType::Null => as_list_array(&array2)?, - _ => as_list_array(&array1)?, - }; - general_array_distinct::(array, field) - } - (DataType::Null, DataType::LargeList(field)) - | (DataType::LargeList(field), DataType::Null) => { - let array = match array1.data_type() { - DataType::Null => as_large_list_array(&array2)?, - _ => as_large_list_array(&array1)?, - }; - general_array_distinct::(array, field) - } - (DataType::Null, DataType::Null) => Ok(array1.clone()), - (DataType::List(_), DataType::List(_)) => { - let array1 = as_list_array(&array1)?; - let array2 = as_list_array(&array2)?; - general_set_lists::(array1, array2, false) - } - (DataType::LargeList(_), DataType::LargeList(_)) => { - let array1 = as_large_list_array(&array1)?; - let array2 = as_large_list_array(&array2)?; - general_set_lists::(array1, array2, false) - } - (data_type1, data_type2) => { - internal_err!( - "array_intersect does not support types '{data_type1:?}' and '{data_type2:?}'" - ) - } - } + general_set_op(array1, array2, SetOp::Intersect) } /// Array_to_string SQL function