Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Make PatternRewriter a Builder #3540

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/pattern_rewriter/test_pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from conftest import assert_print_op

from xdsl.builder import ImplicitBuilder
from xdsl.context import MLContext
from xdsl.dialects import test
from xdsl.dialects.arith import AddiOp, Arith, ConstantOp, MuliOp
Expand All @@ -14,6 +15,7 @@
IntegerType,
ModuleOp,
StringAttr,
UnitAttr,
i32,
i64,
)
Expand Down Expand Up @@ -1410,6 +1412,44 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
)


def test_pattern_rewriter_as_op_builder():
"""Test that the PatternRewriter works as an OpBuilder."""
prog = """
"builtin.module"() ({
"test.op"() : () -> ()
"test.op"() {"nomatch"} : () -> ()
"test.op"() : () -> ()
}) : () -> ()"""
expected = """
"builtin.module"() ({
"test.op"() {"inserted"} : () -> ()
"test.op"() {"replaced"} : () -> ()
"test.op"() {"nomatch"} : () -> ()
"test.op"() {"inserted"} : () -> ()
"test.op"() {"replaced"} : () -> ()
}) : () -> ()"""

class Rewrite(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
if "nomatch" in op.attributes:
return
with ImplicitBuilder(rewriter):
test.TestOp.create(attributes={"inserted": UnitAttr()})
rewriter.replace_matched_op(
test.TestOp.create(attributes={"replaced": UnitAttr()})
)

rewrite_and_compare(
prog,
expected,
PatternRewriteWalker(Rewrite(), apply_recursively=False),
op_inserted=4,
op_removed=2,
op_replaced=2,
)


def test_type_conversion():
"""Test rewriter on ops without results"""
prog = """\
Expand Down Expand Up @@ -1777,3 +1817,32 @@ def convert_type(self, typ: IntegerType) -> IndexType:
op_replaced=1,
op_modified=1,
)


def test_pattern_rewriter_erase_op_with_region():
"""Test that erasing an operation with a region works correctly."""
prog = """
"builtin.module"() ({
"test.op"() ({
"test.op"() {"error_if_matching"} : () -> ()
}): () -> ()
}) : () -> ()"""
expected = """
"builtin.module"() ({
^0:
}) : () -> ()"""

class Rewrite(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
if "error_if_matching" in op.attributes:
raise Exception("operation that is supposed to be deleted was matched")
assert not op.attributes
rewriter.erase_matched_op()

rewrite_and_compare(
prog,
expected,
PatternRewriteWalker(Rewrite(), apply_recursively=False),
op_removed=1,
)
7 changes: 7 additions & 0 deletions tests/test_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,13 @@ def transformation_unsafe(module: ModuleOp, rewriter: Rewriter) -> None:
rewrite_and_compare(prog, expected, transformation_safe)


def test_erase_orphan_op():
"""Test that we can erase an orphan operation."""
module = ModuleOp([])
rewriter = Rewriter()
rewriter.erase_op(module)


def test_inline_region_before():
"""Test the insertion of a block in a region."""
prog = """\
Expand Down
6 changes: 3 additions & 3 deletions xdsl/backend/riscv/riscv_scf_to_asm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /):
yield_op = body.last_op
assert isinstance(yield_op, riscv_scf.YieldOp)

body.insert_ops_after(
rewriter.insert_op(
compor marked this conversation as resolved.
Show resolved Hide resolved
[
riscv.AddOp(get_loop_var, op.step, rd=loop_var_reg),
riscv.BltOp(get_loop_var, op.ub, scf_body),
riscv.LabelOp(scf_body_end),
],
yield_op,
InsertPoint.after(yield_op),
)
body.erase_op(yield_op)
rewriter.erase_op(yield_op)

# We know that the body is not empty now.
assert body.first_op is not None
Expand Down
18 changes: 14 additions & 4 deletions xdsl/pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from typing_extensions import deprecated

from xdsl.builder import BuilderListener
from xdsl.builder import Builder, BuilderListener
from xdsl.dialects.builtin import ArrayAttr, DictionaryAttr, ModuleOp
from xdsl.ir import (
Attribute,
Expand Down Expand Up @@ -77,8 +77,8 @@ def extend_from_listener(self, listener: BuilderListener | PatternRewriterListen
)


@dataclass(eq=False)
class PatternRewriter(PatternRewriterListener):
@dataclass(eq=False, init=False)
class PatternRewriter(Builder, PatternRewriterListener):
"""
A rewriter used during pattern matching.
Once an operation is matched, this rewriter is used to apply
Expand All @@ -91,6 +91,11 @@ class PatternRewriter(PatternRewriterListener):
has_done_action: bool = field(default=False, init=False)
math-fehr marked this conversation as resolved.
Show resolved Hide resolved
"""Has the rewriter done any action during the current match."""

def __init__(self, current_operation: Operation):
math-fehr marked this conversation as resolved.
Show resolved Hide resolved
PatternRewriterListener.__init__(self)
self.current_operation = current_operation
Builder.__init__(self, InsertPoint.before(current_operation))

def insert_op(
self, op: Operation | Sequence[Operation], insertion_point: InsertPoint
):
Expand Down Expand Up @@ -726,7 +731,11 @@ def _handle_operation_removal(self, op: Operation) -> None:
"""Handle removal of an operation."""
if self.apply_recursively:
self._add_operands_to_worklist(op.operands)
self._worklist.remove(op)
if op.regions:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a separate change to making PatternRewriter a Builder?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will rebase this branch on a new one, this was indeed a bug before that can be refactored.

for sub_op in op.walk():
self._worklist.remove(sub_op)
else:
self._worklist.remove(op)
math-fehr marked this conversation as resolved.
Show resolved Hide resolved

def _handle_operation_modification(self, op: Operation) -> None:
"""Handle modification of an operation."""
Expand Down Expand Up @@ -829,6 +838,7 @@ def _process_worklist(self, listener: PatternRewriterListener) -> bool:
# Reset the rewriter on `op`
rewriter.has_done_action = False
rewriter.current_operation = op
rewriter.insertion_point = InsertPoint.before(op)

# Apply the pattern on the operation
try:
Expand Down
8 changes: 4 additions & 4 deletions xdsl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def erase_op(op: Operation, safe_erase: bool = True):
If safe_erase is True, check that the operation has no uses.
Otherwise, replace its uses with ErasedSSAValue.
"""
assert op.parent is not None, "Cannot erase an operation that has no parents"

block = op.parent
block.erase_op(op, safe_erase=safe_erase)
if (block := op.parent) is not None:
block.erase_op(op, safe_erase=safe_erase)
else:
op.erase(safe_erase=safe_erase)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it now ok to erase a operation with no parents?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll move this to a separate PR as well, the idea is that sometimes we detach an operation and then destroy it.

Comment on lines +72 to +75
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH this feels like it could be its own PR, with an accompanying test


@staticmethod
def replace_op(
Expand Down
23 changes: 17 additions & 6 deletions xdsl/transforms/convert_scf_to_openmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint


@dataclass
Expand Down Expand Up @@ -43,7 +44,8 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /):
regions=[Region(Block())],
operands=[[], [], [], [], [], []],
)
with ImplicitBuilder(parallel.region):
rewriter.insertion_point = InsertPoint.at_end(parallel.region.block)
with ImplicitBuilder(rewriter):
math-fehr marked this conversation as resolved.
Show resolved Hide resolved
if self.chunk is None:
chunk_op = []
else:
Expand All @@ -65,7 +67,9 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /):
omp.ScheduleKind(self.schedule)
)
omp.TerminatorOp()
with ImplicitBuilder(wsloop.body):

rewriter.insertion_point = InsertPoint.at_end(wsloop.body.block)
with ImplicitBuilder(rewriter):
loop_nest = omp.LoopNestOp(
operands=[
loop.lowerBound[:collapse],
Expand All @@ -75,15 +79,21 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /):
regions=[Region(Block(arg_types=[IndexType()] * collapse))],
)
omp.TerminatorOp()
with ImplicitBuilder(loop_nest.body):

rewriter.insertion_point = InsertPoint.at_end(loop_nest.body.block)
with ImplicitBuilder(rewriter):
scope = memref.AllocaScopeOp(result_types=[[]], regions=[Region(Block())])
omp.YieldOp()
with ImplicitBuilder(scope.scope):

rewriter.insertion_point = InsertPoint.at_end(scope.scope.block)
with ImplicitBuilder(rewriter):
scope_terminator = memref.AllocaScopeReturnOp(operands=[[]])

for newarg, oldarg in zip(
loop_nest.body.block.args, loop.body.block.args[:collapse]
):
oldarg.replace_by(newarg)

for _ in range(collapse):
loop.body.block.erase_arg(loop.body.block.args[0])
if collapse < len(loop.lowerBound):
Expand All @@ -96,8 +106,9 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /):
new_ops = [new_loop]
else:
new_ops = [loop.body.block.detach_op(o) for o in loop.body.block.ops]
new_ops.pop()
scope.scope.block.insert_ops_before(new_ops, scope_terminator)
last_op = new_ops.pop()
rewriter.erase_op(last_op)
rewriter.insert_op(new_ops, InsertPoint.before(scope_terminator))

rewriter.replace_matched_op(parallel)

Expand Down
Loading