Skip to content

Commit

Permalink
transformations: (cf) switch canonicalization
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Oct 15, 2024
1 parent f0a74f1 commit fc0e077
Show file tree
Hide file tree
Showing 3 changed files with 585 additions and 4 deletions.
283 changes: 283 additions & 0 deletions tests/filecheck/dialects/cf/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,286 @@ func.func @cond_br_same_successor_insert_select(
^bb1(%result : i32, %result2 : tensor<2xi32>):
return %result, %result2 : i32, tensor<2xi32>
}

/// Test the folding of SwitchOp
// CHECK: func.func @switch_only_default(%flag : i32, %caseOperand0 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.br ^1(%caseOperand0 : f32)
// CHECK-NEXT: ^1(%bb2Arg : f32):
// CHECK-NEXT: "test.termop"(%bb2Arg) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_only_default(%flag : i32, %caseOperand0 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"test.termop"() [^bb1, ^bb2] : () -> ()
^bb1:
cf.switch %flag : i32, [
default: ^bb2(%caseOperand0 : f32)
]
^bb2(%bb2Arg : f32):
"test.termop"(%bb2Arg) : (f32) -> ()
}


// CHECK: func.func @switch_case_matching_default(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.switch %flag : i32, [
// CHECK-NEXT: default: ^1(%caseOperand0 : f32),
// CHECK-NEXT: 10: ^2(%caseOperand1 : f32)
// CHECK-NEXT: ]
// CHECK-NEXT: ^1(%bb2Arg : f32):
// CHECK-NEXT: "test.termop"(%bb2Arg) : (f32) -> ()
// CHECK-NEXT: ^2(%bb3Arg : f32):
// CHECK-NEXT: "test.termop"(%bb3Arg) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_case_matching_default(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"test.termop"() [^bb1, ^bb2, ^bb3] : () -> ()
^bb1:
cf.switch %flag : i32, [
default: ^bb2(%caseOperand0 : f32),
42: ^bb2(%caseOperand0 : f32),
10: ^bb3(%caseOperand1 : f32),
17: ^bb2(%caseOperand0 : f32)
]
^bb2(%bb2Arg : f32):
"test.termop"(%bb2Arg) : (f32) -> ()
^bb3(%bb3Arg : f32):
"test.termop"(%bb3Arg) : (f32) -> ()
}


// CHECK: func.func @switch_on_const_no_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.br ^1(%caseOperand0 : f32)
// CHECK-NEXT: ^1(%bb2Arg : f32):
// CHECK-NEXT: "test.termop"(%bb2Arg) : (f32) -> ()
// CHECK-NEXT: ^2(%bb3Arg : f32):
// CHECK-NEXT: "test.termop"(%bb3Arg) : (f32) -> ()
// CHECK-NEXT: ^3(%bb4Arg : f32):
// CHECK-NEXT: "test.termop"(%bb4Arg) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_on_const_no_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"test.termop"() [^bb1, ^bb2, ^bb3, ^bb4] : () -> ()
^bb1:
%c0_i32 = arith.constant 0 : i32
cf.switch %c0_i32 : i32, [
default: ^bb2(%caseOperand0 : f32),
-1: ^bb3(%caseOperand1 : f32),
1: ^bb4(%caseOperand2 : f32)
]
^bb2(%bb2Arg : f32):
"test.termop"(%bb2Arg) : (f32) -> ()
^bb3(%bb3Arg : f32):
"test.termop"(%bb3Arg) : (f32) -> ()
^bb4(%bb4Arg : f32):
"test.termop"(%bb4Arg) : (f32) -> ()
}

// CHECK: func.func @switch_on_const_with_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.br ^3(%caseOperand2 : f32)
// CHECK-NEXT: ^1(%bb2Arg : f32):
// CHECK-NEXT: "test.termop"(%bb2Arg) : (f32) -> ()
// CHECK-NEXT: ^2(%bb3Arg : f32):
// CHECK-NEXT: "test.termop"(%bb3Arg) : (f32) -> ()
// CHECK-NEXT: ^3(%bb4Arg : f32):
// CHECK-NEXT: "test.termop"(%bb4Arg) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_on_const_with_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"test.termop"() [^bb1, ^bb2, ^bb3, ^bb4] : () -> ()
^bb1:
%c0_i32 = arith.constant 1 : i32
cf.switch %c0_i32 : i32, [
default: ^bb2(%caseOperand0 : f32),
-1: ^bb3(%caseOperand1 : f32),
1: ^bb4(%caseOperand2 : f32)
]
^bb2(%bb2Arg : f32):
"test.termop"(%bb2Arg) : (f32) -> ()
^bb3(%bb3Arg : f32):
"test.termop"(%bb3Arg) : (f32) -> ()
^bb4(%bb4Arg : f32):
"test.termop"(%bb4Arg) : (f32) -> ()
}

// CHECK: func.func @switch_passthrough(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32, %caseOperand3 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3, ^4, ^5] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.switch %flag : i32, [
// CHECK-NEXT: default: ^4(%caseOperand0 : f32),
// CHECK-NEXT: 43: ^5(%caseOperand1 : f32),
// CHECK-NEXT: 44: ^3(%caseOperand2 : f32)
// CHECK-NEXT: ]
// CHECK-NEXT: ^1(%bb2Arg : f32):
// CHECK-NEXT: cf.br ^4(%bb2Arg : f32)
// CHECK-NEXT: ^2(%bb3Arg : f32):
// CHECK-NEXT: cf.br ^5(%bb3Arg : f32)
// CHECK-NEXT: ^3(%bb4Arg : f32):
// CHECK-NEXT: "test.termop"(%bb4Arg) : (f32) -> ()
// CHECK-NEXT: ^4(%bb5Arg : f32):
// CHECK-NEXT: "test.termop"(%bb5Arg) : (f32) -> ()
// CHECK-NEXT: ^5(%bb6Arg : f32):
// CHECK-NEXT: "test.termop"(%bb6Arg) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_passthrough(%flag : i32,
%caseOperand0 : f32,
%caseOperand1 : f32,
%caseOperand2 : f32,
%caseOperand3 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"test.termop"() [^bb1, ^bb2, ^bb3, ^bb4, ^bb5, ^bb6] : () -> ()
^bb1:
cf.switch %flag : i32, [
default: ^bb2(%caseOperand0 : f32),
43: ^bb3(%caseOperand1 : f32),
44: ^bb4(%caseOperand2 : f32)
]
^bb2(%bb2Arg : f32):
cf.br ^bb5(%bb2Arg : f32)
^bb3(%bb3Arg : f32):
cf.br ^bb6(%bb3Arg : f32)
^bb4(%bb4Arg : f32):
"test.termop"(%bb4Arg) : (f32) -> ()
^bb5(%bb5Arg : f32):
"test.termop"(%bb5Arg) : (f32) -> ()
^bb6(%bb6Arg : f32):
"test.termop"(%bb6Arg) : (f32) -> ()
}

// CHECK: func.func @switch_from_switch_with_same_value_with_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.switch %flag : i32, [
// CHECK-NEXT: default: ^1,
// CHECK-NEXT: 42: ^4
// CHECK-NEXT: ]
// CHECK-NEXT: ^1:
// CHECK-NEXT: "test.termop"() : () -> ()
// CHECK-NEXT: ^4:
// CHECK-NEXT: "test.op"() : () -> ()
// CHECK-NEXT: cf.br ^3(%caseOperand1 : f32)
// CHECK-NEXT: ^2(%bb4Arg : f32):
// CHECK-NEXT: "test.termop"(%bb4Arg) : (f32) -> ()
// CHECK-NEXT: ^3(%bb5Arg : f32):
// CHECK-NEXT: "test.termop"(%bb5Arg) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_from_switch_with_same_value_with_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
// add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
"test.termop"() [^bb1, ^bb2, ^bb4, ^bb5] : () -> ()
^bb1:
cf.switch %flag : i32, [
default: ^bb2,
42: ^bb3
]

^bb2:
"test.termop"() : () -> ()
^bb3:
// prevent this block from being simplified away
"test.op"() : () -> ()
cf.switch %flag : i32, [
default: ^bb4(%caseOperand0 : f32),
42: ^bb5(%caseOperand1 : f32)
]
^bb4(%bb4Arg : f32):
"test.termop"(%bb4Arg) : (f32) -> ()
^bb5(%bb5Arg : f32):
"test.termop"(%bb5Arg) : (f32) -> ()
}

// CHECK: func.func @switch_from_switch_with_same_value_no_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3, ^4] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.switch %flag : i32, [
// CHECK-NEXT: default: ^1,
// CHECK-NEXT: 42: ^5
// CHECK-NEXT: ]
// CHECK-NEXT: ^1:
// CHECK-NEXT: "test.termop"() : () -> ()
// CHECK-NEXT: ^5:
// CHECK-NEXT: "test.op"() : () -> ()
// CHECK-NEXT: cf.br ^2(%caseOperand0 : f32)
// CHECK-NEXT: ^2(%bb4Arg : f32):
// CHECK-NEXT: "test.termop"(%bb4Arg) : (f32) -> ()
// CHECK-NEXT: ^3(%bb5Arg : f32):
// CHECK-NEXT: "test.termop"(%bb5Arg) : (f32) -> ()
// CHECK-NEXT: ^4(%bb6Arg : f32):
// CHECK-NEXT: "test.termop"(%bb6Arg) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_from_switch_with_same_value_no_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
"test.termop"() [^bb1, ^bb2, ^bb4, ^bb5, ^bb6] : () -> ()
^bb1:
cf.switch %flag : i32, [
default: ^bb2,
42: ^bb3
]
^bb2:
"test.termop"() : () -> ()
^bb3:
"test.op"() : () -> ()
cf.switch %flag : i32, [
default: ^bb4(%caseOperand0 : f32),
0: ^bb5(%caseOperand1 : f32),
43: ^bb6(%caseOperand2 : f32)
]
^bb4(%bb4Arg : f32):
"test.termop"(%bb4Arg) : (f32) -> ()
^bb5(%bb5Arg : f32):
"test.termop"(%bb5Arg) : (f32) -> ()
^bb6(%bb6Arg : f32):
"test.termop"(%bb6Arg) : (f32) -> ()
}

// CHECK: func.func @switch_from_switch_default_with_same_value(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3, ^4] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.switch %flag : i32, [
// CHECK-NEXT: default: ^5,
// CHECK-NEXT: 42: ^1
// CHECK-NEXT: ]
// CHECK-NEXT: ^1:
// CHECK-NEXT: "test.termop"() : () -> ()
// CHECK-NEXT: ^5:
// CHECK-NEXT: "test.op"() : () -> ()
// CHECK-NEXT: cf.switch %flag : i32, [
// CHECK-NEXT: default: ^2(%caseOperand0 : f32),
// CHECK-NEXT: 43: ^4(%caseOperand2 : f32)
// CHECK-NEXT: ]
// CHECK-NEXT: ^2(%bb4Arg : f32):
// CHECK-NEXT: "test.termop"(%bb4Arg) : (f32) -> ()
// CHECK-NEXT: ^3(%bb5Arg : f32):
// CHECK-NEXT: "test.termop"(%bb5Arg) : (f32) -> ()
// CHECK-NEXT: ^4(%bb6Arg : f32):
// CHECK-NEXT: "test.termop"(%bb6Arg) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_from_switch_default_with_same_value(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
"test.termop"() [^bb1, ^bb2, ^bb4, ^bb5, ^bb6] : () -> ()
^bb1:
cf.switch %flag : i32, [
default: ^bb3,
42: ^bb2
]
^bb2:
"test.termop"() : () -> ()
^bb3:
"test.op"() : () -> ()
cf.switch %flag : i32, [
default: ^bb4(%caseOperand0 : f32),
42: ^bb5(%caseOperand1 : f32),
43: ^bb6(%caseOperand2 : f32)
]
^bb4(%bb4Arg : f32):
"test.termop"(%bb4Arg) : (f32) -> ()
^bb5(%bb5Arg : f32):
"test.termop"(%bb5Arg) : (f32) -> ()
^bb6(%bb6Arg : f32):
"test.termop"(%bb6Arg) : (f32) -> ()
}
22 changes: 21 additions & 1 deletion xdsl/dialects/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,26 @@ def __init__(
"""


class SwitchHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.cf import (
DropSwitchCasesThatMatchDefault,
SimplifyConstSwitchValue,
SimplifyPassThroughSwitch,
SimplifySwitchFromSwitchOnSameCondition,
SimplifySwitchWithOnlyDefault,
)

return (
SimplifySwitchWithOnlyDefault(),
SimplifyConstSwitchValue(),
SimplifyPassThroughSwitch(),
DropSwitchCasesThatMatchDefault(),
SimplifySwitchFromSwitchOnSameCondition(),
)


@irdl_op_definition
class Switch(IRDLOperation):
"""Switch operation"""
Expand All @@ -172,7 +192,7 @@ class Switch(IRDLOperation):

irdl_options = [AttrSizedOperandSegments(as_property=True)]

traits = frozenset([IsTerminator(), Pure()])
traits = frozenset([IsTerminator(), Pure(), SwitchHasCanonicalizationPatterns()])

def __init__(
self,
Expand Down
Loading

0 comments on commit fc0e077

Please sign in to comment.