From 6a4dd2c5eee28a91db0a989be73fe6650f547b99 Mon Sep 17 00:00:00 2001 From: blaginin Date: Tue, 29 Oct 2024 20:59:27 +0000 Subject: [PATCH] Rewrite all other methods in DynTreeNode to be iterative --- datafusion/common/src/tree_node.rs | 216 ++++++++++++++++------------- 1 file changed, 123 insertions(+), 93 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index be91b0a5e325..09a4ad619a48 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -912,6 +912,104 @@ macro_rules! map_until_stop_and_collect { }} } +macro_rules! rewrite_recursive { + ($START:ident, $NAME:ident, $TRANSFORM_UP:expr, $TRANSFORM_DOWN:expr) => { + let mut queue = vec![ProcessingState::NotStarted($START)]; + + while let Some(item) = queue.pop() { + match item { + ProcessingState::NotStarted($NAME) => { + let node = $TRANSFORM_DOWN?; + + queue.push(match node.tnr { + TreeNodeRecursion::Continue => { + ProcessingState::ProcessingChildren { + non_processed_children: node + .data + .arc_children() + .into_iter() + .cloned() + .rev() + .collect(), + item: node, + processed_children: vec![], + } + } + TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren( + node.with_tnr(TreeNodeRecursion::Continue), + ), + TreeNodeRecursion::Stop => { + ProcessingState::ProcessedAllChildren(node) + } + }) + } + ProcessingState::ProcessingChildren { + mut item, + mut non_processed_children, + mut processed_children, + } => match item.tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + if let Some(non_processed_item) = non_processed_children.pop() { + queue.push(ProcessingState::ProcessingChildren { + item, + non_processed_children, + processed_children, + }); + queue.push(ProcessingState::NotStarted(non_processed_item)); + } else { + item.transformed |= + processed_children.iter().any(|item| item.transformed); + item.data = item.data.with_new_arc_children( + processed_children.into_iter().map(|c| c.data).collect(), + )?; + queue.push(ProcessingState::ProcessedAllChildren(item)) + } + } + TreeNodeRecursion::Stop => { + processed_children.extend( + non_processed_children + .into_iter() + .rev() + .map(Transformed::no), + ); + item.transformed |= + processed_children.iter().any(|item| item.transformed); + item.data = item.data.with_new_arc_children( + processed_children.into_iter().map(|c| c.data).collect(), + )?; + queue.push(ProcessingState::ProcessedAllChildren(item)); + } + }, + ProcessingState::ProcessedAllChildren(node) => { + let node = node.transform_parent(|$NAME| $TRANSFORM_UP)?; + + if let Some(ProcessingState::ProcessingChildren { + item: mut parent_node, + non_processed_children, + mut processed_children, + .. + }) = queue.pop() + { + parent_node.tnr = node.tnr; + processed_children.push(node); + + queue.push(ProcessingState::ProcessingChildren { + item: parent_node, + non_processed_children, + processed_children, + }) + } else { + debug_assert_eq!(queue.len(), 0); + return Ok(node); + } + } + } + } + + unreachable!(); + }; +} + /// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. /// /// # Example @@ -999,103 +1097,35 @@ impl TreeNode for Arc { } } - fn rewrite>( + fn transform_down_up< + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, + >( self, - rewriter: &mut R, + mut f_down: FD, + mut f_up: FU, ) -> Result> { - let mut queue = vec![ProcessingState::NotStarted(self)]; - - while let Some(item) = queue.pop() { - match item { - ProcessingState::NotStarted(node) => { - let node = rewriter.f_down(node)?; - - queue.push(match node.tnr { - TreeNodeRecursion::Continue => { - ProcessingState::ProcessingChildren { - non_processed_children: node - .data - .arc_children() - .into_iter() - .cloned() - .rev() - .collect(), - item: node, - processed_children: vec![], - } - } - TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren( - node.with_tnr(TreeNodeRecursion::Continue), - ), - TreeNodeRecursion::Stop => { - ProcessingState::ProcessedAllChildren(node) - } - }) - } - ProcessingState::ProcessingChildren { - mut item, - mut non_processed_children, - mut processed_children, - } => match item.tnr { - TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { - if let Some(non_processed_item) = non_processed_children.pop() { - queue.push(ProcessingState::ProcessingChildren { - item, - non_processed_children, - processed_children, - }); - queue.push(ProcessingState::NotStarted(non_processed_item)); - } else { - item.transformed = - processed_children.iter().any(|item| item.transformed); - item.data = item.data.with_new_arc_children( - processed_children.into_iter().map(|c| c.data).collect(), - )?; - queue.push(ProcessingState::ProcessedAllChildren(item)) - } - } - TreeNodeRecursion::Stop => { - processed_children.extend( - non_processed_children - .into_iter() - .rev() - .map(Transformed::no), - ); - item.transformed = - processed_children.iter().any(|item| item.transformed); - item.data = item.data.with_new_arc_children( - processed_children.into_iter().map(|c| c.data).collect(), - )?; - queue.push(ProcessingState::ProcessedAllChildren(item)); - } - }, - ProcessingState::ProcessedAllChildren(node) => { - let node = node.transform_parent(|n| rewriter.f_up(n))?; - - if let Some(ProcessingState::ProcessingChildren { - item: mut parent_node, - non_processed_children, - mut processed_children, - .. - }) = queue.pop() - { - parent_node.tnr = node.tnr; - processed_children.push(node); + rewrite_recursive!(self, node, f_up(node), f_down(node)); + } - queue.push(ProcessingState::ProcessingChildren { - item: parent_node, - non_processed_children, - processed_children, - }) - } else { - debug_assert_eq!(queue.len(), 0); - return Ok(node); - } - } - } - } + fn transform_down Result>>( + self, + f: F, + ) -> Result> { + self.transform_down_up(f, |node| Ok(Transformed::no(node))) + } - unreachable!(); + fn transform_up Result>>( + self, + f: F, + ) -> Result> { + self.transform_down_up(|node| Ok(Transformed::no(node)), f) + } + fn rewrite>( + self, + rewriter: &mut R, + ) -> Result> { + rewrite_recursive!(self, node, rewriter.f_up(node), rewriter.f_down(node)); } fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(