From b452684085f7845760a8303bd39b721ce8194406 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Tue, 15 Oct 2024 14:17:41 +0100 Subject: [PATCH] transformations: (cf) cf.cond_br truth propagation (#3285) Truth propagation canonicalization pattern. I have not implemented the `SimplifyCondBranchFromCondBranchOnSameCondition` pattern from mlir, because it seems to be implied from this pattern and previous const folding pattern. --- tests/filecheck/dialects/cf/canonicalize.mlir | 44 ++++++++++++++++ xdsl/dialects/builtin.py | 4 ++ xdsl/dialects/cf.py | 2 + .../canonicalization_patterns/cf.py | 50 ++++++++++++++++++- 4 files changed, 99 insertions(+), 1 deletion(-) diff --git a/tests/filecheck/dialects/cf/canonicalize.mlir b/tests/filecheck/dialects/cf/canonicalize.mlir index a6f82ccc07..82ed7e14fc 100644 --- a/tests/filecheck/dialects/cf/canonicalize.mlir +++ b/tests/filecheck/dialects/cf/canonicalize.mlir @@ -154,3 +154,47 @@ func.func @cond_br_same_successor_insert_select( ^bb1(%result : i32, %result2 : tensor<2xi32>): return %result, %result2 : i32, tensor<2xi32> } + +/// Test folding conditional branches that are successors of conditional +/// branches with the same condition. +// CHECK: func.func @cond_br_from_cond_br_with_same_condition(%cond : i1) { +// CHECK-NEXT: cf.cond_br %cond, ^0, ^1 +// CHECK-NEXT: ^0: +// CHECK-NEXT: func.return +// CHECK-NEXT: ^1: +// CHECK-NEXT: "test.termop"() : () -> () +// CHECK-NEXT: } +func.func @cond_br_from_cond_br_with_same_condition(%cond : i1) { + cf.cond_br %cond, ^bb1, ^bb2 +^bb1: + cf.cond_br %cond, ^bb3, ^bb2 +^bb2: + "test.termop"() : () -> () +^bb3: + return +} + +// CHECK: func.func @branchCondProp(%arg0 : i1) { +// CHECK-NEXT: %arg0_1 = arith.constant true +// CHECK-NEXT: %arg0_2 = arith.constant false +// CHECK-NEXT: cf.cond_br %arg0, ^0, ^1 +// CHECK-NEXT: ^0: +// CHECK-NEXT: "test.op"(%arg0_1) : (i1) -> () +// CHECK-NEXT: cf.br ^2 +// CHECK-NEXT: ^1: +// CHECK-NEXT: "test.op"(%arg0_2) : (i1) -> () +// CHECK-NEXT: cf.br ^2 +// CHECK-NEXT: ^2: +// CHECK-NEXT: func.return +// CHECK-NEXT: } +func.func @branchCondProp(%arg0: i1) { + cf.cond_br %arg0, ^trueB, ^falseB +^trueB: + "test.op"(%arg0) : (i1) -> () + cf.br ^exit +^falseB: + "test.op"(%arg0) : (i1) -> () + cf.br ^exit +^exit: + return +} diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index ec73f8dbc8..443337b029 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -485,6 +485,10 @@ def from_int_and_width(value: int, width: int) -> IntegerAttr[IntegerType]: def from_index_int_value(value: int) -> IntegerAttr[IndexType]: return IntegerAttr(value, IndexType()) + @staticmethod + def from_bool(value: bool) -> BoolAttr: + return IntegerAttr(value, 1) + def verify(self) -> None: if isinstance(int_type := self.type, IndexType): return diff --git a/xdsl/dialects/cf.py b/xdsl/dialects/cf.py index ff3280a818..3f69138639 100644 --- a/xdsl/dialects/cf.py +++ b/xdsl/dialects/cf.py @@ -100,6 +100,7 @@ class ConditionalBranchHasCanonicalizationPatterns(HasCanonicalizationPatternsTr @classmethod def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: from xdsl.transforms.canonicalization_patterns.cf import ( + CondBranchTruthPropagation, SimplifyCondBranchIdenticalSuccessors, SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, @@ -109,6 +110,7 @@ def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: SimplifyConstCondBranchPred(), SimplifyPassThroughCondBranch(), SimplifyCondBranchIdenticalSuccessors(), + CondBranchTruthPropagation(), ) diff --git a/xdsl/transforms/canonicalization_patterns/cf.py b/xdsl/transforms/canonicalization_patterns/cf.py index db42cefd8d..9cee170662 100644 --- a/xdsl/transforms/canonicalization_patterns/cf.py +++ b/xdsl/transforms/canonicalization_patterns/cf.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from xdsl.dialects import arith, cf -from xdsl.dialects.builtin import IntegerAttr +from xdsl.dialects.builtin import BoolAttr, IntegerAttr from xdsl.ir import Block, BlockArgument, SSAValue from xdsl.pattern_rewriter import ( PatternRewriter, @@ -209,3 +209,51 @@ def match_and_rewrite(self, op: cf.ConditionalBranch, rewriter: PatternRewriter) ) rewriter.replace_matched_op(cf.Branch(op.then_block, *merged_operands)) + + +class CondBranchTruthPropagation(RewritePattern): + """ + cf.cond_br %arg0, ^trueB, ^falseB + + ^trueB: + "test.consumer1"(%arg0) : (i1) -> () + ... + + ^falseB: + "test.consumer2"(%arg0) : (i1) -> () + ... + + -> + + cf.cond_br %arg0, ^trueB, ^falseB + ^trueB: + "test.consumer1"(%true) : (i1) -> () + ... + + ^falseB: + "test.consumer2"(%false) : (i1) -> () + ... + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: cf.ConditionalBranch, rewriter: PatternRewriter): + if len(op.then_block.predecessors()) == 1: + if any( + use.operation.parent_block() is op.then_block for use in op.cond.uses + ): + const_true = arith.Constant(BoolAttr.from_bool(True)) + rewriter.insert_op(const_true, InsertPoint.before(op)) + op.cond.replace_by_if( + const_true.result, + lambda use: use.operation.parent_block() is op.then_block, + ) + if len(op.else_block.predecessors()) == 1: + if any( + use.operation.parent_block() is op.else_block for use in op.cond.uses + ): + const_false = arith.Constant(BoolAttr.from_bool(False)) + rewriter.insert_op(const_false, InsertPoint.before(op)) + op.cond.replace_by_if( + const_false.result, + lambda use: use.operation.parent_block() is op.else_block, + )