From 7961d5712b84233d1a86b6e0637fc4a57144cf26 Mon Sep 17 00:00:00 2001
From: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Date: Tue, 3 Dec 2024 11:49:16 +0100
Subject: [PATCH] Don't apply `local_add_neg_to_sub` rewrite if negative
 variabe is a constant

---
 pytensor/tensor/rewriting/math.py   | 80 ++++++++++++++++++-----------
 tests/tensor/rewriting/test_math.py | 19 -------
 2 files changed, 50 insertions(+), 49 deletions(-)

diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py
index f36a58fcc3..aa2d279f43 100644
--- a/pytensor/tensor/rewriting/math.py
+++ b/pytensor/tensor/rewriting/math.py
@@ -535,30 +535,59 @@ def local_mul_pow_to_pow_add(fgraph, node):
 @register_stabilize
 @register_specialize
 @register_canonicalize
-@node_rewriter([sub])
+@node_rewriter([add, sub])
 def local_expm1(fgraph, node):
-    """Detect ``exp(a) - 1`` and convert them to ``expm1(a)``."""
-    in1, in2 = node.inputs
-    out = node.outputs[0]
+    """Detect ``exp(a) - 1`` or ``-1 + exp(a)`` and convert them to ``expm1(a)``."""
+    if len(node.inputs) != 2:
+        # TODO: handle more than two inputs in add
+        return None
 
-    if (
-        in1.owner
-        and isinstance(in1.owner.op, Elemwise)
-        and isinstance(in1.owner.op.scalar_op, ps.Exp)
-        and get_underlying_scalar_constant_value(in2, raise_not_constant=False) == 1
-    ):
-        in11 = in1.owner.inputs[0]
-        new_out = expm1(in11)
+    if isinstance(node.op.scalar_op, ps.Sub):
+        exp_x, other_inp = node.inputs
+        if not (
+            exp_x.owner
+            and isinstance(exp_x.owner.op, Elemwise)
+            and isinstance(exp_x.owner.op.scalar_op, ps.Exp)
+            and get_underlying_scalar_constant_value(
+                other_inp, raise_not_constant=False
+            )
+            == 1
+        ):
+            return None
+    else:
+        # Try both orders
+        other_inp, exp_x = node.inputs
+        for i in range(2):
+            if i == 1:
+                other_inp, exp_x = exp_x, other_inp
+            if (
+                exp_x.owner
+                and isinstance(exp_x.owner.op, Elemwise)
+                and isinstance(exp_x.owner.op.scalar_op, ps.Exp)
+                and get_underlying_scalar_constant_value(
+                    other_inp, raise_not_constant=False
+                )
+                == -1
+            ):
+                break
+        else:  # no break
+            return None
 
-        if new_out.type.broadcastable != out.type.broadcastable:
-            new_out = broadcast_arrays(in11, in2)[0]
+    [old_out] = node.outputs
 
-        if new_out.dtype != out.dtype:
-            new_out = cast(new_out, dtype=out.dtype)
+    [x] = exp_x.owner.inputs
+    if x.type.broadcastable != old_out.type.broadcastable:
+        x = broadcast_arrays(x, other_inp)[0]
 
-        if not out.type.is_super(new_out.type):
-            return
-        return [new_out]
+    new_out = expm1(x)
+
+    if new_out.dtype != old_out.dtype:
+        new_out = cast(new_out, dtype=old_out.dtype)
+
+    if not old_out.type.is_super(new_out.type):
+        return None
+
+    return [new_out]
 
 
 @register_specialize
@@ -1824,15 +1853,6 @@ def local_add_neg_to_sub(fgraph, node):
                     new_out = sub(first, pre_neg)
                     return [new_out]
 
-            # Check if it is a negative constant
-            if (
-                isinstance(second, TensorConstant)
-                and second.unique_value is not None
-                and second.unique_value < 0
-            ):
-                new_out = sub(first, np.abs(second.data))
-                return [new_out]
-
 
 @register_canonicalize
 @node_rewriter([mul])
@@ -2606,9 +2626,9 @@ def local_greedy_distributor(fgraph, node):
 register_stabilize(local_one_minus_erfc)
 register_specialize(local_one_minus_erfc)
 
-# erfc(-x)-1=>erf(x)
+# -1 + erfc(-x)=>erf(x)
 local_erf_neg_minus_one = PatternNodeRewriter(
-    (sub, (erfc, (neg, "x")), 1),
+    (add, -1, (erfc, (neg, "x"))),
     (erf, "x"),
     allow_multiple_clients=True,
     name="local_erf_neg_minus_one",
diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py
index f2f421c6a5..4ff1a29048 100644
--- a/tests/tensor/rewriting/test_math.py
+++ b/tests/tensor/rewriting/test_math.py
@@ -4440,25 +4440,6 @@ def test_local_add_neg_to_sub(first_negative):
     assert np.allclose(f(x_test, y_test), exp)
 
 
-@pytest.mark.parametrize("const_left", (True, False))
-def test_local_add_neg_to_sub_const(const_left):
-    x = vector("x")
-    const = np.full((3, 2), 5.0)
-    out = -const + x if const_left else x + (-const)
-
-    f = function([x], out, mode=Mode("py"))
-
-    nodes = [
-        node.op
-        for node in f.maker.fgraph.toposort()
-        if not isinstance(node.op, DimShuffle | Alloc)
-    ]
-    assert nodes == [pt.sub]
-
-    x_test = np.array([3, 4], dtype=config.floatX)
-    assert np.allclose(f(x_test), x_test + (-const))
-
-
 def test_log1mexp_stabilization():
     mode = Mode("py").including("stabilize")