Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ScalarUDFImpl::invoke_with_args to support passing the return type created for the udf instance #13290

Merged
merged 4 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{
aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs,
};
pub use udf::{scalar_doc_sections, ScalarUDF, ScalarUDFImpl};
pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
pub use udf_docs::{DocSection, Documentation, DocumentationBuilder};
pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
98 changes: 57 additions & 41 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,7 @@ impl ScalarUDF {
self.inner.simplify(args, info)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke`] for more details.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
#[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke(args)
Expand All @@ -216,20 +213,27 @@ impl ScalarUDF {
self.inner.is_nullable(args, schema)
}

/// Invoke the function with `args` and number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_batch`] for more details.
#[deprecated(since = "43.0.0", note = "Use `invoke_with_args` instead")]
pub fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke_batch(args, number_rows)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_with_args`] for details.
pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}

/// Invoke the function without `args` but number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
/// Note: This method is deprecated and will be removed in future releases.
/// User defined functions should implement [`Self::invoke_with_args`] instead.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
pub fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
#[allow(deprecated)]
Expand Down Expand Up @@ -324,26 +328,37 @@ where
}
}

/// Trait for implementing [`ScalarUDF`].
pub struct ScalarFunctionArgs<'a> {
Copy link
Contributor

@alamb alamb Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this structure also provides a potential place for "preparing" a scalar function (e.g. to pre-compile a regex 🤔 ) - #8051

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you suggest having a opaque block returned from a (new interface method) which is that that value to invoke?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was thinking something like

trait ScalarUDFImpl { 
  /// prepares to run the function, returning any state (such as 
  /// a precompiled regex). Called once per instance of function in the query
  fn prepare(&self, args: &ScalarFunctionArgs) -> Option<Box<dyn Any>> { None }

  /// `prepared` field in ScalarFunctonArgs has the result of calling `prepare`
  fn invoke_with_args(&self, args: &ScalarFunctionArgs) -> ...
pub struct ScalarFunctionArgs<'a> {
  ...
  /// The result from a call to `prepare` for this function instance
  prepared: Option<Box<dyn Any>>,
}

// The evaluated arguments to the function
pub args: &'a [ColumnarValue],
// The number of rows in record batch being evaluated
pub number_rows: usize,
// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`)
// when creating the physical expression from the logical expression
pub return_type: &'a DataType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did I not see this till now?

Please please please add SessionContext or at least SessionConfig to this. It would allow us to unblock so so many tickets

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should also be easy to add as a follow-up?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, as long as it's done in this release cycle so that we don't churn the api even more between cycles

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes -- I will file a ticket

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracking with #13519

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(BTW the change in the design to a struct I think allows for much easier non-breaking API additions in the future)

}

/// Trait for implementing user defined scalar functions.
///
/// This trait exposes the full API for implementing user defined functions and
/// can be used to implement any function.
///
/// See [`advanced_udf.rs`] for a full example with complete implementation and
/// [`ScalarUDF`] for other available options.
///
///
/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
///
/// # Basic Example
/// ```
/// # use std::any::Any;
/// # use std::sync::OnceLock;
/// # use arrow::datatypes::DataType;
/// # use datafusion_common::{DataFusionError, plan_err, Result};
/// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility};
/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility};
/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
///
/// /// This struct for a simple UDF that adds one to an int32
/// #[derive(Debug)]
/// struct AddOne {
/// signature: Signature,
Expand All @@ -356,7 +371,7 @@ where
/// }
/// }
/// }
///
///
/// static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
///
/// fn get_doc() -> &'static Documentation {
Expand All @@ -383,7 +398,9 @@ where
/// Ok(DataType::Int32)
/// }
/// // The actual implementation would add one to the argument
/// fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { unimplemented!() }
/// fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
/// unimplemented!()
/// }
/// fn documentation(&self) -> Option<&Documentation> {
/// Some(get_doc())
/// }
Expand Down Expand Up @@ -479,24 +496,9 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {

/// Invoke the function on `args`, returning the appropriate result
///
/// The function will be invoked passed with the slice of [`ColumnarValue`]
/// (either scalar or array).
///
/// If the function does not take any arguments, please use [invoke_no_args]
/// instead and return [not_impl_err] for this function.
///
///
/// # Performance
///
/// For the best performance, the implementations of `invoke` should handle
/// the common case when one or more of their arguments are constant values
/// (aka [`ColumnarValue::Scalar`]).
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
///
/// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
/// Note: This method is deprecated and will be removed in future releases.
/// User defined functions should implement [`Self::invoke_with_args`] instead.
#[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
not_impl_err!(
"Function {} does not implement invoke but called",
Expand All @@ -507,17 +509,12 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// Invoke the function with `args` and the number of rows,
/// returning the appropriate result.
///
/// The function will be invoked with the slice of [`ColumnarValue`]
/// (either scalar or array).
///
/// # Performance
/// Note: See notes on [`Self::invoke_with_args`]
///
/// For the best performance, the implementations should handle the common case
/// when one or more of their arguments are constant values (aka
/// [`ColumnarValue::Scalar`]).
/// Note: This method is deprecated and will be removed in future releases.
/// User defined functions should implement [`Self::invoke_with_args`] instead.
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
/// See <https://github.com/apache/datafusion/issues/13515> for more details.
fn invoke_batch(
&self,
args: &[ColumnarValue],
Expand All @@ -537,9 +534,27 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
}
}

/// Invoke the function returning the appropriate result.
///
/// # Performance
///
/// For the best performance, the implementations should handle the common case
/// when one or more of their arguments are constant values (aka
/// [`ColumnarValue::Scalar`]).
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.invoke_batch(args.args, args.number_rows)
}

/// Invoke the function without `args`, instead the number of rows are provided,
/// returning the appropriate result.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
///
/// Note: This method is deprecated and will be removed in future releases.
/// User defined functions should implement [`Self::invoke_with_args`] instead.
#[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
not_impl_err!(
"Function {} does not implement invoke_no_args but called",
Expand Down Expand Up @@ -767,6 +782,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke_batch(args, number_rows)
}

Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions/benches/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("random_1M_rows_batch_8192", |b| {
b.iter(|| {
for _ in 0..iterations {
#[allow(deprecated)] // TODO: migrate to invoke_with_args
black_box(random_func.invoke_batch(&[], 8192).unwrap());
}
})
Expand All @@ -39,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("random_1M_rows_batch_128", |b| {
b.iter(|| {
for _ in 0..iterations_128 {
#[allow(deprecated)] // TODO: migrate to invoke_with_args
black_box(random_func.invoke_batch(&[], 128).unwrap());
}
})
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/core/version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ mod test {
#[tokio::test]
async fn test_version_udf() {
let version_udf = ScalarUDF::from(VersionFunc::new());
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let version = version_udf.invoke_batch(&[], 1).unwrap();

if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = version {
Expand Down
9 changes: 7 additions & 2 deletions datafusion/functions/src/datetime/to_local_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ mod tests {
use arrow::datatypes::{DataType, TimeUnit};
use chrono::NaiveDateTime;
use datafusion_common::ScalarValue;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};

use super::{adjust_to_local_time, ToLocalTimeFunc};

Expand Down Expand Up @@ -558,7 +558,11 @@ mod tests {

fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) {
let res = ToLocalTimeFunc::new()
.invoke_batch(&[ColumnarValue::Scalar(input)], 1)
.invoke_with_args(ScalarFunctionArgs {
args: &[ColumnarValue::Scalar(input)],
number_rows: 1,
return_type: &expected.data_type(),
})
.unwrap();
match res {
ColumnarValue::Scalar(res) => {
Expand Down Expand Up @@ -617,6 +621,7 @@ mod tests {
.map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
.collect::<TimestampNanosecondArray>();
let batch_size = input.len();
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = ToLocalTimeFunc::new()
.invoke_batch(&[ColumnarValue::Array(Arc::new(input))], batch_size)
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/datetime/to_timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ mod tests {
for array in arrays {
let rt = udf.return_type(&[array.data_type()]).unwrap();
assert!(matches!(rt, Timestamp(_, Some(_))));

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let res = udf
.invoke_batch(&[array.clone()], 1)
.expect("that to_timestamp parsed values without error");
Expand Down Expand Up @@ -1051,7 +1051,7 @@ mod tests {
for array in arrays {
let rt = udf.return_type(&[array.data_type()]).unwrap();
assert!(matches!(rt, Timestamp(_, None)));

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let res = udf
.invoke_batch(&[array.clone()], 1)
.expect("that to_timestamp parsed values without error");
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/datetime/to_unixtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ impl ScalarUDFImpl for ToUnixtimeFunc {
DataType::Date64 | DataType::Date32 | DataType::Timestamp(_, None) => args[0]
.cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)?
.cast_to(&DataType::Int64, None),
#[allow(deprecated)] // TODO: migrate to invoke_with_args
DataType::Utf8 => ToTimestampSecondsFunc::new()
.invoke_batch(args, batch_size)?
.cast_to(&DataType::Int64, None),
Expand Down
20 changes: 10 additions & 10 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ mod tests {
]))), // num
ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))),
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let _ = LogFunc::new().invoke_batch(&args, 4);
}

Expand All @@ -286,7 +286,7 @@ mod tests {
let args = [
ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new().invoke_batch(&args, 1);
result.expect_err("expected error");
}
Expand All @@ -296,7 +296,7 @@ mod tests {
let args = [
ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -320,7 +320,7 @@ mod tests {
let args = [
ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -345,7 +345,7 @@ mod tests {
ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num
ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -370,7 +370,7 @@ mod tests {
ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num
ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -396,7 +396,7 @@ mod tests {
10.0, 100.0, 1000.0, 10000.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -425,7 +425,7 @@ mod tests {
10.0, 100.0, 1000.0, 10000.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -455,7 +455,7 @@ mod tests {
8.0, 4.0, 81.0, 625.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -485,7 +485,7 @@ mod tests {
8.0, 4.0, 81.0, 625.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/math/power.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ mod tests {
ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base
ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0]))), // exponent
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = PowerFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function power");
Expand All @@ -232,7 +232,7 @@ mod tests {
ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base
ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = PowerFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function power");
Expand Down
Loading