From 4a464cb5bea7d52712ca36ff48985cf81a145d71 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Fri, 27 Dec 2024 11:01:59 -0800 Subject: [PATCH] docs(substrait): document SubstraitProducer --- .../substrait/src/logical_plan/consumer.rs | 3 + .../substrait/src/logical_plan/producer.rs | 105 ++++++++++++++++-- .../tests/cases/roundtrip_logical_plan.rs | 2 +- 3 files changed, 97 insertions(+), 13 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 515553152659..d82237298436 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -114,6 +114,9 @@ use substrait::proto::{ /// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans. /// It can be implemented by users to allow for custom handling of relations, expressions, etc. /// +/// Combined with the [crate::logical_plan::producer::SubstraitProducer] this allows for fully +/// customizable Substrait serde. +/// /// # Example Usage /// /// ``` diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 0f4a062e2b1b..cf879ad65629 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -105,12 +105,89 @@ use substrait::{ version, }; +/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use std::sync::Arc; +/// # use substrait::proto::{Expression, Rel}; +/// # use substrait::proto::rel::RelType; +/// # use datafusion::common::DFSchemaRef; +/// # use datafusion::error::Result; +/// # use datafusion::execution::SessionState; +/// # use datafusion::logical_expr::{Between, Extension, Projection}; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; +/// +/// struct CustomSubstraitProducer { +/// extensions: Extensions, +/// state: Arc, +/// } +/// +/// impl SubstraitProducer for CustomSubstraitProducer { +/// +/// fn register_function(&mut self, signature: String) -> u32 { +/// self.extensions.register_function(signature) +/// } +/// +/// fn get_extensions(self) -> Extensions { +/// self.extensions +/// } +/// +/// // You can set additional metadata on the Rels you produce +/// fn consume_projection(&mut self, plan: &Projection) -> Result> { +/// let mut rel = from_projection(self, plan)?; +/// match rel.rel_type { +/// Some(RelType::Project(mut project)) => { +/// let mut project = project.clone(); +/// // set common metadata or advanced extension +/// project.common = None; +/// project.advanced_extension = None; +/// Ok(Box::new(Rel { +/// rel_type: Some(RelType::Project(project)), +/// })) +/// } +/// rel_type => Ok(Box::new(Rel { rel_type })), +/// } +/// } +/// +/// // You can tweak how you convert expressions for your target system +/// fn consume_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result { +/// // add your own encoding for Between +/// todo!() +/// } +/// +/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait +/// fn consume_extension(&mut self, _plan: &Extension) -> Result> { +/// // implement your own serializer into Substrait +/// todo!() +/// } +/// } +/// ``` pub trait SubstraitProducer: Send + Sync + Sized { + /// Within a Substrait plan, functions are referenced using function anchors that are stored at + /// the top level of the [Plan] within + /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) + /// messages. + /// + /// When given a function signature, this method should return the existing anchor for it if + /// there is one. Otherwise, it should generate a new anchor. + fn register_function(&mut self, signature: String) -> u32; + + /// Consume the producer to generate the [Extensions] for the Substrait plan based on the + /// functions that have been registered fn get_extensions(self) -> Extensions; - fn register_function(&mut self, signature: String) -> u32; + // Logical Plan Methods + // There is one method per LogicalPlan to allow for easy overriding of producer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. - // Logical Plans fn consume_plan(&mut self, plan: &LogicalPlan) -> Result> { to_substrait_rel(self, plan) } @@ -175,7 +252,11 @@ pub trait SubstraitProducer: Send + Sync + Sized { substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") } - // Expressions + // Expression Methods + // There is one method per DataFusion Expr to allow for easy overriding of producer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + fn consume_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) -> Result { to_substrait_rex(self, expr, schema) } @@ -212,7 +293,7 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_like(self, like, schema) } - /// Handles: Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative + /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative fn consume_unary_expr( &mut self, expr: &Expr, @@ -253,7 +334,7 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_scalar_function(self, scalar_fn, schema) } - fn consume_agg_function( + fn consume_aggregate_function( &mut self, agg_fn: &expr::AggregateFunction, schema: &DFSchemaRef, @@ -301,14 +382,14 @@ impl<'a> DefaultSubstraitProducer<'a> { } impl SubstraitProducer for DefaultSubstraitProducer<'_> { - fn get_extensions(self) -> Extensions { - self.extensions - } - fn register_function(&mut self, fn_name: String) -> u32 { self.extensions.register_function(fn_name) } + fn get_extensions(self) -> Extensions { + self.extensions + } + fn consume_extension(&mut self, plan: &Extension) -> Result> { let extension_bytes = self .state @@ -1164,7 +1245,7 @@ pub fn to_substrait_agg_measure( ) -> Result { match expr { Expr::AggregateFunction(agg_fn) => from_aggregate_function(producer, agg_fn, schema), - Expr::Alias(Alias{expr,..}) => { + Expr::Alias(Alias { expr, .. }) => { to_substrait_agg_measure(producer, expr, schema) } _ => internal_err!( @@ -2631,7 +2712,7 @@ mod test { ], false, ) - .into(), + .into(), false, ))?; @@ -2640,7 +2721,7 @@ mod test { Field::new("c0", DataType::Int32, true), Field::new("c1", DataType::Utf8, true), ] - .into(), + .into(), ))?; round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 772bf2e7ad8e..7045729493b1 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -583,7 +583,7 @@ async fn self_join_introduces_aliases() -> Result<()> { \n TableScan: data projection=[b, c]", false, ) - .await + .await } #[tokio::test]