From 6b1e9c6a3ae95b7065e902d99d9fde66f0f8e054 Mon Sep 17 00:00:00 2001 From: junxiangMu <63799833+guojidan@users.noreply.github.com> Date: Wed, 3 Jan 2024 20:24:58 +0800 Subject: [PATCH] Implement trait based API for defining WindowUDF (#8719) * Implement trait based API for defining WindowUDF * add test case & docs * fix docs * rename WindowUDFImpl function --- datafusion-examples/README.md | 1 + datafusion-examples/examples/advanced_udwf.rs | 230 ++++++++++++++++++ .../user_defined_window_functions.rs | 64 +++-- datafusion/expr/src/expr_fn.rs | 67 ++++- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udwf.rs | 116 ++++++++- .../tests/cases/roundtrip_logical_plan.rs | 55 +++-- docs/source/library-user-guide/adding-udfs.md | 7 +- 8 files changed, 498 insertions(+), 44 deletions(-) create mode 100644 datafusion-examples/examples/advanced_udwf.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 1296c74ea277..aae451add9e7 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -63,6 +63,7 @@ cargo run --example csv_sql - [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) +- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) ## Distributed diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs new file mode 100644 index 000000000000..91869d80a41a --- /dev/null +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use std::any::Any; + +use arrow::{ + array::{ArrayRef, AsArray, Float64Array}, + datatypes::Float64Type, +}; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::ScalarValue; +use datafusion_expr::{ + PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl, +}; + +/// This example shows how to use the full WindowUDFImpl API to implement a user +/// defined window function. As in the `simple_udwf.rs` example, this struct implements +/// a function `partition_evaluator` that returns the `MyPartitionEvaluator` instance. +/// +/// To do so, we must implement the `WindowUDFImpl` trait. +struct SmoothItUdf { + signature: Signature, +} + +impl SmoothItUdf { + /// Create a new instance of the SmoothItUdf struct + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take one arguments of type f64 + vec![DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + } + } +} + +impl WindowUDFImpl for SmoothItUdf { + /// We implement as_any so that we can downcast the WindowUDFImpl trait object + fn as_any(&self) -> &dyn Any { + self + } + + /// Return the name of this function + fn name(&self) -> &str { + "smooth_it" + } + + /// Return the "signature" of this function -- namely that types of arguments it will take + fn signature(&self) -> &Signature { + &self.signature + } + + /// What is the type of value that will be returned by this function. + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + /// Create a `PartitionEvalutor` to evaluate this function on a new + /// partition. + fn partition_evaluator(&self) -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) + } +} + +/// This implements the lowest level evaluation for a window function +/// +/// It handles calculating the value of the window function for each +/// distinct values of `PARTITION BY` (each car type in our example) +#[derive(Clone, Debug)] +struct MyPartitionEvaluator {} + +impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } +} + +/// Different evaluation methods are called depending on the various +/// settings of WindowUDF. This example uses the simplest and most +/// general, `evaluate`. See `PartitionEvaluator` for the other more +/// advanced uses. +impl PartitionEvaluator for MyPartitionEvaluator { + /// Tell DataFusion the window function varies based on the value + /// of the window frame. + fn uses_window_frame(&self) -> bool { + true + } + + /// This function is called once per input row. + /// + /// `range`specifies which indexes of `values` should be + /// considered for the calculation. + /// + /// Note this is the SLOWEST, but simplest, way to evaluate a + /// window function. It is much faster to implement + /// evaluate_all or evaluate_all_with_rank, if possible + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &std::ops::Range, + ) -> Result { + // Again, the input argument is an array of floating + // point numbers to calculate a moving average + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + + let range_len = range.end - range.start; + + // our smoothing function will average all the values in the + let output = if range_len > 0 { + let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); + Some(sum / range_len as f64) + } else { + None + }; + + Ok(ScalarValue::Float64(output)) + } +} + +// create local execution context with `cars.csv` registered as a table named `cars` +async fn create_context() -> Result { + // declare a new context. In spark API, this corresponds to a new spark SQL session + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + println!("pwd: {}", std::env::current_dir().unwrap().display()); + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); + let read_options = CsvReadOptions::default().has_header(true); + + ctx.register_csv("cars", &csv_path, read_options).await?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context().await?; + let smooth_it = WindowUDF::from(SmoothItUdf::new()); + ctx.register_udwf(smooth_it.clone()); + + // Use SQL to run the new window function + let df = ctx.sql("SELECT * from cars").await?; + // print the results + df.show().await?; + + // Use SQL to run the new window function: + // + // `PARTITION BY car`:each distinct value of car (red, and green) + // should be treated as a separate partition (and will result in + // creating a new `PartitionEvaluator`) + // + // `ORDER BY time`: within each partition ('green' or 'red') the + // rows will be be ordered by the value in the `time` column + // + // `evaluate_inside_range` is invoked with a window defined by the + // SQL. In this case: + // + // The first invocation will be passed row 0, the first row in the + // partition. + // + // The second invocation will be passed rows 0 and 1, the first + // two rows in the partition. + // + // etc. + let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; + // print the results + df.show().await?; + + // this time, call the new widow function with an explicit + // window so evaluate will be invoked with each window. + // + // `ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING`: each invocation + // sees at most 3 rows: the row before, the current row, and the 1 + // row afterward. + let df = ctx.sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ).await?; + // print the results + df.show().await?; + + // Now, run the function using the DataFrame API: + let window_expr = smooth_it.call( + vec![col("speed")], // smooth_it(speed) + vec![col("car")], // PARTITION BY car + vec![col("time").sort(true, true)], // ORDER BY time ASC + WindowFrame::new(false), + ); + let df = ctx.table("cars").await?.window(vec![window_expr])?; + + // print the results + df.show().await?; + + Ok(()) +} diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 5f9939157217..3040fbafe81a 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -19,6 +19,7 @@ //! user defined window functions use std::{ + any::Any, ops::Range, sync::{ atomic::{AtomicUsize, Ordering}, @@ -32,8 +33,7 @@ use arrow_schema::DataType; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - function::PartitionEvaluatorFactory, PartitionEvaluator, ReturnTypeFunction, - Signature, Volatility, WindowUDF, + PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; /// A query with a window function evaluated over the entire partition @@ -471,24 +471,48 @@ impl OddCounter { } fn register(ctx: &mut SessionContext, test_state: Arc) { - let name = "odd_counter"; - let volatility = Volatility::Immutable; - - let signature = Signature::exact(vec![DataType::Int64], volatility); - - let return_type = Arc::new(DataType::Int64); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::clone(&return_type))); - - let partition_evaluator_factory: PartitionEvaluatorFactory = - Arc::new(move || Ok(Box::new(OddCounter::new(Arc::clone(&test_state))))); - - ctx.register_udwf(WindowUDF::new( - name, - &signature, - &return_type, - &partition_evaluator_factory, - )) + struct SimpleWindowUDF { + signature: Signature, + return_type: DataType, + test_state: Arc, + } + + impl SimpleWindowUDF { + fn new(test_state: Arc) -> Self { + let signature = + Signature::exact(vec![DataType::Float64], Volatility::Immutable); + let return_type = DataType::Int64; + Self { + signature, + return_type, + test_state, + } + } + } + + impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "odd_counter" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn partition_evaluator(&self) -> Result> { + Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state)))) + } + } + + ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state))) } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index eed41d97ccba..f76fb17b38bb 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -28,7 +28,7 @@ use crate::{ BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, }; -use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF}; +use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; use std::any::Any; @@ -1059,13 +1059,66 @@ pub fn create_udwf( volatility: Volatility, partition_evaluator_factory: PartitionEvaluatorFactory, ) -> WindowUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - WindowUDF::new( + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + WindowUDF::from(SimpleWindowUDF::new( name, - &Signature::exact(vec![input_type], volatility), - &return_type, - &partition_evaluator_factory, - ) + input_type, + return_type, + volatility, + partition_evaluator_factory, + )) +} + +/// Implements [`WindowUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleWindowUDF { + name: String, + signature: Signature, + return_type: DataType, + partition_evaluator_factory: PartitionEvaluatorFactory, +} + +impl SimpleWindowUDF { + /// Create a new `SimpleWindowUDF` from a name, input types, return type and + /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_type: DataType, + return_type: DataType, + volatility: Volatility, + partition_evaluator_factory: PartitionEvaluatorFactory, + ) -> Self { + let name = name.into(); + let signature = Signature::exact([input_type].to_vec(), volatility); + Self { + name, + signature, + return_type, + partition_evaluator_factory, + } + } +} + +impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn partition_evaluator(&self) -> Result> { + (self.partition_evaluator_factory)() + } } /// Calls a named built in function diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index ab213a19a352..077681d21725 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -82,7 +82,7 @@ pub use signature::{ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; pub use udf::{ScalarUDF, ScalarUDFImpl}; -pub use udwf::WindowUDF; +pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; #[cfg(test)] diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index a97a68341f5c..800386bfc77b 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -24,6 +24,7 @@ use crate::{ use arrow::datatypes::DataType; use datafusion_common::Result; use std::{ + any::Any, fmt::{self, Debug, Display, Formatter}, sync::Arc, }; @@ -80,7 +81,11 @@ impl std::hash::Hash for WindowUDF { } impl WindowUDF { - /// Create a new WindowUDF + /// Create a new WindowUDF from low level details. + /// + /// See [`WindowUDFImpl`] for a more convenient way to create a + /// `WindowUDF` using trait objects + #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl instead")] pub fn new( name: &str, signature: &Signature, @@ -95,6 +100,32 @@ impl WindowUDF { } } + /// Create a new `WindowUDF` from a `[WindowUDFImpl]` trait object + /// + /// Note this is the same as using the `From` impl (`WindowUDF::from`) + pub fn new_from_impl(fun: F) -> WindowUDF + where + F: WindowUDFImpl + Send + Sync + 'static, + { + let arc_fun = Arc::new(fun); + let captured_self = arc_fun.clone(); + let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { + let return_type = captured_self.return_type(arg_types)?; + Ok(Arc::new(return_type)) + }); + + let captured_self = arc_fun.clone(); + let partition_evaluator_factory: PartitionEvaluatorFactory = + Arc::new(move || captured_self.partition_evaluator()); + + Self { + name: arc_fun.name().to_string(), + signature: arc_fun.signature().clone(), + return_type: return_type.clone(), + partition_evaluator_factory, + } + } + /// creates a [`Expr`] that calls the window function given /// the `partition_by`, `order_by`, and `window_frame` definition /// @@ -140,3 +171,86 @@ impl WindowUDF { (self.partition_evaluator_factory)() } } + +impl From for WindowUDF +where + F: WindowUDFImpl + Send + Sync + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// Trait for implementing [`WindowUDF`]. +/// +/// This trait exposes the full API for implementing user defined window functions and +/// can be used to implement any function. +/// +/// See [`advanced_udwf.rs`] for a full example with complete implementation and +/// [`WindowUDF`] for other available options. +/// +/// +/// [`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame}; +/// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; +/// struct SmoothIt { +/// signature: Signature +/// }; +/// +/// impl SmoothIt { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// } +/// } +/// } +/// +/// /// Implement the WindowUDFImpl trait for AddOne +/// impl WindowUDFImpl for SmoothIt { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "smooth_it" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Int32)) { +/// return plan_err!("smooth_it only accepts Int32 arguments"); +/// } +/// Ok(DataType::Int32) +/// } +/// // The actual implementation would add one to the argument +/// fn partition_evaluator(&self) -> Result> { unimplemented!() } +/// } +/// +/// // Create a new ScalarUDF from the implementation +/// let smooth_it = WindowUDF::from(SmoothIt::new()); +/// +/// // Call the function `add_one(col)` +/// let expr = smooth_it.call( +/// vec![col("speed")], // smooth_it(speed) +/// vec![col("car")], // PARTITION BY car +/// vec![col("time").sort(true, true)], // ORDER BY time ASC +/// WindowFrame::new(false), +/// ); +/// ``` +pub trait WindowUDFImpl { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns the function's [`Signature`] for information about what input + /// types are accepted and the function's Volatility. + fn signature(&self) -> &Signature; + + /// What [`DataType`] will be returned by this function, given the types of + /// the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; + + /// Invoke the function, returning the [`PartitionEvaluator`] instance + fn partition_evaluator(&self) -> Result>; +} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index dea99f91e392..402781e17e6f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::collections::HashMap; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -54,6 +55,7 @@ use datafusion_expr::{ BuiltinScalarFunction::{Sqrt, Substr}, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, + WindowUDFImpl, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -1785,27 +1787,52 @@ fn roundtrip_window() { } } - fn return_type(arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return plan_err!( - "dummy_udwf expects 1 argument, got {}: {:?}", - arg_types.len(), - arg_types - ); + struct SimpleWindowUDF { + signature: Signature, + } + + impl SimpleWindowUDF { + fn new() -> Self { + let signature = + Signature::exact(vec![DataType::Float64], Volatility::Immutable); + Self { signature } + } + } + + impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dummy_udwf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return plan_err!( + "dummy_udwf expects 1 argument, got {}: {:?}", + arg_types.len(), + arg_types + ); + } + Ok(arg_types[0].clone()) + } + + fn partition_evaluator(&self) -> Result> { + make_partition_evaluator() } - Ok(Arc::new(arg_types[0].clone())) } fn make_partition_evaluator() -> Result> { Ok(Box::new(DummyWindow {})) } - let dummy_window_udf = WindowUDF::new( - "dummy_udwf", - &Signature::exact(vec![DataType::Float64], Volatility::Immutable), - &(Arc::new(return_type) as _), - &(Arc::new(make_partition_evaluator) as _), - ); + let dummy_window_udf = WindowUDF::from(SimpleWindowUDF::new()); let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index c51e4de3236c..1f687f978f30 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -201,7 +201,8 @@ fn make_partition_evaluator() -> Result> { ### Registering a Window UDF -To register a Window UDF, you need to wrap the function implementation in a `WindowUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udwf` helper functions to make this easier. +To register a Window UDF, you need to wrap the function implementation in a [`WindowUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udwf`] helper functions to make this easier. +There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udwf.rs`]. ```rust use datafusion::logical_expr::{Volatility, create_udwf}; @@ -218,6 +219,10 @@ let smooth_it = create_udwf( ); ``` +[`windowudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.WindowUDF.html +[`create_udwf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udwf.html +[`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs + The `create_udwf` has five arguments to check: - The first argument is the name of the function. This is the name that will be used in SQL queries.