Skip to content

Commit

Permalink
Added support for ScalarUDFImpl::invoke_with_return_type where the …
Browse files Browse the repository at this point in the history
…invoke is passed the return type created for the udf instance
  • Loading branch information
joseph-isaacs committed Nov 19, 2024
1 parent 398d5f6 commit ae73371
Show file tree
Hide file tree
Showing 13 changed files with 89 additions and 41 deletions.
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{
aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs,
};
pub use udf::{scalar_doc_sections, ScalarUDF, ScalarUDFImpl};
pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
pub use udf_docs::{DocSection, Documentation, DocumentationBuilder};
pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
46 changes: 39 additions & 7 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,6 @@ impl ScalarUDF {
self.inner.simplify(args, info)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke`] for more details.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
#[allow(deprecated)]
Expand All @@ -216,17 +213,23 @@ impl ScalarUDF {
self.inner.is_nullable(args, schema)
}

/// Invoke the function with `args` and number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_batch`] for more details.
#[deprecated(since = "43.0.0", note = "Use `invoke_batch` instead")]
pub fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke_batch(args, number_rows)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_with_args`] for more details.
pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}

/// Invoke the function without `args` but number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
Expand Down Expand Up @@ -324,6 +327,16 @@ where
}
}

pub struct ScalarFunctionArgs<'a> {
// The evaluated arguments to the function
pub args: &'a [ColumnarValue],
// The number of rows in record batch being evaluated
pub number_rows: usize,
// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`)
// when creating the physical expression from the logical expression
pub return_type: &'a DataType,
}

/// Trait for implementing [`ScalarUDF`].
///
/// This trait exposes the full API for implementing user defined functions and
Expand Down Expand Up @@ -356,7 +369,7 @@ where
/// }
/// }
/// }
///
///
/// static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
///
/// fn get_doc() -> &'static Documentation {
Expand Down Expand Up @@ -518,6 +531,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
#[deprecated(since = "43.0.0", note = "Use `invoke_with_args` instead")]
fn invoke_batch(
&self,
args: &[ColumnarValue],
Expand All @@ -537,6 +551,23 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
}
}

/// Invoke the function with `args: ScalarFunctionArgs` returning the appropriate result.
///
/// The function will be invoked with a struct `ScalarFunctionArgs`
///
/// # Performance
///
/// For the best performance, the implementations should handle the common case
/// when one or more of their arguments are constant values (aka
/// [`ColumnarValue::Scalar`]).
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.invoke_batch(args.args, args.number_rows)
}

/// Invoke the function without `args`, instead the number of rows are provided,
/// returning the appropriate result.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
Expand Down Expand Up @@ -767,6 +798,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke_batch(args, number_rows)
}

Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions/benches/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("random_1M_rows_batch_8192", |b| {
b.iter(|| {
for _ in 0..iterations {
#[allow(deprecated)] // TODO: migrate to invoke_with_args
black_box(random_func.invoke_batch(&[], 8192).unwrap());
}
})
Expand All @@ -39,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("random_1M_rows_batch_128", |b| {
b.iter(|| {
for _ in 0..iterations_128 {
#[allow(deprecated)] // TODO: migrate to invoke_with_args
black_box(random_func.invoke_batch(&[], 128).unwrap());
}
})
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/core/version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ mod test {
#[tokio::test]
async fn test_version_udf() {
let version_udf = ScalarUDF::from(VersionFunc::new());
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let version = version_udf.invoke_batch(&[], 1).unwrap();

if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = version {
Expand Down
9 changes: 7 additions & 2 deletions datafusion/functions/src/datetime/to_local_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ mod tests {
use arrow::datatypes::{DataType, TimeUnit};
use chrono::NaiveDateTime;
use datafusion_common::ScalarValue;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};

use super::{adjust_to_local_time, ToLocalTimeFunc};

Expand Down Expand Up @@ -558,7 +558,11 @@ mod tests {

fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) {
let res = ToLocalTimeFunc::new()
.invoke_batch(&[ColumnarValue::Scalar(input)], 1)
.invoke_with_args(ScalarFunctionArgs {
args: &[ColumnarValue::Scalar(input)],
number_rows: 1,
return_type: &expected.data_type(),
})
.unwrap();
match res {
ColumnarValue::Scalar(res) => {
Expand Down Expand Up @@ -617,6 +621,7 @@ mod tests {
.map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
.collect::<TimestampNanosecondArray>();
let batch_size = input.len();
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = ToLocalTimeFunc::new()
.invoke_batch(&[ColumnarValue::Array(Arc::new(input))], batch_size)
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/datetime/to_timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ mod tests {
for array in arrays {
let rt = udf.return_type(&[array.data_type()]).unwrap();
assert!(matches!(rt, Timestamp(_, Some(_))));

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let res = udf
.invoke_batch(&[array.clone()], 1)
.expect("that to_timestamp parsed values without error");
Expand Down Expand Up @@ -1051,7 +1051,7 @@ mod tests {
for array in arrays {
let rt = udf.return_type(&[array.data_type()]).unwrap();
assert!(matches!(rt, Timestamp(_, None)));

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let res = udf
.invoke_batch(&[array.clone()], 1)
.expect("that to_timestamp parsed values without error");
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/datetime/to_unixtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ impl ScalarUDFImpl for ToUnixtimeFunc {
DataType::Date64 | DataType::Date32 | DataType::Timestamp(_, None) => args[0]
.cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)?
.cast_to(&DataType::Int64, None),
#[allow(deprecated)] // TODO: migrate to invoke_with_args
DataType::Utf8 => ToTimestampSecondsFunc::new()
.invoke_batch(args, batch_size)?
.cast_to(&DataType::Int64, None),
Expand Down
20 changes: 10 additions & 10 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ mod tests {
]))), // num
ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))),
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let _ = LogFunc::new().invoke_batch(&args, 4);
}

Expand All @@ -286,7 +286,7 @@ mod tests {
let args = [
ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new().invoke_batch(&args, 1);
result.expect_err("expected error");
}
Expand All @@ -296,7 +296,7 @@ mod tests {
let args = [
ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -320,7 +320,7 @@ mod tests {
let args = [
ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -345,7 +345,7 @@ mod tests {
ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num
ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -370,7 +370,7 @@ mod tests {
ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num
ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -396,7 +396,7 @@ mod tests {
10.0, 100.0, 1000.0, 10000.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -425,7 +425,7 @@ mod tests {
10.0, 100.0, 1000.0, 10000.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -455,7 +455,7 @@ mod tests {
8.0, 4.0, 81.0, 625.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -485,7 +485,7 @@ mod tests {
8.0, 4.0, 81.0, 625.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/math/power.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ mod tests {
ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base
ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0]))), // exponent
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = PowerFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function power");
Expand All @@ -232,7 +232,7 @@ mod tests {
ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base
ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = PowerFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function power");
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions/src/math/signum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ mod test {
f32::NEG_INFINITY,
]));
let batch_size = array.len();
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = SignumFunc::new()
.invoke_batch(&[ColumnarValue::Array(array)], batch_size)
.expect("failed to initialize function signum");
Expand Down Expand Up @@ -207,6 +208,7 @@ mod test {
f64::NEG_INFINITY,
]));
let batch_size = array.len();
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = SignumFunc::new()
.invoke_batch(&[ColumnarValue::Array(array)], batch_size)
.expect("failed to initialize function signum");
Expand Down
Loading

0 comments on commit ae73371

Please sign in to comment.