Skip to content

Commit

Permalink
transformations: (cf) switch canonicalization (xdslproject#3291)
Browse files Browse the repository at this point in the history
Adds all the switch canonicalization patterns present in mlir. 

---------

Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
2 people authored and EdmundGoodman committed Dec 6, 2024
1 parent 4c65e2a commit b73b7ef
Show file tree
Hide file tree
Showing 3 changed files with 591 additions and 4 deletions.
283 changes: 283 additions & 0 deletions tests/filecheck/dialects/cf/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,286 @@ func.func @branchCondProp(%arg0: i1) {
^exit:
return
}

/// Test the folding of SwitchOp
// CHECK: func.func @switch_only_default(%flag : i32, %caseOperand0 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.br ^1(%caseOperand0 : f32)
// CHECK-NEXT: ^1(%arg : f32):
// CHECK-NEXT: "test.termop"(%arg) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_only_default(%flag : i32, %caseOperand0 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"test.termop"() [^0, ^1] : () -> ()
^0:
cf.switch %flag : i32, [
default: ^1(%caseOperand0 : f32)
]
^1(%arg : f32):
"test.termop"(%arg) : (f32) -> ()
}


// CHECK: func.func @switch_case_matching_default(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.switch %flag : i32, [
// CHECK-NEXT: default: ^1(%caseOperand0 : f32),
// CHECK-NEXT: 10: ^2(%caseOperand1 : f32)
// CHECK-NEXT: ]
// CHECK-NEXT: ^1(%arg : f32):
// CHECK-NEXT: "test.termop"(%arg) : (f32) -> ()
// CHECK-NEXT: ^2(%arg2 : f32):
// CHECK-NEXT: "test.termop"(%arg2) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_case_matching_default(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"test.termop"() [^0, ^1, ^2] : () -> ()
^0:
cf.switch %flag : i32, [
default: ^1(%caseOperand0 : f32),
42: ^1(%caseOperand0 : f32),
10: ^2(%caseOperand1 : f32),
17: ^1(%caseOperand0 : f32)
]
^1(%arg : f32):
"test.termop"(%arg) : (f32) -> ()
^2(%arg2 : f32):
"test.termop"(%arg2) : (f32) -> ()
}


// CHECK: func.func @switch_on_const_no_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.br ^1(%caseOperand0 : f32)
// CHECK-NEXT: ^1(%arg : f32):
// CHECK-NEXT: "test.termop"(%arg) : (f32) -> ()
// CHECK-NEXT: ^2(%arg2 : f32):
// CHECK-NEXT: "test.termop"(%arg2) : (f32) -> ()
// CHECK-NEXT: ^3(%arg3 : f32):
// CHECK-NEXT: "test.termop"(%arg3) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_on_const_no_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"test.termop"() [^0, ^1, ^2, ^3] : () -> ()
^0:
%c0_i32 = arith.constant 0 : i32
cf.switch %c0_i32 : i32, [
default: ^1(%caseOperand0 : f32),
-1: ^2(%caseOperand1 : f32),
1: ^3(%caseOperand2 : f32)
]
^1(%arg : f32):
"test.termop"(%arg) : (f32) -> ()
^2(%arg2 : f32):
"test.termop"(%arg2) : (f32) -> ()
^3(%arg3 : f32):
"test.termop"(%arg3) : (f32) -> ()
}

// CHECK: func.func @switch_on_const_with_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.br ^3(%caseOperand2 : f32)
// CHECK-NEXT: ^1(%arg : f32):
// CHECK-NEXT: "test.termop"(%arg) : (f32) -> ()
// CHECK-NEXT: ^2(%arg2 : f32):
// CHECK-NEXT: "test.termop"(%arg2) : (f32) -> ()
// CHECK-NEXT: ^3(%arg3 : f32):
// CHECK-NEXT: "test.termop"(%arg3) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_on_const_with_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"test.termop"() [^0, ^1, ^2, ^3] : () -> ()
^0:
%c0_i32 = arith.constant 1 : i32
cf.switch %c0_i32 : i32, [
default: ^1(%caseOperand0 : f32),
-1: ^2(%caseOperand1 : f32),
1: ^3(%caseOperand2 : f32)
]
^1(%arg : f32):
"test.termop"(%arg) : (f32) -> ()
^2(%arg2 : f32):
"test.termop"(%arg2) : (f32) -> ()
^3(%arg3 : f32):
"test.termop"(%arg3) : (f32) -> ()
}

// CHECK: func.func @switch_passthrough(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32, %caseOperand3 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3, ^4, ^5] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.switch %flag : i32, [
// CHECK-NEXT: default: ^4(%caseOperand0 : f32),
// CHECK-NEXT: 43: ^5(%caseOperand1 : f32),
// CHECK-NEXT: 44: ^3(%caseOperand2 : f32)
// CHECK-NEXT: ]
// CHECK-NEXT: ^1(%arg : f32):
// CHECK-NEXT: cf.br ^4(%arg : f32)
// CHECK-NEXT: ^2(%arg2 : f32):
// CHECK-NEXT: cf.br ^5(%arg2 : f32)
// CHECK-NEXT: ^3(%arg3 : f32):
// CHECK-NEXT: "test.termop"(%arg3) : (f32) -> ()
// CHECK-NEXT: ^4(%arg4 : f32):
// CHECK-NEXT: "test.termop"(%arg4) : (f32) -> ()
// CHECK-NEXT: ^5(%arg5 : f32):
// CHECK-NEXT: "test.termop"(%arg5) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_passthrough(%flag : i32,
%caseOperand0 : f32,
%caseOperand1 : f32,
%caseOperand2 : f32,
%caseOperand3 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"test.termop"() [^0, ^1, ^2, ^3, ^4, ^5] : () -> ()
^0:
cf.switch %flag : i32, [
default: ^1(%caseOperand0 : f32),
43: ^2(%caseOperand1 : f32),
44: ^3(%caseOperand2 : f32)
]
^1(%arg : f32):
cf.br ^4(%arg : f32)
^2(%arg2 : f32):
cf.br ^5(%arg2 : f32)
^3(%arg3 : f32):
"test.termop"(%arg3) : (f32) -> ()
^4(%arg4 : f32):
"test.termop"(%arg4) : (f32) -> ()
^5(%arg5 : f32):
"test.termop"(%arg5) : (f32) -> ()
}

// CHECK: func.func @switch_from_switch_with_same_value_with_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.switch %flag : i32, [
// CHECK-NEXT: default: ^1,
// CHECK-NEXT: 42: ^4
// CHECK-NEXT: ]
// CHECK-NEXT: ^1:
// CHECK-NEXT: "test.termop"() : () -> ()
// CHECK-NEXT: ^4:
// CHECK-NEXT: "test.op"() : () -> ()
// CHECK-NEXT: cf.br ^3(%caseOperand1 : f32)
// CHECK-NEXT: ^2(%arg3 : f32):
// CHECK-NEXT: "test.termop"(%arg3) : (f32) -> ()
// CHECK-NEXT: ^3(%arg4 : f32):
// CHECK-NEXT: "test.termop"(%arg4) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_from_switch_with_same_value_with_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
// add predecessors for all blocks except ^2 to avoid other canonicalizations.
"test.termop"() [^0, ^1, ^3, ^4] : () -> ()
^0:
cf.switch %flag : i32, [
default: ^1,
42: ^2
]

^1:
"test.termop"() : () -> ()
^2:
// prevent this block from being simplified away
"test.op"() : () -> ()
cf.switch %flag : i32, [
default: ^3(%caseOperand0 : f32),
42: ^4(%caseOperand1 : f32)
]
^3(%arg3 : f32):
"test.termop"(%arg3) : (f32) -> ()
^4(%arg4 : f32):
"test.termop"(%arg4) : (f32) -> ()
}

// CHECK: func.func @switch_from_switch_with_same_value_no_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3, ^4] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.switch %flag : i32, [
// CHECK-NEXT: default: ^1,
// CHECK-NEXT: 42: ^5
// CHECK-NEXT: ]
// CHECK-NEXT: ^1:
// CHECK-NEXT: "test.termop"() : () -> ()
// CHECK-NEXT: ^5:
// CHECK-NEXT: "test.op"() : () -> ()
// CHECK-NEXT: cf.br ^2(%caseOperand0 : f32)
// CHECK-NEXT: ^2(%arg3 : f32):
// CHECK-NEXT: "test.termop"(%arg3) : (f32) -> ()
// CHECK-NEXT: ^3(%arg4 : f32):
// CHECK-NEXT: "test.termop"(%arg4) : (f32) -> ()
// CHECK-NEXT: ^4(%arg5 : f32):
// CHECK-NEXT: "test.termop"(%arg5) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_from_switch_with_same_value_no_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// add predecessors for all blocks except ^2 to avoid other canonicalizations.
"test.termop"() [^0, ^1, ^3, ^4, ^5] : () -> ()
^0:
cf.switch %flag : i32, [
default: ^1,
42: ^2
]
^1:
"test.termop"() : () -> ()
^2:
"test.op"() : () -> ()
cf.switch %flag : i32, [
default: ^3(%caseOperand0 : f32),
0: ^4(%caseOperand1 : f32),
43: ^5(%caseOperand2 : f32)
]
^3(%arg3 : f32):
"test.termop"(%arg3) : (f32) -> ()
^4(%arg4 : f32):
"test.termop"(%arg4) : (f32) -> ()
^5(%arg5 : f32):
"test.termop"(%arg5) : (f32) -> ()
}

// CHECK: func.func @switch_from_switch_default_with_same_value(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// CHECK-NEXT: "test.termop"() [^0, ^1, ^2, ^3, ^4] : () -> ()
// CHECK-NEXT: ^0:
// CHECK-NEXT: cf.switch %flag : i32, [
// CHECK-NEXT: default: ^5,
// CHECK-NEXT: 42: ^1
// CHECK-NEXT: ]
// CHECK-NEXT: ^1:
// CHECK-NEXT: "test.termop"() : () -> ()
// CHECK-NEXT: ^5:
// CHECK-NEXT: "test.op"() : () -> ()
// CHECK-NEXT: cf.switch %flag : i32, [
// CHECK-NEXT: default: ^2(%caseOperand0 : f32),
// CHECK-NEXT: 43: ^4(%caseOperand2 : f32)
// CHECK-NEXT: ]
// CHECK-NEXT: ^2(%arg3 : f32):
// CHECK-NEXT: "test.termop"(%arg3) : (f32) -> ()
// CHECK-NEXT: ^3(%arg4 : f32):
// CHECK-NEXT: "test.termop"(%arg4) : (f32) -> ()
// CHECK-NEXT: ^4(%arg5 : f32):
// CHECK-NEXT: "test.termop"(%arg5) : (f32) -> ()
// CHECK-NEXT: }
func.func @switch_from_switch_default_with_same_value(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// add predecessors for all blocks except ^2 to avoid other canonicalizations.
"test.termop"() [^0, ^1, ^3, ^4, ^5] : () -> ()
^0:
cf.switch %flag : i32, [
default: ^2,
42: ^1
]
^1:
"test.termop"() : () -> ()
^2:
"test.op"() : () -> ()
cf.switch %flag : i32, [
default: ^3(%caseOperand0 : f32),
42: ^4(%caseOperand1 : f32),
43: ^5(%caseOperand2 : f32)
]
^3(%arg3 : f32):
"test.termop"(%arg3) : (f32) -> ()
^4(%arg4 : f32):
"test.termop"(%arg4) : (f32) -> ()
^5(%arg5 : f32):
"test.termop"(%arg5) : (f32) -> ()
}
22 changes: 21 additions & 1 deletion xdsl/dialects/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,26 @@ def __init__(
"""


class SwitchHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.cf import (
DropSwitchCasesThatMatchDefault,
SimplifyConstSwitchValue,
SimplifyPassThroughSwitch,
SimplifySwitchFromSwitchOnSameCondition,
SimplifySwitchWithOnlyDefault,
)

return (
SimplifySwitchWithOnlyDefault(),
SimplifyConstSwitchValue(),
SimplifyPassThroughSwitch(),
DropSwitchCasesThatMatchDefault(),
SimplifySwitchFromSwitchOnSameCondition(),
)


@irdl_op_definition
class Switch(IRDLOperation):
"""Switch operation"""
Expand All @@ -174,7 +194,7 @@ class Switch(IRDLOperation):

irdl_options = [AttrSizedOperandSegments(as_property=True)]

traits = frozenset([IsTerminator(), Pure()])
traits = frozenset([IsTerminator(), Pure(), SwitchHasCanonicalizationPatterns()])

def __init__(
self,
Expand Down
Loading

0 comments on commit b73b7ef

Please sign in to comment.