From e11492338a38a86e877373e39ef47de42a61c60c Mon Sep 17 00:00:00 2001 From: ZoeLeibowitz Date: Thu, 19 Dec 2024 13:56:00 +0000 Subject: [PATCH 1/2] compiler: Add groupby mode to MapNodes --- devito/ir/iet/visitors.py | 7 +++++-- tests/test_visitors.py | 22 +++++++++++++++++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 9b068a7d25..505fe2e001 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -873,6 +873,7 @@ def default_retval(cls): the nodes of type ``child_types`` retrieved by the search. This behaviour can be changed through this parameter. Accepted values are: - 'immediate': only the closest matching ancestor is mapped. + - 'groupby': the matching ancestors are grouped together as a single key. """ def __init__(self, parent_type=None, child_types=None, mode=None): @@ -885,7 +886,7 @@ def __init__(self, parent_type=None, child_types=None, mode=None): assert issubclass(parent_type, Node) self.parent_type = parent_type self.child_types = as_tuple(child_types) or (Call, Expression) - assert mode in (None, 'immediate') + assert mode in (None, 'immediate', 'groupby') self.mode = mode def visit_object(self, o, ret=None, **kwargs): @@ -902,7 +903,9 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False): if parents is None: parents = [] if isinstance(o, self.child_types): - if self.mode == 'immediate': + if self.mode == 'groupby': + ret.setdefault(as_tuple(parents), []).append(o) + elif self.mode == 'immediate': if in_parent: ret.setdefault(parents[-1], []).append(o) else: diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 0d003d68a0..93cb9c31e2 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -6,7 +6,7 @@ from devito.ir.equations import DummyEq from devito.ir.iet import (Block, Expression, Callable, FindNodes, FindSections, FindSymbols, IsPerfectIteration, Transformer, - Conditional, printAST, Iteration) + Conditional, printAST, Iteration, MapNodes, Call) from devito.types import SpaceDimension, Array @@ -376,3 +376,23 @@ def test_find_symbols_with_duplicates(): # So we expect FindSymbols to catch five Indexeds in total symbols = FindSymbols('indexeds').visit(op) assert len(symbols) == 5 + + +def test_map_nodes(block1): + """ + Tests MapNodes visitor. When MapNodes is created with mode='groupby', + matching ancestors are grouped together under a single key. + This can be useful, for example, when applying transformations to the + outermost Iteration containing a specific node. + """ + map_nodes = MapNodes(Iteration, Expression, mode='groupby').visit(block1) + + assert len(map_nodes.keys()) == 1 + + for iters, (expr,) in map_nodes.items(): + # Replace the outermost `Iteration` with a placeholder, represented in + # this example by a `Call` to a `Callable` + callback = Callable('solver', iters[0], 'void', ()) + processed = Transformer({iters[0]: Call(callback.name)}).visit(block1) + + assert str(processed) == 'solver();' From 0dae417e9794f8c903c99546446f01b6983e50fd Mon Sep 17 00:00:00 2001 From: ZoeLeibowitz Date: Thu, 19 Dec 2024 14:04:05 +0000 Subject: [PATCH 2/2] misc: Comments --- tests/test_visitors.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 93cb9c31e2..937b33d09f 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -390,8 +390,7 @@ def test_map_nodes(block1): assert len(map_nodes.keys()) == 1 for iters, (expr,) in map_nodes.items(): - # Replace the outermost `Iteration` with a placeholder, represented in - # this example by a `Call` to a `Callable` + # Replace the outermost `Iteration` with a `Call` callback = Callable('solver', iters[0], 'void', ()) processed = Transformer({iters[0]: Call(callback.name)}).visit(block1)