From 417c8fedefb81675988c3a8974cd4f07eaa21b18 Mon Sep 17 00:00:00 2001 From: Yanxin Xiang Date: Mon, 26 Feb 2024 19:27:11 -0600 Subject: [PATCH] port range function and change gen_series logic --- datafusion/expr/src/built_in_function.rs | 15 --- datafusion/expr/src/expr_fn.rs | 6 - datafusion/functions-array/src/kernels.rs | 69 ++++++++++- datafusion/functions-array/src/lib.rs | 6 +- datafusion/functions-array/src/udf.rs | 114 +++++++++++++++++- .../physical-expr/src/array_expressions.rs | 1 - datafusion/physical-expr/src/functions.rs | 3 - datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 11 +- datafusion/proto/src/logical_plan/to_proto.rs | 1 - .../test_files/range_and_gen_series.slt | 48 ++++++++ 13 files changed, 236 insertions(+), 47 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/range_and_gen_series.slt diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 8b4e65121c79f..08547b5af4280 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -186,8 +186,6 @@ pub enum BuiltinScalarFunction { MakeArray, /// Flatten Flatten, - /// Range - Range, // struct functions /// struct @@ -431,7 +429,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable, BuiltinScalarFunction::ArrayUnion => Volatility::Immutable, BuiltinScalarFunction::ArrayResize => Volatility::Immutable, - BuiltinScalarFunction::Range => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, @@ -643,9 +640,6 @@ impl BuiltinScalarFunction { (dt, _) => Ok(dt), } } - BuiltinScalarFunction::Range => { - Ok(List(Arc::new(Field::new("item", Int64, true)))) - } BuiltinScalarFunction::ArrayExcept => { match (input_expr_types[0].clone(), input_expr_types[1].clone()) { (DataType::Null, _) | (_, DataType::Null) => { @@ -987,14 +981,6 @@ impl BuiltinScalarFunction { Signature::variadic_any(self.volatility()) } - BuiltinScalarFunction::Range => Signature::one_of( - vec![ - Exact(vec![Int64]), - Exact(vec![Int64, Int64]), - Exact(vec![Int64, Int64, Int64]), - ], - self.volatility(), - ), BuiltinScalarFunction::Struct => Signature::variadic_any(self.volatility()), BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { @@ -1628,7 +1614,6 @@ impl BuiltinScalarFunction { &["array_intersect", "list_intersect"] } BuiltinScalarFunction::OverLay => &["overlay"], - BuiltinScalarFunction::Range => &["range", "generate_series"], // struct functions BuiltinScalarFunction::Struct => &["struct"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4aa270e6dde6f..431af50922d99 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -767,12 +767,6 @@ scalar_expr!( "Returns an array of the elements in the intersection of array1 and array2." ); -nary_scalar_expr!( - Range, - gen_range, - "Returns a list of values in the range between start and stop with step." -); - // string functions scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character"); scalar_expr!( diff --git a/datafusion/functions-array/src/kernels.rs b/datafusion/functions-array/src/kernels.rs index 1b96e01d8b9a9..64c99c2417ad3 100644 --- a/datafusion/functions-array/src/kernels.rs +++ b/datafusion/functions-array/src/kernels.rs @@ -23,11 +23,12 @@ use arrow::array::{ StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array}; +use datafusion_common::cast::{ + as_int64_array, as_large_list_array, as_list_array, as_string_array, +}; use datafusion_common::{exec_err, DataFusionError}; use std::any::type_name; use std::sync::Arc; - macro_rules! downcast_arg { ($ARG:expr, $ARRAY_TYPE:ident) => {{ $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { @@ -252,3 +253,67 @@ pub(super) fn array_to_string(args: &[ArrayRef]) -> datafusion_common::Result` representing the resulting ListArray after the operation. +/// +/// # Arguments +/// +/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. +/// +/// # Examples +/// +/// gen_range(3) => [0, 1, 2] +/// gen_range(1, 4) => [1, 2, 3] +/// gen_range(1, 7, 2) => [1, 3, 5] +pub fn gen_range( + args: &[ArrayRef], + include_upper: i64, +) -> datafusion_common::Result { + let (start_array, stop_array, step_array) = match args.len() { + 1 => (None, as_int64_array(&args[0])?, None), + 2 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + None, + ), + 3 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + Some(as_int64_array(&args[2])?), + ), + _ => return exec_err!("gen_range expects 1 to 3 arguments"), + }; + + let mut values = vec![]; + let mut offsets = vec![0]; + for (idx, stop) in stop_array.iter().enumerate() { + let stop = stop.unwrap_or(0) + include_upper; + let start = start_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(0); + let step = step_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(1); + if step == 0 { + return exec_err!("step can't be 0 for function range(start [, stop, step]"); + } + if step < 0 { + // Decreasing range + values.extend((stop + 1..start + 1).rev().step_by((-step) as usize)); + } else { + // Increasing range + values.extend((start..stop).step_by(step as usize)); + } + + offsets.push(values.len() as i32); + } + let arr = Arc::new(ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(Int64Array::from(values)), + None, + )?); + Ok(arr) +} diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index 84997ed10e323..52ee35211888f 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -44,7 +44,11 @@ pub mod expr_fn { /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { - let functions: Vec> = vec![udf::array_to_string_udf()]; + let functions: Vec> = vec![ + udf::array_to_string_udf(), + udf::range_udf(), + udf::gen_series_udf(), + ]; functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; if let Some(existing_udf) = existing_udf { diff --git a/datafusion/functions-array/src/udf.rs b/datafusion/functions-array/src/udf.rs index b7f9d2497fb70..7bce44950b04b 100644 --- a/datafusion/functions-array/src/udf.rs +++ b/datafusion/functions-array/src/udf.rs @@ -18,12 +18,14 @@ //! [`ScalarUDFImpl`] definitions for array functions. use arrow::datatypes::DataType; +use arrow::datatypes::Field; use datafusion_common::{plan_err, DataFusionError}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; +use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; - +use std::sync::Arc; // Create static instances of ScalarUDFs for each function make_udf_function!(ArrayToString, array_to_string, @@ -31,7 +33,6 @@ make_udf_function!(ArrayToString, "converts each element to its text representation.", // doc array_to_string_udf // internal function name ); - #[derive(Debug)] pub(super) struct ArrayToString { signature: Signature, @@ -83,3 +84,112 @@ impl ScalarUDFImpl for ArrayToString { &self.aliases } } + +make_udf_function!( + Range, + range, + input diamilter, + "create a list of values in the range between start and stop", + range_udf +); +#[derive(Debug)] +pub(super) struct Range { + signature: Signature, + aliases: Vec, +} +impl Range { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Int64]), + Exact(vec![Int64, Int64]), + Exact(vec![Int64, Int64, Int64]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("range")], + } + } +} +impl ScalarUDFImpl for Range { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "range" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + use DataType::*; + Ok(List(Arc::new(Field::new("item", Int64, true)))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + let args = ColumnarValue::values_to_arrays(args)?; + crate::kernels::gen_range(&args, 0).map(ColumnarValue::Array) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} +make_udf_function!( + GenSeries, + gen_series, + input diamilter, + "create a list of values in the range between start and stop, include upper bound", + gen_series_udf +); +#[derive(Debug)] +pub(super) struct GenSeries { + signature: Signature, + aliases: Vec, +} +impl GenSeries { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Int64]), + Exact(vec![Int64, Int64]), + Exact(vec![Int64, Int64, Int64]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("generate_series")], + } + } +} +impl ScalarUDFImpl for GenSeries { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "generate_series" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + use DataType::*; + Ok(List(Arc::new(Field::new("item", Int64, true)))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + let args = ColumnarValue::values_to_arrays(args)?; + crate::kernels::gen_range(&args, 1).map(ColumnarValue::Array) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 38a4359b4f4b6..d6cd03ced40aa 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -39,7 +39,6 @@ use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; - use itertools::Itertools; macro_rules! downcast_arg { diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 8446a65d72c8c..6c4f87b5b3a03 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -410,9 +410,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| { make_scalar_function_inner(array_expressions::array_intersect)(args) }), - BuiltinScalarFunction::Range => Arc::new(|args| { - make_scalar_function_inner(array_expressions::gen_range)(args) - }), BuiltinScalarFunction::Cardinality => Arc::new(|args| { make_scalar_function_inner(array_expressions::cardinality)(args) }), diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 7673ce86ae1db..483d047291fd7 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -665,7 +665,7 @@ enum ScalarFunction { ArrayIntersect = 119; ArrayUnion = 120; OverLay = 121; - Range = 122; + /// 122 is Range ArrayExcept = 123; ArrayPopFront = 124; Levenshtein = 125; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 65483f9ac4678..8ca0a7ba08966 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22438,7 +22438,6 @@ impl serde::Serialize for ScalarFunction { Self::ArrayIntersect => "ArrayIntersect", Self::ArrayUnion => "ArrayUnion", Self::OverLay => "OverLay", - Self::Range => "Range", Self::ArrayExcept => "ArrayExcept", Self::ArrayPopFront => "ArrayPopFront", Self::Levenshtein => "Levenshtein", @@ -22581,7 +22580,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayIntersect", "ArrayUnion", "OverLay", - "Range", "ArrayExcept", "ArrayPopFront", "Levenshtein", @@ -22753,7 +22751,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect), "ArrayUnion" => Ok(ScalarFunction::ArrayUnion), "OverLay" => Ok(ScalarFunction::OverLay), - "Range" => Ok(ScalarFunction::Range), "ArrayExcept" => Ok(ScalarFunction::ArrayExcept), "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), "Levenshtein" => Ok(ScalarFunction::Levenshtein), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index a567269e33568..026f477cce71d 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2753,7 +2753,7 @@ pub enum ScalarFunction { ArrayIntersect = 119, ArrayUnion = 120, OverLay = 121, - Range = 122, + /// / 122 is Range ArrayExcept = 123, ArrayPopFront = 124, Levenshtein = 125, @@ -2893,7 +2893,6 @@ impl ScalarFunction { ScalarFunction::ArrayIntersect => "ArrayIntersect", ScalarFunction::ArrayUnion => "ArrayUnion", ScalarFunction::OverLay => "OverLay", - ScalarFunction::Range => "Range", ScalarFunction::ArrayExcept => "ArrayExcept", ScalarFunction::ArrayPopFront => "ArrayPopFront", ScalarFunction::Levenshtein => "Levenshtein", @@ -3030,7 +3029,6 @@ impl ScalarFunction { "ArrayIntersect" => Some(Self::ArrayIntersect), "ArrayUnion" => Some(Self::ArrayUnion), "OverLay" => Some(Self::OverLay), - "Range" => Some(Self::Range), "ArrayExcept" => Some(Self::ArrayExcept), "ArrayPopFront" => Some(Self::ArrayPopFront), "Levenshtein" => Some(Self::Levenshtein), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2554018a9273a..fb8c7b28236f0 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -57,8 +57,8 @@ use datafusion_expr::{ chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, date_trunc, degrees, digest, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, initcap, - instr, iszero, lcm, left, levenshtein, ln, log, log10, log2, + factorial, find_in_set, flatten, floor, from_unixtime, gcd, initcap, instr, iszero, + lcm, left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, nanvl, now, octet_length, overlay, pi, power, radians, random, regexp_like, regexp_match, regexp_replace, repeat, replace, reverse, right, @@ -509,7 +509,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayIntersect => Self::ArrayIntersect, ScalarFunction::ArrayUnion => Self::ArrayUnion, ScalarFunction::ArrayResize => Self::ArrayResize, - ScalarFunction::Range => Self::Range, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, ScalarFunction::DatePart => Self::DatePart, @@ -1466,12 +1465,6 @@ pub fn parse_expr( parse_expr(&args[2], registry)?, parse_expr(&args[3], registry)?, )), - ScalarFunction::Range => Ok(gen_range( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry)) - .collect::, _>>()?, - )), ScalarFunction::Cardinality => { Ok(cardinality(parse_expr(&args[0], registry)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ccadbb217a581..9793a04eea592 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1491,7 +1491,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArraySlice => Self::ArraySlice, BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect, BuiltinScalarFunction::ArrayUnion => Self::ArrayUnion, - BuiltinScalarFunction::Range => Self::Range, BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, BuiltinScalarFunction::DatePart => Self::DatePart, diff --git a/datafusion/sqllogictest/test_files/range_and_gen_series.slt b/datafusion/sqllogictest/test_files/range_and_gen_series.slt new file mode 100644 index 0000000000000..1e385b8ad44c8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/range_and_gen_series.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query ? +SELECT range(5); +---- +[0, 1, 2, 3, 4] + + +query ? +SELECT range(2, 5); +---- +[2, 3, 4] + + +query ? +SELECT range(2, 5, 3); +---- +[2] + +query ? +SELECT generate_series(5); +---- +[0, 1, 2, 3, 4, 5] + +query ? +SELECT generate_series(2, 5); +---- +[2, 3, 4, 5] + +query ? +SELECT generate_series(2, 5, 3); +---- +[2, 5]