From 2e573ef20c85da5a77a37bba2eb4eb5a5db2ca81 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 27 Aug 2023 08:00:32 -0500 Subject: [PATCH] [Arith] Added simplification rule for multiple equality compares The expression `(x==y) && (x==z)` requires that `y==z`. When `y` and `z` are constants, this can allow better constant folding by rewriting `(x==c1) && (x==c2)` into `(x==c1) && (c1==c2)`. This commit adds the above rewrite, and the corresponding rewrite of the negative expression. --- src/arith/rewrite_simplify.cc | 2 ++ tests/python/unittest/test_arith_rewrite_simplify.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 40088fd963d7..63becf8eb77f 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1856,6 +1856,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { }), cfalse, c2.Eval()->value > c1.Eval()->value); + TVM_TRY_REWRITE((x == c1) && (x == c2), (x == c1) && (c1 == c2)); TVM_TRY_REWRITE(matches_one_of(x == c1 && x != c2, x != c2 && x == c1), x == c1 && c1 != c2); TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && floormod(x, c2) == c3, @@ -2000,6 +2001,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); + TVM_TRY_REWRITE(x != c1 || x != c2, x != c1 || c1 != c2); TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2); TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 46ac0f975157..0b0a43a7d3d3 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -951,6 +951,7 @@ class TestLogical(BaseCompare): TestCase(tvm.tir.And(x <= 1, 2 <= x), tvm.tir.const(False, "bool")), TestCase(tvm.tir.And(2 <= x, x <= 1), tvm.tir.const(False, "bool")), TestCase(tvm.tir.And(x == 1, x != 2), x == 1), + TestCase(tvm.tir.And(x == 1, x == 2), tvm.tir.const(False, "bool")), TestCase(tvm.tir.Or(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)), tvm.tir.const(True, "bool")), TestCase(tvm.tir.Or(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)), tvm.tir.const(True, "bool")), TestCase(tvm.tir.Or(x > y, tvm.tir.Not(x > y)), tvm.tir.const(True, "bool")), @@ -965,6 +966,7 @@ class TestLogical(BaseCompare): TestCase(tvm.tir.Or(x <= 1, 2 <= x), tvm.tir.const(True, "bool")), TestCase(tvm.tir.Or(2 <= x, x <= 1), tvm.tir.const(True, "bool")), TestCase(tvm.tir.Or(x != 1, x == 2), x != 1), + TestCase(tvm.tir.Or(x != 1, x != 2), tvm.tir.const(True, "bool")), TestCase( tvm.tir.Or(x == 1, tvm.tir.Or(y == 1, z == 1)), tvm.tir.Or(tvm.tir.Or(x == 1, y == 1), z == 1),