From 18fc37629250d22faa6ead109725ebf94a4fa532 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Tue, 14 May 2024 07:32:29 +0800 Subject: [PATCH] feat: allow `array_slice` to take an optional stride parameter (#10469) * feat: allow array_slice to take an optional stride parameter * Use ScalarUDF::call * Use create_function and add test * format * fix cargo doc --- datafusion/functions-array/src/array_has.rs | 6 +-- datafusion/functions-array/src/cardinality.rs | 2 +- datafusion/functions-array/src/concat.rs | 6 +-- datafusion/functions-array/src/dimension.rs | 4 +- datafusion/functions-array/src/empty.rs | 2 +- datafusion/functions-array/src/except.rs | 2 +- datafusion/functions-array/src/extract.rs | 23 +++++----- datafusion/functions-array/src/flatten.rs | 2 +- datafusion/functions-array/src/length.rs | 2 +- datafusion/functions-array/src/macros.rs | 44 +++++++++---------- datafusion/functions-array/src/make_array.rs | 2 +- datafusion/functions-array/src/position.rs | 4 +- datafusion/functions-array/src/range.rs | 4 +- datafusion/functions-array/src/remove.rs | 6 +-- datafusion/functions-array/src/repeat.rs | 2 +- datafusion/functions-array/src/replace.rs | 6 +-- datafusion/functions-array/src/resize.rs | 2 +- datafusion/functions-array/src/reverse.rs | 2 +- datafusion/functions-array/src/rewrite.rs | 2 +- datafusion/functions-array/src/set_ops.rs | 6 +-- datafusion/functions-array/src/sort.rs | 2 +- datafusion/functions-array/src/string.rs | 4 +- .../tests/cases/roundtrip_logical_plan.rs | 6 +++ 23 files changed, 74 insertions(+), 67 deletions(-) diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-array/src/array_has.rs index e5e8add95fbe..43d6046f4f82 100644 --- a/datafusion/functions-array/src/array_has.rs +++ b/datafusion/functions-array/src/array_has.rs @@ -34,19 +34,19 @@ use std::any::Any; use std::sync::Arc; // Create static instances of ScalarUDFs for each function -make_udf_function!(ArrayHas, +make_udf_expr_and_func!(ArrayHas, array_has, first_array second_array, // arg name "returns true, if the element appears in the first array, otherwise false.", // doc array_has_udf // internal function name ); -make_udf_function!(ArrayHasAll, +make_udf_expr_and_func!(ArrayHasAll, array_has_all, first_array second_array, // arg name "returns true if each element of the second array appears in the first array; otherwise, it returns false.", // doc array_has_all_udf // internal function name ); -make_udf_function!(ArrayHasAny, +make_udf_expr_and_func!(ArrayHasAny, array_has_any, first_array second_array, // arg name "returns true if at least one element of the second array appears in the first array; otherwise, it returns false.", // doc diff --git a/datafusion/functions-array/src/cardinality.rs b/datafusion/functions-array/src/cardinality.rs index ed9f8d01f973..d6f2456313bc 100644 --- a/datafusion/functions-array/src/cardinality.rs +++ b/datafusion/functions-array/src/cardinality.rs @@ -29,7 +29,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( Cardinality, cardinality, array, diff --git a/datafusion/functions-array/src/concat.rs b/datafusion/functions-array/src/concat.rs index f9d9bf4356ff..a6fed84fa765 100644 --- a/datafusion/functions-array/src/concat.rs +++ b/datafusion/functions-array/src/concat.rs @@ -36,7 +36,7 @@ use datafusion_expr::{ use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; -make_udf_function!( +make_udf_expr_and_func!( ArrayAppend, array_append, array element, // arg name @@ -96,7 +96,7 @@ impl ScalarUDFImpl for ArrayAppend { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayPrepend, array_prepend, element array, @@ -156,7 +156,7 @@ impl ScalarUDFImpl for ArrayPrepend { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayConcat, array_concat, "Concatenates arrays.", diff --git a/datafusion/functions-array/src/dimension.rs b/datafusion/functions-array/src/dimension.rs index 569eff66f7f4..1dc6520f1bc7 100644 --- a/datafusion/functions-array/src/dimension.rs +++ b/datafusion/functions-array/src/dimension.rs @@ -33,7 +33,7 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayDims, array_dims, array, @@ -88,7 +88,7 @@ impl ScalarUDFImpl for ArrayDims { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayNdims, array_ndims, array, diff --git a/datafusion/functions-array/src/empty.rs b/datafusion/functions-array/src/empty.rs index d5fa174eee5f..9fe2c870496b 100644 --- a/datafusion/functions-array/src/empty.rs +++ b/datafusion/functions-array/src/empty.rs @@ -28,7 +28,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayEmpty, array_empty, array, diff --git a/datafusion/functions-array/src/except.rs b/datafusion/functions-array/src/except.rs index 444c7c758771..a56bab1e0611 100644 --- a/datafusion/functions-array/src/except.rs +++ b/datafusion/functions-array/src/except.rs @@ -31,7 +31,7 @@ use std::any::Any; use std::collections::HashSet; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayExcept, array_except, first_array second_array, diff --git a/datafusion/functions-array/src/extract.rs b/datafusion/functions-array/src/extract.rs index 0dbd106b6f18..842f4ec1b839 100644 --- a/datafusion/functions-array/src/extract.rs +++ b/datafusion/functions-array/src/extract.rs @@ -44,7 +44,7 @@ use std::sync::Arc; use crate::utils::make_scalar_function; // Create static instances of ScalarUDFs for each function -make_udf_function!( +make_udf_expr_and_func!( ArrayElement, array_element, array element, @@ -52,15 +52,9 @@ make_udf_function!( array_element_udf ); -make_udf_function!( - ArraySlice, - array_slice, - array begin end stride, - "returns a slice of the array.", - array_slice_udf -); +create_func!(ArraySlice, array_slice_udf); -make_udf_function!( +make_udf_expr_and_func!( ArrayPopFront, array_pop_front, array, @@ -68,7 +62,7 @@ make_udf_function!( array_pop_front_udf ); -make_udf_function!( +make_udf_expr_and_func!( ArrayPopBack, array_pop_back, array, @@ -224,6 +218,15 @@ where Ok(arrow::array::make_array(data)) } +#[doc = "returns a slice of the array."] +pub fn array_slice(array: Expr, begin: Expr, end: Expr, stride: Option) -> Expr { + let args = match stride { + Some(stride) => vec![array, begin, end, stride], + None => vec![array, begin, end], + }; + array_slice_udf().call(args) +} + #[derive(Debug)] pub(super) struct ArraySlice { signature: Signature, diff --git a/datafusion/functions-array/src/flatten.rs b/datafusion/functions-array/src/flatten.rs index e2b50c6c02cc..294d41ada7c3 100644 --- a/datafusion/functions-array/src/flatten.rs +++ b/datafusion/functions-array/src/flatten.rs @@ -31,7 +31,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( Flatten, flatten, array, diff --git a/datafusion/functions-array/src/length.rs b/datafusion/functions-array/src/length.rs index 9bbd11950d21..9cdcaddf8dff 100644 --- a/datafusion/functions-array/src/length.rs +++ b/datafusion/functions-array/src/length.rs @@ -32,7 +32,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayLength, array_length, array, diff --git a/datafusion/functions-array/src/macros.rs b/datafusion/functions-array/src/macros.rs index c49f5830b8d5..4e00aa39bd84 100644 --- a/datafusion/functions-array/src/macros.rs +++ b/datafusion/functions-array/src/macros.rs @@ -19,8 +19,8 @@ /// /// 1. Single `ScalarUDF` instance /// -/// Creates a singleton `ScalarUDF` of the `$UDF` function named `$GNAME` and a -/// function named `$NAME` which returns that function named $NAME. +/// Creates a singleton `ScalarUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$SCALAR_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. /// /// This is used to ensure creating the list of `ScalarUDF` only happens once. /// @@ -41,10 +41,9 @@ /// * `arg`: 0 or more named arguments for the function /// * `DOC`: documentation string for the function /// * `SCALAR_UDF_FUNC`: name of the function to create (just) the `ScalarUDF` -/// * `GNAME`: name for the single static instance of the `ScalarUDF` /// /// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl -macro_rules! make_udf_function { +macro_rules! make_udf_expr_and_func { ($UDF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr , $SCALAR_UDF_FN:ident) => { paste::paste! { // "fluent expr_fn" style function @@ -55,25 +54,7 @@ macro_rules! make_udf_function { vec![$($arg),*], )) } - - /// Singleton instance of [`$UDF`], ensures the UDF is only created once - /// named STATIC_$(UDF). For example `STATIC_ArrayToString` - #[allow(non_upper_case_globals)] - static [< STATIC_ $UDF >]: std::sync::OnceLock> = - std::sync::OnceLock::new(); - - /// ScalarFunction that returns a [`ScalarUDF`] for [`$UDF`] - /// - /// [`ScalarUDF`]: datafusion_expr::ScalarUDF - pub fn $SCALAR_UDF_FN() -> std::sync::Arc { - [< STATIC_ $UDF >] - .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( - <$UDF>::new(), - )) - }) - .clone() - } + create_func!($UDF, $SCALAR_UDF_FN); } }; ($UDF:ty, $EXPR_FN:ident, $DOC:expr , $SCALAR_UDF_FN:ident) => { @@ -86,7 +67,24 @@ macro_rules! make_udf_function { arg, )) } + create_func!($UDF, $SCALAR_UDF_FN); + } + }; +} +/// Creates a singleton `ScalarUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$SCALAR_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. +/// +/// This is used to ensure creating the list of `ScalarUDF` only happens once. +/// +/// # Arguments +/// * `UDF`: name of the [`ScalarUDFImpl`] +/// * `SCALAR_UDF_FUNC`: name of the function to create (just) the `ScalarUDF` +/// +/// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl +macro_rules! create_func { + ($UDF:ty, $SCALAR_UDF_FN:ident) => { + paste::paste! { /// Singleton instance of [`$UDF`], ensures the UDF is only created once /// named STATIC_$(UDF). For example `STATIC_ArrayToString` #[allow(non_upper_case_globals)] diff --git a/datafusion/functions-array/src/make_array.rs b/datafusion/functions-array/src/make_array.rs index 4f7dda933f42..4723464dfaf2 100644 --- a/datafusion/functions-array/src/make_array.rs +++ b/datafusion/functions-array/src/make_array.rs @@ -35,7 +35,7 @@ use datafusion_expr::{Expr, TypeSignature}; use crate::utils::make_scalar_function; -make_udf_function!( +make_udf_expr_and_func!( MakeArray, make_array, "Returns an Arrow array using the specified input expressions.", diff --git a/datafusion/functions-array/src/position.rs b/datafusion/functions-array/src/position.rs index a5a7a7405aa9..efdb7dff0ce6 100644 --- a/datafusion/functions-array/src/position.rs +++ b/datafusion/functions-array/src/position.rs @@ -37,7 +37,7 @@ use itertools::Itertools; use crate::utils::{compare_element_to_list, make_scalar_function}; -make_udf_function!( +make_udf_expr_and_func!( ArrayPosition, array_position, array element index, @@ -168,7 +168,7 @@ fn generic_position( Ok(Arc::new(UInt64Array::from(data))) } -make_udf_function!( +make_udf_expr_and_func!( ArrayPositions, array_positions, array element, // arg name diff --git a/datafusion/functions-array/src/range.rs b/datafusion/functions-array/src/range.rs index 150fe5960266..9a9829f96100 100644 --- a/datafusion/functions-array/src/range.rs +++ b/datafusion/functions-array/src/range.rs @@ -35,7 +35,7 @@ use datafusion_expr::{ use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( Range, range, start stop step, @@ -106,7 +106,7 @@ impl ScalarUDFImpl for Range { } } -make_udf_function!( +make_udf_expr_and_func!( GenSeries, gen_series, start stop step, diff --git a/datafusion/functions-array/src/remove.rs b/datafusion/functions-array/src/remove.rs index 21e373081054..7645c1a57573 100644 --- a/datafusion/functions-array/src/remove.rs +++ b/datafusion/functions-array/src/remove.rs @@ -32,7 +32,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayRemove, array_remove, array element, @@ -81,7 +81,7 @@ impl ScalarUDFImpl for ArrayRemove { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayRemoveN, array_remove_n, array element max, @@ -130,7 +130,7 @@ impl ScalarUDFImpl for ArrayRemoveN { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayRemoveAll, array_remove_all, array element, diff --git a/datafusion/functions-array/src/repeat.rs b/datafusion/functions-array/src/repeat.rs index 89b766bdcdfc..df623c114818 100644 --- a/datafusion/functions-array/src/repeat.rs +++ b/datafusion/functions-array/src/repeat.rs @@ -34,7 +34,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayRepeat, array_repeat, element count, // arg name diff --git a/datafusion/functions-array/src/replace.rs b/datafusion/functions-array/src/replace.rs index c32305bb454b..7cea4945836e 100644 --- a/datafusion/functions-array/src/replace.rs +++ b/datafusion/functions-array/src/replace.rs @@ -38,19 +38,19 @@ use std::any::Any; use std::sync::Arc; // Create static instances of ScalarUDFs for each function -make_udf_function!(ArrayReplace, +make_udf_expr_and_func!(ArrayReplace, array_replace, array from to, "replaces the first occurrence of the specified element with another specified element.", array_replace_udf ); -make_udf_function!(ArrayReplaceN, +make_udf_expr_and_func!(ArrayReplaceN, array_replace_n, array from to max, "replaces the first `max` occurrences of the specified element with another specified element.", array_replace_n_udf ); -make_udf_function!(ArrayReplaceAll, +make_udf_expr_and_func!(ArrayReplaceAll, array_replace_all, array from to, "replaces all occurrences of the specified element with another specified element.", diff --git a/datafusion/functions-array/src/resize.rs b/datafusion/functions-array/src/resize.rs index 561e98e8b76f..63f28c9afa77 100644 --- a/datafusion/functions-array/src/resize.rs +++ b/datafusion/functions-array/src/resize.rs @@ -30,7 +30,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayResize, array_resize, array size value, diff --git a/datafusion/functions-array/src/reverse.rs b/datafusion/functions-array/src/reverse.rs index 9be640565703..3076013899ef 100644 --- a/datafusion/functions-array/src/reverse.rs +++ b/datafusion/functions-array/src/reverse.rs @@ -30,7 +30,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayReverse, array_reverse, array, diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs index 416e79cbc079..5280355a8224 100644 --- a/datafusion/functions-array/src/rewrite.rs +++ b/datafusion/functions-array/src/rewrite.rs @@ -171,7 +171,7 @@ impl FunctionRewrite for ArrayFunctionRewriter { stop, stride, }, - }) => Transformed::yes(array_slice(*expr, *start, *stop, *stride)), + }) => Transformed::yes(array_slice(*expr, *start, *stop, Some(*stride))), _ => Transformed::no(expr), }; diff --git a/datafusion/functions-array/src/set_ops.rs b/datafusion/functions-array/src/set_ops.rs index 5f3087fafd6f..40676b7cdcb8 100644 --- a/datafusion/functions-array/src/set_ops.rs +++ b/datafusion/functions-array/src/set_ops.rs @@ -37,7 +37,7 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; // Create static instances of ScalarUDFs for each function -make_udf_function!( +make_udf_expr_and_func!( ArrayUnion, array_union, array1 array2, @@ -45,7 +45,7 @@ make_udf_function!( array_union_udf ); -make_udf_function!( +make_udf_expr_and_func!( ArrayIntersect, array_intersect, first_array second_array, @@ -53,7 +53,7 @@ make_udf_function!( array_intersect_udf ); -make_udf_function!( +make_udf_expr_and_func!( ArrayDistinct, array_distinct, array, diff --git a/datafusion/functions-array/src/sort.rs b/datafusion/functions-array/src/sort.rs index af78712065fc..16f271ef10ff 100644 --- a/datafusion/functions-array/src/sort.rs +++ b/datafusion/functions-array/src/sort.rs @@ -30,7 +30,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArraySort, array_sort, array desc null_first, diff --git a/datafusion/functions-array/src/string.rs b/datafusion/functions-array/src/string.rs index 38059035005b..4122ddbd45eb 100644 --- a/datafusion/functions-array/src/string.rs +++ b/datafusion/functions-array/src/string.rs @@ -102,7 +102,7 @@ macro_rules! call_array_function { } // Create static instances of ScalarUDFs for each function -make_udf_function!( +make_udf_expr_and_func!( ArrayToString, array_to_string, array delimiter, // arg name @@ -160,7 +160,7 @@ impl ScalarUDFImpl for ArrayToString { } } -make_udf_function!( +make_udf_expr_and_func!( StringToArray, string_to_array, string delimiter null_string, // arg name diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 2927fd01d1b3..ec215937dca8 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -582,7 +582,13 @@ async fn roundtrip_expr_api() -> Result<()> { make_array(vec![lit(1), lit(2), lit(3)]), lit(1), lit(2), + Some(lit(1)), + ), + array_slice( + make_array(vec![lit(1), lit(2), lit(3)]), lit(1), + lit(2), + None, ), array_pop_front(make_array(vec![lit(1), lit(2), lit(3)])), array_pop_back(make_array(vec![lit(1), lit(2), lit(3)])),