Skip to content

Commit

Permalink
core: Deprecate methods that should be using BlockInsertPoint (#3705)
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr authored Jan 20, 2025
1 parent 4e01836 commit 32d338c
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 32 deletions.
14 changes: 9 additions & 5 deletions tests/pattern_rewriter/test_pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
attr_type_rewrite_pattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.rewriter import BlockInsertPoint, InsertPoint


def rewrite_and_compare(
Expand Down Expand Up @@ -1225,7 +1225,7 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
if op.parent is None:
return

rewriter.inline_region_before(op.regions[0], op.parent)
rewriter.inline_region(op.regions[0], BlockInsertPoint.before(op.parent))
rewriter.erase_matched_op()

rewrite_and_compare(
Expand Down Expand Up @@ -1272,7 +1272,7 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
if op.parent is None:
return

rewriter.inline_region_after(op.regions[0], op.parent)
rewriter.inline_region(op.regions[0], BlockInsertPoint.after(op.parent))
rewriter.erase_matched_op()

rewrite_and_compare(
Expand Down Expand Up @@ -1320,7 +1320,9 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
if parent_region is None:
return

rewriter.inline_region_at_start(op.regions[0], parent_region)
rewriter.inline_region(
op.regions[0], BlockInsertPoint.at_start(parent_region)
)
rewriter.erase_matched_op()

rewrite_and_compare(
Expand Down Expand Up @@ -1368,7 +1370,9 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
if parent_region is None:
return

rewriter.inline_region_at_end(op.regions[0], parent_region)
rewriter.inline_region(
op.regions[0], BlockInsertPoint.at_end(parent_region)
)
rewriter.erase_matched_op()

rewrite_and_compare(
Expand Down
16 changes: 8 additions & 8 deletions tests/test_op_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,29 +116,29 @@ def test_builder_create_block():
target = Region([block1, block2])
builder = Builder(InsertPoint.at_start(block1))

new_block1 = builder.create_block_at_start(target, (i32,))
new_block1 = builder.create_block(BlockInsertPoint.at_start(target), (i32,))
assert len(new_block1.args) == 1
assert new_block1.args[0].type == i32
assert len(target.blocks) == 3
assert target.blocks[0] == new_block1
assert builder.insertion_point == InsertPoint.at_start(new_block1)

new_block2 = builder.create_block_at_end(target, (i64,))
new_block2 = builder.create_block(BlockInsertPoint.at_end(target), (i64,))
assert len(new_block2.args) == 1
assert new_block2.args[0].type == i64
assert len(target.blocks) == 4
assert target.blocks[3] == new_block2
assert builder.insertion_point == InsertPoint.at_start(new_block2)

new_block3 = builder.create_block_before(block2, (i32, i64))
new_block3 = builder.create_block(BlockInsertPoint.before(block2), (i32, i64))
assert len(new_block3.args) == 2
assert new_block3.args[0].type == i32
assert new_block3.args[1].type == i64
assert len(target.blocks) == 5
assert target.blocks[2] == new_block3
assert builder.insertion_point == InsertPoint.at_start(new_block3)

new_block4 = builder.create_block_after(block2, (i64, i32))
new_block4 = builder.create_block(BlockInsertPoint.after(block2), (i64, i32))
assert len(new_block4.args) == 2
assert new_block4.args[0].type == i64
assert new_block4.args[1].type == i32
Expand Down Expand Up @@ -182,10 +182,10 @@ def add_block_on_create(b: Block):

b.block_creation_handler = [add_block_on_create]

b1 = b.create_block_at_start(region)
b2 = b.create_block_at_end(region)
b3 = b.create_block_before(block)
b4 = b.create_block_after(block)
b1 = b.create_block(BlockInsertPoint.at_start(region))
b2 = b.create_block(BlockInsertPoint.at_end(region))
b3 = b.create_block(BlockInsertPoint.before(block))
b4 = b.create_block(BlockInsertPoint.after(block))

assert created_blocks == [b1, b2, b3, b4]

Expand Down
26 changes: 15 additions & 11 deletions tests/test_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from xdsl.dialects.builtin import Builtin, Float32Type, Float64Type, ModuleOp, i32, i64
from xdsl.ir import Block, Region
from xdsl.parser import Parser
from xdsl.rewriter import InsertPoint, Rewriter
from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter


def rewrite_and_compare(
Expand Down Expand Up @@ -289,7 +289,9 @@ def test_insert_block_before():
"""

def insert_empty_block_before(module: ModuleOp, rewriter: Rewriter) -> None:
rewriter.insert_block_before(Block(), module.regions[0].blocks[0])
rewriter.insert_block(
Block(), BlockInsertPoint.before(module.regions[0].blocks[0])
)

rewrite_and_compare(prog, expected, insert_empty_block_before)

Expand All @@ -312,7 +314,9 @@ def test_insert_block_after():
"""

def insert_empty_block_after(module: ModuleOp, rewriter: Rewriter) -> None:
rewriter.insert_block_after(Block(), module.regions[0].blocks[0])
rewriter.insert_block(
Block(), BlockInsertPoint.after(module.regions[0].blocks[0])
)

rewrite_and_compare(prog, expected, insert_empty_block_after)

Expand Down Expand Up @@ -510,7 +514,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:
Block((test.TestOp(result_types=(Float64Type(),)),)),
)
)
rewriter.inline_region_before(region, module.body.blocks[1])
rewriter.inline_region(region, BlockInsertPoint.before(module.body.blocks[1]))

rewrite_and_compare(prog, expected, transformation)

Expand Down Expand Up @@ -544,7 +548,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:
Block((test.TestOp(result_types=(Float64Type(),)),)),
)
)
rewriter.inline_region_after(region, module.body.blocks[0])
rewriter.inline_region(region, BlockInsertPoint.after(module.body.blocks[0]))

rewrite_and_compare(prog, expected, transformation)

Expand Down Expand Up @@ -578,7 +582,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:
Block((test.TestOp(result_types=(Float64Type(),)),)),
)
)
rewriter.inline_region_at_start(region, module.body)
rewriter.inline_region(region, BlockInsertPoint.at_start(module.body))

rewrite_and_compare(prog, expected, transformation)

Expand Down Expand Up @@ -612,7 +616,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:
Block((test.TestOp(result_types=(Float64Type(),)),)),
)
)
rewriter.inline_region_at_end(region, module.body)
rewriter.inline_region(region, BlockInsertPoint.at_end(module.body))

rewrite_and_compare(prog, expected, transformation)

Expand All @@ -621,13 +625,13 @@ def test_verify_inline_region():
region = Region(Block())

with pytest.raises(ValueError, match="Cannot move region into itself."):
Rewriter.inline_region_before(region, region.block)
Rewriter.inline_region(region, BlockInsertPoint.before(region.block))

with pytest.raises(ValueError, match="Cannot move region into itself."):
Rewriter.inline_region_after(region, region.block)
Rewriter.inline_region(region, BlockInsertPoint.after(region.block))

with pytest.raises(ValueError, match="Cannot move region into itself."):
Rewriter.inline_region_at_start(region, region)
Rewriter.inline_region(region, BlockInsertPoint.at_start(region))

with pytest.raises(ValueError, match="Cannot move region into itself."):
Rewriter.inline_region_at_end(region, region)
Rewriter.inline_region(region, BlockInsertPoint.at_end(region))
4 changes: 2 additions & 2 deletions xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.rewriter import BlockInsertPoint, InsertPoint


class LowerRiscvScfForPattern(RewritePattern):
Expand Down Expand Up @@ -119,7 +119,7 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /):
),
)

rewriter.inline_region_before(op.body, end_block)
rewriter.inline_region(op.body, BlockInsertPoint.before(end_block))

# Move lb to new register to initialize the iv.
# Skip for loop if condition is not satisfied at start.
Expand Down
14 changes: 13 additions & 1 deletion xdsl/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from types import TracebackType
from typing import ClassVar, TypeAlias, overload

from typing_extensions import deprecated

from xdsl.dialects.builtin import ArrayAttr
from xdsl.ir import Attribute, Block, BlockArgument, Operation, OperationInvT, Region
from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter
Expand Down Expand Up @@ -75,7 +77,7 @@ def insert(self, op: OperationInvT) -> OperationInvT:
return op

def create_block(
self, insert_point: BlockInsertPoint, arg_types: Iterable[Attribute]
self, insert_point: BlockInsertPoint, arg_types: Iterable[Attribute] = ()
) -> Block:
"""
Create a block at the given location, and set the operation insertion point
Expand All @@ -89,6 +91,9 @@ def create_block(
self.handle_block_creation(block)
return block

@deprecated(
"Use create_block(BlockInsertPoint.before(insert_before), arg_types) instead"
)
def create_block_before(
self, insert_before: Block, arg_types: Iterable[Attribute] = ()
) -> Block:
Expand All @@ -98,6 +103,9 @@ def create_block_before(
"""
return self.create_block(BlockInsertPoint.before(insert_before), arg_types)

@deprecated(
"Use create_block(BlockInsertPoint.after(insert_after), arg_types) instead"
)
def create_block_after(
self, insert_after: Block, arg_types: Iterable[Attribute] = ()
) -> Block:
Expand All @@ -107,6 +115,9 @@ def create_block_after(
"""
return self.create_block(BlockInsertPoint.after(insert_after), arg_types)

@deprecated(
"Use create_block(BlockInsertPoint.at_start(region), arg_types) instead"
)
def create_block_at_start(
self, region: Region, arg_types: Iterable[Attribute] = ()
) -> Block:
Expand All @@ -116,6 +127,7 @@ def create_block_at_start(
"""
return self.create_block(BlockInsertPoint.at_start(region), arg_types)

@deprecated("Use create_block(BlockInsertPoint.at_end(region), arg_types) instead")
def create_block_at_end(
self, region: Region, arg_types: Iterable[Attribute] = ()
) -> Block:
Expand Down
12 changes: 12 additions & 0 deletions xdsl/pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,18 +356,30 @@ def inline_region(self, region: Region, insertion_point: BlockInsertPoint) -> No
self.has_done_action = True
Rewriter.inline_region(region, insertion_point)

@deprecated(
"Please use `inline_region(region, BlockInsertPoint.before(target))` instead"
)
def inline_region_before(self, region: Region, target: Block) -> None:
"""Move the region blocks to an existing region."""
self.inline_region(region, BlockInsertPoint.before(target))

@deprecated(
"Please use `inline_region(region, BlockInsertPoint.after(target))` instead"
)
def inline_region_after(self, region: Region, target: Block) -> None:
"""Move the region blocks to an existing region."""
self.inline_region(region, BlockInsertPoint.after(target))

@deprecated(
"Please use `inline_region(region, BlockInsertPoint.at_start(target))` instead"
)
def inline_region_at_start(self, region: Region, target: Region) -> None:
"""Move the region blocks to an existing region."""
self.inline_region(region, BlockInsertPoint.at_start(target))

@deprecated(
"Please use `inline_region(region, BlockInsertPoint.at_end(target))` instead"
)
def inline_region_at_end(self, region: Region, target: Region) -> None:
"""Move the region blocks to an existing region."""
self.inline_region(region, BlockInsertPoint.at_end(target))
Expand Down
10 changes: 10 additions & 0 deletions xdsl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field

from typing_extensions import deprecated

from xdsl.ir import Block, Operation, Region, SSAValue


Expand Down Expand Up @@ -240,6 +242,7 @@ def insert_block(block: Block | Iterable[Block], insert_point: BlockInsertPoint)
else:
region.add_block(block)

@deprecated("Use `insert_block(block, BlockInsertPoint.after(target))` instead")
@staticmethod
def insert_block_after(block: Block | list[Block], target: Block):
"""
Expand All @@ -249,6 +252,7 @@ def insert_block_after(block: Block | list[Block], target: Block):
"""
Rewriter.insert_block(block, BlockInsertPoint.after(target))

@deprecated("Use `insert_block(block, BlockInsertPoint.before(target))` instead")
@staticmethod
def insert_block_before(block: Block | list[Block], target: Block):
"""
Expand Down Expand Up @@ -284,21 +288,27 @@ def inline_region(region: Region, insertion_point: BlockInsertPoint) -> None:
else:
region.move_blocks(insertion_point.region)

@deprecated("Use `inline_region(region, BlockInsertPoint.before(target))` instead")
@staticmethod
def inline_region_before(region: Region, target: Block) -> None:
"""Move the region blocks to an existing region, before `target`."""
Rewriter.inline_region(region, BlockInsertPoint.before(target))

@deprecated("Use `inline_region(region, BlockInsertPoint.after(target))` instead")
@staticmethod
def inline_region_after(region: Region, target: Block) -> None:
"""Move the region blocks to an existing region, after `target`."""
Rewriter.inline_region(region, BlockInsertPoint.after(target))

@deprecated(
"Use `inline_region(region, BlockInsertPoint.at_start(target))` instead"
)
@staticmethod
def inline_region_at_start(region: Region, target: Region) -> None:
"""Move the region blocks to the start of an existing region."""
Rewriter.inline_region(region, BlockInsertPoint.at_start(target))

@deprecated("Use `inline_region(region, BlockInsertPoint.at_end(target))` instead")
@staticmethod
def inline_region_at_end(region: Region, target: Region) -> None:
"""Move the region blocks to the end of an existing region."""
Expand Down
10 changes: 5 additions & 5 deletions xdsl/transforms/convert_scf_to_cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.rewriter import BlockInsertPoint, InsertPoint
from xdsl.traits import IsTerminator


Expand Down Expand Up @@ -60,7 +60,7 @@ def match_and_rewrite(self, if_op: IfOp, rewriter: PatternRewriter, /):
)

rewriter.erase_op(then_terminator)
rewriter.inline_region_before(then_region, continue_block)
rewriter.inline_region(then_region, BlockInsertPoint.before(continue_block))

# Move blocks from the "else" region (if present) to the region containing
# 'scf.if', place it before the continuation block and branch to it. It
Expand All @@ -78,7 +78,7 @@ def match_and_rewrite(self, if_op: IfOp, rewriter: PatternRewriter, /):
)

rewriter.erase_op(else_terminator)
rewriter.inline_region_before(else_region, continue_block)
rewriter.inline_region(else_region, BlockInsertPoint.before(continue_block))
else:
else_block = continue_block

Expand Down Expand Up @@ -116,7 +116,7 @@ def match_and_rewrite(self, for_op: ForOp, rewriter: PatternRewriter):
first_body_block = condition_block.split_before(first_op)
last_body_block = for_op.body.last_block
assert last_body_block is not None
rewriter.inline_region_before(for_op.body, end_block)
rewriter.inline_region(for_op.body, BlockInsertPoint.before(end_block))
iv = condition_block.args[0]

# Append the induction variable stepping logic to the last body block and
Expand Down Expand Up @@ -169,7 +169,7 @@ def _convert_region(
rewriter.replace_op(yield_op, BranchOp(continue_block, *yield_op.operands))

# Inline the region
rewriter.inline_region_before(region, continue_block)
rewriter.inline_region(region, BlockInsertPoint.before(continue_block))
return block

@op_type_rewrite_pattern
Expand Down

0 comments on commit 32d338c

Please sign in to comment.