Skip to content

Commit

Permalink
Generalize NFA.correct to n-ary
Browse files Browse the repository at this point in the history
  • Loading branch information
ineol committed Nov 14, 2024
1 parent 05dbec6 commit 05953af
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 171 deletions.
161 changes: 0 additions & 161 deletions SSA/Experimental/Bits/AutoStructs/FiniteStateMachine.lean
Original file line number Diff line number Diff line change
Expand Up @@ -612,165 +612,4 @@ def termEvalEqFSM : ∀ (t : Term), FSMSolution t

abbrev FSM.ofTerm (t : Term) : FSM (Fin t.arity) := termEvalEqFSM t |>.toFSM

/-!
FSM that implement bitwise-and. Since we use `0` as the good state,
we keep the invariant that if both inputs are good and our state is `0`, then we produce a `0`.
If not, we produce an infinite sequence of `1`.
-/
def and : FSM Bool :=
{ α := Unit,
initCarry := fun _ => false,
nextBitCirc := fun a =>
match a with
| some () =>
-- Only if both are `0` we produce a `0`.
(Circuit.var true (inr false) |||
((Circuit.var false (inr true) |||
-- But if we have failed and have value `1`, then we produce a `1` from our state.
(Circuit.var true (inl ())))))
| none => -- must succeed in both arguments, so we are `0` if both are `0`.
Circuit.var true (inr true) |||
Circuit.var true (inr false)
}

/-!
FSM that implement bitwise-or. Since we use `0` as the good state,
we keep the invariant that if either inputs is `0` then our state is `0`.
If not, we produce a `1`.
-/
def or : FSM Bool :=
{ α := Unit,
initCarry := fun _ => false,
nextBitCirc := fun a =>
match a with
| some () =>
-- If either succeeds, then the full thing succeeds
((Circuit.var true (inr false) &&&
((Circuit.var false (inr true)) |||
-- On the other hand, if we have failed, then propagate failure.
(Circuit.var true (inl ())))))
| none => -- can succeed in either argument, so we are `0` if either is `0`.
Circuit.var true (inr true) &&&
Circuit.var true (inr false)
}

/-!
FSM that implement logical not.
we keep the invariant that if the input ever fails and becomes a `1`, then we produce a `0`.
IF not, we produce an infinite sequence of `1`.
EDIT: Aha, this doesn't work!
We need CNFA to DFA here (as the presburger book does),
where we must produce an infinite sequence of`0` iff the input can *ever* become a `1`.
But here, since we phrase things directly in terms of producing sequences, it's a bit less clear
what we should do :)
- Alternatively, we need to be able to decide `eventually always zero`.
- Alternatively, we push negations inside, and decide `⬝ ≠ ⬝` and `⬝ ≰ ⬝`.
-/

inductive Result : Type
| falseAfter (n : ℕ) : Result
| trueFor (n : ℕ) : Result
| trueForall : Result
deriving Repr, DecidableEq

def card_compl [Fintype α] [DecidableEq α] (c : Circuit α) : ℕ :=
Finset.card $ (@Finset.univ (α → Bool) _).filter (fun a => c.eval a = false)

theorem decideIfZeroAux_wf {α : Type _} [Fintype α] [DecidableEq α]
{c c' : Circuit α} (h : ¬c' ≤ c) : card_compl (c' ||| c) < card_compl c := by
apply Finset.card_lt_card
simp [Finset.ssubset_iff, Finset.subset_iff]
simp only [Circuit.le_def, not_forall, Bool.not_eq_true] at h
rcases h with ⟨x, hx, h⟩
use x
simp [hx, h]

def decideIfZerosAux {arity : Type _} [DecidableEq arity]
(p : FSM arity) (c : Circuit p.α) : Bool :=
if c.eval p.initCarry
then false
else
have c' := (c.bind (p.nextBitCirc ∘ some)).fst
if h : c' ≤ c then true
else
have _wf : card_compl (c' ||| c) < card_compl c :=
decideIfZeroAux_wf h
decideIfZerosAux p (c' ||| c)
termination_by card_compl c

def decideIfZeros {arity : Type _} [DecidableEq arity]
(p : FSM arity) : Bool :=
decideIfZerosAux p (p.nextBitCirc none).fst

theorem decideIfZerosAux_correct {arity : Type _} [DecidableEq arity]
(p : FSM arity) (c : Circuit p.α)
(hc : ∀ s, c.eval s = true
∃ m y, (p.changeInitCarry s).eval y m = true)
(hc₂ : ∀ (x : arity → Bool) (s : p.α → Bool),
(FSM.nextBit p s x).snd = true → Circuit.eval c s = true) :
decideIfZerosAux p c = true ↔ ∀ n x, p.eval x n = false := by
rw [decideIfZerosAux]
split_ifs with h
· simp
exact hc p.initCarry h
· dsimp
split_ifs with h'
· simp only [true_iff]
intro n x
rw [p.eval_eq_zero_of_set {x | c.eval x = true}]
· intro y s
simp [Circuit.le_def, Circuit.eval_fst, Circuit.eval_bind] at h'
simp [Circuit.eval_fst, FSM.nextBit]
apply h'
· assumption
· exact hc₂
· let c' := (c.bind (p.nextBitCirc ∘ some)).fst
have _wf : card_compl (c' ||| c) < card_compl c :=
decideIfZeroAux_wf h'
apply decideIfZerosAux_correct p (c' ||| c)
simp [c', Circuit.eval_fst, Circuit.eval_bind]
intro s hs
rcases hs with ⟨x, hx⟩ | h
· rcases hc _ hx with ⟨m, y, hmy⟩
use (m+1)
use fun a i => Nat.casesOn i x (fun i a => y a i) a
rw [FSM.eval_changeInitCarry_succ]
rw [← hmy]
simp only [FSM.nextBit, Nat.rec_zero, Nat.rec_add_one]
· exact hc _ h
· intro x s h
have := hc₂ _ _ h
simp only [Circuit.eval_bind, Bool.or_eq_true, Circuit.eval_fst,
Circuit.eval_or, this, or_true]
termination_by card_compl c

theorem decideIfZeros_correct {arity : Type _} [DecidableEq arity]
(p : FSM arity) : decideIfZeros p = true ↔ ∀ n x, p.eval x n = false := by
apply decideIfZerosAux_correct
· simp only [Circuit.eval_fst, forall_exists_index]
intro s x h
use 0
use (fun a _ => x a)
simpa [FSM.eval, FSM.changeInitCarry, FSM.nextBit, FSM.carry]
· simp only [Circuit.eval_fst]
intro x s h
use x
exact h

end FSM

/--
The fragment of predicate logic that we support in `bv_automata`.
Currently, we support equality, conjunction, disjunction, and negation.
This can be expanded to also support arithmetic constraints such as unsigned-less-than.
-/
inductive Predicate : Nat → Type _ where
| eq (t1 t2 : Term) : Predicate ((max t1.arity t2.arity))
| and (p : Predicate n) (q : Predicate m) : Predicate (max n m)
| or (p : Predicate n) (q : Predicate m) : Predicate (max n m)
-- For now, we can't prove `not`, because it needs CNFA → DFA conversion
-- the way Sid knows how to build it, or negation normal form,
-- both of which is machinery we lack.
-- | not (p : Predicate n) : Predicate n
54 changes: 44 additions & 10 deletions SSA/Experimental/Bits/AutoStructs/FormulaToAuto.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import Std.Data.HashMap
import Mathlib.Data.Fintype.Basic
import Mathlib.Data.Finset.Basic
import Mathlib.Data.FinEnum
import Mathlib.Data.Vector.Basic
import Mathlib.Data.Vector.Defs
import Mathlib.Tactic.FinCases
import SSA.Experimental.Bits.AutoStructs.Basic
import SSA.Experimental.Bits.AutoStructs.Constructions
Expand All @@ -15,6 +17,7 @@ import SSA.Experimental.Bits.AutoStructs.FiniteStateMachine
import SSA.Experimental.Bits.AutoStructs.GoodNFA

open AutoStructs
open Mathlib

section fsm

Expand All @@ -25,7 +28,8 @@ variable {arity : Type} [FinEnum arity]
def finFunToBitVec (c : carry → Bool) [FinEnum carry] : BitVec (FinEnum.card carry) :=
((FinEnum.toList carry).enum.map (fun (i, x) => c x |> Bool.toNat * 2^i)).foldl (init := 0) Nat.add |> BitVec.ofNat _

def bitVecToFinFun [FinEnum ar] (bv : BitVec $ FinEnum.card ar) : ar → Bool := fun c => bv[FinEnum.equiv.toFun c]
def bitVecToFinFun [FinEnum ar] (bv : BitVec $ FinEnum.card ar) : ar → Bool :=
fun c => bv[FinEnum.equiv.toFun c]

def NFA.ofFSM (p : FSM arity) : NFA (Alphabet arity) (p.α → Bool) where
start s := s = p.initCarry
Expand Down Expand Up @@ -106,8 +110,12 @@ lemma NFA.correct_spec {M : NFA α σ} {ζ : M.sa} {L : Language α} :
rfl

abbrev BVRel := ∀ ⦃w⦄, BitVec w → BitVec w → Prop
abbrev BVNRel n := ∀ ⦃w⦄, Vector (BitVec w) n → Prop

def langRel (R : BVRel) : Set (BitVecs 2) :=
def langRel (R : BVNRel n) : Set (BitVecs n) :=
{ bvs | R bvs.bvs }

def langRel2 (R : BVRel) : Set (BitVecs 2) :=
{ bvs | R (bvs.bvs.get 0) (bvs.bvs.get 1) }

def NFA.autEq : NFA (BitVec 2) Unit :=
Expand All @@ -133,21 +141,34 @@ lemma BitVec.ofFn_0 {f : Fin 0 → Bool} : ofFn f = .nil := by
apply eq_nil

@[simp]
lemma dec_snoc_in_langRel :
dec (w ++ [a]) ∈ langRel R ↔ R (.cons (a.getLsbD 0) ((dec w).bvs.get 0))
(.cons (a.getLsbD 1) ((dec w).bvs.get 1)) := by
lemma dec_snoc_in_langRel {n} {R : BVNRel n} {w : BitVecs' n} {a : BitVec n} :
dec (w ++ [a]) ∈ langRel R ↔
R (Vector.ofFn fun k => .cons (a.getLsbD k) ((dec w).bvs.get k)) := by
simp [langRel]

def GoodNFA.sa (M : GoodNFA n) := M.σ → (∀ ⦃w⦄, BitVec w → BitVec w → Prop)
@[simp]
lemma dec_snoc_in_langRel2 :
dec (w ++ [a]) ∈ langRel2 R ↔ R (.cons (a.getLsbD 0) ((dec w).bvs.get 0))
(.cons (a.getLsbD 1) ((dec w).bvs.get 1)) := by
simp [langRel2]

def GoodNFA.sa (M : GoodNFA n) := M.σ → BVNRel n
def GoodNFA.sa2 (M : GoodNFA 2) := M.σ → BVRel

structure GoodNFA.correct2 (M : GoodNFA 2) (ζ : M.sa) (L : BVRel) where
structure GoodNFA.correct (M : GoodNFA n) (ζ : M.sa) (L : BVNRel n) where
cond1 : ∀ ⦃w⦄ (bvn : Vector (BitVec w) n), (L bvn ↔ ∃ q ∈ M.M.accept, ζ q bvn)
cond2 q : q ∈ M.M.start ↔ ζ q (Vector.replicate n .nil)
cond3 q a {w} (bvn : Vector (BitVec w) n) : q ∈ M.M.stepSet { q | ζ q bvn } a ↔
ζ q (Vector.ofFn fun k => BitVec.cons (a.getLsbD k) (bvn.get k))

structure GoodNFA.correct2 (M : GoodNFA 2) (ζ : M.sa2) (L : BVRel) where
cond1 : ∀ (bv1 bv2 : BitVec w), (L bv1 bv2 ↔ ∃ q ∈ M.M.accept, ζ q bv1 bv2)
cond2 q : q ∈ M.M.start ↔ ζ q .nil .nil
cond3 q a w (bv1 bv2 : BitVec w) : q ∈ M.M.stepSet { q | ζ q bv1 bv2 } a ↔
ζ q (BitVec.cons (a.getLsbD 0) bv1) (BitVec.cons (a.getLsbD 1) bv2)

lemma GoodNFA.correct2_spec (M : GoodNFA 2) (ζ : M.sa) (L : BVRel) :
M.correct2 ζ L → M.accepts = langRel L := by
lemma GoodNFA.correct_spec (M : GoodNFA n) {ζ : M.sa} {L : BVNRel n} :
M.correct ζ L → M.accepts = langRel L := by
rintro ⟨h1, h2, h3⟩
simp [accepts, accepts']
have heq : dec '' (enc '' langRel L) = langRel L := by simp
Expand All @@ -159,7 +180,7 @@ lemma GoodNFA.correct2_spec (M : GoodNFA 2) (ζ : M.sa) (L : BVRel) :
· intros w; rw [in_enc]; simp [langRel, h1]; simp_rw [@in_enc _ _ w]; rfl
intros w; induction w using List.list_reverse_induction
case base =>
intros q; simp [autEq]; rw [in_enc]; simp [h2, langRel]; rfl
intros q; simp [autEq]; rw [in_enc]; simp [h2, langRel]
case ind w a ih =>
rintro q
simp
Expand All @@ -170,6 +191,18 @@ lemma GoodNFA.correct2_spec (M : GoodNFA 2) (ζ : M.sa) (L : BVRel) :
rw [h]; simp_rw [in_enc]
simp [langRel, h3]

lemma GoodNFA.correct2_spec (M : GoodNFA 2) (ζ : M.sa2) (L : BVRel) :
M.correct2 ζ L → M.accepts = langRel2 L := by
rintro ⟨h1, h2, h3⟩
suffices hc : M.correct (fun q w (bvn : Vector (BitVec w) 2) => ζ q (bvn.get 0) (bvn.get 1))
(fun w bvn => L (bvn.get 0) (bvn.get 1)) by
rw [M.correct_spec hc]
simp [langRel2, langRel]
constructor
· simp_all
· intros q; simp_all; rfl
· simp_all

-- move
@[simp]
theorem BitVec.cast_inj (h : w = w') {x y : BitVec w} : cast h x = cast h y ↔ x = y := by
Expand Down Expand Up @@ -484,6 +517,7 @@ def NFA.msbSA (q : msbState) : Language (BitVec 1) :=

@[simp] theorem Language.trivial : x ∈ (⊤ : Language α) := by trivial

-- TODO: rewrite with the n-ary `correct` predicate!
def NFA.msbCorrect : msb.correct msbSA msbLang := by
constructor
· simp [msb, msbSA]
Expand Down

0 comments on commit 05953af

Please sign in to comment.