diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 491fac272c2c..dda6ba62e0af 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -28,6 +28,7 @@ use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::cast::as_float64_array; use datafusion_expr::ColumnarValue; +use datafusion_physical_expr::functions::columnar_values_to_array; use std::sync::Arc; /// create local execution context with an in-memory table: @@ -70,22 +71,11 @@ async fn main() -> Result<()> { // this is guaranteed by DataFusion based on the function's signature. assert_eq!(args.len(), 2); - // Try to obtain row number - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - - let inferred_length = len.unwrap_or(1); - - let arg0 = args[0].clone().into_array(inferred_length)?; - let arg1 = args[1].clone().into_array(inferred_length)?; + let args = columnar_values_to_array(args)?; // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! - let base = as_float64_array(&arg0).expect("cast failed"); - let exponent = as_float64_array(&arg1).expect("cast failed"); + let base = as_float64_array(&args[0]).expect("cast failed"); + let exponent = as_float64_array(&args[1]).expect("cast failed"); // this is guaranteed by DataFusion. We place it just to make it obvious. assert_eq!(exponent.len(), base.len()); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 561fe1d12d92..1c1228949171 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1376,6 +1376,7 @@ mod tests { use datafusion_physical_expr::execution_props::ExecutionProps; use chrono::{DateTime, TimeZone, Utc}; + use datafusion_physical_expr::functions::columnar_values_to_array; // ------------------------------ // --- ExprSimplifier tests ----- @@ -1489,30 +1490,10 @@ mod tests { let return_type = Arc::new(DataType::Int32); let fun = Arc::new(|args: &[ColumnarValue]| { - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - - let inferred_length = len.unwrap_or(1); - - let arg0 = match &args[0] { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => { - scalar.to_array_of_size(inferred_length).unwrap() - } - }; - let arg1 = match &args[1] { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => { - scalar.to_array_of_size(inferred_length).unwrap() - } - }; + let args = columnar_values_to_array(args)?; - let arg0 = as_int32_array(&arg0)?; - let arg1 = as_int32_array(&arg1)?; + let arg0 = as_int32_array(&args[0])?; + let arg1 = as_int32_array(&args[1])?; // 2. perform the computation let array = arg0 diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index ac959dec6e89..2bfdf499123b 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -42,6 +42,7 @@ use arrow::{ compute::kernels::length::{bit_length, length}, datatypes::{DataType, Int32Type, Int64Type, Schema}, }; +use arrow_array::Array; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; pub use datafusion_expr::FuncMonotonicity; use datafusion_expr::{ @@ -191,6 +192,51 @@ pub(crate) enum Hint { AcceptsSingular, } +/// A helper function used to infer the length of arguments of Scalar functions and convert +/// [`ColumnarValue`]s to [`ArrayRef`]s with the inferred length. Note that this function +/// only works for functions that accept either that all arguments are scalars or all arguments +/// are arrays with same length. Otherwise, it will return an error. +pub fn columnar_values_to_array(args: &[ColumnarValue]) -> Result> { + if args.is_empty() { + return Ok(vec![]); + } + + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) if acc.is_none() => Some(1), + ColumnarValue::Scalar(_) => { + if let Some(1) = acc { + acc + } else { + None + } + } + ColumnarValue::Array(a) => { + if let Some(l) = acc { + if l == a.len() { + acc + } else { + None + } + } else { + Some(a.len()) + } + } + }); + + let inferred_length = len.ok_or(DataFusionError::Internal( + "Arguments has mixed length".to_string(), + ))?; + + let args = args + .iter() + .map(|arg| arg.clone().into_array(inferred_length)) + .collect::>>()?; + + Ok(args) +} + /// Decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function /// and vice-versa after evaluation. /// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar. diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index 64dc25411deb..1824b23f9f9b 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -41,12 +41,12 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, Int64Array}; use datafusion::common::Result; - use datafusion::common::cast::as_int64_array; +use datafusion::physical_plan::functions::columnar_values_to_array; -pub fn add_one(args: &[ArrayRef]) -> Result { +pub fn add_one(args: &[ColumnarValue]) -> Result { // Error handling omitted for brevity - + let args = columnar_values_to_array(args)?; let i64s = as_int64_array(&args[0])?; let new_array = i64s @@ -82,7 +82,6 @@ There is a lower level API with more functionality but is more complex, that is ```rust use datafusion::logical_expr::{Volatility, create_udf}; -use datafusion::physical_plan::functions::make_scalar_function; use datafusion::arrow::datatypes::DataType; use std::sync::Arc; @@ -91,13 +90,13 @@ let udf = create_udf( vec![DataType::Int64], Arc::new(DataType::Int64), Volatility::Immutable, - make_scalar_function(add_one), + Arc::new(add_one), ); ``` [`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html [`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html -[`make_scalar_function`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.make_scalar_function.html +[`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs A few things to note: