Skip to content

Commit

Permalink
core: Make PatternRewriter a Builder
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed Dec 30, 2024
1 parent 308bd7b commit a3414af
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 12 deletions.
40 changes: 40 additions & 0 deletions tests/pattern_rewriter/test_pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from conftest import assert_print_op

from xdsl.builder import ImplicitBuilder
from xdsl.context import MLContext
from xdsl.dialects import test
from xdsl.dialects.arith import AddiOp, Arith, ConstantOp, MuliOp
Expand All @@ -14,6 +15,7 @@
IntegerType,
ModuleOp,
StringAttr,
UnitAttr,
i32,
i64,
)
Expand Down Expand Up @@ -1411,6 +1413,44 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
)


def test_pattern_rewriter_as_op_builder():
"""Test that the PatternRewriter works as an OpBuilder."""
prog = """
"builtin.module"() ({
"test.op"() : () -> ()
"test.op"() {"nomatch"} : () -> ()
"test.op"() : () -> ()
}) : () -> ()"""
expected = """
"builtin.module"() ({
"test.op"() {"inserted"} : () -> ()
"test.op"() {"replaced"} : () -> ()
"test.op"() {"nomatch"} : () -> ()
"test.op"() {"inserted"} : () -> ()
"test.op"() {"replaced"} : () -> ()
}) : () -> ()"""

class Rewrite(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
if "nomatch" in op.attributes:
return
with ImplicitBuilder(rewriter):
test.TestOp.create(attributes={"inserted": UnitAttr()})
rewriter.replace_matched_op(
test.TestOp.create(attributes={"replaced": UnitAttr()})
)

rewrite_and_compare(
prog,
expected,
PatternRewriteWalker(Rewrite(), apply_recursively=False),
op_inserted=4,
op_removed=2,
op_replaced=2,
)


def test_type_conversion():
"""Test rewriter on ops without results"""
prog = """\
Expand Down
6 changes: 3 additions & 3 deletions xdsl/backend/riscv/riscv_scf_to_asm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /):
yield_op = body.last_op
assert isinstance(yield_op, riscv_scf.YieldOp)

body.insert_ops_after(
rewriter.insert_op(
[
riscv.AddOp(get_loop_var, op.step, rd=loop_var_reg),
riscv.BltOp(get_loop_var, op.ub, scf_body),
riscv.LabelOp(scf_body_end),
],
yield_op,
InsertPoint.after(yield_op),
)
body.erase_op(yield_op)
rewriter.erase_op(yield_op)

# We know that the body is not empty now.
assert body.first_op is not None
Expand Down
12 changes: 9 additions & 3 deletions xdsl/pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from typing_extensions import deprecated

from xdsl.builder import BuilderListener
from xdsl.builder import Builder, BuilderListener
from xdsl.dialects.builtin import ArrayAttr, DictionaryAttr, ModuleOp
from xdsl.ir import (
Attribute,
Expand Down Expand Up @@ -77,8 +77,8 @@ def extend_from_listener(self, listener: BuilderListener | PatternRewriterListen
)


@dataclass(eq=False)
class PatternRewriter(PatternRewriterListener):
@dataclass(eq=False, init=False)
class PatternRewriter(Builder, PatternRewriterListener):
"""
A rewriter used during pattern matching.
Once an operation is matched, this rewriter is used to apply
Expand All @@ -91,6 +91,11 @@ class PatternRewriter(PatternRewriterListener):
has_done_action: bool = field(default=False, init=False)
"""Has the rewriter done any action during the current match."""

def __init__(self, current_operation: Operation):
PatternRewriterListener.__init__(self)
self.current_operation = current_operation
Builder.__init__(self, InsertPoint.before(current_operation))

def insert_op(
self, op: Operation | Sequence[Operation], insertion_point: InsertPoint
):
Expand Down Expand Up @@ -833,6 +838,7 @@ def _process_worklist(self, listener: PatternRewriterListener) -> bool:
# Reset the rewriter on `op`
rewriter.has_done_action = False
rewriter.current_operation = op
rewriter.insertion_point = InsertPoint.before(op)

# Apply the pattern on the operation
try:
Expand Down
23 changes: 17 additions & 6 deletions xdsl/transforms/convert_scf_to_openmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint


@dataclass
Expand Down Expand Up @@ -43,7 +44,8 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /):
regions=[Region(Block())],
operands=[[], [], [], [], [], []],
)
with ImplicitBuilder(parallel.region):
rewriter.insertion_point = InsertPoint.at_end(parallel.region.block)
with ImplicitBuilder(rewriter):
if self.chunk is None:
chunk_op = []
else:
Expand All @@ -65,7 +67,9 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /):
omp.ScheduleKind(self.schedule)
)
omp.TerminatorOp()
with ImplicitBuilder(wsloop.body):

rewriter.insertion_point = InsertPoint.at_end(wsloop.body.block)
with ImplicitBuilder(rewriter):
loop_nest = omp.LoopNestOp(
operands=[
loop.lowerBound[:collapse],
Expand All @@ -75,15 +79,21 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /):
regions=[Region(Block(arg_types=[IndexType()] * collapse))],
)
omp.TerminatorOp()
with ImplicitBuilder(loop_nest.body):

rewriter.insertion_point = InsertPoint.at_end(loop_nest.body.block)
with ImplicitBuilder(rewriter):
scope = memref.AllocaScopeOp(result_types=[[]], regions=[Region(Block())])
omp.YieldOp()
with ImplicitBuilder(scope.scope):

rewriter.insertion_point = InsertPoint.at_end(scope.scope.block)
with ImplicitBuilder(rewriter):
scope_terminator = memref.AllocaScopeReturnOp(operands=[[]])

for newarg, oldarg in zip(
loop_nest.body.block.args, loop.body.block.args[:collapse]
):
oldarg.replace_by(newarg)

for _ in range(collapse):
loop.body.block.erase_arg(loop.body.block.args[0])
if collapse < len(loop.lowerBound):
Expand All @@ -96,8 +106,9 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /):
new_ops = [new_loop]
else:
new_ops = [loop.body.block.detach_op(o) for o in loop.body.block.ops]
new_ops.pop()
scope.scope.block.insert_ops_before(new_ops, scope_terminator)
last_op = new_ops.pop()
rewriter.erase_op(last_op)
rewriter.insert_op(new_ops, InsertPoint.before(scope_terminator))

rewriter.replace_matched_op(parallel)

Expand Down

0 comments on commit a3414af

Please sign in to comment.