Skip to content

Commit

Permalink
Add smapWithBounds
Browse files Browse the repository at this point in the history
Co-authored-by: Christiaan Baaij <[email protected]>
  • Loading branch information
kleinreact and christiaanb committed Oct 11, 2024
1 parent 5706eaf commit e0b8905
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 17 deletions.
2 changes: 2 additions & 0 deletions changelog/2024-04-08T06_51_45+00_00_smap_with_bounds
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
CHANGED: `dfold` now offers a proof witness for the upper bound of the vector size to the folding function. Note that this change may require additional type annotations, as solutions working in the past may complain with an untouchable type error now.
ADDED: `smapWithBounds` extending `smap` via offering a proof witness for the upper bound of the vector size to the mapping function.
22 changes: 19 additions & 3 deletions clash-ghc/src-ghc/Clash/GHC/Evaluator/Primitive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ module Clash.GHC.Evaluator.Primitive
, isUndefinedXPrimVal
) where

import qualified Control.Lens as Lens
import Control.Concurrent.Supply (Supply,freshId)
import Control.DeepSeq (force)
import Control.Exception (ArithException(..), Exception, tryJust, evaluate)
Expand Down Expand Up @@ -93,24 +94,26 @@ import TysWiredIn (tupleTyCon)
import Clash.Class.BitPack (pack,unpack)
import Clash.Core.DataCon (DataCon (..))
import Clash.Core.Evaluator.Types
import Clash.Core.FreeVars (typeFreeVars)
import Clash.Core.HasType (piResultTys, applyTypeToArgs)
import Clash.Core.Literal (Literal (..))
import Clash.Core.Name
(Name (..), NameSort (..), mkUnsafeSystemName)
import Clash.Core.Pretty (showPpr)
import Clash.Core.Subst (extendTvSubst, mkSubst, substTy)
import Clash.Core.Term
(IsMultiPrim (..), Pat (..), PrimInfo (..), Term (..), WorkInfo (..), mkApps,
PrimUnfolding(..), collectArgs)
import Clash.Core.Type
(Type (..), ConstTy (..), LitTy (..), TypeView (..), mkFunTy, mkTyConApp,
splitFunForallTy, tyView)
normalizeType, splitFunForallTy, tyView)
import Clash.Core.TyCon
(TyConMap, TyConName, tyConDataCons)
import Clash.Core.TysPrim
import Clash.Core.Util
(mkRTree,mkVec,tyNatSize,dataConInstArgTys,primCo, mkSelectorCase,undefinedPrims,
undefinedXPrims)
import Clash.Core.Var (mkLocalId, mkTyVar)
import Clash.Core.Var (mkLocalId, mkTyVar, varName)
import qualified Clash.Data.UniqMap as UniqMap
import Clash.Debug
import Clash.GHC.GHC2Core (modNameM)
Expand Down Expand Up @@ -4370,13 +4373,24 @@ ghcPrimStep tcm isSubj pInfo tys args mach = case primName pInfo of
0 -> reduce (valToTerm z)
_ -> let (tyArgs,_) = splitFunForallTy ty
(tyArgs',_) = splitFunForallTy (Either.rights tyArgs !! 2)
TyConApp snatTcNm _ = tyView (Either.rights tyArgs' !! 0)
ubpT = Either.rights tyArgs' !! 0
fTVs = Lens.toListOf typeFreeVars ubpT
Just tvN = List.find ((== "n") . nameOcc . varName) fTVs
subst0 = extendTvSubst (mkSubst is0) tvN k'ty
Just tvK = List.find ((== "k") . nameOcc . varName) fTVs
subst1 = extendTvSubst subst0 tvK (LitTy (NumTy k'))
witness = normalizeType tcm (substTy subst1 ubpT)
TyConApp tupTcNm _ = tyView witness
Just witnessTc = UniqMap.lookup tupTcNm tcm
ubp : _ = tyConDataCons witnessTc
TyConApp snatTcNm _ = tyView (Either.rights tyArgs' !! 1)
Just snatTc = UniqMap.lookup snatTcNm tcm
[snatDc] = tyConDataCons snatTc
k'ty = LitTy (NumTy (k'-1))
in reduceWHNF $
mkApps (valToTerm f)
[Right k'ty
,Left (Data ubp)
,Left (mkApps (Data snatDc)
[Right k'ty
,Left (Literal (NaturalLiteral (k'-1)))])
Expand All @@ -4392,6 +4406,8 @@ ghcPrimStep tcm isSubj pInfo tys args mach = case primName pInfo of
,Left (Either.lefts vArgs !! 2)
])
]
where
is0 = mScopeNames mach
"Clash.Sized.Vector.dtfold"
| isSubj
, pTy : kTy : aTy : _ <- tys
Expand Down
28 changes: 22 additions & 6 deletions clash-lib/src/Clash/Normalize/PrimitiveReductions.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import Control.Lens ((.=))
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Maybe (MaybeT (..))
import Data.Bifunctor (second)
import Data.List (mapAccumR)
import Data.List (mapAccumR, uncons)
import Data.List.Extra (zipEqual)
#if MIN_VERSION_base(4,20,0)
import qualified Data.List.NonEmpty as NE hiding (unzip)
Expand All @@ -69,17 +69,21 @@ import SrcLoc (wiredInSrcSpan)
#endif

import Clash.Core.DataCon (DataCon)
import Clash.Core.FreeVars (typeFreeVars)
import Clash.Core.HasType

import Clash.Core.Literal (Literal (..))
import Clash.Core.Name
(nameOcc, Name(..), NameSort(User), mkUnsafeSystemName)
import Clash.Core.Pretty (showPpr)
import Clash.Core.Subst (extendTvSubst, mkSubst, substTy)
import Clash.Core.Term
(IsMultiPrim (..), CoreContext (..), PrimInfo (..), Term (..), WorkInfo (..), Pat (..),
collectTermIds, mkApps, PrimUnfolding(..))
import Clash.Core.Type (LitTy (..), Type (..),
TypeView (..), coreView1,
mkFunTy, mkTyConApp,
normalizeType,
splitFunForallTy, tyView)
import Clash.Core.TyCon
(TyConMap, TyConName, tyConDataCons, tyConName)
Expand Down Expand Up @@ -735,22 +739,34 @@ reduceDFold n aTy _kn _motive fun start arg (TransformContext is0 _ctx) = do
(uniqs1,(vars,elems)) = second (second sconcat . NE.unzip)
$ extractElems uniqs0 is1 consCon aTy 'D' n arg
snatDc = Maybe.fromMaybe (error "reduceDFold: faild to build SNat") $ do
(_ltv:Right snTy:_,_) <- pure (splitFunForallTy (inferCoreTypeOf tcm fun))
(_ltv:_rubp:Right snTy:_,_) <- pure (splitFunForallTy (inferCoreTypeOf tcm fun))
(TyConApp snatTcNm _) <- pure (tyView snTy)
snatTc <- UniqMap.lookup snatTcNm tcm
Maybe.listToMaybe (tyConDataCons snatTc)
lbody = doFold (buildSNat snatDc) (n-1) (NE.toList vars)
ubp k = Maybe.fromMaybe
(error "reduceDFold: failed to extract upper bound proof") $ do
(_ltv:Right ubpT:_,_) <- pure (splitFunForallTy (inferCoreTypeOf tcm fun))
-- toListOf does not de-duplicate, but we know that there is only
-- one free variable in here, thus, taking the first element is fine
(tvN, _) <- uncons $ Lens.toListOf typeFreeVars ubpT
let subst = extendTvSubst (mkSubst is0) tvN (LitTy (NumTy k))
let witness = normalizeType tcm (substTy subst ubpT)
(TyConApp tupTcNm _) <- pure (tyView witness)
witnessTc <- UniqMap.lookup tupTcNm tcm
Maybe.listToMaybe (tyConDataCons witnessTc)
lbody = doFold ubp (buildSNat snatDc) (n-1) (NE.toList vars)
lb = Letrec (NE.init elems) lbody
uniqSupply Lens..= uniqs1
changed lb
go _ ty = error $ $(curLoc) ++ "reduceDFold: argument does not have a vector type: " ++ showPpr ty

doFold _ _ [] = start
doFold snDc k (x:xs) = mkApps fun
doFold _ _ _ [] = start
doFold ubp snDc k (x:xs) = mkApps fun
[Right (LitTy (NumTy k))
,Left (Data (ubp k))
,Left (snDc k)
,Left x
,Left (doFold snDc (k-1) xs)
,Left (doFold ubp snDc (k-1) xs)
]

-- | Replace an application of the @Clash.Sized.Vector.head@ primitive on
Expand Down
33 changes: 26 additions & 7 deletions clash-prelude/src/Clash/Sized/Vector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ module Clash.Sized.Vector
, rotateLeft, rotateRight, rotateLeftS, rotateRightS
-- * Element-wise operations
-- ** Mapping
, map, imap, smap
, map, imap, smap, smapWithBounds
-- ** Zipping
, zipWith, zipWith3, zipWith4, zipWith5, zipWith6, zipWith7
, zip, zip3, zip4, zip5, zip6, zip7
Expand Down Expand Up @@ -2299,7 +2299,10 @@ lazyV = lazyV' (repeat ())
-- >>> import Data.Singletons (Apply, Proxy (..), TyFun)
-- >>> data Append (m :: Nat) (a :: Type) (f :: TyFun Nat Type) :: Type
-- >>> type instance Apply (Append m a) l = Vec (l + m) a
-- >>> let append' xs ys = dfold (Proxy :: Proxy (Append m a)) (const (:>)) ys xs
-- >>> :{
-- >>> append' :: forall a k m. KnownNat k => Vec k a -> Vec m a -> Vec (k + m) a
-- >>> append' xs ys = dfold (Proxy :: Proxy (Append m a)) (const ((:>) @a)) ys xs
-- >>> :}
--
-- === Example usage
--
Expand Down Expand Up @@ -2377,7 +2380,7 @@ lazyV = lazyV' (repeat ())
-- fold that produces a structure with a depth of O(log_2(@'length' xs@)).
dfold :: forall p k a . KnownNat k
=> Proxy (p :: TyFun Nat Type -> Type) -- ^ The /motive/
-> (forall l . SNat l -> a -> (p @@ l) -> (p @@ (l + 1)))
-> (forall n . n + 1 <= k => SNat n -> a -> (p @@ n) -> (p @@ (n + 1)))
-- ^ Function to fold.
--
-- __NB__: The @SNat l@ is __not__ the index (see (`!!`)) to the
Expand All @@ -2388,7 +2391,7 @@ dfold :: forall p k a . KnownNat k
-> (p @@ k)
dfold _ f z xs = go (snatProxy (asNatProxy xs)) xs
where
go :: SNat n -> Vec n a -> (p @@ n)
go :: n <= k => SNat n -> Vec n a -> (p @@ n)
go _ Nil = z
go s (y `Cons` ys) =
let s' = s `subSNat` d1
Expand Down Expand Up @@ -2557,7 +2560,7 @@ __NB__: The depth, or delay, of the structure produced by
dtfold :: forall p k a . KnownNat k
=> Proxy (p :: TyFun Nat Type -> Type) -- ^ The /motive/
-> (a -> (p @@ 0)) -- ^ Function to apply to every element
-> (forall l . SNat l -> (p @@ l) -> (p @@ l) -> (p @@ (l + 1)))
-> (forall n . SNat n -> (p @@ n) -> (p @@ n) -> (p @@ (n + 1)))
-- ^ Function to combine results.
--
-- __NB__: The @SNat l@ indicates the depth/height of the node in the
Expand Down Expand Up @@ -2616,7 +2619,7 @@ type instance Apply (VCons a) l = Vec l a
--
-- <<doc/csSort.svg>>
vfold :: forall k a b . KnownNat k
=> (forall l . SNat l -> a -> Vec l b -> Vec (l + 1) b)
=> (forall n . SNat n -> a -> Vec n b -> Vec (n + 1) b)
-> Vec k a
-> Vec k b
vfold f xs = dfold (Proxy @(VCons b)) f Nil xs
Expand Down Expand Up @@ -2645,13 +2648,29 @@ minimum = fold (\x y -> if x <= y then x else y)
-- (1 :> 2 :> 3 :> Nil) :> (1 :> 2 :> 3 :> Nil) :> (1 :> 2 :> 3 :> Nil) :> Nil
-- >>> rotateMatrix xss
-- (1 :> 2 :> 3 :> Nil) :> (3 :> 1 :> 2 :> Nil) :> (2 :> 3 :> 1 :> Nil) :> Nil
smap :: forall k a b . KnownNat k => (forall l . SNat l -> a -> b) -> Vec k a -> Vec k b
smap :: forall k a b . KnownNat k => (forall n . SNat n -> a -> b) -> Vec k a -> Vec k b
smap f xs = reverse
$ dfold (Proxy @(VCons b))
(\sn x xs' -> f sn x :> xs')
Nil (reverse xs)
{-# INLINE smap #-}

-- | Extended version of 'smap' offering an additional boundary proof to
-- the mapped function. Note that the type checker may need additional type
-- annotations to resolve type ambiguity for this. Thus, if the boundary constraint
-- is not needed it is recommended to stay with 'smap' instead.
smapWithBounds ::
forall k a b .
KnownNat k =>
(forall n . n + 1 <= k => SNat n -> a -> b) ->
Vec k a ->
Vec k b
smapWithBounds f xs = reverse
$ dfold (Proxy @(VCons b))
(\sn x xs' -> f sn x :> xs')
Nil (reverse xs)
{-# INLINE smapWithBounds #-}

instance (KnownNat n, BitPack a) => BitPack (Vec n a) where
type BitSize (Vec n a) = n * (BitSize a)
pack = packXWith (concatBitVector# . map pack)
Expand Down
3 changes: 2 additions & 1 deletion tests/shouldwork/Vector/DFold.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import Data.Kind (Type)
data Append (m :: Nat) (a :: Type) (f :: TyFun Nat Type) :: Type
type instance Apply (Append m a) l = Vec (l + m) a

append' xs ys = dfold (Proxy :: Proxy (Append m a)) (const (:>)) ys xs
append' :: forall a k m. KnownNat k => Vec k a -> Vec m a -> Vec (k + m) a
append' xs ys = dfold (Proxy :: Proxy (Append m a)) (const ((:>) @a)) ys xs

topEntity :: (Vec 3 Int,Vec 7 Int) -> Vec 10 Int
topEntity = uncurry append'
Expand Down

0 comments on commit e0b8905

Please sign in to comment.