Skip to content

Commit

Permalink
feat: simp +arith sorts linear atoms (#7040)
Browse files Browse the repository at this point in the history
This PR ensures that terms such as `f (2*x + y)` and `f (y + x + x)`
have the same normal form when using `simp +arith`
  • Loading branch information
leodemoura authored Feb 11, 2025
1 parent 0f1133f commit b87c01b
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 33 deletions.
32 changes: 16 additions & 16 deletions src/Init/Data/Nat/Linear.lean
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ inductive Expr where
deriving Inhabited

def Expr.denote (ctx : Context) : Expr → Nat
| Expr.add a b => Nat.add (denote ctx a) (denote ctx b)
| Expr.num k => k
| Expr.var v => v.denote ctx
| Expr.mulL k e => Nat.mul k (denote ctx e)
| Expr.mulR e k => Nat.mul (denote ctx e) k
| .add a b => Nat.add (denote ctx a) (denote ctx b)
| .num k => k
| .var v => v.denote ctx
| .mulL k e => Nat.mul k (denote ctx e)
| .mulR e k => Nat.mul (denote ctx e) k

abbrev Poly := List (Nat × Var)

Expand Down Expand Up @@ -146,17 +146,17 @@ where
-- Implementation note: This assembles the result using difference lists
-- to avoid `++` on lists.
go (coeff : Nat) : Expr → (Poly → Poly)
| Expr.num k => bif k == 0 then id else ((coeff * k, fixedVar) :: ·)
| Expr.var i => ((coeff, i) :: ·)
| Expr.add a b => go coeff a ∘ go coeff b
| Expr.mulL k a
| Expr.mulR a k => bif k == 0 then id else go (coeff * k) a
| .num k => bif k == 0 then id else ((coeff * k, fixedVar) :: ·)
| .var i => ((coeff, i) :: ·)
| .add a b => go coeff a ∘ go coeff b
| .mulL k a
| .mulR a k => bif k == 0 then id else go (coeff * k) a

def Expr.toNormPoly (e : Expr) : Poly :=
e.toPoly.norm

def Expr.inc (e : Expr) : Expr :=
Expr.add e (Expr.num 1)
.add e (.num 1)

structure PolyCnstr where
eq : Bool
Expand Down Expand Up @@ -244,21 +244,21 @@ def Certificate.denote (ctx : Context) (c : Certificate) : Prop :=

def monomialToExpr (k : Nat) (v : Var) : Expr :=
bif v == fixedVar then
Expr.num k
.num k
else bif k == 1 then
Expr.var v
.var v
else
Expr.mulL k (Expr.var v)
.mulL k (.var v)

def Poly.toExpr (p : Poly) : Expr :=
match p with
| [] => Expr.num 0
| [] => .num 0
| (k, v) :: p => go (monomialToExpr k v) p
where
go (e : Expr) (p : Poly) : Expr :=
match p with
| [] => e
| (k, v) :: p => go (Expr.add e (monomialToExpr k v)) p
| (k, v) :: p => go (.add e (monomialToExpr k v)) p

def PolyCnstr.toExpr (c : PolyCnstr) : ExprCnstr :=
{ c with lhs := c.lhs.toExpr, rhs := c.rhs.toExpr }
Expand Down
38 changes: 37 additions & 1 deletion src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Init.Data.Int.Linear
import Lean.Util.SortExprs
import Lean.Meta.Check
import Lean.Meta.Offset
import Lean.Meta.IntInstTesters
Expand All @@ -31,6 +32,24 @@ def PolyCnstr.toExprCnstr : PolyCnstr → ExprCnstr
| .eq p => .eq p.toExpr (.num 0)
| .le p => .le p.toExpr (.num 0)

/-- Applies the given variable permutation to `e` -/
def Expr.applyPerm (perm : Lean.Perm) (e : Expr) : Expr :=
go e
where
go : Expr → Expr
| .num v => .num v
| .var i => .var (perm[(i : Nat)]?.getD i)
| .neg a => .neg (go a)
| .add a b => .add (go a) (go b)
| .sub a b => .sub (go a) (go b)
| .mulL k a => .mulL k (go a)
| .mulR a k => .mulR (go a) k

/-- Applies the given variable permutation to the given expression constraint. -/
def ExprCnstr.applyPerm (perm : Lean.Perm) : ExprCnstr → ExprCnstr
| .eq a b => .eq (a.applyPerm perm) (b.applyPerm perm)
| .le a b => .le (a.applyPerm perm) (b.applyPerm perm)

end Int.Linear

namespace Lean.Meta.Linear.Int
Expand Down Expand Up @@ -187,7 +206,24 @@ def run (x : M α) : MetaM (α × Array Expr) := do

end ToLinear

export ToLinear (toLinearCnstr? toLinearExpr)
def toLinearExpr (e : Expr) : MetaM (LinearExpr × Array Expr) := do
let (e, atoms) ← ToLinear.run (ToLinear.toLinearExpr e)
if atoms.size == 1 then
return (e, atoms)
else
let (atoms, perm) := sortExprs atoms
let e := e.applyPerm perm
return (e, atoms)

def toLinearCnstr? (e : Expr) : MetaM (Option (LinearCnstr × Array Expr)) := do
let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e)
| return none
if atoms.size <= 1 then
return some (c, atoms)
else
let (atoms, perm) := sortExprs atoms
let c := c.applyPerm perm
return some (c, atoms)

def toContextExpr (ctx : Array Expr) : Expr :=
if h : 0 < ctx.size then
Expand Down
8 changes: 4 additions & 4 deletions src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def Int.Linear.PolyCnstr.getConst : PolyCnstr → Int
namespace Lean.Meta.Linear.Int

def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e) | return none
let some (c, atoms) ← toLinearCnstr? e | return none
withAbstractAtoms atoms ``Int fun atoms => do
let lhs ← c.toArith atoms
let p := c.toPoly
Expand Down Expand Up @@ -127,13 +127,13 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
simpCnstrPos? e

def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let (e, ctx) ← ToLinear.run (ToLinear.toLinearExpr e)
let (e, atoms) ← toLinearExpr e
let p := e.toPoly
let e' := p.toExpr
if e != e' then
-- We only return some if monomials were fused
let p := mkApp4 (mkConst ``Int.Linear.Expr.eq_of_toPoly_eq) (toContextExpr ctx) (toExpr e) (toExpr e') reflBoolTrue
let r ← LinearExpr.toArith ctx e'
let p := mkApp4 (mkConst ``Int.Linear.Expr.eq_of_toPoly_eq) (toContextExpr atoms) (toExpr e) (toExpr e') reflBoolTrue
let r ← LinearExpr.toArith atoms e'
return some (r, p)
else
return none
Expand Down
41 changes: 39 additions & 2 deletions src/Lean/Meta/Tactic/LinearArith/Nat/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,32 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Util.SortExprs
import Lean.Meta.Check
import Lean.Meta.Offset
import Lean.Meta.AppBuilder
import Lean.Meta.KExprMap
import Lean.Data.RArray

namespace Nat.Linear

/-- Applies the given variable permutation to `e` -/
def Expr.applyPerm (perm : Lean.Perm) (e : Expr) : Expr :=
go e
where
go : Expr → Expr
| .num v => .num v
| .var i => .var (perm[(i : Nat)]?.getD i)
| .add a b => .add (go a) (go b)
| .mulL k a => .mulL k (go a)
| .mulR a k => .mulR (go a) k

/-- Applies the given variable permutation to the given expression constraint. -/
def ExprCnstr.applyPerm (perm : Lean.Perm) : ExprCnstr → ExprCnstr
| { eq, lhs, rhs } => { eq, lhs := lhs.applyPerm perm, rhs := rhs.applyPerm perm }

end Nat.Linear

namespace Lean.Meta.Linear.Nat

deriving instance Repr for Nat.Linear.Expr
Expand Down Expand Up @@ -140,12 +160,29 @@ def run (x : M α) : MetaM (α × Array Expr) := do

end ToLinear

export ToLinear (toLinearCnstr? toLinearExpr)
def toLinearExpr (e : Expr) : MetaM (LinearExpr × Array Expr) := do
let (e, atoms) ← ToLinear.run (ToLinear.toLinearExpr e)
if atoms.size == 1 then
return (e, atoms)
else
let (atoms, perm) := sortExprs atoms
let e := e.applyPerm perm
return (e, atoms)

def toLinearCnstr? (e : Expr) : MetaM (Option (LinearCnstr × Array Expr)) := do
let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e)
| return none
if atoms.size <= 1 then
return some (c, atoms)
else
let (atoms, perm) := sortExprs atoms
let c := c.applyPerm perm
return some (c, atoms)

def toContextExpr (ctx : Array Expr) : Expr :=
if h : 0 < ctx.size then
RArray.toExpr (mkConst ``Nat) id (RArray.ofArray ctx h)
else
RArray.toExpr (mkConst ``Nat) id (RArray.leaf (mkNatLit 0))

end Lean.Meta.Linear.Nat
namespace Lean.Meta.Linear.Nat
5 changes: 3 additions & 2 deletions src/Lean/Meta/Tactic/LinearArith/Nat/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import Lean.Meta.Tactic.LinearArith.Nat.Basic
namespace Lean.Meta.Linear.Nat

def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e) | return none
let some (c, atoms) ← toLinearCnstr? e
| return none
withAbstractAtoms atoms ``Nat fun atoms => do
let lhs ← c.toArith atoms
let c₁ := c.toPoly
Expand Down Expand Up @@ -67,7 +68,7 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
simpCnstrPos? e

def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let (e, ctx) ← ToLinear.run (ToLinear.toLinearExpr e)
let (e, ctx) ← toLinearExpr e
let p := e.toPoly
let p' := p.norm
if p'.length < p.length then
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ import Lean.Util.SafeExponentiation
import Lean.Util.NumObjs
import Lean.Util.NumApps
import Lean.Util.FVarSubset
import Lean.Util.SortExprs
23 changes: 23 additions & 0 deletions src/Lean/Util/SortExprs.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Expr

namespace Lean

abbrev Perm := Std.HashMap Nat Nat

/--
Sorts the given expressions using `Expr.lt`, and creates a "permutation map" storing the new position of each expression.
-/
def sortExprs (es : Array Expr) : Array Expr × Perm :=
let es := es.mapIdx fun i e => (e, i)
let es := es.qsort fun (e₁, _) (e₂, _) => e₁.lt e₂
let (_, perm) := es.foldl (init := (0, Std.HashMap.empty)) fun (i, perm) (_, j) => (i+1, perm.insert j i)
let es := es.map (·.1)
(es, perm)

end Lean
25 changes: 17 additions & 8 deletions tests/lean/run/simp_int_arith.lean
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,18 @@ fun x y z f =>
id
(Int.Linear.ExprCnstr.eq_true_of_isValid
(Lean.RArray.branch 1 (Lean.RArray.leaf x)
(Lean.RArray.branch 2 (Lean.RArray.leaf x_1) (Lean.RArray.leaf z)))
(Lean.RArray.branch 2 (Lean.RArray.leaf z) (Lean.RArray.leaf x_1)))
(Int.Linear.ExprCnstr.le
((((((Int.Linear.Expr.var 0).add (Int.Linear.Expr.var 1)).add (Int.Linear.Expr.num 2)).add
(Int.Linear.Expr.var 1)).add
(Int.Linear.Expr.var 2)).add
(Int.Linear.Expr.var 2))
(((((((Int.Linear.Expr.var 1).add (Int.Linear.Expr.mulL 3 (Int.Linear.Expr.var 2))).add
((((((Int.Linear.Expr.var 0).add (Int.Linear.Expr.var 2)).add (Int.Linear.Expr.num 2)).add
(Int.Linear.Expr.var 2)).add
(Int.Linear.Expr.var 1)).add
(Int.Linear.Expr.var 1))
(((((((Int.Linear.Expr.var 2).add (Int.Linear.Expr.mulL 3 (Int.Linear.Expr.var 1))).add
(Int.Linear.Expr.num 1)).add
(Int.Linear.Expr.num 1)).add
(Int.Linear.Expr.var 0)).add
(Int.Linear.Expr.var 1)).sub
(Int.Linear.Expr.var 2)))
(Int.Linear.Expr.var 2)).sub
(Int.Linear.Expr.var 1)))
(Eq.refl true)))
(f y))
-/
Expand Down Expand Up @@ -256,3 +256,12 @@ example (x : Int) : (11*x ≤ 10) ↔ (x ≤ 0) := by

example (x : Int) : (11*x > 10) ↔ (x ≥ 1) := by
simp +arith only

example (x y : Int) : (2*x + y + y = 4) ↔ (y + x = 2) := by
simp +arith

example (x y : Int) : (2*x + y + y ≤ 3) ↔ (y + x ≤ 1) := by
simp +arith

example (f : Int → Int) (x y : Int) : f (2*x + y) = f (y + x + x) := by
simp +arith
8 changes: 8 additions & 0 deletions tests/lean/run/simp_nat_arith.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
example (x y : Nat) : (2*x + y = 4) ↔ (y + x + x = 4) := by
simp +arith

example (x y : Nat) : (2*x + y ≤ 3) ↔ (y + x + x ≤ 3) := by
simp +arith

example (f : Nat → Nat) (x y : Nat) : f (2*x + y) = f (y + x + x) := by
simp +arith

0 comments on commit b87c01b

Please sign in to comment.