Skip to content

Commit

Permalink
feat: correctness of the translation from formulas to mathlib automata
Browse files Browse the repository at this point in the history
The proof that the automaton recognizes the solutions of the formula is
now sorry free.
  • Loading branch information
ineol committed Nov 14, 2024
1 parent 1c4dc9e commit ea91313
Show file tree
Hide file tree
Showing 9 changed files with 3,118 additions and 613 deletions.
954 changes: 903 additions & 51 deletions SSA/Experimental/Bits/AutoStructs/Basic.lean

Large diffs are not rendered by default.

439 changes: 248 additions & 191 deletions SSA/Experimental/Bits/AutoStructs/Constructions.lean

Large diffs are not rendered by default.

287 changes: 271 additions & 16 deletions SSA/Experimental/Bits/AutoStructs/Defs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,29 @@ Released under Apache 2.0 license as described in the file LICENSE.
-/
import Mathlib.Data.Bool.Basic
import Mathlib.Data.Fin.Basic
import Mathlib.Tactic
import SSA.Projects.InstCombine.ForLean
import SSA.Experimental.Bits.Fast.BitStream
import SSA.Experimental.Bits.AutoStructs.ForMathlib

namespace AutoStructs

-- A bunch of maps from `Fin n` to `Fin m` that we use to
-- lift and project variables when we interpret formulas
def liftMaxSucc1 (n m : Nat) : Fin (n + 1) → Fin (max n m + 2) :=
fun k => if _ : k = n then Fin.last (max n m) else k.castLE (by omega)
def liftMaxSucc2 (n m : Nat) : Fin (m + 1) → Fin (max n m + 2) :=
fun k => if _ : k = m then Fin.last (max n m + 1) else k.castLE (by omega)
def liftLast2 n : Fin 2 → Fin (n + 2)
| 0 => n
| 1 => Fin.last (n + 1)
def liftExcept2 n : Fin n → Fin (n + 2) :=
fun k => Fin.castLE (by omega) k
def liftMax1 (n m : Nat) : Fin n → Fin (max n m) :=
fun k => k.castLE (by omega)
def liftMax2 (n m : Nat) : Fin m → Fin (max n m) :=
fun k => k.castLE (by omega)

/-!
# Term Language
This file defines the term language the decision procedure operates on,
Expand Down Expand Up @@ -39,10 +57,6 @@ inductive Term : Type
| sub : Term → Term → Term
/-- Negation -/
| neg : Term → Term
/-- Increment (i.e., add one) -/
| incr : Term → Term
/-- Decrement (i.e., subtract one) -/
| decr : Term → Term
deriving Repr
-- /-- `repeatBit` is an operation that will repeat the infinitely repeat the
-- least significant `true` bit of the input.
Expand Down Expand Up @@ -81,8 +95,6 @@ a term like `var 10` only has a single free variable, but its arity will be `11`
| add t₁ t₂ => max (arity t₁) (arity t₂)
| sub t₁ t₂ => max (arity t₁) (arity t₂)
| neg t => arity t
| incr t => arity t
| decr t => arity t
-- | repeatBit t => arity t


Expand Down Expand Up @@ -122,13 +134,19 @@ and only require that many bitstream values to be given in `vars`.
let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
x₁ - x₂
| neg t => -(Term.evalFin t vars)
| incr t => Term.evalFin t vars + 1
| decr t => Term.evalFin t vars - 1
-- | repeatBit t => BitStream.repeatBit (Term.evalFin t vars)

lemma evalFin_eq {t : Term} {vars1 : Fin t.arity → BitVec w1} {vars2 : Fin t.arity → BitVec w2} :
∀ (heq : w1 = w2),
(∀ n, vars1 n = heq ▸ vars2 n) →
t.evalFin vars1 = heq ▸ t.evalFin vars2 := by
rintro rfl heqs
simp only
congr; ext1; simp_all

@[simp] def Term.evalNat (t : Term) (vars : Nat → BitVec w) : BitVec w :=
match t with
| var n => vars (Fin.last n)
| var n => vars n
| zero => BitVec.zero w
| one => 1
| negOne => -1
Expand All @@ -155,8 +173,6 @@ and only require that many bitstream values to be given in `vars`.
let x₂ := t₂.evalNat vars
x₁ - x₂
| neg t => -(Term.evalNat t vars)
| incr t => Term.evalNat t vars + 1
| decr t => Term.evalNat t vars - 1
-- | repeatBit t => BitStream.repeatBit (Term.evalFin t vars)
@[simp] def Term.evalFinStream (t : Term) (vars : Fin (arity t) → BitStream) : BitStream :=
match t with
Expand Down Expand Up @@ -186,8 +202,9 @@ and only require that many bitstream values to be given in `vars`.
let x₂ := t₂.evalFinStream (fun i => vars (Fin.castLE (by simp [arity]) i))
x₁ - x₂
| neg t => -(Term.evalFinStream t vars)
| incr t => BitStream.incr (Term.evalFinStream t vars)
| decr t => BitStream.decr (Term.evalFinStream t vars)

def Term.language (t : Term) : Set (BitVecs (t.arity + 1)) :=
{ bvs : BitVecs (t.arity + 1) | t.evalFin (fun n => bvs.bvs.get n) = bvs.bvs.get t.arity }

inductive RelationOrdering
| lt | le | gt | ge
Expand All @@ -199,8 +216,7 @@ inductive Relation
| unsigned (ord : RelationOrdering)
deriving Repr

@[simp]
def evalRelation {w} (rel : Relation) (bv1 bv2 : BitVec w) : Bool :=
def evalRelation (rel : Relation) {w} (bv1 bv2 : BitVec w) : Bool :=
match rel with
| .eq => bv1 = bv2
| .signed .lt => bv1 <ₛ bv2
Expand All @@ -212,11 +228,19 @@ def evalRelation {w} (rel : Relation) (bv1 bv2 : BitVec w) : Bool :=
| .unsigned .gt => bv1 >ᵤ bv2
| .unsigned .ge => bv1 ≥ᵤ bv2

@[simp]
lemma evalRelation_coe (rel : Relation) (bv1 bv2 : BitVec w1) (heq : w1 = w2) :
evalRelation rel (heq ▸ bv1) (heq ▸ bv2) = evalRelation rel bv1 bv2 := by
rcases heq; simp

@[simp]
def Relation.language (rel : Relation) : Set (BitVecs 2) :=
{ bvs | evalRelation rel (bvs.bvs.get 0) (bvs.bvs.get 1) }

inductive Binop
| and | or | impl | equiv
deriving Repr

@[simp]
def evalBinop (op : Binop) (b1 b2 : Bool) : Bool :=
match op with
| .and => b1 && b2
Expand All @@ -231,6 +255,14 @@ def evalBinop' (op : Binop) (b1 b2 : Prop) : Prop :=
| .or => b1 ∨ b2
| .impl => b1 → b2
| .equiv => b1 ↔ b2

def langBinop (op : Binop) (l1 l2 : Set (BitVecs n)) : Set (BitVecs n) :=
match op with
| .and => l1 ∩ l2
| .or => l1 ∪ l2
| .impl => l1ᶜ ∪ l2
| .equiv => (l1ᶜ ∪ l2) ∩ (l2ᶜ ∪ l1)

inductive Unop
| neg
deriving Repr
Expand Down Expand Up @@ -265,6 +297,206 @@ def Formula.sat {w : Nat} (φ : Formula) (ρ : Fin φ.arity → BitVec w) : Bool
evalBinop op b1 b2
| .msbSet t => (t.evalFin ρ).msb

@[simp]
def _root_.Set.lift (f : Fin n → Fin m) (bvs : Set (BitVecs n)) : Set (BitVecs m) :=
BitVecs.transport f ⁻¹' bvs

@[simp]
def _root_.Set.proj (f : Fin n → Fin m) (bvs : Set (BitVecs m)) : Set (BitVecs n) :=
BitVecs.transport f '' bvs

@[simp]
def langMsb : Set (BitVecs 1) := { bvs | bvs.bvs.get 0 |>.msb }

@[simp]
def Formula.language (φ : Formula) : Set (BitVecs φ.arity) :=
match φ with
| .atom rel t1 t2 =>
let l1 := t1.language.lift (liftMaxSucc1 (FinEnum.card $ Fin t1.arity) (FinEnum.card $ Fin t2.arity))
let l2 := t2.language.lift (liftMaxSucc2 (FinEnum.card $ Fin t1.arity) (FinEnum.card $ Fin t2.arity))
let lrel := rel.language.lift $ liftLast2 (max (FinEnum.card (Fin t1.arity)) (FinEnum.card (Fin t2.arity)))
let l := lrel ∩ l1 ∩ l2
l.proj (liftExcept2 _)
| .unop .neg φ => φ.languageᶜ
| .binop op φ1 φ2 =>
let l1 := φ1.language.lift $ liftMax1 φ1.arity φ2.arity
let l2 := φ2.language.lift $ liftMax2 φ1.arity φ2.arity
langBinop op l1 l2
| .msbSet t =>
let lmsb := langMsb.lift $ fun _ => Fin.last t.arity
let l' := t.language ∩ lmsb
l'.proj fun n => n.castLE (by simp [Formula.arity, FinEnum.card])

lemma helper1 : (k = 0) → (x ::ᵥ vs).get k = x := by rintro rfl; simp
lemma helper2 : (k = 1) → (x ::ᵥ y ::ᵥ vs).get k = y := by rintro rfl; simp [Mathlib.Vector.get]
lemma msb_coe {x : BitVec w1} (heq : w1 = w2) : x.msb = (heq ▸ x).msb := by rcases heq; simp

lemma formula_language_case_atom :
let φ := Formula.atom rel t1 t2
φ.language = λ (bvs : BitVecs φ.arity) => (φ.sat (fun k => bvs.bvs.get k) = true) := by
unfold Formula.language
rintro φ
let n := φ.arity
unfold φ
dsimp (config := { zeta := false })
lift_lets
intros l1 l2 lrel l
ext bvs
constructor
· intros h; simp at h
obtain ⟨bvsb, h, heqb⟩ := h
unfold l at h
simp at h
unfold lrel l1 l2 at h
obtain ⟨⟨hrel, h1⟩, h2⟩ := h
have _ : n+1 < bvsb.bvs.length := by simp_all [n]
have _ : n < bvsb.bvs.length := by simp_all [n]
have hrel : evalRelation rel (bvsb.bvs.get n) (bvsb.bvs.get (Fin.last (n + 1))) := by
simp at hrel
apply hrel
have ht1 : bvsb.bvs.get n = t1.evalFin fun n => bvsb.bvs.get n := by
unfold Term.language at h1
simp [Mathlib.Vector.transport, liftMaxSucc1] at h1
unfold n; simp; rw [←h1]
congr; ext1 k
split_ifs with h
· exfalso
have _ : k < t1.arity := by simp
have _ : k = t1.arity := by rcases k with ⟨k, hk⟩; simp_all [Fin.last]
omega
· congr; ext; simp; rw [Nat.mod_eq_of_lt]; omega
have ht2 : bvsb.bvs.get (Fin.last (n+1)) = t2.evalFin fun n => bvsb.bvs.get n := by
unfold Term.language at h2
simp [Mathlib.Vector.transport, liftMaxSucc2] at h2
unfold n; simp only [Formula.arity, Fin.natCast_self]; rw [←h2]
congr; ext1 k
split_ifs with h
· exfalso
have _ : k < t2.arity := by simp
have _ : k = t2.arity := by rcases k with ⟨k, hk⟩; simp_all [Fin.last]
omega
· congr; ext; simp; rw [Nat.mod_eq_of_lt]; omega
have hw : bvsb.w = bvs.w := by rw [←heqb]; simp
have heq1 : (t1.evalFin fun n => bvsb.bvs.get n) =
hw ▸ t1.evalFin fun n => bvs.bvs.get $ n.castLE (by simp) := by
apply evalFin_eq hw; intros k
rcases bvs with ⟨w, bvs⟩; rcases hw
injection heqb with _ heqb; rw [←heqb]
simp [Mathlib.Vector.transport, liftExcept2]
congr; ext; simp; omega
have heq2 : (t2.evalFin fun n => bvsb.bvs.get n) =
hw ▸ t2.evalFin fun n => bvs.bvs.get $ n.castLE (by simp) := by
apply evalFin_eq hw; intros k
rcases bvs with ⟨w, bvs⟩; rcases hw
injection heqb with _ heqb; rw [←heqb]
simp [Mathlib.Vector.transport, liftExcept2]
congr; ext; simp; omega
rw [ht1, ht2, heq1, heq2, evalRelation_coe] at hrel
dsimp only [Set.instMembership, Set.Mem]
simp_all
· intros h
simp
let bv1 := t1.evalFin fun k => bvs.bvs.get $ k.castLE (by simp)
let bv2 := t2.evalFin fun k => bvs.bvs.get $ k.castLE (by simp)
use ⟨bvs.w, bvs.bvs.append $ bv1 ::ᵥ bv2 ::ᵥ Mathlib.Vector.nil⟩
rcases bvs with ⟨w, bvs⟩
simp
constructor
· unfold l; simp; split_ands
· unfold lrel; simp only [Fin.isValue, BitVecs.transport_getElem,
liftLast2, Set.mem_setOf_eq, Fin.val_last, le_add_iff_nonneg_right, zero_le,
Mathlib.Vector.append_get_ge]
rw [Mathlib.Vector.append_get_ge (by dsimp; rw [Nat.mod_eq_of_lt]; omega)]
simp [Set.instMembership, Set.Mem] at h
convert h using 2
· apply helper1; ext; simp; rw [Nat.mod_eq_of_lt] <;> omega
· apply helper2; ext; simp
· unfold l1 Term.language; simp [Mathlib.Vector.transport, liftMaxSucc1]
rw [Mathlib.Vector.append_get_ge (by dsimp; rw [Nat.mod_eq_of_lt]; omega)]
rw [helper1 (by ext; simp; rw [Nat.mod_eq_of_lt] <;> omega)]
unfold bv1
congr; ext1 k; split_ifs
· exfalso
have _ : k < t1.arity := by simp
have _ : k = t1.arity := by rcases k with ⟨k, hk⟩; simp_all [Fin.last]
omega
· simp; congr 1
· unfold l2 Term.language; simp [Mathlib.Vector.transport, liftMaxSucc2]
rw [helper2 (by ext; simp)]
unfold bv2
congr; ext1 k; split_ifs
· exfalso
have _ : k < t2.arity := by simp
have _ : k = t2.arity := by rcases k with ⟨k, hk⟩; simp_all [Fin.last]
omega
· simp; congr 1
· ext1; simp
next i =>
simp [Mathlib.Vector.transport, liftExcept2]
rw [Mathlib.Vector.append_get_lt i.isLt]
congr 1

theorem formula_language (φ : Formula) :
φ.language = { (bvs : BitVecs φ.arity) | φ.sat (fun k => bvs.bvs.get k) = true } := by
let n : Nat := φ.arity
induction φ
case atom rel t1 t2 =>
apply formula_language_case_atom
case unop op φ ih =>
rcases op; simp [ih, Set.compl_def]
case binop op φ1 φ2 ih1 ih2 =>
unfold Formula.language
ext1 bvs
simp [ih1, ih2]
have heq1 : (φ1.sat fun k => bvs.bvs.get (liftMax1 φ1.arity φ2.arity k)) = true
1.sat fun n => bvs.bvs.get (Fin.castLE (by simp) n)) = true := by
simp; congr
have heq2 : (φ2.sat fun k => bvs.bvs.get (liftMax2 φ1.arity φ2.arity k)) = true
2.sat fun n => bvs.bvs.get (Fin.castLE (by simp) n)) = true := by
simp; congr
rcases op <;>
simp [evalBinop, langBinop, Set.compl, Set.instMembership,
Set.Mem, Mathlib.Vector.transport] <;> aesop
case msbSet t =>
ext1 bvs; simp only [Formula.arity, Formula.language, Set.proj, Set.lift, langMsb, Fin.isValue,
Set.preimage_setOf_eq, Set.mem_image, Set.mem_inter_iff,
Set.mem_setOf_eq, Formula.sat]
rcases bvs with ⟨w, bvs⟩
constructor
· rintro ⟨bvsb, ⟨ht, hmsb⟩, heq⟩
simp only [Fin.isValue, Formula.arity] at ht hmsb ⊢
unfold Term.language at ht
simp only [BitVecs.transport, Mathlib.Vector.transport] at hmsb
simp at ht; rw [←ht] at hmsb; rw [←hmsb]
simp [BitVecs.transport] at heq
obtain ⟨hw, hbvs⟩ := heq
simp [hw]; congr 1; simp [hw]
rcases hw; simp
congr 1; ext1 k
simp at hbvs; simp [←hbvs, Mathlib.Vector.transport]; congr
· intros heq
use ⟨w,
bvs.append ((t.evalFin fun k => bvs.get $ k.castLE (by simp)) ::ᵥ Mathlib.Vector.nil)⟩
unfold Term.language
simp [BitVecs.transport, Mathlib.Vector.transport] at heq ⊢
constructor; assumption
ext1 k; simp; congr 1

/--
The formula `φ` is true for evey valuation.
-/
@[simp]
def Formula.Tautology (φ : Formula) := φ.language = ⊤

/--
The formula `φ` is true for evey valuation made up of non-empty bitvectors.
-/
@[simp]
def Formula.Tautology' (φ : Formula) := φ.language ∪ BitVecs0 = ⊤

/--
Same as `Formula.sat` but the environment is indexed by unbounded natural number.
-/
@[simp]
def Formula.sat' {w : Nat} (φ : Formula) (ρ : Nat → BitVec w) : Prop :=
match φ with
Expand All @@ -279,6 +511,29 @@ def Formula.sat' {w : Nat} (φ : Formula) (ρ : Nat → BitVec w) : Prop :=
evalBinop' op b1 b2
| .msbSet t => (t.evalNat ρ).msb

lemma evalFin_evalNat (t : Term):
t.evalFin (fun k => ρ k.val) = t.evalNat ρ := by
induction t <;> simp_all

lemma sat_impl_sat' {φ : Formula} :
(φ.sat fun k => ρ k.val) ↔ φ.sat' ρ := by
induction φ
case atom rel t1 t2 =>
simp [←evalFin_evalNat]
case binop op φ1 φ2 ih1 ih2 =>
simp [evalBinop, ←ih1, ←ih2]
rcases op <;> simp
rw [←Bool.eq_false_eq_not_eq_true]
tauto
case unop op φ ih => rcases op; simp [←ih]
case msbSet t =>
simp [←evalFin_evalNat]

lemma env_to_bvs (φ : Formula) (ρ : Fin φ.arity → BitVec w) :
let bvs : BitVecs φ.arity := ⟨w, Mathlib.Vector.ofFn fun k => ρ k⟩
ρ = fun k => bvs.bvs.get k := by
simp

@[simp]
abbrev envOfArray {w} (a : Array (BitVec w)) : Nat → BitVec w := fun n => a.getD n 0

Expand Down
Loading

0 comments on commit ea91313

Please sign in to comment.