Skip to content

Commit

Permalink
compiler: Add groupby mode to MapNodes
Browse files Browse the repository at this point in the history
  • Loading branch information
ZoeLeibowitz committed Dec 19, 2024
1 parent 5a15896 commit e114923
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
7 changes: 5 additions & 2 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,7 @@ def default_retval(cls):
the nodes of type ``child_types`` retrieved by the search. This behaviour
can be changed through this parameter. Accepted values are:
- 'immediate': only the closest matching ancestor is mapped.
- 'groupby': the matching ancestors are grouped together as a single key.
"""

def __init__(self, parent_type=None, child_types=None, mode=None):
Expand All @@ -885,7 +886,7 @@ def __init__(self, parent_type=None, child_types=None, mode=None):
assert issubclass(parent_type, Node)
self.parent_type = parent_type
self.child_types = as_tuple(child_types) or (Call, Expression)
assert mode in (None, 'immediate')
assert mode in (None, 'immediate', 'groupby')
self.mode = mode

def visit_object(self, o, ret=None, **kwargs):
Expand All @@ -902,7 +903,9 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
if parents is None:
parents = []
if isinstance(o, self.child_types):
if self.mode == 'immediate':
if self.mode == 'groupby':
ret.setdefault(as_tuple(parents), []).append(o)
elif self.mode == 'immediate':
if in_parent:
ret.setdefault(parents[-1], []).append(o)
else:
Expand Down
22 changes: 21 additions & 1 deletion tests/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from devito.ir.equations import DummyEq
from devito.ir.iet import (Block, Expression, Callable, FindNodes, FindSections,
FindSymbols, IsPerfectIteration, Transformer,
Conditional, printAST, Iteration)
Conditional, printAST, Iteration, MapNodes, Call)
from devito.types import SpaceDimension, Array


Expand Down Expand Up @@ -376,3 +376,23 @@ def test_find_symbols_with_duplicates():
# So we expect FindSymbols to catch five Indexeds in total
symbols = FindSymbols('indexeds').visit(op)
assert len(symbols) == 5


def test_map_nodes(block1):
"""
Tests MapNodes visitor. When MapNodes is created with mode='groupby',
matching ancestors are grouped together under a single key.
This can be useful, for example, when applying transformations to the
outermost Iteration containing a specific node.
"""
map_nodes = MapNodes(Iteration, Expression, mode='groupby').visit(block1)

assert len(map_nodes.keys()) == 1

for iters, (expr,) in map_nodes.items():
# Replace the outermost `Iteration` with a placeholder, represented in
# this example by a `Call` to a `Callable`
callback = Callable('solver', iters[0], 'void', ())
processed = Transformer({iters[0]: Call(callback.name)}).visit(block1)

assert str(processed) == 'solver();'

0 comments on commit e114923

Please sign in to comment.