From cee3ed5ee6f68a18566c614793a9fb523c4b43af Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Wed, 29 May 2024 12:51:33 -0400 Subject: [PATCH 1/9] feat: dependent array type --- Batteries.lean | 1 + Batteries/Data/DArray.lean | 2 + Batteries/Data/DArray/Basic.lean | 157 ++++++++++++++++++++++++++++++ Batteries/Data/DArray/Lemmas.lean | 75 ++++++++++++++ test/darray.lean | 28 ++++++ 5 files changed, 263 insertions(+) create mode 100644 Batteries/Data/DArray.lean create mode 100644 Batteries/Data/DArray/Basic.lean create mode 100644 Batteries/Data/DArray/Lemmas.lean create mode 100644 test/darray.lean diff --git a/Batteries.lean b/Batteries.lean index f56659960b..c987926b24 100644 --- a/Batteries.lean +++ b/Batteries.lean @@ -20,6 +20,7 @@ import Batteries.Data.BitVec import Batteries.Data.Bool import Batteries.Data.ByteArray import Batteries.Data.Char +import Batteries.Data.DArray import Batteries.Data.DList import Batteries.Data.Fin import Batteries.Data.HashMap diff --git a/Batteries/Data/DArray.lean b/Batteries/Data/DArray.lean new file mode 100644 index 0000000000..65bab05079 --- /dev/null +++ b/Batteries/Data/DArray.lean @@ -0,0 +1,2 @@ +import Batteries.Data.DArray.Basic +import Batteries.Data.DArray.Lemmas diff --git a/Batteries/Data/DArray/Basic.lean b/Batteries/Data/DArray/Basic.lean new file mode 100644 index 0000000000..bf21a7958c --- /dev/null +++ b/Batteries/Data/DArray/Basic.lean @@ -0,0 +1,157 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: François G. Dorais +-/ + +namespace Batteries + +/-! +# Dependent Arrays + +`DArray` is a heterogenous array where the type of each item depends on the index. The model +for this type is the dependent function type `(i : Fin n) → α i` where `α i` is the type assigned +to items at index `i`. + +The implementation of `DArray` is based on Lean's dynamic array type. This means that the array +values are stored in a contiguous memory region and can be accessed in constant time. Lean's arrays +also support destructive updates when the array is exclusive (RC=1). + +### Implementation Details + +Lean's array API does not directly support dependent arrays. Each `DArray n α` is internally stored +as an `Array NonScalar` with length `n`. This is sound since Lean's array implementation does not +record nor use the type of the items stored in the array. So it is safe to use `UnsafeCast` to +convert array items to the appropriate type when necessary. +-/ + +/-- `DArray` is a heterogenous array where the type of each item depends on the index. -/ +-- TODO: Use a structure once [lean4#2292](https://github.com/leanprover/lean4/pull/2292) is fixed. +inductive DArray (n) (α : Fin n → Type _) where + /-- Makes a new `DArray` with given item values. `O(n*g)` where `get i` is `O(g)`. -/ + | mk (get : (i : Fin n) → α i) + +namespace DArray + +section unsafe_implementation + +private unsafe abbrev data : DArray n α → Array NonScalar := unsafeCast + +private unsafe def mkImpl (get : (i : Fin n) → α i) : DArray n α := + unsafeCast <| Array.ofFn fun i => (unsafeCast (get i) : NonScalar) + +private unsafe def getImpl (a : DArray n α) (i) : α i := + unsafeCast <| a.data.get ⟨i.val, lcProof⟩ + +private unsafe def ugetImpl (a : DArray n α) (i : USize) (h : i.toNat < n) : α ⟨i.toNat, h⟩ := + unsafeCast <| a.data.uget i lcProof + +private unsafe def setImpl (a : DArray n α) (i) (v : α i) : DArray n α := + unsafeCast <| a.data.set ⟨i.val, lcProof⟩ <| unsafeCast v + +private unsafe def usetImpl (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) : + DArray n α := unsafeCast <| a.data.uset i (unsafeCast v) lcProof + +private unsafe def modifyFImpl [Functor f] (a : DArray n α) (i : Fin n) + (t : α i → f (α i)) : f (DArray n α) := + let v := unsafeCast <| a.data.get ⟨i.val, lcProof⟩ + -- Make sure `v` is unshared, if possible, by replacing its array entry by `box(0)`. + let a := unsafeCast <| a.data.set ⟨i.val, lcProof⟩ (unsafeCast ()) + setImpl a i <$> t v + +private unsafe def umodifyFImpl [Functor f] (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → f (α ⟨i.toNat, h⟩)) : f (DArray n α) := + let v := unsafeCast <| a.data.uget i lcProof + -- Make sure `v` is unshared, if possible, by replacing its array entry by `box(0)`. + let a := unsafeCast <| a.data.uset i (unsafeCast ()) lcProof + usetImpl a i h <$> t v + +private unsafe def pushImpl (a : DArray n α) (v : β) : + DArray (n+1) fun i => if h : i.val < n then α ⟨i.val, h⟩ else β := + unsafeCast <| a.data.push <| unsafeCast v + +private unsafe def popImpl (a : DArray (n+1) α) : DArray n fun i => α i.castSucc := + unsafeCast <| a.data.pop + +private unsafe def copyImpl (a : DArray n α) : DArray n α := + unsafeCast <| a.data.extract 0 n + +end unsafe_implementation + +attribute [implemented_by mkImpl] DArray.mk + +instance (α : Fin n → Type _) [(i : Fin n) → Inhabited (α i)] : Inhabited (DArray n α) where + default := mk fun _ => default + +/-- Gets the `DArray` item at index `i`. `O(1)`. -/ +@[implemented_by getImpl] +protected def get : DArray n α → (i : Fin n) → α i + | mk get => get + +@[simp, inherit_doc DArray.get] +protected abbrev getN (a : DArray n α) (i) (h : i < n := by get_elem_tactic) : α ⟨i, h⟩ := + a.get ⟨i, h⟩ + +/-- Gets the `DArray` item at index `i : USize`. Slightly faster than `get`; `O(1)`. -/ +@[implemented_by ugetImpl] +protected def uget (a : DArray n α) (i : USize) (h : i.toNat < n) : α ⟨i.toNat, h⟩ := + a.get ⟨i.toNat, h⟩ + +private def casesOnImpl.{u} {motive : DArray n α → Sort u} (a : DArray n α) + (h : (get : (i : Fin n) → α i) → motive (.mk get)) : motive a := + h a.get + +attribute [implemented_by casesOnImpl] DArray.casesOn + +/-- Sets the `DArray` item at index `i`. `O(1)` if exclusive else `O(n)`. -/ +@[implemented_by setImpl] +protected def set (a : DArray n α) (i : Fin n) (v : α i) : DArray n α := + mk fun j => if h : i = j then h ▸ v else a.get j + +/-- +Sets the `DArray` item at index `i : USize`. +Slightly faster than `set` and `O(1)` if exclusive else `O(n)`. +-/ +@[implemented_by usetImpl] +protected def uset (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) := + a.set ⟨i.toNat, h⟩ v + +@[simp, inherit_doc DArray.set] +protected abbrev setN (a : DArray n α) (i) (h : i < n := by get_elem_tactic) (v : α ⟨i, h⟩) := + a.set ⟨i, h⟩ v + +/-- Modifies the `DArray` item at index `i` using transform `t` and the functor `f`. -/ +@[implemented_by modifyFImpl] +protected def modifyF [Functor f] (a : DArray n α) (i : Fin n) + (t : α i → f (α i)) : f (DArray n α) := a.set i <$> t (a.get i) + +/-- Modifies the `DArray` item at index `i` using transform `t`. -/ +@[inline] +protected def modify (a : DArray n α) (i : Fin n) (t : α i → α i) : DArray n α := + a.modifyF (f:=Id) i t + +/-- Modifies the `DArray` item at index `i : USize` using transform `t` and the functor `f`. -/ +@[implemented_by umodifyFImpl] +protected def umodifyF [Functor f] (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → f (α ⟨i.toNat, h⟩)) : f (DArray n α) := a.uset i h <$> t (a.uget i h) + +/-- Modifies the `DArray` item at index `i : USize` using transform `t`. -/ +@[inline] +protected def umodify (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → α ⟨i.toNat, h⟩) : DArray n α := + a.umodifyF (f:=Id) i h t + +/-- Copies the `DArray` to an exclusive `DArray`. `O(1)` if exclusive else `O(n)`. -/ +@[implemented_by copyImpl] +protected def copy (a : DArray n α) : DArray n α := mk a.get + +/-- Push an element onto the end of a `DArray`. `O(1)` if exclusive else `O(n)`. -/ +@[implemented_by pushImpl] +protected def push (a : DArray n α) (v : β) : + DArray (n+1) fun i => if h : i.val < n then α ⟨i.val, h⟩ else β := + mk fun i => if h : i.val < n then dif_pos h ▸ a.get ⟨i.val, h⟩ else dif_neg h ▸ v + +/-- Delete the last item of a `DArray`. `O(1)`. -/ +@[implemented_by popImpl] +protected def pop (a : DArray (n+1) α) : DArray n fun i => α i.castSucc := + mk fun i => a.get i.castSucc diff --git a/Batteries/Data/DArray/Lemmas.lean b/Batteries/Data/DArray/Lemmas.lean new file mode 100644 index 0000000000..bea7603e67 --- /dev/null +++ b/Batteries/Data/DArray/Lemmas.lean @@ -0,0 +1,75 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: François G. Dorais +-/ + +import Batteries.Data.DArray.Basic + +namespace Batteries.DArray + +@[ext] +protected theorem ext : {a b : DArray n α} → (∀ i, a.get i = b.get i) → a = b + | mk _, mk _, h => congrArg _ <| funext fun i => h i + +@[simp] +theorem get_mk (i : Fin n) : DArray.get (.mk init) i = init i := rfl + +theorem set_mk {α : Fin n → Type _} {init : (i : Fin n) → α i} (i : Fin n) (v : α i) : + DArray.set (.mk init) i v = .mk fun j => if h : i = j then h ▸ v else init j := rfl + +@[simp] +theorem get_set (a : DArray n α) (i : Fin n) (v : α i) : (a.set i v).get i = v := by + simp only [DArray.get, DArray.set, dif_pos] + +theorem get_set_ne (a : DArray n α) (v : α i) (h : i ≠ j) : (a.set i v).get j = a.get j := by + simp only [DArray.get, DArray.set, dif_neg h] + +@[simp] +theorem set_set (a : DArray n α) (i : Fin n) (v w : α i) : (a.set i v).set i w = a.set i w := by + ext j + if h : i = j then + rw [← h, get_set, get_set] + else + rw [get_set_ne _ _ h, get_set_ne _ _ h, get_set_ne _ _ h] + +theorem get_modifyF [Functor f] [LawfulFunctor f] (a : DArray n α) (i : Fin n) (t : α i → f (α i)) : + (DArray.get . i) <$> a.modifyF i t = t (a.get i) := by + simp [DArray.modifyF, ← comp_map] + conv => rhs; rw [← id_map (t (a.get i))] + congr; ext; simp + +@[simp] +theorem get_modify (a : DArray n α) (i : Fin n) (t : α i → α i) : + (a.modify i t).get i = t (a.get i) := get_modifyF (f:=Id) a i t + +theorem get_modify_ne (a : DArray n α) (t : α i → α i) (h : i ≠ j) : + (a.modify i t).get j = a.get j := get_set_ne _ _ h + +@[simp] +theorem set_modify (a : DArray n α) (i : Fin n) (t : α i → α i) (v : α i) : + (a.set i v).modify i t = a.set i (t v) := by + ext j + if h : i = j then + cases h; simp + else + simp [h, get_modify_ne, get_set_ne] + +@[simp] +theorem uget_eq_get (a : DArray n α) (i : USize) (h : i.toNat < n) : + a.uget i h = a.get ⟨i.toNat, h⟩ := rfl + +@[simp] +theorem uset_eq_set (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) : + a.uset i h v = a.set ⟨i.toNat, h⟩ v := rfl + +@[simp] +theorem umodifyF_eq_modifyF [Functor f] (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → f (α ⟨i.toNat, h⟩)) : a.umodifyF i h t = a.modifyF ⟨i.toNat, h⟩ t := rfl + +@[simp] +theorem umodify_eq_modify (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → α ⟨i.toNat, h⟩) : a.umodify i h t = a.modify ⟨i.toNat, h⟩ t := rfl + +@[simp] +theorem copy_eq (a : DArray n α) : a.copy = a := rfl diff --git a/test/darray.lean b/test/darray.lean new file mode 100644 index 0000000000..0af1a42520 --- /dev/null +++ b/test/darray.lean @@ -0,0 +1,28 @@ +import Batteries.Data.DArray + +open Batteries + +def foo : DArray 3 fun | 0 => String | 1 => Nat | 2 => Array Nat := + .mk fun | 0 => "foo" | 1 => 42 | 2 => #[4, 2] + +def bar := foo.set 0 "bar" + +#guard foo.get 0 == "foo" +#guard foo.get 1 == 42 +#guard foo.get 2 == #[4, 2] + +#guard (foo.set 1 1).get 0 == "foo" +#guard (foo.set 1 1).get 1 == 1 +#guard (foo.set 1 1).get 2 == #[4, 2] + +#guard bar.get 0 == "bar" +#guard (bar.set 0 (foo.get 0)).get 0 == "foo" +#guard ((bar.set 0 "baz").set 1 1).get 0 == "baz" +#guard ((bar.set 0 "baz").set 0 "foo").get 0 == "foo" +#guard ((bar.set 0 "foo").set 0 "baz").get 0 == "baz" + +def Batteries.DArray.head : DArray (n+1) α → α 0 + | mk f => f 0 + +#guard foo.head == "foo" +#guard bar.head == "bar" From 73765f193110e06a44dbad3238442bc98376ee0a Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Wed, 29 May 2024 12:57:44 -0400 Subject: [PATCH 2/9] feat: `Fin.foldlM` and `Fin.foldrM` --- Batteries/Data/Fin/Basic.lean | 71 ++++++++++++++++++---- Batteries/Data/Fin/Lemmas.lean | 102 ++++++++++++++++++++++---------- Batteries/Data/List/Lemmas.lean | 10 ++++ 3 files changed, 142 insertions(+), 41 deletions(-) diff --git a/Batteries/Data/Fin/Basic.lean b/Batteries/Data/Fin/Basic.lean index 495344f90f..abd05d5ba1 100644 --- a/Batteries/Data/Fin/Basic.lean +++ b/Batteries/Data/Fin/Basic.lean @@ -15,16 +15,67 @@ def enum (n) : Array (Fin n) := Array.ofFn id /-- `list n` is the list of all elements of `Fin n` in order -/ def list (n) : List (Fin n) := (enum n).data -/-- Folds over `Fin n` from the left: `foldl 3 f x = f (f (f x 0) 1) 2`. -/ -@[inline] def foldl (n) (f : α → Fin n → α) (init : α) : α := loop init 0 where - /-- Inner loop for `Fin.foldl`. `Fin.foldl.loop n f x i = f (f (f x i) ...) (n-1)` -/ - loop (x : α) (i : Nat) : α := - if h : i < n then loop (f x ⟨i, h⟩) (i+1) else x +/-- +Folds a monadic function over `Fin n` from left to right: +``` +Fin.foldlM n f x₀ = do + let x₁ ← f x₀ 0 + let x₂ ← f x₁ 1 + ... + let xₙ ← f xₙ₋₁ (n-1) + pure xₙ +``` +-/ +@[inline] def foldlM [Monad m] (n) (f : α → Fin n → m α) (init : α) : m α := loop init 0 where + /-- + Inner loop for `Fin.foldlM`. + ``` + Fin.foldlM.loop n f xᵢ i = do + let xᵢ₊₁ ← f xᵢ i + ... + let xₙ ← f xₙ₋₁ (n-1) + pure xₙ + ``` + `Fin.foldlM.loop n f x i = f x i >>= fun x => f x (i+1) >>= ...` + -/ + loop (x : α) (i : Nat) : m α := do + if h : i < n then f x ⟨i, h⟩ >>= (loop · (i+1)) else pure x termination_by n - i +/-- +Folds a monadic function over `Fin n` from right to left: +``` +Fin.foldrM n f x₀ = do + let x₁ ← f (n-1) x₀ + let x₂ ← f (n-2) x₁ + ... + let xₙ ← f 0 xₙ₋₁ + pure xₙ +``` +-/ +@[inline] def foldrM [Monad m] (n) (f : Fin n → α → m α) (init : α) : m α := + loop ⟨n, Nat.le_refl n⟩ init where + /-- + Inner loop for `Fin.foldrM`. + ``` + Fin.foldrM n f i x₀ = do + let x₁ ← f (n-1) x₀ + let x₂ ← f (n-2) x₁ + ... + let xᵢ ← f (n-i) xᵢ₋₁ + pure xᵢ + ``` + -/ + loop : {i // i ≤ n} → α → m α + | ⟨0, _⟩, x => pure x + | ⟨i+1, h⟩, x => f ⟨i, h⟩ x >>= loop ⟨i, Nat.le_of_lt h⟩ + +-- These are also defined in core in the root namespace! + /-- Folds over `Fin n` from the right: `foldr 3 f x = f 0 (f 1 (f 2 x))`. -/ -@[inline] def foldr (n) (f : Fin n → α → α) (init : α) : α := loop ⟨n, Nat.le_refl n⟩ init where - /-- Inner loop for `Fin.foldr`. `Fin.foldr.loop n f i x = f 0 (f ... (f (i-1) x))` -/ - loop : {i // i ≤ n} → α → α - | ⟨0, _⟩, x => x - | ⟨i+1, h⟩, x => loop ⟨i, Nat.le_of_lt h⟩ (f ⟨i, h⟩ x) +@[inline] def foldr (n) (f : Fin n → α → α) (init : α) : α := + foldrM (m:=Id) n f init + +/-- Folds over `Fin n` from the left: `foldl 3 f x = f (f (f x 0) 1) 2`. -/ +@[inline] def foldl (n) (f : α → Fin n → α) (init : α) : α := + foldlM (m:=Id) n f init diff --git a/Batteries/Data/Fin/Lemmas.lean b/Batteries/Data/Fin/Lemmas.lean index f8d2569a54..0deb2b19af 100644 --- a/Batteries/Data/Fin/Lemmas.lean +++ b/Batteries/Data/Fin/Lemmas.lean @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Mario Carneiro -/ import Batteries.Data.Fin.Basic +import Batteries.Data.List.Lemmas namespace Fin @@ -36,52 +37,91 @@ protected theorem le_antisymm {x y : Fin n} (h1 : x ≤ y) (h2 : y ≤ x) : x = theorem list_succ (n) : list (n+1) = 0 :: (list n).map Fin.succ := by apply List.ext_get; simp; intro i; cases i <;> simp -/-! ### foldl -/ +/-! ### foldlM -/ -theorem foldl_loop_lt (f : α → Fin n → α) (x) (h : m < n) : - foldl.loop n f x m = foldl.loop n f (f x ⟨m, h⟩) (m+1) := by - rw [foldl.loop, dif_pos h] +theorem foldlM_loop_lt [Monad m] (f : α → Fin n → m α) (x) (h : i < n) : + foldlM.loop n f x i = f x ⟨i, h⟩ >>= (foldlM.loop n f . (i+1)) := by + rw [foldlM.loop, dif_pos h] -theorem foldl_loop_eq (f : α → Fin n → α) (x) : foldl.loop n f x n = x := by - rw [foldl.loop, dif_neg (Nat.lt_irrefl _)] +theorem foldlM_loop_eq [Monad m] (f : α → Fin n → m α) (x) : foldlM.loop n f x n = pure x := by + rw [foldlM.loop, dif_neg (Nat.lt_irrefl _)] -theorem foldl_loop (f : α → Fin (n+1) → α) (x) (h : m < n+1) : - foldl.loop (n+1) f x m = foldl.loop n (fun x i => f x i.succ) (f x ⟨m, h⟩) m := by - if h' : m < n then - rw [foldl_loop_lt _ _ h, foldl_loop_lt _ _ h', foldl_loop]; rfl +theorem foldlM_loop [Monad m] (f : α → Fin (n+1) → m α) (x) (h : i < n+1) : + foldlM.loop (n+1) f x i = f x ⟨i, h⟩ >>= (foldlM.loop n (fun x j => f x j.succ) . i) := by + if h' : i < n then + rw [foldlM_loop_lt _ _ h] + congr; funext + rw [foldlM_loop_lt _ _ h', foldlM_loop]; rfl else cases Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.not_lt.1 h') - rw [foldl_loop_lt, foldl_loop_eq, foldl_loop_eq] -termination_by n - m + rw [foldlM_loop_lt] + congr; funext + rw [foldlM_loop_eq, foldlM_loop_eq] +termination_by n - i -theorem foldl_zero (f : α → Fin 0 → α) (x) : foldl 0 f x = x := rfl +theorem foldlM_zero [Monad m] (f : α → Fin 0 → m α) (x) : foldlM 0 f x = pure x := rfl -theorem foldl_succ (f : α → Fin (n+1) → α) (x) : - foldl (n+1) f x = foldl n (fun x i => f x i.succ) (f x 0) := foldl_loop .. +theorem foldlM_succ [Monad m] (f : α → Fin (n+1) → m α) (x) : + foldlM (n+1) f x = f x 0 >>= foldlM n (fun x j => f x j.succ) := foldlM_loop .. -theorem foldl_eq_foldl_list (f : α → Fin n → α) (x) : foldl n f x = (list n).foldl f x := by +theorem foldlM_eq_foldlM_list [Monad m] (f : α → Fin n → m α) (x) : + foldlM n f x = (list n).foldlM f x := by induction n generalizing x with | zero => rfl - | succ n ih => rw [foldl_succ, ih, list_succ, List.foldl_cons, List.foldl_map] + | succ n ih => + rw [foldlM_succ, list_succ, List.foldlM_cons] + congr; funext + rw [List.foldlM_map, ih] + +theorem foldl_eq_foldlM : foldl n f init = foldlM (m:=Id) n f init := rfl + +theorem foldl_zero (f : α → Fin 0 → α) (x) : foldl 0 f x = x := rfl + +theorem foldl_succ (f : α → Fin (n+1) → α) (x) : + foldl (n+1) f x = foldl n (fun x j => f x j.succ) (f x 0) := foldlM_succ .. -/-! ### foldr -/ +theorem foldl_eq_foldl_list (f : α → Fin n → α) (x) : + foldl n f x = (list n).foldl f x := + by simp only [foldl_eq_foldlM, foldlM_eq_foldlM_list, List.foldl_eq_foldlM] -theorem foldr_loop_zero (f : Fin n → α → α) (x) : foldr.loop n f ⟨0, Nat.zero_le _⟩ x = x := rfl +/-! ### foldrM -/ -theorem foldr_loop_succ (f : Fin n → α → α) (x) (h : m < n) : - foldr.loop n f ⟨m+1, h⟩ x = foldr.loop n f ⟨m, Nat.le_of_lt h⟩ (f ⟨m, h⟩ x) := rfl +theorem foldrM_loop_zero [Monad m] (f : Fin n → α → m α) (x) : + foldrM.loop n f ⟨0, Nat.zero_le _⟩ x = pure x := rfl -theorem foldr_loop (f : Fin (n+1) → α → α) (x) (h : m+1 ≤ n+1) : - foldr.loop (n+1) f ⟨m+1, h⟩ x = - f 0 (foldr.loop n (fun i => f i.succ) ⟨m, Nat.le_of_succ_le_succ h⟩ x) := by - induction m generalizing x with - | zero => simp [foldr_loop_zero, foldr_loop_succ] - | succ m ih => rw [foldr_loop_succ, ih]; rfl +theorem foldrM_loop_succ [Monad m] (f : Fin n → α → m α) (x) (h : i < n) : + foldrM.loop n f ⟨i+1, h⟩ x = f ⟨i, h⟩ x >>= foldrM.loop n f ⟨i, Nat.le_of_lt h⟩ := rfl -theorem foldr_succ (f : Fin (n+1) → α → α) (x) : - foldr (n+1) f x = f 0 (foldr n (fun i => f i.succ) x) := foldr_loop .. +theorem foldrM_loop [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) (x) (h : i+1 ≤ n+1) : + foldrM.loop (n+1) f ⟨i+1, h⟩ x = + foldrM.loop n (fun j => f j.succ) ⟨i, Nat.le_of_succ_le_succ h⟩ x >>= f 0 := by + induction i generalizing x with + | zero => + rw [foldrM_loop_zero, foldrM_loop_succ, pure_bind] + conv => rhs; rw [←bind_pure (f 0 x)] + congr + | succ i ih => + rw [foldrM_loop_succ, foldrM_loop_succ, bind_assoc] + congr; funext; exact ih .. -theorem foldr_eq_foldr_list (f : Fin n → α → α) (x) : foldr n f x = (list n).foldr f x := by +theorem foldrM_zero [Monad m] (f : Fin 0 → α → m α) (x) : foldrM 0 f x = pure x := rfl + +theorem foldrM_succ [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) (x) : + foldrM (n+1) f x = foldrM n (fun i => f i.succ) x >>= f 0 := foldrM_loop .. + +theorem foldrM_eq_foldrM_list [Monad m] [LawfulMonad m] (f : Fin n → α → m α) (x) : + foldrM n f x = (list n).foldrM f x := by induction n with | zero => rfl - | succ n ih => rw [foldr_succ, ih, list_succ, List.foldr_cons, List.foldr_map] + | succ n ih => rw [foldrM_succ, ih, list_succ, List.foldrM_cons, List.foldrM_map] + +theorem foldr_eq_foldrM (f : Fin n → α → α) (init) : foldr n f init = foldrM (m:=Id) n f init := rfl + +theorem foldr_zero (f : Fin 0 → α → α) (x) : foldr 0 f x = x := rfl + +theorem foldr_succ (f : Fin (n+1) → α → α) (x) : + foldr (n+1) f x = f 0 (foldr n (fun i => f i.succ) x) := foldrM_loop .. + +theorem foldr_eq_foldr_list (f : Fin n → α → α) (x) : + foldr n f x = (list n).foldr f x := by + simp only [foldr_eq_foldrM, foldrM_eq_foldrM_list, List.foldr_eq_foldrM] diff --git a/Batteries/Data/List/Lemmas.lean b/Batteries/Data/List/Lemmas.lean index d23566b21f..d61c2c03d2 100644 --- a/Batteries/Data/List/Lemmas.lean +++ b/Batteries/Data/List/Lemmas.lean @@ -2822,3 +2822,13 @@ theorem lt_antisymm' [LT α] have ab : ¬a < b := fun ab => h₁ (.head _ _ ab) cases lt_antisymm ab (fun ba => h₂ (.head _ _ ba)) rw [ih (fun ll => h₁ (.tail ab ab ll)) (fun ll => h₂ (.tail ab ab ll))] + +/-! ### foldlM and foldrM -/ + +theorem foldlM_map [Monad m] (f : β₁ → β₂) (g : α → β₂ → m α) (l : List β₁) (init : α) : + (l.map f).foldlM g init = l.foldlM (fun x y => g x (f y)) init := by + induction l generalizing g init <;> simp [*] + +theorem foldrM_map [Monad m] [LawfulMonad m] (f : β₁ → β₂) (g : β₂ → α → m α) (l : List β₁) + (init : α) : (l.map f).foldrM g init = l.foldrM (fun x y => g (f x) y) init := by + induction l generalizing g init <;> simp [*] From 01e73538f5add40af5fe25dbd768e44fa97f55f4 Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Wed, 29 May 2024 14:01:12 -0400 Subject: [PATCH 3/9] feat: folds for dependent arrays --- Batteries/Data/DArray/Basic.lean | 66 ++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/Batteries/Data/DArray/Basic.lean b/Batteries/Data/DArray/Basic.lean index bf21a7958c..a8cbf221c8 100644 --- a/Batteries/Data/DArray/Basic.lean +++ b/Batteries/Data/DArray/Basic.lean @@ -4,6 +4,8 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: François G. Dorais -/ +import Batteries.Data.Fin.Basic + namespace Batteries /-! @@ -76,6 +78,33 @@ private unsafe def popImpl (a : DArray (n+1) α) : DArray n fun i => α i.castSu private unsafe def copyImpl (a : DArray n α) : DArray n α := unsafeCast <| a.data.extract 0 n +private unsafe def foldlMImpl [Monad m] (a : DArray n α) (f : β → {i : Fin n} → α i → m β) + (init : β) : m β := + if n < USize.size then + loop 0 init + else + have : Inhabited β := ⟨init⟩ + -- array data exceeds the entire address space! + panic! "out of memory" +where + -- loop invariant: `i.toNat ≤ n` + loop (i : USize) (x : β) : m β := + if i.toNat ≥ n then pure x else + f x (a.ugetImpl i lcProof) >>= loop (i+1) + +private unsafe def foldrMImpl [Monad m] (a : DArray n α) (f : {i : Fin n} → α i → β → m β) + (init : β) : m β := + if h : n < USize.size then + loop (.ofNatCore n h) init + else + have : Inhabited β := ⟨init⟩ + -- array data exceeds the entire address space! + panic! "out of memory" +where + -- loop invariant: `i.toNat ≤ n` + loop (i : USize) (x : β) : m β := + if i = 0 then pure x else f (a.ugetImpl (i-1) lcProof) x >>= loop (i-1) + end unsafe_implementation attribute [implemented_by mkImpl] DArray.mk @@ -155,3 +184,40 @@ protected def push (a : DArray n α) (v : β) : @[implemented_by popImpl] protected def pop (a : DArray (n+1) α) : DArray n fun i => α i.castSucc := mk fun i => a.get i.castSucc + +/-- +Folds a monadic function over a `DArray` from left to right: +``` +DArray.foldlM a f x₀ = do + let x₁ ← f x₀ (a.get 0) + let x₂ ← f x₁ (a.get 1) + ... + let xₙ ← f xₙ₋₁ (a.get (n-1)) + pure xₙ +``` +-/ +@[implemented_by foldlMImpl] +def foldlM [Monad m] (a : DArray n α) (f : β → {i : Fin n} → α i → m β) (init : β) : m β := + Fin.foldlM n (f · <| a.get ·) init + +/-- Folds a function over a `DArray` from the left. -/ +def foldl (a : DArray n α) (f : β → {i : Fin n} → α i → β) (init : β) : β := + a.foldlM (m:=Id) f init + +/-- +Folds a monadic function over a `DArray` from right to left: +``` +DArray.foldrM a f x₀ = do + let x₁ ← f (a.get (n-1)) x₀ + let x₂ ← f (a.get (n-2)) x₁ + ... + let xₙ ← f (a.get 0) xₙ₋₁ + pure xₙ +``` +-/ +def foldrM [Monad m] (a : DArray n α) (f : {i : Fin n} → α i → β → m β) (init : β) : m β := + Fin.foldrM n (f <| a.get ·) init + +/-- Folds a function over a `DArray` from the right. -/ +def foldr (a : DArray n α) (f : {i : Fin n} → α i → β → β) (init : β) : β := + a.foldrM (m:=Id) f init From bf6f52824d3e5760f040d57c1e535ff7c7cdf891 Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Wed, 29 May 2024 20:41:39 -0400 Subject: [PATCH 4/9] feat: `ForIn` instance for `DArray` --- Batteries/Data/DArray/Basic.lean | 14 ++++++++++++++ test/darray.lean | 17 ++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/Batteries/Data/DArray/Basic.lean b/Batteries/Data/DArray/Basic.lean index a8cbf221c8..17d9e29f91 100644 --- a/Batteries/Data/DArray/Basic.lean +++ b/Batteries/Data/DArray/Basic.lean @@ -221,3 +221,17 @@ def foldrM [Monad m] (a : DArray n α) (f : {i : Fin n} → α i → β → m β /-- Folds a function over a `DArray` from the right. -/ def foldr (a : DArray n α) (f : {i : Fin n} → α i → β → β) (init : β) : β := a.foldrM (m:=Id) f init + +/-- Implementation of `ForIn` for `DArray`. -/ +@[specialize] +def forIn [Monad m] (a : DArray n α) (init : β) (f : Sigma α → β → m (ForInStep β)) : m β := do + match ← a.foldlM step (.yield init) with + | .done r => pure r + | .yield r => pure r +where + step : ForInStep β → {i : Fin n} → α i → m (ForInStep β) + | .done r, _, _ => pure (.done r) + | .yield r, i, x => f ⟨i, x⟩ r + +instance (α : Fin n → Type _) [Monad m] : ForIn m (DArray n α) (Sigma α) where + forIn := forIn diff --git a/test/darray.lean b/test/darray.lean index 0af1a42520..5f3747d5eb 100644 --- a/test/darray.lean +++ b/test/darray.lean @@ -1,4 +1,4 @@ -import Batteries.Data.DArray +import Batteries.Data.DArray.Basic open Batteries @@ -26,3 +26,18 @@ def Batteries.DArray.head : DArray (n+1) α → α 0 #guard foo.head == "foo" #guard bar.head == "bar" + +abbrev Data (n : Nat) : Type _ := DArray (n+1) fun | 0 => String | ⟨_+1,_⟩ => Nat + +def Data.sum (a : Data n) : String × Nat := Id.run do + let mut r := ("", 0) + for ⟨i, x⟩ in a do + match i with + | 0 => r := (x, r.snd) + | ⟨_+1,_⟩ => r := ⟨r.fst, r.snd+x⟩ + return r + +def test : Data 2 := + .mk fun | 0 => "foo" | 1 => 4 | 2 => 2 + +#guard test.sum == ("foo", 6) From 6d9236af2f7235ccee56689d82cb61e2e521aadb Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Wed, 29 May 2024 21:05:26 -0400 Subject: [PATCH 5/9] feat: add `map` and `amap` --- Batteries/Data/DArray/Basic.lean | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/Batteries/Data/DArray/Basic.lean b/Batteries/Data/DArray/Basic.lean index 17d9e29f91..befccbece3 100644 --- a/Batteries/Data/DArray/Basic.lean +++ b/Batteries/Data/DArray/Basic.lean @@ -105,6 +105,15 @@ where loop (i : USize) (x : β) : m β := if i = 0 then pure x else f (a.ugetImpl (i-1) lcProof) x >>= loop (i-1) +@[specialize] +private unsafe def mapImpl (f : {i : Fin n} → α i → β i) (a : DArray n α) : DArray n β := + let f := fun i x => (unsafeCast (f (i:=i.cast lcProof) (unsafeCast x)) : NonScalar) + unsafeCast <| a.data.mapIdx f + +@[specialize] +private unsafe def amapImpl (f : {i : Fin n} → α i → β) (a : DArray n α) : Array β := + unsafeCast <| a.mapImpl f + end unsafe_implementation attribute [implemented_by mkImpl] DArray.mk @@ -235,3 +244,19 @@ where instance (α : Fin n → Type _) [Monad m] : ForIn m (DArray n α) (Sigma α) where forIn := forIn + +/-- +Applies `f : {i : Fin n} → α i → β i` to each element of a `DArray n α`, +returns the dependent array of results. +-/ +@[implemented_by mapImpl] +protected def map (f : {i : Fin n} → α i → β i) (a : DArray n α) : DArray n β := + mk fun i => f (a.get i) + +/-- +Applies `f : {i : Fin n} → α i → β` to each element of a `DArray n α`, +returns the (non-dependent) array of results. +-/ +@[implemented_by amapImpl] +def amap (f : {i : Fin n} → α i → β) (a : DArray n α) : Array β := + Array.ofFn fun i => f (a.get i) From c80ca8b5b4d38c26dcc0b3144584e77bb6f29ac1 Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Wed, 29 May 2024 21:33:54 -0400 Subject: [PATCH 6/9] fix: missing docstring --- Batteries/Data/DArray/Basic.lean | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Batteries/Data/DArray/Basic.lean b/Batteries/Data/DArray/Basic.lean index befccbece3..31f2fdfbb7 100644 --- a/Batteries/Data/DArray/Basic.lean +++ b/Batteries/Data/DArray/Basic.lean @@ -238,11 +238,12 @@ def forIn [Monad m] (a : DArray n α) (init : β) (f : Sigma α → β → m (Fo | .done r => pure r | .yield r => pure r where + /-- Step function for `forIn`. -/ step : ForInStep β → {i : Fin n} → α i → m (ForInStep β) | .done r, _, _ => pure (.done r) | .yield r, i, x => f ⟨i, x⟩ r -instance (α : Fin n → Type _) [Monad m] : ForIn m (DArray n α) (Sigma α) where +instance (m : Type _ → Type _) (α : Fin n → Type _) : ForIn m (DArray n α) (Sigma α) where forIn := forIn /-- From d87b4cd773ee619e11e00abb54880c8b2345cf91 Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Wed, 29 May 2024 21:44:40 -0400 Subject: [PATCH 7/9] feat: more lemmas --- Batteries/Data/DArray/Lemmas.lean | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/Batteries/Data/DArray/Lemmas.lean b/Batteries/Data/DArray/Lemmas.lean index bea7603e67..551c7d17ab 100644 --- a/Batteries/Data/DArray/Lemmas.lean +++ b/Batteries/Data/DArray/Lemmas.lean @@ -71,5 +71,28 @@ theorem umodifyF_eq_modifyF [Functor f] (a : DArray n α) (i : USize) (h : i.toN theorem umodify_eq_modify (a : DArray n α) (i : USize) (h : i.toNat < n) (t : α ⟨i.toNat, h⟩ → α ⟨i.toNat, h⟩) : a.umodify i h t = a.modify ⟨i.toNat, h⟩ t := rfl +theorem foldlM_eq_fin_foldlM [Monad m] (a : DArray n α) (f : β → {i : Fin n} → α i → m β) (init) : + a.foldlM f init = Fin.foldlM n (f · <| a.get ·) init := rfl + +theorem foldl_eq_foldlM (a : DArray n α) (f : β → {i : Fin n} → α i → β) (init) : + a.foldl f init = a.foldlM (m:=Id) f init := rfl + +theorem foldrM_eq_fin_foldrM [Monad m] (a : DArray n α) (f : {i : Fin n} → α i → β → m β) (init) : + a.foldrM f init = Fin.foldrM n (f <| a.get ·) init := rfl + +theorem foldr_eq_foldrM (a : DArray n α) (f : {i : Fin n} → α i → β → β) (init) : + a.foldr f init = a.foldrM (m:=Id) f init := rfl + @[simp] theorem copy_eq (a : DArray n α) : a.copy = a := rfl + +theorem get_map {β : Fin n → Type _} (f : {i : Fin n} → α i → β i) (a : DArray n α) (i : Fin n) : + (a.map f).get i = f (a.get i) := rfl + +@[simp] +theorem size_amap (f : {i : Fin n} → α i → β) (a : DArray n α) : + (a.amap f).size = n := Array.size_ofFn .. + +theorem getElem_amap (f : {i : Fin n} → α i → β) (a : DArray n α) (i : Fin n) + (h : i.val < (a.amap f).size := (a.size_amap f).symm ▸ i.is_lt) : + (a.amap f)[i] = f (a.get i) := by simp [amap] From 087e3596845d45cf18ce5c5edfa42adc9743dbbb Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Thu, 6 Jun 2024 02:12:06 -0400 Subject: [PATCH 8/9] feat: lemmas `List.foldlM_map` and `List.foldrM_map` --- Batteries/Data/List/Lemmas.lean | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/Batteries/Data/List/Lemmas.lean b/Batteries/Data/List/Lemmas.lean index d23566b21f..d61c2c03d2 100644 --- a/Batteries/Data/List/Lemmas.lean +++ b/Batteries/Data/List/Lemmas.lean @@ -2822,3 +2822,13 @@ theorem lt_antisymm' [LT α] have ab : ¬a < b := fun ab => h₁ (.head _ _ ab) cases lt_antisymm ab (fun ba => h₂ (.head _ _ ba)) rw [ih (fun ll => h₁ (.tail ab ab ll)) (fun ll => h₂ (.tail ab ab ll))] + +/-! ### foldlM and foldrM -/ + +theorem foldlM_map [Monad m] (f : β₁ → β₂) (g : α → β₂ → m α) (l : List β₁) (init : α) : + (l.map f).foldlM g init = l.foldlM (fun x y => g x (f y)) init := by + induction l generalizing g init <;> simp [*] + +theorem foldrM_map [Monad m] [LawfulMonad m] (f : β₁ → β₂) (g : β₂ → α → m α) (l : List β₁) + (init : α) : (l.map f).foldrM g init = l.foldrM (fun x y => g (f x) y) init := by + induction l generalizing g init <;> simp [*] From 8fd38ad5db009cdb5cd784f9425a0c769ce1cb9b Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Thu, 6 Jun 2024 02:15:45 -0400 Subject: [PATCH 9/9] feat: `Fin.foldlM` and `Fin.foldrM` --- Batteries/Data/Fin/Basic.lean | 70 +++++++++++++++++++++---- Batteries/Data/Fin/Lemmas.lean | 96 +++++++++++++++++++++++----------- 2 files changed, 126 insertions(+), 40 deletions(-) diff --git a/Batteries/Data/Fin/Basic.lean b/Batteries/Data/Fin/Basic.lean index 495344f90f..0651aac9af 100644 --- a/Batteries/Data/Fin/Basic.lean +++ b/Batteries/Data/Fin/Basic.lean @@ -15,16 +15,66 @@ def enum (n) : Array (Fin n) := Array.ofFn id /-- `list n` is the list of all elements of `Fin n` in order -/ def list (n) : List (Fin n) := (enum n).data -/-- Folds over `Fin n` from the left: `foldl 3 f x = f (f (f x 0) 1) 2`. -/ -@[inline] def foldl (n) (f : α → Fin n → α) (init : α) : α := loop init 0 where - /-- Inner loop for `Fin.foldl`. `Fin.foldl.loop n f x i = f (f (f x i) ...) (n-1)` -/ - loop (x : α) (i : Nat) : α := - if h : i < n then loop (f x ⟨i, h⟩) (i+1) else x +/-- +Folds a monadic function over `Fin n` from left to right: +``` +Fin.foldlM n f x₀ = do + let x₁ ← f x₀ 0 + let x₂ ← f x₁ 1 + ... + let xₙ ← f xₙ₋₁ (n-1) + pure xₙ +``` +-/ +@[inline] def foldlM [Monad m] (n) (f : α → Fin n → m α) (init : α) : m α := loop init 0 where + /-- + Inner loop for `Fin.foldlM`. + ``` + Fin.foldlM.loop n f xᵢ i = do + let xᵢ₊₁ ← f xᵢ i + ... + let xₙ ← f xₙ₋₁ (n-1) + pure xₙ + ``` + -/ + loop (x : α) (i : Nat) : m α := do + if h : i < n then f x ⟨i, h⟩ >>= (loop · (i+1)) else pure x termination_by n - i +/-- +Folds a monadic function over `Fin n` from right to left: +``` +Fin.foldrM n f xₙ = do + let xₙ₋₁ ← f (n-1) xₙ + let xₙ₋₂ ← f (n-2) xₙ₋₁ + ... + let x₀ ← f 0 x₁ + pure x₀ +``` +-/ +@[inline] def foldrM [Monad m] (n) (f : Fin n → α → m α) (init : α) : m α := + loop ⟨n, Nat.le_refl n⟩ init where + /-- + Inner loop for `Fin.foldrM`. + ``` + Fin.foldrM.loop n f i xᵢ = do + let xᵢ₋₁ ← f (i-1) xᵢ + ... + let x₁ ← f 1 x₂ + let x₀ ← f 0 x₁ + pure x₀ + ``` + -/ + loop : {i // i ≤ n} → α → m α + | ⟨0, _⟩, x => pure x + | ⟨i+1, h⟩, x => f ⟨i, h⟩ x >>= loop ⟨i, Nat.le_of_lt h⟩ + +-- These are also defined in core in the root namespace! + /-- Folds over `Fin n` from the right: `foldr 3 f x = f 0 (f 1 (f 2 x))`. -/ -@[inline] def foldr (n) (f : Fin n → α → α) (init : α) : α := loop ⟨n, Nat.le_refl n⟩ init where - /-- Inner loop for `Fin.foldr`. `Fin.foldr.loop n f i x = f 0 (f ... (f (i-1) x))` -/ - loop : {i // i ≤ n} → α → α - | ⟨0, _⟩, x => x - | ⟨i+1, h⟩, x => loop ⟨i, Nat.le_of_lt h⟩ (f ⟨i, h⟩ x) +@[inline] def foldr (n) (f : Fin n → α → α) (init : α) : α := + foldrM (m:=Id) n f init + +/-- Folds over `Fin n` from the left: `foldl 3 f x = f (f (f x 0) 1) 2`. -/ +@[inline] def foldl (n) (f : α → Fin n → α) (init : α) : α := + foldlM (m:=Id) n f init diff --git a/Batteries/Data/Fin/Lemmas.lean b/Batteries/Data/Fin/Lemmas.lean index 7167efa0ed..1510ea7cb3 100644 --- a/Batteries/Data/Fin/Lemmas.lean +++ b/Batteries/Data/Fin/Lemmas.lean @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Mario Carneiro -/ import Batteries.Data.Fin.Basic +import Batteries.Data.List.Lemmas namespace Fin @@ -52,28 +53,79 @@ theorem list_reverse (n) : (list n).reverse = (list n).map rev := by conv => rhs; rw [list_succ] simp [List.reverse_map, ih, Function.comp_def, rev_succ] -/-! ### foldl -/ +/-! ### foldlM -/ -theorem foldl_loop_lt (f : α → Fin n → α) (x) (h : m < n) : - foldl.loop n f x m = foldl.loop n f (f x ⟨m, h⟩) (m+1) := by - rw [foldl.loop, dif_pos h] +theorem foldlM_loop_lt [Monad m] (f : α → Fin n → m α) (x) (h : i < n) : + foldlM.loop n f x i = f x ⟨i, h⟩ >>= (foldlM.loop n f . (i+1)) := by + rw [foldlM.loop, dif_pos h] -theorem foldl_loop_eq (f : α → Fin n → α) (x) : foldl.loop n f x n = x := by - rw [foldl.loop, dif_neg (Nat.lt_irrefl _)] +theorem foldlM_loop_eq [Monad m] (f : α → Fin n → m α) (x) : foldlM.loop n f x n = pure x := by + rw [foldlM.loop, dif_neg (Nat.lt_irrefl _)] -theorem foldl_loop (f : α → Fin (n+1) → α) (x) (h : m < n+1) : - foldl.loop (n+1) f x m = foldl.loop n (fun x i => f x i.succ) (f x ⟨m, h⟩) m := by - if h' : m < n then - rw [foldl_loop_lt _ _ h, foldl_loop_lt _ _ h', foldl_loop]; rfl +theorem foldlM_loop [Monad m] (f : α → Fin (n+1) → m α) (x) (h : i < n+1) : + foldlM.loop (n+1) f x i = f x ⟨i, h⟩ >>= (foldlM.loop n (fun x j => f x j.succ) . i) := by + if h' : i < n then + rw [foldlM_loop_lt _ _ h] + congr; funext + rw [foldlM_loop_lt _ _ h', foldlM_loop]; rfl else cases Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.not_lt.1 h') - rw [foldl_loop_lt, foldl_loop_eq, foldl_loop_eq] -termination_by n - m + rw [foldlM_loop_lt] + congr; funext + rw [foldlM_loop_eq, foldlM_loop_eq] +termination_by n - i + +@[simp] theorem foldlM_zero [Monad m] (f : α → Fin 0 → m α) (x) : foldlM 0 f x = pure x := rfl + +theorem foldlM_succ [Monad m] (f : α → Fin (n+1) → m α) (x) : + foldlM (n+1) f x = f x 0 >>= foldlM n (fun x j => f x j.succ) := foldlM_loop .. + +theorem foldlM_eq_foldlM_list [Monad m] (f : α → Fin n → m α) (x) : + foldlM n f x = (list n).foldlM f x := by + induction n generalizing x with + | zero => rfl + | succ n ih => + rw [foldlM_succ, list_succ, List.foldlM_cons] + congr; funext + rw [List.foldlM_map, ih] + +/-! ### foldrM -/ + +theorem foldrM_loop_zero [Monad m] (f : Fin n → α → m α) (x) : + foldrM.loop n f ⟨0, Nat.zero_le _⟩ x = pure x := rfl + +theorem foldrM_loop_succ [Monad m] (f : Fin n → α → m α) (x) (h : i < n) : + foldrM.loop n f ⟨i+1, h⟩ x = f ⟨i, h⟩ x >>= foldrM.loop n f ⟨i, Nat.le_of_lt h⟩ := rfl + +theorem foldrM_loop [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) (x) (h : i+1 ≤ n+1) : + foldrM.loop (n+1) f ⟨i+1, h⟩ x = + foldrM.loop n (fun j => f j.succ) ⟨i, Nat.le_of_succ_le_succ h⟩ x >>= f 0 := by + induction i generalizing x with + | zero => + rw [foldrM_loop_zero, foldrM_loop_succ, pure_bind] + conv => rhs; rw [←bind_pure (f 0 x)] + congr + | succ i ih => + rw [foldrM_loop_succ, foldrM_loop_succ, bind_assoc] + congr; funext; exact ih .. + +@[simp] theorem foldrM_zero [Monad m] (f : Fin 0 → α → m α) (x) : foldrM 0 f x = pure x := rfl + +theorem foldrM_succ [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) (x) : + foldrM (n+1) f x = foldrM n (fun i => f i.succ) x >>= f 0 := foldrM_loop .. + +theorem foldrM_eq_foldrM_list [Monad m] [LawfulMonad m] (f : Fin n → α → m α) (x) : + foldrM n f x = (list n).foldrM f x := by + induction n with + | zero => rfl + | succ n ih => rw [foldrM_succ, ih, list_succ, List.foldrM_cons, List.foldrM_map] + +/-! ### foldl/foldr -/ @[simp] theorem foldl_zero (f : α → Fin 0 → α) (x) : foldl 0 f x = x := rfl theorem foldl_succ (f : α → Fin (n+1) → α) (x) : - foldl (n+1) f x = foldl n (fun x i => f x i.succ) (f x 0) := foldl_loop .. + foldl (n+1) f x = foldl n (fun x i => f x i.succ) (f x 0) := foldlM_succ .. theorem foldl_succ_last (f : α → Fin (n+1) → α) (x) : foldl (n+1) f x = f (foldl n (f · ·.castSucc) x) (last n) := by @@ -87,24 +139,10 @@ theorem foldl_eq_foldl_list (f : α → Fin n → α) (x) : foldl n f x = (list | zero => rfl | succ n ih => rw [foldl_succ, ih, list_succ, List.foldl_cons, List.foldl_map] -/-! ### foldr -/ - -theorem foldr_loop_zero (f : Fin n → α → α) (x) : foldr.loop n f ⟨0, Nat.zero_le _⟩ x = x := rfl - -theorem foldr_loop_succ (f : Fin n → α → α) (x) (h : m < n) : - foldr.loop n f ⟨m+1, h⟩ x = foldr.loop n f ⟨m, Nat.le_of_lt h⟩ (f ⟨m, h⟩ x) := rfl - -theorem foldr_loop (f : Fin (n+1) → α → α) (x) (h : m+1 ≤ n+1) : - foldr.loop (n+1) f ⟨m+1, h⟩ x = - f 0 (foldr.loop n (fun i => f i.succ) ⟨m, Nat.le_of_succ_le_succ h⟩ x) := by - induction m generalizing x with - | zero => simp [foldr_loop_zero, foldr_loop_succ] - | succ m ih => rw [foldr_loop_succ, ih]; rfl - @[simp] theorem foldr_zero (f : Fin 0 → α → α) (x) : foldr 0 f x = x := rfl theorem foldr_succ (f : Fin (n+1) → α → α) (x) : - foldr (n+1) f x = f 0 (foldr n (fun i => f i.succ) x) := foldr_loop .. + foldr (n+1) f x = f 0 (foldr n (fun i => f i.succ) x) := foldrM_succ .. theorem foldr_succ_last (f : Fin (n+1) → α → α) (x) : foldr (n+1) f x = foldr n (f ·.castSucc) (f (last n) x) := by @@ -118,8 +156,6 @@ theorem foldr_eq_foldr_list (f : Fin n → α → α) (x) : foldr n f x = (list | zero => rfl | succ n ih => rw [foldr_succ, ih, list_succ, List.foldr_cons, List.foldr_map] -/-! ### foldl/foldr -/ - theorem foldl_rev (f : Fin n → α → α) (x) : foldl n (fun x i => f i.rev x) x = foldr n f x := by induction n generalizing x with