diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index e120c5e7bf8e..f1d763ba6e41 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -125,6 +125,7 @@ impl TableProvider for LocalCsvTable { )?)) } } + struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index 1e710bc321a2..11cf52eb3fcf 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -17,17 +17,18 @@ under the License. --> -# Adding User Defined Functions: Scalar/Window/Aggregate +# Adding User Defined Functions: Scalar/Window/Aggregate/Table Functions User Defined Functions (UDFs) are functions that can be used in the context of DataFusion execution. This page covers how to add UDFs to DataFusion. In particular, it covers how to add Scalar, Window, and Aggregate UDFs. -| UDF Type | Description | Example | -| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------ | -| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs) | -| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs) | -| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs) | +| UDF Type | Description | Example | +| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------- | +| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | +| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | +| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | +| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | First we'll talk about adding an Scalar UDF end-to-end, then we'll talk about the differences between the different types of UDFs. @@ -432,3 +433,100 @@ Then, we can query like below: ```rust let df = ctx.sql("SELECT geo_mean(a) FROM t").await?; ``` + +## Adding a User-Defined Table Function + +A User-Defined Table Function (UDTF) is a function that takes parameters and returns a `TableProvider`. + +Because we're returning a `TableProvider`, in this example we'll use the `MemTable` data source to represent a table. This is a simple struct that holds a set of RecordBatches in memory and treats them as a table. In your case, this would be replaced with your own struct that implements `TableProvider`. + +While this is a simple example for illustrative purposes, UDTFs have a lot of potential use cases. And can be particularly useful for reading data from external sources and interactive analysis. For example, see the [example][4] for a working example that reads from a CSV file. As another example, you could use the built-in UDTF `parquet_metadata` in the CLI to read the metadata from a Parquet file. + +```console +❯ select filename, row_group_id, row_group_num_rows, row_group_bytes, stats_min, stats_max from parquet_metadata('./benchmarks/data/hits.parquet') where column_id = 17 limit 10; ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +| filename | row_group_id | row_group_num_rows | row_group_bytes | stats_min | stats_max | ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +| ./benchmarks/data/hits.parquet | 0 | 450560 | 188921521 | 0 | 73256 | +| ./benchmarks/data/hits.parquet | 1 | 612174 | 210338885 | 0 | 109827 | +| ./benchmarks/data/hits.parquet | 2 | 344064 | 161242466 | 0 | 122484 | +| ./benchmarks/data/hits.parquet | 3 | 606208 | 235549898 | 0 | 121073 | +| ./benchmarks/data/hits.parquet | 4 | 335872 | 137103898 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 5 | 311296 | 145453612 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 6 | 303104 | 138833963 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 7 | 303104 | 191140113 | 0 | 73256 | +| ./benchmarks/data/hits.parquet | 8 | 573440 | 208038598 | 0 | 95823 | +| ./benchmarks/data/hits.parquet | 9 | 344064 | 147838157 | 0 | 73256 | ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +``` + +### Writing the UDTF + +The simple UDTF used here takes a single `Int64` argument and returns a table with a single column with the value of the argument. To create a function in DataFusion, you need to implement the `TableFunctionImpl` trait. This trait has a single method, `call`, that takes a slice of `Expr`s and returns a `Result>`. + +In the `call` method, you parse the input `Expr`s and return a `TableProvider`. You might also want to do some validation of the input `Expr`s, e.g. checking that the number of arguments is correct. + +```rust +use datafusion::common::plan_err; +use datafusion::datasource::function::TableFunctionImpl; +// Other imports here + +/// A table function that returns a table provider with the value as a single column +#[derive(Default)] +pub struct EchoFunction {} + +impl TableFunctionImpl for EchoFunction { + fn call(&self, exprs: &[Expr]) -> Result> { + let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { + return plan_err!("First argument must be an integer"); + }; + + // Create the schema for the table + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + // Create a single RecordBatch with the value as a single column + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(vec![*value]))], + )?; + + // Create a MemTable plan that returns the RecordBatch + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + + Ok(Arc::new(provider)) + } +} +``` + +### Registering and Using the UDTF + +With the UDTF implemented, you can register it with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let ctx = SessionContext::new(); + +ctx.register_udtf("echo", Arc::new(EchoFunction::default())); +``` + +And if all goes well, you can use it in your query: + +```rust +use datafusion::arrow::util::pretty; + +let df = ctx.sql("SELECT * FROM echo(1)").await?; + +let results = df.collect().await?; +pretty::print_batches(&results)?; +// +---+ +// | a | +// +---+ +// | 1 | +// +---+ +``` + +[1]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs +[2]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs +[3]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs +[4]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udtf.rs