Skip to content

Commit

Permalink
Add ScalarUDFImpl::invoke_with_args to support passing the return t…
Browse files Browse the repository at this point in the history
…ype created for the udf instance (#13290)

* Added support for `ScalarUDFImpl::invoke_with_return_type` where the invoke is passed the return type created for the udf instance

* Do not yet deprecate invoke_batch, add docs to invoke_with_args

* add ticket reference

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
joseph-isaacs and alamb authored Nov 21, 2024
1 parent e7d9504 commit edbd93a
Show file tree
Hide file tree
Showing 13 changed files with 107 additions and 75 deletions.
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{
aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs,
};
pub use udf::{scalar_doc_sections, ScalarUDF, ScalarUDFImpl};
pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
pub use udf_docs::{DocSection, Documentation, DocumentationBuilder};
pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
98 changes: 57 additions & 41 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,7 @@ impl ScalarUDF {
self.inner.simplify(args, info)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke`] for more details.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
#[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke(args)
Expand All @@ -216,20 +213,27 @@ impl ScalarUDF {
self.inner.is_nullable(args, schema)
}

/// Invoke the function with `args` and number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_batch`] for more details.
#[deprecated(since = "43.0.0", note = "Use `invoke_with_args` instead")]
pub fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke_batch(args, number_rows)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_with_args`] for details.
pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}

/// Invoke the function without `args` but number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
/// Note: This method is deprecated and will be removed in future releases.
/// User defined functions should implement [`Self::invoke_with_args`] instead.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
pub fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
#[allow(deprecated)]
Expand Down Expand Up @@ -324,26 +328,37 @@ where
}
}

/// Trait for implementing [`ScalarUDF`].
pub struct ScalarFunctionArgs<'a> {
// The evaluated arguments to the function
pub args: &'a [ColumnarValue],
// The number of rows in record batch being evaluated
pub number_rows: usize,
// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`)
// when creating the physical expression from the logical expression
pub return_type: &'a DataType,
}

/// Trait for implementing user defined scalar functions.
///
/// This trait exposes the full API for implementing user defined functions and
/// can be used to implement any function.
///
/// See [`advanced_udf.rs`] for a full example with complete implementation and
/// [`ScalarUDF`] for other available options.
///
///
/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
///
/// # Basic Example
/// ```
/// # use std::any::Any;
/// # use std::sync::OnceLock;
/// # use arrow::datatypes::DataType;
/// # use datafusion_common::{DataFusionError, plan_err, Result};
/// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility};
/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility};
/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
///
/// /// This struct for a simple UDF that adds one to an int32
/// #[derive(Debug)]
/// struct AddOne {
/// signature: Signature,
Expand All @@ -356,7 +371,7 @@ where
/// }
/// }
/// }
///
///
/// static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
///
/// fn get_doc() -> &'static Documentation {
Expand All @@ -383,7 +398,9 @@ where
/// Ok(DataType::Int32)
/// }
/// // The actual implementation would add one to the argument
/// fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { unimplemented!() }
/// fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
/// unimplemented!()
/// }
/// fn documentation(&self) -> Option<&Documentation> {
/// Some(get_doc())
/// }
Expand Down Expand Up @@ -479,24 +496,9 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {

/// Invoke the function on `args`, returning the appropriate result
///
/// The function will be invoked passed with the slice of [`ColumnarValue`]
/// (either scalar or array).
///
/// If the function does not take any arguments, please use [invoke_no_args]
/// instead and return [not_impl_err] for this function.
///
///
/// # Performance
///
/// For the best performance, the implementations of `invoke` should handle
/// the common case when one or more of their arguments are constant values
/// (aka [`ColumnarValue::Scalar`]).
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
///
/// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
/// Note: This method is deprecated and will be removed in future releases.
/// User defined functions should implement [`Self::invoke_with_args`] instead.
#[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
not_impl_err!(
"Function {} does not implement invoke but called",
Expand All @@ -507,17 +509,12 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// Invoke the function with `args` and the number of rows,
/// returning the appropriate result.
///
/// The function will be invoked with the slice of [`ColumnarValue`]
/// (either scalar or array).
///
/// # Performance
/// Note: See notes on [`Self::invoke_with_args`]
///
/// For the best performance, the implementations should handle the common case
/// when one or more of their arguments are constant values (aka
/// [`ColumnarValue::Scalar`]).
/// Note: This method is deprecated and will be removed in future releases.
/// User defined functions should implement [`Self::invoke_with_args`] instead.
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
/// See <https://github.com/apache/datafusion/issues/13515> for more details.
fn invoke_batch(
&self,
args: &[ColumnarValue],
Expand All @@ -537,9 +534,27 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
}
}

/// Invoke the function returning the appropriate result.
///
/// # Performance
///
/// For the best performance, the implementations should handle the common case
/// when one or more of their arguments are constant values (aka
/// [`ColumnarValue::Scalar`]).
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.invoke_batch(args.args, args.number_rows)
}

/// Invoke the function without `args`, instead the number of rows are provided,
/// returning the appropriate result.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
///
/// Note: This method is deprecated and will be removed in future releases.
/// User defined functions should implement [`Self::invoke_with_args`] instead.
#[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
not_impl_err!(
"Function {} does not implement invoke_no_args but called",
Expand Down Expand Up @@ -767,6 +782,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke_batch(args, number_rows)
}

Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions/benches/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("random_1M_rows_batch_8192", |b| {
b.iter(|| {
for _ in 0..iterations {
#[allow(deprecated)] // TODO: migrate to invoke_with_args
black_box(random_func.invoke_batch(&[], 8192).unwrap());
}
})
Expand All @@ -39,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("random_1M_rows_batch_128", |b| {
b.iter(|| {
for _ in 0..iterations_128 {
#[allow(deprecated)] // TODO: migrate to invoke_with_args
black_box(random_func.invoke_batch(&[], 128).unwrap());
}
})
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/core/version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ mod test {
#[tokio::test]
async fn test_version_udf() {
let version_udf = ScalarUDF::from(VersionFunc::new());
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let version = version_udf.invoke_batch(&[], 1).unwrap();

if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = version {
Expand Down
9 changes: 7 additions & 2 deletions datafusion/functions/src/datetime/to_local_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ mod tests {
use arrow::datatypes::{DataType, TimeUnit};
use chrono::NaiveDateTime;
use datafusion_common::ScalarValue;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};

use super::{adjust_to_local_time, ToLocalTimeFunc};

Expand Down Expand Up @@ -558,7 +558,11 @@ mod tests {

fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) {
let res = ToLocalTimeFunc::new()
.invoke_batch(&[ColumnarValue::Scalar(input)], 1)
.invoke_with_args(ScalarFunctionArgs {
args: &[ColumnarValue::Scalar(input)],
number_rows: 1,
return_type: &expected.data_type(),
})
.unwrap();
match res {
ColumnarValue::Scalar(res) => {
Expand Down Expand Up @@ -617,6 +621,7 @@ mod tests {
.map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
.collect::<TimestampNanosecondArray>();
let batch_size = input.len();
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = ToLocalTimeFunc::new()
.invoke_batch(&[ColumnarValue::Array(Arc::new(input))], batch_size)
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/datetime/to_timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ mod tests {
for array in arrays {
let rt = udf.return_type(&[array.data_type()]).unwrap();
assert!(matches!(rt, Timestamp(_, Some(_))));

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let res = udf
.invoke_batch(&[array.clone()], 1)
.expect("that to_timestamp parsed values without error");
Expand Down Expand Up @@ -1051,7 +1051,7 @@ mod tests {
for array in arrays {
let rt = udf.return_type(&[array.data_type()]).unwrap();
assert!(matches!(rt, Timestamp(_, None)));

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let res = udf
.invoke_batch(&[array.clone()], 1)
.expect("that to_timestamp parsed values without error");
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/datetime/to_unixtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ impl ScalarUDFImpl for ToUnixtimeFunc {
DataType::Date64 | DataType::Date32 | DataType::Timestamp(_, None) => args[0]
.cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)?
.cast_to(&DataType::Int64, None),
#[allow(deprecated)] // TODO: migrate to invoke_with_args
DataType::Utf8 => ToTimestampSecondsFunc::new()
.invoke_batch(args, batch_size)?
.cast_to(&DataType::Int64, None),
Expand Down
20 changes: 10 additions & 10 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ mod tests {
]))), // num
ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))),
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let _ = LogFunc::new().invoke_batch(&args, 4);
}

Expand All @@ -286,7 +286,7 @@ mod tests {
let args = [
ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new().invoke_batch(&args, 1);
result.expect_err("expected error");
}
Expand All @@ -296,7 +296,7 @@ mod tests {
let args = [
ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -320,7 +320,7 @@ mod tests {
let args = [
ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -345,7 +345,7 @@ mod tests {
ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num
ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -370,7 +370,7 @@ mod tests {
ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num
ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -396,7 +396,7 @@ mod tests {
10.0, 100.0, 1000.0, 10000.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -425,7 +425,7 @@ mod tests {
10.0, 100.0, 1000.0, 10000.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -455,7 +455,7 @@ mod tests {
8.0, 4.0, 81.0, 625.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -485,7 +485,7 @@ mod tests {
8.0, 4.0, 81.0, 625.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/math/power.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ mod tests {
ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base
ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0]))), // exponent
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = PowerFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function power");
Expand All @@ -232,7 +232,7 @@ mod tests {
ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base
ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = PowerFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function power");
Expand Down
Loading

0 comments on commit edbd93a

Please sign in to comment.