diff --git a/Batteries.lean b/Batteries.lean index f56659960b..c987926b24 100644 --- a/Batteries.lean +++ b/Batteries.lean @@ -20,6 +20,7 @@ import Batteries.Data.BitVec import Batteries.Data.Bool import Batteries.Data.ByteArray import Batteries.Data.Char +import Batteries.Data.DArray import Batteries.Data.DList import Batteries.Data.Fin import Batteries.Data.HashMap diff --git a/Batteries/Data/DArray.lean b/Batteries/Data/DArray.lean new file mode 100644 index 0000000000..65bab05079 --- /dev/null +++ b/Batteries/Data/DArray.lean @@ -0,0 +1,2 @@ +import Batteries.Data.DArray.Basic +import Batteries.Data.DArray.Lemmas diff --git a/Batteries/Data/DArray/Basic.lean b/Batteries/Data/DArray/Basic.lean new file mode 100644 index 0000000000..31f2fdfbb7 --- /dev/null +++ b/Batteries/Data/DArray/Basic.lean @@ -0,0 +1,263 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: François G. Dorais +-/ + +import Batteries.Data.Fin.Basic + +namespace Batteries + +/-! +# Dependent Arrays + +`DArray` is a heterogenous array where the type of each item depends on the index. The model +for this type is the dependent function type `(i : Fin n) → α i` where `α i` is the type assigned +to items at index `i`. + +The implementation of `DArray` is based on Lean's dynamic array type. This means that the array +values are stored in a contiguous memory region and can be accessed in constant time. Lean's arrays +also support destructive updates when the array is exclusive (RC=1). + +### Implementation Details + +Lean's array API does not directly support dependent arrays. Each `DArray n α` is internally stored +as an `Array NonScalar` with length `n`. This is sound since Lean's array implementation does not +record nor use the type of the items stored in the array. So it is safe to use `UnsafeCast` to +convert array items to the appropriate type when necessary. +-/ + +/-- `DArray` is a heterogenous array where the type of each item depends on the index. -/ +-- TODO: Use a structure once [lean4#2292](https://github.com/leanprover/lean4/pull/2292) is fixed. +inductive DArray (n) (α : Fin n → Type _) where + /-- Makes a new `DArray` with given item values. `O(n*g)` where `get i` is `O(g)`. -/ + | mk (get : (i : Fin n) → α i) + +namespace DArray + +section unsafe_implementation + +private unsafe abbrev data : DArray n α → Array NonScalar := unsafeCast + +private unsafe def mkImpl (get : (i : Fin n) → α i) : DArray n α := + unsafeCast <| Array.ofFn fun i => (unsafeCast (get i) : NonScalar) + +private unsafe def getImpl (a : DArray n α) (i) : α i := + unsafeCast <| a.data.get ⟨i.val, lcProof⟩ + +private unsafe def ugetImpl (a : DArray n α) (i : USize) (h : i.toNat < n) : α ⟨i.toNat, h⟩ := + unsafeCast <| a.data.uget i lcProof + +private unsafe def setImpl (a : DArray n α) (i) (v : α i) : DArray n α := + unsafeCast <| a.data.set ⟨i.val, lcProof⟩ <| unsafeCast v + +private unsafe def usetImpl (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) : + DArray n α := unsafeCast <| a.data.uset i (unsafeCast v) lcProof + +private unsafe def modifyFImpl [Functor f] (a : DArray n α) (i : Fin n) + (t : α i → f (α i)) : f (DArray n α) := + let v := unsafeCast <| a.data.get ⟨i.val, lcProof⟩ + -- Make sure `v` is unshared, if possible, by replacing its array entry by `box(0)`. + let a := unsafeCast <| a.data.set ⟨i.val, lcProof⟩ (unsafeCast ()) + setImpl a i <$> t v + +private unsafe def umodifyFImpl [Functor f] (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → f (α ⟨i.toNat, h⟩)) : f (DArray n α) := + let v := unsafeCast <| a.data.uget i lcProof + -- Make sure `v` is unshared, if possible, by replacing its array entry by `box(0)`. + let a := unsafeCast <| a.data.uset i (unsafeCast ()) lcProof + usetImpl a i h <$> t v + +private unsafe def pushImpl (a : DArray n α) (v : β) : + DArray (n+1) fun i => if h : i.val < n then α ⟨i.val, h⟩ else β := + unsafeCast <| a.data.push <| unsafeCast v + +private unsafe def popImpl (a : DArray (n+1) α) : DArray n fun i => α i.castSucc := + unsafeCast <| a.data.pop + +private unsafe def copyImpl (a : DArray n α) : DArray n α := + unsafeCast <| a.data.extract 0 n + +private unsafe def foldlMImpl [Monad m] (a : DArray n α) (f : β → {i : Fin n} → α i → m β) + (init : β) : m β := + if n < USize.size then + loop 0 init + else + have : Inhabited β := ⟨init⟩ + -- array data exceeds the entire address space! + panic! "out of memory" +where + -- loop invariant: `i.toNat ≤ n` + loop (i : USize) (x : β) : m β := + if i.toNat ≥ n then pure x else + f x (a.ugetImpl i lcProof) >>= loop (i+1) + +private unsafe def foldrMImpl [Monad m] (a : DArray n α) (f : {i : Fin n} → α i → β → m β) + (init : β) : m β := + if h : n < USize.size then + loop (.ofNatCore n h) init + else + have : Inhabited β := ⟨init⟩ + -- array data exceeds the entire address space! + panic! "out of memory" +where + -- loop invariant: `i.toNat ≤ n` + loop (i : USize) (x : β) : m β := + if i = 0 then pure x else f (a.ugetImpl (i-1) lcProof) x >>= loop (i-1) + +@[specialize] +private unsafe def mapImpl (f : {i : Fin n} → α i → β i) (a : DArray n α) : DArray n β := + let f := fun i x => (unsafeCast (f (i:=i.cast lcProof) (unsafeCast x)) : NonScalar) + unsafeCast <| a.data.mapIdx f + +@[specialize] +private unsafe def amapImpl (f : {i : Fin n} → α i → β) (a : DArray n α) : Array β := + unsafeCast <| a.mapImpl f + +end unsafe_implementation + +attribute [implemented_by mkImpl] DArray.mk + +instance (α : Fin n → Type _) [(i : Fin n) → Inhabited (α i)] : Inhabited (DArray n α) where + default := mk fun _ => default + +/-- Gets the `DArray` item at index `i`. `O(1)`. -/ +@[implemented_by getImpl] +protected def get : DArray n α → (i : Fin n) → α i + | mk get => get + +@[simp, inherit_doc DArray.get] +protected abbrev getN (a : DArray n α) (i) (h : i < n := by get_elem_tactic) : α ⟨i, h⟩ := + a.get ⟨i, h⟩ + +/-- Gets the `DArray` item at index `i : USize`. Slightly faster than `get`; `O(1)`. -/ +@[implemented_by ugetImpl] +protected def uget (a : DArray n α) (i : USize) (h : i.toNat < n) : α ⟨i.toNat, h⟩ := + a.get ⟨i.toNat, h⟩ + +private def casesOnImpl.{u} {motive : DArray n α → Sort u} (a : DArray n α) + (h : (get : (i : Fin n) → α i) → motive (.mk get)) : motive a := + h a.get + +attribute [implemented_by casesOnImpl] DArray.casesOn + +/-- Sets the `DArray` item at index `i`. `O(1)` if exclusive else `O(n)`. -/ +@[implemented_by setImpl] +protected def set (a : DArray n α) (i : Fin n) (v : α i) : DArray n α := + mk fun j => if h : i = j then h ▸ v else a.get j + +/-- +Sets the `DArray` item at index `i : USize`. +Slightly faster than `set` and `O(1)` if exclusive else `O(n)`. +-/ +@[implemented_by usetImpl] +protected def uset (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) := + a.set ⟨i.toNat, h⟩ v + +@[simp, inherit_doc DArray.set] +protected abbrev setN (a : DArray n α) (i) (h : i < n := by get_elem_tactic) (v : α ⟨i, h⟩) := + a.set ⟨i, h⟩ v + +/-- Modifies the `DArray` item at index `i` using transform `t` and the functor `f`. -/ +@[implemented_by modifyFImpl] +protected def modifyF [Functor f] (a : DArray n α) (i : Fin n) + (t : α i → f (α i)) : f (DArray n α) := a.set i <$> t (a.get i) + +/-- Modifies the `DArray` item at index `i` using transform `t`. -/ +@[inline] +protected def modify (a : DArray n α) (i : Fin n) (t : α i → α i) : DArray n α := + a.modifyF (f:=Id) i t + +/-- Modifies the `DArray` item at index `i : USize` using transform `t` and the functor `f`. -/ +@[implemented_by umodifyFImpl] +protected def umodifyF [Functor f] (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → f (α ⟨i.toNat, h⟩)) : f (DArray n α) := a.uset i h <$> t (a.uget i h) + +/-- Modifies the `DArray` item at index `i : USize` using transform `t`. -/ +@[inline] +protected def umodify (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → α ⟨i.toNat, h⟩) : DArray n α := + a.umodifyF (f:=Id) i h t + +/-- Copies the `DArray` to an exclusive `DArray`. `O(1)` if exclusive else `O(n)`. -/ +@[implemented_by copyImpl] +protected def copy (a : DArray n α) : DArray n α := mk a.get + +/-- Push an element onto the end of a `DArray`. `O(1)` if exclusive else `O(n)`. -/ +@[implemented_by pushImpl] +protected def push (a : DArray n α) (v : β) : + DArray (n+1) fun i => if h : i.val < n then α ⟨i.val, h⟩ else β := + mk fun i => if h : i.val < n then dif_pos h ▸ a.get ⟨i.val, h⟩ else dif_neg h ▸ v + +/-- Delete the last item of a `DArray`. `O(1)`. -/ +@[implemented_by popImpl] +protected def pop (a : DArray (n+1) α) : DArray n fun i => α i.castSucc := + mk fun i => a.get i.castSucc + +/-- +Folds a monadic function over a `DArray` from left to right: +``` +DArray.foldlM a f x₀ = do + let x₁ ← f x₀ (a.get 0) + let x₂ ← f x₁ (a.get 1) + ... + let xₙ ← f xₙ₋₁ (a.get (n-1)) + pure xₙ +``` +-/ +@[implemented_by foldlMImpl] +def foldlM [Monad m] (a : DArray n α) (f : β → {i : Fin n} → α i → m β) (init : β) : m β := + Fin.foldlM n (f · <| a.get ·) init + +/-- Folds a function over a `DArray` from the left. -/ +def foldl (a : DArray n α) (f : β → {i : Fin n} → α i → β) (init : β) : β := + a.foldlM (m:=Id) f init + +/-- +Folds a monadic function over a `DArray` from right to left: +``` +DArray.foldrM a f x₀ = do + let x₁ ← f (a.get (n-1)) x₀ + let x₂ ← f (a.get (n-2)) x₁ + ... + let xₙ ← f (a.get 0) xₙ₋₁ + pure xₙ +``` +-/ +def foldrM [Monad m] (a : DArray n α) (f : {i : Fin n} → α i → β → m β) (init : β) : m β := + Fin.foldrM n (f <| a.get ·) init + +/-- Folds a function over a `DArray` from the right. -/ +def foldr (a : DArray n α) (f : {i : Fin n} → α i → β → β) (init : β) : β := + a.foldrM (m:=Id) f init + +/-- Implementation of `ForIn` for `DArray`. -/ +@[specialize] +def forIn [Monad m] (a : DArray n α) (init : β) (f : Sigma α → β → m (ForInStep β)) : m β := do + match ← a.foldlM step (.yield init) with + | .done r => pure r + | .yield r => pure r +where + /-- Step function for `forIn`. -/ + step : ForInStep β → {i : Fin n} → α i → m (ForInStep β) + | .done r, _, _ => pure (.done r) + | .yield r, i, x => f ⟨i, x⟩ r + +instance (m : Type _ → Type _) (α : Fin n → Type _) : ForIn m (DArray n α) (Sigma α) where + forIn := forIn + +/-- +Applies `f : {i : Fin n} → α i → β i` to each element of a `DArray n α`, +returns the dependent array of results. +-/ +@[implemented_by mapImpl] +protected def map (f : {i : Fin n} → α i → β i) (a : DArray n α) : DArray n β := + mk fun i => f (a.get i) + +/-- +Applies `f : {i : Fin n} → α i → β` to each element of a `DArray n α`, +returns the (non-dependent) array of results. +-/ +@[implemented_by amapImpl] +def amap (f : {i : Fin n} → α i → β) (a : DArray n α) : Array β := + Array.ofFn fun i => f (a.get i) diff --git a/Batteries/Data/DArray/Lemmas.lean b/Batteries/Data/DArray/Lemmas.lean new file mode 100644 index 0000000000..551c7d17ab --- /dev/null +++ b/Batteries/Data/DArray/Lemmas.lean @@ -0,0 +1,98 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: François G. Dorais +-/ + +import Batteries.Data.DArray.Basic + +namespace Batteries.DArray + +@[ext] +protected theorem ext : {a b : DArray n α} → (∀ i, a.get i = b.get i) → a = b + | mk _, mk _, h => congrArg _ <| funext fun i => h i + +@[simp] +theorem get_mk (i : Fin n) : DArray.get (.mk init) i = init i := rfl + +theorem set_mk {α : Fin n → Type _} {init : (i : Fin n) → α i} (i : Fin n) (v : α i) : + DArray.set (.mk init) i v = .mk fun j => if h : i = j then h ▸ v else init j := rfl + +@[simp] +theorem get_set (a : DArray n α) (i : Fin n) (v : α i) : (a.set i v).get i = v := by + simp only [DArray.get, DArray.set, dif_pos] + +theorem get_set_ne (a : DArray n α) (v : α i) (h : i ≠ j) : (a.set i v).get j = a.get j := by + simp only [DArray.get, DArray.set, dif_neg h] + +@[simp] +theorem set_set (a : DArray n α) (i : Fin n) (v w : α i) : (a.set i v).set i w = a.set i w := by + ext j + if h : i = j then + rw [← h, get_set, get_set] + else + rw [get_set_ne _ _ h, get_set_ne _ _ h, get_set_ne _ _ h] + +theorem get_modifyF [Functor f] [LawfulFunctor f] (a : DArray n α) (i : Fin n) (t : α i → f (α i)) : + (DArray.get . i) <$> a.modifyF i t = t (a.get i) := by + simp [DArray.modifyF, ← comp_map] + conv => rhs; rw [← id_map (t (a.get i))] + congr; ext; simp + +@[simp] +theorem get_modify (a : DArray n α) (i : Fin n) (t : α i → α i) : + (a.modify i t).get i = t (a.get i) := get_modifyF (f:=Id) a i t + +theorem get_modify_ne (a : DArray n α) (t : α i → α i) (h : i ≠ j) : + (a.modify i t).get j = a.get j := get_set_ne _ _ h + +@[simp] +theorem set_modify (a : DArray n α) (i : Fin n) (t : α i → α i) (v : α i) : + (a.set i v).modify i t = a.set i (t v) := by + ext j + if h : i = j then + cases h; simp + else + simp [h, get_modify_ne, get_set_ne] + +@[simp] +theorem uget_eq_get (a : DArray n α) (i : USize) (h : i.toNat < n) : + a.uget i h = a.get ⟨i.toNat, h⟩ := rfl + +@[simp] +theorem uset_eq_set (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) : + a.uset i h v = a.set ⟨i.toNat, h⟩ v := rfl + +@[simp] +theorem umodifyF_eq_modifyF [Functor f] (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → f (α ⟨i.toNat, h⟩)) : a.umodifyF i h t = a.modifyF ⟨i.toNat, h⟩ t := rfl + +@[simp] +theorem umodify_eq_modify (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → α ⟨i.toNat, h⟩) : a.umodify i h t = a.modify ⟨i.toNat, h⟩ t := rfl + +theorem foldlM_eq_fin_foldlM [Monad m] (a : DArray n α) (f : β → {i : Fin n} → α i → m β) (init) : + a.foldlM f init = Fin.foldlM n (f · <| a.get ·) init := rfl + +theorem foldl_eq_foldlM (a : DArray n α) (f : β → {i : Fin n} → α i → β) (init) : + a.foldl f init = a.foldlM (m:=Id) f init := rfl + +theorem foldrM_eq_fin_foldrM [Monad m] (a : DArray n α) (f : {i : Fin n} → α i → β → m β) (init) : + a.foldrM f init = Fin.foldrM n (f <| a.get ·) init := rfl + +theorem foldr_eq_foldrM (a : DArray n α) (f : {i : Fin n} → α i → β → β) (init) : + a.foldr f init = a.foldrM (m:=Id) f init := rfl + +@[simp] +theorem copy_eq (a : DArray n α) : a.copy = a := rfl + +theorem get_map {β : Fin n → Type _} (f : {i : Fin n} → α i → β i) (a : DArray n α) (i : Fin n) : + (a.map f).get i = f (a.get i) := rfl + +@[simp] +theorem size_amap (f : {i : Fin n} → α i → β) (a : DArray n α) : + (a.amap f).size = n := Array.size_ofFn .. + +theorem getElem_amap (f : {i : Fin n} → α i → β) (a : DArray n α) (i : Fin n) + (h : i.val < (a.amap f).size := (a.size_amap f).symm ▸ i.is_lt) : + (a.amap f)[i] = f (a.get i) := by simp [amap] diff --git a/Batteries/Data/Fin/Basic.lean b/Batteries/Data/Fin/Basic.lean index 495344f90f..0651aac9af 100644 --- a/Batteries/Data/Fin/Basic.lean +++ b/Batteries/Data/Fin/Basic.lean @@ -15,16 +15,66 @@ def enum (n) : Array (Fin n) := Array.ofFn id /-- `list n` is the list of all elements of `Fin n` in order -/ def list (n) : List (Fin n) := (enum n).data -/-- Folds over `Fin n` from the left: `foldl 3 f x = f (f (f x 0) 1) 2`. -/ -@[inline] def foldl (n) (f : α → Fin n → α) (init : α) : α := loop init 0 where - /-- Inner loop for `Fin.foldl`. `Fin.foldl.loop n f x i = f (f (f x i) ...) (n-1)` -/ - loop (x : α) (i : Nat) : α := - if h : i < n then loop (f x ⟨i, h⟩) (i+1) else x +/-- +Folds a monadic function over `Fin n` from left to right: +``` +Fin.foldlM n f x₀ = do + let x₁ ← f x₀ 0 + let x₂ ← f x₁ 1 + ... + let xₙ ← f xₙ₋₁ (n-1) + pure xₙ +``` +-/ +@[inline] def foldlM [Monad m] (n) (f : α → Fin n → m α) (init : α) : m α := loop init 0 where + /-- + Inner loop for `Fin.foldlM`. + ``` + Fin.foldlM.loop n f xᵢ i = do + let xᵢ₊₁ ← f xᵢ i + ... + let xₙ ← f xₙ₋₁ (n-1) + pure xₙ + ``` + -/ + loop (x : α) (i : Nat) : m α := do + if h : i < n then f x ⟨i, h⟩ >>= (loop · (i+1)) else pure x termination_by n - i +/-- +Folds a monadic function over `Fin n` from right to left: +``` +Fin.foldrM n f xₙ = do + let xₙ₋₁ ← f (n-1) xₙ + let xₙ₋₂ ← f (n-2) xₙ₋₁ + ... + let x₀ ← f 0 x₁ + pure x₀ +``` +-/ +@[inline] def foldrM [Monad m] (n) (f : Fin n → α → m α) (init : α) : m α := + loop ⟨n, Nat.le_refl n⟩ init where + /-- + Inner loop for `Fin.foldrM`. + ``` + Fin.foldrM.loop n f i xᵢ = do + let xᵢ₋₁ ← f (i-1) xᵢ + ... + let x₁ ← f 1 x₂ + let x₀ ← f 0 x₁ + pure x₀ + ``` + -/ + loop : {i // i ≤ n} → α → m α + | ⟨0, _⟩, x => pure x + | ⟨i+1, h⟩, x => f ⟨i, h⟩ x >>= loop ⟨i, Nat.le_of_lt h⟩ + +-- These are also defined in core in the root namespace! + /-- Folds over `Fin n` from the right: `foldr 3 f x = f 0 (f 1 (f 2 x))`. -/ -@[inline] def foldr (n) (f : Fin n → α → α) (init : α) : α := loop ⟨n, Nat.le_refl n⟩ init where - /-- Inner loop for `Fin.foldr`. `Fin.foldr.loop n f i x = f 0 (f ... (f (i-1) x))` -/ - loop : {i // i ≤ n} → α → α - | ⟨0, _⟩, x => x - | ⟨i+1, h⟩, x => loop ⟨i, Nat.le_of_lt h⟩ (f ⟨i, h⟩ x) +@[inline] def foldr (n) (f : Fin n → α → α) (init : α) : α := + foldrM (m:=Id) n f init + +/-- Folds over `Fin n` from the left: `foldl 3 f x = f (f (f x 0) 1) 2`. -/ +@[inline] def foldl (n) (f : α → Fin n → α) (init : α) : α := + foldlM (m:=Id) n f init diff --git a/Batteries/Data/Fin/Lemmas.lean b/Batteries/Data/Fin/Lemmas.lean index 7167efa0ed..1510ea7cb3 100644 --- a/Batteries/Data/Fin/Lemmas.lean +++ b/Batteries/Data/Fin/Lemmas.lean @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Mario Carneiro -/ import Batteries.Data.Fin.Basic +import Batteries.Data.List.Lemmas namespace Fin @@ -52,28 +53,79 @@ theorem list_reverse (n) : (list n).reverse = (list n).map rev := by conv => rhs; rw [list_succ] simp [List.reverse_map, ih, Function.comp_def, rev_succ] -/-! ### foldl -/ +/-! ### foldlM -/ -theorem foldl_loop_lt (f : α → Fin n → α) (x) (h : m < n) : - foldl.loop n f x m = foldl.loop n f (f x ⟨m, h⟩) (m+1) := by - rw [foldl.loop, dif_pos h] +theorem foldlM_loop_lt [Monad m] (f : α → Fin n → m α) (x) (h : i < n) : + foldlM.loop n f x i = f x ⟨i, h⟩ >>= (foldlM.loop n f . (i+1)) := by + rw [foldlM.loop, dif_pos h] -theorem foldl_loop_eq (f : α → Fin n → α) (x) : foldl.loop n f x n = x := by - rw [foldl.loop, dif_neg (Nat.lt_irrefl _)] +theorem foldlM_loop_eq [Monad m] (f : α → Fin n → m α) (x) : foldlM.loop n f x n = pure x := by + rw [foldlM.loop, dif_neg (Nat.lt_irrefl _)] -theorem foldl_loop (f : α → Fin (n+1) → α) (x) (h : m < n+1) : - foldl.loop (n+1) f x m = foldl.loop n (fun x i => f x i.succ) (f x ⟨m, h⟩) m := by - if h' : m < n then - rw [foldl_loop_lt _ _ h, foldl_loop_lt _ _ h', foldl_loop]; rfl +theorem foldlM_loop [Monad m] (f : α → Fin (n+1) → m α) (x) (h : i < n+1) : + foldlM.loop (n+1) f x i = f x ⟨i, h⟩ >>= (foldlM.loop n (fun x j => f x j.succ) . i) := by + if h' : i < n then + rw [foldlM_loop_lt _ _ h] + congr; funext + rw [foldlM_loop_lt _ _ h', foldlM_loop]; rfl else cases Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.not_lt.1 h') - rw [foldl_loop_lt, foldl_loop_eq, foldl_loop_eq] -termination_by n - m + rw [foldlM_loop_lt] + congr; funext + rw [foldlM_loop_eq, foldlM_loop_eq] +termination_by n - i + +@[simp] theorem foldlM_zero [Monad m] (f : α → Fin 0 → m α) (x) : foldlM 0 f x = pure x := rfl + +theorem foldlM_succ [Monad m] (f : α → Fin (n+1) → m α) (x) : + foldlM (n+1) f x = f x 0 >>= foldlM n (fun x j => f x j.succ) := foldlM_loop .. + +theorem foldlM_eq_foldlM_list [Monad m] (f : α → Fin n → m α) (x) : + foldlM n f x = (list n).foldlM f x := by + induction n generalizing x with + | zero => rfl + | succ n ih => + rw [foldlM_succ, list_succ, List.foldlM_cons] + congr; funext + rw [List.foldlM_map, ih] + +/-! ### foldrM -/ + +theorem foldrM_loop_zero [Monad m] (f : Fin n → α → m α) (x) : + foldrM.loop n f ⟨0, Nat.zero_le _⟩ x = pure x := rfl + +theorem foldrM_loop_succ [Monad m] (f : Fin n → α → m α) (x) (h : i < n) : + foldrM.loop n f ⟨i+1, h⟩ x = f ⟨i, h⟩ x >>= foldrM.loop n f ⟨i, Nat.le_of_lt h⟩ := rfl + +theorem foldrM_loop [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) (x) (h : i+1 ≤ n+1) : + foldrM.loop (n+1) f ⟨i+1, h⟩ x = + foldrM.loop n (fun j => f j.succ) ⟨i, Nat.le_of_succ_le_succ h⟩ x >>= f 0 := by + induction i generalizing x with + | zero => + rw [foldrM_loop_zero, foldrM_loop_succ, pure_bind] + conv => rhs; rw [←bind_pure (f 0 x)] + congr + | succ i ih => + rw [foldrM_loop_succ, foldrM_loop_succ, bind_assoc] + congr; funext; exact ih .. + +@[simp] theorem foldrM_zero [Monad m] (f : Fin 0 → α → m α) (x) : foldrM 0 f x = pure x := rfl + +theorem foldrM_succ [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) (x) : + foldrM (n+1) f x = foldrM n (fun i => f i.succ) x >>= f 0 := foldrM_loop .. + +theorem foldrM_eq_foldrM_list [Monad m] [LawfulMonad m] (f : Fin n → α → m α) (x) : + foldrM n f x = (list n).foldrM f x := by + induction n with + | zero => rfl + | succ n ih => rw [foldrM_succ, ih, list_succ, List.foldrM_cons, List.foldrM_map] + +/-! ### foldl/foldr -/ @[simp] theorem foldl_zero (f : α → Fin 0 → α) (x) : foldl 0 f x = x := rfl theorem foldl_succ (f : α → Fin (n+1) → α) (x) : - foldl (n+1) f x = foldl n (fun x i => f x i.succ) (f x 0) := foldl_loop .. + foldl (n+1) f x = foldl n (fun x i => f x i.succ) (f x 0) := foldlM_succ .. theorem foldl_succ_last (f : α → Fin (n+1) → α) (x) : foldl (n+1) f x = f (foldl n (f · ·.castSucc) x) (last n) := by @@ -87,24 +139,10 @@ theorem foldl_eq_foldl_list (f : α → Fin n → α) (x) : foldl n f x = (list | zero => rfl | succ n ih => rw [foldl_succ, ih, list_succ, List.foldl_cons, List.foldl_map] -/-! ### foldr -/ - -theorem foldr_loop_zero (f : Fin n → α → α) (x) : foldr.loop n f ⟨0, Nat.zero_le _⟩ x = x := rfl - -theorem foldr_loop_succ (f : Fin n → α → α) (x) (h : m < n) : - foldr.loop n f ⟨m+1, h⟩ x = foldr.loop n f ⟨m, Nat.le_of_lt h⟩ (f ⟨m, h⟩ x) := rfl - -theorem foldr_loop (f : Fin (n+1) → α → α) (x) (h : m+1 ≤ n+1) : - foldr.loop (n+1) f ⟨m+1, h⟩ x = - f 0 (foldr.loop n (fun i => f i.succ) ⟨m, Nat.le_of_succ_le_succ h⟩ x) := by - induction m generalizing x with - | zero => simp [foldr_loop_zero, foldr_loop_succ] - | succ m ih => rw [foldr_loop_succ, ih]; rfl - @[simp] theorem foldr_zero (f : Fin 0 → α → α) (x) : foldr 0 f x = x := rfl theorem foldr_succ (f : Fin (n+1) → α → α) (x) : - foldr (n+1) f x = f 0 (foldr n (fun i => f i.succ) x) := foldr_loop .. + foldr (n+1) f x = f 0 (foldr n (fun i => f i.succ) x) := foldrM_succ .. theorem foldr_succ_last (f : Fin (n+1) → α → α) (x) : foldr (n+1) f x = foldr n (f ·.castSucc) (f (last n) x) := by @@ -118,8 +156,6 @@ theorem foldr_eq_foldr_list (f : Fin n → α → α) (x) : foldr n f x = (list | zero => rfl | succ n ih => rw [foldr_succ, ih, list_succ, List.foldr_cons, List.foldr_map] -/-! ### foldl/foldr -/ - theorem foldl_rev (f : Fin n → α → α) (x) : foldl n (fun x i => f i.rev x) x = foldr n f x := by induction n generalizing x with diff --git a/Batteries/Data/List/Lemmas.lean b/Batteries/Data/List/Lemmas.lean index d23566b21f..d61c2c03d2 100644 --- a/Batteries/Data/List/Lemmas.lean +++ b/Batteries/Data/List/Lemmas.lean @@ -2822,3 +2822,13 @@ theorem lt_antisymm' [LT α] have ab : ¬a < b := fun ab => h₁ (.head _ _ ab) cases lt_antisymm ab (fun ba => h₂ (.head _ _ ba)) rw [ih (fun ll => h₁ (.tail ab ab ll)) (fun ll => h₂ (.tail ab ab ll))] + +/-! ### foldlM and foldrM -/ + +theorem foldlM_map [Monad m] (f : β₁ → β₂) (g : α → β₂ → m α) (l : List β₁) (init : α) : + (l.map f).foldlM g init = l.foldlM (fun x y => g x (f y)) init := by + induction l generalizing g init <;> simp [*] + +theorem foldrM_map [Monad m] [LawfulMonad m] (f : β₁ → β₂) (g : β₂ → α → m α) (l : List β₁) + (init : α) : (l.map f).foldrM g init = l.foldrM (fun x y => g (f x) y) init := by + induction l generalizing g init <;> simp [*] diff --git a/test/darray.lean b/test/darray.lean new file mode 100644 index 0000000000..5f3747d5eb --- /dev/null +++ b/test/darray.lean @@ -0,0 +1,43 @@ +import Batteries.Data.DArray.Basic + +open Batteries + +def foo : DArray 3 fun | 0 => String | 1 => Nat | 2 => Array Nat := + .mk fun | 0 => "foo" | 1 => 42 | 2 => #[4, 2] + +def bar := foo.set 0 "bar" + +#guard foo.get 0 == "foo" +#guard foo.get 1 == 42 +#guard foo.get 2 == #[4, 2] + +#guard (foo.set 1 1).get 0 == "foo" +#guard (foo.set 1 1).get 1 == 1 +#guard (foo.set 1 1).get 2 == #[4, 2] + +#guard bar.get 0 == "bar" +#guard (bar.set 0 (foo.get 0)).get 0 == "foo" +#guard ((bar.set 0 "baz").set 1 1).get 0 == "baz" +#guard ((bar.set 0 "baz").set 0 "foo").get 0 == "foo" +#guard ((bar.set 0 "foo").set 0 "baz").get 0 == "baz" + +def Batteries.DArray.head : DArray (n+1) α → α 0 + | mk f => f 0 + +#guard foo.head == "foo" +#guard bar.head == "bar" + +abbrev Data (n : Nat) : Type _ := DArray (n+1) fun | 0 => String | ⟨_+1,_⟩ => Nat + +def Data.sum (a : Data n) : String × Nat := Id.run do + let mut r := ("", 0) + for ⟨i, x⟩ in a do + match i with + | 0 => r := (x, r.snd) + | ⟨_+1,_⟩ => r := ⟨r.fst, r.snd+x⟩ + return r + +def test : Data 2 := + .mk fun | 0 => "foo" | 1 => 4 | 2 => 2 + +#guard test.sum == ("foo", 6)