diff --git a/tests/pattern_rewriter/test_pattern_rewriter.py b/tests/pattern_rewriter/test_pattern_rewriter.py index 0d2d1d1190..b1b6270318 100644 --- a/tests/pattern_rewriter/test_pattern_rewriter.py +++ b/tests/pattern_rewriter/test_pattern_rewriter.py @@ -4,6 +4,7 @@ import pytest from conftest import assert_print_op +from xdsl.builder import ImplicitBuilder from xdsl.context import MLContext from xdsl.dialects import test from xdsl.dialects.arith import AddiOp, Arith, ConstantOp, MuliOp @@ -14,6 +15,7 @@ IntegerType, ModuleOp, StringAttr, + UnitAttr, i32, i64, ) @@ -1411,6 +1413,44 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter): ) +def test_pattern_rewriter_as_op_builder(): + """Test that the PatternRewriter works as an OpBuilder.""" + prog = """ +"builtin.module"() ({ + "test.op"() : () -> () + "test.op"() {"nomatch"} : () -> () + "test.op"() : () -> () +}) : () -> ()""" + expected = """ +"builtin.module"() ({ + "test.op"() {"inserted"} : () -> () + "test.op"() {"replaced"} : () -> () + "test.op"() {"nomatch"} : () -> () + "test.op"() {"inserted"} : () -> () + "test.op"() {"replaced"} : () -> () +}) : () -> ()""" + + class Rewrite(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter): + if "nomatch" in op.attributes: + return + with ImplicitBuilder(rewriter): + test.TestOp.create(attributes={"inserted": UnitAttr()}) + rewriter.replace_matched_op( + test.TestOp.create(attributes={"replaced": UnitAttr()}) + ) + + rewrite_and_compare( + prog, + expected, + PatternRewriteWalker(Rewrite(), apply_recursively=False), + op_inserted=4, + op_removed=2, + op_replaced=2, + ) + + def test_type_conversion(): """Test rewriter on ops without results""" prog = """\ diff --git a/xdsl/backend/riscv/riscv_scf_to_asm.py b/xdsl/backend/riscv/riscv_scf_to_asm.py index 9214853c66..a4b4dcc972 100644 --- a/xdsl/backend/riscv/riscv_scf_to_asm.py +++ b/xdsl/backend/riscv/riscv_scf_to_asm.py @@ -69,15 +69,15 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /): yield_op = body.last_op assert isinstance(yield_op, riscv_scf.YieldOp) - body.insert_ops_after( + rewriter.insert_op( [ riscv.AddOp(get_loop_var, op.step, rd=loop_var_reg), riscv.BltOp(get_loop_var, op.ub, scf_body), riscv.LabelOp(scf_body_end), ], - yield_op, + InsertPoint.after(yield_op), ) - body.erase_op(yield_op) + rewriter.erase_op(yield_op) # We know that the body is not empty now. assert body.first_op is not None diff --git a/xdsl/pattern_rewriter.py b/xdsl/pattern_rewriter.py index 06a6bd5239..ec10ff20c5 100644 --- a/xdsl/pattern_rewriter.py +++ b/xdsl/pattern_rewriter.py @@ -10,7 +10,7 @@ from typing_extensions import deprecated -from xdsl.builder import BuilderListener +from xdsl.builder import Builder, BuilderListener from xdsl.dialects.builtin import ArrayAttr, DictionaryAttr, ModuleOp from xdsl.ir import ( Attribute, @@ -77,8 +77,8 @@ def extend_from_listener(self, listener: BuilderListener | PatternRewriterListen ) -@dataclass(eq=False) -class PatternRewriter(PatternRewriterListener): +@dataclass(eq=False, init=False) +class PatternRewriter(Builder, PatternRewriterListener): """ A rewriter used during pattern matching. Once an operation is matched, this rewriter is used to apply @@ -91,6 +91,11 @@ class PatternRewriter(PatternRewriterListener): has_done_action: bool = field(default=False, init=False) """Has the rewriter done any action during the current match.""" + def __init__(self, current_operation: Operation): + PatternRewriterListener.__init__(self) + self.current_operation = current_operation + Builder.__init__(self, InsertPoint.before(current_operation)) + def insert_op( self, op: Operation | Sequence[Operation], insertion_point: InsertPoint ): @@ -833,6 +838,7 @@ def _process_worklist(self, listener: PatternRewriterListener) -> bool: # Reset the rewriter on `op` rewriter.has_done_action = False rewriter.current_operation = op + rewriter.insertion_point = InsertPoint.before(op) # Apply the pattern on the operation try: diff --git a/xdsl/transforms/convert_scf_to_openmp.py b/xdsl/transforms/convert_scf_to_openmp.py index 8a13c29042..c60395ccd8 100644 --- a/xdsl/transforms/convert_scf_to_openmp.py +++ b/xdsl/transforms/convert_scf_to_openmp.py @@ -14,6 +14,7 @@ RewritePattern, op_type_rewrite_pattern, ) +from xdsl.rewriter import InsertPoint @dataclass @@ -43,7 +44,8 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /): regions=[Region(Block())], operands=[[], [], [], [], [], []], ) - with ImplicitBuilder(parallel.region): + rewriter.insertion_point = InsertPoint.at_end(parallel.region.block) + with ImplicitBuilder(rewriter): if self.chunk is None: chunk_op = [] else: @@ -65,7 +67,9 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /): omp.ScheduleKind(self.schedule) ) omp.TerminatorOp() - with ImplicitBuilder(wsloop.body): + + rewriter.insertion_point = InsertPoint.at_end(wsloop.body.block) + with ImplicitBuilder(rewriter): loop_nest = omp.LoopNestOp( operands=[ loop.lowerBound[:collapse], @@ -75,15 +79,21 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /): regions=[Region(Block(arg_types=[IndexType()] * collapse))], ) omp.TerminatorOp() - with ImplicitBuilder(loop_nest.body): + + rewriter.insertion_point = InsertPoint.at_end(loop_nest.body.block) + with ImplicitBuilder(rewriter): scope = memref.AllocaScopeOp(result_types=[[]], regions=[Region(Block())]) omp.YieldOp() - with ImplicitBuilder(scope.scope): + + rewriter.insertion_point = InsertPoint.at_end(scope.scope.block) + with ImplicitBuilder(rewriter): scope_terminator = memref.AllocaScopeReturnOp(operands=[[]]) + for newarg, oldarg in zip( loop_nest.body.block.args, loop.body.block.args[:collapse] ): oldarg.replace_by(newarg) + for _ in range(collapse): loop.body.block.erase_arg(loop.body.block.args[0]) if collapse < len(loop.lowerBound): @@ -96,8 +106,9 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /): new_ops = [new_loop] else: new_ops = [loop.body.block.detach_op(o) for o in loop.body.block.ops] - new_ops.pop() - scope.scope.block.insert_ops_before(new_ops, scope_terminator) + last_op = new_ops.pop() + rewriter.erase_op(last_op) + rewriter.insert_op(new_ops, InsertPoint.before(scope_terminator)) rewriter.replace_matched_op(parallel)