Skip to content

Commit

Permalink
fix ci
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H committed Dec 12, 2023
1 parent acef85f commit f529f83
Showing 1 changed file with 54 additions and 63 deletions.
117 changes: 54 additions & 63 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
use std::any::type_name;
use std::collections::HashSet;
use std::fmt::{Display, Formatter};
use std::sync::Arc;

use arrow::array::*;
Expand Down Expand Up @@ -1535,11 +1536,25 @@ macro_rules! to_string {
}};
}

/// general function for array_union and array_intersect
fn general_set_lists<OffsetSize: OffsetSizeTrait>(
#[derive(Debug, PartialEq)]
enum SetOp {
Union,
Intersect,
}

impl Display for SetOp {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
SetOp::Union => write!(f, "array_union"),
SetOp::Intersect => write!(f, "array_intersect"),
}
}
}

fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
l: &GenericListArray<OffsetSize>,
r: &GenericListArray<OffsetSize>,
is_union: bool,
set_op: SetOp,
) -> Result<ArrayRef> {
if matches!(l.value_type(), DataType::Null) {
let field = Arc::new(Field::new("item", r.value_type(), true));
Expand All @@ -1550,12 +1565,7 @@ fn general_set_lists<OffsetSize: OffsetSizeTrait>(
}

if l.value_type() != r.value_type() {
let operation = if is_union {
"array_union"
} else {
"array_intersect"
};
return internal_err!("{operation} is not implemented for '{l:?}' and '{r:?}'");
return internal_err!("{set_op} is not implemented for '{l:?}' and '{r:?}'");
}

let dt = l.value_type();
Expand All @@ -1571,14 +1581,23 @@ fn general_set_lists<OffsetSize: OffsetSizeTrait>(

let l_iter = l_values.iter().sorted().dedup();
let values_set: HashSet<_> = l_iter.clone().collect();
let mut rows = if is_union {
let mut rows = if set_op == SetOp::Union {
l_iter.collect::<Vec<_>>()
} else {
vec![]
};
for r_val in r_values.iter().sorted().dedup() {
if !values_set.contains(&r_val) == is_union {
rows.push(r_val);
match set_op {
SetOp::Union => {
if !values_set.contains(&r_val) {
rows.push(r_val);
}
}
SetOp::Intersect => {
if values_set.contains(&r_val) {
rows.push(r_val);
}
}
}
}

Expand All @@ -1591,12 +1610,7 @@ fn general_set_lists<OffsetSize: OffsetSizeTrait>(
let array = match arrays.first() {
Some(array) => array.clone(),
None => {
let operation = if is_union {
"array_union"
} else {
"array_intersect"
};
return internal_err!("{operation}: failed to get array from rows");
return internal_err!("{set_op}: failed to get array from rows");
}
};
new_arrays.push(array);
Expand All @@ -1611,15 +1625,13 @@ fn general_set_lists<OffsetSize: OffsetSizeTrait>(
Ok(Arc::new(arr))
}

/// Array_union SQL function
pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_union needs two arguments");
}
let array1 = &args[0];
let array2 = &args[1];

fn general_set_op(
array1: &ArrayRef,
array2: &ArrayRef,
set_op: SetOp,
) -> Result<ArrayRef> {
match (array1.data_type(), array2.data_type()) {
// Null type
(DataType::Null, DataType::List(field))
| (DataType::List(field), DataType::Null) => {
let array = match array1.data_type() {
Expand All @@ -1637,24 +1649,36 @@ pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
general_array_distinct::<i64>(array, field)
}
(DataType::Null, DataType::Null) => Ok(array1.clone()),

(DataType::List(_), DataType::List(_)) => {
let array1 = as_list_array(&array1)?;
let array2 = as_list_array(&array2)?;
general_set_lists::<i32>(array1, array2, true)
generic_set_lists::<i32>(array1, array2, set_op)
}
(DataType::LargeList(_), DataType::LargeList(_)) => {
let array1 = as_large_list_array(&array1)?;
let array2 = as_large_list_array(&array2)?;
general_set_lists::<i64>(array1, array2, true)
generic_set_lists::<i64>(array1, array2, set_op)
}
(data_type1, data_type2) => {
internal_err!(
"array_union does not support types '{data_type1:?}' and '{data_type2:?}'"
"{set_op} does not support types '{data_type1:?}' and '{data_type2:?}'"
)
}
}
}

/// Array_union SQL function
pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_union needs two arguments");
}
let array1 = &args[0];
let array2 = &args[1];

general_set_op(array1, array2, SetOp::Union)
}

/// array_intersect SQL function
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
Expand All @@ -1664,40 +1688,7 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
let array1 = &args[0];
let array2 = &args[1];

match (array1.data_type(), array2.data_type()) {
(DataType::Null, DataType::List(field))
| (DataType::List(field), DataType::Null) => {
let array = match array1.data_type() {
DataType::Null => as_list_array(&array2)?,
_ => as_list_array(&array1)?,
};
general_array_distinct::<i32>(array, field)
}
(DataType::Null, DataType::LargeList(field))
| (DataType::LargeList(field), DataType::Null) => {
let array = match array1.data_type() {
DataType::Null => as_large_list_array(&array2)?,
_ => as_large_list_array(&array1)?,
};
general_array_distinct::<i64>(array, field)
}
(DataType::Null, DataType::Null) => Ok(array1.clone()),
(DataType::List(_), DataType::List(_)) => {
let array1 = as_list_array(&array1)?;
let array2 = as_list_array(&array2)?;
general_set_lists::<i32>(array1, array2, false)
}
(DataType::LargeList(_), DataType::LargeList(_)) => {
let array1 = as_large_list_array(&array1)?;
let array2 = as_large_list_array(&array2)?;
general_set_lists::<i64>(array1, array2, false)
}
(data_type1, data_type2) => {
internal_err!(
"array_intersect does not support types '{data_type1:?}' and '{data_type2:?}'"
)
}
}
general_set_op(array1, array2, SetOp::Intersect)
}

/// Array_to_string SQL function
Expand Down

0 comments on commit f529f83

Please sign in to comment.