Skip to content

Commit

Permalink
[mlir][index] Fold cmp(x, x) when x isn't a constant (llvm#78812)
Browse files Browse the repository at this point in the history
Such cases show up in the middle of optimizations passes, e.g., after
some rewrites and then CSE. The current folder can fold such cases when
the inputs are constant; this patch improves it to fold even if the
inputs are non-constant.
  • Loading branch information
StrongerXi authored Jan 19, 2024
1 parent b86d023 commit c17aa14
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
22 changes: 22 additions & 0 deletions mlir/lib/Dialect/Index/IR/IndexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,24 @@ static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
lhsRange, ConstantIntRanges::constant(cstB));
}

/// Return the result of `cmp(pred, x, x)`
static bool compareSameArgs(IndexCmpPredicate pred) {
switch (pred) {
case IndexCmpPredicate::EQ:
case IndexCmpPredicate::SGE:
case IndexCmpPredicate::SLE:
case IndexCmpPredicate::UGE:
case IndexCmpPredicate::ULE:
return true;
case IndexCmpPredicate::NE:
case IndexCmpPredicate::SGT:
case IndexCmpPredicate::SLT:
case IndexCmpPredicate::UGT:
case IndexCmpPredicate::ULT:
return false;
}
}

OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
// Attempt to fold if both inputs are constant.
auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
Expand Down Expand Up @@ -606,6 +624,10 @@ OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
return BoolAttr::get(getContext(), *result64);
}

// Fold `cmp(x, x)`
if (getLhs() == getRhs())
return BoolAttr::get(getContext(), compareSameArgs(getPred()));

return {};
}

Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Dialect/Index/index-canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,26 @@ func.func @cmp(%arg0: index) -> (i1, i1, i1, i1, i1, i1) {
return %0, %1, %2, %3, %5, %7 : i1, i1, i1, i1, i1, i1
}

// CHECK-LABEL: @cmp_same_args
func.func @cmp_same_args(%a: index) -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
%0 = index.cmp eq(%a, %a)
%1 = index.cmp sge(%a, %a)
%2 = index.cmp sle(%a, %a)
%3 = index.cmp uge(%a, %a)
%4 = index.cmp ule(%a, %a)
%5 = index.cmp ne(%a, %a)
%6 = index.cmp sgt(%a, %a)
%7 = index.cmp slt(%a, %a)
%8 = index.cmp ugt(%a, %a)
%9 = index.cmp ult(%a, %a)

// CHECK-DAG: %[[TRUE:.*]] = index.bool.constant true
// CHECK-DAG: %[[FALSE:.*]] = index.bool.constant false
// CHECK-NEXT: return %[[TRUE]], %[[TRUE]], %[[TRUE]], %[[TRUE]], %[[TRUE]],
// CHECK-SAME: %[[FALSE]], %[[FALSE]], %[[FALSE]], %[[FALSE]], %[[FALSE]]
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
}

// CHECK-LABEL: @cmp_nofold
func.func @cmp_nofold() -> i1 {
%lhs = index.constant 1
Expand Down

0 comments on commit c17aa14

Please sign in to comment.