Skip to content

Commit

Permalink
Deprecate make_scalar_function (#8878)
Browse files Browse the repository at this point in the history
* Make make_scalar_function private

* More

* More

* Fix

* More

* Update datafusion/physical-expr/src/functions.rs

Co-authored-by: Andrew Lamb <[email protected]>

* For review

* For review

* Fix

* Update deprecated since tag

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
viirya and alamb authored Jan 22, 2024
1 parent 795c71f commit 903ef94
Show file tree
Hide file tree
Showing 11 changed files with 315 additions and 232 deletions.
29 changes: 20 additions & 9 deletions datafusion-examples/examples/simple_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ use datafusion::{
logical_expr::Volatility,
};

use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion::{error::Result, physical_plan::functions::make_scalar_function};
use datafusion_common::cast::as_float64_array;
use datafusion_expr::ColumnarValue;
use std::sync::Arc;

/// create local execution context with an in-memory table:
Expand Down Expand Up @@ -61,17 +62,30 @@ async fn main() -> Result<()> {
let ctx = create_context()?;

// First, declare the actual implementation of the calculation
let pow = |args: &[ArrayRef]| {
let pow = Arc::new(|args: &[ColumnarValue]| {
// in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to:
// 1. cast the values to the type we want
// 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result

// this is guaranteed by DataFusion based on the function's signature.
assert_eq!(args.len(), 2);

// Try to obtain row number
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let inferred_length = len.unwrap_or(1);

let arg0 = args[0].clone().into_array(inferred_length)?;
let arg1 = args[1].clone().into_array(inferred_length)?;

// 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics!
let base = as_float64_array(&args[0]).expect("cast failed");
let exponent = as_float64_array(&args[1]).expect("cast failed");
let base = as_float64_array(&arg0).expect("cast failed");
let exponent = as_float64_array(&arg1).expect("cast failed");

// this is guaranteed by DataFusion. We place it just to make it obvious.
assert_eq!(exponent.len(), base.len());
Expand All @@ -92,11 +106,8 @@ async fn main() -> Result<()> {

// `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!)
// `Arc` because arrays are immutable, thread-safe, trait objects.
Ok(Arc::new(array) as ArrayRef)
};
// the function above expects an `ArrayRef`, but DataFusion may pass a scalar to a UDF.
// thus, we use `make_scalar_function` to decorare the closure so that it can handle both Arrays and Scalar values.
let pow = make_scalar_function(pow);
Ok(ColumnarValue::from(Arc::new(array) as ArrayRef))
});

// Next:
// * give it a name so that it shows nicely when the plan is printed
Expand Down
52 changes: 33 additions & 19 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ use arrow::compute::kernels::numeric::add;
use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::*;
use datafusion::{
execution::registry::FunctionRegistry,
physical_plan::functions::make_scalar_function, test_util,
};
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::as_float64_array;
use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue};
use datafusion_expr::{
Expand Down Expand Up @@ -87,12 +84,18 @@ async fn scalar_udf() -> Result<()> {

ctx.register_batch("t", batch)?;

let myfunc = |args: &[ArrayRef]| {
let l = as_int32_array(&args[0])?;
let r = as_int32_array(&args[1])?;
Ok(Arc::new(add(l, r)?) as ArrayRef)
};
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(l) = &args[0] else {
panic!("should be array")
};
let ColumnarValue::Array(r) = &args[1] else {
panic!("should be array")
};

let l = as_int32_array(l)?;
let r = as_int32_array(r)?;
Ok(ColumnarValue::from(Arc::new(add(l, r)?) as ArrayRef))
});

ctx.register_udf(create_udf(
"my_add",
Expand Down Expand Up @@ -163,11 +166,14 @@ async fn scalar_udf_zero_params() -> Result<()> {

ctx.register_batch("t", batch)?;
// create function just returns 100 regardless of inp
let myfunc = |args: &[ArrayRef]| {
let num_rows = args[0].len();
Ok(Arc::new((0..num_rows).map(|_| 100).collect::<Int32Array>()) as ArrayRef)
};
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Scalar(_) = &args[0] else {
panic!("expect scalar")
};
Ok(ColumnarValue::Array(
Arc::new((0..1).map(|_| 100).collect::<Int32Array>()) as ArrayRef,
))
});

ctx.register_udf(create_udf(
"get_100",
Expand Down Expand Up @@ -307,8 +313,12 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
ctx.register_batch("t", batch).unwrap();

let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!("should be array")
};
Ok(ColumnarValue::from(Arc::clone(array)))
});

ctx.register_udf(create_udf(
"MY_FUNC",
Expand Down Expand Up @@ -348,8 +358,12 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
ctx.register_batch("t", batch).unwrap();

let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!("should be array")
};
Ok(ColumnarValue::from(Arc::clone(array)))
});

let udf = create_udf(
"dummy",
Expand Down
12 changes: 12 additions & 0 deletions datafusion/expr/src/columnar_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ pub enum ColumnarValue {
Scalar(ScalarValue),
}

impl From<ArrayRef> for ColumnarValue {
fn from(value: ArrayRef) -> Self {
ColumnarValue::Array(value)
}
}

impl From<ScalarValue> for ColumnarValue {
fn from(value: ScalarValue) -> Self {
ColumnarValue::Scalar(value)
}
}

impl ColumnarValue {
pub fn data_type(&self) -> DataType {
match self {
Expand Down
37 changes: 28 additions & 9 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1321,9 +1321,7 @@ mod tests {
assert_contains, cast::as_int32_array, plan_datafusion_err, DFField, ToDFSchema,
};
use datafusion_expr::{interval_arithmetic::Interval, *};
use datafusion_physical_expr::{
execution_props::ExecutionProps, functions::make_scalar_function,
};
use datafusion_physical_expr::execution_props::ExecutionProps;

use chrono::{DateTime, TimeZone, Utc};

Expand Down Expand Up @@ -1438,9 +1436,31 @@ mod tests {
let input_types = vec![DataType::Int32, DataType::Int32];
let return_type = Arc::new(DataType::Int32);

let fun = |args: &[ArrayRef]| {
let arg0 = as_int32_array(&args[0])?;
let arg1 = as_int32_array(&args[1])?;
let fun = Arc::new(|args: &[ColumnarValue]| {
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let inferred_length = len.unwrap_or(1);

let arg0 = match &args[0] {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => {
scalar.to_array_of_size(inferred_length).unwrap()
}
};
let arg1 = match &args[1] {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => {
scalar.to_array_of_size(inferred_length).unwrap()
}
};

let arg0 = as_int32_array(&arg0)?;
let arg1 = as_int32_array(&arg1)?;

// 2. perform the computation
let array = arg0
Expand All @@ -1456,10 +1476,9 @@ mod tests {
})
.collect::<Int32Array>();

Ok(Arc::new(array) as ArrayRef)
};
Ok(ColumnarValue::from(Arc::new(array) as ArrayRef))
});

let fun = make_scalar_function(fun);
Arc::new(create_udf(
"udf_add",
input_types,
Expand Down
Loading

0 comments on commit 903ef94

Please sign in to comment.