Skip to content

Commit

Permalink
Remove macros in favour of LegacyRewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
blaginin committed Oct 29, 2024
1 parent 6a4dd2c commit b6fa0a7
Showing 1 changed file with 151 additions and 104 deletions.
255 changes: 151 additions & 104 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! [`TreeNode`] for visiting and rewriting expression and plan trees
use crate::Result;
use std::marker::PhantomData;
use std::sync::Arc;

/// These macros are used to determine continuation during transforming traversals.
Expand Down Expand Up @@ -912,104 +913,6 @@ macro_rules! map_until_stop_and_collect {
}}
}

macro_rules! rewrite_recursive {
($START:ident, $NAME:ident, $TRANSFORM_UP:expr, $TRANSFORM_DOWN:expr) => {
let mut queue = vec![ProcessingState::NotStarted($START)];

while let Some(item) = queue.pop() {
match item {
ProcessingState::NotStarted($NAME) => {
let node = $TRANSFORM_DOWN?;

queue.push(match node.tnr {
TreeNodeRecursion::Continue => {
ProcessingState::ProcessingChildren {
non_processed_children: node
.data
.arc_children()
.into_iter()
.cloned()
.rev()
.collect(),
item: node,
processed_children: vec![],
}
}
TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren(
node.with_tnr(TreeNodeRecursion::Continue),
),
TreeNodeRecursion::Stop => {
ProcessingState::ProcessedAllChildren(node)
}
})
}
ProcessingState::ProcessingChildren {
mut item,
mut non_processed_children,
mut processed_children,
} => match item.tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
if let Some(non_processed_item) = non_processed_children.pop() {
queue.push(ProcessingState::ProcessingChildren {
item,
non_processed_children,
processed_children,
});
queue.push(ProcessingState::NotStarted(non_processed_item));
} else {
item.transformed |=
processed_children.iter().any(|item| item.transformed);
item.data = item.data.with_new_arc_children(
processed_children.into_iter().map(|c| c.data).collect(),
)?;
queue.push(ProcessingState::ProcessedAllChildren(item))
}
}
TreeNodeRecursion::Stop => {
processed_children.extend(
non_processed_children
.into_iter()
.rev()
.map(Transformed::no),
);
item.transformed |=
processed_children.iter().any(|item| item.transformed);
item.data = item.data.with_new_arc_children(
processed_children.into_iter().map(|c| c.data).collect(),
)?;
queue.push(ProcessingState::ProcessedAllChildren(item));
}
},
ProcessingState::ProcessedAllChildren(node) => {
let node = node.transform_parent(|$NAME| $TRANSFORM_UP)?;

if let Some(ProcessingState::ProcessingChildren {
item: mut parent_node,
non_processed_children,
mut processed_children,
..
}) = queue.pop()
{
parent_node.tnr = node.tnr;
processed_children.push(node);

queue.push(ProcessingState::ProcessingChildren {
item: parent_node,
non_processed_children,
processed_children,
})
} else {
debug_assert_eq!(queue.len(), 0);
return Ok(node);
}
}
}
}

unreachable!();
};
}

/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
///
/// # Example
Expand Down Expand Up @@ -1063,6 +966,59 @@ pub trait DynTreeNode {
) -> Result<Arc<Self>>;
}

pub struct LegacyRewriter<
FD: FnMut(Node) -> Result<Transformed<Node>>,
FU: FnMut(Node) -> Result<Transformed<Node>>,
Node: TreeNode,
> {
f_down_func: FD,
f_up_func: FU,
_node: PhantomData<Node>,
}

impl<
FD: FnMut(Node) -> Result<Transformed<Node>>,
FU: FnMut(Node) -> Result<Transformed<Node>>,
Node: TreeNode,
> LegacyRewriter<FD, FU, Node>
{
pub fn new(f_down_func: FD, f_up_func: FU) -> Self {
Self {
f_down_func,
f_up_func,
_node: PhantomData,
}
}
}
impl<
FD: FnMut(Node) -> Result<Transformed<Node>>,
FU: FnMut(Node) -> Result<Transformed<Node>>,
Node: TreeNode,
> TreeNodeRewriter for LegacyRewriter<FD, FU, Node>
{
type Node = Node;

fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
(self.f_down_func)(node)
}

fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
(self.f_up_func)(node)
}
}

macro_rules! update_rec_node {
($NAME:ident, $CHILDREN:ident) => {{
$NAME.transformed |= $CHILDREN.iter().any(|item| item.transformed);

$NAME.data = $NAME
.data
.with_new_arc_children($CHILDREN.into_iter().map(|c| c.data).collect())?;

$NAME
}};
}

/// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`]
/// (such as [`Arc<dyn PhysicalExpr>`]).
impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
Expand Down Expand Up @@ -1102,30 +1058,121 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
FU: FnMut(Self) -> Result<Transformed<Self>>,
>(
self,
mut f_down: FD,
mut f_up: FU,
f_down: FD,
f_up: FU,
) -> Result<Transformed<Self>> {
rewrite_recursive!(self, node, f_up(node), f_down(node));
self.rewrite(&mut LegacyRewriter::new(f_down, f_up))
}

fn transform_down<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: F,
) -> Result<Transformed<Self>> {
self.transform_down_up(f, |node| Ok(Transformed::no(node)))
self.rewrite(&mut LegacyRewriter::new(f, |node| {
Ok(Transformed::no(node))
}))
}

fn transform_up<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: F,
) -> Result<Transformed<Self>> {
self.transform_down_up(|node| Ok(Transformed::no(node)), f)
self.rewrite(&mut LegacyRewriter::new(
|node| Ok(Transformed::no(node)),
f,
))
}
fn rewrite<R: TreeNodeRewriter<Node = Self>>(
self,
rewriter: &mut R,
) -> Result<Transformed<Self>> {
rewrite_recursive!(self, node, rewriter.f_up(node), rewriter.f_down(node));
let mut queue = vec![ProcessingState::NotStarted(self)];

while let Some(item) = queue.pop() {
match item {
ProcessingState::NotStarted(node) => {
let node = rewriter.f_down(node)?;

queue.push(match node.tnr {
TreeNodeRecursion::Continue => {
ProcessingState::ProcessingChildren {
non_processed_children: node
.data
.arc_children()
.into_iter()
.cloned()
.rev()
.collect(),
item: node,
processed_children: vec![],
}
}
TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren(
node.with_tnr(TreeNodeRecursion::Continue),
),
TreeNodeRecursion::Stop => {
ProcessingState::ProcessedAllChildren(node)
}
})
}
ProcessingState::ProcessingChildren {
mut item,
mut non_processed_children,
mut processed_children,
} => match item.tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
if let Some(non_processed_item) = non_processed_children.pop() {
queue.push(ProcessingState::ProcessingChildren {
item,
non_processed_children,
processed_children,
});
queue.push(ProcessingState::NotStarted(non_processed_item));
} else {
queue.push(ProcessingState::ProcessedAllChildren(
update_rec_node!(item, processed_children),
))
}
}
TreeNodeRecursion::Stop => {
processed_children.extend(
non_processed_children
.into_iter()
.rev()
.map(Transformed::no),
);
queue.push(ProcessingState::ProcessedAllChildren(
update_rec_node!(item, processed_children),
));
}
},
ProcessingState::ProcessedAllChildren(node) => {
let node = node.transform_parent(|n| rewriter.f_up(n))?;

if let Some(ProcessingState::ProcessingChildren {
item: mut parent_node,
non_processed_children,
mut processed_children,
..
}) = queue.pop()
{
parent_node.tnr = node.tnr;
processed_children.push(node);

queue.push(ProcessingState::ProcessingChildren {
item: parent_node,
non_processed_children,
processed_children,
})
} else {
debug_assert_eq!(queue.len(), 0);
return Ok(node);
}
}
}
}

unreachable!();
}

fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(
Expand Down

0 comments on commit b6fa0a7

Please sign in to comment.