diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 09a4ad619a48f..560348da8a994 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -18,6 +18,7 @@ //! [`TreeNode`] for visiting and rewriting expression and plan trees use crate::Result; +use std::marker::PhantomData; use std::sync::Arc; /// These macros are used to determine continuation during transforming traversals. @@ -912,104 +913,6 @@ macro_rules! map_until_stop_and_collect { }} } -macro_rules! rewrite_recursive { - ($START:ident, $NAME:ident, $TRANSFORM_UP:expr, $TRANSFORM_DOWN:expr) => { - let mut queue = vec![ProcessingState::NotStarted($START)]; - - while let Some(item) = queue.pop() { - match item { - ProcessingState::NotStarted($NAME) => { - let node = $TRANSFORM_DOWN?; - - queue.push(match node.tnr { - TreeNodeRecursion::Continue => { - ProcessingState::ProcessingChildren { - non_processed_children: node - .data - .arc_children() - .into_iter() - .cloned() - .rev() - .collect(), - item: node, - processed_children: vec![], - } - } - TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren( - node.with_tnr(TreeNodeRecursion::Continue), - ), - TreeNodeRecursion::Stop => { - ProcessingState::ProcessedAllChildren(node) - } - }) - } - ProcessingState::ProcessingChildren { - mut item, - mut non_processed_children, - mut processed_children, - } => match item.tnr { - TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { - if let Some(non_processed_item) = non_processed_children.pop() { - queue.push(ProcessingState::ProcessingChildren { - item, - non_processed_children, - processed_children, - }); - queue.push(ProcessingState::NotStarted(non_processed_item)); - } else { - item.transformed |= - processed_children.iter().any(|item| item.transformed); - item.data = item.data.with_new_arc_children( - processed_children.into_iter().map(|c| c.data).collect(), - )?; - queue.push(ProcessingState::ProcessedAllChildren(item)) - } - } - TreeNodeRecursion::Stop => { - processed_children.extend( - non_processed_children - .into_iter() - .rev() - .map(Transformed::no), - ); - item.transformed |= - processed_children.iter().any(|item| item.transformed); - item.data = item.data.with_new_arc_children( - processed_children.into_iter().map(|c| c.data).collect(), - )?; - queue.push(ProcessingState::ProcessedAllChildren(item)); - } - }, - ProcessingState::ProcessedAllChildren(node) => { - let node = node.transform_parent(|$NAME| $TRANSFORM_UP)?; - - if let Some(ProcessingState::ProcessingChildren { - item: mut parent_node, - non_processed_children, - mut processed_children, - .. - }) = queue.pop() - { - parent_node.tnr = node.tnr; - processed_children.push(node); - - queue.push(ProcessingState::ProcessingChildren { - item: parent_node, - non_processed_children, - processed_children, - }) - } else { - debug_assert_eq!(queue.len(), 0); - return Ok(node); - } - } - } - } - - unreachable!(); - }; -} - /// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. /// /// # Example @@ -1063,6 +966,59 @@ pub trait DynTreeNode { ) -> Result>; } +pub struct LegacyRewriter< + FD: FnMut(Node) -> Result>, + FU: FnMut(Node) -> Result>, + Node: TreeNode, +> { + f_down_func: FD, + f_up_func: FU, + _node: PhantomData, +} + +impl< + FD: FnMut(Node) -> Result>, + FU: FnMut(Node) -> Result>, + Node: TreeNode, + > LegacyRewriter +{ + pub fn new(f_down_func: FD, f_up_func: FU) -> Self { + Self { + f_down_func, + f_up_func, + _node: PhantomData, + } + } +} +impl< + FD: FnMut(Node) -> Result>, + FU: FnMut(Node) -> Result>, + Node: TreeNode, + > TreeNodeRewriter for LegacyRewriter +{ + type Node = Node; + + fn f_down(&mut self, node: Self::Node) -> Result> { + (self.f_down_func)(node) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + (self.f_up_func)(node) + } +} + +macro_rules! update_rec_node { + ($NAME:ident, $CHILDREN:ident) => {{ + $NAME.transformed |= $CHILDREN.iter().any(|item| item.transformed); + + $NAME.data = $NAME + .data + .with_new_arc_children($CHILDREN.into_iter().map(|c| c.data).collect())?; + + $NAME + }}; +} + /// Blanket implementation for any `Arc` where `T` implements [`DynTreeNode`] /// (such as [`Arc`]). impl TreeNode for Arc { @@ -1102,30 +1058,121 @@ impl TreeNode for Arc { FU: FnMut(Self) -> Result>, >( self, - mut f_down: FD, - mut f_up: FU, + f_down: FD, + f_up: FU, ) -> Result> { - rewrite_recursive!(self, node, f_up(node), f_down(node)); + self.rewrite(&mut LegacyRewriter::new(f_down, f_up)) } fn transform_down Result>>( self, f: F, ) -> Result> { - self.transform_down_up(f, |node| Ok(Transformed::no(node))) + self.rewrite(&mut LegacyRewriter::new(f, |node| { + Ok(Transformed::no(node)) + })) } fn transform_up Result>>( self, f: F, ) -> Result> { - self.transform_down_up(|node| Ok(Transformed::no(node)), f) + self.rewrite(&mut LegacyRewriter::new( + |node| Ok(Transformed::no(node)), + f, + )) } fn rewrite>( self, rewriter: &mut R, ) -> Result> { - rewrite_recursive!(self, node, rewriter.f_up(node), rewriter.f_down(node)); + let mut queue = vec![ProcessingState::NotStarted(self)]; + + while let Some(item) = queue.pop() { + match item { + ProcessingState::NotStarted(node) => { + let node = rewriter.f_down(node)?; + + queue.push(match node.tnr { + TreeNodeRecursion::Continue => { + ProcessingState::ProcessingChildren { + non_processed_children: node + .data + .arc_children() + .into_iter() + .cloned() + .rev() + .collect(), + item: node, + processed_children: vec![], + } + } + TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren( + node.with_tnr(TreeNodeRecursion::Continue), + ), + TreeNodeRecursion::Stop => { + ProcessingState::ProcessedAllChildren(node) + } + }) + } + ProcessingState::ProcessingChildren { + mut item, + mut non_processed_children, + mut processed_children, + } => match item.tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + if let Some(non_processed_item) = non_processed_children.pop() { + queue.push(ProcessingState::ProcessingChildren { + item, + non_processed_children, + processed_children, + }); + queue.push(ProcessingState::NotStarted(non_processed_item)); + } else { + queue.push(ProcessingState::ProcessedAllChildren( + update_rec_node!(item, processed_children), + )) + } + } + TreeNodeRecursion::Stop => { + processed_children.extend( + non_processed_children + .into_iter() + .rev() + .map(Transformed::no), + ); + queue.push(ProcessingState::ProcessedAllChildren( + update_rec_node!(item, processed_children), + )); + } + }, + ProcessingState::ProcessedAllChildren(node) => { + let node = node.transform_parent(|n| rewriter.f_up(n))?; + + if let Some(ProcessingState::ProcessingChildren { + item: mut parent_node, + non_processed_children, + mut processed_children, + .. + }) = queue.pop() + { + parent_node.tnr = node.tnr; + processed_children.push(node); + + queue.push(ProcessingState::ProcessingChildren { + item: parent_node, + non_processed_children, + processed_children, + }) + } else { + debug_assert_eq!(queue.len(), 0); + return Ok(node); + } + } + } + } + + unreachable!(); } fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(