Skip to content

Commit

Permalink
transformations: (cf) cf.cond_br truth propagation (#3285)
Browse files Browse the repository at this point in the history
Truth propagation canonicalization pattern.

I have not implemented the
`SimplifyCondBranchFromCondBranchOnSameCondition` pattern from mlir,
because it seems to be implied from this pattern and previous const
folding pattern.
  • Loading branch information
alexarice authored Oct 15, 2024
1 parent f0a74f1 commit b452684
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 1 deletion.
44 changes: 44 additions & 0 deletions tests/filecheck/dialects/cf/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,47 @@ func.func @cond_br_same_successor_insert_select(
^bb1(%result : i32, %result2 : tensor<2xi32>):
return %result, %result2 : i32, tensor<2xi32>
}

/// Test folding conditional branches that are successors of conditional
/// branches with the same condition.
// CHECK: func.func @cond_br_from_cond_br_with_same_condition(%cond : i1) {
// CHECK-NEXT: cf.cond_br %cond, ^0, ^1
// CHECK-NEXT: ^0:
// CHECK-NEXT: func.return
// CHECK-NEXT: ^1:
// CHECK-NEXT: "test.termop"() : () -> ()
// CHECK-NEXT: }
func.func @cond_br_from_cond_br_with_same_condition(%cond : i1) {
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
cf.cond_br %cond, ^bb3, ^bb2
^bb2:
"test.termop"() : () -> ()
^bb3:
return
}

// CHECK: func.func @branchCondProp(%arg0 : i1) {
// CHECK-NEXT: %arg0_1 = arith.constant true
// CHECK-NEXT: %arg0_2 = arith.constant false
// CHECK-NEXT: cf.cond_br %arg0, ^0, ^1
// CHECK-NEXT: ^0:
// CHECK-NEXT: "test.op"(%arg0_1) : (i1) -> ()
// CHECK-NEXT: cf.br ^2
// CHECK-NEXT: ^1:
// CHECK-NEXT: "test.op"(%arg0_2) : (i1) -> ()
// CHECK-NEXT: cf.br ^2
// CHECK-NEXT: ^2:
// CHECK-NEXT: func.return
// CHECK-NEXT: }
func.func @branchCondProp(%arg0: i1) {
cf.cond_br %arg0, ^trueB, ^falseB
^trueB:
"test.op"(%arg0) : (i1) -> ()
cf.br ^exit
^falseB:
"test.op"(%arg0) : (i1) -> ()
cf.br ^exit
^exit:
return
}
4 changes: 4 additions & 0 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,10 @@ def from_int_and_width(value: int, width: int) -> IntegerAttr[IntegerType]:
def from_index_int_value(value: int) -> IntegerAttr[IndexType]:
return IntegerAttr(value, IndexType())

@staticmethod
def from_bool(value: bool) -> BoolAttr:
return IntegerAttr(value, 1)

def verify(self) -> None:
if isinstance(int_type := self.type, IndexType):
return
Expand Down
2 changes: 2 additions & 0 deletions xdsl/dialects/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class ConditionalBranchHasCanonicalizationPatterns(HasCanonicalizationPatternsTr
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.cf import (
CondBranchTruthPropagation,
SimplifyCondBranchIdenticalSuccessors,
SimplifyConstCondBranchPred,
SimplifyPassThroughCondBranch,
Expand All @@ -109,6 +110,7 @@ def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
SimplifyConstCondBranchPred(),
SimplifyPassThroughCondBranch(),
SimplifyCondBranchIdenticalSuccessors(),
CondBranchTruthPropagation(),
)


Expand Down
50 changes: 49 additions & 1 deletion xdsl/transforms/canonicalization_patterns/cf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Sequence

from xdsl.dialects import arith, cf
from xdsl.dialects.builtin import IntegerAttr
from xdsl.dialects.builtin import BoolAttr, IntegerAttr
from xdsl.ir import Block, BlockArgument, SSAValue
from xdsl.pattern_rewriter import (
PatternRewriter,
Expand Down Expand Up @@ -209,3 +209,51 @@ def match_and_rewrite(self, op: cf.ConditionalBranch, rewriter: PatternRewriter)
)

rewriter.replace_matched_op(cf.Branch(op.then_block, *merged_operands))


class CondBranchTruthPropagation(RewritePattern):
"""
cf.cond_br %arg0, ^trueB, ^falseB
^trueB:
"test.consumer1"(%arg0) : (i1) -> ()
...
^falseB:
"test.consumer2"(%arg0) : (i1) -> ()
...
->
cf.cond_br %arg0, ^trueB, ^falseB
^trueB:
"test.consumer1"(%true) : (i1) -> ()
...
^falseB:
"test.consumer2"(%false) : (i1) -> ()
...
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.ConditionalBranch, rewriter: PatternRewriter):
if len(op.then_block.predecessors()) == 1:
if any(
use.operation.parent_block() is op.then_block for use in op.cond.uses
):
const_true = arith.Constant(BoolAttr.from_bool(True))
rewriter.insert_op(const_true, InsertPoint.before(op))
op.cond.replace_by_if(
const_true.result,
lambda use: use.operation.parent_block() is op.then_block,
)
if len(op.else_block.predecessors()) == 1:
if any(
use.operation.parent_block() is op.else_block for use in op.cond.uses
):
const_false = arith.Constant(BoolAttr.from_bool(False))
rewriter.insert_op(const_false, InsertPoint.before(op))
op.cond.replace_by_if(
const_false.result,
lambda use: use.operation.parent_block() is op.else_block,
)

0 comments on commit b452684

Please sign in to comment.