From 32d338c7b10e54cbfb3b1c3cf3981fc3fdcb92e9 Mon Sep 17 00:00:00 2001 From: Fehr Mathieu Date: Mon, 20 Jan 2025 06:17:59 +0100 Subject: [PATCH] core: Deprecate methods that should be using BlockInsertPoint (#3705) --- .../pattern_rewriter/test_pattern_rewriter.py | 14 ++++++---- tests/test_op_builder.py | 16 ++++++------ tests/test_rewriter.py | 26 +++++++++++-------- .../lowering/convert_riscv_scf_to_riscv_cf.py | 4 +-- xdsl/builder.py | 14 +++++++++- xdsl/pattern_rewriter.py | 12 +++++++++ xdsl/rewriter.py | 10 +++++++ xdsl/transforms/convert_scf_to_cf.py | 10 +++---- 8 files changed, 74 insertions(+), 32 deletions(-) diff --git a/tests/pattern_rewriter/test_pattern_rewriter.py b/tests/pattern_rewriter/test_pattern_rewriter.py index 0cb0495585..0cca22bd38 100644 --- a/tests/pattern_rewriter/test_pattern_rewriter.py +++ b/tests/pattern_rewriter/test_pattern_rewriter.py @@ -33,7 +33,7 @@ attr_type_rewrite_pattern, op_type_rewrite_pattern, ) -from xdsl.rewriter import InsertPoint +from xdsl.rewriter import BlockInsertPoint, InsertPoint def rewrite_and_compare( @@ -1225,7 +1225,7 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter): if op.parent is None: return - rewriter.inline_region_before(op.regions[0], op.parent) + rewriter.inline_region(op.regions[0], BlockInsertPoint.before(op.parent)) rewriter.erase_matched_op() rewrite_and_compare( @@ -1272,7 +1272,7 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter): if op.parent is None: return - rewriter.inline_region_after(op.regions[0], op.parent) + rewriter.inline_region(op.regions[0], BlockInsertPoint.after(op.parent)) rewriter.erase_matched_op() rewrite_and_compare( @@ -1320,7 +1320,9 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter): if parent_region is None: return - rewriter.inline_region_at_start(op.regions[0], parent_region) + rewriter.inline_region( + op.regions[0], BlockInsertPoint.at_start(parent_region) + ) rewriter.erase_matched_op() rewrite_and_compare( @@ -1368,7 +1370,9 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter): if parent_region is None: return - rewriter.inline_region_at_end(op.regions[0], parent_region) + rewriter.inline_region( + op.regions[0], BlockInsertPoint.at_end(parent_region) + ) rewriter.erase_matched_op() rewrite_and_compare( diff --git a/tests/test_op_builder.py b/tests/test_op_builder.py index a1ece0fecc..69caea5eb8 100644 --- a/tests/test_op_builder.py +++ b/tests/test_op_builder.py @@ -116,21 +116,21 @@ def test_builder_create_block(): target = Region([block1, block2]) builder = Builder(InsertPoint.at_start(block1)) - new_block1 = builder.create_block_at_start(target, (i32,)) + new_block1 = builder.create_block(BlockInsertPoint.at_start(target), (i32,)) assert len(new_block1.args) == 1 assert new_block1.args[0].type == i32 assert len(target.blocks) == 3 assert target.blocks[0] == new_block1 assert builder.insertion_point == InsertPoint.at_start(new_block1) - new_block2 = builder.create_block_at_end(target, (i64,)) + new_block2 = builder.create_block(BlockInsertPoint.at_end(target), (i64,)) assert len(new_block2.args) == 1 assert new_block2.args[0].type == i64 assert len(target.blocks) == 4 assert target.blocks[3] == new_block2 assert builder.insertion_point == InsertPoint.at_start(new_block2) - new_block3 = builder.create_block_before(block2, (i32, i64)) + new_block3 = builder.create_block(BlockInsertPoint.before(block2), (i32, i64)) assert len(new_block3.args) == 2 assert new_block3.args[0].type == i32 assert new_block3.args[1].type == i64 @@ -138,7 +138,7 @@ def test_builder_create_block(): assert target.blocks[2] == new_block3 assert builder.insertion_point == InsertPoint.at_start(new_block3) - new_block4 = builder.create_block_after(block2, (i64, i32)) + new_block4 = builder.create_block(BlockInsertPoint.after(block2), (i64, i32)) assert len(new_block4.args) == 2 assert new_block4.args[0].type == i64 assert new_block4.args[1].type == i32 @@ -182,10 +182,10 @@ def add_block_on_create(b: Block): b.block_creation_handler = [add_block_on_create] - b1 = b.create_block_at_start(region) - b2 = b.create_block_at_end(region) - b3 = b.create_block_before(block) - b4 = b.create_block_after(block) + b1 = b.create_block(BlockInsertPoint.at_start(region)) + b2 = b.create_block(BlockInsertPoint.at_end(region)) + b3 = b.create_block(BlockInsertPoint.before(block)) + b4 = b.create_block(BlockInsertPoint.after(block)) assert created_blocks == [b1, b2, b3, b4] diff --git a/tests/test_rewriter.py b/tests/test_rewriter.py index bdb668fd62..b4f6af7c4e 100644 --- a/tests/test_rewriter.py +++ b/tests/test_rewriter.py @@ -9,7 +9,7 @@ from xdsl.dialects.builtin import Builtin, Float32Type, Float64Type, ModuleOp, i32, i64 from xdsl.ir import Block, Region from xdsl.parser import Parser -from xdsl.rewriter import InsertPoint, Rewriter +from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter def rewrite_and_compare( @@ -289,7 +289,9 @@ def test_insert_block_before(): """ def insert_empty_block_before(module: ModuleOp, rewriter: Rewriter) -> None: - rewriter.insert_block_before(Block(), module.regions[0].blocks[0]) + rewriter.insert_block( + Block(), BlockInsertPoint.before(module.regions[0].blocks[0]) + ) rewrite_and_compare(prog, expected, insert_empty_block_before) @@ -312,7 +314,9 @@ def test_insert_block_after(): """ def insert_empty_block_after(module: ModuleOp, rewriter: Rewriter) -> None: - rewriter.insert_block_after(Block(), module.regions[0].blocks[0]) + rewriter.insert_block( + Block(), BlockInsertPoint.after(module.regions[0].blocks[0]) + ) rewrite_and_compare(prog, expected, insert_empty_block_after) @@ -510,7 +514,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: Block((test.TestOp(result_types=(Float64Type(),)),)), ) ) - rewriter.inline_region_before(region, module.body.blocks[1]) + rewriter.inline_region(region, BlockInsertPoint.before(module.body.blocks[1])) rewrite_and_compare(prog, expected, transformation) @@ -544,7 +548,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: Block((test.TestOp(result_types=(Float64Type(),)),)), ) ) - rewriter.inline_region_after(region, module.body.blocks[0]) + rewriter.inline_region(region, BlockInsertPoint.after(module.body.blocks[0])) rewrite_and_compare(prog, expected, transformation) @@ -578,7 +582,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: Block((test.TestOp(result_types=(Float64Type(),)),)), ) ) - rewriter.inline_region_at_start(region, module.body) + rewriter.inline_region(region, BlockInsertPoint.at_start(module.body)) rewrite_and_compare(prog, expected, transformation) @@ -612,7 +616,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: Block((test.TestOp(result_types=(Float64Type(),)),)), ) ) - rewriter.inline_region_at_end(region, module.body) + rewriter.inline_region(region, BlockInsertPoint.at_end(module.body)) rewrite_and_compare(prog, expected, transformation) @@ -621,13 +625,13 @@ def test_verify_inline_region(): region = Region(Block()) with pytest.raises(ValueError, match="Cannot move region into itself."): - Rewriter.inline_region_before(region, region.block) + Rewriter.inline_region(region, BlockInsertPoint.before(region.block)) with pytest.raises(ValueError, match="Cannot move region into itself."): - Rewriter.inline_region_after(region, region.block) + Rewriter.inline_region(region, BlockInsertPoint.after(region.block)) with pytest.raises(ValueError, match="Cannot move region into itself."): - Rewriter.inline_region_at_start(region, region) + Rewriter.inline_region(region, BlockInsertPoint.at_start(region)) with pytest.raises(ValueError, match="Cannot move region into itself."): - Rewriter.inline_region_at_end(region, region) + Rewriter.inline_region(region, BlockInsertPoint.at_end(region)) diff --git a/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py b/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py index 24abc9f874..c288bffbeb 100644 --- a/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py +++ b/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py @@ -7,7 +7,7 @@ RewritePattern, op_type_rewrite_pattern, ) -from xdsl.rewriter import InsertPoint +from xdsl.rewriter import BlockInsertPoint, InsertPoint class LowerRiscvScfForPattern(RewritePattern): @@ -119,7 +119,7 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /): ), ) - rewriter.inline_region_before(op.body, end_block) + rewriter.inline_region(op.body, BlockInsertPoint.before(end_block)) # Move lb to new register to initialize the iv. # Skip for loop if condition is not satisfied at start. diff --git a/xdsl/builder.py b/xdsl/builder.py index c9b9e7374d..75f5286c51 100644 --- a/xdsl/builder.py +++ b/xdsl/builder.py @@ -7,6 +7,8 @@ from types import TracebackType from typing import ClassVar, TypeAlias, overload +from typing_extensions import deprecated + from xdsl.dialects.builtin import ArrayAttr from xdsl.ir import Attribute, Block, BlockArgument, Operation, OperationInvT, Region from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter @@ -75,7 +77,7 @@ def insert(self, op: OperationInvT) -> OperationInvT: return op def create_block( - self, insert_point: BlockInsertPoint, arg_types: Iterable[Attribute] + self, insert_point: BlockInsertPoint, arg_types: Iterable[Attribute] = () ) -> Block: """ Create a block at the given location, and set the operation insertion point @@ -89,6 +91,9 @@ def create_block( self.handle_block_creation(block) return block + @deprecated( + "Use create_block(BlockInsertPoint.before(insert_before), arg_types) instead" + ) def create_block_before( self, insert_before: Block, arg_types: Iterable[Attribute] = () ) -> Block: @@ -98,6 +103,9 @@ def create_block_before( """ return self.create_block(BlockInsertPoint.before(insert_before), arg_types) + @deprecated( + "Use create_block(BlockInsertPoint.after(insert_after), arg_types) instead" + ) def create_block_after( self, insert_after: Block, arg_types: Iterable[Attribute] = () ) -> Block: @@ -107,6 +115,9 @@ def create_block_after( """ return self.create_block(BlockInsertPoint.after(insert_after), arg_types) + @deprecated( + "Use create_block(BlockInsertPoint.at_start(region), arg_types) instead" + ) def create_block_at_start( self, region: Region, arg_types: Iterable[Attribute] = () ) -> Block: @@ -116,6 +127,7 @@ def create_block_at_start( """ return self.create_block(BlockInsertPoint.at_start(region), arg_types) + @deprecated("Use create_block(BlockInsertPoint.at_end(region), arg_types) instead") def create_block_at_end( self, region: Region, arg_types: Iterable[Attribute] = () ) -> Block: diff --git a/xdsl/pattern_rewriter.py b/xdsl/pattern_rewriter.py index 340ad51299..a0b14ba17b 100644 --- a/xdsl/pattern_rewriter.py +++ b/xdsl/pattern_rewriter.py @@ -356,18 +356,30 @@ def inline_region(self, region: Region, insertion_point: BlockInsertPoint) -> No self.has_done_action = True Rewriter.inline_region(region, insertion_point) + @deprecated( + "Please use `inline_region(region, BlockInsertPoint.before(target))` instead" + ) def inline_region_before(self, region: Region, target: Block) -> None: """Move the region blocks to an existing region.""" self.inline_region(region, BlockInsertPoint.before(target)) + @deprecated( + "Please use `inline_region(region, BlockInsertPoint.after(target))` instead" + ) def inline_region_after(self, region: Region, target: Block) -> None: """Move the region blocks to an existing region.""" self.inline_region(region, BlockInsertPoint.after(target)) + @deprecated( + "Please use `inline_region(region, BlockInsertPoint.at_start(target))` instead" + ) def inline_region_at_start(self, region: Region, target: Region) -> None: """Move the region blocks to an existing region.""" self.inline_region(region, BlockInsertPoint.at_start(target)) + @deprecated( + "Please use `inline_region(region, BlockInsertPoint.at_end(target))` instead" + ) def inline_region_at_end(self, region: Region, target: Region) -> None: """Move the region blocks to an existing region.""" self.inline_region(region, BlockInsertPoint.at_end(target)) diff --git a/xdsl/rewriter.py b/xdsl/rewriter.py index a45cc5161a..e280aef608 100644 --- a/xdsl/rewriter.py +++ b/xdsl/rewriter.py @@ -3,6 +3,8 @@ from collections.abc import Iterable, Sequence from dataclasses import dataclass, field +from typing_extensions import deprecated + from xdsl.ir import Block, Operation, Region, SSAValue @@ -240,6 +242,7 @@ def insert_block(block: Block | Iterable[Block], insert_point: BlockInsertPoint) else: region.add_block(block) + @deprecated("Use `insert_block(block, BlockInsertPoint.after(target))` instead") @staticmethod def insert_block_after(block: Block | list[Block], target: Block): """ @@ -249,6 +252,7 @@ def insert_block_after(block: Block | list[Block], target: Block): """ Rewriter.insert_block(block, BlockInsertPoint.after(target)) + @deprecated("Use `insert_block(block, BlockInsertPoint.before(target))` instead") @staticmethod def insert_block_before(block: Block | list[Block], target: Block): """ @@ -284,21 +288,27 @@ def inline_region(region: Region, insertion_point: BlockInsertPoint) -> None: else: region.move_blocks(insertion_point.region) + @deprecated("Use `inline_region(region, BlockInsertPoint.before(target))` instead") @staticmethod def inline_region_before(region: Region, target: Block) -> None: """Move the region blocks to an existing region, before `target`.""" Rewriter.inline_region(region, BlockInsertPoint.before(target)) + @deprecated("Use `inline_region(region, BlockInsertPoint.after(target))` instead") @staticmethod def inline_region_after(region: Region, target: Block) -> None: """Move the region blocks to an existing region, after `target`.""" Rewriter.inline_region(region, BlockInsertPoint.after(target)) + @deprecated( + "Use `inline_region(region, BlockInsertPoint.at_start(target))` instead" + ) @staticmethod def inline_region_at_start(region: Region, target: Region) -> None: """Move the region blocks to the start of an existing region.""" Rewriter.inline_region(region, BlockInsertPoint.at_start(target)) + @deprecated("Use `inline_region(region, BlockInsertPoint.at_end(target))` instead") @staticmethod def inline_region_at_end(region: Region, target: Region) -> None: """Move the region blocks to the end of an existing region.""" diff --git a/xdsl/transforms/convert_scf_to_cf.py b/xdsl/transforms/convert_scf_to_cf.py index d1c07e683b..977c756fd1 100644 --- a/xdsl/transforms/convert_scf_to_cf.py +++ b/xdsl/transforms/convert_scf_to_cf.py @@ -16,7 +16,7 @@ RewritePattern, op_type_rewrite_pattern, ) -from xdsl.rewriter import InsertPoint +from xdsl.rewriter import BlockInsertPoint, InsertPoint from xdsl.traits import IsTerminator @@ -60,7 +60,7 @@ def match_and_rewrite(self, if_op: IfOp, rewriter: PatternRewriter, /): ) rewriter.erase_op(then_terminator) - rewriter.inline_region_before(then_region, continue_block) + rewriter.inline_region(then_region, BlockInsertPoint.before(continue_block)) # Move blocks from the "else" region (if present) to the region containing # 'scf.if', place it before the continuation block and branch to it. It @@ -78,7 +78,7 @@ def match_and_rewrite(self, if_op: IfOp, rewriter: PatternRewriter, /): ) rewriter.erase_op(else_terminator) - rewriter.inline_region_before(else_region, continue_block) + rewriter.inline_region(else_region, BlockInsertPoint.before(continue_block)) else: else_block = continue_block @@ -116,7 +116,7 @@ def match_and_rewrite(self, for_op: ForOp, rewriter: PatternRewriter): first_body_block = condition_block.split_before(first_op) last_body_block = for_op.body.last_block assert last_body_block is not None - rewriter.inline_region_before(for_op.body, end_block) + rewriter.inline_region(for_op.body, BlockInsertPoint.before(end_block)) iv = condition_block.args[0] # Append the induction variable stepping logic to the last body block and @@ -169,7 +169,7 @@ def _convert_region( rewriter.replace_op(yield_op, BranchOp(continue_block, *yield_op.operands)) # Inline the region - rewriter.inline_region_before(region, continue_block) + rewriter.inline_region(region, BlockInsertPoint.before(continue_block)) return block @op_type_rewrite_pattern