Skip to content

Commit

Permalink
transformations: (cf) cf.cond_br identical successors (#3284)
Browse files Browse the repository at this point in the history
Adds simplification to `cf.cond_br` when both branches point to the same
successor.

The mlir version only performs the reduction when either all the
operands are the same, or the common destination branch only has one
predecessor. For now this precondition has not been included in the implementation.

Also updates some of the other tests which now reduce further

---------

Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
alexarice and superlopuh authored Oct 15, 2024
1 parent a7d8ebe commit f0a74f1
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 5 deletions.
31 changes: 27 additions & 4 deletions tests/filecheck/dialects/cf/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ func.func @br_dead_passthrough() {
/// This will reduce further with other rewrites

// CHECK: func.func @cond_br_folding(%cond : i1, %a : i32) {
// CHECK-NEXT: cf.cond_br %cond, ^[[#b0:]], ^[[#b0]]
// CHECK-NEXT: ^[[#b0]]:
// CHECK-NEXT: func.return
// CHECK-NEXT: }
func.func @cond_br_folding(%cond : i1, %a : i32) {
Expand Down Expand Up @@ -101,8 +99,8 @@ func.func @cond_br_and_br_folding(%a : i32) {

/// Test that pass-through successors of CondBranchOp get folded.
// CHECK: func.func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
// CHECK-NEXT: cf.cond_br %cond, ^[[#b0:]](%arg0, %arg1 : i32, i32), ^[[#b0]](%arg2, %arg2 : i32, i32)
// CHECK-NEXT: ^[[#b0]](%arg4 : i32, %arg5 : i32):
// CHECK-NEXT: %arg4 = arith.select %cond, %arg0, %arg2 : i32
// CHECK-NEXT: %arg5 = arith.select %cond, %arg1, %arg2 : i32
// CHECK-NEXT: func.return %arg4, %arg5 : i32, i32
// CHECK-NEXT: }
func.func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
Expand Down Expand Up @@ -131,3 +129,28 @@ func.func @cond_br_pass_through_fail(%cond : i1) {
^bb2:
return
}

/// Test the folding of CondBranchOp when the successors are identical.
// CHECK: func.func @cond_br_same_successor(%cond : i1, %a : i32) {
// CHECK-NEXT: func.return
// CHECK-NEXT: }
func.func @cond_br_same_successor(%cond : i1, %a : i32) {
cf.cond_br %cond, ^bb1(%a : i32), ^bb1(%a : i32)
^bb1(%result : i32):
return
}

/// Test the folding of CondBranchOp when the successors are identical, but the
/// arguments are different.
// CHECK: func.func @cond_br_same_successor_insert_select(%cond : i1, %a : i32, %b : i32, %c : tensor<2xi32>, %d : tensor<2xi32>) -> (i32, tensor<2xi32>) {
// CHECK-NEXT: %result = arith.select %cond, %a, %b : i32
// CHECK-NEXT: %result2 = arith.select %cond, %c, %d : tensor<2xi32>
// CHECK-NEXT: func.return %result, %result2 : i32, tensor<2xi32>
// CHECK-NEXT: }
func.func @cond_br_same_successor_insert_select(
%cond : i1, %a : i32, %b : i32, %c : tensor<2xi32>, %d : tensor<2xi32>
) -> (i32, tensor<2xi32>) {
cf.cond_br %cond, ^bb1(%a, %c : i32, tensor<2xi32>), ^bb1(%b, %d : i32, tensor<2xi32>)
^bb1(%result : i32, %result2 : tensor<2xi32>):
return %result, %result2 : i32, tensor<2xi32>
}
7 changes: 6 additions & 1 deletion xdsl/dialects/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,16 @@ class ConditionalBranchHasCanonicalizationPatterns(HasCanonicalizationPatternsTr
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.cf import (
SimplifyCondBranchIdenticalSuccessors,
SimplifyConstCondBranchPred,
SimplifyPassThroughCondBranch,
)

return (SimplifyConstCondBranchPred(), SimplifyPassThroughCondBranch())
return (
SimplifyConstCondBranchPred(),
SimplifyPassThroughCondBranch(),
SimplifyCondBranchIdenticalSuccessors(),
)


@irdl_op_definition
Expand Down
37 changes: 37 additions & 0 deletions xdsl/transforms/canonicalization_patterns/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,40 @@ def match_and_rewrite(self, op: cf.ConditionalBranch, rewriter: PatternRewriter)
op.cond, new_then, new_then_args, new_else, new_else_args
)
)


class SimplifyCondBranchIdenticalSuccessors(RewritePattern):
"""
cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
-> br ^bb1(A, ..., N)
cf.cond_br %cond, ^bb1(A), ^bb1(B)
-> %select = arith.select %cond, A, B
br ^bb1(%select)
"""

@staticmethod
def _merge_operand(
op1: SSAValue,
op2: SSAValue,
rewriter: PatternRewriter,
cond_br: cf.ConditionalBranch,
) -> SSAValue:
if op1 == op2:
return op1
select = arith.Select(cond_br.cond, op1, op2)
rewriter.insert_op(select, InsertPoint.before(cond_br))
return select.result

@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.ConditionalBranch, rewriter: PatternRewriter):
# Check that the true and false destinations are the same
if op.then_block != op.else_block:
return

merged_operands = tuple(
self._merge_operand(op1, op2, rewriter, op)
for (op1, op2) in zip(op.then_arguments, op.else_arguments, strict=True)
)

rewriter.replace_matched_op(cf.Branch(op.then_block, *merged_operands))

0 comments on commit f0a74f1

Please sign in to comment.