From b73b7efb5923f287bfe65a6fa07cc052deb36698 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Tue, 15 Oct 2024 22:22:53 +0100 Subject: [PATCH] transformations: (cf) switch canonicalization (#3291) Adds all the switch canonicalization patterns present in mlir. --------- Co-authored-by: Sasha Lopoukhine --- tests/filecheck/dialects/cf/canonicalize.mlir | 283 +++++++++++++++++ xdsl/dialects/cf.py | 22 +- .../canonicalization_patterns/cf.py | 290 +++++++++++++++++- 3 files changed, 591 insertions(+), 4 deletions(-) diff --git a/tests/filecheck/dialects/cf/canonicalize.mlir b/tests/filecheck/dialects/cf/canonicalize.mlir index 82ed7e14fc..2e8e4d492b 100644 --- a/tests/filecheck/dialects/cf/canonicalize.mlir +++ b/tests/filecheck/dialects/cf/canonicalize.mlir @@ -198,3 +198,286 @@ func.func @branchCondProp(%arg0: i1) { ^exit: return } + +/// Test the folding of SwitchOp +// CHECK: func.func @switch_only_default(%flag : i32, %caseOperand0 : f32) { +// CHECK-NEXT: "test.termop"() [^0, ^1] : () -> () +// CHECK-NEXT: ^0: +// CHECK-NEXT: cf.br ^1(%caseOperand0 : f32) +// CHECK-NEXT: ^1(%arg : f32): +// CHECK-NEXT: "test.termop"(%arg) : (f32) -> () +// CHECK-NEXT: } +func.func @switch_only_default(%flag : i32, %caseOperand0 : f32) { + // add predecessors for all blocks to avoid other canonicalizations. + "test.termop"() [^0, ^1] : () -> () + ^0: + cf.switch %flag : i32, [ + default: ^1(%caseOperand0 : f32) + ] + ^1(%arg : f32): + "test.termop"(%arg) : (f32) -> () +} + + +// CHECK: func.func @switch_case_matching_default(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) { +// CHECK-NEXT: "test.termop"() [^0, ^1, ^2] : () -> () +// CHECK-NEXT: ^0: +// CHECK-NEXT: cf.switch %flag : i32, [ +// CHECK-NEXT: default: ^1(%caseOperand0 : f32), +// CHECK-NEXT: 10: ^2(%caseOperand1 : f32) +// CHECK-NEXT: ] +// CHECK-NEXT: ^1(%arg : f32): +// CHECK-NEXT: "test.termop"(%arg) : (f32) -> () +// CHECK-NEXT: ^2(%arg2 : f32): +// CHECK-NEXT: "test.termop"(%arg2) : (f32) -> () +// CHECK-NEXT: } +func.func @switch_case_matching_default(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) { + // add predecessors for all blocks to avoid other canonicalizations. + "test.termop"() [^0, ^1, ^2] : () -> () + ^0: + cf.switch %flag : i32, [ + default: ^1(%caseOperand0 : f32), + 42: ^1(%caseOperand0 : f32), + 10: ^2(%caseOperand1 : f32), + 17: ^1(%caseOperand0 : f32) + ] + ^1(%arg : f32): + "test.termop"(%arg) : (f32) -> () + ^2(%arg2 : f32): + "test.termop"(%arg2) : (f32) -> () +} + + +// CHECK: func.func @switch_on_const_no_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) { +// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3] : () -> () +// CHECK-NEXT: ^0: +// CHECK-NEXT: cf.br ^1(%caseOperand0 : f32) +// CHECK-NEXT: ^1(%arg : f32): +// CHECK-NEXT: "test.termop"(%arg) : (f32) -> () +// CHECK-NEXT: ^2(%arg2 : f32): +// CHECK-NEXT: "test.termop"(%arg2) : (f32) -> () +// CHECK-NEXT: ^3(%arg3 : f32): +// CHECK-NEXT: "test.termop"(%arg3) : (f32) -> () +// CHECK-NEXT: } +func.func @switch_on_const_no_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) { + // add predecessors for all blocks to avoid other canonicalizations. + "test.termop"() [^0, ^1, ^2, ^3] : () -> () + ^0: + %c0_i32 = arith.constant 0 : i32 + cf.switch %c0_i32 : i32, [ + default: ^1(%caseOperand0 : f32), + -1: ^2(%caseOperand1 : f32), + 1: ^3(%caseOperand2 : f32) + ] + ^1(%arg : f32): + "test.termop"(%arg) : (f32) -> () + ^2(%arg2 : f32): + "test.termop"(%arg2) : (f32) -> () + ^3(%arg3 : f32): + "test.termop"(%arg3) : (f32) -> () +} + +// CHECK: func.func @switch_on_const_with_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) { +// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3] : () -> () +// CHECK-NEXT: ^0: +// CHECK-NEXT: cf.br ^3(%caseOperand2 : f32) +// CHECK-NEXT: ^1(%arg : f32): +// CHECK-NEXT: "test.termop"(%arg) : (f32) -> () +// CHECK-NEXT: ^2(%arg2 : f32): +// CHECK-NEXT: "test.termop"(%arg2) : (f32) -> () +// CHECK-NEXT: ^3(%arg3 : f32): +// CHECK-NEXT: "test.termop"(%arg3) : (f32) -> () +// CHECK-NEXT: } +func.func @switch_on_const_with_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) { + // add predecessors for all blocks to avoid other canonicalizations. + "test.termop"() [^0, ^1, ^2, ^3] : () -> () + ^0: + %c0_i32 = arith.constant 1 : i32 + cf.switch %c0_i32 : i32, [ + default: ^1(%caseOperand0 : f32), + -1: ^2(%caseOperand1 : f32), + 1: ^3(%caseOperand2 : f32) + ] + ^1(%arg : f32): + "test.termop"(%arg) : (f32) -> () + ^2(%arg2 : f32): + "test.termop"(%arg2) : (f32) -> () + ^3(%arg3 : f32): + "test.termop"(%arg3) : (f32) -> () +} + +// CHECK: func.func @switch_passthrough(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32, %caseOperand3 : f32) { +// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3, ^4, ^5] : () -> () +// CHECK-NEXT: ^0: +// CHECK-NEXT: cf.switch %flag : i32, [ +// CHECK-NEXT: default: ^4(%caseOperand0 : f32), +// CHECK-NEXT: 43: ^5(%caseOperand1 : f32), +// CHECK-NEXT: 44: ^3(%caseOperand2 : f32) +// CHECK-NEXT: ] +// CHECK-NEXT: ^1(%arg : f32): +// CHECK-NEXT: cf.br ^4(%arg : f32) +// CHECK-NEXT: ^2(%arg2 : f32): +// CHECK-NEXT: cf.br ^5(%arg2 : f32) +// CHECK-NEXT: ^3(%arg3 : f32): +// CHECK-NEXT: "test.termop"(%arg3) : (f32) -> () +// CHECK-NEXT: ^4(%arg4 : f32): +// CHECK-NEXT: "test.termop"(%arg4) : (f32) -> () +// CHECK-NEXT: ^5(%arg5 : f32): +// CHECK-NEXT: "test.termop"(%arg5) : (f32) -> () +// CHECK-NEXT: } +func.func @switch_passthrough(%flag : i32, + %caseOperand0 : f32, + %caseOperand1 : f32, + %caseOperand2 : f32, + %caseOperand3 : f32) { + // add predecessors for all blocks to avoid other canonicalizations. + "test.termop"() [^0, ^1, ^2, ^3, ^4, ^5] : () -> () + ^0: + cf.switch %flag : i32, [ + default: ^1(%caseOperand0 : f32), + 43: ^2(%caseOperand1 : f32), + 44: ^3(%caseOperand2 : f32) + ] + ^1(%arg : f32): + cf.br ^4(%arg : f32) + ^2(%arg2 : f32): + cf.br ^5(%arg2 : f32) + ^3(%arg3 : f32): + "test.termop"(%arg3) : (f32) -> () + ^4(%arg4 : f32): + "test.termop"(%arg4) : (f32) -> () + ^5(%arg5 : f32): + "test.termop"(%arg5) : (f32) -> () +} + +// CHECK: func.func @switch_from_switch_with_same_value_with_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) { +// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3] : () -> () +// CHECK-NEXT: ^0: +// CHECK-NEXT: cf.switch %flag : i32, [ +// CHECK-NEXT: default: ^1, +// CHECK-NEXT: 42: ^4 +// CHECK-NEXT: ] +// CHECK-NEXT: ^1: +// CHECK-NEXT: "test.termop"() : () -> () +// CHECK-NEXT: ^4: +// CHECK-NEXT: "test.op"() : () -> () +// CHECK-NEXT: cf.br ^3(%caseOperand1 : f32) +// CHECK-NEXT: ^2(%arg3 : f32): +// CHECK-NEXT: "test.termop"(%arg3) : (f32) -> () +// CHECK-NEXT: ^3(%arg4 : f32): +// CHECK-NEXT: "test.termop"(%arg4) : (f32) -> () +// CHECK-NEXT: } +func.func @switch_from_switch_with_same_value_with_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) { + // add predecessors for all blocks except ^2 to avoid other canonicalizations. + "test.termop"() [^0, ^1, ^3, ^4] : () -> () + ^0: + cf.switch %flag : i32, [ + default: ^1, + 42: ^2 + ] + + ^1: + "test.termop"() : () -> () + ^2: + // prevent this block from being simplified away + "test.op"() : () -> () + cf.switch %flag : i32, [ + default: ^3(%caseOperand0 : f32), + 42: ^4(%caseOperand1 : f32) + ] + ^3(%arg3 : f32): + "test.termop"(%arg3) : (f32) -> () + ^4(%arg4 : f32): + "test.termop"(%arg4) : (f32) -> () +} + +// CHECK: func.func @switch_from_switch_with_same_value_no_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) { +// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3, ^4] : () -> () +// CHECK-NEXT: ^0: +// CHECK-NEXT: cf.switch %flag : i32, [ +// CHECK-NEXT: default: ^1, +// CHECK-NEXT: 42: ^5 +// CHECK-NEXT: ] +// CHECK-NEXT: ^1: +// CHECK-NEXT: "test.termop"() : () -> () +// CHECK-NEXT: ^5: +// CHECK-NEXT: "test.op"() : () -> () +// CHECK-NEXT: cf.br ^2(%caseOperand0 : f32) +// CHECK-NEXT: ^2(%arg3 : f32): +// CHECK-NEXT: "test.termop"(%arg3) : (f32) -> () +// CHECK-NEXT: ^3(%arg4 : f32): +// CHECK-NEXT: "test.termop"(%arg4) : (f32) -> () +// CHECK-NEXT: ^4(%arg5 : f32): +// CHECK-NEXT: "test.termop"(%arg5) : (f32) -> () +// CHECK-NEXT: } +func.func @switch_from_switch_with_same_value_no_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) { + // add predecessors for all blocks except ^2 to avoid other canonicalizations. + "test.termop"() [^0, ^1, ^3, ^4, ^5] : () -> () + ^0: + cf.switch %flag : i32, [ + default: ^1, + 42: ^2 + ] + ^1: + "test.termop"() : () -> () + ^2: + "test.op"() : () -> () + cf.switch %flag : i32, [ + default: ^3(%caseOperand0 : f32), + 0: ^4(%caseOperand1 : f32), + 43: ^5(%caseOperand2 : f32) + ] + ^3(%arg3 : f32): + "test.termop"(%arg3) : (f32) -> () + ^4(%arg4 : f32): + "test.termop"(%arg4) : (f32) -> () + ^5(%arg5 : f32): + "test.termop"(%arg5) : (f32) -> () +} + +// CHECK: func.func @switch_from_switch_default_with_same_value(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) { +// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3, ^4] : () -> () +// CHECK-NEXT: ^0: +// CHECK-NEXT: cf.switch %flag : i32, [ +// CHECK-NEXT: default: ^5, +// CHECK-NEXT: 42: ^1 +// CHECK-NEXT: ] +// CHECK-NEXT: ^1: +// CHECK-NEXT: "test.termop"() : () -> () +// CHECK-NEXT: ^5: +// CHECK-NEXT: "test.op"() : () -> () +// CHECK-NEXT: cf.switch %flag : i32, [ +// CHECK-NEXT: default: ^2(%caseOperand0 : f32), +// CHECK-NEXT: 43: ^4(%caseOperand2 : f32) +// CHECK-NEXT: ] +// CHECK-NEXT: ^2(%arg3 : f32): +// CHECK-NEXT: "test.termop"(%arg3) : (f32) -> () +// CHECK-NEXT: ^3(%arg4 : f32): +// CHECK-NEXT: "test.termop"(%arg4) : (f32) -> () +// CHECK-NEXT: ^4(%arg5 : f32): +// CHECK-NEXT: "test.termop"(%arg5) : (f32) -> () +// CHECK-NEXT: } +func.func @switch_from_switch_default_with_same_value(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) { + // add predecessors for all blocks except ^2 to avoid other canonicalizations. + "test.termop"() [^0, ^1, ^3, ^4, ^5] : () -> () + ^0: + cf.switch %flag : i32, [ + default: ^2, + 42: ^1 + ] + ^1: + "test.termop"() : () -> () + ^2: + "test.op"() : () -> () + cf.switch %flag : i32, [ + default: ^3(%caseOperand0 : f32), + 42: ^4(%caseOperand1 : f32), + 43: ^5(%caseOperand2 : f32) + ] + ^3(%arg3 : f32): + "test.termop"(%arg3) : (f32) -> () + ^4(%arg4 : f32): + "test.termop"(%arg4) : (f32) -> () + ^5(%arg5 : f32): + "test.termop"(%arg5) : (f32) -> () +} diff --git a/xdsl/dialects/cf.py b/xdsl/dialects/cf.py index 3f69138639..41199104cf 100644 --- a/xdsl/dialects/cf.py +++ b/xdsl/dialects/cf.py @@ -151,6 +151,26 @@ def __init__( """ +class SwitchHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait): + @classmethod + def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: + from xdsl.transforms.canonicalization_patterns.cf import ( + DropSwitchCasesThatMatchDefault, + SimplifyConstSwitchValue, + SimplifyPassThroughSwitch, + SimplifySwitchFromSwitchOnSameCondition, + SimplifySwitchWithOnlyDefault, + ) + + return ( + SimplifySwitchWithOnlyDefault(), + SimplifyConstSwitchValue(), + SimplifyPassThroughSwitch(), + DropSwitchCasesThatMatchDefault(), + SimplifySwitchFromSwitchOnSameCondition(), + ) + + @irdl_op_definition class Switch(IRDLOperation): """Switch operation""" @@ -174,7 +194,7 @@ class Switch(IRDLOperation): irdl_options = [AttrSizedOperandSegments(as_property=True)] - traits = frozenset([IsTerminator(), Pure()]) + traits = frozenset([IsTerminator(), Pure(), SwitchHasCanonicalizationPatterns()]) def __init__( self, diff --git a/xdsl/transforms/canonicalization_patterns/cf.py b/xdsl/transforms/canonicalization_patterns/cf.py index 9cee170662..20910f357d 100644 --- a/xdsl/transforms/canonicalization_patterns/cf.py +++ b/xdsl/transforms/canonicalization_patterns/cf.py @@ -1,8 +1,14 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence +from typing import cast from xdsl.dialects import arith, cf -from xdsl.dialects.builtin import BoolAttr, IntegerAttr -from xdsl.ir import Block, BlockArgument, SSAValue +from xdsl.dialects.builtin import ( + AnyIntegerAttr, + BoolAttr, + DenseIntOrFPElementsAttr, + IntegerAttr, +) +from xdsl.ir import Block, BlockArgument, Operation, SSAValue from xdsl.pattern_rewriter import ( PatternRewriter, RewritePattern, @@ -257,3 +263,281 @@ def match_and_rewrite(self, op: cf.ConditionalBranch, rewriter: PatternRewriter) const_false.result, lambda use: use.operation.parent_block() is op.else_block, ) + + +class SimplifySwitchWithOnlyDefault(RewritePattern): + """ + switch %flag : i32, [ + default: ^bb1 + ] + -> br ^bb1 + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: cf.Switch, rewriter: PatternRewriter): + if not op.case_blocks: + rewriter.replace_matched_op( + cf.Branch(op.default_block, *op.default_operands) + ) + + +def drop_case_helper( + rewriter: PatternRewriter, + op: cf.Switch, + predicate: Callable[[AnyIntegerAttr, Block, Sequence[Operation | SSAValue]], bool], +): + case_values = op.case_values + if case_values is None: + return + requires_change = False + + new_case_values: list[int] = [] + new_case_blocks: list[Block] = [] + new_case_operands: list[Sequence[Operation | SSAValue]] = [] + + for switch_case, block, operands in zip( + case_values.data.data, + op.case_blocks, + op.case_operand, + strict=True, + ): + int_switch_case = cast(AnyIntegerAttr, switch_case) + if predicate(int_switch_case, block, operands): + requires_change = True + continue + new_case_values.append(cast(AnyIntegerAttr, switch_case).value.data) + new_case_blocks.append(block) + new_case_operands.append(operands) + + if requires_change: + rewriter.replace_matched_op( + cf.Switch( + op.flag, + op.default_block, + op.default_operands, + DenseIntOrFPElementsAttr.vector_from_list( + new_case_values, case_values.get_element_type() + ), + new_case_blocks, + new_case_operands, + ) + ) + + +class DropSwitchCasesThatMatchDefault(RewritePattern): + """ + switch %flag : i32, [ + default: ^bb1, + 42: ^bb1, + 43: ^bb2 + ] + -> + switch %flag : i32, [ + default: ^bb1, + 43: ^bb2 + ] + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: cf.Switch, rewriter: PatternRewriter): + def predicate( + switch_case: AnyIntegerAttr, + block: Block, + operands: Sequence[Operation | SSAValue], + ) -> bool: + return block == op.default_block and operands == op.default_operands + + drop_case_helper(rewriter, op, predicate) + + +def fold_switch(switch: cf.Switch, rewriter: PatternRewriter, flag: int): + """ + Helper for folding a switch with a constant value. + switch %c_42 : i32, [ + default: ^bb1 , + 42: ^bb2, + 43: ^bb3 + ] + -> br ^bb2 + """ + case_values = () if switch.case_values is None else switch.case_values.data.data + + new_block, new_operands = next( + ( + (block, operand) + for (c, block, operand) in zip( + case_values, switch.case_blocks, switch.case_operand, strict=True + ) + if flag == c.value.data + ), + (switch.default_block, switch.default_operands), + ) + + rewriter.replace_matched_op(cf.Branch(new_block, *new_operands)) + + +class SimplifyConstSwitchValue(RewritePattern): + """ + switch %c_42 : i32, [ + default: ^bb1, + 42: ^bb2, + 43: ^bb3 + ] + -> br ^bb2 + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: cf.Switch, rewriter: PatternRewriter): + if (flag := const_evaluate_operand(op.flag)) is not None: + fold_switch(op, rewriter, flag) + + +class SimplifyPassThroughSwitch(RewritePattern): + """ + switch %c_42 : i32, [ + default: ^bb1, + 42: ^bb2, + ] + ^bb2: + br ^bb3 + -> + switch %c_42 : i32, [ + default: ^bb1, + 42: ^bb3, + ] + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: cf.Switch, rewriter: PatternRewriter): + requires_change = False + + new_case_blocks: list[Block] = [] + new_case_operands: list[Sequence[Operation | SSAValue]] = [] + + for block, operands in zip(op.case_blocks, op.case_operand, strict=True): + collapsed = collapse_branch(block, operands) + requires_change |= collapsed is not None + (new_block, new_operands) = collapsed or (block, operands) + new_case_blocks.append(new_block) + new_case_operands.append(new_operands) + + collapsed = collapse_branch(op.default_block, op.default_operands) + + requires_change |= collapsed is not None + + (default_block, default_operands) = collapsed or ( + op.default_block, + op.default_operands, + ) + + if requires_change: + rewriter.replace_matched_op( + cf.Switch( + op.flag, + default_block, + default_operands, + op.case_values, + new_case_blocks, + new_case_operands, + ) + ) + + +class SimplifySwitchFromSwitchOnSameCondition(RewritePattern): + """ + switch %flag : i32, [ + default: ^bb1, + 42: ^bb2, + ] + ^bb2: + switch %flag : i32, [ + default: ^bb3, + 42: ^bb4 + ] + -> + switch %flag : i32, [ + default: ^bb1, + 42: ^bb2, + ] + ^bb2: + br ^bb4 + + and + + switch %flag : i32, [ + default: ^bb1, + 42: ^bb2, + ] + ^bb2: + switch %flag : i32, [ + default: ^bb3, + 43: ^bb4 + ] + -> + switch %flag : i32, [ + default: ^bb1, + 42: ^bb2, + ] + ^bb2: + br ^bb3 + + and + + switch %flag : i32, [ + default: ^bb1, + 42: ^bb2 + ] + ^bb1: + switch %flag : i32, [ + default: ^bb3, + 42: ^bb4, + 43: ^bb5 + ] + -> + switch %flag : i32, [ + default: ^bb1, + 42: ^bb2, + ] + ^bb1: + switch %flag : i32, [ + default: ^bb3, + 43: ^bb5 + ] + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: cf.Switch, rewriter: PatternRewriter): + block = op.parent_block() + if block is None: + return + preds = block.uses + if len(preds) != 1: + return + pred = next(iter(preds)) + switch = pred.operation + if not isinstance(switch, cf.Switch): + return + + if switch.flag != op.flag: + return + + case_values = switch.case_values + if case_values is None: + return + + if pred.index != 0: + fold_switch( + op, + rewriter, + cast(int, case_values.data.data[pred.index - 1].value.data), + ) + else: + + def predicate( + switch_case: AnyIntegerAttr, + block: Block, + operands: Sequence[Operation | SSAValue], + ) -> bool: + return switch_case in case_values.data.data + + drop_case_helper(rewriter, op, predicate)