From 517399b5da644d2ffaba6887a387860f9f0b451a Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Wed, 29 May 2024 21:05:26 -0400 Subject: [PATCH] feat: add `map` and `toArray` --- Batteries/Data/DArray/Basic.lean | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/Batteries/Data/DArray/Basic.lean b/Batteries/Data/DArray/Basic.lean index 17d9e29f91..abfc1e3bdb 100644 --- a/Batteries/Data/DArray/Basic.lean +++ b/Batteries/Data/DArray/Basic.lean @@ -78,6 +78,14 @@ private unsafe def popImpl (a : DArray (n+1) α) : DArray n fun i => α i.castSu private unsafe def copyImpl (a : DArray n α) : DArray n α := unsafeCast <| a.data.extract 0 n +private unsafe def toArrayImpl (a : DArray n fun _ => α) : Array α := + unsafeCast a + +@[specialize] +private unsafe def mapImpl (f : {i : Fin n} → α i → β i) (a : DArray n α) : DArray n β := + let f := fun i x => (unsafeCast (f (i:=i.cast lcProof) (unsafeCast x)) : NonScalar) + unsafeCast <| a.data.mapIdx f + private unsafe def foldlMImpl [Monad m] (a : DArray n α) (f : β → {i : Fin n} → α i → m β) (init : β) : m β := if n < USize.size then @@ -185,6 +193,16 @@ protected def push (a : DArray n α) (v : β) : protected def pop (a : DArray (n+1) α) : DArray n fun i => α i.castSucc := mk fun i => a.get i.castSucc +/-- Cast a dependent array with constant types to an array. `O(1)` if exclusive else `O(n)`. -/ +@[implemented_by toArrayImpl] +protected def toArray (a : DArray n fun _ => α) : Array α := + .ofFn fun i => a.get i + +/-- Applies `f` to each element of a dependent array, returns the array of results. -/ +@[implemented_by mapImpl] +protected def map (f : {i : Fin n} → α i → β i) (a : DArray n α) : DArray n β := + mk fun i => f (a.get i) + /-- Folds a monadic function over a `DArray` from left to right: ```