From 26798bd9a6c8a6e0ed5e7e8280770b33373f3d88 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Wed, 23 Feb 2022 10:54:12 +0100 Subject: [PATCH 01/67] link posable library --- accelerate.cabal | 3 + .../Array/Accelerate/Representation/POS.hs | 146 ++++++++++++++++++ .../Array/Accelerate/Representation/Shape.hs | 39 +++++ stack-9.0.yaml | 4 +- 4 files changed, 191 insertions(+), 1 deletion(-) create mode 100644 src/Data/Array/Accelerate/Representation/POS.hs diff --git a/accelerate.cabal b/accelerate.cabal index 118210a38..b7f647747 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -361,6 +361,8 @@ library , unique , unordered-containers >= 0.2 , vector >= 0.10 + , posable >= 0.9.0.0 + , ghc-typelits-knownnat >= 0.7.6 exposed-modules: -- The core language and reference implementation @@ -400,6 +402,7 @@ library Data.Array.Accelerate.Lifetime Data.Array.Accelerate.Pretty Data.Array.Accelerate.Representation.Array + Data.Array.Accelerate.Representation.POS Data.Array.Accelerate.Representation.Elt Data.Array.Accelerate.Representation.Shape Data.Array.Accelerate.Representation.Slice diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs new file mode 100644 index 000000000..1bf3d3507 --- /dev/null +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -0,0 +1,146 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE TypeFamilyDependencies #-} +{-# OPTIONS_HADDOCK hide #-} + +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +-- This is needed to derive POSable for tuples of size more then 4 +{-# OPTIONS_GHC -fconstraint-solver-iterations=16 #-} +-- | +-- Module : Data.Array.Accelerate.Representation.POS +-- Copyright : [2008..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Representation.POS (POSable(..), POS, POST, mkPOS, fromPOS) + where + +import Data.Array.Accelerate.Type + +import Data.Bits +import Data.Char +import Data.Kind +import Language.Haskell.TH.Extra hiding ( Type ) + +import GHC.Generics +import GHC.TypeLits + +import Data.Type.POSable.POSable +import Data.Type.POSable.Representation + +type POS a = (Finite (Choices a), Product (Fields a)) + +mkPOS :: (POSable a) => a -> POS a +mkPOS x = (choices x, fields x) + +fromPOS :: (POSable a) => POS a -> a +fromPOS (cs, fs) = fromPOSable cs fs + +type POST a = (Finite (Choices a), ProductType (Fields a)) + +mkPOST :: forall a . (POSable a) => POST a +mkPOST = (0, emptyFields @a) + +runQ $ do + let + -- XXX: we might want to do the digItOut trick used by FromIntegral? + -- + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + ] + + newtypes :: [Name] + newtypes = + [ ''CShort + , ''CUShort + , ''CInt + , ''CUInt + , ''CLong + , ''CULong + , ''CLLong + , ''CULLong + , ''CFloat + , ''CDouble + , ''CChar + , ''CSChar + , ''CUChar + ] + + mkSimple :: Name -> Q [Dec] + mkSimple name = + let t = conT name + in + [d| instance POSable $t where + type Choices $t = 1 + choices _ = 0 + + type Fields $t = '[ '[$t]] + fields x = Cons (Pick x) Nil + + fromPOSable 0 (Cons (Pick x) Nil) = x + fromPOSable _ _ = error "index out of range" + + emptyFields = PTCons (STSucc 0 STZero) PTNil + |] + + mkTuple :: Int -> Q Dec + mkTuple n = + let + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + ts = map varT xs + res = tupT ts + ctx = mapM (appT [t| POSable |]) ts + in + instanceD ctx [t| POSable $res |] [] + + mkNewtype :: Name -> Q [Dec] + mkNewtype name = do + r <- reify name + base <- case r of + TyConI (NewtypeD _ _ _ _ (NormalC _ [(_, ConT b)]) _) -> return b + _ -> error "unexpected case generating newtype Elt instance" + -- + [d| instance POSable $(conT name) + |] + -- + ss <- mapM mkSimple (integralTypes ++ floatingTypes) + ns <- mapM mkNewtype newtypes + ts <- mapM mkTuple [2..16] + -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] + return (concat ss ++ concat ns ++ ts) + + +type family Snoc2List x = xs | xs -> x where + Snoc2List () = '[] + Snoc2List (xs, x) = (x ': Snoc2List xs) diff --git a/src/Data/Array/Accelerate/Representation/Shape.hs b/src/Data/Array/Accelerate/Representation/Shape.hs index fa3651c03..2c4f3906c 100644 --- a/src/Data/Array/Accelerate/Representation/Shape.hs +++ b/src/Data/Array/Accelerate/Representation/Shape.hs @@ -2,6 +2,14 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE FlexibleContexts #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Representation.Shape @@ -19,6 +27,8 @@ module Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Error import Data.Array.Accelerate.Type import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Representation.POS +import Data.Type.POSable.Representation import Language.Haskell.TH.Extra import Prelude hiding ( zip ) @@ -195,3 +205,32 @@ liftShapeR :: ShapeR sh -> CodeQ (ShapeR sh) liftShapeR ShapeRz = [|| ShapeRz ||] liftShapeR (ShapeRsnoc sh) = [|| ShapeRsnoc $$(liftShapeR sh) ||] + +instance POSable (ShapeR ()) where + type Choices (ShapeR ()) = 1 + choices x = 0 + + emptyChoices = 0 + + fromPOSable cs fs = ShapeRz + + type Fields (ShapeR ()) = '[] + + fields ShapeRz = Nil + + emptyFields = PTNil + + +instance (POSable (ShapeR sh)) => POSable (ShapeR (sh, Int)) where + type Choices (ShapeR (sh, Int)) = 1 + choices x = 0 + + emptyChoices = 0 + + fromPOSable 0 (Cons x xs) = ShapeRsnoc (fromPOSable 0 xs) + + type Fields (ShapeR (sh, Int)) = '[] ': Fields (ShapeR sh) + + fields (ShapeRsnoc sh) = Cons Undef (fields sh) + + emptyFields = PTCons STZero (emptyFields @(ShapeR sh)) diff --git a/stack-9.0.yaml b/stack-9.0.yaml index 1349abd27..b5b82f65a 100644 --- a/stack-9.0.yaml +++ b/stack-9.0.yaml @@ -7,7 +7,9 @@ resolver: nightly-2022-02-16 packages: - . -# extra-deps: [] +extra-deps: +- ../sizeof + # Override default flag values for local packages and extra-deps # flags: {} From d46f7c748f34fc2dbfa32a3d32d6dcf2da3f1d9e Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 24 Feb 2022 11:00:09 +0100 Subject: [PATCH 02/67] POS instances for primary types, Vec --- .../Array/Accelerate/Representation/POS.hs | 9 +- src/Data/Array/Accelerate/Sugar/Elt.hs | 213 ++---------------- src/Data/Array/Accelerate/Sugar/Vec.hs | 35 ++- 3 files changed, 60 insertions(+), 197 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index 1bf3d3507..94ca4dd2f 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -26,7 +26,9 @@ -- Portability : non-portable (GHC extensions) -- -module Data.Array.Accelerate.Representation.POS (POSable(..), POS, POST, mkPOS, fromPOS) +module Data.Array.Accelerate.Representation.POS ( + POSable(..), POS, POST, mkPOS, mkPOST, fromPOS, Product(..), Sum(..), + GroundType, Finite) where import Data.Array.Accelerate.Type @@ -101,7 +103,10 @@ runQ $ do mkSimple name = let t = conT name in - [d| instance POSable $t where + [d| + instance GroundType $t + + instance POSable $t where type Choices $t = 1 choices _ = 0 diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index b55158900..0d05f9da8 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -11,6 +11,7 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_HADDOCK hide #-} +{-# OPTIONS_GHC -ddump-splices #-} -- | -- Module : Data.Array.Accelerate.Sugar.Elt -- Copyright : [2008..2020] The Accelerate Team @@ -27,6 +28,7 @@ module Data.Array.Accelerate.Sugar.Elt ( Elt(..) ) import Data.Array.Accelerate.Representation.Elt import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Representation.POS import Data.Array.Accelerate.Type import Data.Bits @@ -79,172 +81,21 @@ class Elt a where -- from the surface type into the internal representation type consisting -- only of simple primitive types, unit '()', and pair '(,)'. -- - type EltR a :: Type - type EltR a = GEltR () (Rep a) - -- - eltR :: TypeR (EltR a) - tagsR :: [TagR (EltR a)] - fromElt :: a -> EltR a - toElt :: EltR a -> a - - default eltR - :: (GElt (Rep a), EltR a ~ GEltR () (Rep a)) - => TypeR (EltR a) - eltR = geltR @(Rep a) TupRunit - - default tagsR - :: (Generic a, GElt (Rep a), EltR a ~ GEltR () (Rep a)) - => [TagR (EltR a)] - tagsR = gtagsR @(Rep a) TagRunit - - default fromElt - :: (Generic a, GElt (Rep a), EltR a ~ GEltR () (Rep a)) - => a - -> EltR a - fromElt = gfromElt () . from - - default toElt - :: (Generic a, GElt (Rep a), EltR a ~ GEltR () (Rep a)) - => EltR a - -> a - toElt = to . snd . gtoElt @(Rep a) @() - -class GElt f where - type GEltR t f - geltR :: TypeR t -> TypeR (GEltR t f) - gtagsR :: TagR t -> [TagR (GEltR t f)] - gfromElt :: t -> f a -> GEltR t f - gtoElt :: GEltR t f -> (t, f a) -- - gundef :: t -> GEltR t f - guntag :: TagR t -> TagR (GEltR t f) - -instance GElt U1 where - type GEltR t U1 = t - geltR t = t - gtagsR t = [t] - gfromElt t U1 = t - gtoElt t = (t, U1) - gundef t = t - guntag t = t - -instance GElt a => GElt (M1 i c a) where - type GEltR t (M1 i c a) = GEltR t a - geltR = geltR @a - gtagsR = gtagsR @a - gfromElt t (M1 x) = gfromElt t x - gtoElt x = let (t, x1) = gtoElt x in (t, M1 x1) - gundef = gundef @a - guntag = guntag @a - -instance Elt a => GElt (K1 i a) where - type GEltR t (K1 i a) = (t, EltR a) - geltR t = TupRpair t (eltR @a) - gtagsR t = TagRpair t <$> tagsR @a - gfromElt t (K1 x) = (t, fromElt x) - gtoElt (t, x) = (t, K1 (toElt x)) - gundef t = (t, undefElt (eltR @a)) - guntag t = TagRpair t (untag (eltR @a)) - -instance (GElt a, GElt b) => GElt (a :*: b) where - type GEltR t (a :*: b) = GEltR (GEltR t a) b - geltR = geltR @b . geltR @a - gtagsR = concatMap (gtagsR @b) . gtagsR @a - gfromElt t (a :*: b) = gfromElt (gfromElt t a) b - gtoElt t = - let (t1, b) = gtoElt t - (t2, a) = gtoElt t1 - in - (t2, a :*: b) - gundef t = gundef @b (gundef @a t) - guntag t = guntag @b (guntag @a t) + -- eltR :: EltRT a + -- tagsR :: [TagR (EltR a)] + fromElt :: a -> POS a + toElt :: POS a -> a -instance (GElt a, GElt b, GSumElt (a :+: b)) => GElt (a :+: b) where - type GEltR t (a :+: b) = (TAG, GSumEltR t (a :+: b)) - geltR t = TupRpair (TupRsingle scalarType) (gsumEltR @(a :+: b) t) - gtagsR t = uncurry TagRtag <$> gsumTagsR @(a :+: b) 0 t - gfromElt = gsumFromElt 0 - gtoElt (k,x) = gsumToElt k x - gundef t = (0xff, gsumUndef @(a :+: b) t) - guntag t = TagRpair (TagRundef scalarType) (gsumUntag @(a :+: b) t) + -- default eltR :: (POSable a) => EltRT a + -- eltR = mkPOST @a + default fromElt :: (POSable a) => a -> POS a + fromElt a = mkPOS a -class GSumElt f where - type GSumEltR t f - gsumEltR :: TypeR t -> TypeR (GSumEltR t f) - gsumTagsR :: TAG -> TagR t -> [(TAG, TagR (GSumEltR t f))] - gsumFromElt :: TAG -> t -> f a -> (TAG, GSumEltR t f) - gsumToElt :: TAG -> GSumEltR t f -> (t, f a) - gsumUndef :: t -> GSumEltR t f - gsumUntag :: TagR t -> TagR (GSumEltR t f) - -instance GSumElt U1 where - type GSumEltR t U1 = t - gsumEltR t = t - gsumTagsR n t = [(n, t)] - gsumFromElt n t U1 = (n, t) - gsumToElt _ t = (t, U1) - gsumUndef t = t - gsumUntag t = t - -instance GSumElt a => GSumElt (M1 i c a) where - type GSumEltR t (M1 i c a) = GSumEltR t a - gsumEltR = gsumEltR @a - gsumTagsR = gsumTagsR @a - gsumFromElt n t (M1 x) = gsumFromElt n t x - gsumToElt k x = let (t, x') = gsumToElt k x in (t, M1 x') - gsumUntag = gsumUntag @a - gsumUndef = gsumUndef @a - -instance Elt a => GSumElt (K1 i a) where - type GSumEltR t (K1 i a) = (t, EltR a) - gsumEltR t = TupRpair t (eltR @a) - gsumTagsR n t = (n,) . TagRpair t <$> tagsR @a - gsumFromElt n t (K1 x) = (n, (t, fromElt x)) - gsumToElt _ (t, x) = (t, K1 (toElt x)) - gsumUntag t = TagRpair t (untag (eltR @a)) - gsumUndef t = (t, undefElt (eltR @a)) - -instance (GElt a, GElt b) => GSumElt (a :*: b) where - type GSumEltR t (a :*: b) = GEltR t (a :*: b) - gsumEltR = geltR @(a :*: b) - gsumTagsR n t = (n,) <$> gtagsR @(a :*: b) t - gsumFromElt n t (a :*: b) = (n, gfromElt (gfromElt t a) b) - gsumToElt _ t0 = - let (t1, b) = gtoElt t0 - (t2, a) = gtoElt t1 - in - (t2, a :*: b) - gsumUndef = gundef @(a :*: b) - gsumUntag = guntag @(a :*: b) - -instance (GSumElt a, GSumElt b) => GSumElt (a :+: b) where - type GSumEltR t (a :+: b) = GSumEltR (GSumEltR t a) b - gsumEltR = gsumEltR @b . gsumEltR @a - - gsumFromElt n t (L1 a) = let (m,r) = gsumFromElt n t a - in (shiftL m 1, gsumUndef @b r) - gsumFromElt n t (R1 b) = let (m,r) = gsumFromElt n (gsumUndef @a t) b - in (setBit (m `shiftL` 1) 0, r) - - gsumToElt k t0 = - let (t1, b) = gsumToElt (shiftR k 1) t0 - (t2, a) = gsumToElt (shiftR k 1) t1 - in - if testBit k 0 - then (t2, R1 b) - else (t2, L1 a) - - gsumTagsR k t = - let a = gsumTagsR @a k t - b = gsumTagsR @b k (gsumUntag @a t) - in - map (\(x,y) -> (x `shiftL` 1, gsumUntag @b y)) a ++ - map (\(x,y) -> (setBit (x `shiftL` 1) 0, y)) b - - gsumUndef t = gsumUndef @b (gsumUndef @a t) - gsumUntag t = gsumUntag @b (gsumUntag @a t) + default toElt :: (POSable a) => POS a -> a + toElt a = fromPOS a untag :: TypeR t -> TagR t @@ -281,19 +132,14 @@ untag (TupRpair ta tb) = TagRpair (untag ta) (untag tb) -- Instances for basic types are generated at the end of this module. -- -instance Elt () -instance Elt Bool -instance Elt Ordering -instance Elt a => Elt (Maybe a) -instance (Elt a, Elt b) => Elt (Either a b) - -instance Elt Char where - type EltR Char = Word32 - eltR = TupRsingle scalarType - tagsR = [TagRsingle scalarType] - toElt = chr . fromIntegral - fromElt = fromIntegral . ord +instance (POSable ()) => Elt () +instance (POSable Bool) => Elt Bool +instance (POSable Ordering) => Elt Ordering +instance (POSable (Maybe a), Elt a) => Elt (Maybe a) +instance (POSable (Either a b), Elt a, Elt b) => Elt (Either a b) +-- Anything that has a POS instance has a default Elt instance +-- TODO: build instances for the sections of newtypes runQ $ do let -- XXX: we might want to do the digItOut trick used by FromIntegral? @@ -340,12 +186,7 @@ runQ $ do mkSimple name = let t = conT name in - [d| instance Elt $t where - type EltR $t = $t - eltR = TupRsingle scalarType - tagsR = [TagRsingle scalarType] - fromElt = id - toElt = id + [d| instance Elt $t |] mkTuple :: Int -> Q Dec @@ -380,17 +221,13 @@ runQ $ do TyConI (NewtypeD _ _ _ _ (NormalC _ [(_, ConT b)]) _) -> return b _ -> error "unexpected case generating newtype Elt instance" -- - [d| instance Elt $(conT name) where - type EltR $(conT name) = $(conT base) - eltR = TupRsingle scalarType - tagsR = [TagRsingle scalarType] - fromElt $(conP (mkName (nameBase name)) [varP (mkName "x")]) = x - toElt = $(conE (mkName (nameBase name))) + [d| instance Elt $(conT name) |] -- ss <- mapM mkSimple (integralTypes ++ floatingTypes) - ns <- mapM mkNewtype newtypes - ts <- mapM mkTuple [2..16] + -- TODO: + -- ns <- mapM mkNewtype newtypes + -- ts <- mapM mkTuple [2..8] -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] - return (concat ss ++ concat ns ++ ts) + return (concat ss) diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index 723d32c7b..dd48b440a 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -2,8 +2,15 @@ {-# LANGUAGE MagicHash #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilyDependencies #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_HADDOCK hide #-} {-# OPTIONS_GHC -fno-warn-orphans #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -- | -- Module : Data.Array.Accelerate.Sugar.Vec -- Copyright : [2008..2020] The Accelerate Team @@ -20,20 +27,34 @@ module Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Representation.POS import Data.Array.Accelerate.Type import Data.Primitive.Types import Data.Primitive.Vec +import Data.Kind import GHC.TypeLits import GHC.Prim -type VecElt a = (Elt a, Prim a, IsSingle a, EltR a ~ a) +type VecElt a = (Elt a, Prim a, IsSingle a) -instance (KnownNat n, VecElt a) => Elt (Vec n a) where - type EltR (Vec n a) = Vec n a - eltR = TupRsingle (VectorScalarType (VectorType (fromIntegral (natVal' (proxy# :: Proxy# n))) singleType)) - tagsR = [TagRsingle (VectorScalarType (VectorType (fromIntegral (natVal' (proxy# :: Proxy# n))) singleType))] - toElt = id - fromElt = id +instance GroundType (Vec n a) +instance (KnownNat n, VecElt a) => POSable (Vec n a) where + type Choices (Vec n a) = 1 + + choices _ = 0 + + emptyChoices = 0 + + fromPOSable 0 (Cons (Pick x) Nil) = x + + type Fields (Vec n a) = '[ '[Vec n a]] + fields x = Cons (Pick x) Nil + + emptyFields = undefined + + +-- Elt instance automatically derived from POSable instance +instance (KnownNat n, VecElt a) => (Elt (Vec n a)) From 078d4f037624325761949040befa649c480529a0 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 24 Feb 2022 11:36:21 +0100 Subject: [PATCH 03/67] emptyFields implementation for Vec --- src/Data/Array/Accelerate/Representation/POS.hs | 2 +- src/Data/Array/Accelerate/Sugar/Vec.hs | 6 +++--- src/Data/Primitive/Vec.hs | 12 ++++++++++++ 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index 94ca4dd2f..311d2eeab 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -28,7 +28,7 @@ module Data.Array.Accelerate.Representation.POS ( POSable(..), POS, POST, mkPOS, mkPOST, fromPOS, Product(..), Sum(..), - GroundType, Finite) + GroundType, Finite, ProductType(..), SumType(..)) where import Data.Array.Accelerate.Type diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index dd48b440a..7184bdbd8 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -41,7 +41,7 @@ type VecElt a = (Elt a, Prim a, IsSingle a) instance GroundType (Vec n a) -instance (KnownNat n, VecElt a) => POSable (Vec n a) where +instance (KnownNat n, VecElt a, Num a) => POSable (Vec n a) where type Choices (Vec n a) = 1 choices _ = 0 @@ -53,8 +53,8 @@ instance (KnownNat n, VecElt a) => POSable (Vec n a) where type Fields (Vec n a) = '[ '[Vec n a]] fields x = Cons (Pick x) Nil - emptyFields = undefined + emptyFields = PTCons (STSucc (replicateVecN 0) STZero) PTNil -- Elt instance automatically derived from POSable instance -instance (KnownNat n, VecElt a) => (Elt (Vec n a)) +instance (KnownNat n, VecElt a, Num a) => (Elt (Vec n a)) diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 0342f401c..4472b38cf 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -10,6 +10,7 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Vec @@ -34,6 +35,8 @@ module Data.Primitive.Vec ( listOfVec, liftVec, + replicateVecN, + ) where import Control.Monad.ST @@ -48,6 +51,7 @@ import GHC.Prim import GHC.TypeLits import GHC.Word +import Data.Proxy -- Note: [Representing SIMD vector types] -- @@ -259,6 +263,14 @@ packVec16 a b c d e f g h i j k l m n o p = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# +replicateVecN :: forall a n . (KnownNat n, Prim a) => a -> Vec n a +replicateVecN x = runST $ do + let n = fromInteger $ natVal (Proxy :: Proxy n) + mba <- newByteArray (n * sizeOf x) + mapM_ (\n' -> writeByteArray mba n x) [0..n] + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + -- O(n) at runtime to copy from the Addr# to the ByteArray#. We should be able -- to do this without copying, but I don't think the definition of ByteArray# is -- exported (or it is deeply magical). From a327abe4341e55d951d9be9958dcb730f148b4a5 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 24 Feb 2022 13:50:53 +0100 Subject: [PATCH 04/67] Redefine Elt for Shapes --- .../Array/Accelerate/Representation/POS.hs | 5 +- .../Array/Accelerate/Representation/Shape.hs | 8 +-- src/Data/Array/Accelerate/Sugar/Shape.hs | 61 ++++++++++++++----- 3 files changed, 52 insertions(+), 22 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index 311d2eeab..32d722a5f 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -43,6 +43,7 @@ import GHC.TypeLits import Data.Type.POSable.POSable import Data.Type.POSable.Representation +import Data.Type.POSable.Instances type POS a = (Finite (Choices a), Product (Fields a)) @@ -141,9 +142,9 @@ runQ $ do -- ss <- mapM mkSimple (integralTypes ++ floatingTypes) ns <- mapM mkNewtype newtypes - ts <- mapM mkTuple [2..16] + -- ts <- mapM mkTuple [2..16] -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] - return (concat ss ++ concat ns ++ ts) + return (concat ss ++ concat ns) type family Snoc2List x = xs | xs -> x where diff --git a/src/Data/Array/Accelerate/Representation/Shape.hs b/src/Data/Array/Accelerate/Representation/Shape.hs index 2c4f3906c..f3d0cad5c 100644 --- a/src/Data/Array/Accelerate/Representation/Shape.hs +++ b/src/Data/Array/Accelerate/Representation/Shape.hs @@ -208,11 +208,11 @@ liftShapeR (ShapeRsnoc sh) = [|| ShapeRsnoc $$(liftShapeR sh) ||] instance POSable (ShapeR ()) where type Choices (ShapeR ()) = 1 - choices x = 0 + choices _ = 0 emptyChoices = 0 - fromPOSable cs fs = ShapeRz + fromPOSable _ _ = ShapeRz type Fields (ShapeR ()) = '[] @@ -223,11 +223,11 @@ instance POSable (ShapeR ()) where instance (POSable (ShapeR sh)) => POSable (ShapeR (sh, Int)) where type Choices (ShapeR (sh, Int)) = 1 - choices x = 0 + choices _ = 0 emptyChoices = 0 - fromPOSable 0 (Cons x xs) = ShapeRsnoc (fromPOSable 0 xs) + fromPOSable 0 (Cons _ xs) = ShapeRsnoc (fromPOSable 0 xs) type Fields (ShapeR (sh, Int)) = '[] ': Fields (ShapeR sh) diff --git a/src/Data/Array/Accelerate/Sugar/Shape.hs b/src/Data/Array/Accelerate/Sugar/Shape.hs index 1ac8bd0c4..e6cce4d99 100644 --- a/src/Data/Array/Accelerate/Sugar/Shape.hs +++ b/src/Data/Array/Accelerate/Sugar/Shape.hs @@ -30,7 +30,7 @@ module Data.Array.Accelerate.Sugar.Shape where -import Data.Array.Accelerate.Sugar.Elt +-- import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import qualified Data.Array.Accelerate.Representation.Shape as R @@ -40,6 +40,30 @@ import Data.Kind import GHC.Generics +class Elt' a where + type EltR a + + fromElt :: a -> EltR a + toElt :: EltR a -> a + +instance Elt' Int where + type EltR Int = Int + + fromElt = id + toElt = id + +instance Elt' Z where + type EltR Z = () + + fromElt Z = () + toElt () = Z + +instance Elt' All where + type EltR All = () + + fromElt All = () + toElt () = All + -- Shorthand for common shape types -- type DIM0 = Z @@ -56,14 +80,14 @@ type DIM9 = DIM8 :. Int -- | Rank-0 index -- data Z = Z - deriving (Show, Eq, Generic, Elt) + deriving (Show, Eq) -- | Increase an index rank by one dimension. The ':.' operator is used to -- construct both values and types. -- infixl 3 :. data tail :. head = !tail :. !head - deriving (Eq, Generic) -- Not deriving Elt or Show + deriving (Eq, Generic) -- Not deriving Elt' or Show -- We don't we use a derived Show instance for (:.) because this will insert -- parenthesis to demonstrate which order the operator is applied, i.e.: @@ -97,7 +121,7 @@ instance (Show sh, Show sz) => Show (sh :. sz) where -- 'Data.Array.Accelerate.Language.replicate' for examples. -- data All = All - deriving (Show, Eq, Generic, Elt) + deriving (Show, Eq) -- | Marker for arbitrary dimensions in 'Data.Array.Accelerate.Language.slice' -- and 'Data.Array.Accelerate.Language.replicate' descriptors. @@ -126,7 +150,7 @@ data Split = Split -- For example, in the following definition, 'Divide' matches against any shape -- and flattens everything but the innermost dimension. -- --- > vectors :: (Shape sh, Elt e) => Acc (Array (sh:.Int) e) -> Seq [Vector e] +-- > vectors :: (Shape sh, Elt' e) => Acc (Array (sh:.Int) e) -> Seq [Vector e] -- > vectors = toSeq (Divide :. All) -- data Divide sh = Divide @@ -241,7 +265,7 @@ sliceShape slx = toElt . R.sliceShape slx . fromElt -- | Project the full shape from a slice -- sliceDomain - :: (Elt slix, Shape sl, Shape dim) + :: (Elt' slix, Shape sl, Shape dim) => R.SliceIndex (EltR slix) (EltR sl) co (EltR dim) -> slix -> sl @@ -258,7 +282,7 @@ sliceDomain slx slix sl = toElt $ R.sliceDomain slx (fromElt slix) (fromElt sl) -- > in -- > enumSlices slix sh :: [ Z :. Int :. Int :. All ] -- -enumSlices :: forall slix co sl dim. (Elt slix, Elt dim) +enumSlices :: forall slix co sl dim. (Elt' slix, Elt' dim) => R.SliceIndex (EltR slix) sl co (EltR dim) -> dim -- Bounds -> [slix] -- All slices within bounds. @@ -266,7 +290,7 @@ enumSlices slix = map toElt . R.enumSlices slix . fromElt -- | Shapes and indices of multi-dimensional arrays -- -class (Elt sh, Elt (Any sh), FullShape sh ~ sh, CoSliceShape sh ~ sh, SliceShape sh ~ Z) +class (Elt' sh, Elt' (Any sh), FullShape sh ~ sh, CoSliceShape sh ~ sh, SliceShape sh ~ Z) => Shape sh where -- | Reified type witness for shapes @@ -282,7 +306,7 @@ class (Elt sh, Elt (Any sh), FullShape sh ~ sh, CoSliceShape sh ~ sh, SliceShape -- | Slices, aka generalised indices, as /n/-tuples and mappings of slice -- indices to slices, co-slices, and slice dimensions -- -class (Elt sl, Shape (SliceShape sl), Shape (CoSliceShape sl), Shape (FullShape sl)) +class (Elt' sl, Shape (SliceShape sl), Shape (CoSliceShape sl), Shape (FullShape sl)) => Slice sl where type SliceShape sl :: Type -- the projected slice type CoSliceShape sl :: Type -- the complement of the slice @@ -303,18 +327,23 @@ class (Slice (DivisionSlice sl)) => Division sl where (EltR (CoSliceShape slix)) (EltR (FullShape slix)) -instance (Elt t, Elt h) => Elt (t :. h) where +instance (Elt' t, Elt' h) => Elt' (t :. h) where type EltR (t :. h) = (EltR t, EltR h) - eltR = TupRpair (eltR @t) (eltR @h) - tagsR = [TagRpair t h | t <- tagsR @t, h <- tagsR @h] + -- eltR = TupRpair (eltR @t) (eltR @h) + -- tagsR = [TagRpair t h | t <- tagsR @t, h <- tagsR @h] fromElt (t:.h) = (fromElt t, fromElt h) toElt (t, h) = toElt t :. toElt h -instance Elt (Any Z) -instance Shape sh => Elt (Any (sh :. Int)) where +instance Elt' (Any Z) where + type EltR (Any Z) = () + + fromElt Any = () + toElt () = Any + +instance Shape sh => Elt' (Any (sh :. Int)) where type EltR (Any (sh :. Int)) = (EltR (Any sh), ()) - eltR = TupRpair (eltR @(Any sh)) TupRunit - tagsR = [TagRpair t TagRunit | t <- tagsR @(Any sh)] + -- eltR = TupRpair (eltR @(Any sh)) TupRunit + -- tagsR = [TagRpair t TagRunit | t <- tagsR @(Any sh)] fromElt _ = (fromElt (Any :: Any sh), ()) toElt _ = Any From 7d2411780d2abbc01d74db47d5c28e54511a482c Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 24 Feb 2022 13:57:13 +0100 Subject: [PATCH 05/67] actually fill array in replicateVecN --- src/Data/Primitive/Vec.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 4472b38cf..f0a65ca2d 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -267,7 +267,7 @@ replicateVecN :: forall a n . (KnownNat n, Prim a) => a -> Vec n a replicateVecN x = runST $ do let n = fromInteger $ natVal (Proxy :: Proxy n) mba <- newByteArray (n * sizeOf x) - mapM_ (\n' -> writeByteArray mba n x) [0..n] + mapM_ (\n' -> writeByteArray mba n' x) [0..n] ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# From df2ecff5cb16e63611557cdf53edf449043d322e Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Fri, 25 Feb 2022 14:07:21 +0100 Subject: [PATCH 06/67] Array with Elt' --- src/Data/Array/Accelerate/Sugar/Array.hs | 38 ++++++++++++++---------- src/Data/Array/Accelerate/Sugar/Elt.hs | 1 + src/Data/Array/Accelerate/Sugar/Shape.hs | 1 + 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/src/Data/Array/Accelerate/Sugar/Array.hs b/src/Data/Array/Accelerate/Sugar/Array.hs index 87a490246..00a52a743 100644 --- a/src/Data/Array/Accelerate/Sugar/Array.hs +++ b/src/Data/Array/Accelerate/Sugar/Array.hs @@ -22,13 +22,15 @@ module Data.Array.Accelerate.Sugar.Array where -import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Type import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Representation.Type import qualified Data.Array.Accelerate.Representation.Array as R import Control.DeepSeq import Data.Kind +import Data.Char +import Data.Word import Data.Typeable import Language.Haskell.TH.Extra hiding ( Type ) import System.IO.Unsafe @@ -40,6 +42,12 @@ import qualified GHC.Exts as GHC -- $setup -- >>> :seti -XOverloadedLists +instance Elt' Char where + type EltR Char = Word32 + eltR = TupRsingle scalarType + -- tagsR = [TagRsingle scalarType] + toElt = chr . fromIntegral + fromElt = fromIntegral . ord type Scalar = Array DIM0 -- ^ Scalar arrays hold a single element type Vector = Array DIM1 -- ^ Vectors are one-dimensional arrays @@ -95,14 +103,14 @@ type Segments = Vector newtype Array sh e = Array (R.Array (EltR sh) (EltR e)) deriving Typeable -instance (Shape sh, Elt e, Eq sh, Eq e) => Eq (Array sh e) where +instance (Shape sh, Elt' e, Eq sh, Eq e) => Eq (Array sh e) where arr1 == arr2 = shape arr1 == shape arr2 && toList arr1 == toList arr2 arr1 /= arr2 = shape arr1 /= shape arr2 || toList arr1 /= toList arr2 -instance (Shape sh, Elt e, Show e) => Show (Array sh e) where +instance (Shape sh, Elt' e, Show e) => Show (Array sh e) where show (Array arr) = R.showArray (shows . toElt @e) (arrayR @sh @e) arr -instance Elt e => IsList (Vector e) where +instance Elt' e => IsList (Vector e) where type Item (Vector e) = e toList = toList fromListN n = fromList (Z:.n) @@ -111,7 +119,7 @@ instance Elt e => IsList (Vector e) where instance IsString (Vector Char) where fromString s = fromList (Z :. length s) s -instance (Shape sh, Elt e) => NFData (Array sh e) where +instance (Shape sh, Elt' e) => NFData (Array sh e) where rnf (Array arr) = R.rnfArray (arrayR @sh @e) arr -- Note: [Embedded class constraints on Array] @@ -146,26 +154,26 @@ reshape sh (Array arr) = Array $ R.reshape (shapeR @sh) (fromElt sh) (shapeR @sh -- | Return the value of an array at the given multidimensional index -- infixl 9 ! -(!) :: forall sh e. (Shape sh, Elt e) => Array sh e -> sh -> e +(!) :: forall sh e. (Shape sh, Elt' e) => Array sh e -> sh -> e (!) (Array arr) ix = toElt $ R.indexArray (arrayR @sh @e) arr (fromElt ix) -- | Return the value of an array at given the linear (row-major) index -- infixl 9 !! -(!!) :: forall sh e. Elt e => Array sh e -> Int -> e +(!!) :: forall sh e. Elt' e => Array sh e -> Int -> e (!!) (Array arr) i = toElt $ R.linearIndexArray (eltR @e) arr i -- | Create an array from its representation function, applied at each -- index of the array -- -fromFunction :: (Shape sh, Elt e) => sh -> (sh -> e) -> Array sh e +fromFunction :: (Shape sh, Elt' e) => sh -> (sh -> e) -> Array sh e fromFunction sh f = unsafePerformIO $! fromFunctionM sh (return . f) -- | Create an array using a monadic function applied at each index -- -- @since 1.2.0.0 -- -fromFunctionM :: forall sh e. (Shape sh, Elt e) => sh -> (sh -> IO e) -> IO (Array sh e) +fromFunctionM :: forall sh e. (Shape sh, Elt' e) => sh -> (sh -> IO e) -> IO (Array sh e) fromFunctionM sh f = Array <$> R.fromFunctionM (arrayR @sh @e) (fromElt sh) f' where f' x = do @@ -174,12 +182,12 @@ fromFunctionM sh f = Array <$> R.fromFunctionM (arrayR @sh @e) (fromElt sh) f' -- | Create a vector from the concatenation of the given list of vectors -- -concatVectors :: forall e. Elt e => [Vector e] -> Vector e +concatVectors :: forall e. Elt' e => [Vector e] -> Vector e concatVectors = toArr . R.concatVectors (eltR @e) . map fromArr -- | Creates a new, uninitialized Accelerate array -- -allocateArray :: forall sh e. (Shape sh, Elt e) => sh -> IO (Array sh e) +allocateArray :: forall sh e. (Shape sh, Elt' e) => sh -> IO (Array sh e) allocateArray sh = Array <$> R.allocateArray (arrayR @sh @e) (fromElt sh) -- | Convert elements of a list into an Accelerate 'Array' @@ -212,12 +220,12 @@ allocateArray sh = Array <$> R.allocateArray (arrayR @sh @e) (fromElt sh) -- and then traversing it a second time to collect the elements into the array, -- thus forcing the spine of the list to be manifest on the heap. -- -fromList :: forall sh e. (Shape sh, Elt e) => sh -> [e] -> Array sh e +fromList :: forall sh e. (Shape sh, Elt' e) => sh -> [e] -> Array sh e fromList sh xs = toArr $ R.fromList (arrayR @sh @e) (fromElt sh) $ map fromElt xs -- | Convert an accelerated 'Array' to a list in row-major order -- -toList :: forall sh e. (Shape sh, Elt e) => Array sh e -> [e] +toList :: forall sh e. (Shape sh, Elt' e) => Array sh e -> [e] toList = map toElt . R.toList (arrayR @sh @e) . fromArr @@ -255,7 +263,7 @@ class Arrays a where => a -> ArraysR a fromArr = (`gfromArr` ()) . from -arrayR :: forall sh e. (Shape sh, Elt e) => R.ArrayR (R.Array (EltR sh) (EltR e)) +arrayR :: forall sh e. (Shape sh, Elt' e) => R.ArrayR (R.Array (EltR sh) (EltR e)) arrayR = R.ArrayR (shapeR @sh) (eltR @e) class GArrays f where @@ -299,7 +307,7 @@ instance Arrays () where fromArr = id toArr = id -instance (Shape sh, Elt e) => Arrays (Array sh e) where +instance (Shape sh, Elt' e) => Arrays (Array sh e) where type ArraysR (Array sh e) = R.Array (EltR sh) (EltR e) arraysR = R.arraysRarray (shapeR @sh) (eltR @e) fromArr (Array arr) = arr diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index 0d05f9da8..f97cdef20 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -34,6 +34,7 @@ import Data.Array.Accelerate.Type import Data.Bits import Data.Char import Data.Kind +import Data.Word import Language.Haskell.TH.Extra hiding ( Type ) import GHC.Generics diff --git a/src/Data/Array/Accelerate/Sugar/Shape.hs b/src/Data/Array/Accelerate/Sugar/Shape.hs index e6cce4d99..5034f2f8d 100644 --- a/src/Data/Array/Accelerate/Sugar/Shape.hs +++ b/src/Data/Array/Accelerate/Sugar/Shape.hs @@ -43,6 +43,7 @@ import GHC.Generics class Elt' a where type EltR a + eltR :: TypeR (EltR a) fromElt :: a -> EltR a toElt :: EltR a -> a From 99f482b56d199c6d31d10f75f804e32d7ea19f95 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 1 Mar 2022 10:11:30 +0100 Subject: [PATCH 07/67] don't import Type from POS --- src/Data/Array/Accelerate/Representation/POS.hs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index 32d722a5f..4bdd899af 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -31,7 +31,7 @@ module Data.Array.Accelerate.Representation.POS ( GroundType, Finite, ProductType(..), SumType(..)) where -import Data.Array.Accelerate.Type +-- import Data.Array.Accelerate.Type import Data.Bits import Data.Char @@ -45,6 +45,11 @@ import Data.Type.POSable.POSable import Data.Type.POSable.Representation import Data.Type.POSable.Instances +import Data.Int +import Data.Word +import Numeric.Half +import Foreign.C.Types + type POS a = (Finite (Choices a), Product (Fields a)) mkPOS :: (POSable a) => a -> POS a From 8476596a7ae8eabd5d618d4b408684f8b66e3583 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 1 Mar 2022 13:34:18 +0100 Subject: [PATCH 08/67] convert typelists to tuples --- src/Data/Array/Accelerate/Representation/POS.hs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index 4bdd899af..3168ae1da 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -8,6 +8,7 @@ {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE TypeFamilyDependencies #-} @@ -52,6 +53,14 @@ import Foreign.C.Types type POS a = (Finite (Choices a), Product (Fields a)) +type family EltR (cs :: Nat) (fs :: f (g a)) :: Type where + EltR 1 x = FlattenProduct x + EltR n x = (Finite n, FlattenProduct x) + +type family FlattenProduct (xss :: f (g a)) :: Type where + FlattenProduct '[] = () + FlattenProduct (x ': xs) = (Sum x, FlattenProduct xs) + mkPOS :: (POSable a) => a -> POS a mkPOS x = (choices x, fields x) From d994fddc9ce61da92f28a4c7f8d04db87becf577 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Wed, 2 Mar 2022 10:43:28 +0100 Subject: [PATCH 09/67] Convert POS to EltR --- .../Array/Accelerate/Representation/POS.hs | 20 ++++++- .../Array/Accelerate/Representation/Type.hs | 4 ++ src/Data/Array/Accelerate/Type.hs | 53 ++++++++++++++++--- 3 files changed, 67 insertions(+), 10 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index 3168ae1da..ea9674e50 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -51,15 +51,31 @@ import Data.Word import Numeric.Half import Foreign.C.Types +-- import Data.Array.Accelerate.Representation.Type + type POS a = (Finite (Choices a), Product (Fields a)) -type family EltR (cs :: Nat) (fs :: f (g a)) :: Type where +type family EltR (cs :: Nat) (fs :: f (g a)) = (r :: Type) where EltR 1 x = FlattenProduct x EltR n x = (Finite n, FlattenProduct x) type family FlattenProduct (xss :: f (g a)) :: Type where FlattenProduct '[] = () - FlattenProduct (x ': xs) = (Sum x, FlattenProduct xs) + FlattenProduct (x ': xs) = (FlattenSum x, FlattenProduct xs) + +type family FlattenSum (xss :: f a) :: Type where + FlattenSum '[] = () + FlattenSum (x ': xs) = (x, FlattenSum xs) + +mkEltR :: (POSable a) => a -> EltR (Choices a) (Fields a) +mkEltR x = undefined + where + cs = choices x + fs = fields x + +-- productToTupR :: Product a -> TypeR (FlattenProduct a) +-- productToTupR Nil = TupRunit +-- productToTupR (Cons x xs) = TupRpair x (productToTupR xs) mkPOS :: (POSable a) => a -> POS a mkPOS x = (choices x, fields x) diff --git a/src/Data/Array/Accelerate/Representation/Type.hs b/src/Data/Array/Accelerate/Representation/Type.hs index 477f09a00..a647929f2 100644 --- a/src/Data/Array/Accelerate/Representation/Type.hs +++ b/src/Data/Array/Accelerate/Representation/Type.hs @@ -43,6 +43,10 @@ data TupR s a where TupRsingle :: s a -> TupR s a TupRpair :: TupR s a -> TupR s b -> TupR s (a, b) +productToTupR :: Product a -> TypeR (FlattenProduct a) +productToTupR Nil = TupRunit +productToTupR (Cons x xs) = TupRpair x (productToTupR xs) + instance Show (TupR ScalarType a) where show TupRunit = "()" show (TupRsingle t) = show t diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 94e891cc1..f61414848 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -1,19 +1,25 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE TypeFamilyDependencies #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Type @@ -66,6 +72,7 @@ module Data.Array.Accelerate.Type ( ) where import Data.Array.Accelerate.Orphans () -- Prim Half +import Data.Array.Accelerate.Representation.POS import Data.Primitive.Vec import Data.Bits @@ -73,16 +80,45 @@ import Data.Int import Data.Primitive.Types import Data.Type.Equality import Data.Word +import Data.Kind import Foreign.C.Types import Foreign.Storable ( Storable ) import Formatting -import Language.Haskell.TH.Extra +import Language.Haskell.TH.Extra hiding (Type) import Numeric.Half import Text.Printf import GHC.Prim import GHC.TypeLits +import Unsafe.Coerce + + +type family EltR (cs :: Nat) fs :: Type where + EltR 1 x = FlattenProduct x + EltR n x = (Finite n, FlattenProduct x) + +type family FlattenProduct (xss :: f (g a)) :: Type where + FlattenProduct '[] = () + FlattenProduct (x ': xs) = (ScalarType (FlattenSum x), FlattenProduct xs) + +type family FlattenSum (xss :: f a) :: Type where + FlattenSum '[] = () + FlattenSum (x ': xs) = (x, FlattenSum xs) + +flattenProduct :: Product a -> FlattenProduct a +flattenProduct Nil = () +flattenProduct (Cons x xs) = (SumScalarType x, flattenProduct xs) + +mkEltR :: forall a . (POSable a) => a -> EltR (Choices a) (Fields a) +mkEltR x = case natVal cs of + -- This distinction is hard to express in a type-correct way, + -- hence the unsafeCoerce + 1 -> unsafeCoerce fs + _ -> unsafeCoerce (cs, fs) + where + cs = choices x + fs = flattenProduct (fields x) -- Scalar types -- ------------ @@ -142,6 +178,7 @@ data BoundedType a where -- | All scalar element types implement Eq & Ord -- data ScalarType a where + SumScalarType :: Sum a -> ScalarType (FlattenSum a) SingleScalarType :: SingleType a -> ScalarType a VectorScalarType :: VectorType (Vec n a) -> ScalarType (Vec n a) From 27eac67cb9027ddbe1169bf4b4a1ab352fda682d Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Wed, 2 Mar 2022 12:38:57 +0100 Subject: [PATCH 10/67] integrate Elt and POS? --- .../Array/Accelerate/Representation/Type.hs | 7 +++--- src/Data/Array/Accelerate/Sugar/Elt.hs | 18 +++++++------- src/Data/Array/Accelerate/Type.hs | 24 ++++++++++++++----- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/Type.hs b/src/Data/Array/Accelerate/Representation/Type.hs index a647929f2..bb0e516e3 100644 --- a/src/Data/Array/Accelerate/Representation/Type.hs +++ b/src/Data/Array/Accelerate/Representation/Type.hs @@ -20,6 +20,7 @@ module Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type import Data.Primitive.Vec +import Data.Array.Accelerate.Representation.POS import Formatting import Language.Haskell.TH.Extra @@ -43,9 +44,9 @@ data TupR s a where TupRsingle :: s a -> TupR s a TupRpair :: TupR s a -> TupR s b -> TupR s (a, b) -productToTupR :: Product a -> TypeR (FlattenProduct a) -productToTupR Nil = TupRunit -productToTupR (Cons x xs) = TupRpair x (productToTupR xs) +-- productToTupR :: Product a -> TypeR (FlattenProduct a) +-- productToTupR Nil = TupRunit +-- productToTupR (Cons x xs) = TupRpair x (productToTupR xs) instance Show (TupR ScalarType a) where show TupRunit = "()" diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index f97cdef20..39ced1dd9 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -82,21 +82,23 @@ class Elt a where -- from the surface type into the internal representation type consisting -- only of simple primitive types, unit '()', and pair '(,)'. -- + type EltR a :: Type + type EltR a = POStoEltR (Choices a) (Fields a) -- - -- eltR :: EltRT a - -- tagsR :: [TagR (EltR a)] - fromElt :: a -> POS a - toElt :: POS a -> a + eltR :: TypeR (EltR a) + tagsR :: [TagR (EltR a)] + fromElt :: a -> EltR a + toElt :: EltR a -> a -- default eltR :: (POSable a) => EltRT a -- eltR = mkPOST @a - default fromElt :: (POSable a) => a -> POS a - fromElt a = mkPOS a + default fromElt :: (POSable a, POStoEltR (Choices a) (Fields a) ~ EltR a) => a -> EltR a + fromElt a = mkEltR a - default toElt :: (POSable a) => POS a -> a - toElt a = fromPOS a + default toElt :: (POSable a, POStoEltR (Choices a) (Fields a) ~ EltR a) => EltR a -> a + toElt a = fromEltR a untag :: TypeR t -> TagR t diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index f61414848..daae1f3ca 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -94,15 +94,15 @@ import GHC.TypeLits import Unsafe.Coerce -type family EltR (cs :: Nat) fs :: Type where - EltR 1 x = FlattenProduct x - EltR n x = (Finite n, FlattenProduct x) +type family POStoEltR (cs :: Nat) fs :: Type where + POStoEltR 1 x = FlattenProduct x + POStoEltR n x = (Finite n, FlattenProduct x) -type family FlattenProduct (xss :: f (g a)) :: Type where +type family FlattenProduct (xss :: f (g a)) = (r :: Type) | r -> f where FlattenProduct '[] = () FlattenProduct (x ': xs) = (ScalarType (FlattenSum x), FlattenProduct xs) -type family FlattenSum (xss :: f a) :: Type where +type family FlattenSum (xss :: f a) = (r :: Type) | r -> f where FlattenSum '[] = () FlattenSum (x ': xs) = (x, FlattenSum xs) @@ -110,7 +110,11 @@ flattenProduct :: Product a -> FlattenProduct a flattenProduct Nil = () flattenProduct (Cons x xs) = (SumScalarType x, flattenProduct xs) -mkEltR :: forall a . (POSable a) => a -> EltR (Choices a) (Fields a) +-- unFlattenProduct :: FlattenProduct a -> Product a +-- unFlattenProduct () = Nil +-- unFlattenProduct (SumScalarType x, xs) = Cons x (unFlattenProduct xs) + +mkEltR :: forall a . (POSable a) => a -> POStoEltR (Choices a) (Fields a) mkEltR x = case natVal cs of -- This distinction is hard to express in a type-correct way, -- hence the unsafeCoerce @@ -120,6 +124,14 @@ mkEltR x = case natVal cs of cs = choices x fs = flattenProduct (fields x) + +fromEltR :: forall a . (POSable a) => POStoEltR (Choices a) (Fields a) -> a +fromEltR x = fromPOSable cs fs + where + (cs, fs) = case natVal (emptyChoices @a) of + 1 -> (0, unsafeCoerce x) + _ -> unsafeCoerce x + -- Scalar types -- ------------ From 3080b142f6f6979a5a34845baf9bd8cd6638c5b0 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Wed, 2 Mar 2022 15:45:27 +0100 Subject: [PATCH 11/67] sorta kinda integrated POS into shapes --- .../Array/Accelerate/Representation/POS.hs | 39 ++++---- .../Array/Accelerate/Representation/Shape.hs | 92 +++++++++---------- .../Array/Accelerate/Representation/Slice.hs | 19 ++-- src/Data/Array/Accelerate/Sugar/Shape.hs | 71 +++++--------- src/Data/Array/Accelerate/Type.hs | 3 + 5 files changed, 101 insertions(+), 123 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index ea9674e50..d0d7c30ad 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -29,7 +29,7 @@ module Data.Array.Accelerate.Representation.POS ( POSable(..), POS, POST, mkPOS, mkPOST, fromPOS, Product(..), Sum(..), - GroundType, Finite, ProductType(..), SumType(..)) + GroundType, Finite, ProductType(..), SumType(..), POSable.Generic) where -- import Data.Array.Accelerate.Type @@ -42,7 +42,7 @@ import Language.Haskell.TH.Extra hiding ( Typ import GHC.Generics import GHC.TypeLits -import Data.Type.POSable.POSable +import Data.Type.POSable.POSable as POSable import Data.Type.POSable.Representation import Data.Type.POSable.Instances @@ -55,27 +55,28 @@ import Foreign.C.Types type POS a = (Finite (Choices a), Product (Fields a)) -type family EltR (cs :: Nat) (fs :: f (g a)) = (r :: Type) where - EltR 1 x = FlattenProduct x - EltR n x = (Finite n, FlattenProduct x) +-- type family EltR (cs :: Nat) (fs :: f (g a)) = (r :: Type) where +-- EltR 1 x = FlattenProduct x +-- EltR n x = (Finite n, FlattenProduct x) -type family FlattenProduct (xss :: f (g a)) :: Type where - FlattenProduct '[] = () - FlattenProduct (x ': xs) = (FlattenSum x, FlattenProduct xs) +-- type family FlattenProduct (xss :: f (g a)) :: Type where +-- FlattenProduct '[] = () +-- FlattenProduct '[ '[x]] = x +-- FlattenProduct (x ': xs) = (FlattenSum x, FlattenProduct xs) -type family FlattenSum (xss :: f a) :: Type where - FlattenSum '[] = () - FlattenSum (x ': xs) = (x, FlattenSum xs) +-- type family FlattenSum (xss :: f a) :: Type where +-- FlattenSum '[] = () +-- FlattenSum (x ': xs) = (x, FlattenSum xs) -mkEltR :: (POSable a) => a -> EltR (Choices a) (Fields a) -mkEltR x = undefined - where - cs = choices x - fs = fields x +-- mkEltR :: (POSable a) => a -> EltR (Choices a) (Fields a) +-- mkEltR x = undefined +-- where +-- cs = choices x +-- fs = fields x --- productToTupR :: Product a -> TypeR (FlattenProduct a) --- productToTupR Nil = TupRunit --- productToTupR (Cons x xs) = TupRpair x (productToTupR xs) +-- -- productToTupR :: Product a -> TypeR (FlattenProduct a) +-- -- productToTupR Nil = TupRunit +-- -- productToTupR (Cons x xs) = TupRpair x (productToTupR xs) mkPOS :: (POSable a) => a -> POS a mkPOS x = (choices x, fields x) diff --git a/src/Data/Array/Accelerate/Representation/Shape.hs b/src/Data/Array/Accelerate/Representation/Shape.hs index f3d0cad5c..3b7f6493b 100644 --- a/src/Data/Array/Accelerate/Representation/Shape.hs +++ b/src/Data/Array/Accelerate/Representation/Shape.hs @@ -40,7 +40,7 @@ import GHC.Base ( quotInt, r -- data ShapeR sh where ShapeRz :: ShapeR () - ShapeRsnoc :: ShapeR sh -> ShapeR (sh, Int) + ShapeRsnoc :: ShapeR sh -> ShapeR (sh, SingletonType Int) -- | Nicely format a shape as a string -- @@ -57,14 +57,14 @@ type DIM3 = ((((), Int), Int), Int) dim0 :: ShapeR DIM0 dim0 = ShapeRz -dim1 :: ShapeR DIM1 -dim1 = ShapeRsnoc dim0 +-- dim1 :: ShapeR DIM1 +-- dim1 = ShapeRsnoc dim0 -dim2 :: ShapeR DIM2 -dim2 = ShapeRsnoc dim1 +-- dim2 :: ShapeR DIM2 +-- dim2 = ShapeRsnoc dim1 -dim3 :: ShapeR DIM3 -dim3 = ShapeRsnoc dim2 +-- dim3 :: ShapeR DIM3 +-- dim3 = ShapeRsnoc dim2 -- | Number of dimensions of a /shape/ or /index/ (>= 0) -- @@ -76,15 +76,15 @@ rank (ShapeRsnoc shr) = rank shr + 1 -- size :: ShapeR sh -> sh -> Int size ShapeRz () = 1 -size (ShapeRsnoc shr) (sh, sz) - | sz <= 0 = 0 - | otherwise = size shr sh * sz +-- size (ShapeRsnoc shr) (sh, sz) +-- | sz <= 0 = 0 +-- | otherwise = size shr sh * sz -- | The empty shape -- empty :: ShapeR sh -> sh empty ShapeRz = () -empty (ShapeRsnoc shr) = (empty shr, 0) +-- empty (ShapeRsnoc shr) = (empty shr, 0) -- | Yield the intersection of two shapes -- @@ -98,11 +98,11 @@ union = zip max zip :: (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh zip _ ShapeRz () () = () -zip f (ShapeRsnoc shr) (as, a) (bs, b) = (zip f shr as bs, f a b) +-- zip f (ShapeRsnoc shr) (as, a) (bs, b) = (zip f shr as bs, f a b) eq :: ShapeR sh -> sh -> sh -> Bool eq ShapeRz () () = True -eq (ShapeRsnoc shr) (sh, i) (sh', i') = i == i' && eq shr sh sh' +-- eq (ShapeRsnoc shr) (sh, i) (sh', i') = i == i' && eq shr sh sh' -- | Map a multi-dimensional index into one in a linear, row-major @@ -111,23 +111,23 @@ eq (ShapeRsnoc shr) (sh, i) (sh', i') = i == i' && eq shr sh sh' -- toIndex :: HasCallStack => ShapeR sh -> sh -> sh -> Int toIndex ShapeRz () () = 0 -toIndex (ShapeRsnoc shr) (sh, sz) (ix, i) - = indexCheck i sz - $ toIndex shr sh ix * sz + i +-- toIndex (ShapeRsnoc shr) (sh, sz) (ix, i) +-- = indexCheck i sz +-- $ toIndex shr sh ix * sz + i -- | Inverse of 'toIndex' -- fromIndex :: HasCallStack => ShapeR sh -> sh -> Int -> sh fromIndex ShapeRz () _ = () -fromIndex (ShapeRsnoc shr) (sh, sz) i - = (fromIndex shr sh (i `quotInt` sz), r) +-- fromIndex (ShapeRsnoc shr) (sh, sz) i +-- = (fromIndex shr sh (i `quotInt` sz), r) -- If we assume that the index is in range, there is no point in computing -- the remainder for the highest dimension since i < sz must hold. -- where - r = case shr of -- Check if rank of shr is 0 - ShapeRz -> indexCheck i sz i - _ -> i `remInt` sz + -- r = case shr of -- Check if rank of shr is 0 + -- ShapeRz -> indexCheck i sz i + -- _ -> i `remInt` sz -- | Iterate through the entire shape, applying the function in the second -- argument; third argument combines results and fourth is an initial value @@ -136,20 +136,20 @@ fromIndex (ShapeRsnoc shr) (sh, sz) i -- iter :: ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a iter ShapeRz () f _ _ = f () -iter (ShapeRsnoc shr) (sh, sz) f c z = iter shr sh (\ix -> iter' (ix,0) z) c z - where - iter' (ix,i) r | i >= sz = r - | otherwise = iter' (ix,i+1) (r `c` f (ix,i)) +-- iter (ShapeRsnoc shr) (sh, sz) f c z = iter shr sh (\ix -> iter' (ix,0) z) c z +-- where +-- iter' (ix,i) r | i >= sz = r +-- | otherwise = iter' (ix,i+1) (r `c` f (ix,i)) -- | Variant of 'iter' without an initial value -- iter1 :: HasCallStack => ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a iter1 ShapeRz () f _ = f () -iter1 (ShapeRsnoc _ ) (_, 0) _ _ = boundsError "empty iteration space" -iter1 (ShapeRsnoc shr) (sh, sz) f c = iter1 shr sh (\ix -> iter1' (ix,0)) c - where - iter1' (ix,i) | i == sz-1 = f (ix,i) - | otherwise = f (ix,i) `c` iter1' (ix,i+1) +-- iter1 (ShapeRsnoc _ ) (_, 0) _ _ = boundsError "empty iteration space" +-- iter1 (ShapeRsnoc shr) (sh, sz) f c = iter1 shr sh (\ix -> iter1' (ix,0)) c +-- where +-- iter1' (ix,i) | i == sz-1 = f (ix,i) +-- | otherwise = f (ix,i) `c` iter1' (ix,i+1) -- Operations to facilitate conversion with IArray @@ -157,19 +157,19 @@ iter1 (ShapeRsnoc shr) (sh, sz) f c = iter1 shr sh (\ix -> iter1' (ix,0)) c -- rangeToShape :: ShapeR sh -> (sh, sh) -> sh rangeToShape ShapeRz ((), ()) = () -rangeToShape (ShapeRsnoc shr) ((sh1, sz1), (sh2, sz2)) = (rangeToShape shr (sh1, sh2), sz2 - sz1 + 1) +-- rangeToShape (ShapeRsnoc shr) ((sh1, sz1), (sh2, sz2)) = (rangeToShape shr (sh1, sh2), sz2 - sz1 + 1) -- | Converse of 'rangeToShape' -- shapeToRange :: ShapeR sh -> sh -> (sh, sh) shapeToRange ShapeRz () = ((), ()) -shapeToRange (ShapeRsnoc shr) (sh, sz) = let (low, high) = shapeToRange shr sh in ((low, 0), (high, sz - 1)) +-- shapeToRange (ShapeRsnoc shr) (sh, sz) = let (low, high) = shapeToRange shr sh in ((low, 0), (high, sz - 1)) -- | Convert a shape or index into its list of dimensions -- shapeToList :: ShapeR sh -> sh -> [Int] shapeToList ShapeRz () = [] -shapeToList (ShapeRsnoc shr) (sh,sz) = sz : shapeToList shr sh +-- shapeToList (ShapeRsnoc shr) (sh,sz) = sz : shapeToList shr sh -- | Convert a list of dimensions into a shape -- @@ -183,15 +183,15 @@ listToShape shr ds = -- listToShape' :: ShapeR sh -> [Int] -> Maybe sh listToShape' ShapeRz [] = Just () -listToShape' (ShapeRsnoc shr) (x:xs) = (, x) <$> listToShape' shr xs +-- listToShape' (ShapeRsnoc shr) (x:xs) = (, x) <$> listToShape' shr xs listToShape' _ _ = Nothing shapeType :: ShapeR sh -> TypeR sh shapeType ShapeRz = TupRunit -shapeType (ShapeRsnoc shr) = - shapeType shr - `TupRpair` - TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt))) +-- shapeType (ShapeRsnoc shr) = +-- shapeType shr +-- `TupRpair` +-- TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt))) rnfShape :: ShapeR sh -> sh -> () rnfShape ShapeRz () = () @@ -221,16 +221,16 @@ instance POSable (ShapeR ()) where emptyFields = PTNil -instance (POSable (ShapeR sh)) => POSable (ShapeR (sh, Int)) where - type Choices (ShapeR (sh, Int)) = 1 - choices _ = 0 +-- instance (POSable (ShapeR sh)) => POSable (ShapeR (sh, Int)) where +-- type Choices (ShapeR (sh, Int)) = 1 +-- choices _ = 0 - emptyChoices = 0 +-- emptyChoices = 0 - fromPOSable 0 (Cons _ xs) = ShapeRsnoc (fromPOSable 0 xs) +-- fromPOSable 0 (Cons _ xs) = ShapeRsnoc (fromPOSable 0 xs) - type Fields (ShapeR (sh, Int)) = '[] ': Fields (ShapeR sh) +-- type Fields (ShapeR (sh, Int)) = '[] ': Fields (ShapeR sh) - fields (ShapeRsnoc sh) = Cons Undef (fields sh) +-- fields (ShapeRsnoc sh) = Cons Undef (fields sh) - emptyFields = PTCons STZero (emptyFields @(ShapeR sh)) +-- emptyFields = PTCons STZero (emptyFields @(ShapeR sh)) diff --git a/src/Data/Array/Accelerate/Representation/Slice.hs b/src/Data/Array/Accelerate/Representation/Slice.hs index dee059a37..a40951e4a 100644 --- a/src/Data/Array/Accelerate/Representation/Slice.hs +++ b/src/Data/Array/Accelerate/Representation/Slice.hs @@ -19,6 +19,7 @@ module Data.Array.Accelerate.Representation.Slice where import Data.Array.Accelerate.Representation.Shape +import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra @@ -41,21 +42,21 @@ instance Slice sl => Slice (sl, ()) where type SliceShape (sl, ()) = (SliceShape sl, Int) type CoSliceShape (sl, ()) = CoSliceShape sl type FullShape (sl, ()) = (FullShape sl, Int) - sliceIndex = SliceAll (sliceIndex @sl) + -- sliceIndex = SliceAll (sliceIndex @sl) instance Slice sl => Slice (sl, Int) where type SliceShape (sl, Int) = SliceShape sl type CoSliceShape (sl, Int) = (CoSliceShape sl, Int) type FullShape (sl, Int) = (FullShape sl, Int) - sliceIndex = SliceFixed (sliceIndex @sl) + -- sliceIndex = SliceFixed (sliceIndex @sl) -- |Generalised array index, which may index only in a subset of the dimensions -- of a shape. -- data SliceIndex ix slice coSlice sliceDim where SliceNil :: SliceIndex () () () () - SliceAll :: SliceIndex ix slice co dim -> SliceIndex (ix, ()) (slice, Int) co (dim, Int) - SliceFixed :: SliceIndex ix slice co dim -> SliceIndex (ix, Int) slice (co, Int) (dim, Int) + SliceAll :: SliceIndex ix slice co dim -> SliceIndex (ix, ()) (slice, SingletonType Int) co (dim, SingletonType Int) + SliceFixed :: SliceIndex ix slice co dim -> SliceIndex (ix, SingletonType Int) slice (co, SingletonType Int) (dim, SingletonType Int) instance Show (SliceIndex ix slice coSlice sliceDim) where show SliceNil = "SliceNil" @@ -74,17 +75,17 @@ sliceShape (SliceFixed slix) (sh, _) = sliceShape slix sh sliceDomain :: SliceIndex slix sl co dim -> slix -> sl -> dim sliceDomain SliceNil () () = () sliceDomain (SliceAll slix) (slx, ()) (sl, sz) = (sliceDomain slix slx sl, sz) -sliceDomain (SliceFixed slix) (slx, sz) sl = (sliceDomain slix slx sl, sz) +-- sliceDomain (SliceFixed slix) (slx, sz) sl = (sliceDomain slix slx sl, sz) sliceShapeR :: SliceIndex slix sl co dim -> ShapeR sl sliceShapeR SliceNil = ShapeRz -sliceShapeR (SliceAll sl) = ShapeRsnoc $ sliceShapeR sl +-- sliceShapeR (SliceAll sl) = ShapeRsnoc $ sliceShapeR sl sliceShapeR (SliceFixed sl) = sliceShapeR sl sliceDomainR :: SliceIndex slix sl co dim -> ShapeR dim sliceDomainR SliceNil = ShapeRz -sliceDomainR (SliceAll sl) = ShapeRsnoc $ sliceDomainR sl -sliceDomainR (SliceFixed sl) = ShapeRsnoc $ sliceDomainR sl +-- sliceDomainR (SliceAll sl) = ShapeRsnoc $ sliceDomainR sl +-- sliceDomainR (SliceFixed sl) = ShapeRsnoc $ sliceDomainR sl -- | Enumerate all slices within a given bound. The innermost dimension changes -- most rapidly. @@ -98,7 +99,7 @@ enumSlices -> [slix] enumSlices SliceNil () = [()] enumSlices (SliceAll sl) (sh, _) = [ (sh', ()) | sh' <- enumSlices sl sh] -enumSlices (SliceFixed sl) (sh, n) = [ (sh', i) | sh' <- enumSlices sl sh, i <- [0..n-1]] +-- enumSlices (SliceFixed sl) (sh, n) = [ (sh', i) | sh' <- enumSlices sl sh, i <- [0..n-1]] rnfSliceIndex :: SliceIndex ix slice co sh -> () rnfSliceIndex SliceNil = () diff --git a/src/Data/Array/Accelerate/Sugar/Shape.hs b/src/Data/Array/Accelerate/Sugar/Shape.hs index 5034f2f8d..a86c9f10b 100644 --- a/src/Data/Array/Accelerate/Sugar/Shape.hs +++ b/src/Data/Array/Accelerate/Sugar/Shape.hs @@ -30,41 +30,17 @@ module Data.Array.Accelerate.Sugar.Shape where --- import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Representation.POS as POS import qualified Data.Array.Accelerate.Representation.Shape as R import qualified Data.Array.Accelerate.Representation.Slice as R import Data.Kind -import GHC.Generics +import GHC.Generics as GHC -class Elt' a where - type EltR a - - eltR :: TypeR (EltR a) - fromElt :: a -> EltR a - toElt :: EltR a -> a - -instance Elt' Int where - type EltR Int = Int - - fromElt = id - toElt = id - -instance Elt' Z where - type EltR Z = () - - fromElt Z = () - toElt () = Z - -instance Elt' All where - type EltR All = () - - fromElt All = () - toElt () = All - -- Shorthand for common shape types -- type DIM0 = Z @@ -81,14 +57,14 @@ type DIM9 = DIM8 :. Int -- | Rank-0 index -- data Z = Z - deriving (Show, Eq) + deriving (Show, Eq, GHC.Generic, POS.Generic, POSable, Elt) -- | Increase an index rank by one dimension. The ':.' operator is used to -- construct both values and types. -- infixl 3 :. data tail :. head = !tail :. !head - deriving (Eq, Generic) -- Not deriving Elt' or Show + deriving (Eq, GHC.Generic) -- Not deriving Elt or Show -- We don't we use a derived Show instance for (:.) because this will insert -- parenthesis to demonstrate which order the operator is applied, i.e.: @@ -122,7 +98,7 @@ instance (Show sh, Show sz) => Show (sh :. sz) where -- 'Data.Array.Accelerate.Language.replicate' for examples. -- data All = All - deriving (Show, Eq) + deriving (Show, Eq, GHC.Generic, POS.Generic, POSable, Elt) -- | Marker for arbitrary dimensions in 'Data.Array.Accelerate.Language.slice' -- and 'Data.Array.Accelerate.Language.replicate' descriptors. @@ -134,7 +110,7 @@ data All = All -- 'Data.Array.Accelerate.Language.replicate' for examples. -- data Any sh = Any - deriving (Show, Eq, Generic) + deriving (Show, Eq, GHC.Generic) -- | Marker for splitting along an entire dimension in division descriptors. -- @@ -151,7 +127,7 @@ data Split = Split -- For example, in the following definition, 'Divide' matches against any shape -- and flattens everything but the innermost dimension. -- --- > vectors :: (Shape sh, Elt' e) => Acc (Array (sh:.Int) e) -> Seq [Vector e] +-- > vectors :: (Shape sh, Elt e) => Acc (Array (sh:.Int) e) -> Seq [Vector e] -- > vectors = toSeq (Divide :. All) -- data Divide sh = Divide @@ -266,7 +242,7 @@ sliceShape slx = toElt . R.sliceShape slx . fromElt -- | Project the full shape from a slice -- sliceDomain - :: (Elt' slix, Shape sl, Shape dim) + :: (Elt slix, Shape sl, Shape dim) => R.SliceIndex (EltR slix) (EltR sl) co (EltR dim) -> slix -> sl @@ -283,7 +259,7 @@ sliceDomain slx slix sl = toElt $ R.sliceDomain slx (fromElt slix) (fromElt sl) -- > in -- > enumSlices slix sh :: [ Z :. Int :. Int :. All ] -- -enumSlices :: forall slix co sl dim. (Elt' slix, Elt' dim) +enumSlices :: forall slix co sl dim. (Elt slix, Elt dim) => R.SliceIndex (EltR slix) sl co (EltR dim) -> dim -- Bounds -> [slix] -- All slices within bounds. @@ -291,7 +267,7 @@ enumSlices slix = map toElt . R.enumSlices slix . fromElt -- | Shapes and indices of multi-dimensional arrays -- -class (Elt' sh, Elt' (Any sh), FullShape sh ~ sh, CoSliceShape sh ~ sh, SliceShape sh ~ Z) +class (Elt sh, Elt (Any sh), FullShape sh ~ sh, CoSliceShape sh ~ sh, SliceShape sh ~ Z) => Shape sh where -- | Reified type witness for shapes @@ -307,7 +283,7 @@ class (Elt' sh, Elt' (Any sh), FullShape sh ~ sh, CoSliceShape sh ~ sh, SliceSha -- | Slices, aka generalised indices, as /n/-tuples and mappings of slice -- indices to slices, co-slices, and slice dimensions -- -class (Elt' sl, Shape (SliceShape sl), Shape (CoSliceShape sl), Shape (FullShape sl)) +class (Elt sl, Shape (SliceShape sl), Shape (CoSliceShape sl), Shape (FullShape sl)) => Slice sl where type SliceShape sl :: Type -- the projected slice type CoSliceShape sl :: Type -- the complement of the slice @@ -328,29 +304,26 @@ class (Slice (DivisionSlice sl)) => Division sl where (EltR (CoSliceShape slix)) (EltR (FullShape slix)) -instance (Elt' t, Elt' h) => Elt' (t :. h) where +instance (Elt t, Elt h) => Elt (t :. h) where type EltR (t :. h) = (EltR t, EltR h) - -- eltR = TupRpair (eltR @t) (eltR @h) - -- tagsR = [TagRpair t h | t <- tagsR @t, h <- tagsR @h] + eltR = TupRpair (eltR @t) (eltR @h) + tagsR = [TagRpair t h | t <- tagsR @t, h <- tagsR @h] fromElt (t:.h) = (fromElt t, fromElt h) toElt (t, h) = toElt t :. toElt h -instance Elt' (Any Z) where - type EltR (Any Z) = () - - fromElt Any = () - toElt () = Any - -instance Shape sh => Elt' (Any (sh :. Int)) where +instance POS.Generic (Any Z) +instance POSable (Any Z) +instance Elt (Any Z) +instance Shape sh => Elt (Any (sh :. Int)) where type EltR (Any (sh :. Int)) = (EltR (Any sh), ()) - -- eltR = TupRpair (eltR @(Any sh)) TupRunit - -- tagsR = [TagRpair t TagRunit | t <- tagsR @(Any sh)] + eltR = TupRpair (eltR @(Any sh)) TupRunit + tagsR = [TagRpair t TagRunit | t <- tagsR @(Any sh)] fromElt _ = (fromElt (Any :: Any sh), ()) toElt _ = Any instance Shape Z where shapeR = R.ShapeRz - sliceAnyIndex = R.SliceNil + -- sliceAnyIndex = R.SliceNil sliceNoneIndex = R.SliceNil -- Note that the constraint 'i ~ Int' allows the compiler to infer that diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index daae1f3ca..4ea331550 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -132,6 +132,9 @@ fromEltR x = fromPOSable cs fs 1 -> (0, unsafeCoerce x) _ -> unsafeCoerce x + +type SingletonType x = (ScalarType (Int, ()), ()) + -- Scalar types -- ------------ From b878ec6bce32452c458e08425f5deb2779f8f600 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 3 Mar 2022 11:01:27 +0100 Subject: [PATCH 12/67] Slices understand SingletonTypes now --- .../Array/Accelerate/Representation/Slice.hs | 48 ++++++++++++++----- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/Slice.hs b/src/Data/Array/Accelerate/Representation/Slice.hs index a40951e4a..c28d50cee 100644 --- a/src/Data/Array/Accelerate/Representation/Slice.hs +++ b/src/Data/Array/Accelerate/Representation/Slice.hs @@ -19,6 +19,7 @@ module Data.Array.Accelerate.Representation.Slice where import Data.Array.Accelerate.Representation.Shape +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra @@ -39,16 +40,16 @@ instance Slice () where sliceIndex = SliceNil instance Slice sl => Slice (sl, ()) where - type SliceShape (sl, ()) = (SliceShape sl, Int) + type SliceShape (sl, ()) = (SliceShape sl, SingletonType Int) type CoSliceShape (sl, ()) = CoSliceShape sl - type FullShape (sl, ()) = (FullShape sl, Int) - -- sliceIndex = SliceAll (sliceIndex @sl) + type FullShape (sl, ()) = (FullShape sl, SingletonType Int) + sliceIndex = SliceAll (sliceIndex @sl) -instance Slice sl => Slice (sl, Int) where - type SliceShape (sl, Int) = SliceShape sl - type CoSliceShape (sl, Int) = (CoSliceShape sl, Int) - type FullShape (sl, Int) = (FullShape sl, Int) - -- sliceIndex = SliceFixed (sliceIndex @sl) +instance Slice sl => Slice (sl, SingletonType Int) where + type SliceShape (sl, SingletonType Int) = SliceShape sl + type CoSliceShape (sl, SingletonType Int) = (CoSliceShape sl, SingletonType Int) + type FullShape (sl, SingletonType Int) = (FullShape sl, SingletonType Int) + sliceIndex = SliceFixed (sliceIndex @sl) -- |Generalised array index, which may index only in a subset of the dimensions -- of a shape. @@ -75,17 +76,17 @@ sliceShape (SliceFixed slix) (sh, _) = sliceShape slix sh sliceDomain :: SliceIndex slix sl co dim -> slix -> sl -> dim sliceDomain SliceNil () () = () sliceDomain (SliceAll slix) (slx, ()) (sl, sz) = (sliceDomain slix slx sl, sz) --- sliceDomain (SliceFixed slix) (slx, sz) sl = (sliceDomain slix slx sl, sz) +sliceDomain (SliceFixed slix) (slx, sz) sl = (sliceDomain slix slx sl, sz) sliceShapeR :: SliceIndex slix sl co dim -> ShapeR sl sliceShapeR SliceNil = ShapeRz --- sliceShapeR (SliceAll sl) = ShapeRsnoc $ sliceShapeR sl +sliceShapeR (SliceAll sl) = ShapeRsnoc $ sliceShapeR sl sliceShapeR (SliceFixed sl) = sliceShapeR sl sliceDomainR :: SliceIndex slix sl co dim -> ShapeR dim sliceDomainR SliceNil = ShapeRz --- sliceDomainR (SliceAll sl) = ShapeRsnoc $ sliceDomainR sl --- sliceDomainR (SliceFixed sl) = ShapeRsnoc $ sliceDomainR sl +sliceDomainR (SliceAll sl) = ShapeRsnoc $ sliceDomainR sl +sliceDomainR (SliceFixed sl) = ShapeRsnoc $ sliceDomainR sl -- | Enumerate all slices within a given bound. The innermost dimension changes -- most rapidly. @@ -99,7 +100,28 @@ enumSlices -> [slix] enumSlices SliceNil () = [()] enumSlices (SliceAll sl) (sh, _) = [ (sh', ()) | sh' <- enumSlices sl sh] --- enumSlices (SliceFixed sl) (sh, n) = [ (sh', i) | sh' <- enumSlices sl sh, i <- [0..n-1]] +enumSlices (SliceFixed sl) (sh, n) = [ (sh', i) | sh' <- enumSlices sl sh, i <- [0..n-1]] + +-- These functions and the Num, Enum instance make sure we can use the range +-- syntax used above. We might have to provide these instances for all +-- SingletonTypes maybe? +liftSingNumBinary :: (Elt a, EltR a ~ SingletonType a) => (a -> a -> a) -> SingletonType a -> SingletonType a -> SingletonType a +liftSingNumBinary f x y = fromElt $ f (toElt x) (toElt y) + +liftSingNumUnary :: (Elt a, EltR a ~ SingletonType a) => (a -> a) -> SingletonType a -> SingletonType a +liftSingNumUnary f x = fromElt $ f (toElt x) + +instance Num (SingletonType Int) where + (+) = liftSingNumBinary @Int (+) + (*) = liftSingNumBinary @Int (*) + (-) = liftSingNumBinary @Int (-) + abs = liftSingNumUnary @Int abs + signum = liftSingNumUnary @Int abs + fromInteger = fromElt . fromInteger @Int + +instance Enum (SingletonType Int) where + toEnum = fromElt + fromEnum = toElt rnfSliceIndex :: SliceIndex ix slice co sh -> () rnfSliceIndex SliceNil = () From a1e1497ba355886b3d4c340936ed1a358d8f4527 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 3 Mar 2022 14:13:08 +0100 Subject: [PATCH 13/67] shapes with singletontypes --- .../Array/Accelerate/Representation/Shape.hs | 101 ++++++++++-------- .../Array/Accelerate/Representation/Slice.hs | 21 ---- src/Data/Array/Accelerate/Sugar/Elt.hs | 32 +++++- src/Data/Array/Accelerate/Type.hs | 7 +- 4 files changed, 88 insertions(+), 73 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/Shape.hs b/src/Data/Array/Accelerate/Representation/Shape.hs index 3b7f6493b..907e177e5 100644 --- a/src/Data/Array/Accelerate/Representation/Shape.hs +++ b/src/Data/Array/Accelerate/Representation/Shape.hs @@ -27,6 +27,7 @@ module Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Error import Data.Array.Accelerate.Type import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Representation.POS import Data.Type.POSable.Representation @@ -50,21 +51,21 @@ showShape shr = foldr (\sh str -> str ++ " :. " ++ show sh) "Z" . shapeToList sh -- Synonyms for common shape types -- type DIM0 = () -type DIM1 = ((), Int) -type DIM2 = (((), Int), Int) -type DIM3 = ((((), Int), Int), Int) +type DIM1 = ((), SingletonType Int) +type DIM2 = (((), SingletonType Int), SingletonType Int) +type DIM3 = ((((), SingletonType Int), SingletonType Int), SingletonType Int) dim0 :: ShapeR DIM0 dim0 = ShapeRz --- dim1 :: ShapeR DIM1 --- dim1 = ShapeRsnoc dim0 +dim1 :: ShapeR DIM1 +dim1 = ShapeRsnoc dim0 --- dim2 :: ShapeR DIM2 --- dim2 = ShapeRsnoc dim1 +dim2 :: ShapeR DIM2 +dim2 = ShapeRsnoc dim1 --- dim3 :: ShapeR DIM3 --- dim3 = ShapeRsnoc dim2 +dim3 :: ShapeR DIM3 +dim3 = ShapeRsnoc dim2 -- | Number of dimensions of a /shape/ or /index/ (>= 0) -- @@ -74,17 +75,17 @@ rank (ShapeRsnoc shr) = rank shr + 1 -- | Total number of elements in an array of the given shape -- -size :: ShapeR sh -> sh -> Int +size :: ShapeR sh -> sh -> SingletonType Int size ShapeRz () = 1 --- size (ShapeRsnoc shr) (sh, sz) --- | sz <= 0 = 0 --- | otherwise = size shr sh * sz +size (ShapeRsnoc shr) (sh, sz) + -- | toElt sz <= 0 = 0 -- TODO fix Ord instance + | otherwise = size shr sh * sz -- | The empty shape -- empty :: ShapeR sh -> sh empty ShapeRz = () --- empty (ShapeRsnoc shr) = (empty shr, 0) +empty (ShapeRsnoc shr) = (empty shr, 0) -- | Yield the intersection of two shapes -- @@ -96,13 +97,13 @@ intersect = zip min union :: ShapeR sh -> sh -> sh -> sh union = zip max -zip :: (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh +zip :: (SingletonType Int -> SingletonType Int -> SingletonType Int) -> ShapeR sh -> sh -> sh -> sh zip _ ShapeRz () () = () --- zip f (ShapeRsnoc shr) (as, a) (bs, b) = (zip f shr as bs, f a b) +zip f (ShapeRsnoc shr) (as, a) (bs, b) = (zip f shr as bs, f a b) eq :: ShapeR sh -> sh -> sh -> Bool eq ShapeRz () () = True --- eq (ShapeRsnoc shr) (sh, i) (sh', i') = i == i' && eq shr sh sh' +eq (ShapeRsnoc shr) (sh, i) (sh', i') = i == i' && eq shr sh sh' -- | Map a multi-dimensional index into one in a linear, row-major @@ -111,23 +112,29 @@ eq ShapeRz () () = True -- toIndex :: HasCallStack => ShapeR sh -> sh -> sh -> Int toIndex ShapeRz () () = 0 --- toIndex (ShapeRsnoc shr) (sh, sz) (ix, i) --- = indexCheck i sz --- $ toIndex shr sh ix * sz + i +toIndex (ShapeRsnoc shr) (sh, sz) (ix, i) + = indexCheck (toElt i) (toElt sz) + $ toIndex shr sh ix * toElt sz + toElt i -- | Inverse of 'toIndex' -- -fromIndex :: HasCallStack => ShapeR sh -> sh -> Int -> sh +fromIndex :: HasCallStack => ShapeR sh -> sh -> SingletonType Int -> sh fromIndex ShapeRz () _ = () --- fromIndex (ShapeRsnoc shr) (sh, sz) i --- = (fromIndex shr sh (i `quotInt` sz), r) +fromIndex (ShapeRsnoc shr) (sh, sz) i + = (fromIndex shr sh (i `liftQuotInt` sz), r) -- If we assume that the index is in range, there is no point in computing -- the remainder for the highest dimension since i < sz must hold. - -- + where - -- r = case shr of -- Check if rank of shr is 0 - -- ShapeRz -> indexCheck i sz i - -- _ -> i `remInt` sz + r = case shr of -- Check if rank of shr is 0 + ShapeRz -> indexCheck (toElt i) (toElt sz) i + _ -> i `liftRemInt` sz + +liftQuotInt :: SingletonType Int -> SingletonType Int -> SingletonType Int +liftQuotInt = liftSingNumBinary quotInt + +liftRemInt :: SingletonType Int -> SingletonType Int -> SingletonType Int +liftRemInt = liftSingNumBinary remInt -- | Iterate through the entire shape, applying the function in the second -- argument; third argument combines results and fourth is an initial value @@ -136,20 +143,20 @@ fromIndex ShapeRz () _ = () -- iter :: ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a iter ShapeRz () f _ _ = f () --- iter (ShapeRsnoc shr) (sh, sz) f c z = iter shr sh (\ix -> iter' (ix,0) z) c z --- where --- iter' (ix,i) r | i >= sz = r --- | otherwise = iter' (ix,i+1) (r `c` f (ix,i)) +iter (ShapeRsnoc shr) (sh, sz) f c z = iter shr sh (\ix -> iter' (ix,0) z) c z + where + iter' (ix,i) r | i >= sz = r + | otherwise = iter' (ix,i+1) (r `c` f (ix,i)) -- | Variant of 'iter' without an initial value -- iter1 :: HasCallStack => ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a iter1 ShapeRz () f _ = f () --- iter1 (ShapeRsnoc _ ) (_, 0) _ _ = boundsError "empty iteration space" --- iter1 (ShapeRsnoc shr) (sh, sz) f c = iter1 shr sh (\ix -> iter1' (ix,0)) c --- where --- iter1' (ix,i) | i == sz-1 = f (ix,i) --- | otherwise = f (ix,i) `c` iter1' (ix,i+1) +iter1 (ShapeRsnoc _ ) (_, 0) _ _ = boundsError "empty iteration space" +iter1 (ShapeRsnoc shr) (sh, sz) f c = iter1 shr sh (\ix -> iter1' (ix,0)) c + where + iter1' (ix,i) | i == sz-1 = f (ix,i) + | otherwise = f (ix,i) `c` iter1' (ix,i+1) -- Operations to facilitate conversion with IArray @@ -157,23 +164,23 @@ iter1 ShapeRz () f _ = f () -- rangeToShape :: ShapeR sh -> (sh, sh) -> sh rangeToShape ShapeRz ((), ()) = () --- rangeToShape (ShapeRsnoc shr) ((sh1, sz1), (sh2, sz2)) = (rangeToShape shr (sh1, sh2), sz2 - sz1 + 1) +rangeToShape (ShapeRsnoc shr) ((sh1, sz1), (sh2, sz2)) = (rangeToShape shr (sh1, sh2), sz2 - sz1 + 1) -- | Converse of 'rangeToShape' -- shapeToRange :: ShapeR sh -> sh -> (sh, sh) shapeToRange ShapeRz () = ((), ()) --- shapeToRange (ShapeRsnoc shr) (sh, sz) = let (low, high) = shapeToRange shr sh in ((low, 0), (high, sz - 1)) +shapeToRange (ShapeRsnoc shr) (sh, sz) = let (low, high) = shapeToRange shr sh in ((low, 0), (high, sz - 1)) -- | Convert a shape or index into its list of dimensions -- -shapeToList :: ShapeR sh -> sh -> [Int] +shapeToList :: ShapeR sh -> sh -> [SingletonType Int] shapeToList ShapeRz () = [] --- shapeToList (ShapeRsnoc shr) (sh,sz) = sz : shapeToList shr sh +shapeToList (ShapeRsnoc shr) (sh,sz) = sz : shapeToList shr sh -- | Convert a list of dimensions into a shape -- -listToShape :: HasCallStack => ShapeR sh -> [Int] -> sh +listToShape :: HasCallStack => ShapeR sh -> [SingletonType Int] -> sh listToShape shr ds = case listToShape' shr ds of Just sh -> sh @@ -181,17 +188,17 @@ listToShape shr ds = -- | Attempt to convert a list of dimensions into a shape -- -listToShape' :: ShapeR sh -> [Int] -> Maybe sh +listToShape' :: ShapeR sh -> [SingletonType Int] -> Maybe sh listToShape' ShapeRz [] = Just () --- listToShape' (ShapeRsnoc shr) (x:xs) = (, x) <$> listToShape' shr xs +listToShape' (ShapeRsnoc shr) (x:xs) = (, x) <$> listToShape' shr xs listToShape' _ _ = Nothing shapeType :: ShapeR sh -> TypeR sh shapeType ShapeRz = TupRunit --- shapeType (ShapeRsnoc shr) = --- shapeType shr --- `TupRpair` --- TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt))) +shapeType (ShapeRsnoc shr) = + shapeType shr + `TupRpair` + TupRsingle (SingleScalarType (NumSingleType (IntegralNumType (TypeSingletonType @Int)))) rnfShape :: ShapeR sh -> sh -> () rnfShape ShapeRz () = () diff --git a/src/Data/Array/Accelerate/Representation/Slice.hs b/src/Data/Array/Accelerate/Representation/Slice.hs index c28d50cee..5cd5905f7 100644 --- a/src/Data/Array/Accelerate/Representation/Slice.hs +++ b/src/Data/Array/Accelerate/Representation/Slice.hs @@ -102,27 +102,6 @@ enumSlices SliceNil () = [()] enumSlices (SliceAll sl) (sh, _) = [ (sh', ()) | sh' <- enumSlices sl sh] enumSlices (SliceFixed sl) (sh, n) = [ (sh', i) | sh' <- enumSlices sl sh, i <- [0..n-1]] --- These functions and the Num, Enum instance make sure we can use the range --- syntax used above. We might have to provide these instances for all --- SingletonTypes maybe? -liftSingNumBinary :: (Elt a, EltR a ~ SingletonType a) => (a -> a -> a) -> SingletonType a -> SingletonType a -> SingletonType a -liftSingNumBinary f x y = fromElt $ f (toElt x) (toElt y) - -liftSingNumUnary :: (Elt a, EltR a ~ SingletonType a) => (a -> a) -> SingletonType a -> SingletonType a -liftSingNumUnary f x = fromElt $ f (toElt x) - -instance Num (SingletonType Int) where - (+) = liftSingNumBinary @Int (+) - (*) = liftSingNumBinary @Int (*) - (-) = liftSingNumBinary @Int (-) - abs = liftSingNumUnary @Int abs - signum = liftSingNumUnary @Int abs - fromInteger = fromElt . fromInteger @Int - -instance Enum (SingletonType Int) where - toEnum = fromElt - fromEnum = toElt - rnfSliceIndex :: SliceIndex ix slice co sh -> () rnfSliceIndex SliceNil = () rnfSliceIndex (SliceAll sh) = rnfSliceIndex sh diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index 39ced1dd9..7ff09f2b7 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -22,7 +22,7 @@ -- Portability : non-portable (GHC extensions) -- -module Data.Array.Accelerate.Sugar.Elt ( Elt(..) ) +module Data.Array.Accelerate.Sugar.Elt ( Elt(..), SingletonType, liftSingNumBinary, liftSingNumUnary ) where import Data.Array.Accelerate.Representation.Elt @@ -100,7 +100,6 @@ class Elt a where default toElt :: (POSable a, POStoEltR (Choices a) (Fields a) ~ EltR a) => EltR a -> a toElt a = fromEltR a - untag :: TypeR t -> TagR t untag TupRunit = TagRunit untag (TupRsingle t) = TagRundef t @@ -234,3 +233,32 @@ runQ $ do -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] return (concat ss) + + + +-- These functions and the Num, Enum instance make sure we can use the range +-- syntax and other Num stuff. We might have to provide these instances for all +-- SingletonTypes maybe? +liftSingNumBinary :: (Elt a, EltR a ~ SingletonType a) => (a -> a -> a) -> SingletonType a -> SingletonType a -> SingletonType a +liftSingNumBinary f x y = fromElt $ f (toElt x) (toElt y) + +liftSingNumUnary :: (Elt a, EltR a ~ SingletonType a) => (a -> a) -> SingletonType a -> SingletonType a +liftSingNumUnary f x = fromElt $ f (toElt x) + +instance Num (SingletonType Int) where + (+) = liftSingNumBinary @Int (+) + (*) = liftSingNumBinary @Int (*) + (-) = liftSingNumBinary @Int (-) + abs = liftSingNumUnary @Int abs + signum = liftSingNumUnary @Int abs + fromInteger = fromElt . fromInteger @Int + +instance Enum (SingletonType Int) where + toEnum = fromElt + fromEnum = toElt + +instance {-# OVERLAPPING #-} Eq (SingletonType Int) where + (==) x y = toElt @Int x == toElt @Int y + +instance {-# OVERLAPPING #-} Ord (SingletonType Int) where + compare x y = compare (toElt @Int x) (toElt @Int y) diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 4ea331550..b29843417 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -132,9 +132,6 @@ fromEltR x = fromPOSable cs fs 1 -> (0, unsafeCoerce x) _ -> unsafeCoerce x - -type SingletonType x = (ScalarType (Int, ()), ()) - -- Scalar types -- ------------ @@ -171,6 +168,10 @@ data IntegralType a where TypeWord16 :: IntegralType Word16 TypeWord32 :: IntegralType Word32 TypeWord64 :: IntegralType Word64 + TypeSingletonType :: IntegralType (SingletonType a) + + +type SingletonType x = (ScalarType (Int, ()), ()) -- | Floating-point types supported in array computations. -- From 1e4a1f5eb19ac4e3ae8287a218296d05db2a5c4b Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 3 Mar 2022 14:19:03 +0100 Subject: [PATCH 14/67] shape sugar with singletontypes --- src/Data/Array/Accelerate/Sugar/Shape.hs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Data/Array/Accelerate/Sugar/Shape.hs b/src/Data/Array/Accelerate/Sugar/Shape.hs index a86c9f10b..16b81971b 100644 --- a/src/Data/Array/Accelerate/Sugar/Shape.hs +++ b/src/Data/Array/Accelerate/Sugar/Shape.hs @@ -141,7 +141,7 @@ rank = R.rank (shapeR @sh) -- | Total number of elements in an array of the given /shape/ -- -size :: forall sh. Shape sh => sh -> Int +size :: forall sh. Shape sh => sh -> SingletonType Int size = R.size (shapeR @sh) . fromElt -- | The empty /shape/ @@ -172,7 +172,7 @@ toIndex sh ix = R.toIndex (shapeR @sh) (fromElt sh) (fromElt ix) -- fromIndex :: forall sh. Shape sh => sh -- ^ Total shape (extent) of the array - -> Int -- ^ The argument index + -> SingletonType Int -- ^ The argument index -> sh -- ^ Corresponding multi-dimensional index fromIndex sh = toElt . R.fromIndex (shapeR @sh) (fromElt sh) @@ -210,19 +210,19 @@ shapeToRange ix = -- | Convert a shape to a list of dimensions -- -shapeToList :: forall sh. Shape sh => sh -> [Int] +shapeToList :: forall sh. Shape sh => sh -> [SingletonType Int] shapeToList = R.shapeToList (shapeR @sh) . fromElt -- | Convert a list of dimensions into a shape. If the list does not -- contain exactly the number of elements as specified by the type of the -- shape: error. -- -listToShape :: forall sh. Shape sh => [Int] -> sh +listToShape :: forall sh. Shape sh => [SingletonType Int] -> sh listToShape = toElt . R.listToShape (shapeR @sh) -- | Attempt to convert a list of dimensions into a shape -- -listToShape' :: forall sh. Shape sh => [Int] -> Maybe sh +listToShape' :: forall sh. Shape sh => [SingletonType Int] -> Maybe sh listToShape' = fmap toElt . R.listToShape' (shapeR @sh) -- | Nicely format a shape as a string From c66b0f7e4bd07e55aec0e7705e982300a5acd4f7 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 3 Mar 2022 14:52:33 +0100 Subject: [PATCH 15/67] more array with singletontypes --- .../Array/Accelerate/Representation/Array.hs | 25 +++++++++++-------- .../Array/Accelerate/Representation/Shape.hs | 2 +- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/Array.hs b/src/Data/Array/Accelerate/Representation/Array.hs index d61304e76..c337f301d 100644 --- a/src/Data/Array/Accelerate/Representation/Array.hs +++ b/src/Data/Array/Accelerate/Representation/Array.hs @@ -26,6 +26,7 @@ import Data.Array.Accelerate.Type import Data.Array.Accelerate.Representation.Elt import Data.Array.Accelerate.Representation.Shape hiding ( zip ) import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Sugar.Elt import Data.List ( intersperse ) import Data.Maybe ( isJust ) @@ -98,7 +99,7 @@ arraysRpair a b = TupRunit `TupRpair` TupRsingle a `TupRpair` TupRsingle b -- allocateArray :: ArrayR (Array sh e) -> sh -> IO (Array sh e) allocateArray (ArrayR shR eR) sh = do - adata <- newArrayData eR (size shR sh) + adata <- newArrayData eR (toElt $ size shR sh) return $! Array sh adata -- | Create an array from its representation function, applied at each @@ -114,13 +115,13 @@ fromFunction repr sh f = unsafePerformIO $! fromFunctionM repr sh (return . f) fromFunctionM :: ArrayR (Array sh e) -> sh -> (sh -> IO e) -> IO (Array sh e) fromFunctionM (ArrayR shR eR) sh f = do let !n = size shR sh - arr <- newArrayData eR n + arr <- newArrayData eR (toElt n) -- let write !i | i >= n = return () | otherwise = do v <- f (fromIndex shR sh i) - writeArrayData eR arr i v + writeArrayData eR arr (toElt i) v write (i+1) -- write 0 @@ -137,9 +138,9 @@ fromList (ArrayR shR eR) sh xs = adata `seq` Array sh adata -- !n = size shR sh (adata, _) = runArrayData @e $ do - arr <- newArrayData eR n + arr <- newArrayData eR (toElt n) let go !i _ | i >= n = return () - go !i (v:vs) = writeArrayData eR arr i v >> go (i+1) vs + go !i (v:vs) = writeArrayData eR arr (toElt i) v >> go (i+1) vs go _ [] = error "Data.Array.Accelerate.fromList: not enough input data" -- go 0 xs @@ -156,16 +157,16 @@ toList (ArrayR shR eR) (Array sh adata) = go 0 -- !n = size shR sh go !i | i >= n = [] - | otherwise = indexArrayData eR adata i : go (i+1) + | otherwise = indexArrayData eR adata (toElt i) : go (i+1) concatVectors :: forall e. TypeR e -> [Vector e] -> Vector e -concatVectors tR vs = adata `seq` Array ((), len) adata +concatVectors tR vs = adata `seq` Array ((), fromElt len) adata where offsets = scanl (+) 0 (map (size dim1 . shape) vs) - len = last offsets + len = toElt $ last offsets (adata, _) = runArrayData @e $ do arr <- newArrayData tR len - sequence_ [ writeArrayData tR arr (i + k) (indexArrayData tR ad i) + sequence_ [ writeArrayData tR arr (toElt (i + k)) (indexArrayData tR ad (toElt i)) | (Array ((), n) ad, k) <- vs `zip` offsets , i <- [0 .. n - 1] ] return (arr, undefined) @@ -217,7 +218,9 @@ showMatrix f (ArrayR _ arrR) arr@(Array sh _) | rows * cols == 0 = "[]" | otherwise = "\n [" ++ ppMat 0 0 where - (((), rows), cols) = sh + rows = toElt rows' + cols = toElt cols' + (((), rows'), cols') = sh lengths = U.generate (rows*cols) (\i -> length (f (linearIndexArray arrR arr i) "")) widths = U.generate cols (\c -> U.maximum (U.generate rows (\r -> lengths U.! (r*cols+c)))) -- @@ -321,7 +324,7 @@ liftArray (ArrayR shR adR) (Array sh adata) = [|| Array $$(liftElt (shapeType shR) sh) $$(liftArrayData sz adR adata) ||] `at` [t| Array $(liftTypeQ (shapeType shR)) $(liftTypeQ adR) |] where sz :: Int - sz = size shR sh + sz = toElt (size shR sh) at :: CodeQ t -> Q Type -> CodeQ t at e t = unsafeCodeCoerce $ sigE (unTypeCode e) t diff --git a/src/Data/Array/Accelerate/Representation/Shape.hs b/src/Data/Array/Accelerate/Representation/Shape.hs index 907e177e5..76a30fc0e 100644 --- a/src/Data/Array/Accelerate/Representation/Shape.hs +++ b/src/Data/Array/Accelerate/Representation/Shape.hs @@ -78,7 +78,7 @@ rank (ShapeRsnoc shr) = rank shr + 1 size :: ShapeR sh -> sh -> SingletonType Int size ShapeRz () = 1 size (ShapeRsnoc shr) (sh, sz) - -- | toElt sz <= 0 = 0 -- TODO fix Ord instance + | sz <= 0 = 0 | otherwise = size shr sh * sz -- | The empty shape From 88401e1146edf05152d1520e82453e99de2dd2e3 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 3 Mar 2022 14:56:20 +0100 Subject: [PATCH 16/67] reverted elt' change --- src/Data/Array/Accelerate/Sugar/Array.hs | 38 ++++++++++-------------- src/Data/Array/Accelerate/Sugar/Elt.hs | 13 ++++++-- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/src/Data/Array/Accelerate/Sugar/Array.hs b/src/Data/Array/Accelerate/Sugar/Array.hs index 00a52a743..87a490246 100644 --- a/src/Data/Array/Accelerate/Sugar/Array.hs +++ b/src/Data/Array/Accelerate/Sugar/Array.hs @@ -22,15 +22,13 @@ module Data.Array.Accelerate.Sugar.Array where -import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Representation.Type import qualified Data.Array.Accelerate.Representation.Array as R import Control.DeepSeq import Data.Kind -import Data.Char -import Data.Word import Data.Typeable import Language.Haskell.TH.Extra hiding ( Type ) import System.IO.Unsafe @@ -42,12 +40,6 @@ import qualified GHC.Exts as GHC -- $setup -- >>> :seti -XOverloadedLists -instance Elt' Char where - type EltR Char = Word32 - eltR = TupRsingle scalarType - -- tagsR = [TagRsingle scalarType] - toElt = chr . fromIntegral - fromElt = fromIntegral . ord type Scalar = Array DIM0 -- ^ Scalar arrays hold a single element type Vector = Array DIM1 -- ^ Vectors are one-dimensional arrays @@ -103,14 +95,14 @@ type Segments = Vector newtype Array sh e = Array (R.Array (EltR sh) (EltR e)) deriving Typeable -instance (Shape sh, Elt' e, Eq sh, Eq e) => Eq (Array sh e) where +instance (Shape sh, Elt e, Eq sh, Eq e) => Eq (Array sh e) where arr1 == arr2 = shape arr1 == shape arr2 && toList arr1 == toList arr2 arr1 /= arr2 = shape arr1 /= shape arr2 || toList arr1 /= toList arr2 -instance (Shape sh, Elt' e, Show e) => Show (Array sh e) where +instance (Shape sh, Elt e, Show e) => Show (Array sh e) where show (Array arr) = R.showArray (shows . toElt @e) (arrayR @sh @e) arr -instance Elt' e => IsList (Vector e) where +instance Elt e => IsList (Vector e) where type Item (Vector e) = e toList = toList fromListN n = fromList (Z:.n) @@ -119,7 +111,7 @@ instance Elt' e => IsList (Vector e) where instance IsString (Vector Char) where fromString s = fromList (Z :. length s) s -instance (Shape sh, Elt' e) => NFData (Array sh e) where +instance (Shape sh, Elt e) => NFData (Array sh e) where rnf (Array arr) = R.rnfArray (arrayR @sh @e) arr -- Note: [Embedded class constraints on Array] @@ -154,26 +146,26 @@ reshape sh (Array arr) = Array $ R.reshape (shapeR @sh) (fromElt sh) (shapeR @sh -- | Return the value of an array at the given multidimensional index -- infixl 9 ! -(!) :: forall sh e. (Shape sh, Elt' e) => Array sh e -> sh -> e +(!) :: forall sh e. (Shape sh, Elt e) => Array sh e -> sh -> e (!) (Array arr) ix = toElt $ R.indexArray (arrayR @sh @e) arr (fromElt ix) -- | Return the value of an array at given the linear (row-major) index -- infixl 9 !! -(!!) :: forall sh e. Elt' e => Array sh e -> Int -> e +(!!) :: forall sh e. Elt e => Array sh e -> Int -> e (!!) (Array arr) i = toElt $ R.linearIndexArray (eltR @e) arr i -- | Create an array from its representation function, applied at each -- index of the array -- -fromFunction :: (Shape sh, Elt' e) => sh -> (sh -> e) -> Array sh e +fromFunction :: (Shape sh, Elt e) => sh -> (sh -> e) -> Array sh e fromFunction sh f = unsafePerformIO $! fromFunctionM sh (return . f) -- | Create an array using a monadic function applied at each index -- -- @since 1.2.0.0 -- -fromFunctionM :: forall sh e. (Shape sh, Elt' e) => sh -> (sh -> IO e) -> IO (Array sh e) +fromFunctionM :: forall sh e. (Shape sh, Elt e) => sh -> (sh -> IO e) -> IO (Array sh e) fromFunctionM sh f = Array <$> R.fromFunctionM (arrayR @sh @e) (fromElt sh) f' where f' x = do @@ -182,12 +174,12 @@ fromFunctionM sh f = Array <$> R.fromFunctionM (arrayR @sh @e) (fromElt sh) f' -- | Create a vector from the concatenation of the given list of vectors -- -concatVectors :: forall e. Elt' e => [Vector e] -> Vector e +concatVectors :: forall e. Elt e => [Vector e] -> Vector e concatVectors = toArr . R.concatVectors (eltR @e) . map fromArr -- | Creates a new, uninitialized Accelerate array -- -allocateArray :: forall sh e. (Shape sh, Elt' e) => sh -> IO (Array sh e) +allocateArray :: forall sh e. (Shape sh, Elt e) => sh -> IO (Array sh e) allocateArray sh = Array <$> R.allocateArray (arrayR @sh @e) (fromElt sh) -- | Convert elements of a list into an Accelerate 'Array' @@ -220,12 +212,12 @@ allocateArray sh = Array <$> R.allocateArray (arrayR @sh @e) (fromElt sh) -- and then traversing it a second time to collect the elements into the array, -- thus forcing the spine of the list to be manifest on the heap. -- -fromList :: forall sh e. (Shape sh, Elt' e) => sh -> [e] -> Array sh e +fromList :: forall sh e. (Shape sh, Elt e) => sh -> [e] -> Array sh e fromList sh xs = toArr $ R.fromList (arrayR @sh @e) (fromElt sh) $ map fromElt xs -- | Convert an accelerated 'Array' to a list in row-major order -- -toList :: forall sh e. (Shape sh, Elt' e) => Array sh e -> [e] +toList :: forall sh e. (Shape sh, Elt e) => Array sh e -> [e] toList = map toElt . R.toList (arrayR @sh @e) . fromArr @@ -263,7 +255,7 @@ class Arrays a where => a -> ArraysR a fromArr = (`gfromArr` ()) . from -arrayR :: forall sh e. (Shape sh, Elt' e) => R.ArrayR (R.Array (EltR sh) (EltR e)) +arrayR :: forall sh e. (Shape sh, Elt e) => R.ArrayR (R.Array (EltR sh) (EltR e)) arrayR = R.ArrayR (shapeR @sh) (eltR @e) class GArrays f where @@ -307,7 +299,7 @@ instance Arrays () where fromArr = id toArr = id -instance (Shape sh, Elt' e) => Arrays (Array sh e) where +instance (Shape sh, Elt e) => Arrays (Array sh e) where type ArraysR (Array sh e) = R.Array (EltR sh) (EltR e) arraysR = R.arraysRarray (shapeR @sh) (eltR @e) fromArr (Array arr) = arr diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index 7ff09f2b7..5ab86fdbd 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -134,12 +134,19 @@ untag (TupRpair ta tb) = TagRpair (untag ta) (untag tb) -- Instances for basic types are generated at the end of this module. -- -instance (POSable ()) => Elt () -instance (POSable Bool) => Elt Bool -instance (POSable Ordering) => Elt Ordering +instance Elt () +instance Elt Bool +instance Elt Ordering instance (POSable (Maybe a), Elt a) => Elt (Maybe a) instance (POSable (Either a b), Elt a, Elt b) => Elt (Either a b) +instance Elt Char where + type EltR Char = Word32 + eltR = TupRsingle scalarType + tagsR = [TagRsingle scalarType] + toElt = chr . fromIntegral + fromElt = fromIntegral . ord + -- Anything that has a POS instance has a default Elt instance -- TODO: build instances for the sections of newtypes runQ $ do From fdca1a24f0ef99564bd5a4b0ef8924ce919e3b50 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Fri, 4 Mar 2022 14:38:36 +0100 Subject: [PATCH 17/67] AST understands POS --- src/Data/Array/Accelerate/AST.hs | 29 +++++++++++-------- .../Accelerate/Representation/Stencil.hs | 22 ++++++++------ src/Data/Array/Accelerate/Smart.hs | 18 ++++-------- src/Data/Array/Accelerate/Type.hs | 2 ++ 4 files changed, 37 insertions(+), 34 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index a6d0f75f7..4a5a9477e 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -9,6 +9,8 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE DataKinds #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.AST @@ -146,6 +148,7 @@ import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Sugar.Foreign +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type import Data.Primitive.Vec @@ -198,9 +201,8 @@ type ArrayVar = Var ArrayR type ArrayVars aenv = Vars ArrayR aenv -- Bool is not a primitive type -type PrimBool = TAG -type PrimMaybe a = (TAG, ((), a)) - +type PrimBool = EltR Bool +type PrimMaybe a = EltR (Maybe a) -- Trace messages data Message a where Message :: (a -> String) -- embedded show @@ -681,13 +683,13 @@ data PrimFun sig where PrimBOr :: IntegralType a -> PrimFun ((a, a) -> a) PrimBXor :: IntegralType a -> PrimFun ((a, a) -> a) PrimBNot :: IntegralType a -> PrimFun (a -> a) - PrimBShiftL :: IntegralType a -> PrimFun ((a, Int) -> a) - PrimBShiftR :: IntegralType a -> PrimFun ((a, Int) -> a) - PrimBRotateL :: IntegralType a -> PrimFun ((a, Int) -> a) - PrimBRotateR :: IntegralType a -> PrimFun ((a, Int) -> a) - PrimPopCount :: IntegralType a -> PrimFun (a -> Int) - PrimCountLeadingZeros :: IntegralType a -> PrimFun (a -> Int) - PrimCountTrailingZeros :: IntegralType a -> PrimFun (a -> Int) + PrimBShiftL :: IntegralType a -> PrimFun ((a, SingletonType Int) -> a) + PrimBShiftR :: IntegralType a -> PrimFun ((a, SingletonType Int) -> a) + PrimBRotateL :: IntegralType a -> PrimFun ((a, SingletonType Int) -> a) + PrimBRotateR :: IntegralType a -> PrimFun ((a, SingletonType Int) -> a) + PrimPopCount :: IntegralType a -> PrimFun (a -> SingletonType Int) + PrimCountLeadingZeros :: IntegralType a -> PrimFun (a -> SingletonType Int) + PrimCountTrailingZeros :: IntegralType a -> PrimFun (a -> SingletonType Int) -- operators from Fractional and Floating PrimFDiv :: FloatingType a -> PrimFun ((a, a) -> a) @@ -940,8 +942,11 @@ primFunType = \case integral = num . IntegralNumType floating = num . FloatingNumType - tbool = TupRsingle scalarTypeWord8 - tint = TupRsingle scalarTypeInt + tbool :: TypeR PrimBool + tbool = TupRpair (TupRsingle (TagScalarType @2 0)) TupRunit + + tint :: TypeR (SingletonType Int) + tint = TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeSingletonType))) -- Normal form data diff --git a/src/Data/Array/Accelerate/Representation/Stencil.hs b/src/Data/Array/Accelerate/Representation/Stencil.hs index dd546721c..7761f4d41 100644 --- a/src/Data/Array/Accelerate/Representation/Stencil.hs +++ b/src/Data/Array/Accelerate/Representation/Stencil.hs @@ -1,5 +1,6 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Representation.Stencil @@ -25,6 +26,9 @@ module Data.Array.Accelerate.Representation.Stencil ( import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Representation.Elt +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra @@ -40,14 +44,14 @@ data StencilR sh e pat where StencilRtup3 :: StencilR sh e pat1 -> StencilR sh e pat2 -> StencilR sh e pat3 - -> StencilR (sh, Int) e (Tup3 pat1 pat2 pat3) + -> StencilR (sh, SingletonType Int) e (Tup3 pat1 pat2 pat3) StencilRtup5 :: StencilR sh e pat1 -> StencilR sh e pat2 -> StencilR sh e pat3 -> StencilR sh e pat4 -> StencilR sh e pat5 - -> StencilR (sh, Int) e (Tup5 pat1 pat2 pat3 pat4 pat5) + -> StencilR (sh, SingletonType Int) e (Tup5 pat1 pat2 pat3 pat4 pat5) StencilRtup7 :: StencilR sh e pat1 -> StencilR sh e pat2 @@ -56,7 +60,7 @@ data StencilR sh e pat where -> StencilR sh e pat5 -> StencilR sh e pat6 -> StencilR sh e pat7 - -> StencilR (sh, Int) e (Tup7 pat1 pat2 pat3 pat4 pat5 pat6 pat7) + -> StencilR (sh, SingletonType Int) e (Tup7 pat1 pat2 pat3 pat4 pat5 pat6 pat7) StencilRtup9 :: StencilR sh e pat1 -> StencilR sh e pat2 @@ -67,7 +71,7 @@ data StencilR sh e pat where -> StencilR sh e pat7 -> StencilR sh e pat8 -> StencilR sh e pat9 - -> StencilR (sh, Int) e (Tup9 pat1 pat2 pat3 pat4 pat5 pat6 pat7 pat8 pat9) + -> StencilR (sh, SingletonType Int) e (Tup9 pat1 pat2 pat3 pat4 pat5 pat6 pat7 pat8 pat9) stencilEltR :: StencilR sh e pat -> TypeR e stencilEltR (StencilRunit3 t) = t @@ -111,19 +115,19 @@ stencilHalo = go' go' StencilRunit7{} = (dim1, ((), 3)) go' StencilRunit9{} = (dim1, ((), 4)) -- - go' (StencilRtup3 a b c ) = (ShapeRsnoc shR, cons shR 1 $ foldl1 (union shR) [a', go b, go c]) + go' (StencilRtup3 a b c ) = (ShapeRsnoc shR, cons shR (fromElt @Int 1) $ foldl1 (union shR) [a', go b, go c]) where (shR, a') = go' a - go' (StencilRtup5 a b c d e ) = (ShapeRsnoc shR, cons shR 2 $ foldl1 (union shR) [a', go b, go c, go d, go e]) + go' (StencilRtup5 a b c d e ) = (ShapeRsnoc shR, cons shR (fromElt @Int 2) $ foldl1 (union shR) [a', go b, go c, go d, go e]) where (shR, a') = go' a - go' (StencilRtup7 a b c d e f g ) = (ShapeRsnoc shR, cons shR 3 $ foldl1 (union shR) [a', go b, go c, go d, go e, go f, go g]) + go' (StencilRtup7 a b c d e f g ) = (ShapeRsnoc shR, cons shR (fromElt @Int 3) $ foldl1 (union shR) [a', go b, go c, go d, go e, go f, go g]) where (shR, a') = go' a - go' (StencilRtup9 a b c d e f g h i) = (ShapeRsnoc shR, cons shR 4 $ foldl1 (union shR) [a', go b, go c, go d, go e, go f, go g, go h, go i]) + go' (StencilRtup9 a b c d e f g h i) = (ShapeRsnoc shR, cons shR (fromElt @Int 4) $ foldl1 (union shR) [a', go b, go c, go d, go e, go f, go g, go h, go i]) where (shR, a') = go' a go :: StencilR sh e stencil -> sh go = snd . go' - cons :: ShapeR sh -> Int -> sh -> (sh, Int) + cons :: ShapeR sh -> SingletonType Int -> sh -> (sh, SingletonType Int) cons ShapeRz ix () = ((), ix) cons (ShapeRsnoc shr) ix (sh, sz) = (cons shr ix sh, sz) diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 8fa577f41..a4be7a5ae 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -1156,21 +1156,13 @@ mkMin = mkPrimBinary $ PrimMin singleType -- Logical operators mkLAnd :: Exp Bool -> Exp Bool -> Exp Bool -mkLAnd (Exp a) (Exp b) = mkExp $ SmartExp (PrimApp PrimLAnd (SmartExp $ Pair x y)) `Pair` SmartExp Nil - where - x = SmartExp $ Prj PairIdxLeft a - y = SmartExp $ Prj PairIdxLeft b +mkLAnd (Exp a) (Exp b) = mkExp $ PrimApp PrimLAnd (SmartExp $ Pair a b) mkLOr :: Exp Bool -> Exp Bool -> Exp Bool -mkLOr (Exp a) (Exp b) = mkExp $ SmartExp (PrimApp PrimLOr (SmartExp $ Pair x y)) `Pair` SmartExp Nil - where - x = SmartExp $ Prj PairIdxLeft a - y = SmartExp $ Prj PairIdxLeft b +mkLOr (Exp a) (Exp b) = mkExp $ PrimApp PrimLOr (SmartExp $ Pair a b) mkLNot :: Exp Bool -> Exp Bool -mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil - where - x = SmartExp $ Prj PairIdxLeft a +mkLNot (Exp a) = mkExp $ PrimApp PrimLNot a -- Numeric conversions @@ -1260,10 +1252,10 @@ mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b) mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool -mkPrimUnaryBool = mkCoerce @PrimBool $$ mkPrimUnary +mkPrimUnaryBool = mkCoerce @Bool $$ mkPrimUnary mkPrimBinaryBool :: (Elt a, Elt b) => PrimFun ((EltR a, EltR b) -> PrimBool) -> Exp a -> Exp b -> Exp Bool -mkPrimBinaryBool = mkCoerce @PrimBool $$$ mkPrimBinary +mkPrimBinaryBool = mkCoerce @Bool $$$ mkPrimBinary unPair :: SmartExp (a, b) -> (SmartExp a, SmartExp b) unPair e = (SmartExp $ Prj PairIdxLeft e, SmartExp $ Prj PairIdxRight e) diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index b29843417..265f35d3a 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -169,6 +169,7 @@ data IntegralType a where TypeWord32 :: IntegralType Word32 TypeWord64 :: IntegralType Word64 TypeSingletonType :: IntegralType (SingletonType a) + TypeTAG :: IntegralType (Finite n) type SingletonType x = (ScalarType (Int, ()), ()) @@ -195,6 +196,7 @@ data BoundedType a where -- data ScalarType a where SumScalarType :: Sum a -> ScalarType (FlattenSum a) + TagScalarType :: Finite n -> ScalarType (Finite n) SingleScalarType :: SingleType a -> ScalarType a VectorScalarType :: VectorType (Vec n a) -> ScalarType (Vec n a) From bce517385def4d43d4f7736d45790c455aa27f6b Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 8 Mar 2022 09:56:41 +0100 Subject: [PATCH 18/67] AST understands POS --- src/Data/Array/Accelerate/Smart.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index a4be7a5ae..9223c534c 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -1252,10 +1252,10 @@ mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b) mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool -mkPrimUnaryBool = mkCoerce @Bool $$ mkPrimUnary +mkPrimUnaryBool = mkPrimUnary mkPrimBinaryBool :: (Elt a, Elt b) => PrimFun ((EltR a, EltR b) -> PrimBool) -> Exp a -> Exp b -> Exp Bool -mkPrimBinaryBool = mkCoerce @Bool $$$ mkPrimBinary +mkPrimBinaryBool = mkPrimBinary unPair :: SmartExp (a, b) -> (SmartExp a, SmartExp b) unPair e = (SmartExp $ Prj PairIdxLeft e, SmartExp $ Prj PairIdxRight e) From 238551f215f98233814780931aa3f9f84f72ffb0 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 8 Mar 2022 10:09:27 +0100 Subject: [PATCH 19/67] stencil --- src/Data/Array/Accelerate/Smart.hs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 9223c534c..e8e547b56 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -627,7 +627,7 @@ class Stencil sh e stencil where -- DIM1 instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e) where type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e) - = EltR (e, e, e) + = ((((), EltR e), EltR e), EltR e) stencilR = StencilRunit3 @(EltR e) $ eltR @e stencilPrj s = (Exp $ prj2 s, Exp $ prj1 s, @@ -635,7 +635,7 @@ instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e) where instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e) where type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e) - = EltR (e, e, e, e, e) + = ((((((), EltR e), EltR e), EltR e), EltR e), EltR e) stencilR = StencilRunit5 $ eltR @e stencilPrj s = (Exp $ prj4 s, Exp $ prj3 s, @@ -645,7 +645,7 @@ instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e) where instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) where type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) - = EltR (e, e, e, e, e, e, e) + = ((((((((), EltR e), EltR e), EltR e), EltR e), EltR e), EltR e), EltR e) stencilR = StencilRunit7 $ eltR @e stencilPrj s = (Exp $ prj6 s, Exp $ prj5 s, @@ -658,7 +658,7 @@ instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) where type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) - = EltR (e, e, e, e, e, e, e, e, e) + = ((((((((((), EltR e), EltR e), EltR e), EltR e), EltR e), EltR e), EltR e), EltR e), EltR e) stencilR = StencilRunit9 $ eltR @e stencilPrj s = (Exp $ prj8 s, Exp $ prj7 s, From 05b4adee7f9bc29a6bd438b4f81195f9474f6ec6 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 8 Mar 2022 14:55:34 +0100 Subject: [PATCH 20/67] Make Singletontypes behave as original --- src/Data/Array/Accelerate/Type.hs | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 265f35d3a..3f3d956a6 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -95,10 +95,11 @@ import Unsafe.Coerce type family POStoEltR (cs :: Nat) fs :: Type where - POStoEltR 1 x = FlattenProduct x - POStoEltR n x = (Finite n, FlattenProduct x) + POStoEltR 1 '[ '[x]] = x -- singletontypes + POStoEltR 1 x = FlattenProduct x -- tagless types (could / should be represented without Sums in the Product) + POStoEltR n x = (Finite n, FlattenProduct x) -- all other types -type family FlattenProduct (xss :: f (g a)) = (r :: Type) | r -> f where +type family FlattenProduct (xss :: f (g a)) = (r :: Type) where FlattenProduct '[] = () FlattenProduct (x ': xs) = (ScalarType (FlattenSum x), FlattenProduct xs) @@ -117,8 +118,10 @@ flattenProduct (Cons x xs) = (SumScalarType x, flattenProduct xs) mkEltR :: forall a . (POSable a) => a -> POStoEltR (Choices a) (Fields a) mkEltR x = case natVal cs of -- This distinction is hard to express in a type-correct way, - -- hence the unsafeCoerce - 1 -> unsafeCoerce fs + -- hence the unsafeCoerce's + 1 -> case emptyFields @a of + PTCons (STSucc _ STZero) PTNil | Cons (Pick f) Nil <- fields x -> unsafeCoerce f + _ -> unsafeCoerce fs _ -> unsafeCoerce (cs, fs) where cs = choices x @@ -126,11 +129,11 @@ mkEltR x = case natVal cs of fromEltR :: forall a . (POSable a) => POStoEltR (Choices a) (Fields a) -> a -fromEltR x = fromPOSable cs fs - where - (cs, fs) = case natVal (emptyChoices @a) of - 1 -> (0, unsafeCoerce x) - _ -> unsafeCoerce x +fromEltR x = case natVal (emptyChoices @a) of + 1 -> case emptyFields @a of + PTCons (STSucc _ STZero) PTNil -> unsafeCoerce x + _ -> fromPOSable 0 (unsafeCoerce x) + _ -> uncurry fromPOSable (unsafeCoerce x) -- Scalar types -- ------------ @@ -180,6 +183,7 @@ data FloatingType a where TypeHalf :: FloatingType Half TypeFloat :: FloatingType Float TypeDouble :: FloatingType Double + TypeFloatingSingletonType :: FloatingType (SingletonType a) -- | Numeric element types implement Num & Real -- From f714a7d8c7fe8befd756aba4a50e7f0ef6c56b55 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 8 Mar 2022 15:14:23 +0100 Subject: [PATCH 21/67] revert shape, slice singletons --- .../Array/Accelerate/Representation/Shape.hs | 57 ++++++++----------- .../Array/Accelerate/Representation/Slice.hs | 18 +++--- 2 files changed, 33 insertions(+), 42 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/Shape.hs b/src/Data/Array/Accelerate/Representation/Shape.hs index 76a30fc0e..f3d0cad5c 100644 --- a/src/Data/Array/Accelerate/Representation/Shape.hs +++ b/src/Data/Array/Accelerate/Representation/Shape.hs @@ -27,7 +27,6 @@ module Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Error import Data.Array.Accelerate.Type import Data.Array.Accelerate.Representation.Type -import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Representation.POS import Data.Type.POSable.Representation @@ -41,7 +40,7 @@ import GHC.Base ( quotInt, r -- data ShapeR sh where ShapeRz :: ShapeR () - ShapeRsnoc :: ShapeR sh -> ShapeR (sh, SingletonType Int) + ShapeRsnoc :: ShapeR sh -> ShapeR (sh, Int) -- | Nicely format a shape as a string -- @@ -51,9 +50,9 @@ showShape shr = foldr (\sh str -> str ++ " :. " ++ show sh) "Z" . shapeToList sh -- Synonyms for common shape types -- type DIM0 = () -type DIM1 = ((), SingletonType Int) -type DIM2 = (((), SingletonType Int), SingletonType Int) -type DIM3 = ((((), SingletonType Int), SingletonType Int), SingletonType Int) +type DIM1 = ((), Int) +type DIM2 = (((), Int), Int) +type DIM3 = ((((), Int), Int), Int) dim0 :: ShapeR DIM0 dim0 = ShapeRz @@ -75,7 +74,7 @@ rank (ShapeRsnoc shr) = rank shr + 1 -- | Total number of elements in an array of the given shape -- -size :: ShapeR sh -> sh -> SingletonType Int +size :: ShapeR sh -> sh -> Int size ShapeRz () = 1 size (ShapeRsnoc shr) (sh, sz) | sz <= 0 = 0 @@ -97,7 +96,7 @@ intersect = zip min union :: ShapeR sh -> sh -> sh -> sh union = zip max -zip :: (SingletonType Int -> SingletonType Int -> SingletonType Int) -> ShapeR sh -> sh -> sh -> sh +zip :: (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh zip _ ShapeRz () () = () zip f (ShapeRsnoc shr) (as, a) (bs, b) = (zip f shr as bs, f a b) @@ -113,28 +112,22 @@ eq (ShapeRsnoc shr) (sh, i) (sh', i') = i == i' && eq shr sh sh' toIndex :: HasCallStack => ShapeR sh -> sh -> sh -> Int toIndex ShapeRz () () = 0 toIndex (ShapeRsnoc shr) (sh, sz) (ix, i) - = indexCheck (toElt i) (toElt sz) - $ toIndex shr sh ix * toElt sz + toElt i + = indexCheck i sz + $ toIndex shr sh ix * sz + i -- | Inverse of 'toIndex' -- -fromIndex :: HasCallStack => ShapeR sh -> sh -> SingletonType Int -> sh +fromIndex :: HasCallStack => ShapeR sh -> sh -> Int -> sh fromIndex ShapeRz () _ = () fromIndex (ShapeRsnoc shr) (sh, sz) i - = (fromIndex shr sh (i `liftQuotInt` sz), r) + = (fromIndex shr sh (i `quotInt` sz), r) -- If we assume that the index is in range, there is no point in computing -- the remainder for the highest dimension since i < sz must hold. - + -- where r = case shr of -- Check if rank of shr is 0 - ShapeRz -> indexCheck (toElt i) (toElt sz) i - _ -> i `liftRemInt` sz - -liftQuotInt :: SingletonType Int -> SingletonType Int -> SingletonType Int -liftQuotInt = liftSingNumBinary quotInt - -liftRemInt :: SingletonType Int -> SingletonType Int -> SingletonType Int -liftRemInt = liftSingNumBinary remInt + ShapeRz -> indexCheck i sz i + _ -> i `remInt` sz -- | Iterate through the entire shape, applying the function in the second -- argument; third argument combines results and fourth is an initial value @@ -174,13 +167,13 @@ shapeToRange (ShapeRsnoc shr) (sh, sz) = let (low, high) = shapeToRange shr sh i -- | Convert a shape or index into its list of dimensions -- -shapeToList :: ShapeR sh -> sh -> [SingletonType Int] +shapeToList :: ShapeR sh -> sh -> [Int] shapeToList ShapeRz () = [] shapeToList (ShapeRsnoc shr) (sh,sz) = sz : shapeToList shr sh -- | Convert a list of dimensions into a shape -- -listToShape :: HasCallStack => ShapeR sh -> [SingletonType Int] -> sh +listToShape :: HasCallStack => ShapeR sh -> [Int] -> sh listToShape shr ds = case listToShape' shr ds of Just sh -> sh @@ -188,7 +181,7 @@ listToShape shr ds = -- | Attempt to convert a list of dimensions into a shape -- -listToShape' :: ShapeR sh -> [SingletonType Int] -> Maybe sh +listToShape' :: ShapeR sh -> [Int] -> Maybe sh listToShape' ShapeRz [] = Just () listToShape' (ShapeRsnoc shr) (x:xs) = (, x) <$> listToShape' shr xs listToShape' _ _ = Nothing @@ -198,7 +191,7 @@ shapeType ShapeRz = TupRunit shapeType (ShapeRsnoc shr) = shapeType shr `TupRpair` - TupRsingle (SingleScalarType (NumSingleType (IntegralNumType (TypeSingletonType @Int)))) + TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt))) rnfShape :: ShapeR sh -> sh -> () rnfShape ShapeRz () = () @@ -228,16 +221,16 @@ instance POSable (ShapeR ()) where emptyFields = PTNil --- instance (POSable (ShapeR sh)) => POSable (ShapeR (sh, Int)) where --- type Choices (ShapeR (sh, Int)) = 1 --- choices _ = 0 +instance (POSable (ShapeR sh)) => POSable (ShapeR (sh, Int)) where + type Choices (ShapeR (sh, Int)) = 1 + choices _ = 0 --- emptyChoices = 0 + emptyChoices = 0 --- fromPOSable 0 (Cons _ xs) = ShapeRsnoc (fromPOSable 0 xs) + fromPOSable 0 (Cons _ xs) = ShapeRsnoc (fromPOSable 0 xs) --- type Fields (ShapeR (sh, Int)) = '[] ': Fields (ShapeR sh) + type Fields (ShapeR (sh, Int)) = '[] ': Fields (ShapeR sh) --- fields (ShapeRsnoc sh) = Cons Undef (fields sh) + fields (ShapeRsnoc sh) = Cons Undef (fields sh) --- emptyFields = PTCons STZero (emptyFields @(ShapeR sh)) + emptyFields = PTCons STZero (emptyFields @(ShapeR sh)) diff --git a/src/Data/Array/Accelerate/Representation/Slice.hs b/src/Data/Array/Accelerate/Representation/Slice.hs index 5cd5905f7..dee059a37 100644 --- a/src/Data/Array/Accelerate/Representation/Slice.hs +++ b/src/Data/Array/Accelerate/Representation/Slice.hs @@ -19,8 +19,6 @@ module Data.Array.Accelerate.Representation.Slice where import Data.Array.Accelerate.Representation.Shape -import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra @@ -40,15 +38,15 @@ instance Slice () where sliceIndex = SliceNil instance Slice sl => Slice (sl, ()) where - type SliceShape (sl, ()) = (SliceShape sl, SingletonType Int) + type SliceShape (sl, ()) = (SliceShape sl, Int) type CoSliceShape (sl, ()) = CoSliceShape sl - type FullShape (sl, ()) = (FullShape sl, SingletonType Int) + type FullShape (sl, ()) = (FullShape sl, Int) sliceIndex = SliceAll (sliceIndex @sl) -instance Slice sl => Slice (sl, SingletonType Int) where - type SliceShape (sl, SingletonType Int) = SliceShape sl - type CoSliceShape (sl, SingletonType Int) = (CoSliceShape sl, SingletonType Int) - type FullShape (sl, SingletonType Int) = (FullShape sl, SingletonType Int) +instance Slice sl => Slice (sl, Int) where + type SliceShape (sl, Int) = SliceShape sl + type CoSliceShape (sl, Int) = (CoSliceShape sl, Int) + type FullShape (sl, Int) = (FullShape sl, Int) sliceIndex = SliceFixed (sliceIndex @sl) -- |Generalised array index, which may index only in a subset of the dimensions @@ -56,8 +54,8 @@ instance Slice sl => Slice (sl, SingletonType Int) where -- data SliceIndex ix slice coSlice sliceDim where SliceNil :: SliceIndex () () () () - SliceAll :: SliceIndex ix slice co dim -> SliceIndex (ix, ()) (slice, SingletonType Int) co (dim, SingletonType Int) - SliceFixed :: SliceIndex ix slice co dim -> SliceIndex (ix, SingletonType Int) slice (co, SingletonType Int) (dim, SingletonType Int) + SliceAll :: SliceIndex ix slice co dim -> SliceIndex (ix, ()) (slice, Int) co (dim, Int) + SliceFixed :: SliceIndex ix slice co dim -> SliceIndex (ix, Int) slice (co, Int) (dim, Int) instance Show (SliceIndex ix slice coSlice sliceDim) where show SliceNil = "SliceNil" From ae5f19fc9ddbd815d863f4fdace71c8f8790f8a5 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 8 Mar 2022 15:15:35 +0100 Subject: [PATCH 22/67] revert sugar shape singleton --- src/Data/Array/Accelerate/Sugar/Shape.hs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Data/Array/Accelerate/Sugar/Shape.hs b/src/Data/Array/Accelerate/Sugar/Shape.hs index 16b81971b..02cd46b49 100644 --- a/src/Data/Array/Accelerate/Sugar/Shape.hs +++ b/src/Data/Array/Accelerate/Sugar/Shape.hs @@ -141,7 +141,7 @@ rank = R.rank (shapeR @sh) -- | Total number of elements in an array of the given /shape/ -- -size :: forall sh. Shape sh => sh -> SingletonType Int +size :: forall sh. Shape sh => sh -> Int size = R.size (shapeR @sh) . fromElt -- | The empty /shape/ @@ -172,7 +172,7 @@ toIndex sh ix = R.toIndex (shapeR @sh) (fromElt sh) (fromElt ix) -- fromIndex :: forall sh. Shape sh => sh -- ^ Total shape (extent) of the array - -> SingletonType Int -- ^ The argument index + -> Int -- ^ The argument index -> sh -- ^ Corresponding multi-dimensional index fromIndex sh = toElt . R.fromIndex (shapeR @sh) (fromElt sh) @@ -210,19 +210,19 @@ shapeToRange ix = -- | Convert a shape to a list of dimensions -- -shapeToList :: forall sh. Shape sh => sh -> [SingletonType Int] +shapeToList :: forall sh. Shape sh => sh -> [Int] shapeToList = R.shapeToList (shapeR @sh) . fromElt -- | Convert a list of dimensions into a shape. If the list does not -- contain exactly the number of elements as specified by the type of the -- shape: error. -- -listToShape :: forall sh. Shape sh => [SingletonType Int] -> sh +listToShape :: forall sh. Shape sh => [Int] -> sh listToShape = toElt . R.listToShape (shapeR @sh) -- | Attempt to convert a list of dimensions into a shape -- -listToShape' :: forall sh. Shape sh => [SingletonType Int] -> Maybe sh +listToShape' :: forall sh. Shape sh => [Int] -> Maybe sh listToShape' = fmap toElt . R.listToShape' (shapeR @sh) -- | Nicely format a shape as a string @@ -323,7 +323,7 @@ instance Shape sh => Elt (Any (sh :. Int)) where instance Shape Z where shapeR = R.ShapeRz - -- sliceAnyIndex = R.SliceNil + sliceAnyIndex = R.SliceNil sliceNoneIndex = R.SliceNil -- Note that the constraint 'i ~ Int' allows the compiler to infer that From fd1398f4a018cb76409666cdc27573b5e94a3424 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 8 Mar 2022 15:17:01 +0100 Subject: [PATCH 23/67] revert stencil singletontype --- .../Accelerate/Representation/Stencil.hs | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/Stencil.hs b/src/Data/Array/Accelerate/Representation/Stencil.hs index 7761f4d41..dd546721c 100644 --- a/src/Data/Array/Accelerate/Representation/Stencil.hs +++ b/src/Data/Array/Accelerate/Representation/Stencil.hs @@ -1,6 +1,5 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Representation.Stencil @@ -26,9 +25,6 @@ module Data.Array.Accelerate.Representation.Stencil ( import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Representation.Type -import Data.Array.Accelerate.Representation.Elt -import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra @@ -44,14 +40,14 @@ data StencilR sh e pat where StencilRtup3 :: StencilR sh e pat1 -> StencilR sh e pat2 -> StencilR sh e pat3 - -> StencilR (sh, SingletonType Int) e (Tup3 pat1 pat2 pat3) + -> StencilR (sh, Int) e (Tup3 pat1 pat2 pat3) StencilRtup5 :: StencilR sh e pat1 -> StencilR sh e pat2 -> StencilR sh e pat3 -> StencilR sh e pat4 -> StencilR sh e pat5 - -> StencilR (sh, SingletonType Int) e (Tup5 pat1 pat2 pat3 pat4 pat5) + -> StencilR (sh, Int) e (Tup5 pat1 pat2 pat3 pat4 pat5) StencilRtup7 :: StencilR sh e pat1 -> StencilR sh e pat2 @@ -60,7 +56,7 @@ data StencilR sh e pat where -> StencilR sh e pat5 -> StencilR sh e pat6 -> StencilR sh e pat7 - -> StencilR (sh, SingletonType Int) e (Tup7 pat1 pat2 pat3 pat4 pat5 pat6 pat7) + -> StencilR (sh, Int) e (Tup7 pat1 pat2 pat3 pat4 pat5 pat6 pat7) StencilRtup9 :: StencilR sh e pat1 -> StencilR sh e pat2 @@ -71,7 +67,7 @@ data StencilR sh e pat where -> StencilR sh e pat7 -> StencilR sh e pat8 -> StencilR sh e pat9 - -> StencilR (sh, SingletonType Int) e (Tup9 pat1 pat2 pat3 pat4 pat5 pat6 pat7 pat8 pat9) + -> StencilR (sh, Int) e (Tup9 pat1 pat2 pat3 pat4 pat5 pat6 pat7 pat8 pat9) stencilEltR :: StencilR sh e pat -> TypeR e stencilEltR (StencilRunit3 t) = t @@ -115,19 +111,19 @@ stencilHalo = go' go' StencilRunit7{} = (dim1, ((), 3)) go' StencilRunit9{} = (dim1, ((), 4)) -- - go' (StencilRtup3 a b c ) = (ShapeRsnoc shR, cons shR (fromElt @Int 1) $ foldl1 (union shR) [a', go b, go c]) + go' (StencilRtup3 a b c ) = (ShapeRsnoc shR, cons shR 1 $ foldl1 (union shR) [a', go b, go c]) where (shR, a') = go' a - go' (StencilRtup5 a b c d e ) = (ShapeRsnoc shR, cons shR (fromElt @Int 2) $ foldl1 (union shR) [a', go b, go c, go d, go e]) + go' (StencilRtup5 a b c d e ) = (ShapeRsnoc shR, cons shR 2 $ foldl1 (union shR) [a', go b, go c, go d, go e]) where (shR, a') = go' a - go' (StencilRtup7 a b c d e f g ) = (ShapeRsnoc shR, cons shR (fromElt @Int 3) $ foldl1 (union shR) [a', go b, go c, go d, go e, go f, go g]) + go' (StencilRtup7 a b c d e f g ) = (ShapeRsnoc shR, cons shR 3 $ foldl1 (union shR) [a', go b, go c, go d, go e, go f, go g]) where (shR, a') = go' a - go' (StencilRtup9 a b c d e f g h i) = (ShapeRsnoc shR, cons shR (fromElt @Int 4) $ foldl1 (union shR) [a', go b, go c, go d, go e, go f, go g, go h, go i]) + go' (StencilRtup9 a b c d e f g h i) = (ShapeRsnoc shR, cons shR 4 $ foldl1 (union shR) [a', go b, go c, go d, go e, go f, go g, go h, go i]) where (shR, a') = go' a go :: StencilR sh e stencil -> sh go = snd . go' - cons :: ShapeR sh -> SingletonType Int -> sh -> (sh, SingletonType Int) + cons :: ShapeR sh -> Int -> sh -> (sh, Int) cons ShapeRz ix () = ((), ix) cons (ShapeRsnoc shr) ix (sh, sz) = (cons shr ix sh, sz) From 996ad1bab211fabce288270bdc932dea9f5b4bca Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 8 Mar 2022 15:26:00 +0100 Subject: [PATCH 24/67] revert singletontype completely --- src/Data/Array/Accelerate/AST.hs | 18 ++++----- .../Array/Accelerate/Representation/POS.hs | 2 +- src/Data/Array/Accelerate/Sugar/Elt.hs | 35 +++-------------- src/Data/Array/Accelerate/Sugar/Vec.hs | 38 ++++++++++++++----- src/Data/Array/Accelerate/Type.hs | 5 --- 5 files changed, 45 insertions(+), 53 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 4a5a9477e..84e3529c1 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -683,13 +683,13 @@ data PrimFun sig where PrimBOr :: IntegralType a -> PrimFun ((a, a) -> a) PrimBXor :: IntegralType a -> PrimFun ((a, a) -> a) PrimBNot :: IntegralType a -> PrimFun (a -> a) - PrimBShiftL :: IntegralType a -> PrimFun ((a, SingletonType Int) -> a) - PrimBShiftR :: IntegralType a -> PrimFun ((a, SingletonType Int) -> a) - PrimBRotateL :: IntegralType a -> PrimFun ((a, SingletonType Int) -> a) - PrimBRotateR :: IntegralType a -> PrimFun ((a, SingletonType Int) -> a) - PrimPopCount :: IntegralType a -> PrimFun (a -> SingletonType Int) - PrimCountLeadingZeros :: IntegralType a -> PrimFun (a -> SingletonType Int) - PrimCountTrailingZeros :: IntegralType a -> PrimFun (a -> SingletonType Int) + PrimBShiftL :: IntegralType a -> PrimFun ((a, Int) -> a) + PrimBShiftR :: IntegralType a -> PrimFun ((a, Int) -> a) + PrimBRotateL :: IntegralType a -> PrimFun ((a, Int) -> a) + PrimBRotateR :: IntegralType a -> PrimFun ((a, Int) -> a) + PrimPopCount :: IntegralType a -> PrimFun (a -> Int) + PrimCountLeadingZeros :: IntegralType a -> PrimFun (a -> Int) + PrimCountTrailingZeros :: IntegralType a -> PrimFun (a -> Int) -- operators from Fractional and Floating PrimFDiv :: FloatingType a -> PrimFun ((a, a) -> a) @@ -945,8 +945,8 @@ primFunType = \case tbool :: TypeR PrimBool tbool = TupRpair (TupRsingle (TagScalarType @2 0)) TupRunit - tint :: TypeR (SingletonType Int) - tint = TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeSingletonType))) + tint :: TypeR Int + tint = TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt))) -- Normal form data diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index d0d7c30ad..3cc00610f 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -29,7 +29,7 @@ module Data.Array.Accelerate.Representation.POS ( POSable(..), POS, POST, mkPOS, mkPOST, fromPOS, Product(..), Sum(..), - GroundType, Finite, ProductType(..), SumType(..), POSable.Generic) + GroundType, Finite, ProductType(..), SumType(..), POSable.Generic, type (++)) where -- import Data.Array.Accelerate.Type diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index 5ab86fdbd..f007a30e8 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -22,7 +22,7 @@ -- Portability : non-portable (GHC extensions) -- -module Data.Array.Accelerate.Sugar.Elt ( Elt(..), SingletonType, liftSingNumBinary, liftSingNumUnary ) +module Data.Array.Accelerate.Sugar.Elt ( Elt(..) ) where import Data.Array.Accelerate.Representation.Elt @@ -241,31 +241,8 @@ runQ $ do return (concat ss) - - --- These functions and the Num, Enum instance make sure we can use the range --- syntax and other Num stuff. We might have to provide these instances for all --- SingletonTypes maybe? -liftSingNumBinary :: (Elt a, EltR a ~ SingletonType a) => (a -> a -> a) -> SingletonType a -> SingletonType a -> SingletonType a -liftSingNumBinary f x y = fromElt $ f (toElt x) (toElt y) - -liftSingNumUnary :: (Elt a, EltR a ~ SingletonType a) => (a -> a) -> SingletonType a -> SingletonType a -liftSingNumUnary f x = fromElt $ f (toElt x) - -instance Num (SingletonType Int) where - (+) = liftSingNumBinary @Int (+) - (*) = liftSingNumBinary @Int (*) - (-) = liftSingNumBinary @Int (-) - abs = liftSingNumUnary @Int abs - signum = liftSingNumUnary @Int abs - fromInteger = fromElt . fromInteger @Int - -instance Enum (SingletonType Int) where - toEnum = fromElt - fromEnum = toElt - -instance {-# OVERLAPPING #-} Eq (SingletonType Int) where - (==) x y = toElt @Int x == toElt @Int y - -instance {-# OVERLAPPING #-} Ord (SingletonType Int) where - compare x y = compare (toElt @Int x) (toElt @Int y) +-- TODO: bring this back into TH +instance (POSable a, POSable b) => Elt (a, b) +instance (POSable a, POSable b, POSable c) => Elt (a, b, c) +instance (POSable a, POSable b, POSable c, POSable d) => Elt (a, b, c, d) +instance (POSable a, POSable b, POSable c, POSable d, POSable e) => Elt (a, b, c, d, e) diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index 7184bdbd8..dcdaa87e4 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -8,6 +8,7 @@ {-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE FlexibleInstances #-} {-# OPTIONS_HADDOCK hide #-} {-# OPTIONS_GHC -fno-warn-orphans #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} @@ -37,24 +38,43 @@ import GHC.TypeLits import GHC.Prim -type VecElt a = (Elt a, Prim a, IsSingle a) +type VecElt a = (Elt a, Prim a, IsSingle a, GroundType a, Num a) -instance GroundType (Vec n a) +instance VecElt a => POSable (Vec2 a) where + type Choices (Vec2 a) = 1 -instance (KnownNat n, VecElt a, Num a) => POSable (Vec n a) where - type Choices (Vec n a) = 1 + choices _ = 0 + + emptyChoices = 0 + + fromPOSable 0 (Cons (Pick a) (Cons (Pick b) Nil)) = Vec2 a b + + type Fields (Vec2 a) = '[ '[a], '[a]] + fields (Vec2 a b) = Cons (Pick a) (Cons (Pick b) Nil) + + emptyFields = PTCons (STSucc 0 STZero) (PTCons (STSucc 0 STZero) PTNil) + + +-- Elt instance automatically derived from POSable instance +instance VecElt a => Elt (Vec2 a) + + +instance VecElt a => POSable (Vec4 a) where + type Choices (Vec4 a) = 1 choices _ = 0 emptyChoices = 0 - fromPOSable 0 (Cons (Pick x) Nil) = x + fromPOSable 0 ( Cons (Pick a) (Cons (Pick b) (Cons (Pick c) (Cons (Pick d) Nil)))) = Vec4 a b c d - type Fields (Vec n a) = '[ '[Vec n a]] - fields x = Cons (Pick x) Nil + type Fields (Vec4 a) = '[ '[a], '[a], '[a], '[a]] + fields (Vec4 a b c d) = Cons (Pick a) (Cons (Pick b) (Cons (Pick c) (Cons (Pick d) Nil))) - emptyFields = PTCons (STSucc (replicateVecN 0) STZero) PTNil + emptyFields = PTCons (STSucc 0 STZero) (PTCons (STSucc 0 STZero) (PTCons (STSucc 0 STZero) (PTCons (STSucc 0 STZero) PTNil))) -- Elt instance automatically derived from POSable instance -instance (KnownNat n, VecElt a, Num a) => (Elt (Vec n a)) +instance VecElt a => Elt (Vec4 a) + +-- TODO: instances for 8 and 16, probably with some TH diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 3f3d956a6..68c5791a0 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -171,19 +171,14 @@ data IntegralType a where TypeWord16 :: IntegralType Word16 TypeWord32 :: IntegralType Word32 TypeWord64 :: IntegralType Word64 - TypeSingletonType :: IntegralType (SingletonType a) TypeTAG :: IntegralType (Finite n) - -type SingletonType x = (ScalarType (Int, ()), ()) - -- | Floating-point types supported in array computations. -- data FloatingType a where TypeHalf :: FloatingType Half TypeFloat :: FloatingType Float TypeDouble :: FloatingType Double - TypeFloatingSingletonType :: FloatingType (SingletonType a) -- | Numeric element types implement Num & Real -- From e1e00f79bfcb51d2db2e2701d0936c31d2bf9d3f Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 10 Mar 2022 16:51:52 +0100 Subject: [PATCH 25/67] create groundtypes with POSable instance via TH --- src/Data/Array/Accelerate/Pattern/TH.hs | 2 +- .../Array/Accelerate/Representation/POS.hs | 92 +----------- src/Data/Array/Accelerate/Sugar/POS.hs | 135 ++++++++++++++++++ src/Data/Array/Accelerate/Type.hs | 12 +- 4 files changed, 146 insertions(+), 95 deletions(-) create mode 100644 src/Data/Array/Accelerate/Sugar/POS.hs diff --git a/src/Data/Array/Accelerate/Pattern/TH.hs b/src/Data/Array/Accelerate/Pattern/TH.hs index 0323f8d1a..c9ec918ac 100644 --- a/src/Data/Array/Accelerate/Pattern/TH.hs +++ b/src/Data/Array/Accelerate/Pattern/TH.hs @@ -293,7 +293,7 @@ mkConS tn' tvs' prev' next' tag' con' = do ++ map varE xs ++ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat fs1) - tagged = [| Exp $ SmartExp $ Pair (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeWord8))) $(litE (IntegerL (toInteger tag))))) $vs |] + tagged = [| Exp $ SmartExp $ Pair (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeTAG))) $(litE (IntegerL (toInteger tag))))) $vs |] body = clause (map (\x -> [p| (Exp $(varP x)) |]) xs) (normalB tagged) [] r <- sequence [ sigD fun sig diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index 3cc00610f..b62d4521b 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -29,7 +29,8 @@ module Data.Array.Accelerate.Representation.POS ( POSable(..), POS, POST, mkPOS, mkPOST, fromPOS, Product(..), Sum(..), - GroundType, Finite, ProductType(..), SumType(..), POSable.Generic, type (++)) + GroundType(..), Finite, ProductType(..), SumType(..), POSable.Generic, type (++), + mkPOSableGroundType) where -- import Data.Array.Accelerate.Type @@ -45,6 +46,7 @@ import GHC.TypeLits import Data.Type.POSable.POSable as POSable import Data.Type.POSable.Representation import Data.Type.POSable.Instances +import Data.Type.POSable.TH import Data.Int import Data.Word @@ -89,94 +91,6 @@ type POST a = (Finite (Choices a), ProductType (Fields a)) mkPOST :: forall a . (POSable a) => POST a mkPOST = (0, emptyFields @a) -runQ $ do - let - -- XXX: we might want to do the digItOut trick used by FromIntegral? - -- - integralTypes :: [Name] - integralTypes = - [ ''Int - , ''Int8 - , ''Int16 - , ''Int32 - , ''Int64 - , ''Word - , ''Word8 - , ''Word16 - , ''Word32 - , ''Word64 - ] - - floatingTypes :: [Name] - floatingTypes = - [ ''Half - , ''Float - , ''Double - ] - - newtypes :: [Name] - newtypes = - [ ''CShort - , ''CUShort - , ''CInt - , ''CUInt - , ''CLong - , ''CULong - , ''CLLong - , ''CULLong - , ''CFloat - , ''CDouble - , ''CChar - , ''CSChar - , ''CUChar - ] - - mkSimple :: Name -> Q [Dec] - mkSimple name = - let t = conT name - in - [d| - instance GroundType $t - - instance POSable $t where - type Choices $t = 1 - choices _ = 0 - - type Fields $t = '[ '[$t]] - fields x = Cons (Pick x) Nil - - fromPOSable 0 (Cons (Pick x) Nil) = x - fromPOSable _ _ = error "index out of range" - - emptyFields = PTCons (STSucc 0 STZero) PTNil - |] - - mkTuple :: Int -> Q Dec - mkTuple n = - let - xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - ts = map varT xs - res = tupT ts - ctx = mapM (appT [t| POSable |]) ts - in - instanceD ctx [t| POSable $res |] [] - - mkNewtype :: Name -> Q [Dec] - mkNewtype name = do - r <- reify name - base <- case r of - TyConI (NewtypeD _ _ _ _ (NormalC _ [(_, ConT b)]) _) -> return b - _ -> error "unexpected case generating newtype Elt instance" - -- - [d| instance POSable $(conT name) - |] - -- - ss <- mapM mkSimple (integralTypes ++ floatingTypes) - ns <- mapM mkNewtype newtypes - -- ts <- mapM mkTuple [2..16] - -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] - return (concat ss ++ concat ns) - type family Snoc2List x = xs | xs -> x where Snoc2List () = '[] diff --git a/src/Data/Array/Accelerate/Sugar/POS.hs b/src/Data/Array/Accelerate/Sugar/POS.hs new file mode 100644 index 000000000..665576941 --- /dev/null +++ b/src/Data/Array/Accelerate/Sugar/POS.hs @@ -0,0 +1,135 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE TypeFamilyDependencies #-} +{-# OPTIONS_HADDOCK hide #-} + +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +-- This is needed to derive POSable for tuples of size more then 4 +{-# OPTIONS_GHC -fconstraint-solver-iterations=16 #-} +-- | +-- Module : Data.Array.Accelerate.Representation.POS +-- Copyright : [2008..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Sugar.POS + where + +-- import Data.Array.Accelerate.Type + +import Data.Bits +import Data.Char +import Data.Kind +import Language.Haskell.TH.Extra hiding ( Type ) + +import GHC.Generics +import GHC.TypeLits + +import Data.Type.POSable.POSable as POSable +import Data.Type.POSable.Representation +import Data.Type.POSable.Instances +import Data.Type.POSable.TH + +import Data.Int +import Data.Word +import Numeric.Half +import Foreign.C.Types + +import Data.Array.Accelerate.Type + +runQ $ do + let + -- XXX: we might want to do the digItOut trick used by FromIntegral? + -- + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + ] + + newtypes :: [Name] + newtypes = + [ ''CShort + , ''CUShort + , ''CInt + , ''CUInt + , ''CLong + , ''CULong + , ''CLLong + , ''CULLong + , ''CFloat + , ''CDouble + , ''CChar + , ''CSChar + , ''CUChar + ] + + mkSimple :: Name -> Name -> Q [Dec] + mkSimple typ name = + let t = conT name + tt = conT typ + tr = pure $ ConE $ mkName ("Type" ++ nameBase name) + in + [d| + instance GroundType $t where + type TypeRep $t = $tt $t + + mkTypeRep = $tr + |] + + mkTuple :: Int -> Q Dec + mkTuple n = + let + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + ts = map varT xs + res = tupT ts + ctx = mapM (appT [t| POSable |]) ts + in + instanceD ctx [t| POSable $res |] [] + + mkNewtype :: Name -> Q [Dec] + mkNewtype name = do + r <- reify name + base <- case r of + TyConI (NewtypeD _ _ _ _ (NormalC _ [(_, ConT b)]) _) -> return b + _ -> error "unexpected case generating newtype Elt instance" + -- + mkPOSableGroundType name + -- + si <- mapM (mkSimple ''IntegralType) integralTypes + sf <- mapM (mkSimple ''FloatingType) floatingTypes + ns <- mapM mkPOSableGroundType (floatingTypes ++ integralTypes) + -- ns <- mapM mkNewtype newtypes + -- ts <- mapM mkTuple [2..16] + -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] + return (concat si ++ concat sf ++ concat ns) + \ No newline at end of file diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 68c5791a0..becfaab77 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -101,11 +101,11 @@ type family POStoEltR (cs :: Nat) fs :: Type where type family FlattenProduct (xss :: f (g a)) = (r :: Type) where FlattenProduct '[] = () - FlattenProduct (x ': xs) = (ScalarType (FlattenSum x), FlattenProduct xs) + FlattenProduct (x ': xs) = (ScalarType (Sum x), FlattenProduct xs) -type family FlattenSum (xss :: f a) = (r :: Type) | r -> f where - FlattenSum '[] = () - FlattenSum (x ': xs) = (x, FlattenSum xs) +type family FlattenProductType (xss :: f (g a)) = (r :: Type) where + FlattenProductType '[] = () + FlattenProductType (x ': xs) = (SumType x, FlattenProductType xs) flattenProduct :: Product a -> FlattenProduct a flattenProduct Nil = () @@ -194,9 +194,11 @@ data BoundedType a where -- | All scalar element types implement Eq & Ord -- data ScalarType a where - SumScalarType :: Sum a -> ScalarType (FlattenSum a) + SumScalarType :: Sum a -> ScalarType (Sum a) + SumScalarTypeR :: SumType a -> ScalarType (SumType a) TagScalarType :: Finite n -> ScalarType (Finite n) SingleScalarType :: SingleType a -> ScalarType a + SingletonScalarType :: ScalarType a VectorScalarType :: VectorType (Vec n a) -> ScalarType (Vec n a) data SingleType a where From 520adc6fd3a7fff956e49d7184eb48fdfebb5af7 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Fri, 11 Mar 2022 15:10:20 +0100 Subject: [PATCH 26/67] default definition for eltR, including ugly hacks --- src/Data/Array/Accelerate/Sugar/Elt.hs | 42 +++++++++++++++++++++++--- src/Data/Array/Accelerate/Sugar/POS.hs | 25 ++++++--------- src/Data/Array/Accelerate/Sugar/Vec.hs | 14 ++------- 3 files changed, 49 insertions(+), 32 deletions(-) diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index f007a30e8..6a6d200df 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -29,6 +29,7 @@ import Data.Array.Accelerate.Representation.Elt import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.POS +import Data.Array.Accelerate.Sugar.POS import Data.Array.Accelerate.Type import Data.Bits @@ -38,7 +39,9 @@ import Data.Word import Language.Haskell.TH.Extra hiding ( Type ) import GHC.Generics - +import GHC.TypeLits +import Unsafe.Coerce +import Data.Type.Equality -- | The 'Elt' class characterises the allowable array element types, and -- hence the types which can appear in scalar Accelerate expressions of @@ -91,14 +94,43 @@ class Elt a where fromElt :: a -> EltR a toElt :: EltR a -> a - -- default eltR :: (POSable a) => EltRT a - -- eltR = mkPOST @a + default eltR :: (POSable a, POStoEltR (Choices a) (Fields a) ~ EltR a) => TypeR (EltR a) + eltR = mkEltRT @a default fromElt :: (POSable a, POStoEltR (Choices a) (Fields a) ~ EltR a) => a -> EltR a - fromElt a = mkEltR a + fromElt = mkEltR default toElt :: (POSable a, POStoEltR (Choices a) (Fields a) ~ EltR a) => EltR a -> a - toElt a = fromEltR a + toElt = fromEltR + +flattenProductType :: ProductType a -> TypeR (FlattenProductType a) +flattenProductType PTNil = TupRunit +flattenProductType (PTCons x xs) = TupRpair (TupRsingle (SumScalarTypeR x)) (flattenProductType xs) + +mkEltRT :: forall a . (POSable a) => TypeR (POStoEltR (Choices a) (Fields a)) +mkEltRT = case natVal cs of + -- This distinction is hard to express in a type-correct way, + -- hence the unsafeCoerce's + 1 -> case emptyFields @a of + PTCons (STSucc x STZero) PTNil -> TupRsingle (typeRep2scalarT (unsafeCoerce x)) + x -> unsafeCoerce $ flattenProductType x + _ -> unsafeCoerce $ TupRpair (TupRsingle (TagScalarType cs)) (flattenProductType (emptyFields @a)) + where + cs = emptyChoices @a + + -- This means we should NEVER add a GroundType x for which TypeRep x !~ ScalarType x, + -- because that will lead to nasty bugs. + -- We should never expose GroundType to Accelerate users (which wouldn't make + -- much sense anyhow, why would a user add a machine type?), and also try + -- to limit the visibility of GroundType within Accelerate itself. + -- We cannot enforce this however without inlining the POSable library and replacing + -- the definition of TypeRep. + typeRep2scalarT :: forall x . TypeRep x -> ScalarType x + typeRep2scalarT a + | Refl :: (TypeRep x :~: ScalarType x) <- unsafeCoerce Refl + = a + + untag :: TypeR t -> TagR t untag TupRunit = TagRunit diff --git a/src/Data/Array/Accelerate/Sugar/POS.hs b/src/Data/Array/Accelerate/Sugar/POS.hs index 665576941..5f8a76d5d 100644 --- a/src/Data/Array/Accelerate/Sugar/POS.hs +++ b/src/Data/Array/Accelerate/Sugar/POS.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} @@ -12,6 +11,7 @@ {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE TypeFamilyDependencies #-} {-# OPTIONS_HADDOCK hide #-} +{-# OPTIONS_GHC -ddump-splices #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -- This is needed to derive POSable for tuples of size more then 4 @@ -31,17 +31,10 @@ module Data.Array.Accelerate.Sugar.POS -- import Data.Array.Accelerate.Type -import Data.Bits -import Data.Char -import Data.Kind import Language.Haskell.TH.Extra hiding ( Type ) -import GHC.Generics -import GHC.TypeLits - import Data.Type.POSable.POSable as POSable import Data.Type.POSable.Representation -import Data.Type.POSable.Instances import Data.Type.POSable.TH import Data.Int @@ -51,6 +44,7 @@ import Foreign.C.Types import Data.Array.Accelerate.Type + runQ $ do let -- XXX: we might want to do the digItOut trick used by FromIntegral? @@ -93,17 +87,16 @@ runQ $ do , ''CUChar ] - mkSimple :: Name -> Name -> Q [Dec] - mkSimple typ name = + mkSimple :: Name -> Name -> Name -> Q [Dec] + mkSimple typ val name = let t = conT name - tt = conT typ - tr = pure $ ConE $ mkName ("Type" ++ nameBase name) + tr = pure $ AppE (ConE val) (ConE $ mkName ("Type" ++ nameBase name)) in [d| instance GroundType $t where - type TypeRep $t = $tt $t + type TypeRep $t = ScalarType $t - mkTypeRep = $tr + mkTypeRep = SingleScalarType (NumSingleType $tr) |] mkTuple :: Int -> Q Dec @@ -125,8 +118,8 @@ runQ $ do -- mkPOSableGroundType name -- - si <- mapM (mkSimple ''IntegralType) integralTypes - sf <- mapM (mkSimple ''FloatingType) floatingTypes + si <- mapM (mkSimple ''IntegralType 'IntegralNumType) integralTypes + sf <- mapM (mkSimple ''FloatingType 'FloatingNumType) floatingTypes ns <- mapM mkPOSableGroundType (floatingTypes ++ integralTypes) -- ns <- mapM mkNewtype newtypes -- ts <- mapM mkTuple [2..16] diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index dcdaa87e4..0c40ba407 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -1,9 +1,7 @@ {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE MagicHash #-} {-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE NoStarIsType #-} @@ -26,16 +24,10 @@ module Data.Array.Accelerate.Sugar.Vec where import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Representation.Tag -import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.POS import Data.Array.Accelerate.Type import Data.Primitive.Types import Data.Primitive.Vec -import Data.Kind - -import GHC.TypeLits -import GHC.Prim type VecElt a = (Elt a, Prim a, IsSingle a, GroundType a, Num a) @@ -52,7 +44,7 @@ instance VecElt a => POSable (Vec2 a) where type Fields (Vec2 a) = '[ '[a], '[a]] fields (Vec2 a b) = Cons (Pick a) (Cons (Pick b) Nil) - emptyFields = PTCons (STSucc 0 STZero) (PTCons (STSucc 0 STZero) PTNil) + emptyFields = PTCons (STSucc (mkTypeRep @a) STZero) (PTCons (STSucc (mkTypeRep @a) STZero) PTNil) -- Elt instance automatically derived from POSable instance @@ -71,7 +63,7 @@ instance VecElt a => POSable (Vec4 a) where type Fields (Vec4 a) = '[ '[a], '[a], '[a], '[a]] fields (Vec4 a b c d) = Cons (Pick a) (Cons (Pick b) (Cons (Pick c) (Cons (Pick d) Nil))) - emptyFields = PTCons (STSucc 0 STZero) (PTCons (STSucc 0 STZero) (PTCons (STSucc 0 STZero) (PTCons (STSucc 0 STZero) PTNil))) + emptyFields = PTCons (STSucc (mkTypeRep @a) STZero) (PTCons (STSucc (mkTypeRep @a) STZero) (PTCons (STSucc (mkTypeRep @a) STZero) (PTCons (STSucc (mkTypeRep @a) STZero) PTNil))) -- Elt instance automatically derived from POSable instance From df1fd120daf8d0b14f43e94f850166788293555a Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 29 Mar 2022 12:23:24 +0200 Subject: [PATCH 27/67] add OuterChoices / outerChoice to POS instances --- .../Array/Accelerate/Representation/Shape.hs | 14 ++++-- src/Data/Array/Accelerate/Sugar/Vec.hs | 5 +++ src/Data/Array/Accelerate/Type.hs | 45 +++++++++++++------ 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/Shape.hs b/src/Data/Array/Accelerate/Representation/Shape.hs index f3d0cad5c..ce7b78b6d 100644 --- a/src/Data/Array/Accelerate/Representation/Shape.hs +++ b/src/Data/Array/Accelerate/Representation/Shape.hs @@ -220,6 +220,9 @@ instance POSable (ShapeR ()) where emptyFields = PTNil + type OuterChoices (ShapeR ()) = 1 + outerChoice _ = 0 + instance (POSable (ShapeR sh)) => POSable (ShapeR (sh, Int)) where type Choices (ShapeR (sh, Int)) = 1 @@ -229,8 +232,13 @@ instance (POSable (ShapeR sh)) => POSable (ShapeR (sh, Int)) where fromPOSable 0 (Cons _ xs) = ShapeRsnoc (fromPOSable 0 xs) - type Fields (ShapeR (sh, Int)) = '[] ': Fields (ShapeR sh) + type Fields (ShapeR (sh, Int)) = '[Undef] ': Fields (ShapeR sh) + + fields (ShapeRsnoc sh) = Cons (Pick Undef) (fields sh) + + emptyFields = PTCons (STSucc Undef STZero) (emptyFields @(ShapeR sh)) + + type OuterChoices (ShapeR (sh, Int)) = 1 + outerChoice _ = 0 - fields (ShapeRsnoc sh) = Cons Undef (fields sh) - emptyFields = PTCons STZero (emptyFields @(ShapeR sh)) diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index 0c40ba407..d9d94ae1d 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -46,6 +46,8 @@ instance VecElt a => POSable (Vec2 a) where emptyFields = PTCons (STSucc (mkTypeRep @a) STZero) (PTCons (STSucc (mkTypeRep @a) STZero) PTNil) + type OuterChoices (Vec2 a) = 1 + outerChoice _ = 0 -- Elt instance automatically derived from POSable instance instance VecElt a => Elt (Vec2 a) @@ -65,6 +67,9 @@ instance VecElt a => POSable (Vec4 a) where emptyFields = PTCons (STSucc (mkTypeRep @a) STZero) (PTCons (STSucc (mkTypeRep @a) STZero) (PTCons (STSucc (mkTypeRep @a) STZero) (PTCons (STSucc (mkTypeRep @a) STZero) PTNil))) + type OuterChoices (Vec4 a) = 1 + outerChoice _ = 0 + -- Elt instance automatically derived from POSable instance instance VecElt a => Elt (Vec4 a) diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index becfaab77..0db76532a 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -76,6 +76,7 @@ import Data.Array.Accelerate.Representation.POS import Data.Primitive.Vec import Data.Bits +import Data.Proxy import Data.Int import Data.Primitive.Types import Data.Type.Equality @@ -96,7 +97,7 @@ import Unsafe.Coerce type family POStoEltR (cs :: Nat) fs :: Type where POStoEltR 1 '[ '[x]] = x -- singletontypes - POStoEltR 1 x = FlattenProduct x -- tagless types (could / should be represented without Sums in the Product) + POStoEltR 1 x = FlattenProduct x -- tagless types POStoEltR n x = (Finite n, FlattenProduct x) -- all other types type family FlattenProduct (xss :: f (g a)) = (r :: Type) where @@ -111,29 +112,44 @@ flattenProduct :: Product a -> FlattenProduct a flattenProduct Nil = () flattenProduct (Cons x xs) = (SumScalarType x, flattenProduct xs) --- unFlattenProduct :: FlattenProduct a -> Product a --- unFlattenProduct () = Nil --- unFlattenProduct (SumScalarType x, xs) = Cons x (unFlattenProduct xs) +-- This might typecheck without unsafeCoerce by using cmpNat, but that requires a very new version of base mkEltR :: forall a . (POSable a) => a -> POStoEltR (Choices a) (Fields a) -mkEltR x = case natVal cs of - -- This distinction is hard to express in a type-correct way, - -- hence the unsafeCoerce's - 1 -> case emptyFields @a of - PTCons (STSucc _ STZero) PTNil | Cons (Pick f) Nil <- fields x -> unsafeCoerce f - _ -> unsafeCoerce fs - _ -> unsafeCoerce (cs, fs) +mkEltR x = case sameNat cs (Proxy :: Proxy 1) of + Just Refl -> case emptyFields @a of + -- Lots of cases because GHC does not understand inequality + -- First up: singleton type + PTCons (STSucc _ STZero) PTNil | Cons (Pick f) Nil <- fields x -> f + -- Singleton type, but we don't get an actual value out of + -- the call to `fields` (Should never occur). + PTCons (STSucc _ STZero) PTNil -> error "Value does not match representation" + -- unit type + PTNil -> fs + -- weird type with no value in the first field (thus having + -- no constructors) + PTCons STZero _ -> fs + -- weird type with a sum in the first field (weird because + -- we already have `Choices a ~ 1` in scope) + PTCons (STSucc _ (STSucc _ _)) _ -> fs + -- type with two or more fields + PTCons _ (PTCons _ _) -> fs + -- unsafeCoerce because we cannot prove to the compiler that + -- `Choices a !~ 1` (and thus the 3th branch of the POStoEltR + -- type family holds) + Nothing -> unsafeCoerce (cs, fs) where cs = choices x fs = flattenProduct (fields x) +-- TODO: this might not be correct +-- This might typecheck with cmpNat, but that requires a very new version of base fromEltR :: forall a . (POSable a) => POStoEltR (Choices a) (Fields a) -> a -fromEltR x = case natVal (emptyChoices @a) of - 1 -> case emptyFields @a of +fromEltR x = case sameNat (emptyChoices @a) (Proxy :: Proxy 1) of + Just Refl -> case emptyFields @a of PTCons (STSucc _ STZero) PTNil -> unsafeCoerce x _ -> fromPOSable 0 (unsafeCoerce x) - _ -> uncurry fromPOSable (unsafeCoerce x) + Nothing -> uncurry fromPOSable (unsafeCoerce x) -- Scalar types -- ------------ @@ -240,6 +256,7 @@ instance Show (VectorType a) where instance Show (ScalarType a) where show (SingleScalarType ty) = show ty show (VectorScalarType ty) = show ty + -- TODO add all constructors formatIntegralType :: Format r (IntegralType a -> r) formatIntegralType = later $ \case From c299d35e1f4ffba2278e8dcbe63225881adc4d21 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 29 Mar 2022 13:16:15 +0200 Subject: [PATCH 28/67] remove unused stuff from Representation/POS --- .../Array/Accelerate/Representation/POS.hs | 74 +------------------ 1 file changed, 2 insertions(+), 72 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index b62d4521b..8edae74b5 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -1,17 +1,4 @@ -{-# LANGUAGE AllowAmbiguousTypes #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE TypeFamilyDependencies #-} {-# OPTIONS_HADDOCK hide #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} @@ -28,70 +15,13 @@ -- module Data.Array.Accelerate.Representation.POS ( - POSable(..), POS, POST, mkPOS, mkPOST, fromPOS, Product(..), Sum(..), + POSable(..), Product(..), Sum(..), GroundType(..), Finite, ProductType(..), SumType(..), POSable.Generic, type (++), mkPOSableGroundType) where --- import Data.Array.Accelerate.Type - -import Data.Bits -import Data.Char -import Data.Kind -import Language.Haskell.TH.Extra hiding ( Type ) - -import GHC.Generics -import GHC.TypeLits import Data.Type.POSable.POSable as POSable import Data.Type.POSable.Representation -import Data.Type.POSable.Instances +import Data.Type.POSable.Instances () import Data.Type.POSable.TH - -import Data.Int -import Data.Word -import Numeric.Half -import Foreign.C.Types - --- import Data.Array.Accelerate.Representation.Type - -type POS a = (Finite (Choices a), Product (Fields a)) - --- type family EltR (cs :: Nat) (fs :: f (g a)) = (r :: Type) where --- EltR 1 x = FlattenProduct x --- EltR n x = (Finite n, FlattenProduct x) - --- type family FlattenProduct (xss :: f (g a)) :: Type where --- FlattenProduct '[] = () --- FlattenProduct '[ '[x]] = x --- FlattenProduct (x ': xs) = (FlattenSum x, FlattenProduct xs) - --- type family FlattenSum (xss :: f a) :: Type where --- FlattenSum '[] = () --- FlattenSum (x ': xs) = (x, FlattenSum xs) - --- mkEltR :: (POSable a) => a -> EltR (Choices a) (Fields a) --- mkEltR x = undefined --- where --- cs = choices x --- fs = fields x - --- -- productToTupR :: Product a -> TypeR (FlattenProduct a) --- -- productToTupR Nil = TupRunit --- -- productToTupR (Cons x xs) = TupRpair x (productToTupR xs) - -mkPOS :: (POSable a) => a -> POS a -mkPOS x = (choices x, fields x) - -fromPOS :: (POSable a) => POS a -> a -fromPOS (cs, fs) = fromPOSable cs fs - -type POST a = (Finite (Choices a), ProductType (Fields a)) - -mkPOST :: forall a . (POSable a) => POST a -mkPOST = (0, emptyFields @a) - - -type family Snoc2List x = xs | xs -> x where - Snoc2List () = '[] - Snoc2List (xs, x) = (x ': Snoc2List xs) From fac73b7065a01704835f4005b6b10999825c22c1 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 7 Apr 2022 11:09:33 +0200 Subject: [PATCH 29/67] convert Sums to tuple representation --- .../Array/Accelerate/Representation/POS.hs | 4 +- .../Array/Accelerate/Representation/Tag.hs | 6 +- src/Data/Array/Accelerate/Sugar/Elt.hs | 84 +++++++++++++------ src/Data/Array/Accelerate/Sugar/POS.hs | 12 ++- src/Data/Array/Accelerate/Sugar/Vec.hs | 6 +- src/Data/Array/Accelerate/Type.hs | 45 +++++++--- 6 files changed, 104 insertions(+), 53 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index 8edae74b5..93a69dd61 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -16,8 +16,8 @@ module Data.Array.Accelerate.Representation.POS ( POSable(..), Product(..), Sum(..), - GroundType(..), Finite, ProductType(..), SumType(..), POSable.Generic, type (++), - mkPOSableGroundType) + Ground(..), Finite, ProductType(..), SumType(..), POSable.Generic, type (++), + mkPOSableGround) where diff --git a/src/Data/Array/Accelerate/Representation/Tag.hs b/src/Data/Array/Accelerate/Representation/Tag.hs index ed7e07e80..31d3e82c0 100644 --- a/src/Data/Array/Accelerate/Representation/Tag.hs +++ b/src/Data/Array/Accelerate/Representation/Tag.hs @@ -11,7 +11,7 @@ -- Portability : non-portable (GHC extensions) -- -module Data.Array.Accelerate.Representation.Tag +module Data.Array.Accelerate.Representation.Tag (TAG, TagR(..), rnfTag, liftTag) where import Data.Array.Accelerate.Type @@ -19,10 +19,6 @@ import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra --- | The type of the runtime value used to distinguish constructor --- alternatives in a sum type. --- -type TAG = Word8 -- | This structure both witnesses the layout of our representation types -- (as TupR does) and represents a complete path of pattern matching diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index 6a6d200df..1bd6af341 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -10,6 +10,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ConstraintKinds #-} {-# OPTIONS_HADDOCK hide #-} {-# OPTIONS_GHC -ddump-splices #-} -- | @@ -25,23 +26,21 @@ module Data.Array.Accelerate.Sugar.Elt ( Elt(..) ) where -import Data.Array.Accelerate.Representation.Elt -import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.POS -import Data.Array.Accelerate.Sugar.POS +import Data.Array.Accelerate.Representation.Tag +import Data.Array.Accelerate.Sugar.POS () import Data.Array.Accelerate.Type -import Data.Bits import Data.Char import Data.Kind -import Data.Word import Language.Haskell.TH.Extra hiding ( Type ) -import GHC.Generics import GHC.TypeLits import Unsafe.Coerce import Data.Type.Equality +import Data.Proxy +import Data.Typeable -- | The 'Elt' class characterises the allowable array element types, and -- hence the types which can appear in scalar Accelerate expressions of @@ -103,34 +102,69 @@ class Elt a where default toElt :: (POSable a, POStoEltR (Choices a) (Fields a) ~ EltR a) => EltR a -> a toElt = fromEltR -flattenProductType :: ProductType a -> TypeR (FlattenProductType a) +flattenProductType :: ProductType a -> TypeR (FlattenProduct a) flattenProductType PTNil = TupRunit -flattenProductType (PTCons x xs) = TupRpair (TupRsingle (SumScalarTypeR x)) (flattenProductType xs) +flattenProductType (PTCons x xs) = TupRpair (TupRsingle (flattenSumType x)) (flattenProductType xs) + +flattenSumType :: SumType a -> ScalarType (SumScalar (FlattenSum a)) +flattenSumType STZero = SumScalarType ZeroScalarType +flattenSumType (STSucc (x :: x) xs) + = SumScalarType (SuccScalarType (mkScalarType x) (flattenSumType xs)) + +-- This is an unsafe conversion, and should be kept strictly in sync with the +-- set of types that implement Ground +mkScalarType :: forall a . (Typeable a, Ground a) => a -> ScalarType a +mkScalarType _ + | Just Refl <- eqT @a @Int + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Int8 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Int16 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Int32 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Int64 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Word + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Word8 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Word16 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Word32 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Word64 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Half + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Float + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Double + = scalarType @a mkEltRT :: forall a . (POSable a) => TypeR (POStoEltR (Choices a) (Fields a)) -mkEltRT = case natVal cs of +mkEltRT = case sameNat cs (Proxy :: Proxy 1) of -- This distinction is hard to express in a type-correct way, -- hence the unsafeCoerce's - 1 -> case emptyFields @a of - PTCons (STSucc x STZero) PTNil -> TupRsingle (typeRep2scalarT (unsafeCoerce x)) + Just Refl -> case emptyFields @a of + PTCons (STSucc x STZero) PTNil -> TupRsingle (mkScalarType x) x -> unsafeCoerce $ flattenProductType x - _ -> unsafeCoerce $ TupRpair (TupRsingle (TagScalarType cs)) (flattenProductType (emptyFields @a)) + Nothing -> unsafeCoerce $ TupRpair (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeTAG)))) (flattenProductType (emptyFields @a)) where cs = emptyChoices @a - -- This means we should NEVER add a GroundType x for which TypeRep x !~ ScalarType x, - -- because that will lead to nasty bugs. - -- We should never expose GroundType to Accelerate users (which wouldn't make - -- much sense anyhow, why would a user add a machine type?), and also try - -- to limit the visibility of GroundType within Accelerate itself. - -- We cannot enforce this however without inlining the POSable library and replacing - -- the definition of TypeRep. - typeRep2scalarT :: forall x . TypeRep x -> ScalarType x - typeRep2scalarT a - | Refl :: (TypeRep x :~: ScalarType x) <- unsafeCoerce Refl - = a - - untag :: TypeR t -> TagR t untag TupRunit = TagRunit diff --git a/src/Data/Array/Accelerate/Sugar/POS.hs b/src/Data/Array/Accelerate/Sugar/POS.hs index 5f8a76d5d..70ecd806e 100644 --- a/src/Data/Array/Accelerate/Sugar/POS.hs +++ b/src/Data/Array/Accelerate/Sugar/POS.hs @@ -90,13 +90,11 @@ runQ $ do mkSimple :: Name -> Name -> Name -> Q [Dec] mkSimple typ val name = let t = conT name - tr = pure $ AppE (ConE val) (ConE $ mkName ("Type" ++ nameBase name)) + -- tr = pure $ AppE (ConE val) (ConE $ mkName ("Type" ++ nameBase name)) in [d| - instance GroundType $t where - type TypeRep $t = ScalarType $t - - mkTypeRep = SingleScalarType (NumSingleType $tr) + instance Ground $t where + mkGround = 0 |] mkTuple :: Int -> Q Dec @@ -116,11 +114,11 @@ runQ $ do TyConI (NewtypeD _ _ _ _ (NormalC _ [(_, ConT b)]) _) -> return b _ -> error "unexpected case generating newtype Elt instance" -- - mkPOSableGroundType name + mkPOSableGround name -- si <- mapM (mkSimple ''IntegralType 'IntegralNumType) integralTypes sf <- mapM (mkSimple ''FloatingType 'FloatingNumType) floatingTypes - ns <- mapM mkPOSableGroundType (floatingTypes ++ integralTypes) + ns <- mapM mkPOSableGround (floatingTypes ++ integralTypes) -- ns <- mapM mkNewtype newtypes -- ts <- mapM mkTuple [2..16] -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index d9d94ae1d..4bcc3b2e9 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -30,7 +30,7 @@ import Data.Primitive.Types import Data.Primitive.Vec -type VecElt a = (Elt a, Prim a, IsSingle a, GroundType a, Num a) +type VecElt a = (Elt a, Prim a, IsSingle a, Ground a, Num a) instance VecElt a => POSable (Vec2 a) where type Choices (Vec2 a) = 1 @@ -44,7 +44,7 @@ instance VecElt a => POSable (Vec2 a) where type Fields (Vec2 a) = '[ '[a], '[a]] fields (Vec2 a b) = Cons (Pick a) (Cons (Pick b) Nil) - emptyFields = PTCons (STSucc (mkTypeRep @a) STZero) (PTCons (STSucc (mkTypeRep @a) STZero) PTNil) + emptyFields = PTCons (STSucc (mkGround @a) STZero) (PTCons (STSucc (mkGround @a) STZero) PTNil) type OuterChoices (Vec2 a) = 1 outerChoice _ = 0 @@ -65,7 +65,7 @@ instance VecElt a => POSable (Vec4 a) where type Fields (Vec4 a) = '[ '[a], '[a], '[a], '[a]] fields (Vec4 a b c d) = Cons (Pick a) (Cons (Pick b) (Cons (Pick c) (Cons (Pick d) Nil))) - emptyFields = PTCons (STSucc (mkTypeRep @a) STZero) (PTCons (STSucc (mkTypeRep @a) STZero) (PTCons (STSucc (mkTypeRep @a) STZero) (PTCons (STSucc (mkTypeRep @a) STZero) PTNil))) + emptyFields = PTCons (STSucc (mkGround @a) STZero) (PTCons (STSucc (mkGround @a) STZero) (PTCons (STSucc (mkGround @a) STZero) (PTCons (STSucc (mkGround @a) STZero) PTNil))) type OuterChoices (Vec4 a) = 1 outerChoice _ = 0 diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 0db76532a..00a641c43 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -95,23 +95,41 @@ import GHC.TypeLits import Unsafe.Coerce +-- | The type of the runtime value used to distinguish constructor +-- alternatives in a sum type. +-- +type TAG = Word8 + + type family POStoEltR (cs :: Nat) fs :: Type where POStoEltR 1 '[ '[x]] = x -- singletontypes POStoEltR 1 x = FlattenProduct x -- tagless types - POStoEltR n x = (Finite n, FlattenProduct x) -- all other types + POStoEltR n x = (TAG, FlattenProduct x) -- all other types -type family FlattenProduct (xss :: f (g a)) = (r :: Type) where +type family FlattenProduct (xss :: [[a]]) :: Type where FlattenProduct '[] = () - FlattenProduct (x ': xs) = (ScalarType (Sum x), FlattenProduct xs) + FlattenProduct (x ': xs) = (SumScalar (FlattenSum x), FlattenProduct xs) + +type family FlattenSum (xs :: [a]) :: Type where + FlattenSum '[] = () + FlattenSum (x ': xs) = (x, FlattenSum xs) -type family FlattenProductType (xss :: f (g a)) = (r :: Type) where +type family FlattenProductType (xss :: [[a]]) :: Type where FlattenProductType '[] = () - FlattenProductType (x ': xs) = (SumType x, FlattenProductType xs) + FlattenProductType (x ': xs) = (SumScalarType (FlattenSumType x), FlattenProductType xs) + +type family FlattenSumType (xs :: [a]) :: Type where + FlattenSumType '[] = () + FlattenSumType (x ': xs) = (x, FlattenSumType xs) + flattenProduct :: Product a -> FlattenProduct a flattenProduct Nil = () -flattenProduct (Cons x xs) = (SumScalarType x, flattenProduct xs) +flattenProduct (Cons x xs) = (flattenSum x, flattenProduct xs) +flattenSum :: Sum a -> SumScalar (FlattenSum a) +flattenSum (Pick x) = PickScalar x +flattenSum (Skip xs) = SkipScalar (flattenSum xs) -- This might typecheck without unsafeCoerce by using cmpNat, but that requires a very new version of base mkEltR :: forall a . (POSable a) => a -> POStoEltR (Choices a) (Fields a) @@ -187,7 +205,7 @@ data IntegralType a where TypeWord16 :: IntegralType Word16 TypeWord32 :: IntegralType Word32 TypeWord64 :: IntegralType Word64 - TypeTAG :: IntegralType (Finite n) + TypeTAG :: IntegralType TAG -- | Floating-point types supported in array computations. -- @@ -210,12 +228,17 @@ data BoundedType a where -- | All scalar element types implement Eq & Ord -- data ScalarType a where - SumScalarType :: Sum a -> ScalarType (Sum a) - SumScalarTypeR :: SumType a -> ScalarType (SumType a) - TagScalarType :: Finite n -> ScalarType (Finite n) SingleScalarType :: SingleType a -> ScalarType a - SingletonScalarType :: ScalarType a VectorScalarType :: VectorType (Vec n a) -> ScalarType (Vec n a) + SumScalarType :: SumScalarType a -> ScalarType a + +data SumScalar x where + PickScalar :: a -> SumScalar (a, b) + SkipScalar :: SumScalar b -> SumScalar (a, b) + +data SumScalarType a where + SuccScalarType :: ScalarType a -> ScalarType (SumScalar b) -> SumScalarType (SumScalar (a, b)) + ZeroScalarType :: SumScalarType (SumScalar ()) data SingleType a where NumSingleType :: NumType a -> SingleType a From b8d1fa0268a12e09297b97b14aa5a45226100bd9 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 7 Apr 2022 11:51:48 +0200 Subject: [PATCH 30/67] pretty print POS structures --- src/Data/Array/Accelerate/Representation/POS.hs | 2 +- src/Data/Array/Accelerate/Representation/Type.hs | 2 +- src/Data/Array/Accelerate/Sugar/Elt.hs | 4 ++++ src/Data/Array/Accelerate/Type.hs | 16 +++++++++++++++- 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index 93a69dd61..3f9204db3 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -17,7 +17,7 @@ module Data.Array.Accelerate.Representation.POS ( POSable(..), Product(..), Sum(..), Ground(..), Finite, ProductType(..), SumType(..), POSable.Generic, type (++), - mkPOSableGround) + mkPOSableGround, Undef(..)) where diff --git a/src/Data/Array/Accelerate/Representation/Type.hs b/src/Data/Array/Accelerate/Representation/Type.hs index bb0e516e3..57318798c 100644 --- a/src/Data/Array/Accelerate/Representation/Type.hs +++ b/src/Data/Array/Accelerate/Representation/Type.hs @@ -51,7 +51,7 @@ data TupR s a where instance Show (TupR ScalarType a) where show TupRunit = "()" show (TupRsingle t) = show t - show (TupRpair a b) = "(" ++ show a ++ "," ++ show b ++ ")" + show (TupRpair a b) = show a ++ " ✕ " ++ show b formatTypeR :: Format r (TypeR a -> r) formatTypeR = later $ \case diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index 1bd6af341..742c9a519 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -153,6 +153,10 @@ mkScalarType _ mkScalarType _ | Just Refl <- eqT @a @Double = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Undef + = scalarType @a + mkEltRT :: forall a . (POSable a) => TypeR (POStoEltR (Choices a) (Fields a)) mkEltRT = case sameNat cs (Proxy :: Proxy 1) of diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 00a641c43..59483ec34 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -242,6 +242,7 @@ data SumScalarType a where data SingleType a where NumSingleType :: NumType a -> SingleType a + UndefSingleType :: SingleType Undef data VectorType a where VectorType :: KnownNat n => {-# UNPACK #-} !Int -> SingleType a -> VectorType (Vec n a) @@ -257,6 +258,7 @@ instance Show (IntegralType a) where show TypeWord16 = "Word16" show TypeWord32 = "Word32" show TypeWord64 = "Word64" + show TypeTAG = "TAG" instance Show (FloatingType a) where show TypeHalf = "Half" @@ -272,6 +274,7 @@ instance Show (BoundedType a) where instance Show (SingleType a) where show (NumSingleType ty) = show ty + show UndefSingleType = "Undef" instance Show (VectorType a) where show (VectorType n ty) = printf "<%d x %s>" n (show ty) @@ -279,7 +282,12 @@ instance Show (VectorType a) where instance Show (ScalarType a) where show (SingleScalarType ty) = show ty show (VectorScalarType ty) = show ty - -- TODO add all constructors + show (SumScalarType ty) = show ty + +instance Show (SumScalarType a) where + show ZeroScalarType = "" + show (SuccScalarType x (SumScalarType (ZeroScalarType))) = show x + show (SuccScalarType x xs) = show x ++ " + " ++ show xs formatIntegralType :: Format r (IntegralType a -> r) formatIntegralType = later $ \case @@ -614,3 +622,9 @@ runQ $ do -- return (concat is ++ concat fs ++ concat vs) + +instance IsSingle Undef where + singleType = UndefSingleType + +instance IsScalar Undef where + scalarType = SingleScalarType singleType From ead51a2ac5d613caf019ded67fd860c37236bb3a Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 7 Apr 2022 14:18:03 +0200 Subject: [PATCH 31/67] IsScalar instances for SumScalarType --- src/Data/Array/Accelerate/Type.hs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 59483ec34..844332c8c 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -628,3 +628,19 @@ instance IsSingle Undef where instance IsScalar Undef where scalarType = SingleScalarType singleType + +instance IsScalar (SumScalar ()) where + scalarType = SumScalarType ZeroScalarType + +instance (IsScalar a, IsSumScalar b) => IsScalar (SumScalar (a, b)) where + scalarType = SumScalarType $ SuccScalarType (scalarType @a) (SumScalarType + (sumScalarType @b)) + +class IsSumScalar a where + sumScalarType :: SumScalarType (SumScalar a) + +instance IsSumScalar () where + sumScalarType = ZeroScalarType + +instance (IsScalar a, IsSumScalar b) => IsSumScalar (a, b) where + sumScalarType = SuccScalarType (scalarType @a) (SumScalarType (sumScalarType @b)) From 8f825d3da651b3727996b9abbb27ce8c0542b310 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 7 Apr 2022 14:33:08 +0200 Subject: [PATCH 32/67] build Maybe in Matchable --- accelerate.cabal | 2 + src/Data/Array/Accelerate/AST.hs | 8 +- .../Array/Accelerate/Pattern/Matchable.hs | 243 ++++++++++++++++++ src/Data/Array/Accelerate/Smart.hs | 4 + 4 files changed, 255 insertions(+), 2 deletions(-) create mode 100644 src/Data/Array/Accelerate/Pattern/Matchable.hs diff --git a/accelerate.cabal b/accelerate.cabal index b7f647747..85c87b236 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -363,6 +363,7 @@ library , vector >= 0.10 , posable >= 0.9.0.0 , ghc-typelits-knownnat >= 0.7.6 + , generics-sop >= 0.5.1.1 exposed-modules: -- The core language and reference implementation @@ -466,6 +467,7 @@ library Data.Array.Accelerate.Lift Data.Array.Accelerate.Orphans Data.Array.Accelerate.Pattern + Data.Array.Accelerate.Pattern.Matchable Data.Array.Accelerate.Pattern.Bool Data.Array.Accelerate.Pattern.Either Data.Array.Accelerate.Pattern.Maybe diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 84e3529c1..33bdc7c4e 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -147,6 +147,7 @@ import Data.Array.Accelerate.Representation.Stencil import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.Vec +import Data.Array.Accelerate.Representation.POS (Finite) import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type @@ -164,6 +165,7 @@ import qualified Language.Haskell.TH.Syntax as TH import GHC.TypeLits +import Data.Proxy -- Array expressions -- ----------------- @@ -670,6 +672,9 @@ data PrimFun sig where PrimAbs :: NumType a -> PrimFun (a -> a) PrimSig :: NumType a -> PrimFun (a -> a) + -- operator on Finite + PrimShiftFinite :: Proxy a -> PrimFun (Finite b -> Finite (a + b)) + -- operators from Integral PrimQuot :: IntegralType a -> PrimFun ((a, a) -> a) PrimRem :: IntegralType a -> PrimFun ((a, a) -> a) @@ -943,8 +948,7 @@ primFunType = \case floating = num . FloatingNumType tbool :: TypeR PrimBool - tbool = TupRpair (TupRsingle (TagScalarType @2 0)) TupRunit - + tbool = TupRpair (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType (TypeTAG))))) TupRunit tint :: TypeR Int tint = TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt))) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs new file mode 100644 index 000000000..204828e8c --- /dev/null +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -0,0 +1,243 @@ +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FlexibleInstances #-} + + +module Data.Array.Accelerate.Pattern.Matchable where + +import Data.Array.Accelerate.Smart as Smart +import GHC.TypeLits +import Data.Proxy +import Data.Kind +import Generics.SOP as SOP +import Data.Type.Equality +import Data.Array.Accelerate.Representation.POS as POS +import Data.Array.Accelerate.Representation.Tag +import Unsafe.Coerce +import qualified Data.Array.Accelerate.AST as AST +import Data.Array.Accelerate.Type +import Data.Array.Accelerate.AST.Idx +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Representation.Type +-- import Data.Array.Accelerate.Pretty + + +class Matchable a where + type SOPCode a :: [[Type]] + type SOPCode a = Code a + + type Choices' a :: Nat + type Choices' a = Choices a + + build :: + ( KnownNat n + ) => Proxy n + -> NP Exp (Index (SOPCode a) n) + -> Exp a + default build :: + ( KnownNat n + , POSable a + ) => Proxy n + -> NP Exp (Index (SOPCode a) n) + -> Exp a + + build n _ = case sameNat (emptyChoices @a) (Proxy :: Proxy 1) of + -- no tag + Just Refl -> undefined + -- tagged + Nothing -> undefined + + match :: ( KnownNat n + ) => Proxy n + -> Exp a + -> Maybe (NP Exp (Index (SOPCode a) n)) + +buildTag :: SOP.All POSable xs => NP Exp xs -> Exp TAG +buildTag SOP.Nil = constant 0 -- exp of 0 :: Finite 1 +buildTag (((Exp x) :: (Exp x)) :* (xs :: xs)) = case sameNat (emptyChoices @x) (Proxy :: Proxy 1) of + -- x doesn't contain a tag, skip + Just Refl + -> buildTag xs + -- x contains a tag, build an Exp to calculate the product + Nothing + | Refl :: (EltR x :~: (TAG, _r)) <- unsafeCoerce Refl + -- TODO: this is incorrect, we need the size of the TAG here (return to Finite?) + -> mkMul (Exp (SmartExp (Prj PairIdxLeft x))) (buildTag xs) + +-- flattenProduct :: Product a -> FlattenProduct a +-- flattenProduct Nil = () +-- flattenProduct (Cons x xs) = (SumScalarType x, flattenProduct xs) + +buildFields :: forall n a . (POSable a, Elt a) => Proxy n -> NP SmartExp (Index (SOPCode a) n) -> SmartExp (FlattenProduct (Fields a)) +buildFields _ a = case emptyFields @a of + PTNil -> case constant () of { Exp se -> se } + PTCons st pt -> case a of + -- SOP.Nil -> SmartExp (Pair (someFunction st) undefined) + (x :* xs) -> SmartExp (Pair undefined undefined) + +buildFields' :: Proxy n -> ProductType (Fields a) -> NP SmartExp (Index (SOPCode a) n) -> SmartExp (FlattenProduct (Fields a)) +buildFields' _ PTNil _ = SmartExp Smart.Nil +buildFields' n (PTCons x xs) SOP.Nil = undefined -- SmartExp $ Pair _ (buildFields' n xs SOP.Nil) +buildFields' _ (PTCons x xs) (y :* ys) = undefined + +someFunction :: SumType x -> SmartExp (ScalarType (Sum x)) +someFunction = undefined + +newtype SEFPF a = SEFPF (SmartExp (FlattenProduct (Fields a))) + +-- mapBuildField :: (All POSable xs, All Elt xs) => NP SmartExp xs -> NP SEFPF xs +-- mapBuildField SOP.Nil = SOP.Nil +-- mapBuildField ((x :: SmartExp x) :* xs) = SEFPF (buildField @x x) :* mapBuildField xs + + +buildField :: forall a . (POSable a, Elt a, EltR a ~ POStoEltR (Choices a) (Fields a)) => SmartExp (EltR a) -> SmartExp (FlattenProduct (Fields a)) +buildField (SmartExp a) = case sameNat (emptyChoices @a) (Proxy :: Proxy 1) of + Just Refl -> + case emptyFields @a of + -- singleton types + PTCons (STSucc _ STZero) PTNil + | Refl :: (POStoEltR (Choices a) (Fields a) :~: a) <- unsafeCoerce Refl + -> SmartExp $ Pair (SmartExp (undefined a)) (SmartExp Smart.Nil) + -- tagless types + _ | Refl :: (POStoEltR (Choices a) (Fields a) :~: FlattenProduct (Fields a)) <- unsafeCoerce Refl + -> SmartExp a + -- tagged types + Nothing + -- We know that this is true because Choices a is not equal to 1 + | Refl :: (POStoEltR (Choices a) (Fields a) :~: (_x, FlattenProduct (Fields a))) <- unsafeCoerce Refl + -> SmartExp (Prj PairIdxRight (SmartExp a)) + + +type family Index (xs :: [[Type]]) (y :: Nat) :: [Type] where + Index (x ': xs) 0 = x + Index (x ': xs) n = Index xs (n - 1) + +type family ListToCons (xs :: [Type]) :: Type where + ListToCons '[] = () + ListToCons (x ': xs) = (x, ListToCons xs) + +-- copied from POSable library +type family Products (xs :: [Nat]) :: Nat where + Products '[] = 1 + Products (x ': xs) = x * Products xs + +-- idem +type family MapChoices (xs :: [Type]) :: [Nat] where + MapChoices '[] = '[] + MapChoices (x ': xs) = Choices x ': MapChoices xs + +-- idem +type family Concat (xss :: [[x]]) :: [x] where + Concat '[] = '[] + Concat (xs ': xss) = xs ++ Concat xss + + + +instance Matchable Bool where + type Choices' Bool = 2 + + build n _ = Exp (SmartExp (Pair (undefined (fromInteger $ natVal n)) (SmartExp Smart.Nil))) + + match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of + Just Refl -> + case e of + SmartExp (Match (TagRtag 0 TagRunit) _x) -> Just SOP.Nil + + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + Nothing -> + case sameNat n (Proxy :: Proxy 1) of + Just Refl -> + case e of + SmartExp (Match (TagRtag 1 TagRunit) _x) -> Just SOP.Nil + + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + + Nothing -> + error "Impossible type encountered" + +makeTag :: TAG -> SmartExp TAG +makeTag x = undefined -- SmartExp (Const (TupRsingle (tagType x))) + +tagType :: TupR ScalarType TAG +tagType = TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeTAG))) + + +instance Matchable (Maybe Int) where + type Choices' (Maybe Int) = 2 + + build n x = case sameNat n (Proxy :: Proxy 0) of + Just Refl -> + Exp ( + SmartExp ( + Pair + (makeTag 0) + (SmartExp ( + Pair + (SmartExp ( + (Const + (scalarType @(SumScalar (Undef, (Int, ())))) + (PickScalar POS.Undef) + ) + )) + (SmartExp Smart.Nil) + )) + ) + ) + Nothing -> case sameNat n (Proxy :: Proxy 1) of + Just Refl | (Exp x' :* SOP.Nil) <- x -> Exp ( + SmartExp ( + Pair + (makeTag 1) + (SmartExp ( + Pair + (SmartExp + (Union + (Right ( + SmartExp + (Union + (Left x') + ) + )) + ) + ) + (SmartExp Smart.Nil) + )) + ) + ) + Nothing -> error "Impossible type encountered" + + match n exp@(Exp e) = case sameNat n (Proxy :: Proxy 0) of + Just Refl -> + case e of + SmartExp (Match (TagRtag 0 (TagRpair _ TagRunit)) _x) + -> Just SOP.Nil + + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + Nothing -> -- matchJust + case sameNat n (Proxy :: Proxy 1) of + Just Refl -> + case e of + SmartExp (Match (TagRtag 1 _) x) -> undefined + -- -> Just (Exp (SmartExp (Match (TagRsingle (scalarType @Int)) + -- (Prj PairIdxLeft x))) :* SOP.Nil) + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + + Nothing -> + error "Impossible type encountered" diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index e8e547b56..f5d538978 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -93,6 +93,7 @@ import Data.Array.Accelerate.Representation.Stencil hiding ( Ste import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.Vec +import Data.Array.Accelerate.Representation.POS hiding (Nil, Undef) import Data.Array.Accelerate.Sugar.Array ( Arrays ) import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Foreign @@ -510,6 +511,9 @@ data PreSmartExp acc exp t where -> exp (t1, t2) -> PreSmartExp acc exp t + Union :: Either (exp t1) (exp (SumScalar t2)) + -> PreSmartExp acc exp (SumScalar (t1, t2)) + VecPack :: KnownNat n => VecR n s tup -> exp tup From e13bed11fb5ebfcdc413c6a7cb54df2604e137a4 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 7 Apr 2022 14:50:00 +0200 Subject: [PATCH 33/67] compiling Maybe Int pattern match --- src/Data/Array/Accelerate/AST/Idx.hs | 9 ++++++++- .../Array/Accelerate/Pattern/Matchable.hs | 19 ++++++++++++++++--- src/Data/Array/Accelerate/Smart.hs | 4 ++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/Data/Array/Accelerate/AST/Idx.hs b/src/Data/Array/Accelerate/AST/Idx.hs index 548453e2b..00ff8eee3 100644 --- a/src/Data/Array/Accelerate/AST/Idx.hs +++ b/src/Data/Array/Accelerate/AST/Idx.hs @@ -27,11 +27,13 @@ module Data.Array.Accelerate.AST.Idx ( idxToInt, rnfIdx, liftIdx, - PairIdx(..) + PairIdx(..), + UnionIdx(..) ) where import Language.Haskell.TH.Extra +import Data.Array.Accelerate.Type #ifndef ACCELERATE_INTERNAL_CHECKS import Data.Type.Equality ((:~:)(Refl)) @@ -112,3 +114,8 @@ data PairIdx p a where PairIdxLeft :: PairIdx (a, b) a PairIdxRight :: PairIdx (a, b) b +data UnionIdx p a where + UnionIdxLeft :: UnionIdx (a, b) a + UnionIdxRight :: UnionIdx (a, b) (SumScalar b) + + \ No newline at end of file diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index 204828e8c..92f82ecfb 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -232,9 +232,22 @@ instance Matchable (Maybe Int) where case sameNat n (Proxy :: Proxy 1) of Just Refl -> case e of - SmartExp (Match (TagRtag 1 _) x) -> undefined - -- -> Just (Exp (SmartExp (Match (TagRsingle (scalarType @Int)) - -- (Prj PairIdxLeft x))) :* SOP.Nil) + SmartExp (Match (TagRtag 1 _) x) + -> Just + (Exp + (SmartExp + (PrjUnion + UnionIdxLeft + (SmartExp + (PrjUnion + UnionIdxRight + (SmartExp + (Prj + PairIdxLeft + (SmartExp (Prj PairIdxRight x)) + )) + )) + )) :* SOP.Nil) SmartExp Match {} -> Nothing _ -> error "Embedded pattern synonym used outside 'match' context." diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index f5d538978..39d4d9e25 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -514,6 +514,10 @@ data PreSmartExp acc exp t where Union :: Either (exp t1) (exp (SumScalar t2)) -> PreSmartExp acc exp (SumScalar (t1, t2)) + PrjUnion :: UnionIdx (t1, t2) t + -> exp (SumScalar (t1, t2)) + -> PreSmartExp acc exp t + VecPack :: KnownNat n => VecR n s tup -> exp tup From 09447f79030faf13f88d0591d1070b22b72d6861 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Fri, 8 Apr 2022 10:37:47 +0200 Subject: [PATCH 34/67] build TAG --- .../Array/Accelerate/Pattern/Matchable.hs | 140 ++++++++++++++---- src/Data/Array/Accelerate/Type.hs | 4 +- 2 files changed, 118 insertions(+), 26 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index 92f82ecfb..adb0e41d5 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -49,7 +49,7 @@ class Matchable a where ) => Proxy n -> NP Exp (Index (SOPCode a) n) -> Exp a - + build n _ = case sameNat (emptyChoices @a) (Proxy :: Proxy 1) of -- no tag Just Refl -> undefined @@ -139,8 +139,8 @@ type family MapChoices (xs :: [Type]) :: [Nat] where type family Concat (xss :: [[x]]) :: [x] where Concat '[] = '[] Concat (xs ': xss) = xs ++ Concat xss - - + + instance Matchable Bool where type Choices' Bool = 2 @@ -151,9 +151,9 @@ instance Matchable Bool where Just Refl -> case e of SmartExp (Match (TagRtag 0 TagRunit) _x) -> Just SOP.Nil - + SmartExp Match {} -> Nothing - + _ -> error "Embedded pattern synonym used outside 'match' context." Nothing -> case sameNat n (Proxy :: Proxy 1) of @@ -169,7 +169,7 @@ instance Matchable Bool where error "Impossible type encountered" makeTag :: TAG -> SmartExp TAG -makeTag x = undefined -- SmartExp (Const (TupRsingle (tagType x))) +makeTag x = SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeTAG))) x) tagType :: TupR ScalarType TAG tagType = TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeTAG))) @@ -187,10 +187,9 @@ instance Matchable (Maybe Int) where (SmartExp ( Pair (SmartExp ( - (Const + Const (scalarType @(SumScalar (Undef, (Int, ())))) (PickScalar POS.Undef) - ) )) (SmartExp Smart.Nil) )) @@ -219,14 +218,14 @@ instance Matchable (Maybe Int) where ) Nothing -> error "Impossible type encountered" - match n exp@(Exp e) = case sameNat n (Proxy :: Proxy 0) of + match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of Just Refl -> case e of SmartExp (Match (TagRtag 0 (TagRpair _ TagRunit)) _x) -> Just SOP.Nil - + SmartExp Match {} -> Nothing - + _ -> error "Embedded pattern synonym used outside 'match' context." Nothing -> -- matchJust case sameNat n (Proxy :: Proxy 1) of @@ -234,23 +233,114 @@ instance Matchable (Maybe Int) where case e of SmartExp (Match (TagRtag 1 _) x) -> Just - (Exp - (SmartExp - (PrjUnion - UnionIdxLeft - (SmartExp - (PrjUnion - UnionIdxRight - (SmartExp - (Prj - PairIdxLeft - (SmartExp (Prj PairIdxRight x)) - )) - )) - )) :* SOP.Nil) + (mkExp + (PrjUnion + UnionIdxLeft + (SmartExp + (PrjUnion + UnionIdxRight + (SmartExp + (Prj + PairIdxLeft + (SmartExp (Prj PairIdxRight x)) + )) + )) + ) :* SOP.Nil) SmartExp Match {} -> Nothing _ -> error "Embedded pattern synonym used outside 'match' context." Nothing -> error "Impossible type encountered" + +instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) where + type Choices' (Either a b) = OuterChoices (Either a b) + + build n x + | Refl :: (EltR (Either a b) :~: (TAG, FlattenProduct (Fields (Either a b)))) <- unsafeCoerce Refl -- this should be easily provable, I'm just lazy + = case sameNat n (Proxy :: Proxy 0) of + Just Refl + -> Exp ( + SmartExp ( + Pair + (unExp $ buildTAG x) + _ + ) + ) + where + tag = undefined --foldl 1 (*) (mapChoices x) + test = natVal (Proxy :: Proxy (Choices a)) + -- Nothing -> case sameNat n (Proxy :: Proxy 1) of + -- Just Refl | (Exp x' :* SOP.Nil) <- x -> Exp ( + -- SmartExp ( + -- Pair + -- (makeTag 1) + -- (SmartExp ( + -- Pair + -- (SmartExp + -- (Union + -- (Right ( + -- SmartExp + -- (Union + -- (Left x') + -- ) + -- )) + -- ) + -- ) + -- (SmartExp Smart.Nil) + -- )) + -- ) + -- ) + -- Nothing -> error "Impossible type encountered" + + -- match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of + -- Just Refl -> + -- case e of + -- SmartExp (Match (TagRtag 0 (TagRpair _ TagRunit)) _x) + -- -> Just SOP.Nil + + -- SmartExp Match {} -> Nothing + + -- _ -> error "Embedded pattern synonym used outside 'match' context." + -- Nothing -> -- matchJust + -- case sameNat n (Proxy :: Proxy 1) of + -- Just Refl -> + -- case e of + -- SmartExp (Match (TagRtag 1 _) x) + -- -> Just + -- (Exp + -- (SmartExp + -- (PrjUnion + -- UnionIdxLeft + -- (SmartExp + -- (PrjUnion + -- UnionIdxRight + -- (SmartExp + -- (Prj + -- PairIdxLeft + -- (SmartExp (Prj PairIdxRight x)) + -- )) + -- )) + -- )) :* SOP.Nil) + -- SmartExp Match {} -> Nothing + + -- _ -> error "Embedded pattern synonym used outside 'match' context." + + -- Nothing -> + -- error "Impossible type encountered" + +-- like combineProducts, but lifted to the AST +buildTAG :: (All POSable xs) => NP Exp xs -> Exp TAG +buildTAG SOP.Nil = Exp $ makeTag 0 +buildTAG (x :* xs) = combineProduct x (buildTAG xs) + +-- like Finite.combineProduct, but lifted to the AST +-- basically `tag x + tag y * natVal x` +combineProduct :: forall x. (POSable x) => Exp x -> Exp TAG -> Exp TAG +combineProduct x y = case sameNat (Proxy :: Proxy (Choices x)) (Proxy :: Proxy 1) of + -- untagged type: `tag x = 0`, `natVal x = 1` + Just Refl -> y + -- tagged type + Nothing + | Refl :: (EltR x :~: (TAG, FlattenProduct (Fields x))) <- unsafeCoerce Refl + -> mkAdd (mkExp $ Prj PairIdxLeft (unExp x)) (mkMul y (constant (fromInteger $ natVal (Proxy :: Proxy (Choices x))))) diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 844332c8c..fa75eee45 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -301,6 +301,7 @@ formatIntegralType = later $ \case TypeWord16 -> "Word16" TypeWord32 -> "Word32" TypeWord64 -> "Word64" + TypeTAG -> "TAG" formatFloatingType :: Format r (FloatingType a -> r) formatFloatingType = later $ \case @@ -373,6 +374,7 @@ integralDict TypeWord8 = IntegralDict integralDict TypeWord16 = IntegralDict integralDict TypeWord32 = IntegralDict integralDict TypeWord64 = IntegralDict +integralDict TypeTAG = IntegralDict floatingDict :: FloatingType a -> FloatingDict a floatingDict TypeHalf = FloatingDict @@ -633,7 +635,7 @@ instance IsScalar (SumScalar ()) where scalarType = SumScalarType ZeroScalarType instance (IsScalar a, IsSumScalar b) => IsScalar (SumScalar (a, b)) where - scalarType = SumScalarType $ SuccScalarType (scalarType @a) (SumScalarType + scalarType = SumScalarType $ SuccScalarType (scalarType @a) (SumScalarType (sumScalarType @b)) class IsSumScalar a where From 6bce2064561020ffdf86f1a8c5743bf627f096bb Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 12 Apr 2022 15:55:34 +0200 Subject: [PATCH 35/67] split EltR with helper function --- src/Data/Array/Accelerate/Sugar/Elt.hs | 30 +++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index 742c9a519..e77c038ae 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -11,6 +11,7 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE GADTs #-} {-# OPTIONS_HADDOCK hide #-} {-# OPTIONS_GHC -ddump-splices #-} -- | @@ -23,7 +24,7 @@ -- Portability : non-portable (GHC extensions) -- -module Data.Array.Accelerate.Sugar.Elt ( Elt(..) ) +module Data.Array.Accelerate.Sugar.Elt ( Elt(..), eltRType, EltRType(..) ) where import Data.Array.Accelerate.Representation.Type @@ -102,14 +103,36 @@ class Elt a where default toElt :: (POSable a, POStoEltR (Choices a) (Fields a) ~ EltR a) => EltR a -> a toElt = fromEltR +-- function to bring the contraints in scope that are needed to work with EltR, +-- without needing to inspect how POS2EltR works +data EltRType x where + SingletonType :: (EltR x ~ x, Fields x ~ '[ '[x]]) => EltRType x + TaglessType :: (EltR x ~ FlattenProduct (Fields x)) => EltRType x + TaggedType :: (EltR x ~ (TAG, FlattenProduct (Fields x))) => EltRType x + +eltRType :: forall x . POSable x => EltRType x +eltRType = case sameNat (Proxy :: Proxy (Choices x)) (Proxy :: Proxy 1) of + Just Refl -> case emptyFields @x of + PTCons (STSucc _ STZero) PTNil + | Refl :: (EltR x :~: x) <- unsafeCoerce Refl + , Refl :: (Fields x :~: '[ '[x]]) <- unsafeCoerce Refl + -> SingletonType + _ + | Refl :: (EltR x :~: FlattenProduct (Fields x)) <- unsafeCoerce Refl + -> TaglessType + Nothing + | Refl :: (EltR x :~: (TAG, FlattenProduct (Fields x))) <- unsafeCoerce Refl + -> TaggedType + + flattenProductType :: ProductType a -> TypeR (FlattenProduct a) flattenProductType PTNil = TupRunit flattenProductType (PTCons x xs) = TupRpair (TupRsingle (flattenSumType x)) (flattenProductType xs) flattenSumType :: SumType a -> ScalarType (SumScalar (FlattenSum a)) flattenSumType STZero = SumScalarType ZeroScalarType -flattenSumType (STSucc (x :: x) xs) - = SumScalarType (SuccScalarType (mkScalarType x) (flattenSumType xs)) +flattenSumType (STSucc x xs) = case flattenSumType xs of + SumScalarType xs' -> SumScalarType (SuccScalarType (mkScalarType x) xs') -- This is an unsafe conversion, and should be kept strictly in sync with the -- set of types that implement Ground @@ -310,6 +333,7 @@ runQ $ do -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] return (concat ss) +instance Elt Undef -- TODO: bring this back into TH instance (POSable a, POSable b) => Elt (a, b) From a3530442f06dfde01bbc60cac7c6922e06e68478 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 12 Apr 2022 15:56:18 +0200 Subject: [PATCH 36/67] simplify SumScalarType --- src/Data/Array/Accelerate/Type.hs | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index fa75eee45..7d81ac734 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -106,7 +106,7 @@ type family POStoEltR (cs :: Nat) fs :: Type where POStoEltR 1 x = FlattenProduct x -- tagless types POStoEltR n x = (TAG, FlattenProduct x) -- all other types -type family FlattenProduct (xss :: [[a]]) :: Type where +type family FlattenProduct (xss :: f [a]) = (r :: Type) | r -> f where FlattenProduct '[] = () FlattenProduct (x ': xs) = (SumScalar (FlattenSum x), FlattenProduct xs) @@ -230,15 +230,18 @@ data BoundedType a where data ScalarType a where SingleScalarType :: SingleType a -> ScalarType a VectorScalarType :: VectorType (Vec n a) -> ScalarType (Vec n a) - SumScalarType :: SumScalarType a -> ScalarType a + SumScalarType :: SumScalarType a -> ScalarType (SumScalar a) + +class IsSumScalar a where + sumScalarType :: SumScalarType a data SumScalar x where PickScalar :: a -> SumScalar (a, b) SkipScalar :: SumScalar b -> SumScalar (a, b) data SumScalarType a where - SuccScalarType :: ScalarType a -> ScalarType (SumScalar b) -> SumScalarType (SumScalar (a, b)) - ZeroScalarType :: SumScalarType (SumScalar ()) + SuccScalarType :: ScalarType a -> SumScalarType b -> SumScalarType (a, b) + ZeroScalarType :: SumScalarType () data SingleType a where NumSingleType :: NumType a -> SingleType a @@ -286,7 +289,7 @@ instance Show (ScalarType a) where instance Show (SumScalarType a) where show ZeroScalarType = "" - show (SuccScalarType x (SumScalarType (ZeroScalarType))) = show x + show (SuccScalarType x (ZeroScalarType)) = show x show (SuccScalarType x xs) = show x ++ " + " ++ show xs formatIntegralType :: Format r (IntegralType a -> r) @@ -631,18 +634,11 @@ instance IsSingle Undef where instance IsScalar Undef where scalarType = SingleScalarType singleType -instance IsScalar (SumScalar ()) where - scalarType = SumScalarType ZeroScalarType - -instance (IsScalar a, IsSumScalar b) => IsScalar (SumScalar (a, b)) where - scalarType = SumScalarType $ SuccScalarType (scalarType @a) (SumScalarType - (sumScalarType @b)) - -class IsSumScalar a where - sumScalarType :: SumScalarType (SumScalar a) +instance (IsSumScalar a) => IsScalar (SumScalar a) where + scalarType = SumScalarType (sumScalarType @a) instance IsSumScalar () where sumScalarType = ZeroScalarType instance (IsScalar a, IsSumScalar b) => IsSumScalar (a, b) where - sumScalarType = SuccScalarType (scalarType @a) (SumScalarType (sumScalarType @b)) + sumScalarType = SuccScalarType (scalarType @a) (sumScalarType @b) From 331449165c13095701416277bcde8dcd3d7f7b35 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 12 Apr 2022 15:56:55 +0200 Subject: [PATCH 37/67] more stuff for Matchable --- .../Array/Accelerate/Pattern/Matchable.hs | 207 +++++++++++++++++- 1 file changed, 198 insertions(+), 9 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index adb0e41d5..6f988322a 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -10,6 +10,9 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeFamilyDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE ConstraintKinds #-} module Data.Array.Accelerate.Pattern.Matchable where @@ -77,12 +80,12 @@ buildTag (((Exp x) :: (Exp x)) :* (xs :: xs)) = case sameNat (emptyChoices @x) ( -- flattenProduct Nil = () -- flattenProduct (Cons x xs) = (SumScalarType x, flattenProduct xs) -buildFields :: forall n a . (POSable a, Elt a) => Proxy n -> NP SmartExp (Index (SOPCode a) n) -> SmartExp (FlattenProduct (Fields a)) -buildFields _ a = case emptyFields @a of - PTNil -> case constant () of { Exp se -> se } - PTCons st pt -> case a of - -- SOP.Nil -> SmartExp (Pair (someFunction st) undefined) - (x :* xs) -> SmartExp (Pair undefined undefined) +-- buildFields :: forall n a . (POSable a, Elt a) => Proxy n -> NP SmartExp (Index (SOPCode a) n) -> SmartExp (FlattenProduct (Fields a)) +-- buildFields _ a = case emptyFields @a of +-- PTNil -> case constant () of { Exp se -> se } +-- PTCons st pt -> case a of +-- -- SOP.Nil -> SmartExp (Pair (someFunction st) undefined) +-- (x :* xs) -> SmartExp (Pair undefined undefined) buildFields' :: Proxy n -> ProductType (Fields a) -> NP SmartExp (Index (SOPCode a) n) -> SmartExp (FlattenProduct (Fields a)) buildFields' _ PTNil _ = SmartExp Smart.Nil @@ -140,7 +143,32 @@ type family Concat (xss :: [[x]]) :: [x] where Concat '[] = '[] Concat (xs ': xss) = xs ++ Concat xss - +-- idem +type family MapFields (xs :: [Type]) :: [[[Type]]] where + MapFields '[] = '[] + MapFields (x ': xs) = Fields x ': MapFields xs + +type family MapFlattenProduct (xs :: [[[Type]]]) :: [Type] where + MapFlattenProduct '[] = '[] + MapFlattenProduct (x ': xs) = FlattenProduct x ': MapFlattenProduct xs + +type family ConcatT (xss :: [x]) :: x where + ConcatT '[] = () + ConcatT (x ': xs) = (x, ConcatT xs) + +-- type family RealConcatT (xss :: [[Type]]) :: Type where +-- ConcatT '[] = () +-- ConcatT ('[] ': xs) = RealConcatT xs +-- ConcatT ((x ': xs) ': ys) = (x, RealConcatT xs ys) + +type family ConcatAST (x :: Type) (y :: Type) :: Type where + ConcatAST xs () = xs + ConcatAST () ys = ys + ConcatAST (x, xs) ys = (x, ConcatAST xs ys) + +type family ConcatASTs (xs :: [Type]) :: Type where + ConcatASTs '[] = () + ConcatASTs (x ': xs) = ConcatAST x (ConcatASTs xs) instance Matchable Bool where type Choices' Bool = 2 @@ -253,7 +281,7 @@ instance Matchable (Maybe Int) where Nothing -> error "Impossible type encountered" -instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) where +instance (POSable (Either a b), POSable a, POSable b, Elt a) => Matchable (Either a b) where type Choices' (Either a b) = OuterChoices (Either a b) build n x @@ -264,7 +292,7 @@ instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) SmartExp ( Pair (unExp $ buildTAG x) - _ + undefined --(understandConcatPlease @a @b (mergeLeft @a @b (getSingleElem x))) ) ) where @@ -329,6 +357,167 @@ instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) -- Nothing -> -- error "Impossible type encountered" +-- weirdConvert :: forall x . Elt x => TypeR (EltR x) +-- weirdConvert = eltR @x + +-- getSingleElem :: NP Exp '[a] -> Exp a +-- getSingleElem (x :* SOP.Nil) = x + +-- understandConcatPlease :: SmartExp (FlattenProduct (Merge (Fields a) (Fields b))) -> SmartExp (FlattenProduct (Merge (Fields a ++ '[]) (Fields b ++ '[]))) +-- understandConcatPlease = unsafeCoerce + +-- mergeLeft :: forall a b . (POSable a, Elt a) => Exp a -> SmartExp (FlattenProduct (Merge (Fields a) (Fields b))) +-- mergeLeft (Exp a) = case buildFields1 @a a of +-- a' -> case weirdConvert2 @a (eltR @a) of +-- a3 -> undefined + +-- mergeLeft' :: forall a b . TypeR (EltR a -> ProductType b -> SmartExp (FlattenProduct a) -> SmartExp (Merge' (FlattenProduct a) (FlattenProduct b)) +-- mergeLeft' PTNil PTNil a = a +-- mergeLeft' PTNil (PTCons gb (gbs :: (ProductType (Fields b')))) a = SmartExp $ Pair (fromSumType gb) (mergeLeft' @a @b' PTNil gbs a) + + +-- buildFields :: (All POSable xs) => NP SmartExp xs -> SmartExp (ConcatT (MapFlattenProduct (MapFields xs))) +-- buildFields SOP.Nil = () +-- buildFields (x :* xs) = SmartExp $ Pair +-- where +-- fieldsx = buildFields1 x + +merge :: forall a b . TypeR a -> TypeR b -> SmartExp a -> SmartExp b -> SmartExp (Merge' a b) +merge TupRunit TupRunit a b = unExp $ constant () +merge TupRunit (TupRpair (TupRsingle (SumScalarType x)) gbs) a b + = SmartExp $ Pair + (mergeSumUndefLeft x (SmartExp $ Prj PairIdxLeft b)) + (merge TupRunit gbs a (SmartExp $ Prj PairIdxRight b)) +merge (TupRpair (TupRsingle (SumScalarType x)) gas) TupRunit a b + = SmartExp $ Pair + (mergeSumUndefRight x (SmartExp $ Prj PairIdxLeft a)) + (merge gas TupRunit (SmartExp $ Prj PairIdxRight a) b) +merge (TupRpair (TupRsingle (SumScalarType ga)) gas) (TupRpair (TupRsingle (SumScalarType gb)) gbs) a b + = SmartExp $ Pair + (undefined) -- mergeSum + (merge gas gbs (SmartExp $ Prj PairIdxRight a) (SmartExp $ Prj PairIdxRight b)) + + +mergeSumUndefRight :: SumScalarType x -> SmartExp (SumScalar x) -> SmartExp (SumScalar (Concat' x (Undef, ()))) +mergeSumUndefRight ZeroScalarType a = SmartExp $ Const (SumScalarType (SuccScalarType (SingleScalarType UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) +mergeSumUndefRight (SuccScalarType x xs) a = SmartExp $ Union (Left (SmartExp $ PrjUnion UnionIdxLeft a)) + +mergeSumUndefLeft :: SumScalarType x -> SmartExp (SumScalar x) -> SmartExp (SumScalar (Undef, x)) +mergeSumUndefLeft ZeroScalarType a = SmartExp $ Const (SumScalarType (SuccScalarType (SingleScalarType UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) +mergeSumUndefLeft (SuccScalarType x xs) a = SmartExp $ Union (Right a) + +-- mergeSum :: SumScalar + +class AllSumScalar (xs :: Type) where + +instance AllSumScalar () where + +instance (x' ~ SumScalar x, IsSumScalar x', AllSumScalar xs) => AllSumScalar (x, xs) where + +class All' (c :: k -> Constraint) (xs :: Type) where + +instance All' c () where + +instance (c x, All' c xs) => All' c (x, xs) where + + + +-- ZipWith Concat +-- like POSable.Merge, but lifted to tuple lists +type family Merge' (a :: Type) (b :: Type) = (r :: Type) where + Merge' () () = () + Merge' () (SumScalar b, bs) = (SumScalar (Undef, b), Merge' () bs) + Merge' (SumScalar a, as) () = (SumScalar (Concat' a (Undef, ())), Merge' as ()) + Merge' (SumScalar a, as) (SumScalar b, bs) = (SumScalar (Concat' a b), Merge' as bs) + +type family Concat' (a :: Type) (b :: Type) = (r :: Type) where + Concat' () ys = ys + Concat' (x, xs) ys = (x, Concat' xs ys) + +fromSumType :: SumType x -> SmartExp (SumScalar (Undef, FlattenSum x)) +fromSumType x = SmartExp (Union (Left (SmartExp (Const (SingleScalarType UndefSingleType) POS.Undef)))) + +buildFields1 :: forall x . (POSable x) => SmartExp (EltR x) -> SmartExp (FlattenProduct (Fields x)) +buildFields1 x = case eltRType @x of + SingletonType -> SmartExp $ Pair (SmartExp $ Union (Left x)) (SmartExp Smart.Nil) + TaglessType -> x + TaggedType -> SmartExp $ Prj PairIdxRight x + +weirdConvert2 :: forall x . (Elt x, POSable x) => TypeR (EltR x) -> TypeR (FlattenProduct (Fields x)) +weirdConvert2 x = case eltRType @x of + SingletonType -> case x of + TupRsingle x' -> TupRpair (TupRsingle (SumScalarType (SuccScalarType x' ZeroScalarType))) TupRunit + TaglessType -> x + TaggedType -> case x of + TupRpair _ x' -> x' + +-- guidedAppend :: forall x y . TypeR (FlattenProduct (Fields x)) -> SmartExp (FlattenProduct (Fields x)) -> SmartExp (FlattenProduct (Fields y)) -> SmartExp (FlattenProduct (Fields x ++ Fields y)) +-- guidedAppend TupRunit x y | Refl :: (FlattenProduct (Fields y) :~: FlattenProduct (Fields x ++ Fields y)) <- unsafeCoerce Refl = y +-- guidedAppend (TupRsingle g) x y = SmartExp (Pair x y) +-- guidedAppend (TupRpair g1 g2) x y = undefined + +buildFields2 :: forall x y . (POSable x) => SmartExp (EltR x) -> SmartExp (FlattenProduct (Fields y)) -> SmartExp (FlattenProduct (Fields x ++ Fields y)) +buildFields2 x y = case eltRType @x of + SingletonType -> SmartExp $ Pair (SmartExp $ Union (Left x)) y + TaglessType -> buildFields3 @x @y x y + TaggedType -> buildFields3 @x @y (SmartExp $ Prj PairIdxRight x) y + + +buildFields3 :: forall x y . (POSable x) => SmartExp (FlattenProduct (Fields x)) -> SmartExp (FlattenProduct (Fields y)) -> SmartExp (FlattenProduct (Fields x ++ Fields y)) +buildFields3 = buildFields4 @x @y (emptyFields @x) + +buildFields4 :: forall x y . ProductType (Fields x) -> SmartExp (FlattenProduct (Fields x)) -> SmartExp (FlattenProduct (Fields y)) -> SmartExp (FlattenProduct (Fields x ++ Fields y)) +buildFields4 PTNil x y = y +buildFields4 (PTCons g gs) x y = undefined + where + x' :: SmartExp (SumScalar (FlattenSum (Head (Fields x)))) + x' = SmartExp $ Prj PairIdxLeft x + xs' :: SmartExp (FlattenProduct (Tail (Fields x))) + xs' = SmartExp $ Prj PairIdxRight x + f :: SmartExp (FlattenProduct (Fields x ++ Fields y)) + f = SmartExp (Pair x' xy) + xy :: SmartExp (FlattenProduct (Tail (Fields x) ++ Fields y)) + xy = undefined + +type family Head (xs :: [x]) :: x where + Head (x ': xs) = x + +type family Tail (xs :: [x]) :: [x] where + Tail (x ': xs) = xs +-- concatFields :: forall x y. (POSable x) => SmartExp (FlattenProduct (Fields x)) -> SmartExp (FlattenProduct (Fields y)) -> SmartExp (FlattenProduct (Fields x ++ Fields y)) +-- concatFields x y = case emptyFields @x of +-- PTNil +-- -> y +-- (PTCons x' (xs' :: (ProductType xs))) +-- -> SmartExp $ Pair (SmartExp $ Prj PairIdxLeft x) (f @xs @y (SmartExp $ Prj PairIdxRight x) y) +-- where +-- -- rec' :: SmartExp (FlattenProduct ) +-- -- rec' x y = concatFields (SmartExp $ Prj PairIdxRight x) y +-- f :: SmartExp (FlattenProduct ys) -> SmartExp (FlattenProduct (Fields z)) -> SmartExp (FlattenProduct (ys ++ Fields z)) +-- f xs y' = undefined + +-- concatFields' :: forall xs ys . SmartExp (FlattenProduct xs) -> SmartExp (FlattenProduct ys) -> SmartExp (FlattenProduct (xs ++ ys)) +-- concatFields' = undefined + +-- convertASTtoNP :: forall x . POSable x => SmartExp (FlattenProduct (Fields x)) -> NP SmartExp (MapFlattenSum (Fields x)) +-- convertASTtoNP = convertASTtoNP' @(Fields x) (emptyFields @x) + +-- convertASTtoNP' :: ProductType x -> SmartExp (FlattenProduct x) -> NP SmartExp (MapFlattenSum x) +-- convertASTtoNP' PTNil _ = SOP.Nil +-- convertASTtoNP' (PTCons _ xs) y = SmartExp (Prj PairIdxLeft y) :* convertASTtoNP' xs (SmartExp $ Prj PairIdxRight y) + +-- nPtoAST :: NP SmartExp xs -> SmartExp (ConcatASTs xs) +-- nPtoAST SOP.Nil = SmartExp Smart.Nil +-- nPtoAST (x :* xs) = SmartExp $ Pair x (nPtoAST xs) + +-- concatAST :: forall x y . (Elt x) => SmartExp x -> SmartExp y -> SmartExp (ConcatAST x y) +-- concatAST x y | TupRunit <- eltR @x = y +-- concatAST x y | (TupRpair _ _) <- eltR @x = SmartExp $ Pair (SmartExp $ Prj PairIdxLeft x) (concatAST (SmartExp (Prj PairIdxRight x)) y) + +type family MapFlattenSum (x :: [[Type]]) :: [Type] where + MapFlattenSum '[] = '[] + MapFlattenSum (x ': xs) = SumScalar (FlattenSum x) ': MapFlattenSum xs + -- like combineProducts, but lifted to the AST buildTAG :: (All POSable xs) => NP Exp xs -> Exp TAG buildTAG SOP.Nil = Exp $ makeTag 0 From 0f6cf5beb8ef5e599678f17858b08f27dbf44a69 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 14 Apr 2022 14:34:35 +0200 Subject: [PATCH 38/67] simpler union operators --- .../Array/Accelerate/Pattern/Matchable.hs | 80 +++++++++++++------ src/Data/Array/Accelerate/Smart.hs | 13 +-- 2 files changed, 62 insertions(+), 31 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index 6f988322a..68915ea39 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -232,12 +232,10 @@ instance Matchable (Maybe Int) where Pair (SmartExp (Union - (Right ( - SmartExp - (Union - (Left x') - ) - )) + scalarTypeUndefLeft + (SmartExp + (LiftUnion x') + ) ) ) (SmartExp Smart.Nil) @@ -260,20 +258,9 @@ instance Matchable (Maybe Int) where Just Refl -> case e of SmartExp (Match (TagRtag 1 _) x) - -> Just - (mkExp - (PrjUnion - UnionIdxLeft - (SmartExp - (PrjUnion - UnionIdxRight - (SmartExp - (Prj - PairIdxLeft - (SmartExp (Prj PairIdxRight x)) - )) - )) - ) :* SOP.Nil) + -> Just ( + (mkExp $ PrjUnion $ SmartExp $ Union (unConcatSumScalarType (SumScalarType $ SuccScalarType (SingleScalarType UndefSingleType) ZeroScalarType)) (SmartExp $ Prj PairIdxLeft (SmartExp $ Prj PairIdxRight x))) + :* SOP.Nil) SmartExp Match {} -> Nothing _ -> error "Embedded pattern synonym used outside 'match' context." @@ -281,6 +268,10 @@ instance Matchable (Maybe Int) where Nothing -> error "Impossible type encountered" +unConcatSumScalarType :: ScalarType (SumScalar a) -> ScalarType (SumScalar (Concat' a b)) -> ScalarType (SumScalar b) +unConcatSumScalarType (SumScalarType ZeroScalarType) xs = xs +unConcatSumScalarType (SumScalarType (SuccScalarType a as)) (SumScalarType (SuccScalarType x xs)) = unConcatSumScalarType (SumScalarType as) (SumScalarType xs) + instance (POSable (Either a b), POSable a, POSable b, Elt a) => Matchable (Either a b) where type Choices' (Either a b) = OuterChoices (Either a b) @@ -382,6 +373,33 @@ instance (POSable (Either a b), POSable a, POSable b, Elt a) => Matchable (Eithe -- where -- fieldsx = buildFields1 x +-- mergeLeft :: forall a b . TypeR a -> TypeR b -> SmartExp a -> SmartExp (Merge' a b) +-- mergeLeft TupRunit TupRunit a = unExp $ constant () +-- mergeLeft TupRunit (TupRpair (TupRsingle (SumScalarType x)) gbs) a +-- = SmartExp $ Pair +-- (makeUndefLeft x) +-- (mergeLeft TupRunit gbs a) +-- mergeLeft (TupRpair (TupRsingle (SumScalarType x)) gas) TupRunit a +-- = SmartExp $ Pair +-- (mergeSumUndefRight x (SmartExp $ Prj PairIdxLeft a)) +-- (mergeLeft gas TupRunit (SmartExp $ Prj PairIdxRight a)) +-- mergeLeft (TupRpair (TupRsingle (SumScalarType ga)) gas) (TupRpair (TupRsingle (SumScalarType gb)) gbs) a +-- = SmartExp $ Pair +-- (SmartExp $ Union _ (SmartExp $ Prj PairIdxLeft a)) +-- (mergeLeft gas gbs (SmartExp $ Prj PairIdxRight a)) + +makeUndefLeft :: SumScalarType x -> SmartExp (SumScalar (Undef, x)) +makeUndefLeft x = SmartExp $ Const (SumScalarType (SuccScalarType (SingleScalarType UndefSingleType) x)) (PickScalar POS.Undef) + +mergeSumLeft :: forall a b . SumScalarType a -> SumScalarType b -> SmartExp (SumScalar a) -> SmartExp (SumScalar (Concat' a b)) +mergeSumLeft ls rs x = SmartExp $ Union (const $ scalarSumConcat ls rs) x + +scalarSumConcat:: SumScalarType xs -> SumScalarType ys -> ScalarType (SumScalar (Concat' xs ys)) +scalarSumConcat ZeroScalarType rs = SumScalarType rs +scalarSumConcat (SuccScalarType l ls) rs = SumScalarType $ SuccScalarType l ls' + where + SumScalarType ls' = scalarSumConcat ls rs + merge :: forall a b . TypeR a -> TypeR b -> SmartExp a -> SmartExp b -> SmartExp (Merge' a b) merge TupRunit TupRunit a b = unExp $ constant () merge TupRunit (TupRpair (TupRsingle (SumScalarType x)) gbs) a b @@ -400,11 +418,21 @@ merge (TupRpair (TupRsingle (SumScalarType ga)) gas) (TupRpair (TupRsingle (SumS mergeSumUndefRight :: SumScalarType x -> SmartExp (SumScalar x) -> SmartExp (SumScalar (Concat' x (Undef, ()))) mergeSumUndefRight ZeroScalarType a = SmartExp $ Const (SumScalarType (SuccScalarType (SingleScalarType UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) -mergeSumUndefRight (SuccScalarType x xs) a = SmartExp $ Union (Left (SmartExp $ PrjUnion UnionIdxLeft a)) +mergeSumUndefRight (SuccScalarType x xs) a = SmartExp $ Union scalarTypeUndefRight a mergeSumUndefLeft :: SumScalarType x -> SmartExp (SumScalar x) -> SmartExp (SumScalar (Undef, x)) mergeSumUndefLeft ZeroScalarType a = SmartExp $ Const (SumScalarType (SuccScalarType (SingleScalarType UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) -mergeSumUndefLeft (SuccScalarType x xs) a = SmartExp $ Union (Right a) +mergeSumUndefLeft (SuccScalarType x xs) a = SmartExp $ Union scalarTypeUndefLeft a + +scalarTypeUndefLeft :: ScalarType (SumScalar a) -> ScalarType (SumScalar (Undef, a)) +scalarTypeUndefLeft (SumScalarType x) = SumScalarType (SuccScalarType (scalarType @Undef) x) + +scalarTypeUndefRight :: ScalarType (SumScalar a) -> ScalarType (SumScalar (Concat' a (Undef, ()))) +scalarTypeUndefRight (SumScalarType ZeroScalarType) = SumScalarType (SuccScalarType (scalarType @Undef) ZeroScalarType) +scalarTypeUndefRight (SumScalarType (SuccScalarType x xs)) + = SumScalarType (SuccScalarType x xs') + where + (SumScalarType xs') = scalarTypeUndefRight (SumScalarType xs) -- mergeSum :: SumScalar @@ -434,12 +462,12 @@ type family Concat' (a :: Type) (b :: Type) = (r :: Type) where Concat' () ys = ys Concat' (x, xs) ys = (x, Concat' xs ys) -fromSumType :: SumType x -> SmartExp (SumScalar (Undef, FlattenSum x)) -fromSumType x = SmartExp (Union (Left (SmartExp (Const (SingleScalarType UndefSingleType) POS.Undef)))) +-- fromSumType :: SumType x -> SmartExp (SumScalar (Undef, FlattenSum x)) +-- fromSumType x = SmartExp $ Union (_) $ SmartExp (LiftUnion (SmartExp (Const (SingleScalarType UndefSingleType) POS.Undef))) buildFields1 :: forall x . (POSable x) => SmartExp (EltR x) -> SmartExp (FlattenProduct (Fields x)) buildFields1 x = case eltRType @x of - SingletonType -> SmartExp $ Pair (SmartExp $ Union (Left x)) (SmartExp Smart.Nil) + SingletonType -> SmartExp $ Pair (SmartExp $ LiftUnion x) (SmartExp Smart.Nil) TaglessType -> x TaggedType -> SmartExp $ Prj PairIdxRight x @@ -458,7 +486,7 @@ weirdConvert2 x = case eltRType @x of buildFields2 :: forall x y . (POSable x) => SmartExp (EltR x) -> SmartExp (FlattenProduct (Fields y)) -> SmartExp (FlattenProduct (Fields x ++ Fields y)) buildFields2 x y = case eltRType @x of - SingletonType -> SmartExp $ Pair (SmartExp $ Union (Left x)) y + SingletonType -> SmartExp $ Pair (SmartExp $ LiftUnion x) y TaglessType -> buildFields3 @x @y x y TaggedType -> buildFields3 @x @y (SmartExp $ Prj PairIdxRight x) y diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 39d4d9e25..aa45fcc35 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -511,12 +511,15 @@ data PreSmartExp acc exp t where -> exp (t1, t2) -> PreSmartExp acc exp t - Union :: Either (exp t1) (exp (SumScalar t2)) - -> PreSmartExp acc exp (SumScalar (t1, t2)) + LiftUnion :: exp t1 + -> PreSmartExp acc exp (SumScalar (t1, ())) - PrjUnion :: UnionIdx (t1, t2) t - -> exp (SumScalar (t1, t2)) - -> PreSmartExp acc exp t + Union :: (ScalarType (SumScalar t1) -> ScalarType (SumScalar t2)) + -> exp (SumScalar t1) + -> PreSmartExp acc exp (SumScalar t2) + + PrjUnion :: exp (SumScalar (t1, ())) + -> PreSmartExp acc exp t1 VecPack :: KnownNat n => VecR n s tup From 26b41c8284898a20750da018c38369ac9bf1a3a1 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 17 May 2022 11:40:42 +0200 Subject: [PATCH 39/67] new union ast constructors --- .../Array/Accelerate/Pattern/Matchable.hs | 42 ++++++++++--------- src/Data/Array/Accelerate/Smart.hs | 6 +-- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index 68915ea39..b7f95db27 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -178,7 +178,7 @@ instance Matchable Bool where match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of Just Refl -> case e of - SmartExp (Match (TagRtag 0 TagRunit) _x) -> Just SOP.Nil + SmartExp (Match (0,1) _x) -> Just SOP.Nil SmartExp Match {} -> Nothing @@ -187,7 +187,7 @@ instance Matchable Bool where case sameNat n (Proxy :: Proxy 1) of Just Refl -> case e of - SmartExp (Match (TagRtag 1 TagRunit) _x) -> Just SOP.Nil + SmartExp (Match (1,2) _x) -> Just SOP.Nil SmartExp Match {} -> Nothing @@ -247,7 +247,7 @@ instance Matchable (Maybe Int) where match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of Just Refl -> case e of - SmartExp (Match (TagRtag 0 (TagRpair _ TagRunit)) _x) + SmartExp (Match (0,1) _x) -> Just SOP.Nil SmartExp Match {} -> Nothing @@ -257,7 +257,7 @@ instance Matchable (Maybe Int) where case sameNat n (Proxy :: Proxy 1) of Just Refl -> case e of - SmartExp (Match (TagRtag 1 _) x) + SmartExp (Match (1,2) x) -> Just ( (mkExp $ PrjUnion $ SmartExp $ Union (unConcatSumScalarType (SumScalarType $ SuccScalarType (SingleScalarType UndefSingleType) ZeroScalarType)) (SmartExp $ Prj PairIdxLeft (SmartExp $ Prj PairIdxRight x))) :* SOP.Nil) @@ -283,7 +283,7 @@ instance (POSable (Either a b), POSable a, POSable b, Elt a) => Matchable (Eithe SmartExp ( Pair (unExp $ buildTAG x) - undefined --(understandConcatPlease @a @b (mergeLeft @a @b (getSingleElem x))) + (mergeLeft _ _ _) ) ) where @@ -373,20 +373,20 @@ instance (POSable (Either a b), POSable a, POSable b, Elt a) => Matchable (Eithe -- where -- fieldsx = buildFields1 x --- mergeLeft :: forall a b . TypeR a -> TypeR b -> SmartExp a -> SmartExp (Merge' a b) --- mergeLeft TupRunit TupRunit a = unExp $ constant () --- mergeLeft TupRunit (TupRpair (TupRsingle (SumScalarType x)) gbs) a --- = SmartExp $ Pair --- (makeUndefLeft x) --- (mergeLeft TupRunit gbs a) --- mergeLeft (TupRpair (TupRsingle (SumScalarType x)) gas) TupRunit a --- = SmartExp $ Pair --- (mergeSumUndefRight x (SmartExp $ Prj PairIdxLeft a)) --- (mergeLeft gas TupRunit (SmartExp $ Prj PairIdxRight a)) --- mergeLeft (TupRpair (TupRsingle (SumScalarType ga)) gas) (TupRpair (TupRsingle (SumScalarType gb)) gbs) a --- = SmartExp $ Pair --- (SmartExp $ Union _ (SmartExp $ Prj PairIdxLeft a)) --- (mergeLeft gas gbs (SmartExp $ Prj PairIdxRight a)) +mergeLeft :: forall a b . TypeR a -> TypeR b -> SmartExp a -> SmartExp (Merge' a b) +mergeLeft TupRunit TupRunit a = unExp $ constant () +mergeLeft TupRunit (TupRpair (TupRsingle (SumScalarType x)) gbs) a + = SmartExp $ Pair + (makeUndefLeft x) + (mergeLeft TupRunit gbs a) +mergeLeft (TupRpair (TupRsingle (SumScalarType x)) gas) TupRunit a + = SmartExp $ Pair + (mergeSumUndefRight x (SmartExp $ Prj PairIdxLeft a)) + (mergeLeft gas TupRunit (SmartExp $ Prj PairIdxRight a)) +mergeLeft (TupRpair (TupRsingle (SumScalarType (ga :: (SumScalarType ga)))) gas) (TupRpair (TupRsingle (SumScalarType (gb :: (SumScalarType gb)))) gbs) a + = SmartExp $ Pair + (SmartExp $ Union (\y -> scalarSumConcat' @ga @gb y (SumScalarType gb)) (SmartExp $ Prj PairIdxLeft a)) -- (scalarSumConcat' @ga @gb (SumScalarType ga)) + (mergeLeft gas gbs (SmartExp $ Prj PairIdxRight a)) makeUndefLeft :: SumScalarType x -> SmartExp (SumScalar (Undef, x)) makeUndefLeft x = SmartExp $ Const (SumScalarType (SuccScalarType (SingleScalarType UndefSingleType) x)) (PickScalar POS.Undef) @@ -394,6 +394,10 @@ makeUndefLeft x = SmartExp $ Const (SumScalarType (SuccScalarType (SingleScalarT mergeSumLeft :: forall a b . SumScalarType a -> SumScalarType b -> SmartExp (SumScalar a) -> SmartExp (SumScalar (Concat' a b)) mergeSumLeft ls rs x = SmartExp $ Union (const $ scalarSumConcat ls rs) x + +scalarSumConcat':: ScalarType (SumScalar xs) -> ScalarType (SumScalar ys) -> ScalarType (SumScalar (Concat' xs ys)) +scalarSumConcat' (SumScalarType ls) (SumScalarType rs) = scalarSumConcat ls rs + scalarSumConcat:: SumScalarType xs -> SumScalarType ys -> ScalarType (SumScalar (Concat' xs ys)) scalarSumConcat ZeroScalarType rs = SumScalarType rs scalarSumConcat (SuccScalarType l ls) rs = SumScalarType $ SuccScalarType l ls' diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index aa45fcc35..83f35b4fc 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -492,7 +492,7 @@ data PreSmartExp acc exp t where -> PreSmartExp acc exp t -- Needed for embedded pattern matching - Match :: TagR t + Match :: (TAG, TAG) -- inclusive tag lower bound inclusive, exclusive tag upper bound -> exp t -> PreSmartExp acc exp t @@ -542,7 +542,7 @@ data PreSmartExp acc exp t where -> PreSmartExp acc exp sh Case :: exp a - -> [(TagR a, exp b)] + -> [(TAG, TAG, exp b)] -> PreSmartExp acc exp b Cond :: exp PrimBool @@ -866,7 +866,7 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where VecUnpack vecR _ -> vecRtuple vecR ToIndex _ _ _ -> TupRsingle scalarTypeInt FromIndex shr _ _ -> shapeType shr - Case _ ((_,c):_) -> typeR c + Case _ ((_,_,c):_) -> typeR c Case{} -> internalError "encountered empty case" Cond _ e _ -> typeR e While t _ _ _ -> t From 5caac306585cfdbdb3a2ae0a55e6fe62ecb5e16c Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 17 May 2022 12:32:54 +0200 Subject: [PATCH 40/67] only allow singleTypes in sums --- .../Array/Accelerate/Pattern/Matchable.hs | 26 +++++----- src/Data/Array/Accelerate/Sugar/Elt.hs | 49 ++++++++++++++++++- src/Data/Array/Accelerate/Type.hs | 6 +-- 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index b7f95db27..482bd74b7 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -259,7 +259,7 @@ instance Matchable (Maybe Int) where case e of SmartExp (Match (1,2) x) -> Just ( - (mkExp $ PrjUnion $ SmartExp $ Union (unConcatSumScalarType (SumScalarType $ SuccScalarType (SingleScalarType UndefSingleType) ZeroScalarType)) (SmartExp $ Prj PairIdxLeft (SmartExp $ Prj PairIdxRight x))) + (mkExp $ PrjUnion $ SmartExp $ Union (unConcatSumScalarType (SumScalarType $ SuccScalarType (UndefSingleType) ZeroScalarType)) (SmartExp $ Prj PairIdxLeft (SmartExp $ Prj PairIdxRight x))) :* SOP.Nil) SmartExp Match {} -> Nothing @@ -389,7 +389,7 @@ mergeLeft (TupRpair (TupRsingle (SumScalarType (ga :: (SumScalarType ga)))) gas) (mergeLeft gas gbs (SmartExp $ Prj PairIdxRight a)) makeUndefLeft :: SumScalarType x -> SmartExp (SumScalar (Undef, x)) -makeUndefLeft x = SmartExp $ Const (SumScalarType (SuccScalarType (SingleScalarType UndefSingleType) x)) (PickScalar POS.Undef) +makeUndefLeft x = SmartExp $ Const (SumScalarType (SuccScalarType (UndefSingleType) x)) (PickScalar POS.Undef) mergeSumLeft :: forall a b . SumScalarType a -> SumScalarType b -> SmartExp (SumScalar a) -> SmartExp (SumScalar (Concat' a b)) mergeSumLeft ls rs x = SmartExp $ Union (const $ scalarSumConcat ls rs) x @@ -421,18 +421,18 @@ merge (TupRpair (TupRsingle (SumScalarType ga)) gas) (TupRpair (TupRsingle (SumS mergeSumUndefRight :: SumScalarType x -> SmartExp (SumScalar x) -> SmartExp (SumScalar (Concat' x (Undef, ()))) -mergeSumUndefRight ZeroScalarType a = SmartExp $ Const (SumScalarType (SuccScalarType (SingleScalarType UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) +mergeSumUndefRight ZeroScalarType a = SmartExp $ Const (SumScalarType (SuccScalarType (UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) mergeSumUndefRight (SuccScalarType x xs) a = SmartExp $ Union scalarTypeUndefRight a mergeSumUndefLeft :: SumScalarType x -> SmartExp (SumScalar x) -> SmartExp (SumScalar (Undef, x)) -mergeSumUndefLeft ZeroScalarType a = SmartExp $ Const (SumScalarType (SuccScalarType (SingleScalarType UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) +mergeSumUndefLeft ZeroScalarType a = SmartExp $ Const (SumScalarType (SuccScalarType (UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) mergeSumUndefLeft (SuccScalarType x xs) a = SmartExp $ Union scalarTypeUndefLeft a scalarTypeUndefLeft :: ScalarType (SumScalar a) -> ScalarType (SumScalar (Undef, a)) -scalarTypeUndefLeft (SumScalarType x) = SumScalarType (SuccScalarType (scalarType @Undef) x) +scalarTypeUndefLeft (SumScalarType x) = SumScalarType (SuccScalarType (singleType @Undef) x) scalarTypeUndefRight :: ScalarType (SumScalar a) -> ScalarType (SumScalar (Concat' a (Undef, ()))) -scalarTypeUndefRight (SumScalarType ZeroScalarType) = SumScalarType (SuccScalarType (scalarType @Undef) ZeroScalarType) +scalarTypeUndefRight (SumScalarType ZeroScalarType) = SumScalarType (SuccScalarType (singleType @Undef) ZeroScalarType) scalarTypeUndefRight (SumScalarType (SuccScalarType x xs)) = SumScalarType (SuccScalarType x xs') where @@ -475,13 +475,13 @@ buildFields1 x = case eltRType @x of TaglessType -> x TaggedType -> SmartExp $ Prj PairIdxRight x -weirdConvert2 :: forall x . (Elt x, POSable x) => TypeR (EltR x) -> TypeR (FlattenProduct (Fields x)) -weirdConvert2 x = case eltRType @x of - SingletonType -> case x of - TupRsingle x' -> TupRpair (TupRsingle (SumScalarType (SuccScalarType x' ZeroScalarType))) TupRunit - TaglessType -> x - TaggedType -> case x of - TupRpair _ x' -> x' +-- weirdConvert2 :: forall x . (Elt x, POSable x) => TypeR (EltR x) -> TypeR (FlattenProduct (Fields x)) +-- weirdConvert2 x = case eltRType @x of +-- SingletonType -> case x of +-- TupRsingle x' -> TupRpair (TupRsingle (SumScalarType (SuccScalarType x' ZeroScalarType))) TupRunit +-- TaglessType -> x +-- TaggedType -> case x of +-- TupRpair _ x' -> x' -- guidedAppend :: forall x y . TypeR (FlattenProduct (Fields x)) -> SmartExp (FlattenProduct (Fields x)) -> SmartExp (FlattenProduct (Fields y)) -> SmartExp (FlattenProduct (Fields x ++ Fields y)) -- guidedAppend TupRunit x y | Refl :: (FlattenProduct (Fields y) :~: FlattenProduct (Fields x ++ Fields y)) <- unsafeCoerce Refl = y diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index e77c038ae..278ae3466 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -132,7 +132,7 @@ flattenProductType (PTCons x xs) = TupRpair (TupRsingle (flattenSumType x)) (fla flattenSumType :: SumType a -> ScalarType (SumScalar (FlattenSum a)) flattenSumType STZero = SumScalarType ZeroScalarType flattenSumType (STSucc x xs) = case flattenSumType xs of - SumScalarType xs' -> SumScalarType (SuccScalarType (mkScalarType x) xs') + SumScalarType xs' -> SumScalarType (SuccScalarType (mkSingleType x) xs') -- This is an unsafe conversion, and should be kept strictly in sync with the -- set of types that implement Ground @@ -181,6 +181,53 @@ mkScalarType _ = scalarType @a +-- This is an unsafe conversion, and should be kept strictly in sync with the +-- set of types that implement Ground +mkSingleType :: forall a . (Typeable a, Ground a) => a -> SingleType a +mkSingleType _ + | Just Refl <- eqT @a @Int + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Int8 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Int16 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Int32 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Int64 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Word + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Word8 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Word16 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Word32 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Word64 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Half + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Float + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Double + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Undef + = singleType @a + + mkEltRT :: forall a . (POSable a) => TypeR (POStoEltR (Choices a) (Fields a)) mkEltRT = case sameNat cs (Proxy :: Proxy 1) of -- This distinction is hard to express in a type-correct way, diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 7d81ac734..e7d492926 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -240,7 +240,7 @@ data SumScalar x where SkipScalar :: SumScalar b -> SumScalar (a, b) data SumScalarType a where - SuccScalarType :: ScalarType a -> SumScalarType b -> SumScalarType (a, b) + SuccScalarType :: SingleType a -> SumScalarType b -> SumScalarType (a, b) ZeroScalarType :: SumScalarType () data SingleType a where @@ -640,5 +640,5 @@ instance (IsSumScalar a) => IsScalar (SumScalar a) where instance IsSumScalar () where sumScalarType = ZeroScalarType -instance (IsScalar a, IsSumScalar b) => IsSumScalar (a, b) where - sumScalarType = SuccScalarType (scalarType @a) (sumScalarType @b) +instance (IsSingle a, IsSumScalar b) => IsSumScalar (a, b) where + sumScalarType = SuccScalarType (singleType @a) (sumScalarType @b) From 2b9a753e3ceb59bc88bbd1ea41206673d285fdcb Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Tue, 17 May 2022 12:36:20 +0200 Subject: [PATCH 41/67] rename sumscalar to unionscalar --- src/Data/Array/Accelerate/AST/Idx.hs | 9 +- .../Array/Accelerate/Pattern/Matchable.hs | 88 +++++++++---------- src/Data/Array/Accelerate/Smart.hs | 10 +-- src/Data/Array/Accelerate/Sugar/Elt.hs | 6 +- src/Data/Array/Accelerate/Type.hs | 40 ++++----- 5 files changed, 73 insertions(+), 80 deletions(-) diff --git a/src/Data/Array/Accelerate/AST/Idx.hs b/src/Data/Array/Accelerate/AST/Idx.hs index 00ff8eee3..acb5212fe 100644 --- a/src/Data/Array/Accelerate/AST/Idx.hs +++ b/src/Data/Array/Accelerate/AST/Idx.hs @@ -27,8 +27,7 @@ module Data.Array.Accelerate.AST.Idx ( idxToInt, rnfIdx, liftIdx, - PairIdx(..), - UnionIdx(..) + PairIdx(..) ) where @@ -113,9 +112,3 @@ pattern VoidIdx a <- (\case{} -> a) data PairIdx p a where PairIdxLeft :: PairIdx (a, b) a PairIdxRight :: PairIdx (a, b) b - -data UnionIdx p a where - UnionIdxLeft :: UnionIdx (a, b) a - UnionIdxRight :: UnionIdx (a, b) (SumScalar b) - - \ No newline at end of file diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index 482bd74b7..47a80230e 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -78,7 +78,7 @@ buildTag (((Exp x) :: (Exp x)) :* (xs :: xs)) = case sameNat (emptyChoices @x) ( -- flattenProduct :: Product a -> FlattenProduct a -- flattenProduct Nil = () --- flattenProduct (Cons x xs) = (SumScalarType x, flattenProduct xs) +-- flattenProduct (Cons x xs) = (UnionScalarType x, flattenProduct xs) -- buildFields :: forall n a . (POSable a, Elt a) => Proxy n -> NP SmartExp (Index (SOPCode a) n) -> SmartExp (FlattenProduct (Fields a)) -- buildFields _ a = case emptyFields @a of @@ -216,7 +216,7 @@ instance Matchable (Maybe Int) where Pair (SmartExp ( Const - (scalarType @(SumScalar (Undef, (Int, ())))) + (scalarType @(UnionScalar (Undef, (Int, ())))) (PickScalar POS.Undef) )) (SmartExp Smart.Nil) @@ -259,7 +259,7 @@ instance Matchable (Maybe Int) where case e of SmartExp (Match (1,2) x) -> Just ( - (mkExp $ PrjUnion $ SmartExp $ Union (unConcatSumScalarType (SumScalarType $ SuccScalarType (UndefSingleType) ZeroScalarType)) (SmartExp $ Prj PairIdxLeft (SmartExp $ Prj PairIdxRight x))) + (mkExp $ PrjUnion $ SmartExp $ Union (unConcatSumScalarType (UnionScalarType $ SuccScalarType (UndefSingleType) ZeroScalarType)) (SmartExp $ Prj PairIdxLeft (SmartExp $ Prj PairIdxRight x))) :* SOP.Nil) SmartExp Match {} -> Nothing @@ -268,9 +268,9 @@ instance Matchable (Maybe Int) where Nothing -> error "Impossible type encountered" -unConcatSumScalarType :: ScalarType (SumScalar a) -> ScalarType (SumScalar (Concat' a b)) -> ScalarType (SumScalar b) -unConcatSumScalarType (SumScalarType ZeroScalarType) xs = xs -unConcatSumScalarType (SumScalarType (SuccScalarType a as)) (SumScalarType (SuccScalarType x xs)) = unConcatSumScalarType (SumScalarType as) (SumScalarType xs) +unConcatSumScalarType :: ScalarType (UnionScalar a) -> ScalarType (UnionScalar (Concat' a b)) -> ScalarType (UnionScalar b) +unConcatSumScalarType (UnionScalarType ZeroScalarType) xs = xs +unConcatSumScalarType (UnionScalarType (SuccScalarType a as)) (UnionScalarType (SuccScalarType x xs)) = unConcatSumScalarType (UnionScalarType as) (UnionScalarType xs) instance (POSable (Either a b), POSable a, POSable b, Elt a) => Matchable (Either a b) where type Choices' (Either a b) = OuterChoices (Either a b) @@ -375,76 +375,76 @@ instance (POSable (Either a b), POSable a, POSable b, Elt a) => Matchable (Eithe mergeLeft :: forall a b . TypeR a -> TypeR b -> SmartExp a -> SmartExp (Merge' a b) mergeLeft TupRunit TupRunit a = unExp $ constant () -mergeLeft TupRunit (TupRpair (TupRsingle (SumScalarType x)) gbs) a +mergeLeft TupRunit (TupRpair (TupRsingle (UnionScalarType x)) gbs) a = SmartExp $ Pair (makeUndefLeft x) (mergeLeft TupRunit gbs a) -mergeLeft (TupRpair (TupRsingle (SumScalarType x)) gas) TupRunit a +mergeLeft (TupRpair (TupRsingle (UnionScalarType x)) gas) TupRunit a = SmartExp $ Pair (mergeSumUndefRight x (SmartExp $ Prj PairIdxLeft a)) (mergeLeft gas TupRunit (SmartExp $ Prj PairIdxRight a)) -mergeLeft (TupRpair (TupRsingle (SumScalarType (ga :: (SumScalarType ga)))) gas) (TupRpair (TupRsingle (SumScalarType (gb :: (SumScalarType gb)))) gbs) a +mergeLeft (TupRpair (TupRsingle (UnionScalarType (ga :: (UnionScalarType ga)))) gas) (TupRpair (TupRsingle (UnionScalarType (gb :: (UnionScalarType gb)))) gbs) a = SmartExp $ Pair - (SmartExp $ Union (\y -> scalarSumConcat' @ga @gb y (SumScalarType gb)) (SmartExp $ Prj PairIdxLeft a)) -- (scalarSumConcat' @ga @gb (SumScalarType ga)) + (SmartExp $ Union (\y -> scalarSumConcat' @ga @gb y (UnionScalarType gb)) (SmartExp $ Prj PairIdxLeft a)) -- (scalarSumConcat' @ga @gb (UnionScalarType ga)) (mergeLeft gas gbs (SmartExp $ Prj PairIdxRight a)) -makeUndefLeft :: SumScalarType x -> SmartExp (SumScalar (Undef, x)) -makeUndefLeft x = SmartExp $ Const (SumScalarType (SuccScalarType (UndefSingleType) x)) (PickScalar POS.Undef) +makeUndefLeft :: UnionScalarType x -> SmartExp (UnionScalar (Undef, x)) +makeUndefLeft x = SmartExp $ Const (UnionScalarType (SuccScalarType (UndefSingleType) x)) (PickScalar POS.Undef) -mergeSumLeft :: forall a b . SumScalarType a -> SumScalarType b -> SmartExp (SumScalar a) -> SmartExp (SumScalar (Concat' a b)) +mergeSumLeft :: forall a b . UnionScalarType a -> UnionScalarType b -> SmartExp (UnionScalar a) -> SmartExp (UnionScalar (Concat' a b)) mergeSumLeft ls rs x = SmartExp $ Union (const $ scalarSumConcat ls rs) x -scalarSumConcat':: ScalarType (SumScalar xs) -> ScalarType (SumScalar ys) -> ScalarType (SumScalar (Concat' xs ys)) -scalarSumConcat' (SumScalarType ls) (SumScalarType rs) = scalarSumConcat ls rs +scalarSumConcat':: ScalarType (UnionScalar xs) -> ScalarType (UnionScalar ys) -> ScalarType (UnionScalar (Concat' xs ys)) +scalarSumConcat' (UnionScalarType ls) (UnionScalarType rs) = scalarSumConcat ls rs -scalarSumConcat:: SumScalarType xs -> SumScalarType ys -> ScalarType (SumScalar (Concat' xs ys)) -scalarSumConcat ZeroScalarType rs = SumScalarType rs -scalarSumConcat (SuccScalarType l ls) rs = SumScalarType $ SuccScalarType l ls' +scalarSumConcat:: UnionScalarType xs -> UnionScalarType ys -> ScalarType (UnionScalar (Concat' xs ys)) +scalarSumConcat ZeroScalarType rs = UnionScalarType rs +scalarSumConcat (SuccScalarType l ls) rs = UnionScalarType $ SuccScalarType l ls' where - SumScalarType ls' = scalarSumConcat ls rs + UnionScalarType ls' = scalarSumConcat ls rs merge :: forall a b . TypeR a -> TypeR b -> SmartExp a -> SmartExp b -> SmartExp (Merge' a b) merge TupRunit TupRunit a b = unExp $ constant () -merge TupRunit (TupRpair (TupRsingle (SumScalarType x)) gbs) a b +merge TupRunit (TupRpair (TupRsingle (UnionScalarType x)) gbs) a b = SmartExp $ Pair (mergeSumUndefLeft x (SmartExp $ Prj PairIdxLeft b)) (merge TupRunit gbs a (SmartExp $ Prj PairIdxRight b)) -merge (TupRpair (TupRsingle (SumScalarType x)) gas) TupRunit a b +merge (TupRpair (TupRsingle (UnionScalarType x)) gas) TupRunit a b = SmartExp $ Pair (mergeSumUndefRight x (SmartExp $ Prj PairIdxLeft a)) (merge gas TupRunit (SmartExp $ Prj PairIdxRight a) b) -merge (TupRpair (TupRsingle (SumScalarType ga)) gas) (TupRpair (TupRsingle (SumScalarType gb)) gbs) a b +merge (TupRpair (TupRsingle (UnionScalarType ga)) gas) (TupRpair (TupRsingle (UnionScalarType gb)) gbs) a b = SmartExp $ Pair (undefined) -- mergeSum (merge gas gbs (SmartExp $ Prj PairIdxRight a) (SmartExp $ Prj PairIdxRight b)) -mergeSumUndefRight :: SumScalarType x -> SmartExp (SumScalar x) -> SmartExp (SumScalar (Concat' x (Undef, ()))) -mergeSumUndefRight ZeroScalarType a = SmartExp $ Const (SumScalarType (SuccScalarType (UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) +mergeSumUndefRight :: UnionScalarType x -> SmartExp (UnionScalar x) -> SmartExp (UnionScalar (Concat' x (Undef, ()))) +mergeSumUndefRight ZeroScalarType a = SmartExp $ Const (UnionScalarType (SuccScalarType (UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) mergeSumUndefRight (SuccScalarType x xs) a = SmartExp $ Union scalarTypeUndefRight a -mergeSumUndefLeft :: SumScalarType x -> SmartExp (SumScalar x) -> SmartExp (SumScalar (Undef, x)) -mergeSumUndefLeft ZeroScalarType a = SmartExp $ Const (SumScalarType (SuccScalarType (UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) +mergeSumUndefLeft :: UnionScalarType x -> SmartExp (UnionScalar x) -> SmartExp (UnionScalar (Undef, x)) +mergeSumUndefLeft ZeroScalarType a = SmartExp $ Const (UnionScalarType (SuccScalarType (UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) mergeSumUndefLeft (SuccScalarType x xs) a = SmartExp $ Union scalarTypeUndefLeft a -scalarTypeUndefLeft :: ScalarType (SumScalar a) -> ScalarType (SumScalar (Undef, a)) -scalarTypeUndefLeft (SumScalarType x) = SumScalarType (SuccScalarType (singleType @Undef) x) +scalarTypeUndefLeft :: ScalarType (UnionScalar a) -> ScalarType (UnionScalar (Undef, a)) +scalarTypeUndefLeft (UnionScalarType x) = UnionScalarType (SuccScalarType (singleType @Undef) x) -scalarTypeUndefRight :: ScalarType (SumScalar a) -> ScalarType (SumScalar (Concat' a (Undef, ()))) -scalarTypeUndefRight (SumScalarType ZeroScalarType) = SumScalarType (SuccScalarType (singleType @Undef) ZeroScalarType) -scalarTypeUndefRight (SumScalarType (SuccScalarType x xs)) - = SumScalarType (SuccScalarType x xs') +scalarTypeUndefRight :: ScalarType (UnionScalar a) -> ScalarType (UnionScalar (Concat' a (Undef, ()))) +scalarTypeUndefRight (UnionScalarType ZeroScalarType) = UnionScalarType (SuccScalarType (singleType @Undef) ZeroScalarType) +scalarTypeUndefRight (UnionScalarType (SuccScalarType x xs)) + = UnionScalarType (SuccScalarType x xs') where - (SumScalarType xs') = scalarTypeUndefRight (SumScalarType xs) + (UnionScalarType xs') = scalarTypeUndefRight (UnionScalarType xs) --- mergeSum :: SumScalar +-- mergeSum :: UnionScalar -class AllSumScalar (xs :: Type) where +class AllUnionScalar (xs :: Type) where -instance AllSumScalar () where +instance AllUnionScalar () where -instance (x' ~ SumScalar x, IsSumScalar x', AllSumScalar xs) => AllSumScalar (x, xs) where +instance (x' ~ UnionScalar x, IsUnionScalar x', AllUnionScalar xs) => AllUnionScalar (x, xs) where class All' (c :: k -> Constraint) (xs :: Type) where @@ -458,15 +458,15 @@ instance (c x, All' c xs) => All' c (x, xs) where -- like POSable.Merge, but lifted to tuple lists type family Merge' (a :: Type) (b :: Type) = (r :: Type) where Merge' () () = () - Merge' () (SumScalar b, bs) = (SumScalar (Undef, b), Merge' () bs) - Merge' (SumScalar a, as) () = (SumScalar (Concat' a (Undef, ())), Merge' as ()) - Merge' (SumScalar a, as) (SumScalar b, bs) = (SumScalar (Concat' a b), Merge' as bs) + Merge' () (UnionScalar b, bs) = (UnionScalar (Undef, b), Merge' () bs) + Merge' (UnionScalar a, as) () = (UnionScalar (Concat' a (Undef, ())), Merge' as ()) + Merge' (UnionScalar a, as) (UnionScalar b, bs) = (UnionScalar (Concat' a b), Merge' as bs) type family Concat' (a :: Type) (b :: Type) = (r :: Type) where Concat' () ys = ys Concat' (x, xs) ys = (x, Concat' xs ys) --- fromSumType :: SumType x -> SmartExp (SumScalar (Undef, FlattenSum x)) +-- fromSumType :: SumType x -> SmartExp (UnionScalar (Undef, FlattenSum x)) -- fromSumType x = SmartExp $ Union (_) $ SmartExp (LiftUnion (SmartExp (Const (SingleScalarType UndefSingleType) POS.Undef))) buildFields1 :: forall x . (POSable x) => SmartExp (EltR x) -> SmartExp (FlattenProduct (Fields x)) @@ -478,7 +478,7 @@ buildFields1 x = case eltRType @x of -- weirdConvert2 :: forall x . (Elt x, POSable x) => TypeR (EltR x) -> TypeR (FlattenProduct (Fields x)) -- weirdConvert2 x = case eltRType @x of -- SingletonType -> case x of --- TupRsingle x' -> TupRpair (TupRsingle (SumScalarType (SuccScalarType x' ZeroScalarType))) TupRunit +-- TupRsingle x' -> TupRpair (TupRsingle (UnionScalarType (SuccScalarType x' ZeroScalarType))) TupRunit -- TaglessType -> x -- TaggedType -> case x of -- TupRpair _ x' -> x' @@ -502,7 +502,7 @@ buildFields4 :: forall x y . ProductType (Fields x) -> SmartExp (FlattenProduct buildFields4 PTNil x y = y buildFields4 (PTCons g gs) x y = undefined where - x' :: SmartExp (SumScalar (FlattenSum (Head (Fields x)))) + x' :: SmartExp (UnionScalar (FlattenSum (Head (Fields x)))) x' = SmartExp $ Prj PairIdxLeft x xs' :: SmartExp (FlattenProduct (Tail (Fields x))) xs' = SmartExp $ Prj PairIdxRight x @@ -548,7 +548,7 @@ type family Tail (xs :: [x]) :: [x] where type family MapFlattenSum (x :: [[Type]]) :: [Type] where MapFlattenSum '[] = '[] - MapFlattenSum (x ': xs) = SumScalar (FlattenSum x) ': MapFlattenSum xs + MapFlattenSum (x ': xs) = UnionScalar (FlattenSum x) ': MapFlattenSum xs -- like combineProducts, but lifted to the AST buildTAG :: (All POSable xs) => NP Exp xs -> Exp TAG diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 83f35b4fc..af3416355 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -512,13 +512,13 @@ data PreSmartExp acc exp t where -> PreSmartExp acc exp t LiftUnion :: exp t1 - -> PreSmartExp acc exp (SumScalar (t1, ())) + -> PreSmartExp acc exp (UnionScalar (t1, ())) - Union :: (ScalarType (SumScalar t1) -> ScalarType (SumScalar t2)) - -> exp (SumScalar t1) - -> PreSmartExp acc exp (SumScalar t2) + Union :: (ScalarType (UnionScalar t1) -> ScalarType (UnionScalar t2)) + -> exp (UnionScalar t1) + -> PreSmartExp acc exp (UnionScalar t2) - PrjUnion :: exp (SumScalar (t1, ())) + PrjUnion :: exp (UnionScalar (t1, ())) -> PreSmartExp acc exp t1 VecPack :: KnownNat n diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index 278ae3466..a3c2f999a 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -129,10 +129,10 @@ flattenProductType :: ProductType a -> TypeR (FlattenProduct a) flattenProductType PTNil = TupRunit flattenProductType (PTCons x xs) = TupRpair (TupRsingle (flattenSumType x)) (flattenProductType xs) -flattenSumType :: SumType a -> ScalarType (SumScalar (FlattenSum a)) -flattenSumType STZero = SumScalarType ZeroScalarType +flattenSumType :: SumType a -> ScalarType (UnionScalar (FlattenSum a)) +flattenSumType STZero = UnionScalarType ZeroScalarType flattenSumType (STSucc x xs) = case flattenSumType xs of - SumScalarType xs' -> SumScalarType (SuccScalarType (mkSingleType x) xs') + UnionScalarType xs' -> UnionScalarType (SuccScalarType (mkSingleType x) xs') -- This is an unsafe conversion, and should be kept strictly in sync with the -- set of types that implement Ground diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index e7d492926..8746d8c3a 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -108,7 +108,7 @@ type family POStoEltR (cs :: Nat) fs :: Type where type family FlattenProduct (xss :: f [a]) = (r :: Type) | r -> f where FlattenProduct '[] = () - FlattenProduct (x ': xs) = (SumScalar (FlattenSum x), FlattenProduct xs) + FlattenProduct (x ': xs) = (UnionScalar (FlattenSum x), FlattenProduct xs) type family FlattenSum (xs :: [a]) :: Type where FlattenSum '[] = () @@ -116,7 +116,7 @@ type family FlattenSum (xs :: [a]) :: Type where type family FlattenProductType (xss :: [[a]]) :: Type where FlattenProductType '[] = () - FlattenProductType (x ': xs) = (SumScalarType (FlattenSumType x), FlattenProductType xs) + FlattenProductType (x ': xs) = (UnionScalarType (FlattenSumType x), FlattenProductType xs) type family FlattenSumType (xs :: [a]) :: Type where FlattenSumType '[] = () @@ -127,7 +127,7 @@ flattenProduct :: Product a -> FlattenProduct a flattenProduct Nil = () flattenProduct (Cons x xs) = (flattenSum x, flattenProduct xs) -flattenSum :: Sum a -> SumScalar (FlattenSum a) +flattenSum :: Sum a -> UnionScalar (FlattenSum a) flattenSum (Pick x) = PickScalar x flattenSum (Skip xs) = SkipScalar (flattenSum xs) @@ -230,18 +230,18 @@ data BoundedType a where data ScalarType a where SingleScalarType :: SingleType a -> ScalarType a VectorScalarType :: VectorType (Vec n a) -> ScalarType (Vec n a) - SumScalarType :: SumScalarType a -> ScalarType (SumScalar a) + UnionScalarType :: UnionScalarType a -> ScalarType (UnionScalar a) -class IsSumScalar a where - sumScalarType :: SumScalarType a +class IsUnionScalar a where + unionScalarType :: UnionScalarType a -data SumScalar x where - PickScalar :: a -> SumScalar (a, b) - SkipScalar :: SumScalar b -> SumScalar (a, b) +data UnionScalar x where + PickScalar :: a -> UnionScalar (a, b) + SkipScalar :: UnionScalar b -> UnionScalar (a, b) -data SumScalarType a where - SuccScalarType :: SingleType a -> SumScalarType b -> SumScalarType (a, b) - ZeroScalarType :: SumScalarType () +data UnionScalarType a where + SuccScalarType :: SingleType a -> UnionScalarType b -> UnionScalarType (a, b) + ZeroScalarType :: UnionScalarType () data SingleType a where NumSingleType :: NumType a -> SingleType a @@ -285,9 +285,9 @@ instance Show (VectorType a) where instance Show (ScalarType a) where show (SingleScalarType ty) = show ty show (VectorScalarType ty) = show ty - show (SumScalarType ty) = show ty + show (UnionScalarType ty) = show ty -instance Show (SumScalarType a) where +instance Show (UnionScalarType a) where show ZeroScalarType = "" show (SuccScalarType x (ZeroScalarType)) = show x show (SuccScalarType x xs) = show x ++ " + " ++ show xs @@ -634,11 +634,11 @@ instance IsSingle Undef where instance IsScalar Undef where scalarType = SingleScalarType singleType -instance (IsSumScalar a) => IsScalar (SumScalar a) where - scalarType = SumScalarType (sumScalarType @a) +instance (IsUnionScalar a) => IsScalar (UnionScalar a) where + scalarType = UnionScalarType (unionScalarType @a) -instance IsSumScalar () where - sumScalarType = ZeroScalarType +instance IsUnionScalar () where + unionScalarType = ZeroScalarType -instance (IsSingle a, IsSumScalar b) => IsSumScalar (a, b) where - sumScalarType = SuccScalarType (singleType @a) (sumScalarType @b) +instance (IsSingle a, IsUnionScalar b) => IsUnionScalar (a, b) where + unionScalarType = SuccScalarType (singleType @a) (unionScalarType @b) From 79fed7274def2adbe552ec0a6de51495725e7c0e Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 19 May 2022 11:19:16 +0200 Subject: [PATCH 42/67] rewrote Matchable without POSable references --- .../Array/Accelerate/Pattern/Matchable.hs | 189 +++--------------- src/Data/Array/Accelerate/Sugar/Elt.hs | 7 +- src/Data/Array/Accelerate/Sugar/Shape.hs | 6 + 3 files changed, 42 insertions(+), 160 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index 47a80230e..afe12d793 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -13,6 +13,7 @@ {-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ConstraintKinds #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Accelerate.Pattern.Matchable where @@ -23,7 +24,7 @@ import Data.Proxy import Data.Kind import Generics.SOP as SOP import Data.Type.Equality -import Data.Array.Accelerate.Representation.POS as POS +import Data.Array.Accelerate.Representation.POS as POS (Undef(..)) import Data.Array.Accelerate.Representation.Tag import Unsafe.Coerce import qualified Data.Array.Accelerate.AST as AST @@ -38,8 +39,8 @@ class Matchable a where type SOPCode a :: [[Type]] type SOPCode a = Code a - type Choices' a :: Nat - type Choices' a = Choices a + -- type Choices' a :: Nat + -- type Choices' a = Choices a build :: ( KnownNat n @@ -48,12 +49,12 @@ class Matchable a where -> Exp a default build :: ( KnownNat n - , POSable a + , Elt a ) => Proxy n -> NP Exp (Index (SOPCode a) n) -> Exp a - build n _ = case sameNat (emptyChoices @a) (Proxy :: Proxy 1) of + build n _ = case sameNat (Proxy :: Proxy (EltChoices a)) (Proxy :: Proxy 1) of -- no tag Just Refl -> undefined -- tagged @@ -64,9 +65,9 @@ class Matchable a where -> Exp a -> Maybe (NP Exp (Index (SOPCode a) n)) -buildTag :: SOP.All POSable xs => NP Exp xs -> Exp TAG +buildTag :: SOP.All Elt xs => NP Exp xs -> Exp TAG buildTag SOP.Nil = constant 0 -- exp of 0 :: Finite 1 -buildTag (((Exp x) :: (Exp x)) :* (xs :: xs)) = case sameNat (emptyChoices @x) (Proxy :: Proxy 1) of +buildTag (((Exp x) :: (Exp x)) :* (xs :: xs)) = case sameNat (Proxy :: Proxy (EltChoices x)) (Proxy :: Proxy 1) of -- x doesn't contain a tag, skip Just Refl -> buildTag xs @@ -76,48 +77,6 @@ buildTag (((Exp x) :: (Exp x)) :* (xs :: xs)) = case sameNat (emptyChoices @x) ( -- TODO: this is incorrect, we need the size of the TAG here (return to Finite?) -> mkMul (Exp (SmartExp (Prj PairIdxLeft x))) (buildTag xs) --- flattenProduct :: Product a -> FlattenProduct a --- flattenProduct Nil = () --- flattenProduct (Cons x xs) = (UnionScalarType x, flattenProduct xs) - --- buildFields :: forall n a . (POSable a, Elt a) => Proxy n -> NP SmartExp (Index (SOPCode a) n) -> SmartExp (FlattenProduct (Fields a)) --- buildFields _ a = case emptyFields @a of --- PTNil -> case constant () of { Exp se -> se } --- PTCons st pt -> case a of --- -- SOP.Nil -> SmartExp (Pair (someFunction st) undefined) --- (x :* xs) -> SmartExp (Pair undefined undefined) - -buildFields' :: Proxy n -> ProductType (Fields a) -> NP SmartExp (Index (SOPCode a) n) -> SmartExp (FlattenProduct (Fields a)) -buildFields' _ PTNil _ = SmartExp Smart.Nil -buildFields' n (PTCons x xs) SOP.Nil = undefined -- SmartExp $ Pair _ (buildFields' n xs SOP.Nil) -buildFields' _ (PTCons x xs) (y :* ys) = undefined - -someFunction :: SumType x -> SmartExp (ScalarType (Sum x)) -someFunction = undefined - -newtype SEFPF a = SEFPF (SmartExp (FlattenProduct (Fields a))) - --- mapBuildField :: (All POSable xs, All Elt xs) => NP SmartExp xs -> NP SEFPF xs --- mapBuildField SOP.Nil = SOP.Nil --- mapBuildField ((x :: SmartExp x) :* xs) = SEFPF (buildField @x x) :* mapBuildField xs - - -buildField :: forall a . (POSable a, Elt a, EltR a ~ POStoEltR (Choices a) (Fields a)) => SmartExp (EltR a) -> SmartExp (FlattenProduct (Fields a)) -buildField (SmartExp a) = case sameNat (emptyChoices @a) (Proxy :: Proxy 1) of - Just Refl -> - case emptyFields @a of - -- singleton types - PTCons (STSucc _ STZero) PTNil - | Refl :: (POStoEltR (Choices a) (Fields a) :~: a) <- unsafeCoerce Refl - -> SmartExp $ Pair (SmartExp (undefined a)) (SmartExp Smart.Nil) - -- tagless types - _ | Refl :: (POStoEltR (Choices a) (Fields a) :~: FlattenProduct (Fields a)) <- unsafeCoerce Refl - -> SmartExp a - -- tagged types - Nothing - -- We know that this is true because Choices a is not equal to 1 - | Refl :: (POStoEltR (Choices a) (Fields a) :~: (_x, FlattenProduct (Fields a))) <- unsafeCoerce Refl - -> SmartExp (Prj PairIdxRight (SmartExp a)) type family Index (xs :: [[Type]]) (y :: Nat) :: [Type] where @@ -128,25 +87,21 @@ type family ListToCons (xs :: [Type]) :: Type where ListToCons '[] = () ListToCons (x ': xs) = (x, ListToCons xs) --- copied from POSable library +-- copied from Elt library type family Products (xs :: [Nat]) :: Nat where Products '[] = 1 Products (x ': xs) = x * Products xs -- idem -type family MapChoices (xs :: [Type]) :: [Nat] where - MapChoices '[] = '[] - MapChoices (x ': xs) = Choices x ': MapChoices xs +-- type family MapChoices (xs :: [Type]) :: [Nat] where +-- MapChoices '[] = '[] +-- MapChoices (x ': xs) = Choices x ': MapChoices xs -- idem -type family Concat (xss :: [[x]]) :: [x] where - Concat '[] = '[] - Concat (xs ': xss) = xs ++ Concat xss +-- type family Concat (xss :: [[x]]) :: [x] where +-- Concat '[] = '[] +-- Concat (xs ': xss) = xs ++ Concat xss --- idem -type family MapFields (xs :: [Type]) :: [[[Type]]] where - MapFields '[] = '[] - MapFields (x ': xs) = Fields x ': MapFields xs type family MapFlattenProduct (xs :: [[[Type]]]) :: [Type] where MapFlattenProduct '[] = '[] @@ -171,7 +126,7 @@ type family ConcatASTs (xs :: [Type]) :: Type where ConcatASTs (x ': xs) = ConcatAST x (ConcatASTs xs) instance Matchable Bool where - type Choices' Bool = 2 + -- type Choices' Bool = 2 build n _ = Exp (SmartExp (Pair (undefined (fromInteger $ natVal n)) (SmartExp Smart.Nil))) @@ -204,7 +159,7 @@ tagType = TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeTAG)) instance Matchable (Maybe Int) where - type Choices' (Maybe Int) = 2 + -- type Choices' (Maybe Int) = 2 build n x = case sameNat n (Proxy :: Proxy 0) of Just Refl -> @@ -272,23 +227,25 @@ unConcatSumScalarType :: ScalarType (UnionScalar a) -> ScalarType (UnionScalar ( unConcatSumScalarType (UnionScalarType ZeroScalarType) xs = xs unConcatSumScalarType (UnionScalarType (SuccScalarType a as)) (UnionScalarType (SuccScalarType x xs)) = unConcatSumScalarType (UnionScalarType as) (UnionScalarType xs) -instance (POSable (Either a b), POSable a, POSable b, Elt a) => Matchable (Either a b) where - type Choices' (Either a b) = OuterChoices (Either a b) + + +instance (Elt (Either a b), Elt a, Elt b, Elt a) => Matchable (Either a b) where + -- type Choices' (Either a b) = OuterChoices (Either a b) build n x - | Refl :: (EltR (Either a b) :~: (TAG, FlattenProduct (Fields (Either a b)))) <- unsafeCoerce Refl -- this should be easily provable, I'm just lazy + | Refl :: (EltR (Either a b) :~: (TAG, y)) <- unsafeCoerce Refl -- this should be provable, I'm just lazy = case sameNat n (Proxy :: Proxy 0) of Just Refl -> Exp ( SmartExp ( Pair (unExp $ buildTAG x) - (mergeLeft _ _ _) + _ ) ) where tag = undefined --foldl 1 (*) (mapChoices x) - test = natVal (Proxy :: Proxy (Choices a)) + test = natVal (Proxy :: Proxy (EltChoices a)) -- Nothing -> case sameNat n (Proxy :: Proxy 1) of -- Just Refl | (Exp x' :* SOP.Nil) <- x -> Exp ( -- SmartExp ( @@ -354,20 +311,7 @@ instance (POSable (Either a b), POSable a, POSable b, Elt a) => Matchable (Eithe -- getSingleElem :: NP Exp '[a] -> Exp a -- getSingleElem (x :* SOP.Nil) = x --- understandConcatPlease :: SmartExp (FlattenProduct (Merge (Fields a) (Fields b))) -> SmartExp (FlattenProduct (Merge (Fields a ++ '[]) (Fields b ++ '[]))) --- understandConcatPlease = unsafeCoerce - --- mergeLeft :: forall a b . (POSable a, Elt a) => Exp a -> SmartExp (FlattenProduct (Merge (Fields a) (Fields b))) --- mergeLeft (Exp a) = case buildFields1 @a a of --- a' -> case weirdConvert2 @a (eltR @a) of --- a3 -> undefined - --- mergeLeft' :: forall a b . TypeR (EltR a -> ProductType b -> SmartExp (FlattenProduct a) -> SmartExp (Merge' (FlattenProduct a) (FlattenProduct b)) --- mergeLeft' PTNil PTNil a = a --- mergeLeft' PTNil (PTCons gb (gbs :: (ProductType (Fields b')))) a = SmartExp $ Pair (fromSumType gb) (mergeLeft' @a @b' PTNil gbs a) - - --- buildFields :: (All POSable xs) => NP SmartExp xs -> SmartExp (ConcatT (MapFlattenProduct (MapFields xs))) +-- buildFields :: (All Elt xs) => NP SmartExp xs -> SmartExp (ConcatT (MapFlattenProduct (MapFields xs))) -- buildFields SOP.Nil = () -- buildFields (x :* xs) = SmartExp $ Pair -- where @@ -466,102 +410,29 @@ type family Concat' (a :: Type) (b :: Type) = (r :: Type) where Concat' () ys = ys Concat' (x, xs) ys = (x, Concat' xs ys) --- fromSumType :: SumType x -> SmartExp (UnionScalar (Undef, FlattenSum x)) --- fromSumType x = SmartExp $ Union (_) $ SmartExp (LiftUnion (SmartExp (Const (SingleScalarType UndefSingleType) POS.Undef))) - -buildFields1 :: forall x . (POSable x) => SmartExp (EltR x) -> SmartExp (FlattenProduct (Fields x)) -buildFields1 x = case eltRType @x of - SingletonType -> SmartExp $ Pair (SmartExp $ LiftUnion x) (SmartExp Smart.Nil) - TaglessType -> x - TaggedType -> SmartExp $ Prj PairIdxRight x - --- weirdConvert2 :: forall x . (Elt x, POSable x) => TypeR (EltR x) -> TypeR (FlattenProduct (Fields x)) --- weirdConvert2 x = case eltRType @x of --- SingletonType -> case x of --- TupRsingle x' -> TupRpair (TupRsingle (UnionScalarType (SuccScalarType x' ZeroScalarType))) TupRunit --- TaglessType -> x --- TaggedType -> case x of --- TupRpair _ x' -> x' - --- guidedAppend :: forall x y . TypeR (FlattenProduct (Fields x)) -> SmartExp (FlattenProduct (Fields x)) -> SmartExp (FlattenProduct (Fields y)) -> SmartExp (FlattenProduct (Fields x ++ Fields y)) --- guidedAppend TupRunit x y | Refl :: (FlattenProduct (Fields y) :~: FlattenProduct (Fields x ++ Fields y)) <- unsafeCoerce Refl = y --- guidedAppend (TupRsingle g) x y = SmartExp (Pair x y) --- guidedAppend (TupRpair g1 g2) x y = undefined - -buildFields2 :: forall x y . (POSable x) => SmartExp (EltR x) -> SmartExp (FlattenProduct (Fields y)) -> SmartExp (FlattenProduct (Fields x ++ Fields y)) -buildFields2 x y = case eltRType @x of - SingletonType -> SmartExp $ Pair (SmartExp $ LiftUnion x) y - TaglessType -> buildFields3 @x @y x y - TaggedType -> buildFields3 @x @y (SmartExp $ Prj PairIdxRight x) y - - -buildFields3 :: forall x y . (POSable x) => SmartExp (FlattenProduct (Fields x)) -> SmartExp (FlattenProduct (Fields y)) -> SmartExp (FlattenProduct (Fields x ++ Fields y)) -buildFields3 = buildFields4 @x @y (emptyFields @x) - -buildFields4 :: forall x y . ProductType (Fields x) -> SmartExp (FlattenProduct (Fields x)) -> SmartExp (FlattenProduct (Fields y)) -> SmartExp (FlattenProduct (Fields x ++ Fields y)) -buildFields4 PTNil x y = y -buildFields4 (PTCons g gs) x y = undefined - where - x' :: SmartExp (UnionScalar (FlattenSum (Head (Fields x)))) - x' = SmartExp $ Prj PairIdxLeft x - xs' :: SmartExp (FlattenProduct (Tail (Fields x))) - xs' = SmartExp $ Prj PairIdxRight x - f :: SmartExp (FlattenProduct (Fields x ++ Fields y)) - f = SmartExp (Pair x' xy) - xy :: SmartExp (FlattenProduct (Tail (Fields x) ++ Fields y)) - xy = undefined type family Head (xs :: [x]) :: x where Head (x ': xs) = x type family Tail (xs :: [x]) :: [x] where Tail (x ': xs) = xs --- concatFields :: forall x y. (POSable x) => SmartExp (FlattenProduct (Fields x)) -> SmartExp (FlattenProduct (Fields y)) -> SmartExp (FlattenProduct (Fields x ++ Fields y)) --- concatFields x y = case emptyFields @x of --- PTNil --- -> y --- (PTCons x' (xs' :: (ProductType xs))) --- -> SmartExp $ Pair (SmartExp $ Prj PairIdxLeft x) (f @xs @y (SmartExp $ Prj PairIdxRight x) y) --- where --- -- rec' :: SmartExp (FlattenProduct ) --- -- rec' x y = concatFields (SmartExp $ Prj PairIdxRight x) y --- f :: SmartExp (FlattenProduct ys) -> SmartExp (FlattenProduct (Fields z)) -> SmartExp (FlattenProduct (ys ++ Fields z)) --- f xs y' = undefined - --- concatFields' :: forall xs ys . SmartExp (FlattenProduct xs) -> SmartExp (FlattenProduct ys) -> SmartExp (FlattenProduct (xs ++ ys)) --- concatFields' = undefined - --- convertASTtoNP :: forall x . POSable x => SmartExp (FlattenProduct (Fields x)) -> NP SmartExp (MapFlattenSum (Fields x)) --- convertASTtoNP = convertASTtoNP' @(Fields x) (emptyFields @x) - --- convertASTtoNP' :: ProductType x -> SmartExp (FlattenProduct x) -> NP SmartExp (MapFlattenSum x) --- convertASTtoNP' PTNil _ = SOP.Nil --- convertASTtoNP' (PTCons _ xs) y = SmartExp (Prj PairIdxLeft y) :* convertASTtoNP' xs (SmartExp $ Prj PairIdxRight y) - --- nPtoAST :: NP SmartExp xs -> SmartExp (ConcatASTs xs) --- nPtoAST SOP.Nil = SmartExp Smart.Nil --- nPtoAST (x :* xs) = SmartExp $ Pair x (nPtoAST xs) - --- concatAST :: forall x y . (Elt x) => SmartExp x -> SmartExp y -> SmartExp (ConcatAST x y) --- concatAST x y | TupRunit <- eltR @x = y --- concatAST x y | (TupRpair _ _) <- eltR @x = SmartExp $ Pair (SmartExp $ Prj PairIdxLeft x) (concatAST (SmartExp (Prj PairIdxRight x)) y) type family MapFlattenSum (x :: [[Type]]) :: [Type] where MapFlattenSum '[] = '[] MapFlattenSum (x ': xs) = UnionScalar (FlattenSum x) ': MapFlattenSum xs -- like combineProducts, but lifted to the AST -buildTAG :: (All POSable xs) => NP Exp xs -> Exp TAG +buildTAG :: (All Elt xs) => NP Exp xs -> Exp TAG buildTAG SOP.Nil = Exp $ makeTag 0 buildTAG (x :* xs) = combineProduct x (buildTAG xs) -- like Finite.combineProduct, but lifted to the AST -- basically `tag x + tag y * natVal x` -combineProduct :: forall x. (POSable x) => Exp x -> Exp TAG -> Exp TAG -combineProduct x y = case sameNat (Proxy :: Proxy (Choices x)) (Proxy :: Proxy 1) of +combineProduct :: forall x. (Elt x) => Exp x -> Exp TAG -> Exp TAG +combineProduct x y = case sameNat (Proxy :: Proxy (EltChoices x)) (Proxy :: Proxy 1) of -- untagged type: `tag x = 0`, `natVal x = 1` Just Refl -> y -- tagged type Nothing - | Refl :: (EltR x :~: (TAG, FlattenProduct (Fields x))) <- unsafeCoerce Refl - -> mkAdd (mkExp $ Prj PairIdxLeft (unExp x)) (mkMul y (constant (fromInteger $ natVal (Proxy :: Proxy (Choices x))))) + | Refl :: (EltR x :~: (TAG, y)) <- unsafeCoerce Refl + -> mkAdd (mkExp $ Prj PairIdxLeft (unExp x)) (mkMul y (constant (fromInteger $ natVal (Proxy :: Proxy (EltChoices x))))) diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index a3c2f999a..e92a903a0 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -12,6 +12,7 @@ {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE GADTs #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# OPTIONS_HADDOCK hide #-} {-# OPTIONS_GHC -ddump-splices #-} -- | @@ -80,7 +81,7 @@ import Data.Typeable -- See the function 'Data.Array.Accelerate.match' for details on how to use -- sum types in embedded code. -- -class Elt a where +class (KnownNat (EltChoices a)) => Elt a where -- | Type representation mapping, which explains how to convert a type -- from the surface type into the internal representation type consisting -- only of simple primitive types, unit '()', and pair '(,)'. @@ -88,6 +89,9 @@ class Elt a where type EltR a :: Type type EltR a = POStoEltR (Choices a) (Fields a) + type EltChoices a :: Nat + type EltChoices a = Choices a + -- eltR :: TypeR (EltR a) tagsR :: [TagR (EltR a)] @@ -282,6 +286,7 @@ instance (POSable (Either a b), Elt a, Elt b) => Elt (Either a b) instance Elt Char where type EltR Char = Word32 + type EltChoices Char = 1 eltR = TupRsingle scalarType tagsR = [TagRsingle scalarType] toElt = chr . fromIntegral diff --git a/src/Data/Array/Accelerate/Sugar/Shape.hs b/src/Data/Array/Accelerate/Sugar/Shape.hs index 02cd46b49..fb2d9bee2 100644 --- a/src/Data/Array/Accelerate/Sugar/Shape.hs +++ b/src/Data/Array/Accelerate/Sugar/Shape.hs @@ -8,6 +8,8 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE DataKinds #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Sugar.Shape @@ -39,6 +41,7 @@ import qualified Data.Array.Accelerate.Representation.Slice as R import Data.Kind import GHC.Generics as GHC +import GHC.TypeLits (type (+)) -- Shorthand for common shape types @@ -306,6 +309,7 @@ class (Slice (DivisionSlice sl)) => Division sl where instance (Elt t, Elt h) => Elt (t :. h) where type EltR (t :. h) = (EltR t, EltR h) + type EltChoices (t :. h) = 1 eltR = TupRpair (eltR @t) (eltR @h) tagsR = [TagRpair t h | t <- tagsR @t, h <- tagsR @h] fromElt (t:.h) = (fromElt t, fromElt h) @@ -314,8 +318,10 @@ instance (Elt t, Elt h) => Elt (t :. h) where instance POS.Generic (Any Z) instance POSable (Any Z) instance Elt (Any Z) + instance Shape sh => Elt (Any (sh :. Int)) where type EltR (Any (sh :. Int)) = (EltR (Any sh), ()) + type EltChoices (Any (sh :. Int)) = 1 eltR = TupRpair (eltR @(Any sh)) TupRunit tagsR = [TagRpair t TagRunit | t <- tagsR @(Any sh)] fromElt _ = (fromElt (Any :: Any sh), ()) From da6380e773cd3a26504cf44c60b502d90c8ad6f8 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Fri, 20 May 2022 10:21:42 +0200 Subject: [PATCH 43/67] index operator with beauty notation --- .../Array/Accelerate/Pattern/Matchable.hs | 125 +++++------------- .../Array/Accelerate/Representation/POS.hs | 2 +- 2 files changed, 31 insertions(+), 96 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index afe12d793..768c93755 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -24,7 +24,7 @@ import Data.Proxy import Data.Kind import Generics.SOP as SOP import Data.Type.Equality -import Data.Array.Accelerate.Representation.POS as POS (Undef(..)) +import Data.Array.Accelerate.Representation.POS as POS import Data.Array.Accelerate.Representation.Tag import Unsafe.Coerce import qualified Data.Array.Accelerate.AST as AST @@ -45,13 +45,13 @@ class Matchable a where build :: ( KnownNat n ) => Proxy n - -> NP Exp (Index (SOPCode a) n) + -> NP Exp (SOPCode a !! n) -> Exp a default build :: ( KnownNat n , Elt a ) => Proxy n - -> NP Exp (Index (SOPCode a) n) + -> NP Exp (SOPCode a !! n) -> Exp a build n _ = case sameNat (Proxy :: Proxy (EltChoices a)) (Proxy :: Proxy 1) of @@ -63,7 +63,7 @@ class Matchable a where match :: ( KnownNat n ) => Proxy n -> Exp a - -> Maybe (NP Exp (Index (SOPCode a) n)) + -> Maybe (NP Exp (SOPCode a !! n)) buildTag :: SOP.All Elt xs => NP Exp xs -> Exp TAG buildTag SOP.Nil = constant 0 -- exp of 0 :: Finite 1 @@ -79,9 +79,11 @@ buildTag (((Exp x) :: (Exp x)) :* (xs :: xs)) = case sameNat (Proxy :: Proxy (El -type family Index (xs :: [[Type]]) (y :: Nat) :: [Type] where - Index (x ': xs) 0 = x - Index (x ': xs) n = Index xs (n - 1) +type family (!!) (xs :: [[Type]]) (y :: Nat) :: [Type] where + (x ': xs) !! 0 = x + (x ': xs) !! n = xs !! (n - 1) + +infixl 9 !! type family ListToCons (xs :: [Type]) :: Type where ListToCons '[] = () @@ -229,93 +231,26 @@ unConcatSumScalarType (UnionScalarType (SuccScalarType a as)) (UnionScalarType ( -instance (Elt (Either a b), Elt a, Elt b, Elt a) => Matchable (Either a b) where +instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) where -- type Choices' (Either a b) = OuterChoices (Either a b) - build n x - | Refl :: (EltR (Either a b) :~: (TAG, y)) <- unsafeCoerce Refl -- this should be provable, I'm just lazy + build n fs + -- | Refl :: (EltR (Either a b) :~: (TAG, y)) <- unsafeCoerce Refl -- this should be provable, I'm just lazy = case sameNat n (Proxy :: Proxy 0) of - Just Refl - -> Exp ( - SmartExp ( - Pair - (unExp $ buildTAG x) - _ - ) - ) - where - tag = undefined --foldl 1 (*) (mapChoices x) - test = natVal (Proxy :: Proxy (EltChoices a)) - -- Nothing -> case sameNat n (Proxy :: Proxy 1) of - -- Just Refl | (Exp x' :* SOP.Nil) <- x -> Exp ( - -- SmartExp ( - -- Pair - -- (makeTag 1) - -- (SmartExp ( - -- Pair - -- (SmartExp - -- (Union - -- (Right ( - -- SmartExp - -- (Union - -- (Left x') - -- ) - -- )) - -- ) - -- ) - -- (SmartExp Smart.Nil) - -- )) - -- ) - -- ) - -- Nothing -> error "Impossible type encountered" - - -- match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of - -- Just Refl -> - -- case e of - -- SmartExp (Match (TagRtag 0 (TagRpair _ TagRunit)) _x) - -- -> Just SOP.Nil - - -- SmartExp Match {} -> Nothing - - -- _ -> error "Embedded pattern synonym used outside 'match' context." - -- Nothing -> -- matchJust - -- case sameNat n (Proxy :: Proxy 1) of - -- Just Refl -> - -- case e of - -- SmartExp (Match (TagRtag 1 _) x) - -- -> Just - -- (Exp - -- (SmartExp - -- (PrjUnion - -- UnionIdxLeft - -- (SmartExp - -- (PrjUnion - -- UnionIdxRight - -- (SmartExp - -- (Prj - -- PairIdxLeft - -- (SmartExp (Prj PairIdxRight x)) - -- )) - -- )) - -- )) :* SOP.Nil) - -- SmartExp Match {} -> Nothing - - -- _ -> error "Embedded pattern synonym used outside 'match' context." - - -- Nothing -> - -- error "Impossible type encountered" - --- weirdConvert :: forall x . Elt x => TypeR (EltR x) --- weirdConvert = eltR @x - --- getSingleElem :: NP Exp '[a] -> Exp a --- getSingleElem (x :* SOP.Nil) = x - --- buildFields :: (All Elt xs) => NP SmartExp xs -> SmartExp (ConcatT (MapFlattenProduct (MapFields xs))) --- buildFields SOP.Nil = () --- buildFields (x :* xs) = SmartExp $ Pair --- where --- fieldsx = buildFields1 x + Just Refl -> + case emptyFields @(Either a b) of + x + | Refl :: POStoEltR (Choices a + Choices b) (Merge (Fields a ++ '[]) (Fields b ++ '[])) :~: (TAG, FlattenProduct (Merge (Fields a) (Fields b))) <- unsafeCoerce Refl + -> + Exp (SmartExp (Pair (unExp $ buildTAG fs) undefined)) + Nothing -> + case emptyFields @(Either a b) of + PTNil -> undefined + (PTCons x xs) -> undefined + +-- convert :: forall a b . NP Exp a -> SmartExp (FlattenProduct (Merge a b)) +-- convert SOP.Nil = SmartExp (Pair _ _) +-- convert (x :* xs) = undefined mergeLeft :: forall a b . TypeR a -> TypeR b -> SmartExp a -> SmartExp (Merge' a b) mergeLeft TupRunit TupRunit a = unExp $ constant () @@ -422,17 +357,17 @@ type family MapFlattenSum (x :: [[Type]]) :: [Type] where MapFlattenSum (x ': xs) = UnionScalar (FlattenSum x) ': MapFlattenSum xs -- like combineProducts, but lifted to the AST -buildTAG :: (All Elt xs) => NP Exp xs -> Exp TAG +buildTAG :: (All POSable xs) => NP Exp xs -> Exp TAG buildTAG SOP.Nil = Exp $ makeTag 0 buildTAG (x :* xs) = combineProduct x (buildTAG xs) -- like Finite.combineProduct, but lifted to the AST -- basically `tag x + tag y * natVal x` -combineProduct :: forall x. (Elt x) => Exp x -> Exp TAG -> Exp TAG -combineProduct x y = case sameNat (Proxy :: Proxy (EltChoices x)) (Proxy :: Proxy 1) of +combineProduct :: forall x. (POSable x) => Exp x -> Exp TAG -> Exp TAG +combineProduct x y = case sameNat (emptyChoices @x) (Proxy :: Proxy 1) of -- untagged type: `tag x = 0`, `natVal x = 1` Just Refl -> y -- tagged type Nothing | Refl :: (EltR x :~: (TAG, y)) <- unsafeCoerce Refl - -> mkAdd (mkExp $ Prj PairIdxLeft (unExp x)) (mkMul y (constant (fromInteger $ natVal (Proxy :: Proxy (EltChoices x))))) + -> mkAdd (mkExp $ Prj PairIdxLeft (unExp x)) (mkMul y (constant (fromInteger $ natVal (emptyChoices @x)))) diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index 3f9204db3..dad960650 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -17,7 +17,7 @@ module Data.Array.Accelerate.Representation.POS ( POSable(..), Product(..), Sum(..), Ground(..), Finite, ProductType(..), SumType(..), POSable.Generic, type (++), - mkPOSableGround, Undef(..)) + mkPOSableGround, Undef(..), type Merge) where From 02bf6cb6faf99e0951a7a179065768ff465d773d Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Fri, 20 May 2022 10:33:49 +0200 Subject: [PATCH 44/67] removed outerchoices --- src/Data/Array/Accelerate/Representation/Shape.hs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/Shape.hs b/src/Data/Array/Accelerate/Representation/Shape.hs index ce7b78b6d..af06873b0 100644 --- a/src/Data/Array/Accelerate/Representation/Shape.hs +++ b/src/Data/Array/Accelerate/Representation/Shape.hs @@ -220,9 +220,6 @@ instance POSable (ShapeR ()) where emptyFields = PTNil - type OuterChoices (ShapeR ()) = 1 - outerChoice _ = 0 - instance (POSable (ShapeR sh)) => POSable (ShapeR (sh, Int)) where type Choices (ShapeR (sh, Int)) = 1 @@ -238,7 +235,4 @@ instance (POSable (ShapeR sh)) => POSable (ShapeR (sh, Int)) where emptyFields = PTCons (STSucc Undef STZero) (emptyFields @(ShapeR sh)) - type OuterChoices (ShapeR (sh, Int)) = 1 - outerChoice _ = 0 - From 5d05a98b3371342e43fdcc7f97ca8350c7d21677 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Fri, 20 May 2022 12:07:30 +0200 Subject: [PATCH 45/67] more Either build AST --- .../Array/Accelerate/Pattern/Matchable.hs | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index 768c93755..fe51974f1 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -65,9 +65,9 @@ class Matchable a where -> Exp a -> Maybe (NP Exp (SOPCode a !! n)) -buildTag :: SOP.All Elt xs => NP Exp xs -> Exp TAG +buildTag :: SOP.All POSable xs => NP Exp xs -> Exp TAG buildTag SOP.Nil = constant 0 -- exp of 0 :: Finite 1 -buildTag (((Exp x) :: (Exp x)) :* (xs :: xs)) = case sameNat (Proxy :: Proxy (EltChoices x)) (Proxy :: Proxy 1) of +buildTag (((Exp x) :: (Exp x)) :* (xs :: xs)) = case sameNat (emptyChoices @x) (Proxy :: Proxy 1) of -- x doesn't contain a tag, skip Just Refl -> buildTag xs @@ -235,18 +235,31 @@ instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) -- type Choices' (Either a b) = OuterChoices (Either a b) build n fs - -- | Refl :: (EltR (Either a b) :~: (TAG, y)) <- unsafeCoerce Refl -- this should be provable, I'm just lazy + | Refl :: (EltR (Either a b) :~: (TAG, FlattenProduct (Fields (Either a b)))) <- unsafeCoerce Refl -- this should be provable, I'm just lazy = case sameNat n (Proxy :: Proxy 0) of - Just Refl -> - case emptyFields @(Either a b) of - x - | Refl :: POStoEltR (Choices a + Choices b) (Merge (Fields a ++ '[]) (Fields b ++ '[])) :~: (TAG, FlattenProduct (Merge (Fields a) (Fields b))) <- unsafeCoerce Refl - -> - Exp (SmartExp (Pair (unExp $ buildTAG fs) undefined)) + Just Refl -> case emptyFields @a of + PTNil -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (undefPairs @(Fields b) (emptyFields @b)))) + PTCons st pt -> Exp (SmartExp (Pair (unExp $ buildTAG fs) undefined)) Nothing -> - case emptyFields @(Either a b) of - PTNil -> undefined - (PTCons x xs) -> undefined + case sameNat n (Proxy :: Proxy 1) of + Just Refl -> case emptyFields @a of + PTNil -> case fs of + x :* SOP.Nil -> case eltRType @b of -- disambiguate between tagless and tagged b's + SingletonType -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (SmartExp (Pair (SmartExp (Union undefined (SmartExp (LiftUnion (unExp x))))) (SmartExp Smart.Nil))))) + TaglessType -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (mergePairs @(Fields b) (emptyFields @b) (unExp x)))) + TaggedType -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (mergePairs @(Fields b) (emptyFields @b) (SmartExp (Prj PairIdxRight (unExp x)))))) + (PTCons x xs) -> Exp (SmartExp (Pair (unExp $ buildTAG fs) undefined)) + Nothing -> error "Index out of bounds" + +undefPairs :: forall xs . ProductType xs -> SmartExp (FlattenProduct (Merge '[] (xs ++ '[]))) +undefPairs PTNil = SmartExp Smart.Nil +undefPairs (PTCons x xs) = SmartExp (Pair (SmartExp (Union undefined (SmartExp (LiftUnion (unExp $ constant POS.Undef))))) (undefPairs xs)) + +mergePairs :: forall xs . ProductType xs -> SmartExp (FlattenProduct xs) -> SmartExp (FlattenProduct (Merge '[] (xs ++ '[]))) +mergePairs PTNil _ = SmartExp Smart.Nil +mergePairs (PTCons x xs) y = SmartExp (Pair (SmartExp (Union undefined (SmartExp (Prj PairIdxLeft y)))) (mergePairs xs (SmartExp (Prj PairIdxRight y)))) + +-- -- convert :: forall a b . NP Exp a -> SmartExp (FlattenProduct (Merge a b)) -- convert SOP.Nil = SmartExp (Pair _ _) From 912b66c86817d249955cabca792e459a0cb817b9 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Wed, 1 Jun 2022 13:39:15 +0200 Subject: [PATCH 46/67] cleanup Matchable --- .../Array/Accelerate/Pattern/Matchable.hs | 206 ++++-------------- .../Array/Accelerate/Representation/POS.hs | 8 +- .../Array/Accelerate/Representation/Shape.hs | 2 +- src/Data/Array/Accelerate/Smart.hs | 3 +- src/Data/Array/Accelerate/Sugar/POS.hs | 6 +- 5 files changed, 50 insertions(+), 175 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index fe51974f1..c1c94874e 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -1,6 +1,6 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeFamilies #-} + {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE TypeApplications #-} @@ -85,52 +85,10 @@ type family (!!) (xs :: [[Type]]) (y :: Nat) :: [Type] where infixl 9 !! -type family ListToCons (xs :: [Type]) :: Type where - ListToCons '[] = () - ListToCons (x ': xs) = (x, ListToCons xs) - --- copied from Elt library -type family Products (xs :: [Nat]) :: Nat where - Products '[] = 1 - Products (x ': xs) = x * Products xs - --- idem --- type family MapChoices (xs :: [Type]) :: [Nat] where --- MapChoices '[] = '[] --- MapChoices (x ': xs) = Choices x ': MapChoices xs - --- idem --- type family Concat (xss :: [[x]]) :: [x] where --- Concat '[] = '[] --- Concat (xs ': xss) = xs ++ Concat xss - - -type family MapFlattenProduct (xs :: [[[Type]]]) :: [Type] where - MapFlattenProduct '[] = '[] - MapFlattenProduct (x ': xs) = FlattenProduct x ': MapFlattenProduct xs - -type family ConcatT (xss :: [x]) :: x where - ConcatT '[] = () - ConcatT (x ': xs) = (x, ConcatT xs) - --- type family RealConcatT (xss :: [[Type]]) :: Type where --- ConcatT '[] = () --- ConcatT ('[] ': xs) = RealConcatT xs --- ConcatT ((x ': xs) ': ys) = (x, RealConcatT xs ys) - -type family ConcatAST (x :: Type) (y :: Type) :: Type where - ConcatAST xs () = xs - ConcatAST () ys = ys - ConcatAST (x, xs) ys = (x, ConcatAST xs ys) - -type family ConcatASTs (xs :: [Type]) :: Type where - ConcatASTs '[] = () - ConcatASTs (x ': xs) = ConcatAST x (ConcatASTs xs) - instance Matchable Bool where -- type Choices' Bool = 2 - build n _ = Exp (SmartExp (Pair (undefined (fromInteger $ natVal n)) (SmartExp Smart.Nil))) + build n _ = Exp (SmartExp (Pair (unExp $ constant @TAG (fromInteger $ natVal n)) (SmartExp Smart.Nil))) match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of Just Refl -> @@ -161,8 +119,6 @@ tagType = TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeTAG)) instance Matchable (Maybe Int) where - -- type Choices' (Maybe Int) = 2 - build n x = case sameNat n (Proxy :: Proxy 0) of Just Refl -> Exp ( @@ -189,7 +145,6 @@ instance Matchable (Maybe Int) where Pair (SmartExp (Union - scalarTypeUndefLeft (SmartExp (LiftUnion x') ) @@ -216,7 +171,7 @@ instance Matchable (Maybe Int) where case e of SmartExp (Match (1,2) x) -> Just ( - (mkExp $ PrjUnion $ SmartExp $ Union (unConcatSumScalarType (UnionScalarType $ SuccScalarType (UndefSingleType) ZeroScalarType)) (SmartExp $ Prj PairIdxLeft (SmartExp $ Prj PairIdxRight x))) + mkExp (PrjUnion $ SmartExp $ Union (SmartExp $ Prj PairIdxLeft (SmartExp $ Prj PairIdxRight x))) :* SOP.Nil) SmartExp Match {} -> Nothing @@ -225,27 +180,58 @@ instance Matchable (Maybe Int) where Nothing -> error "Impossible type encountered" -unConcatSumScalarType :: ScalarType (UnionScalar a) -> ScalarType (UnionScalar (Concat' a b)) -> ScalarType (UnionScalar b) -unConcatSumScalarType (UnionScalarType ZeroScalarType) xs = xs -unConcatSumScalarType (UnionScalarType (SuccScalarType a as)) (UnionScalarType (SuccScalarType x xs)) = unConcatSumScalarType (UnionScalarType as) (UnionScalarType xs) +instance Matchable (Maybe a) where + build n x = case sameNat n (Proxy :: Proxy 0) of + Just Refl -> + Exp undefined + Nothing -> case sameNat n (Proxy :: Proxy 1) of + Just Refl | (Exp x' :* SOP.Nil) <- x -> Exp undefined + Nothing -> error "Impossible type encountered" + + match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of + Just Refl -> + case e of + SmartExp (Match (0,1) _x) + -> Just SOP.Nil + + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + Nothing -> -- matchJust + case sameNat n (Proxy :: Proxy 1) of + Just Refl -> + case e of + SmartExp (Match (1,2) x) + -> Just (undefined) + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + Nothing -> + error "Impossible type encountered" instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) where -- type Choices' (Either a b) = OuterChoices (Either a b) build n fs - | Refl :: (EltR (Either a b) :~: (TAG, FlattenProduct (Fields (Either a b)))) <- unsafeCoerce Refl -- this should be provable, I'm just lazy + -- this is only not true if either left or right has a tag of type Finite 0 + -- types with tags of Finite 0 have no constructors, and are quite useless + | Refl :: (EltR (Either a b) :~: (TAG, FlattenProduct (Fields (Either a b)))) <- unsafeCoerce Refl = case sameNat n (Proxy :: Proxy 0) of + -- we have chosen constructor 0 (Left) Just Refl -> case emptyFields @a of + -- Left has no fields PTNil -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (undefPairs @(Fields b) (emptyFields @b)))) + -- Left has fields PTCons st pt -> Exp (SmartExp (Pair (unExp $ buildTAG fs) undefined)) Nothing -> case sameNat n (Proxy :: Proxy 1) of + -- we have chosen constructor 1 (Right) Just Refl -> case emptyFields @a of PTNil -> case fs of x :* SOP.Nil -> case eltRType @b of -- disambiguate between tagless and tagged b's - SingletonType -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (SmartExp (Pair (SmartExp (Union undefined (SmartExp (LiftUnion (unExp x))))) (SmartExp Smart.Nil))))) + SingletonType -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (SmartExp (Pair (SmartExp (Union (SmartExp (LiftUnion (unExp x))))) (SmartExp Smart.Nil))))) TaglessType -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (mergePairs @(Fields b) (emptyFields @b) (unExp x)))) TaggedType -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (mergePairs @(Fields b) (emptyFields @b) (SmartExp (Prj PairIdxRight (unExp x)))))) (PTCons x xs) -> Exp (SmartExp (Pair (unExp $ buildTAG fs) undefined)) @@ -253,121 +239,11 @@ instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) undefPairs :: forall xs . ProductType xs -> SmartExp (FlattenProduct (Merge '[] (xs ++ '[]))) undefPairs PTNil = SmartExp Smart.Nil -undefPairs (PTCons x xs) = SmartExp (Pair (SmartExp (Union undefined (SmartExp (LiftUnion (unExp $ constant POS.Undef))))) (undefPairs xs)) +undefPairs (PTCons x xs) = SmartExp (Pair (SmartExp (Union (SmartExp (LiftUnion (unExp $ constant POS.Undef))))) (undefPairs xs)) mergePairs :: forall xs . ProductType xs -> SmartExp (FlattenProduct xs) -> SmartExp (FlattenProduct (Merge '[] (xs ++ '[]))) mergePairs PTNil _ = SmartExp Smart.Nil -mergePairs (PTCons x xs) y = SmartExp (Pair (SmartExp (Union undefined (SmartExp (Prj PairIdxLeft y)))) (mergePairs xs (SmartExp (Prj PairIdxRight y)))) - --- - --- convert :: forall a b . NP Exp a -> SmartExp (FlattenProduct (Merge a b)) --- convert SOP.Nil = SmartExp (Pair _ _) --- convert (x :* xs) = undefined - -mergeLeft :: forall a b . TypeR a -> TypeR b -> SmartExp a -> SmartExp (Merge' a b) -mergeLeft TupRunit TupRunit a = unExp $ constant () -mergeLeft TupRunit (TupRpair (TupRsingle (UnionScalarType x)) gbs) a - = SmartExp $ Pair - (makeUndefLeft x) - (mergeLeft TupRunit gbs a) -mergeLeft (TupRpair (TupRsingle (UnionScalarType x)) gas) TupRunit a - = SmartExp $ Pair - (mergeSumUndefRight x (SmartExp $ Prj PairIdxLeft a)) - (mergeLeft gas TupRunit (SmartExp $ Prj PairIdxRight a)) -mergeLeft (TupRpair (TupRsingle (UnionScalarType (ga :: (UnionScalarType ga)))) gas) (TupRpair (TupRsingle (UnionScalarType (gb :: (UnionScalarType gb)))) gbs) a - = SmartExp $ Pair - (SmartExp $ Union (\y -> scalarSumConcat' @ga @gb y (UnionScalarType gb)) (SmartExp $ Prj PairIdxLeft a)) -- (scalarSumConcat' @ga @gb (UnionScalarType ga)) - (mergeLeft gas gbs (SmartExp $ Prj PairIdxRight a)) - -makeUndefLeft :: UnionScalarType x -> SmartExp (UnionScalar (Undef, x)) -makeUndefLeft x = SmartExp $ Const (UnionScalarType (SuccScalarType (UndefSingleType) x)) (PickScalar POS.Undef) - -mergeSumLeft :: forall a b . UnionScalarType a -> UnionScalarType b -> SmartExp (UnionScalar a) -> SmartExp (UnionScalar (Concat' a b)) -mergeSumLeft ls rs x = SmartExp $ Union (const $ scalarSumConcat ls rs) x - - -scalarSumConcat':: ScalarType (UnionScalar xs) -> ScalarType (UnionScalar ys) -> ScalarType (UnionScalar (Concat' xs ys)) -scalarSumConcat' (UnionScalarType ls) (UnionScalarType rs) = scalarSumConcat ls rs - -scalarSumConcat:: UnionScalarType xs -> UnionScalarType ys -> ScalarType (UnionScalar (Concat' xs ys)) -scalarSumConcat ZeroScalarType rs = UnionScalarType rs -scalarSumConcat (SuccScalarType l ls) rs = UnionScalarType $ SuccScalarType l ls' - where - UnionScalarType ls' = scalarSumConcat ls rs - -merge :: forall a b . TypeR a -> TypeR b -> SmartExp a -> SmartExp b -> SmartExp (Merge' a b) -merge TupRunit TupRunit a b = unExp $ constant () -merge TupRunit (TupRpair (TupRsingle (UnionScalarType x)) gbs) a b - = SmartExp $ Pair - (mergeSumUndefLeft x (SmartExp $ Prj PairIdxLeft b)) - (merge TupRunit gbs a (SmartExp $ Prj PairIdxRight b)) -merge (TupRpair (TupRsingle (UnionScalarType x)) gas) TupRunit a b - = SmartExp $ Pair - (mergeSumUndefRight x (SmartExp $ Prj PairIdxLeft a)) - (merge gas TupRunit (SmartExp $ Prj PairIdxRight a) b) -merge (TupRpair (TupRsingle (UnionScalarType ga)) gas) (TupRpair (TupRsingle (UnionScalarType gb)) gbs) a b - = SmartExp $ Pair - (undefined) -- mergeSum - (merge gas gbs (SmartExp $ Prj PairIdxRight a) (SmartExp $ Prj PairIdxRight b)) - - -mergeSumUndefRight :: UnionScalarType x -> SmartExp (UnionScalar x) -> SmartExp (UnionScalar (Concat' x (Undef, ()))) -mergeSumUndefRight ZeroScalarType a = SmartExp $ Const (UnionScalarType (SuccScalarType (UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) -mergeSumUndefRight (SuccScalarType x xs) a = SmartExp $ Union scalarTypeUndefRight a - -mergeSumUndefLeft :: UnionScalarType x -> SmartExp (UnionScalar x) -> SmartExp (UnionScalar (Undef, x)) -mergeSumUndefLeft ZeroScalarType a = SmartExp $ Const (UnionScalarType (SuccScalarType (UndefSingleType) ZeroScalarType)) (PickScalar POS.Undef) -mergeSumUndefLeft (SuccScalarType x xs) a = SmartExp $ Union scalarTypeUndefLeft a - -scalarTypeUndefLeft :: ScalarType (UnionScalar a) -> ScalarType (UnionScalar (Undef, a)) -scalarTypeUndefLeft (UnionScalarType x) = UnionScalarType (SuccScalarType (singleType @Undef) x) - -scalarTypeUndefRight :: ScalarType (UnionScalar a) -> ScalarType (UnionScalar (Concat' a (Undef, ()))) -scalarTypeUndefRight (UnionScalarType ZeroScalarType) = UnionScalarType (SuccScalarType (singleType @Undef) ZeroScalarType) -scalarTypeUndefRight (UnionScalarType (SuccScalarType x xs)) - = UnionScalarType (SuccScalarType x xs') - where - (UnionScalarType xs') = scalarTypeUndefRight (UnionScalarType xs) - --- mergeSum :: UnionScalar - -class AllUnionScalar (xs :: Type) where - -instance AllUnionScalar () where - -instance (x' ~ UnionScalar x, IsUnionScalar x', AllUnionScalar xs) => AllUnionScalar (x, xs) where - -class All' (c :: k -> Constraint) (xs :: Type) where - -instance All' c () where - -instance (c x, All' c xs) => All' c (x, xs) where - - - --- ZipWith Concat --- like POSable.Merge, but lifted to tuple lists -type family Merge' (a :: Type) (b :: Type) = (r :: Type) where - Merge' () () = () - Merge' () (UnionScalar b, bs) = (UnionScalar (Undef, b), Merge' () bs) - Merge' (UnionScalar a, as) () = (UnionScalar (Concat' a (Undef, ())), Merge' as ()) - Merge' (UnionScalar a, as) (UnionScalar b, bs) = (UnionScalar (Concat' a b), Merge' as bs) - -type family Concat' (a :: Type) (b :: Type) = (r :: Type) where - Concat' () ys = ys - Concat' (x, xs) ys = (x, Concat' xs ys) - - -type family Head (xs :: [x]) :: x where - Head (x ': xs) = x - -type family Tail (xs :: [x]) :: [x] where - Tail (x ': xs) = xs - -type family MapFlattenSum (x :: [[Type]]) :: [Type] where - MapFlattenSum '[] = '[] - MapFlattenSum (x ': xs) = UnionScalar (FlattenSum x) ': MapFlattenSum xs +mergePairs (PTCons x xs) y = SmartExp (Pair (SmartExp (Union (SmartExp (Prj PairIdxLeft y)))) (mergePairs xs (SmartExp (Prj PairIdxRight y)))) -- like combineProducts, but lifted to the AST buildTAG :: (All POSable xs) => NP Exp xs -> Exp TAG diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs index dad960650..fb4f53394 100644 --- a/src/Data/Array/Accelerate/Representation/POS.hs +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -21,7 +21,7 @@ module Data.Array.Accelerate.Representation.POS ( where -import Data.Type.POSable.POSable as POSable -import Data.Type.POSable.Representation -import Data.Type.POSable.Instances () -import Data.Type.POSable.TH +import Generics.POSable.POSable as POSable +import Generics.POSable.Representation +import Generics.POSable.Instances () +import Generics.POSable.TH diff --git a/src/Data/Array/Accelerate/Representation/Shape.hs b/src/Data/Array/Accelerate/Representation/Shape.hs index af06873b0..fa9bd58ff 100644 --- a/src/Data/Array/Accelerate/Representation/Shape.hs +++ b/src/Data/Array/Accelerate/Representation/Shape.hs @@ -28,7 +28,7 @@ import Data.Array.Accelerate.Error import Data.Array.Accelerate.Type import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.POS -import Data.Type.POSable.Representation +import Generics.POSable.Representation import Language.Haskell.TH.Extra import Prelude hiding ( zip ) diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index af3416355..e5906653f 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -514,8 +514,7 @@ data PreSmartExp acc exp t where LiftUnion :: exp t1 -> PreSmartExp acc exp (UnionScalar (t1, ())) - Union :: (ScalarType (UnionScalar t1) -> ScalarType (UnionScalar t2)) - -> exp (UnionScalar t1) + Union :: exp (UnionScalar t1) -> PreSmartExp acc exp (UnionScalar t2) PrjUnion :: exp (UnionScalar (t1, ())) diff --git a/src/Data/Array/Accelerate/Sugar/POS.hs b/src/Data/Array/Accelerate/Sugar/POS.hs index 70ecd806e..1d6762945 100644 --- a/src/Data/Array/Accelerate/Sugar/POS.hs +++ b/src/Data/Array/Accelerate/Sugar/POS.hs @@ -33,9 +33,9 @@ module Data.Array.Accelerate.Sugar.POS import Language.Haskell.TH.Extra hiding ( Type ) -import Data.Type.POSable.POSable as POSable -import Data.Type.POSable.Representation -import Data.Type.POSable.TH +import Generics.POSable.POSable as POSable +import Generics.POSable.Representation +import Generics.POSable.TH import Data.Int import Data.Word From b17614c2ccb3fb7bbbdfb7b31b1b0fa84690932f Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Wed, 1 Jun 2022 14:39:00 +0200 Subject: [PATCH 47/67] makeLeft works :O --- .../Array/Accelerate/Pattern/Matchable.hs | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index c1c94874e..772612453 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -181,13 +181,21 @@ instance Matchable (Maybe Int) where error "Impossible type encountered" -instance Matchable (Maybe a) where - build n x = case sameNat n (Proxy :: Proxy 0) of +instance (POSable a) => Matchable (Maybe a) where + build n fs = case sameNat n (Proxy :: Proxy 0) of + -- Produce a Nothing Just Refl -> - Exp undefined + case sameNat (Proxy @(Choices a)) (Proxy @0) of + -- a has 0 valid choices (which means we cannot create a Just of this type) + -- we ignore the implementation for now, because this is not really useful + Just Refl -> undefined + -- a has at least 1 choice. + -- this means that it always has a tag + Nothing + | Refl :: (EltR (Maybe a) :~: (TAG, FlattenProduct (Fields (Maybe a)))) <- unsafeCoerce Refl + -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (makeLeft @() @a (SmartExp Smart.Nil)))) Nothing -> case sameNat n (Proxy :: Proxy 1) of - Just Refl | (Exp x' :* SOP.Nil) <- x -> Exp undefined - Nothing -> error "Impossible type encountered" + x -> undefined match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of Just Refl -> @@ -211,6 +219,15 @@ instance Matchable (Maybe a) where Nothing -> error "Impossible type encountered" +makeLeft :: forall a b . (POSable a, POSable b) => SmartExp (FlattenProduct (Fields a)) -> SmartExp (FlattenProduct (Merge (Fields a ++ '[]) (Fields b ++ '[]))) +makeLeft x = makeLeft' x (emptyFields @a) (emptyFields @b) + +makeLeft' :: forall a b . SmartExp (FlattenProduct a) -> ProductType a -> ProductType b -> SmartExp (FlattenProduct (Merge (a ++ '[]) (b ++ '[]))) +makeLeft' _ PTNil PTNil = SmartExp Smart.Nil +makeLeft' x PTNil (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (SmartExp (LiftUnion (SmartExp (Const (SingleScalarType UndefSingleType) POS.Undef)))))) (makeLeft' x PTNil rs)) +makeLeft' x (PTCons _ ls) PTNil = SmartExp (Pair (SmartExp (Union (SmartExp $ Prj PairIdxLeft x))) (makeLeft' (SmartExp $ Prj PairIdxRight x) ls PTNil)) +makeLeft' x (PTCons _ ls) (PTCons r rs) = SmartExp (Pair (SmartExp (Union (SmartExp $ Prj PairIdxLeft x))) (makeLeft' (SmartExp $ Prj PairIdxRight x) ls rs)) + instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) where -- type Choices' (Either a b) = OuterChoices (Either a b) From 25bbffb8dbf2550a0cb7d0bafff38e1e196d6a7b Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 2 Jun 2022 11:46:36 +0200 Subject: [PATCH 48/67] more Maybe build --- .../Array/Accelerate/Pattern/Matchable.hs | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index 772612453..fa8c87a95 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -181,21 +181,29 @@ instance Matchable (Maybe Int) where error "Impossible type encountered" -instance (POSable a) => Matchable (Maybe a) where - build n fs = case sameNat n (Proxy :: Proxy 0) of - -- Produce a Nothing - Just Refl -> - case sameNat (Proxy @(Choices a)) (Proxy @0) of - -- a has 0 valid choices (which means we cannot create a Just of this type) - -- we ignore the implementation for now, because this is not really useful - Just Refl -> undefined - -- a has at least 1 choice. - -- this means that it always has a tag +instance (POSable (Maybe a), POSable a) => Matchable (Maybe a) where + build n fs = case sameNat (Proxy @(Choices a)) (Proxy @0) of + -- a has 0 valid choices (which means we cannot create a Just of this type) + -- we ignore the implementation for now, because this is not really useful + Just Refl -> undefined + -- a has at least 1 choice. + -- this means that Maybe a always has a tag + Nothing + | Refl :: (EltR (Maybe a) :~: (TAG, FlattenProduct (Fields (Maybe a)))) <- unsafeCoerce Refl + -> case sameNat n (Proxy :: Proxy 0) of + -- Produce a Nothing + Just Refl -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (makeLeft @() @a (SmartExp Smart.Nil)))) Nothing - | Refl :: (EltR (Maybe a) :~: (TAG, FlattenProduct (Fields (Maybe a)))) <- unsafeCoerce Refl - -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (makeLeft @() @a (SmartExp Smart.Nil)))) - Nothing -> case sameNat n (Proxy :: Proxy 1) of - x -> undefined + | Exp x :* SOP.Nil <- fs + -> case sameNat n (Proxy :: Proxy 1) of + Just Refl -> case eltRType @a of + -- TOOD: lift type + SingletonType -> undefined + -- TODO: add 1 to the tag + TaglessType -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (makeRight @() @a x))) + -- TODO: remove tag + TaggedType -> undefined + Nothing -> error $ "Impossible situation requested: Maybe has 2 constructors, constructor " ++ show (natVal n) ++ "is out of bound" match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of Just Refl -> @@ -226,7 +234,16 @@ makeLeft' :: forall a b . SmartExp (FlattenProduct a) -> ProductType a -> Produc makeLeft' _ PTNil PTNil = SmartExp Smart.Nil makeLeft' x PTNil (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (SmartExp (LiftUnion (SmartExp (Const (SingleScalarType UndefSingleType) POS.Undef)))))) (makeLeft' x PTNil rs)) makeLeft' x (PTCons _ ls) PTNil = SmartExp (Pair (SmartExp (Union (SmartExp $ Prj PairIdxLeft x))) (makeLeft' (SmartExp $ Prj PairIdxRight x) ls PTNil)) -makeLeft' x (PTCons _ ls) (PTCons r rs) = SmartExp (Pair (SmartExp (Union (SmartExp $ Prj PairIdxLeft x))) (makeLeft' (SmartExp $ Prj PairIdxRight x) ls rs)) +makeLeft' x (PTCons _ ls) (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (SmartExp $ Prj PairIdxLeft x))) (makeLeft' (SmartExp $ Prj PairIdxRight x) ls rs)) + +makeRight :: forall a b . (POSable a, POSable b) => SmartExp (FlattenProduct (Fields b)) -> SmartExp (FlattenProduct (Merge (Fields a ++ '[]) (Fields b ++ '[]))) +makeRight x = makeRight' x (emptyFields @a) (emptyFields @b) + +makeRight' :: forall a b . SmartExp (FlattenProduct b) -> ProductType a -> ProductType b -> SmartExp (FlattenProduct (Merge (a ++ '[]) (b ++ '[]))) +makeRight' _ PTNil PTNil = SmartExp Smart.Nil +makeRight' x PTNil (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (SmartExp $ Prj PairIdxLeft x))) (makeRight' (SmartExp $ Prj PairIdxRight x) PTNil rs)) +makeRight' x (PTCons _ ls) PTNil = SmartExp (Pair (SmartExp (Union (SmartExp (LiftUnion (SmartExp (Const (SingleScalarType UndefSingleType) POS.Undef)))))) (makeRight' x ls PTNil)) +makeRight' x (PTCons _ ls) (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (SmartExp $ Prj PairIdxLeft x))) (makeRight' (SmartExp $ Prj PairIdxRight x) ls rs)) instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) where -- type Choices' (Either a b) = OuterChoices (Either a b) From 7a52b2c2ae4fea38f8c83b03b6a22534867a09e3 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 2 Jun 2022 13:20:58 +0200 Subject: [PATCH 49/67] build implemented for Maybe a --- src/Data/Array/Accelerate/Pattern/Matchable.hs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index fa8c87a95..afd21ac44 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -196,14 +196,10 @@ instance (POSable (Maybe a), POSable a) => Matchable (Maybe a) where Nothing | Exp x :* SOP.Nil <- fs -> case sameNat n (Proxy :: Proxy 1) of - Just Refl -> case eltRType @a of - -- TOOD: lift type - SingletonType -> undefined - -- TODO: add 1 to the tag - TaglessType -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (makeRight @() @a x))) - -- TODO: remove tag - TaggedType -> undefined - Nothing -> error $ "Impossible situation requested: Maybe has 2 constructors, constructor " ++ show (natVal n) ++ "is out of bound" + -- Add 1 to the tag because we have skipped 1 choice: Nothing + Just Refl -> Exp (SmartExp (Pair (unExp $ mkAdd @TAG (constant 1) (buildTAG fs)) (makeRight @() @a (unTag @a x)))) + Nothing -> error $ "Impossible situation requested: Maybe has 2 constructors, constructor " ++ show (natVal n) ++ "is out of bound" + Nothing -> error "Impossible situation requested: Just a expects a single value, got 0 or more then 1" match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of Just Refl -> @@ -245,6 +241,12 @@ makeRight' x PTNil (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (SmartExp $ P makeRight' x (PTCons _ ls) PTNil = SmartExp (Pair (SmartExp (Union (SmartExp (LiftUnion (SmartExp (Const (SingleScalarType UndefSingleType) POS.Undef)))))) (makeRight' x ls PTNil)) makeRight' x (PTCons _ ls) (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (SmartExp $ Prj PairIdxLeft x))) (makeRight' (SmartExp $ Prj PairIdxRight x) ls rs)) +unTag :: forall x . (POSable x) => SmartExp (EltR x) -> SmartExp (FlattenProduct (Fields x)) +unTag x = case eltRType @x of + SingletonType -> SmartExp (Pair (SmartExp (LiftUnion x)) (SmartExp Smart.Nil)) + TaglessType -> x + TaggedType -> SmartExp $ Prj PairIdxRight x + instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) where -- type Choices' (Either a b) = OuterChoices (Either a b) From e183544518df3c1c2e0376b0c9ab55261b4ece0d Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 2 Jun 2022 15:20:58 +0200 Subject: [PATCH 50/67] match on maybe --- .../Array/Accelerate/Pattern/Matchable.hs | 100 +++++++++++++----- src/Data/Array/Accelerate/Smart.hs | 2 +- src/Data/Array/Accelerate/Sugar/Vec.hs | 7 -- 3 files changed, 72 insertions(+), 37 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index afd21ac44..6ae102f5e 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -93,7 +93,7 @@ instance Matchable Bool where match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of Just Refl -> case e of - SmartExp (Match (0,1) _x) -> Just SOP.Nil + SmartExp (Match n _x) | n == 0 -> Just SOP.Nil SmartExp Match {} -> Nothing @@ -102,7 +102,7 @@ instance Matchable Bool where case sameNat n (Proxy :: Proxy 1) of Just Refl -> case e of - SmartExp (Match (1,2) _x) -> Just SOP.Nil + SmartExp (Match n _x) | n == 1 -> Just SOP.Nil SmartExp Match {} -> Nothing @@ -159,7 +159,8 @@ instance Matchable (Maybe Int) where match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of Just Refl -> case e of - SmartExp (Match (0,1) _x) + SmartExp (Match m _x) + | m == 0 -> Just SOP.Nil SmartExp Match {} -> Nothing @@ -169,9 +170,10 @@ instance Matchable (Maybe Int) where case sameNat n (Proxy :: Proxy 1) of Just Refl -> case e of - SmartExp (Match (1,2) x) + SmartExp (Match m x) + | m == 1 -> Just ( - mkExp (PrjUnion $ SmartExp $ Union (SmartExp $ Prj PairIdxLeft (SmartExp $ Prj PairIdxRight x))) + mkExp (PrjUnion $ SmartExp $ Union (prjLeft (prjRight x))) :* SOP.Nil) SmartExp Match {} -> Nothing @@ -188,8 +190,7 @@ instance (POSable (Maybe a), POSable a) => Matchable (Maybe a) where Just Refl -> undefined -- a has at least 1 choice. -- this means that Maybe a always has a tag - Nothing - | Refl :: (EltR (Maybe a) :~: (TAG, FlattenProduct (Fields (Maybe a)))) <- unsafeCoerce Refl + Nothing | Refl :: (EltR (Maybe a) :~: (TAG, FlattenProduct (Fields (Maybe a)))) <- unsafeCoerce Refl -> case sameNat n (Proxy :: Proxy 0) of -- Produce a Nothing Just Refl -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (makeLeft @() @a (SmartExp Smart.Nil)))) @@ -201,27 +202,56 @@ instance (POSable (Maybe a), POSable a) => Matchable (Maybe a) where Nothing -> error $ "Impossible situation requested: Maybe has 2 constructors, constructor " ++ show (natVal n) ++ "is out of bound" Nothing -> error "Impossible situation requested: Just a expects a single value, got 0 or more then 1" - match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of - Just Refl -> - case e of - SmartExp (Match (0,1) _x) - -> Just SOP.Nil + match n (Exp e) = case sameNat (Proxy @(Choices a)) (Proxy @0) of + -- a has 0 valid choices (which means we cannot create a Just of this type) + -- we ignore the implementation for now, because this is not really useful + Just Refl -> undefined + -- a has at least 1 choice. + -- this means that Maybe a always has a tag + Nothing | Refl :: (EltR (Maybe a) :~: (TAG, FlattenProduct (Fields (Maybe a)))) <- unsafeCoerce Refl + -> case sameNat n (Proxy :: Proxy 0) of + Just Refl -> + case e of + SmartExp (Match m _x) + | m >= 0 + , m < 1 + -> Just SOP.Nil - SmartExp Match {} -> Nothing + SmartExp Match {} -> Nothing - _ -> error "Embedded pattern synonym used outside 'match' context." - Nothing -> -- matchJust - case sameNat n (Proxy :: Proxy 1) of - Just Refl -> - case e of - SmartExp (Match (1,2) x) - -> Just (undefined) - SmartExp Match {} -> Nothing + _ -> error "Embedded pattern synonym used outside 'match' context." + Nothing -> -- matchJust + case sameNat n (Proxy :: Proxy 1) of + Just Refl -> + case e of + SmartExp (Match m x) + | m >= 1 + , m < fromInteger (natVal $ Proxy @(Choices a)) + -- remove one from the tag as we are not in left anymore + -- the `tag` function will apply the new tag if necessary + -> Just (Exp (tag @a (unExp $ mkMin @TAG (constant 1) (Exp $ prjLeft x)) (splitRight @() @a $ prjRight x)) :* SOP.Nil) + SmartExp Match {} -> Nothing - _ -> error "Embedded pattern synonym used outside 'match' context." + _ -> error "Embedded pattern synonym used outside 'match' context." - Nothing -> - error "Impossible type encountered" + Nothing -> + error "Impossible type encountered" + +splitLeft :: forall a b . (POSable a, POSable b) => SmartExp (FlattenProduct (Merge (Fields a ++ '[]) (Fields b ++ '[]))) -> SmartExp (FlattenProduct (Fields a)) +splitLeft x = splitLeft' x (emptyFields @a) (emptyFields @b) + +splitLeft' :: forall a b . SmartExp (FlattenProduct (Merge (a ++ '[]) (b ++ '[]))) -> ProductType a -> ProductType b -> SmartExp (FlattenProduct a) +splitLeft' _ PTNil _ = SmartExp Smart.Nil +splitLeft' x (PTCons _ ls) PTNil = SmartExp $ Pair (SmartExp $ Union (prjLeft x)) (splitLeft' (prjRight x) ls PTNil) +splitLeft' x (PTCons _ ls) (PTCons _ rs) = SmartExp $ Pair (SmartExp $ Union (prjLeft x)) (splitLeft' (prjRight x) ls rs) + +splitRight :: forall a b . (POSable a, POSable b) => SmartExp (FlattenProduct (Merge (Fields a ++ '[]) (Fields b ++ '[]))) -> SmartExp (FlattenProduct (Fields b)) +splitRight x = splitRight' x (emptyFields @a) (emptyFields @b) + +splitRight' :: forall a b . SmartExp (FlattenProduct (Merge (a ++ '[]) (b ++ '[]))) -> ProductType a -> ProductType b -> SmartExp (FlattenProduct b) +splitRight' _ _ PTNil = SmartExp Smart.Nil +splitRight' x PTNil (PTCons _ rs) = SmartExp $ Pair (SmartExp $ Union (prjLeft x)) (splitRight' (prjRight x) PTNil rs) +splitRight' x (PTCons _ ls) (PTCons _ rs) = SmartExp $ Pair (SmartExp $ Union (prjLeft x)) (splitRight' (prjRight x) ls rs) makeLeft :: forall a b . (POSable a, POSable b) => SmartExp (FlattenProduct (Fields a)) -> SmartExp (FlattenProduct (Merge (Fields a ++ '[]) (Fields b ++ '[]))) makeLeft x = makeLeft' x (emptyFields @a) (emptyFields @b) @@ -229,23 +259,35 @@ makeLeft x = makeLeft' x (emptyFields @a) (emptyFields @b) makeLeft' :: forall a b . SmartExp (FlattenProduct a) -> ProductType a -> ProductType b -> SmartExp (FlattenProduct (Merge (a ++ '[]) (b ++ '[]))) makeLeft' _ PTNil PTNil = SmartExp Smart.Nil makeLeft' x PTNil (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (SmartExp (LiftUnion (SmartExp (Const (SingleScalarType UndefSingleType) POS.Undef)))))) (makeLeft' x PTNil rs)) -makeLeft' x (PTCons _ ls) PTNil = SmartExp (Pair (SmartExp (Union (SmartExp $ Prj PairIdxLeft x))) (makeLeft' (SmartExp $ Prj PairIdxRight x) ls PTNil)) -makeLeft' x (PTCons _ ls) (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (SmartExp $ Prj PairIdxLeft x))) (makeLeft' (SmartExp $ Prj PairIdxRight x) ls rs)) +makeLeft' x (PTCons _ ls) PTNil = SmartExp (Pair (SmartExp (Union (prjLeft x))) (makeLeft' (prjRight x) ls PTNil)) +makeLeft' x (PTCons _ ls) (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (prjLeft x))) (makeLeft' (prjRight x) ls rs)) + +prjLeft :: SmartExp (x, xs) -> SmartExp x +prjLeft = SmartExp . Prj PairIdxLeft + +prjRight :: SmartExp (x, xs) -> SmartExp xs +prjRight = SmartExp . Prj PairIdxRight makeRight :: forall a b . (POSable a, POSable b) => SmartExp (FlattenProduct (Fields b)) -> SmartExp (FlattenProduct (Merge (Fields a ++ '[]) (Fields b ++ '[]))) makeRight x = makeRight' x (emptyFields @a) (emptyFields @b) makeRight' :: forall a b . SmartExp (FlattenProduct b) -> ProductType a -> ProductType b -> SmartExp (FlattenProduct (Merge (a ++ '[]) (b ++ '[]))) makeRight' _ PTNil PTNil = SmartExp Smart.Nil -makeRight' x PTNil (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (SmartExp $ Prj PairIdxLeft x))) (makeRight' (SmartExp $ Prj PairIdxRight x) PTNil rs)) +makeRight' x PTNil (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (prjLeft x))) (makeRight' (prjRight x) PTNil rs)) makeRight' x (PTCons _ ls) PTNil = SmartExp (Pair (SmartExp (Union (SmartExp (LiftUnion (SmartExp (Const (SingleScalarType UndefSingleType) POS.Undef)))))) (makeRight' x ls PTNil)) -makeRight' x (PTCons _ ls) (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (SmartExp $ Prj PairIdxLeft x))) (makeRight' (SmartExp $ Prj PairIdxRight x) ls rs)) +makeRight' x (PTCons _ ls) (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (prjLeft x))) (makeRight' (prjRight x) ls rs)) unTag :: forall x . (POSable x) => SmartExp (EltR x) -> SmartExp (FlattenProduct (Fields x)) unTag x = case eltRType @x of SingletonType -> SmartExp (Pair (SmartExp (LiftUnion x)) (SmartExp Smart.Nil)) TaglessType -> x - TaggedType -> SmartExp $ Prj PairIdxRight x + TaggedType -> prjRight x + +tag :: forall x . (POSable x) => SmartExp TAG -> SmartExp (FlattenProduct (Fields x)) -> SmartExp (EltR x) +tag t x = case eltRType @x of + SingletonType -> SmartExp $ PrjUnion $ prjLeft x + TaglessType -> x + TaggedType -> SmartExp $ Pair t x instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) where -- type Choices' (Either a b) = OuterChoices (Either a b) diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index e5906653f..3124630d6 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -492,7 +492,7 @@ data PreSmartExp acc exp t where -> PreSmartExp acc exp t -- Needed for embedded pattern matching - Match :: (TAG, TAG) -- inclusive tag lower bound inclusive, exclusive tag upper bound + Match :: TAG -> exp t -> PreSmartExp acc exp t diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index 4bcc3b2e9..e61bb2457 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -46,9 +46,6 @@ instance VecElt a => POSable (Vec2 a) where emptyFields = PTCons (STSucc (mkGround @a) STZero) (PTCons (STSucc (mkGround @a) STZero) PTNil) - type OuterChoices (Vec2 a) = 1 - outerChoice _ = 0 - -- Elt instance automatically derived from POSable instance instance VecElt a => Elt (Vec2 a) @@ -67,10 +64,6 @@ instance VecElt a => POSable (Vec4 a) where emptyFields = PTCons (STSucc (mkGround @a) STZero) (PTCons (STSucc (mkGround @a) STZero) (PTCons (STSucc (mkGround @a) STZero) (PTCons (STSucc (mkGround @a) STZero) PTNil))) - type OuterChoices (Vec4 a) = 1 - outerChoice _ = 0 - - -- Elt instance automatically derived from POSable instance instance VecElt a => Elt (Vec4 a) From 7f1a7bc32a5281730a4e0f79562982218d2fe55b Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 2 Jun 2022 16:19:46 +0200 Subject: [PATCH 51/67] Matchable instance for polymorphic Either --- .../Array/Accelerate/Pattern/Matchable.hs | 91 +++++++++++++------ 1 file changed, 65 insertions(+), 26 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index 6ae102f5e..abc5c8f0d 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -199,7 +199,7 @@ instance (POSable (Maybe a), POSable a) => Matchable (Maybe a) where -> case sameNat n (Proxy :: Proxy 1) of -- Add 1 to the tag because we have skipped 1 choice: Nothing Just Refl -> Exp (SmartExp (Pair (unExp $ mkAdd @TAG (constant 1) (buildTAG fs)) (makeRight @() @a (unTag @a x)))) - Nothing -> error $ "Impossible situation requested: Maybe has 2 constructors, constructor " ++ show (natVal n) ++ "is out of bound" + Nothing -> error $ "Impossible situation requested: Maybe has 2 constructors, constructor " ++ show (natVal n) ++ "is out of bounds" Nothing -> error "Impossible situation requested: Just a expects a single value, got 0 or more then 1" match n (Exp e) = case sameNat (Proxy @(Choices a)) (Proxy @0) of @@ -229,7 +229,7 @@ instance (POSable (Maybe a), POSable a) => Matchable (Maybe a) where , m < fromInteger (natVal $ Proxy @(Choices a)) -- remove one from the tag as we are not in left anymore -- the `tag` function will apply the new tag if necessary - -> Just (Exp (tag @a (unExp $ mkMin @TAG (constant 1) (Exp $ prjLeft x)) (splitRight @() @a $ prjRight x)) :* SOP.Nil) + -> Just (Exp (tag @a (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant 1)) (splitRight @() @a $ prjRight x)) :* SOP.Nil) SmartExp Match {} -> Nothing _ -> error "Embedded pattern synonym used outside 'match' context." @@ -290,30 +290,69 @@ tag t x = case eltRType @x of TaggedType -> SmartExp $ Pair t x instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) where - -- type Choices' (Either a b) = OuterChoices (Either a b) - - build n fs - -- this is only not true if either left or right has a tag of type Finite 0 - -- types with tags of Finite 0 have no constructors, and are quite useless - | Refl :: (EltR (Either a b) :~: (TAG, FlattenProduct (Fields (Either a b)))) <- unsafeCoerce Refl - = case sameNat n (Proxy :: Proxy 0) of - -- we have chosen constructor 0 (Left) - Just Refl -> case emptyFields @a of - -- Left has no fields - PTNil -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (undefPairs @(Fields b) (emptyFields @b)))) - -- Left has fields - PTCons st pt -> Exp (SmartExp (Pair (unExp $ buildTAG fs) undefined)) - Nothing -> - case sameNat n (Proxy :: Proxy 1) of - -- we have chosen constructor 1 (Right) - Just Refl -> case emptyFields @a of - PTNil -> case fs of - x :* SOP.Nil -> case eltRType @b of -- disambiguate between tagless and tagged b's - SingletonType -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (SmartExp (Pair (SmartExp (Union (SmartExp (LiftUnion (unExp x))))) (SmartExp Smart.Nil))))) - TaglessType -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (mergePairs @(Fields b) (emptyFields @b) (unExp x)))) - TaggedType -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (mergePairs @(Fields b) (emptyFields @b) (SmartExp (Prj PairIdxRight (unExp x)))))) - (PTCons x xs) -> Exp (SmartExp (Pair (unExp $ buildTAG fs) undefined)) - Nothing -> error "Index out of bounds" + + build n fs = case sameNat (Proxy @(Choices a)) (Proxy @0) of + -- a has 0 valid choices (which means we cannot create a Left of this type) + -- we ignore the implementation for now, because this is not really useful + Just Refl -> undefined + Nothing -> case sameNat (Proxy @(Choices b)) (Proxy @0) of + -- b has 0 valid choices (which means we cannot create a Right of this type) + -- we ignore the implementation too + Just Refl -> undefined + -- a and b have at least 1 choice. + -- this means that Either a b always has a tag + Nothing | Refl :: EltR (Either a b) :~: (TAG, FlattenProduct (Fields (Either a b))) <- unsafeCoerce Refl + -> case sameNat n (Proxy :: Proxy 0) of + -- Product a Left + Just Refl + | Exp x :* SOP.Nil <- fs + -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (makeLeft @a @b (unTag @a x)))) + Nothing + | Exp x :* SOP.Nil <- fs + -> case sameNat n (Proxy :: Proxy 1) of + -- Add natVal @(Choices to the tag) + Just Refl -> Exp (SmartExp (Pair (unExp $ mkAdd @TAG (constant $ fromInteger $ natVal $ Proxy @(Choices a)) (buildTag fs)) (makeRight @a @b (unTag @b x)))) + Nothing -> error $ "Impossible situation requested: Maybe has 2 constructors, constructor " ++ show (natVal n) ++ "is out of bounds" + Nothing -> error "Impossible situation requested: Just a expects a single value, got 0 or more then 1" + + match n (Exp e) = case sameNat (Proxy @(Choices a)) (Proxy @0) of + -- a has 0 valid choices (which means we cannot create a Left of this type) + -- we ignore the implementation for now, because this is not really useful + Just Refl -> undefined + Nothing -> case sameNat (Proxy @(Choices b)) (Proxy @0) of + -- b has 0 valid choices (which means we cannot create a Right of this type) + -- we ignore the implementation too + Just Refl -> undefined + -- a and b have at least 1 choice. + -- this means that Either a b always has a tag + Nothing | Refl :: EltR (Either a b) :~: (TAG, FlattenProduct (Fields (Either a b))) <- unsafeCoerce Refl + -> case sameNat n (Proxy :: Proxy 0) of -- matchLeft + Just Refl -> + case e of + SmartExp (Match m x) + | m >= 0 + , m < fromInteger (natVal $ Proxy @(Choices a)) + -> Just (Exp (tag @a (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant $ fromInteger $ natVal $ Proxy @(Choices a))) (splitLeft @a @b $ prjRight x)) :* SOP.Nil) + + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + Nothing -> -- matchRight + case sameNat n (Proxy :: Proxy 1) of + Just Refl -> + case e of + SmartExp (Match m x) + | m >= fromInteger (natVal $ Proxy @(Choices a)) + , m < fromInteger (natVal $ Proxy @(Choices b)) + -- remove one from the tag as we are not in left anymore + -- the `tag` function will apply the new tag if necessary + -> Just (Exp (tag @b (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant $ fromInteger $ natVal $ Proxy @(Choices a))) (splitRight @a @b $ prjRight x)) :* SOP.Nil) + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + + Nothing -> + error "Impossible type encountered" undefPairs :: forall xs . ProductType xs -> SmartExp (FlattenProduct (Merge '[] (xs ++ '[]))) undefPairs PTNil = SmartExp Smart.Nil From 15716b7ccc73995b267324b3ba1375bc81d60033 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 2 Jun 2022 16:30:17 +0200 Subject: [PATCH 52/67] tag building in terms of tagVal --- .../Array/Accelerate/Pattern/Matchable.hs | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index abc5c8f0d..90c5433f8 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -86,8 +86,6 @@ type family (!!) (xs :: [[Type]]) (y :: Nat) :: [Type] where infixl 9 !! instance Matchable Bool where - -- type Choices' Bool = 2 - build n _ = Exp (SmartExp (Pair (unExp $ constant @TAG (fromInteger $ natVal n)) (SmartExp Smart.Nil))) match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of @@ -226,7 +224,7 @@ instance (POSable (Maybe a), POSable a) => Matchable (Maybe a) where case e of SmartExp (Match m x) | m >= 1 - , m < fromInteger (natVal $ Proxy @(Choices a)) + , m < tagVal @(Choices a) -- remove one from the tag as we are not in left anymore -- the `tag` function will apply the new tag if necessary -> Just (Exp (tag @a (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant 1)) (splitRight @() @a $ prjRight x)) :* SOP.Nil) @@ -311,7 +309,7 @@ instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) | Exp x :* SOP.Nil <- fs -> case sameNat n (Proxy :: Proxy 1) of -- Add natVal @(Choices to the tag) - Just Refl -> Exp (SmartExp (Pair (unExp $ mkAdd @TAG (constant $ fromInteger $ natVal $ Proxy @(Choices a)) (buildTag fs)) (makeRight @a @b (unTag @b x)))) + Just Refl -> Exp (SmartExp (Pair (unExp $ mkAdd @TAG (constant $ tagVal @(Choices a)) (buildTag fs)) (makeRight @a @b (unTag @b x)))) Nothing -> error $ "Impossible situation requested: Maybe has 2 constructors, constructor " ++ show (natVal n) ++ "is out of bounds" Nothing -> error "Impossible situation requested: Just a expects a single value, got 0 or more then 1" @@ -331,8 +329,8 @@ instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) case e of SmartExp (Match m x) | m >= 0 - , m < fromInteger (natVal $ Proxy @(Choices a)) - -> Just (Exp (tag @a (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant $ fromInteger $ natVal $ Proxy @(Choices a))) (splitLeft @a @b $ prjRight x)) :* SOP.Nil) + , m < tagVal @(Choices a) + -> Just (Exp (tag @a (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant $ tagVal @(Choices a))) (splitLeft @a @b $ prjRight x)) :* SOP.Nil) SmartExp Match {} -> Nothing @@ -342,11 +340,11 @@ instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) Just Refl -> case e of SmartExp (Match m x) - | m >= fromInteger (natVal $ Proxy @(Choices a)) - , m < fromInteger (natVal $ Proxy @(Choices b)) + | m >= tagVal @(Choices a) + , m < tagVal @(Choices b) -- remove one from the tag as we are not in left anymore -- the `tag` function will apply the new tag if necessary - -> Just (Exp (tag @b (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant $ fromInteger $ natVal $ Proxy @(Choices a))) (splitRight @a @b $ prjRight x)) :* SOP.Nil) + -> Just (Exp (tag @b (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant $ tagVal @(Choices a))) (splitRight @a @b $ prjRight x)) :* SOP.Nil) SmartExp Match {} -> Nothing _ -> error "Embedded pattern synonym used outside 'match' context." @@ -376,4 +374,7 @@ combineProduct x y = case sameNat (emptyChoices @x) (Proxy :: Proxy 1) of -- tagged type Nothing | Refl :: (EltR x :~: (TAG, y)) <- unsafeCoerce Refl - -> mkAdd (mkExp $ Prj PairIdxLeft (unExp x)) (mkMul y (constant (fromInteger $ natVal (emptyChoices @x)))) + -> mkAdd (mkExp $ Prj PairIdxLeft (unExp x)) (mkMul y (constant (tagVal @(Choices x)))) + +tagVal :: forall a . (KnownNat a) => TAG +tagVal = fromInteger $ natVal (Proxy @a) From 24d245825c591e31242f9d7855d4ff8c0cb67189 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Fri, 3 Jun 2022 12:50:16 +0200 Subject: [PATCH 53/67] pattern matching up to pattern synonyms --- src/Data/Array/Accelerate/Pattern.hs | 198 ++--- .../Array/Accelerate/Pattern/Matchable.hs | 48 +- src/Data/Array/Accelerate/Pattern/Maybe.hs | 4 +- src/Data/Array/Accelerate/Pattern/Ordering.hs | 4 +- src/Data/Array/Accelerate/Pattern/TH.hs | 832 +++++++++--------- .../Array/Accelerate/Representation/Shape.hs | 4 +- .../Array/Accelerate/Representation/Tag.hs | 36 +- src/Data/Array/Accelerate/Smart.hs | 6 +- src/Data/Array/Accelerate/Sugar/Elt.hs | 24 +- src/Data/Array/Accelerate/Sugar/Shape.hs | 4 +- src/Data/Array/Accelerate/Sugar/Vec.hs | 4 +- src/Data/Array/Accelerate/Type.hs | 2 +- 12 files changed, 578 insertions(+), 588 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern.hs b/src/Data/Array/Accelerate/Pattern.hs index e212c0869..dffe94a42 100644 --- a/src/Data/Array/Accelerate/Pattern.hs +++ b/src/Data/Array/Accelerate/Pattern.hs @@ -114,107 +114,107 @@ instance (Elt a, Elt b) => IsPattern Exp (a :. b) (Exp a :. Exp b) where -- IsPattern instances for up to 16-tuples (Acc and Exp). TH takes care of -- the (unremarkable) boilerplate for us. -- -runQ $ do - let - -- Generate instance declarations for IsPattern of the form: - -- instance (Arrays x, ArraysR x ~ (((), ArraysR a), ArraysR b), Arrays a, Arrays b,) => IsPattern Acc x (Acc a, Acc b) - mkAccPattern :: Int -> Q [Dec] - mkAccPattern n = do - a <- newName "a" - let - -- Type variables for the elements - xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - -- Last argument to `IsPattern`, eg (Acc a, Acc b) in the example - b = tupT (map (\t -> [t| Acc $(varT t)|]) xs) - -- Representation as snoc-list of pairs, eg (((), ArraysR a), ArraysR b) - snoc = foldl (\sn t -> [t| ($sn, ArraysR $(varT t)) |]) [t| () |] xs - -- Constraints for the type class, consisting of Arrays constraints on all type variables, - -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. - context = tupT - $ [t| Arrays $(varT a) |] - : [t| ArraysR $(varT a) ~ $snoc |] - : map (\t -> [t| Arrays $(varT t)|]) xs - -- - get x 0 = [| Acc (SmartAcc (Aprj PairIdxRight $x)) |] - get x i = get [| SmartAcc (Aprj PairIdxLeft $x) |] (i-1) - -- - _x <- newName "_x" - [d| instance $context => IsPattern Acc $(varT a) $b where - builder $(tupP (map (\x -> [p| Acc $(varP x)|]) xs)) = - Acc $(foldl (\vs v -> [| SmartAcc ($vs `Apair` $(varE v)) |]) [| SmartAcc Anil |] xs) - matcher (Acc $(varP _x)) = - $(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0])) - |] +-- runQ $ do +-- let +-- -- Generate instance declarations for IsPattern of the form: +-- -- instance (Arrays x, ArraysR x ~ (((), ArraysR a), ArraysR b), Arrays a, Arrays b,) => IsPattern Acc x (Acc a, Acc b) +-- mkAccPattern :: Int -> Q [Dec] +-- mkAccPattern n = do +-- a <- newName "a" +-- let +-- -- Type variables for the elements +-- xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] +-- -- Last argument to `IsPattern`, eg (Acc a, Acc b) in the example +-- b = tupT (map (\t -> [t| Acc $(varT t)|]) xs) +-- -- Representation as snoc-list of pairs, eg (((), ArraysR a), ArraysR b) +-- snoc = foldl (\sn t -> [t| ($sn, ArraysR $(varT t)) |]) [t| () |] xs +-- -- Constraints for the type class, consisting of Arrays constraints on all type variables, +-- -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. +-- context = tupT +-- $ [t| Arrays $(varT a) |] +-- : [t| ArraysR $(varT a) ~ $snoc |] +-- : map (\t -> [t| Arrays $(varT t)|]) xs +-- -- +-- get x 0 = [| Acc (SmartAcc (Aprj PairIdxRight $x)) |] +-- get x i = get [| SmartAcc (Aprj PairIdxLeft $x) |] (i-1) +-- -- +-- _x <- newName "_x" +-- [d| instance $context => IsPattern Acc $(varT a) $b where +-- builder $(tupP (map (\x -> [p| Acc $(varP x)|]) xs)) = +-- Acc $(foldl (\vs v -> [| SmartAcc ($vs `Apair` $(varE v)) |]) [| SmartAcc Anil |] xs) +-- matcher (Acc $(varP _x)) = +-- $(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0])) +-- |] - -- Generate instance declarations for IsPattern of the form: - -- instance (Elt x, EltR x ~ (((), EltR a), EltR b), Elt a, Elt b,) => IsPattern Exp x (Exp a, Exp b) - mkExpPattern :: Int -> Q [Dec] - mkExpPattern n = do - a <- newName "a" - let - -- Type variables for the elements - xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - -- Variables for sub-pattern matches - ms = [ mkName ('m' : show i) | i <- [0 .. n-1] ] - tags = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms - -- Last argument to `IsPattern`, eg (Exp, a, Exp b) in the example - b = tupT (map (\t -> [t| Exp $(varT t)|]) xs) - -- Representation as snoc-list of pairs, eg (((), EltR a), EltR b) - snoc = foldl (\sn t -> [t| ($sn, EltR $(varT t)) |]) [t| () |] xs - -- Constraints for the type class, consisting of Elt constraints on all type variables, - -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. - context = tupT - $ [t| Elt $(varT a) |] - : [t| EltR $(varT a) ~ $snoc |] - : map (\t -> [t| Elt $(varT t)|]) xs - -- - get x 0 = [| SmartExp (Prj PairIdxRight $x) |] - get x i = get [| SmartExp (Prj PairIdxLeft $x) |] (i-1) - -- - _x <- newName "_x" - _y <- newName "_y" - [d| instance $context => IsPattern Exp $(varT a) $b where - builder $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) = - let _unmatch :: SmartExp a -> SmartExp a - _unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y) - _unmatch x = x - in - Exp $(foldl (\vs v -> [| SmartExp ($vs `Pair` _unmatch $(varE v)) |]) [| SmartExp Nil |] xs) - matcher (Exp $(varP _x)) = - case $(varE _x) of - SmartExp (Match $tags $(varP _y)) - -> $(tupE [[| Exp (SmartExp (Match $(varE m) $(get (varE _x) i))) |] | m <- ms | i <- [(n-1), (n-2) .. 0]]) - _ -> $(tupE [[| Exp $(get (varE _x) i) |] | i <- [(n-1), (n-2) .. 0]]) - |] +-- -- Generate instance declarations for IsPattern of the form: +-- -- instance (Elt x, EltR x ~ (((), EltR a), EltR b), Elt a, Elt b,) => IsPattern Exp x (Exp a, Exp b) +-- mkExpPattern :: Int -> Q [Dec] +-- mkExpPattern n = do +-- a <- newName "a" +-- let +-- -- Type variables for the elements +-- xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] +-- -- Variables for sub-pattern matches +-- ms = [ mkName ('m' : show i) | i <- [0 .. n-1] ] +-- tags = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms +-- -- Last argument to `IsPattern`, eg (Exp, a, Exp b) in the example +-- b = tupT (map (\t -> [t| Exp $(varT t)|]) xs) +-- -- Representation as snoc-list of pairs, eg (((), EltR a), EltR b) +-- snoc = foldl (\sn t -> [t| ($sn, EltR $(varT t)) |]) [t| () |] xs +-- -- Constraints for the type class, consisting of Elt constraints on all type variables, +-- -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. +-- context = tupT +-- $ [t| Elt $(varT a) |] +-- : [t| EltR $(varT a) ~ $snoc |] +-- : map (\t -> [t| Elt $(varT t)|]) xs +-- -- +-- get x 0 = [| SmartExp (Prj PairIdxRight $x) |] +-- get x i = get [| SmartExp (Prj PairIdxLeft $x) |] (i-1) +-- -- +-- _x <- newName "_x" +-- _y <- newName "_y" +-- [d| instance $context => IsPattern Exp $(varT a) $b where +-- builder $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) = +-- let _unmatch :: SmartExp a -> SmartExp a +-- _unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y) +-- _unmatch x = x +-- in +-- Exp $(foldl (\vs v -> [| SmartExp ($vs `Pair` _unmatch $(varE v)) |]) [| SmartExp Nil |] xs) +-- matcher (Exp $(varP _x)) = +-- case $(varE _x) of +-- SmartExp (Match $tags $(varP _y)) +-- -> $(tupE [[| Exp (SmartExp (Match $(varE m) $(get (varE _x) i))) |] | m <- ms | i <- [(n-1), (n-2) .. 0]]) +-- _ -> $(tupE [[| Exp $(get (varE _x) i) |] | i <- [(n-1), (n-2) .. 0]]) +-- |] - -- Generate instance declarations for IsVector of the form: - -- instance (Elt v, EltR v ~ Vec 2 a, Elt a) => IsVector Exp v (Exp a, Exp a) - mkVecPattern :: Int -> Q [Dec] - mkVecPattern n = do - a <- newName "a" - v <- newName "v" - let - -- Last argument to `IsVector`, eg (Exp, a, Exp a) in the example - tup = tupT (replicate n ([t| Exp $(varT a)|])) - -- Representation as a vector, eg (Vec 2 a) - vec = [t| Vec $(litT (numTyLit (fromIntegral n))) $(varT a) |] - -- Constraints for the type class, consisting of Elt constraints on all type variables, - -- and an equality constraint on the representation type of `a` and the vector representation `vec`. - context = [t| (Elt $(varT v), VecElt $(varT a), EltR $(varT v) ~ $vec) |] - -- - vecR = foldr appE ([| VecRnil |] `appE` (varE 'singleType `appTypeE` varT a)) (replicate n [| VecRsucc |]) - tR = tupT (replicate n (varT a)) - -- - [d| instance $context => IsVector Exp $(varT v) $tup where - vpack x = case builder x :: Exp $tR of - Exp x' -> Exp (SmartExp (VecPack $vecR x')) - vunpack (Exp x) = matcher (Exp (SmartExp (VecUnpack $vecR x)) :: Exp $tR) - |] - -- - es <- mapM mkExpPattern [0..16] - as <- mapM mkAccPattern [0..16] - vs <- mapM mkVecPattern [2,3,4,8,16] - return $ concat (es ++ as ++ vs) +-- -- Generate instance declarations for IsVector of the form: +-- -- instance (Elt v, EltR v ~ Vec 2 a, Elt a) => IsVector Exp v (Exp a, Exp a) +-- mkVecPattern :: Int -> Q [Dec] +-- mkVecPattern n = do +-- a <- newName "a" +-- v <- newName "v" +-- let +-- -- Last argument to `IsVector`, eg (Exp, a, Exp a) in the example +-- tup = tupT (replicate n ([t| Exp $(varT a)|])) +-- -- Representation as a vector, eg (Vec 2 a) +-- vec = [t| Vec $(litT (numTyLit (fromIntegral n))) $(varT a) |] +-- -- Constraints for the type class, consisting of Elt constraints on all type variables, +-- -- and an equality constraint on the representation type of `a` and the vector representation `vec`. +-- context = [t| (Elt $(varT v), VecElt $(varT a), EltR $(varT v) ~ $vec) |] +-- -- +-- vecR = foldr appE ([| VecRnil |] `appE` (varE 'singleType `appTypeE` varT a)) (replicate n [| VecRsucc |]) +-- tR = tupT (replicate n (varT a)) +-- -- +-- [d| instance $context => IsVector Exp $(varT v) $tup where +-- vpack x = case builder x :: Exp $tR of +-- Exp x' -> Exp (SmartExp (VecPack $vecR x')) +-- vunpack (Exp x) = matcher (Exp (SmartExp (VecUnpack $vecR x)) :: Exp $tR) +-- |] +-- -- +-- es <- mapM mkExpPattern [0..16] +-- as <- mapM mkAccPattern [0..16] +-- vs <- mapM mkVecPattern [2,3,4,8,16] +-- return $ concat (es ++ as ++ vs) -- | Specialised pattern synonyms for tuples, which may be more convenient to diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index 90c5433f8..a5b58bfc5 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -67,7 +67,7 @@ class Matchable a where buildTag :: SOP.All POSable xs => NP Exp xs -> Exp TAG buildTag SOP.Nil = constant 0 -- exp of 0 :: Finite 1 -buildTag (((Exp x) :: (Exp x)) :* (xs :: xs)) = case sameNat (emptyChoices @x) (Proxy :: Proxy 1) of +buildTag (((Exp x) :: (Exp x)) :* (xs :: xs)) = case sameNat (Proxy @(Choices x)) (Proxy :: Proxy 1) of -- x doesn't contain a tag, skip Just Refl -> buildTag xs @@ -91,7 +91,10 @@ instance Matchable Bool where match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of Just Refl -> case e of - SmartExp (Match n _x) | n == 0 -> Just SOP.Nil + SmartExp (Match (TagR l u) _x) + | l == 0 + , u == 1 + -> Just SOP.Nil SmartExp Match {} -> Nothing @@ -100,7 +103,10 @@ instance Matchable Bool where case sameNat n (Proxy :: Proxy 1) of Just Refl -> case e of - SmartExp (Match n _x) | n == 1 -> Just SOP.Nil + SmartExp (Match (TagR l u) _x) + | l == 1 + , u == 2 + -> Just SOP.Nil SmartExp Match {} -> Nothing @@ -157,8 +163,9 @@ instance Matchable (Maybe Int) where match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of Just Refl -> case e of - SmartExp (Match m _x) - | m == 0 + SmartExp (Match (TagR l u) _x) + | l == 0 + , u == 1 -> Just SOP.Nil SmartExp Match {} -> Nothing @@ -168,8 +175,9 @@ instance Matchable (Maybe Int) where case sameNat n (Proxy :: Proxy 1) of Just Refl -> case e of - SmartExp (Match m x) - | m == 1 + SmartExp (Match (TagR l u) x) + | l == 1 + , u == 2 -> Just ( mkExp (PrjUnion $ SmartExp $ Union (prjLeft (prjRight x))) :* SOP.Nil) @@ -210,9 +218,9 @@ instance (POSable (Maybe a), POSable a) => Matchable (Maybe a) where -> case sameNat n (Proxy :: Proxy 0) of Just Refl -> case e of - SmartExp (Match m _x) - | m >= 0 - , m < 1 + SmartExp (Match (TagR l u) _x) + | l == 0 + , u == 1 -> Just SOP.Nil SmartExp Match {} -> Nothing @@ -222,9 +230,9 @@ instance (POSable (Maybe a), POSable a) => Matchable (Maybe a) where case sameNat n (Proxy :: Proxy 1) of Just Refl -> case e of - SmartExp (Match m x) - | m >= 1 - , m < tagVal @(Choices a) + SmartExp (Match (TagR l u) x) + | l == 1 + , u == tagVal @(Choices a) -- remove one from the tag as we are not in left anymore -- the `tag` function will apply the new tag if necessary -> Just (Exp (tag @a (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant 1)) (splitRight @() @a $ prjRight x)) :* SOP.Nil) @@ -327,9 +335,9 @@ instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) -> case sameNat n (Proxy :: Proxy 0) of -- matchLeft Just Refl -> case e of - SmartExp (Match m x) - | m >= 0 - , m < tagVal @(Choices a) + SmartExp (Match (TagR l u) x) + | l == 0 + , u == tagVal @(Choices a) -> Just (Exp (tag @a (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant $ tagVal @(Choices a))) (splitLeft @a @b $ prjRight x)) :* SOP.Nil) SmartExp Match {} -> Nothing @@ -339,9 +347,9 @@ instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) case sameNat n (Proxy :: Proxy 1) of Just Refl -> case e of - SmartExp (Match m x) - | m >= tagVal @(Choices a) - , m < tagVal @(Choices b) + SmartExp (Match (TagR l u) x) + | l == tagVal @(Choices a) + , u == tagVal @(Choices b) -- remove one from the tag as we are not in left anymore -- the `tag` function will apply the new tag if necessary -> Just (Exp (tag @b (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant $ tagVal @(Choices a))) (splitRight @a @b $ prjRight x)) :* SOP.Nil) @@ -368,7 +376,7 @@ buildTAG (x :* xs) = combineProduct x (buildTAG xs) -- like Finite.combineProduct, but lifted to the AST -- basically `tag x + tag y * natVal x` combineProduct :: forall x. (POSable x) => Exp x -> Exp TAG -> Exp TAG -combineProduct x y = case sameNat (emptyChoices @x) (Proxy :: Proxy 1) of +combineProduct x y = case sameNat (Proxy @(Choices x)) (Proxy :: Proxy 1) of -- untagged type: `tag x = 0`, `natVal x = 1` Just Refl -> y -- tagged type diff --git a/src/Data/Array/Accelerate/Pattern/Maybe.hs b/src/Data/Array/Accelerate/Pattern/Maybe.hs index 67e341d64..642f88028 100644 --- a/src/Data/Array/Accelerate/Pattern/Maybe.hs +++ b/src/Data/Array/Accelerate/Pattern/Maybe.hs @@ -22,5 +22,7 @@ module Data.Array.Accelerate.Pattern.Maybe ( import Data.Array.Accelerate.Pattern.TH -mkPattern ''Maybe +-- mkPattern ''Maybe +pattern Nothing_ <- match (Proxy :: Proxy 0) SOP.Nil where + Nothing_{} = build (Proxy @0) SOP.Nil diff --git a/src/Data/Array/Accelerate/Pattern/Ordering.hs b/src/Data/Array/Accelerate/Pattern/Ordering.hs index 2407cf9e9..e6c783043 100644 --- a/src/Data/Array/Accelerate/Pattern/Ordering.hs +++ b/src/Data/Array/Accelerate/Pattern/Ordering.hs @@ -16,11 +16,11 @@ module Data.Array.Accelerate.Pattern.Ordering ( - Ordering, pattern LT_, pattern EQ_, pattern GT_, + -- Ordering, pattern LT_, pattern EQ_, pattern GT_, ) where import Data.Array.Accelerate.Pattern.TH -mkPattern ''Ordering +-- mkPattern ''Ordering diff --git a/src/Data/Array/Accelerate/Pattern/TH.hs b/src/Data/Array/Accelerate/Pattern/TH.hs index c9ec918ac..f14de628d 100644 --- a/src/Data/Array/Accelerate/Pattern/TH.hs +++ b/src/Data/Array/Accelerate/Pattern/TH.hs @@ -12,8 +12,8 @@ module Data.Array.Accelerate.Pattern.TH ( - mkPattern, - mkPatterns, + -- mkPattern, + -- mkPatterns, ) where @@ -38,418 +38,418 @@ import GHC.Stack -- | As 'mkPattern', but for a list of types -- -mkPatterns :: [Name] -> DecsQ -mkPatterns nms = concat <$> mapM mkPattern nms - --- | Generate pattern synonyms for the given simple (Haskell'98) sum or --- product data type. --- --- Constructor and record selectors are renamed to add a trailing --- underscore if it does not exist, or to remove it if it does. For infix --- constructors, the name is prepended with a colon ':'. For example: --- --- > data Point = Point { xcoord_ :: Float, ycoord_ :: Float } --- > deriving (Generic, Elt) --- --- Will create the pattern synonym: --- --- > Point_ :: Exp Float -> Exp Float -> Exp Point --- --- together with the selector functions --- --- > xcoord :: Exp Point -> Exp Float --- > ycoord :: Exp Point -> Exp Float --- -mkPattern :: Name -> DecsQ -mkPattern nm = do - info <- reify nm - case info of - TyConI dec -> mkDec dec - _ -> fail "mkPatterns: expected the name of a newtype or datatype" - -mkDec :: Dec -> DecsQ -mkDec dec = - case dec of - DataD _ nm tv _ cs _ -> mkDataD nm tv cs - NewtypeD _ nm tv _ c _ -> mkNewtypeD nm tv c - _ -> fail "mkPatterns: expected the name of a newtype or datatype" - -mkNewtypeD :: Name -> [TyVarBndr ()] -> Con -> DecsQ -mkNewtypeD tn tvs c = mkDataD tn tvs [c] - -mkDataD :: Name -> [TyVarBndr ()] -> [Con] -> DecsQ -mkDataD tn tvs cs = do - (pats, decs) <- unzip <$> go cs - comp <- pragCompleteD pats Nothing - return $ comp : concat decs - where - -- For single-constructor types we create the pattern synonym for the - -- type directly in terms of Pattern - go [] = fail "mkPatterns: empty data declarations not supported" - go [c] = return <$> mkConP tn tvs c - go _ = go' [] (map fieldTys cs) ctags cs - - -- For sum-types, when creating the pattern for an individual - -- constructor we need to know about the types of the fields all other - -- constructors as well - go' prev (this:next) (tag:tags) (con:cons) = do - r <- mkConS tn tvs prev next tag con - rs <- go' (this:prev) next tags cons - return (r : rs) - go' _ [] [] [] = return [] - go' _ _ _ _ = fail "mkPatterns: unexpected error" - - fieldTys (NormalC _ fs) = map snd fs - fieldTys (RecC _ fs) = map (\(_,_,t) -> t) fs - fieldTys (InfixC a _ b) = [snd a, snd b] - fieldTys _ = fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" - - -- TODO: The GTags class demonstrates a way to generate the tags for - -- a given constructor, rather than backwards-engineering the structure - -- as we've done here. We should use that instead! - -- - ctags = - let n = length cs - m = n `quot` 2 - l = take m (iterate (True:) [False]) - r = take (n-m) (iterate (True:) [True]) - -- - bitsToTag = foldl' f 0 - where - f i False = i `shiftL` 1 - f i True = setBit (i `shiftL` 1) 0 - in - map bitsToTag (l ++ r) - - -mkConP :: Name -> [TyVarBndr ()] -> Con -> Q (Name, [Dec]) -mkConP tn' tvs' con' = do - checkExts [ PatternSynonyms ] - case con' of - NormalC cn fs -> mkNormalC tn' cn (map tyVarBndrName tvs') (map snd fs) - RecC cn fs -> mkRecC tn' cn (map tyVarBndrName tvs') (map (rename . fst3) fs) (map thd3 fs) - InfixC a cn b -> mkInfixC tn' cn (map tyVarBndrName tvs') [snd a, snd b] - _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" - where - mkNormalC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec]) - mkNormalC tn cn tvs fs = do - xs <- replicateM (length fs) (newName "_x") - r <- sequence [ patSynSigD pat sig - , patSynD pat - (prefixPatSyn xs) - implBidir - [p| Pattern $(tupP (map varP xs)) |] - ] - return (pat, r) - where - pat = rename cn - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - - mkRecC :: Name -> Name -> [Name] -> [Name] -> [Type] -> Q (Name, [Dec]) - mkRecC tn cn tvs xs fs = do - r <- sequence [ patSynSigD pat sig - , patSynD pat - (recordPatSyn xs) - implBidir - [p| Pattern $(tupP (map varP xs)) |] - ] - return (pat, r) - where - pat = rename cn - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - - mkInfixC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec]) - mkInfixC tn cn tvs fs = do - mf <- reifyFixity cn - _a <- newName "_a" - _b <- newName "_b" - r <- sequence [ patSynSigD pat sig - , patSynD pat - (infixPatSyn _a _b) - implBidir - [p| Pattern $(tupP [varP _a, varP _b]) |] - ] - r' <- case mf of - Nothing -> return r - Just f -> return (InfixD f pat : r) - return (pat, r') - where - pat = mkName (':' : nameBase cn) - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - -mkConS :: Name -> [TyVarBndr ()] -> [[Type]] -> [[Type]] -> Word8 -> Con -> Q (Name, [Dec]) -mkConS tn' tvs' prev' next' tag' con' = do - checkExts [GADTs, PatternSynonyms, ScopedTypeVariables, TypeApplications, ViewPatterns] - case con' of - NormalC cn fs -> mkNormalC tn' cn tag' (map tyVarBndrName tvs') prev' (map snd fs) next' - RecC cn fs -> mkRecC tn' cn tag' (map tyVarBndrName tvs') (map (rename . fst3) fs) prev' (map thd3 fs) next' - InfixC a cn b -> mkInfixC tn' cn tag' (map tyVarBndrName tvs') prev' [snd a, snd b] next' - _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" - where - mkNormalC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) - mkNormalC tn cn tag tvs ps fs ns = do - let pat = rename cn - (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns - (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns - dec_pat <- mkNormalC_pattern tn pat tvs fs fun_build fun_match - return $ (pat, concat [dec_pat, dec_build, dec_match]) - - mkRecC :: Name -> Name -> Word8 -> [Name] -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) - mkRecC tn cn tag tvs xs ps fs ns = do - let pat = rename cn - (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns - (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns - dec_pat <- mkRecC_pattern tn pat tvs xs fs fun_build fun_match - return $ (pat, concat [dec_pat, dec_build, dec_match]) - - mkInfixC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) - mkInfixC tn cn tag tvs ps fs ns = do - let pat = mkName (':' : nameBase cn) - (fun_build, dec_build) <- mkBuild tn (zencode (nameBase cn)) tvs tag ps fs ns - (fun_match, dec_match) <- mkMatch tn ("(" ++ nameBase pat ++ ")") (zencode (nameBase cn)) tvs tag ps fs ns - dec_pat <- mkInfixC_pattern tn cn pat tvs fs fun_build fun_match - return $ (pat, concat [dec_pat, dec_build, dec_match]) - - mkNormalC_pattern :: Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec] - mkNormalC_pattern tn pat tvs fs build match = do - xs <- replicateM (length fs) (newName "_x") - r <- sequence [ patSynSigD pat sig - , patSynD pat - (prefixPatSyn xs) - (explBidir [clause [] (normalB (varE build)) []]) - (parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |]) - ] - return r - where - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - - mkRecC_pattern :: Name -> Name -> [Name] -> [Name] -> [Type] -> Name -> Name -> Q [Dec] - mkRecC_pattern tn pat tvs xs fs build match = do - r <- sequence [ patSynSigD pat sig - , patSynD pat - (recordPatSyn xs) - (explBidir [clause [] (normalB (varE build)) []]) - (parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |]) - ] - return r - where - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - - mkInfixC_pattern :: Name -> Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec] - mkInfixC_pattern tn cn pat tvs fs build match = do - mf <- reifyFixity cn - _a <- newName "_a" - _b <- newName "_b" - r <- sequence [ patSynSigD pat sig - , patSynD pat - (infixPatSyn _a _b) - (explBidir [clause [] (normalB (varE build)) []]) - (parensP $ viewP (varE match) [p| Just $(tupP [varP _a, varP _b]) |]) - ] - r' <- case mf of - Nothing -> return r - Just f -> return (InfixD f pat : r) - return r' - where - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - - mkBuild :: Name -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) - mkBuild tn cn tvs tag fs0 fs fs1 = do - fun <- newName ("_build" ++ cn) - xs <- replicateM (length fs) (newName "_x") - let - vs = foldl' (\es e -> [| SmartExp ($es `Pair` $e) |]) [| SmartExp Nil |] - $ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat (reverse fs0)) - ++ map varE xs - ++ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat fs1) - - tagged = [| Exp $ SmartExp $ Pair (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeTAG))) $(litE (IntegerL (toInteger tag))))) $vs |] - body = clause (map (\x -> [p| (Exp $(varP x)) |]) xs) (normalB tagged) [] - - r <- sequence [ sigD fun sig - , funD fun [body] - ] - return (fun, r) - where - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - - - mkMatch :: Name -> String -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) - mkMatch tn pn cn tvs tag fs0 fs fs1 = do - fun <- newName ("_match" ++ cn) - e <- newName "_e" - x <- newName "_x" - (ps,es) <- extract vs [| Prj PairIdxRight $(varE x) |] [] [] - unbind <- isExtEnabled RebindableSyntax - let - eqE = if unbind then letE [funD (mkName "==") [clause [] (normalB (varE '(==))) []]] else id - lhs = [p| (Exp $(varP e)) |] - body = normalB $ eqE $ caseE (varE e) - [ TH.match (conP 'SmartExp [(conP 'Match [matchP ps, varP x])]) (normalB [| Just $(tupE es) |]) [] - , TH.match (conP 'SmartExp [(recP 'Match [])]) (normalB [| Nothing |]) [] - , TH.match wildP (normalB [| error $error_msg |]) [] - ] - - r <- sequence [ sigD fun sig - , funD fun [clause [lhs] body []] - ] - return (fun, r) - where - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) -> Maybe $(tupT (map (\t -> [t| Exp $(return t) |]) fs)) |] - - matchP us = [p| TagRtag $(litP (IntegerL (toInteger tag))) $pat |] - where - pat = [p| $(foldl (\ps p -> [p| TagRpair $ps $p |]) [p| TagRunit |] us) |] - - extract [] _ ps es = return (ps, es) - extract (u:us) x ps es = do - _u <- newName "_u" - let x' = [| Prj PairIdxLeft (SmartExp $x) |] - if not u - then extract us x' (wildP:ps) es - else extract us x' (varP _u:ps) ([| Exp (SmartExp (Match $(varE _u) (SmartExp (Prj PairIdxRight (SmartExp $x))))) |] : es) - - vs = reverse - $ [ False | _ <- concat fs0 ] ++ [ True | _ <- fs ] ++ [ False | _ <- concat fs1 ] - - error_msg = - let pv = unwords - $ take (length fs + 1) - $ concatMap (map reverse) - $ iterate (concatMap (\xs -> [ x:xs | x <- ['a'..'z'] ])) [""] - in stringE $ unlines - [ "Embedded pattern synonym used outside 'match' context." - , "" - , "To use case statements in the embedded language the case statement must" - , "be applied as an n-ary function to the 'match' operator. For single" - , "argument case statements this can be done inline using LambdaCase, for" - , "example:" - , "" - , "> x & match \\case" - , printf "> %s%s -> ..." pn pv - , printf "> _%s -> ..." (replicate (length pn + length pv - 1) ' ') - ] - -fst3 :: (a,b,c) -> a -fst3 (a,_,_) = a - -thd3 :: (a,b,c) -> c -thd3 (_,_,c) = c - -rename :: Name -> Name -rename nm = - let - split acc [] = (reverse acc, '\0') -- shouldn't happen - split acc [l] = (reverse acc, l) - split acc (l:ls) = split (l:acc) ls - -- - nm' = nameBase nm - (base, suffix) = split [] nm' - in - case suffix of - '_' -> mkName base - _ -> mkName (nm' ++ "_") - -checkExts :: [Extension] -> Q () -checkExts req = do - enabled <- extsEnabled - let missing = req \\ enabled - unless (null missing) . fail . unlines - $ printf "You must enable the following language extensions to generate pattern synonyms:" - : map (printf " {-# LANGUAGE %s #-}" . show) missing - --- A simplified version of that stolen from GHC/Utils/Encoding.hs --- -type EncodedString = String - -zencode :: String -> EncodedString -zencode [] = [] -zencode (h:rest) = encode_digit h ++ go rest - where - go [] = [] - go (c:cs) = encode_ch c ++ go cs - -unencoded_char :: Char -> Bool -unencoded_char 'z' = False -unencoded_char 'Z' = False -unencoded_char c = isAlphaNum c - -encode_digit :: Char -> EncodedString -encode_digit c | isDigit c = encode_as_unicode_char c - | otherwise = encode_ch c - -encode_ch :: Char -> EncodedString -encode_ch c | unencoded_char c = [c] -- Common case first -encode_ch '(' = "ZL" -encode_ch ')' = "ZR" -encode_ch '[' = "ZM" -encode_ch ']' = "ZN" -encode_ch ':' = "ZC" -encode_ch 'Z' = "ZZ" -encode_ch 'z' = "zz" -encode_ch '&' = "za" -encode_ch '|' = "zb" -encode_ch '^' = "zc" -encode_ch '$' = "zd" -encode_ch '=' = "ze" -encode_ch '>' = "zg" -encode_ch '#' = "zh" -encode_ch '.' = "zi" -encode_ch '<' = "zl" -encode_ch '-' = "zm" -encode_ch '!' = "zn" -encode_ch '+' = "zp" -encode_ch '\'' = "zq" -encode_ch '\\' = "zr" -encode_ch '/' = "zs" -encode_ch '*' = "zt" -encode_ch '_' = "zu" -encode_ch '%' = "zv" -encode_ch c = encode_as_unicode_char c - -encode_as_unicode_char :: Char -> EncodedString -encode_as_unicode_char c - = 'z' - : if isDigit (head hex_str) then hex_str - else '0':hex_str - where - hex_str = showHex (ord c) "U" +-- mkPatterns :: [Name] -> DecsQ +-- mkPatterns nms = concat <$> mapM mkPattern nms + +-- -- | Generate pattern synonyms for the given simple (Haskell'98) sum or +-- -- product data type. +-- -- +-- -- Constructor and record selectors are renamed to add a trailing +-- -- underscore if it does not exist, or to remove it if it does. For infix +-- -- constructors, the name is prepended with a colon ':'. For example: +-- -- +-- -- > data Point = Point { xcoord_ :: Float, ycoord_ :: Float } +-- -- > deriving (Generic, Elt) +-- -- +-- -- Will create the pattern synonym: +-- -- +-- -- > Point_ :: Exp Float -> Exp Float -> Exp Point +-- -- +-- -- together with the selector functions +-- -- +-- -- > xcoord :: Exp Point -> Exp Float +-- -- > ycoord :: Exp Point -> Exp Float +-- -- +-- mkPattern :: Name -> DecsQ +-- mkPattern nm = do +-- info <- reify nm +-- case info of +-- TyConI dec -> mkDec dec +-- _ -> fail "mkPatterns: expected the name of a newtype or datatype" + +-- mkDec :: Dec -> DecsQ +-- mkDec dec = +-- case dec of +-- DataD _ nm tv _ cs _ -> mkDataD nm tv cs +-- NewtypeD _ nm tv _ c _ -> mkNewtypeD nm tv c +-- _ -> fail "mkPatterns: expected the name of a newtype or datatype" + +-- mkNewtypeD :: Name -> [TyVarBndr ()] -> Con -> DecsQ +-- mkNewtypeD tn tvs c = mkDataD tn tvs [c] + +-- mkDataD :: Name -> [TyVarBndr ()] -> [Con] -> DecsQ +-- mkDataD tn tvs cs = do +-- (pats, decs) <- unzip <$> go cs +-- comp <- pragCompleteD pats Nothing +-- return $ comp : concat decs +-- where +-- -- For single-constructor types we create the pattern synonym for the +-- -- type directly in terms of Pattern +-- go [] = fail "mkPatterns: empty data declarations not supported" +-- go [c] = return <$> mkConP tn tvs c +-- go _ = go' [] (map fieldTys cs) ctags cs + +-- -- For sum-types, when creating the pattern for an individual +-- -- constructor we need to know about the types of the fields all other +-- -- constructors as well +-- go' prev (this:next) (tag:tags) (con:cons) = do +-- r <- mkConS tn tvs prev next tag con +-- rs <- go' (this:prev) next tags cons +-- return (r : rs) +-- go' _ [] [] [] = return [] +-- go' _ _ _ _ = fail "mkPatterns: unexpected error" + +-- fieldTys (NormalC _ fs) = map snd fs +-- fieldTys (RecC _ fs) = map (\(_,_,t) -> t) fs +-- fieldTys (InfixC a _ b) = [snd a, snd b] +-- fieldTys _ = fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" + +-- -- TODO: The GTags class demonstrates a way to generate the tags for +-- -- a given constructor, rather than backwards-engineering the structure +-- -- as we've done here. We should use that instead! +-- -- +-- ctags = +-- let n = length cs +-- m = n `quot` 2 +-- l = take m (iterate (True:) [False]) +-- r = take (n-m) (iterate (True:) [True]) +-- -- +-- bitsToTag = foldl' f 0 +-- where +-- f i False = i `shiftL` 1 +-- f i True = setBit (i `shiftL` 1) 0 +-- in +-- map bitsToTag (l ++ r) + + +-- mkConP :: Name -> [TyVarBndr ()] -> Con -> Q (Name, [Dec]) +-- mkConP tn' tvs' con' = do +-- checkExts [ PatternSynonyms ] +-- case con' of +-- NormalC cn fs -> mkNormalC tn' cn (map tyVarBndrName tvs') (map snd fs) +-- RecC cn fs -> mkRecC tn' cn (map tyVarBndrName tvs') (map (rename . fst3) fs) (map thd3 fs) +-- InfixC a cn b -> mkInfixC tn' cn (map tyVarBndrName tvs') [snd a, snd b] +-- _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" +-- where +-- mkNormalC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec]) +-- mkNormalC tn cn tvs fs = do +-- xs <- replicateM (length fs) (newName "_x") +-- r <- sequence [ patSynSigD pat sig +-- , patSynD pat +-- (prefixPatSyn xs) +-- implBidir +-- [p| Pattern $(tupP (map varP xs)) |] +-- ] +-- return (pat, r) +-- where +-- pat = rename cn +-- sig = forallT +-- (map (`plainInvisTV` specifiedSpec) tvs) +-- (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) +-- (foldr (\t ts -> [t| $t -> $ts |]) +-- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] +-- (map (\t -> [t| Exp $(return t) |]) fs)) + +-- mkRecC :: Name -> Name -> [Name] -> [Name] -> [Type] -> Q (Name, [Dec]) +-- mkRecC tn cn tvs xs fs = do +-- r <- sequence [ patSynSigD pat sig +-- , patSynD pat +-- (recordPatSyn xs) +-- implBidir +-- [p| Pattern $(tupP (map varP xs)) |] +-- ] +-- return (pat, r) +-- where +-- pat = rename cn +-- sig = forallT +-- (map (`plainInvisTV` specifiedSpec) tvs) +-- (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) +-- (foldr (\t ts -> [t| $t -> $ts |]) +-- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] +-- (map (\t -> [t| Exp $(return t) |]) fs)) + +-- mkInfixC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec]) +-- mkInfixC tn cn tvs fs = do +-- mf <- reifyFixity cn +-- _a <- newName "_a" +-- _b <- newName "_b" +-- r <- sequence [ patSynSigD pat sig +-- , patSynD pat +-- (infixPatSyn _a _b) +-- implBidir +-- [p| Pattern $(tupP [varP _a, varP _b]) |] +-- ] +-- r' <- case mf of +-- Nothing -> return r +-- Just f -> return (InfixD f pat : r) +-- return (pat, r') +-- where +-- pat = mkName (':' : nameBase cn) +-- sig = forallT +-- (map (`plainInvisTV` specifiedSpec) tvs) +-- (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) +-- (foldr (\t ts -> [t| $t -> $ts |]) +-- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] +-- (map (\t -> [t| Exp $(return t) |]) fs)) + +-- mkConS :: Name -> [TyVarBndr ()] -> [[Type]] -> [[Type]] -> Word8 -> Con -> Q (Name, [Dec]) +-- mkConS tn' tvs' prev' next' tag' con' = do +-- checkExts [GADTs, PatternSynonyms, ScopedTypeVariables, TypeApplications, ViewPatterns] +-- case con' of +-- NormalC cn fs -> mkNormalC tn' cn tag' (map tyVarBndrName tvs') prev' (map snd fs) next' +-- RecC cn fs -> mkRecC tn' cn tag' (map tyVarBndrName tvs') (map (rename . fst3) fs) prev' (map thd3 fs) next' +-- InfixC a cn b -> mkInfixC tn' cn tag' (map tyVarBndrName tvs') prev' [snd a, snd b] next' +-- _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" +-- where +-- mkNormalC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) +-- mkNormalC tn cn tag tvs ps fs ns = do +-- let pat = rename cn +-- (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns +-- (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns +-- dec_pat <- mkNormalC_pattern tn pat tvs fs fun_build fun_match +-- return $ (pat, concat [dec_pat, dec_build, dec_match]) + +-- mkRecC :: Name -> Name -> Word8 -> [Name] -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) +-- mkRecC tn cn tag tvs xs ps fs ns = do +-- let pat = rename cn +-- (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns +-- (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns +-- dec_pat <- mkRecC_pattern tn pat tvs xs fs fun_build fun_match +-- return $ (pat, concat [dec_pat, dec_build, dec_match]) + +-- mkInfixC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) +-- mkInfixC tn cn tag tvs ps fs ns = do +-- let pat = mkName (':' : nameBase cn) +-- (fun_build, dec_build) <- mkBuild tn (zencode (nameBase cn)) tvs tag ps fs ns +-- (fun_match, dec_match) <- mkMatch tn ("(" ++ nameBase pat ++ ")") (zencode (nameBase cn)) tvs tag ps fs ns +-- dec_pat <- mkInfixC_pattern tn cn pat tvs fs fun_build fun_match +-- return $ (pat, concat [dec_pat, dec_build, dec_match]) + +-- mkNormalC_pattern :: Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec] +-- mkNormalC_pattern tn pat tvs fs build match = do +-- xs <- replicateM (length fs) (newName "_x") +-- r <- sequence [ patSynSigD pat sig +-- , patSynD pat +-- (prefixPatSyn xs) +-- (explBidir [clause [] (normalB (varE build)) []]) +-- (parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |]) +-- ] +-- return r +-- where +-- sig = forallT +-- (map (`plainInvisTV` specifiedSpec) tvs) +-- (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) +-- (foldr (\t ts -> [t| $t -> $ts |]) +-- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] +-- (map (\t -> [t| Exp $(return t) |]) fs)) + +-- mkRecC_pattern :: Name -> Name -> [Name] -> [Name] -> [Type] -> Name -> Name -> Q [Dec] +-- mkRecC_pattern tn pat tvs xs fs build match = do +-- r <- sequence [ patSynSigD pat sig +-- , patSynD pat +-- (recordPatSyn xs) +-- (explBidir [clause [] (normalB (varE build)) []]) +-- (parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |]) +-- ] +-- return r +-- where +-- sig = forallT +-- (map (`plainInvisTV` specifiedSpec) tvs) +-- (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) +-- (foldr (\t ts -> [t| $t -> $ts |]) +-- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] +-- (map (\t -> [t| Exp $(return t) |]) fs)) + +-- mkInfixC_pattern :: Name -> Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec] +-- mkInfixC_pattern tn cn pat tvs fs build match = do +-- mf <- reifyFixity cn +-- _a <- newName "_a" +-- _b <- newName "_b" +-- r <- sequence [ patSynSigD pat sig +-- , patSynD pat +-- (infixPatSyn _a _b) +-- (explBidir [clause [] (normalB (varE build)) []]) +-- (parensP $ viewP (varE match) [p| Just $(tupP [varP _a, varP _b]) |]) +-- ] +-- r' <- case mf of +-- Nothing -> return r +-- Just f -> return (InfixD f pat : r) +-- return r' +-- where +-- sig = forallT +-- (map (`plainInvisTV` specifiedSpec) tvs) +-- (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) +-- (foldr (\t ts -> [t| $t -> $ts |]) +-- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] +-- (map (\t -> [t| Exp $(return t) |]) fs)) + +-- mkBuild :: Name -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) +-- mkBuild tn cn tvs tag fs0 fs fs1 = do +-- fun <- newName ("_build" ++ cn) +-- xs <- replicateM (length fs) (newName "_x") +-- let +-- vs = foldl' (\es e -> [| SmartExp ($es `Pair` $e) |]) [| SmartExp Nil |] +-- $ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat (reverse fs0)) +-- ++ map varE xs +-- ++ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat fs1) + +-- tagged = [| Exp $ SmartExp $ Pair (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeTAG))) $(litE (IntegerL (toInteger tag))))) $vs |] +-- body = clause (map (\x -> [p| (Exp $(varP x)) |]) xs) (normalB tagged) [] + +-- r <- sequence [ sigD fun sig +-- , funD fun [body] +-- ] +-- return (fun, r) +-- where +-- sig = forallT +-- (map (`plainInvisTV` specifiedSpec) tvs) +-- (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) +-- (foldr (\t ts -> [t| $t -> $ts |]) +-- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] +-- (map (\t -> [t| Exp $(return t) |]) fs)) + + +-- mkMatch :: Name -> String -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) +-- mkMatch tn pn cn tvs tag fs0 fs fs1 = do +-- fun <- newName ("_match" ++ cn) +-- e <- newName "_e" +-- x <- newName "_x" +-- (ps,es) <- extract vs [| Prj PairIdxRight $(varE x) |] [] [] +-- unbind <- isExtEnabled RebindableSyntax +-- let +-- eqE = if unbind then letE [funD (mkName "==") [clause [] (normalB (varE '(==))) []]] else id +-- lhs = [p| (Exp $(varP e)) |] +-- body = normalB $ eqE $ caseE (varE e) +-- [ TH.match (conP 'SmartExp [(conP 'Match [matchP ps, varP x])]) (normalB [| Just $(tupE es) |]) [] +-- , TH.match (conP 'SmartExp [(recP 'Match [])]) (normalB [| Nothing |]) [] +-- , TH.match wildP (normalB [| error $error_msg |]) [] +-- ] + +-- r <- sequence [ sigD fun sig +-- , funD fun [clause [lhs] body []] +-- ] +-- return (fun, r) +-- where +-- sig = forallT +-- (map (`plainInvisTV` specifiedSpec) tvs) +-- (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) +-- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) -> Maybe $(tupT (map (\t -> [t| Exp $(return t) |]) fs)) |] + +-- matchP us = [p| TagRtag $(litP (IntegerL (toInteger tag))) $pat |] +-- where +-- pat = [p| $(foldl (\ps p -> [p| TagRpair $ps $p |]) [p| TagRunit |] us) |] + +-- extract [] _ ps es = return (ps, es) +-- extract (u:us) x ps es = do +-- _u <- newName "_u" +-- let x' = [| Prj PairIdxLeft (SmartExp $x) |] +-- if not u +-- then extract us x' (wildP:ps) es +-- else extract us x' (varP _u:ps) ([| Exp (SmartExp (Match $(varE _u) (SmartExp (Prj PairIdxRight (SmartExp $x))))) |] : es) + +-- vs = reverse +-- $ [ False | _ <- concat fs0 ] ++ [ True | _ <- fs ] ++ [ False | _ <- concat fs1 ] + +-- error_msg = +-- let pv = unwords +-- $ take (length fs + 1) +-- $ concatMap (map reverse) +-- $ iterate (concatMap (\xs -> [ x:xs | x <- ['a'..'z'] ])) [""] +-- in stringE $ unlines +-- [ "Embedded pattern synonym used outside 'match' context." +-- , "" +-- , "To use case statements in the embedded language the case statement must" +-- , "be applied as an n-ary function to the 'match' operator. For single" +-- , "argument case statements this can be done inline using LambdaCase, for" +-- , "example:" +-- , "" +-- , "> x & match \\case" +-- , printf "> %s%s -> ..." pn pv +-- , printf "> _%s -> ..." (replicate (length pn + length pv - 1) ' ') +-- ] + +-- fst3 :: (a,b,c) -> a +-- fst3 (a,_,_) = a + +-- thd3 :: (a,b,c) -> c +-- thd3 (_,_,c) = c + +-- rename :: Name -> Name +-- rename nm = +-- let +-- split acc [] = (reverse acc, '\0') -- shouldn't happen +-- split acc [l] = (reverse acc, l) +-- split acc (l:ls) = split (l:acc) ls +-- -- +-- nm' = nameBase nm +-- (base, suffix) = split [] nm' +-- in +-- case suffix of +-- '_' -> mkName base +-- _ -> mkName (nm' ++ "_") + +-- checkExts :: [Extension] -> Q () +-- checkExts req = do +-- enabled <- extsEnabled +-- let missing = req \\ enabled +-- unless (null missing) . fail . unlines +-- $ printf "You must enable the following language extensions to generate pattern synonyms:" +-- : map (printf " {-# LANGUAGE %s #-}" . show) missing + +-- -- A simplified version of that stolen from GHC/Utils/Encoding.hs +-- -- +-- type EncodedString = String + +-- zencode :: String -> EncodedString +-- zencode [] = [] +-- zencode (h:rest) = encode_digit h ++ go rest +-- where +-- go [] = [] +-- go (c:cs) = encode_ch c ++ go cs + +-- unencoded_char :: Char -> Bool +-- unencoded_char 'z' = False +-- unencoded_char 'Z' = False +-- unencoded_char c = isAlphaNum c + +-- encode_digit :: Char -> EncodedString +-- encode_digit c | isDigit c = encode_as_unicode_char c +-- | otherwise = encode_ch c + +-- encode_ch :: Char -> EncodedString +-- encode_ch c | unencoded_char c = [c] -- Common case first +-- encode_ch '(' = "ZL" +-- encode_ch ')' = "ZR" +-- encode_ch '[' = "ZM" +-- encode_ch ']' = "ZN" +-- encode_ch ':' = "ZC" +-- encode_ch 'Z' = "ZZ" +-- encode_ch 'z' = "zz" +-- encode_ch '&' = "za" +-- encode_ch '|' = "zb" +-- encode_ch '^' = "zc" +-- encode_ch '$' = "zd" +-- encode_ch '=' = "ze" +-- encode_ch '>' = "zg" +-- encode_ch '#' = "zh" +-- encode_ch '.' = "zi" +-- encode_ch '<' = "zl" +-- encode_ch '-' = "zm" +-- encode_ch '!' = "zn" +-- encode_ch '+' = "zp" +-- encode_ch '\'' = "zq" +-- encode_ch '\\' = "zr" +-- encode_ch '/' = "zs" +-- encode_ch '*' = "zt" +-- encode_ch '_' = "zu" +-- encode_ch '%' = "zv" +-- encode_ch c = encode_as_unicode_char c + +-- encode_as_unicode_char :: Char -> EncodedString +-- encode_as_unicode_char c +-- = 'z' +-- : if isDigit (head hex_str) then hex_str +-- else '0':hex_str +-- where +-- hex_str = showHex (ord c) "U" diff --git a/src/Data/Array/Accelerate/Representation/Shape.hs b/src/Data/Array/Accelerate/Representation/Shape.hs index fa9bd58ff..1abf0ac2f 100644 --- a/src/Data/Array/Accelerate/Representation/Shape.hs +++ b/src/Data/Array/Accelerate/Representation/Shape.hs @@ -210,7 +210,7 @@ instance POSable (ShapeR ()) where type Choices (ShapeR ()) = 1 choices _ = 0 - emptyChoices = 0 + tags = [1] fromPOSable _ _ = ShapeRz @@ -225,7 +225,7 @@ instance (POSable (ShapeR sh)) => POSable (ShapeR (sh, Int)) where type Choices (ShapeR (sh, Int)) = 1 choices _ = 0 - emptyChoices = 0 + tags = [1] fromPOSable 0 (Cons _ xs) = ShapeRsnoc (fromPOSable 0 xs) diff --git a/src/Data/Array/Accelerate/Representation/Tag.hs b/src/Data/Array/Accelerate/Representation/Tag.hs index 31d3e82c0..f61116b79 100644 --- a/src/Data/Array/Accelerate/Representation/Tag.hs +++ b/src/Data/Array/Accelerate/Representation/Tag.hs @@ -11,12 +11,10 @@ -- Portability : non-portable (GHC extensions) -- -module Data.Array.Accelerate.Representation.Tag (TAG, TagR(..), rnfTag, liftTag) +module Data.Array.Accelerate.Representation.Tag (TAG, TagR(..)) where -import Data.Array.Accelerate.Type - -import Language.Haskell.TH.Extra +import Data.Array.Accelerate.Type ( TAG ) @@ -34,31 +32,5 @@ import Language.Haskell.TH.Extra -- (((),(1#,())),(0#,())) -- (True, False) -- (((),(1#,())),(1#,())) -- (True, True) -- -data TagR a where - TagRunit :: TagR () - TagRsingle :: ScalarType a -> TagR a - TagRundef :: ScalarType a -> TagR a - TagRtag :: TAG -> TagR a -> TagR (TAG, a) - TagRpair :: TagR a -> TagR b -> TagR (a, b) - -instance Show (TagR a) where - show TagRunit = "()" - show TagRsingle{} = "." - show TagRundef{} = "undef" - show (TagRtag v t) = "(" ++ show v ++ "#," ++ show t ++ ")" - show (TagRpair ta tb) = "(" ++ show ta ++ "," ++ show tb ++ ")" - -rnfTag :: TagR a -> () -rnfTag TagRunit = () -rnfTag (TagRsingle t) = rnfScalarType t -rnfTag (TagRundef t) = rnfScalarType t -rnfTag (TagRtag v t) = v `seq` rnfTag t -rnfTag (TagRpair ta tb) = rnfTag ta `seq` rnfTag tb - -liftTag :: TagR a -> CodeQ (TagR a) -liftTag TagRunit = [|| TagRunit ||] -liftTag (TagRsingle t) = [|| TagRsingle $$(liftScalarType t) ||] -liftTag (TagRundef t) = [|| TagRundef $$(liftScalarType t) ||] -liftTag (TagRtag v t) = [|| TagRtag v $$(liftTag t) ||] -liftTag (TagRpair ta tb) = [|| TagRpair $$(liftTag ta) $$(liftTag tb) ||] - +data TagR a = TagR TAG TAG + deriving Show diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 3124630d6..c0f640c83 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -492,7 +492,7 @@ data PreSmartExp acc exp t where -> PreSmartExp acc exp t -- Needed for embedded pattern matching - Match :: TAG + Match :: TagR t -> exp t -> PreSmartExp acc exp t @@ -541,7 +541,7 @@ data PreSmartExp acc exp t where -> PreSmartExp acc exp sh Case :: exp a - -> [(TAG, TAG, exp b)] + -> [(TagR b, exp b)] -> PreSmartExp acc exp b Cond :: exp PrimBool @@ -865,7 +865,7 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where VecUnpack vecR _ -> vecRtuple vecR ToIndex _ _ _ -> TupRsingle scalarTypeInt FromIndex shr _ _ -> shapeType shr - Case _ ((_,_,c):_) -> typeR c + Case _ ((_,c):_) -> typeR c Case{} -> internalError "encountered empty case" Cond _ e _ -> typeR e While t _ _ _ -> t diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index e92a903a0..546991c90 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -32,6 +32,7 @@ import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.POS import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Sugar.POS () +import Data.Array.Accelerate.Representation.POS import Data.Array.Accelerate.Type import Data.Char @@ -107,6 +108,15 @@ class (KnownNat (EltChoices a)) => Elt a where default toElt :: (POSable a, POStoEltR (Choices a) (Fields a) ~ EltR a) => EltR a -> a toElt = fromEltR + default tagsR :: (POSable a) => [TagR (EltR a)] + tagsR = f 0 (map fromInteger (tags @a)) + where + f :: TAG -> [TAG] -> [TagR (EltR a)] + f n l = case l of + [] -> [] + x : xs -> (TagR n (n + x)) : f (n + x) xs + + -- function to bring the contraints in scope that are needed to work with EltR, -- without needing to inspect how POS2EltR works data EltRType x where @@ -233,21 +243,19 @@ mkSingleType _ mkEltRT :: forall a . (POSable a) => TypeR (POStoEltR (Choices a) (Fields a)) -mkEltRT = case sameNat cs (Proxy :: Proxy 1) of +mkEltRT = case sameNat (Proxy @(Choices a)) (Proxy :: Proxy 1) of -- This distinction is hard to express in a type-correct way, -- hence the unsafeCoerce's Just Refl -> case emptyFields @a of PTCons (STSucc x STZero) PTNil -> TupRsingle (mkScalarType x) x -> unsafeCoerce $ flattenProductType x Nothing -> unsafeCoerce $ TupRpair (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeTAG)))) (flattenProductType (emptyFields @a)) - where - cs = emptyChoices @a -untag :: TypeR t -> TagR t -untag TupRunit = TagRunit -untag (TupRsingle t) = TagRundef t -untag (TupRpair ta tb) = TagRpair (untag ta) (untag tb) +-- untag :: TypeR t -> TagR t +-- untag TupRunit = TagRunit +-- untag (TupRsingle t) = TagRundef t +-- untag (TupRpair ta tb) = TagRpair (untag ta) (untag tb) -- Note: [Deriving Elt] @@ -288,7 +296,7 @@ instance Elt Char where type EltR Char = Word32 type EltChoices Char = 1 eltR = TupRsingle scalarType - tagsR = [TagRsingle scalarType] + tagsR = [TagR 0 1] toElt = chr . fromIntegral fromElt = fromIntegral . ord diff --git a/src/Data/Array/Accelerate/Sugar/Shape.hs b/src/Data/Array/Accelerate/Sugar/Shape.hs index fb2d9bee2..debd97360 100644 --- a/src/Data/Array/Accelerate/Sugar/Shape.hs +++ b/src/Data/Array/Accelerate/Sugar/Shape.hs @@ -311,7 +311,7 @@ instance (Elt t, Elt h) => Elt (t :. h) where type EltR (t :. h) = (EltR t, EltR h) type EltChoices (t :. h) = 1 eltR = TupRpair (eltR @t) (eltR @h) - tagsR = [TagRpair t h | t <- tagsR @t, h <- tagsR @h] + tagsR = [TagR 0 1] fromElt (t:.h) = (fromElt t, fromElt h) toElt (t, h) = toElt t :. toElt h @@ -323,7 +323,7 @@ instance Shape sh => Elt (Any (sh :. Int)) where type EltR (Any (sh :. Int)) = (EltR (Any sh), ()) type EltChoices (Any (sh :. Int)) = 1 eltR = TupRpair (eltR @(Any sh)) TupRunit - tagsR = [TagRpair t TagRunit | t <- tagsR @(Any sh)] + tagsR = [TagR 0 1] fromElt _ = (fromElt (Any :: Any sh), ()) toElt _ = Any diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index e61bb2457..4972793a3 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -37,7 +37,7 @@ instance VecElt a => POSable (Vec2 a) where choices _ = 0 - emptyChoices = 0 + tags = [1] -- TODO: can a Vec contain non-singleton values? fromPOSable 0 (Cons (Pick a) (Cons (Pick b) Nil)) = Vec2 a b @@ -55,7 +55,7 @@ instance VecElt a => POSable (Vec4 a) where choices _ = 0 - emptyChoices = 0 + tags = [1] -- TODO: can a Vec contain non-singleton values? fromPOSable 0 ( Cons (Pick a) (Cons (Pick b) (Cons (Pick c) (Cons (Pick d) Nil)))) = Vec4 a b c d diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 8746d8c3a..41052afd5 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -163,7 +163,7 @@ mkEltR x = case sameNat cs (Proxy :: Proxy 1) of -- TODO: this might not be correct -- This might typecheck with cmpNat, but that requires a very new version of base fromEltR :: forall a . (POSable a) => POStoEltR (Choices a) (Fields a) -> a -fromEltR x = case sameNat (emptyChoices @a) (Proxy :: Proxy 1) of +fromEltR x = case sameNat (Proxy @(Choices a)) (Proxy :: Proxy 1) of Just Refl -> case emptyFields @a of PTCons (STSucc _ STZero) PTNil -> unsafeCoerce x _ -> fromPOSable 0 (unsafeCoerce x) From 28fc675d14721f14d6f781b71a5cc08e3925054f Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Fri, 3 Jun 2022 13:17:21 +0200 Subject: [PATCH 54/67] Patterns for Maybe --- .../Array/Accelerate/Pattern/Matchable.hs | 68 ------------------- src/Data/Array/Accelerate/Pattern/Maybe.hs | 42 +++++++++++- 2 files changed, 39 insertions(+), 71 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index a5b58bfc5..750a433a4 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -121,74 +121,6 @@ makeTag x = SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType Ty tagType :: TupR ScalarType TAG tagType = TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeTAG))) - -instance Matchable (Maybe Int) where - build n x = case sameNat n (Proxy :: Proxy 0) of - Just Refl -> - Exp ( - SmartExp ( - Pair - (makeTag 0) - (SmartExp ( - Pair - (SmartExp ( - Const - (scalarType @(UnionScalar (Undef, (Int, ())))) - (PickScalar POS.Undef) - )) - (SmartExp Smart.Nil) - )) - ) - ) - Nothing -> case sameNat n (Proxy :: Proxy 1) of - Just Refl | (Exp x' :* SOP.Nil) <- x -> Exp ( - SmartExp ( - Pair - (makeTag 1) - (SmartExp ( - Pair - (SmartExp - (Union - (SmartExp - (LiftUnion x') - ) - ) - ) - (SmartExp Smart.Nil) - )) - ) - ) - Nothing -> error "Impossible type encountered" - - match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of - Just Refl -> - case e of - SmartExp (Match (TagR l u) _x) - | l == 0 - , u == 1 - -> Just SOP.Nil - - SmartExp Match {} -> Nothing - - _ -> error "Embedded pattern synonym used outside 'match' context." - Nothing -> -- matchJust - case sameNat n (Proxy :: Proxy 1) of - Just Refl -> - case e of - SmartExp (Match (TagR l u) x) - | l == 1 - , u == 2 - -> Just ( - mkExp (PrjUnion $ SmartExp $ Union (prjLeft (prjRight x))) - :* SOP.Nil) - SmartExp Match {} -> Nothing - - _ -> error "Embedded pattern synonym used outside 'match' context." - - Nothing -> - error "Impossible type encountered" - - instance (POSable (Maybe a), POSable a) => Matchable (Maybe a) where build n fs = case sameNat (Proxy @(Choices a)) (Proxy @0) of -- a has 0 valid choices (which means we cannot create a Just of this type) diff --git a/src/Data/Array/Accelerate/Pattern/Maybe.hs b/src/Data/Array/Accelerate/Pattern/Maybe.hs index 642f88028..caa03c447 100644 --- a/src/Data/Array/Accelerate/Pattern/Maybe.hs +++ b/src/Data/Array/Accelerate/Pattern/Maybe.hs @@ -4,6 +4,7 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE DataKinds #-} -- | -- Module : Data.Array.Accelerate.Pattern.Maybe -- Copyright : [2018..2020] The Accelerate Team @@ -20,9 +21,44 @@ module Data.Array.Accelerate.Pattern.Maybe ( ) where -import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Smart as Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Pattern.Matchable +import Generics.SOP as SOP +import Data.Array.Accelerate.Representation.POS as POS -- mkPattern ''Maybe +{-# COMPLETE Nothing_, Just_ #-} +pattern Nothing_ :: + forall a . + ( Elt a + , POSable a + , Matchable a + ) => Exp (Maybe a) +pattern Nothing_ <- (matchNothing -> Just ()) where + Nothing_ = buildNothing -pattern Nothing_ <- match (Proxy :: Proxy 0) SOP.Nil where - Nothing_{} = build (Proxy @0) SOP.Nil +matchNothing :: forall a . (POSable a, Elt a) => Exp (Maybe a) -> Maybe () +matchNothing x = case match (Proxy @0) x of + Just SOP.Nil -> Just () + Nothing -> Nothing + +buildNothing :: forall a . (Elt a, POSable a) => Exp (Maybe a) +buildNothing = build (Proxy @0) SOP.Nil + +pattern Just_ :: + forall a . + ( Elt a + , POSable a + , Matchable a + ) => Exp a -> Exp (Maybe a) +pattern Just_ x <- (matchJust -> Just x) where + Just_ = buildJust + +matchJust :: forall a . (Elt a, POSable a) => Exp (Maybe a) -> Maybe (Exp a) +matchJust x = case match (Proxy @1) x of + Just (x' :* SOP.Nil) -> Just x' + Nothing -> Nothing + +buildJust :: forall a . (Elt a, POSable a) => Exp a -> Exp (Maybe a) +buildJust x = build (Proxy @1) (x :* SOP.Nil) From a9b4d34f7e29a1879e58c7728ccf50f45b8799d0 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Fri, 3 Jun 2022 13:27:57 +0200 Subject: [PATCH 55/67] pattern synonyms for Either and Bool --- src/Data/Array/Accelerate/Pattern/Bool.hs | 32 +++++++++++++-- src/Data/Array/Accelerate/Pattern/Either.hs | 44 +++++++++++++++++++-- src/Data/Array/Accelerate/Pattern/Maybe.hs | 2 - 3 files changed, 70 insertions(+), 8 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Bool.hs b/src/Data/Array/Accelerate/Pattern/Bool.hs index d968aaf34..4b98cbe79 100644 --- a/src/Data/Array/Accelerate/Pattern/Bool.hs +++ b/src/Data/Array/Accelerate/Pattern/Bool.hs @@ -1,9 +1,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE DataKinds #-} -- | -- Module : Data.Array.Accelerate.Pattern.Bool -- Copyright : [2018..2020] The Accelerate Team @@ -20,7 +20,33 @@ module Data.Array.Accelerate.Pattern.Bool ( ) where -import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Smart as Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Pattern.Matchable +import Generics.SOP as SOP +import Data.Array.Accelerate.Representation.POS as POS -mkPattern ''Bool +{-# COMPLETE False_, True_ #-} +pattern False_ :: Exp Bool +pattern False_ <- (matchFalse -> Just ()) where + False_ = buildFalse +matchFalse :: Exp Bool -> Maybe () +matchFalse x = case match (Proxy @0) x of + Just SOP.Nil -> Just () + Nothing -> Nothing + +buildFalse :: Exp Bool +buildFalse = build (Proxy @0) SOP.Nil + +pattern True_ :: Exp Bool +pattern True_ <- (matchTrue -> Just x) where + True_ = buildTrue + +matchTrue :: Exp Bool -> Maybe () +matchTrue x = case match (Proxy @1) x of + Just SOP.Nil -> Just () + Nothing -> Nothing + +buildTrue :: Exp Bool +buildTrue = build (Proxy @1) SOP.Nil diff --git a/src/Data/Array/Accelerate/Pattern/Either.hs b/src/Data/Array/Accelerate/Pattern/Either.hs index 67c7b3a3f..59e052667 100644 --- a/src/Data/Array/Accelerate/Pattern/Either.hs +++ b/src/Data/Array/Accelerate/Pattern/Either.hs @@ -1,9 +1,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE DataKinds #-} -- | -- Module : Data.Array.Accelerate.Pattern.Either -- Copyright : [2018..2020] The Accelerate Team @@ -20,7 +20,45 @@ module Data.Array.Accelerate.Pattern.Either ( ) where -import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Smart as Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Pattern.Matchable +import Generics.SOP as SOP +import Data.Array.Accelerate.Representation.POS as POS -mkPattern ''Either +{-# COMPLETE Left_, Right_ #-} +pattern Left_ :: + forall a b . + ( Elt a + , POSable a + , POSable b + , Matchable a + ) => Exp a -> Exp (Either a b) +pattern Left_ x <- (matchLeft -> Just x) where + Left_ = buildLeft +matchLeft :: forall a b . (POSable a, Elt a, POSable b) => Exp (Either a b) -> Maybe (Exp a) +matchLeft x = case match (Proxy @0) x of + Just (x' :* SOP.Nil) -> Just x' + Nothing -> Nothing + +buildLeft :: forall a b . (Elt a, POSable a, POSable b) => Exp a -> Exp (Either a b) +buildLeft x = build (Proxy @0) (x :* SOP.Nil) + +pattern Right_ :: + forall a b . + ( Elt a + , POSable a + , POSable b + , Matchable a + ) => Exp b -> Exp (Either a b) +pattern Right_ x <- (matchRight -> Just x) where + Right_ = buildRight + +matchRight :: forall a b . (Elt a, POSable a, POSable b) => Exp (Either a b) -> Maybe (Exp b) +matchRight x = case match (Proxy @1) x of + Just (x' :* SOP.Nil) -> Just x' + Nothing -> Nothing + +buildRight :: forall a b . (Elt a, POSable a, POSable b) => Exp b -> Exp (Either a b) +buildRight x = build (Proxy @1) (x :* SOP.Nil) diff --git a/src/Data/Array/Accelerate/Pattern/Maybe.hs b/src/Data/Array/Accelerate/Pattern/Maybe.hs index caa03c447..de6121ea9 100644 --- a/src/Data/Array/Accelerate/Pattern/Maybe.hs +++ b/src/Data/Array/Accelerate/Pattern/Maybe.hs @@ -1,7 +1,6 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE DataKinds #-} @@ -27,7 +26,6 @@ import Data.Array.Accelerate.Pattern.Matchable import Generics.SOP as SOP import Data.Array.Accelerate.Representation.POS as POS --- mkPattern ''Maybe {-# COMPLETE Nothing_, Just_ #-} pattern Nothing_ :: forall a . From 908f47ae12705b3ed400533bf56d48546948f25c Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Fri, 3 Jun 2022 13:35:05 +0200 Subject: [PATCH 56/67] make integer synonyms Ground and POSable --- src/Data/Array/Accelerate/Sugar/Elt.hs | 15 ++++++--------- src/Data/Array/Accelerate/Sugar/POS.hs | 20 +++++++++++--------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index 546991c90..9467da26e 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -377,21 +377,18 @@ runQ $ do -- TyConI (NewtypeD [] Foreign.C.Types.CFloat [] Nothing (NormalC Foreign.C.Types.CFloat [(Bang NoSourceUnpackedness NoSourceStrictness,ConT GHC.Types.Float)]) []) -- mkNewtype :: Name -> Q [Dec] - mkNewtype name = do - r <- reify name - base <- case r of - TyConI (NewtypeD _ _ _ _ (NormalC _ [(_, ConT b)]) _) -> return b - _ -> error "unexpected case generating newtype Elt instance" - -- - [d| instance Elt $(conT name) + mkNewtype name = + let t = conT name + in + [d| instance Elt $t |] -- ss <- mapM mkSimple (integralTypes ++ floatingTypes) -- TODO: - -- ns <- mapM mkNewtype newtypes + ns <- mapM mkNewtype newtypes -- ts <- mapM mkTuple [2..8] -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] - return (concat ss) + return (concat ss ++ concat ns) instance Elt Undef diff --git a/src/Data/Array/Accelerate/Sugar/POS.hs b/src/Data/Array/Accelerate/Sugar/POS.hs index 1d6762945..3c943968e 100644 --- a/src/Data/Array/Accelerate/Sugar/POS.hs +++ b/src/Data/Array/Accelerate/Sugar/POS.hs @@ -108,19 +108,21 @@ runQ $ do instanceD ctx [t| POSable $res |] [] mkNewtype :: Name -> Q [Dec] - mkNewtype name = do - r <- reify name - base <- case r of - TyConI (NewtypeD _ _ _ _ (NormalC _ [(_, ConT b)]) _) -> return b - _ -> error "unexpected case generating newtype Elt instance" - -- - mkPOSableGround name + mkNewtype name = + let t = conT name + in + [d| + instance Ground $t where + mkGround = 0 + |] + -- si <- mapM (mkSimple ''IntegralType 'IntegralNumType) integralTypes sf <- mapM (mkSimple ''FloatingType 'FloatingNumType) floatingTypes ns <- mapM mkPOSableGround (floatingTypes ++ integralTypes) - -- ns <- mapM mkNewtype newtypes + ts <- mapM mkNewtype newtypes + nts <- mapM mkPOSableGround newtypes -- ts <- mapM mkTuple [2..16] -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] - return (concat si ++ concat sf ++ concat ns) + return (concat si ++ concat sf ++ concat ns ++ concat ts ++ concat nts) \ No newline at end of file From ff615b98e11f27a99d817a4ad54cfdf817bda5b3 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Wed, 22 Jun 2022 15:27:01 +0200 Subject: [PATCH 57/67] correct definition of mkEltR and fromEltR --- accelerate.cabal | 1 + src/Data/Array/Accelerate/Sugar/Elt.hs | 41 +++++++++++++++++++++-- src/Data/Array/Accelerate/Type.hs | 46 -------------------------- 3 files changed, 39 insertions(+), 49 deletions(-) diff --git a/accelerate.cabal b/accelerate.cabal index 85c87b236..47cae90e9 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -364,6 +364,7 @@ library , posable >= 0.9.0.0 , ghc-typelits-knownnat >= 0.7.6 , generics-sop >= 0.5.1.1 + , finite-typelits >= 0.1.4.2 exposed-modules: -- The core language and reference implementation diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index 9467da26e..760058d71 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -44,6 +44,7 @@ import Unsafe.Coerce import Data.Type.Equality import Data.Proxy import Data.Typeable +import Data.Finite.Internal (Finite(..)) -- | The 'Elt' class characterises the allowable array element types, and -- hence the types which can appear in scalar Accelerate expressions of @@ -120,9 +121,9 @@ class (KnownNat (EltChoices a)) => Elt a where -- function to bring the contraints in scope that are needed to work with EltR, -- without needing to inspect how POS2EltR works data EltRType x where - SingletonType :: (EltR x ~ x, Fields x ~ '[ '[x]]) => EltRType x - TaglessType :: (EltR x ~ FlattenProduct (Fields x)) => EltRType x - TaggedType :: (EltR x ~ (TAG, FlattenProduct (Fields x))) => EltRType x + SingletonType :: (EltR x ~ POStoEltR (Choices x) (Fields x), EltR x ~ x, Fields x ~ '[ '[x]]) => EltRType x + TaglessType :: (EltR x ~ POStoEltR (Choices x) (Fields x), EltR x ~ FlattenProduct (Fields x)) => EltRType x + TaggedType :: (EltR x ~ POStoEltR (Choices x) (Fields x), EltR x ~ (TAG, FlattenProduct (Fields x))) => EltRType x eltRType :: forall x . POSable x => EltRType x eltRType = case sameNat (Proxy :: Proxy (Choices x)) (Proxy :: Proxy 1) of @@ -133,9 +134,11 @@ eltRType = case sameNat (Proxy :: Proxy (Choices x)) (Proxy :: Proxy 1) of -> SingletonType _ | Refl :: (EltR x :~: FlattenProduct (Fields x)) <- unsafeCoerce Refl + , Refl :: (POStoEltR 1 (Fields x) :~: EltR x) <- unsafeCoerce Refl -> TaglessType Nothing | Refl :: (EltR x :~: (TAG, FlattenProduct (Fields x))) <- unsafeCoerce Refl + , Refl :: (POStoEltR (Choices x) (Fields x) :~: (TAG, FlattenProduct (Fields x))) <- unsafeCoerce Refl -> TaggedType @@ -252,6 +255,38 @@ mkEltRT = case sameNat (Proxy @(Choices a)) (Proxy :: Proxy 1) of Nothing -> unsafeCoerce $ TupRpair (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeTAG)))) (flattenProductType (emptyFields @a)) +mkEltR :: forall a . (POSable a) => a -> POStoEltR (Choices a) (Fields a) +mkEltR x = case eltRType @a of + SingletonType | Cons (Pick f) Nil <- fields x -> f + TaglessType -> fs + TaggedType -> (cs, fs) + where + cs = fromInteger @TAG $ toInteger $ choices x + fs = flattenProduct (fields x) + +fromEltR :: forall a . (POSable a) => POStoEltR (Choices a) (Fields a) -> a +fromEltR x = case eltRType @a of + SingletonType -> x + TaglessType -> fromPOSable 0 (unFlattenProduct (emptyFields @a) x) + TaggedType | (t, fs) <- x -> fromPOSable (Finite $ toInteger t) (unFlattenProduct (emptyFields @a) fs) + +unFlattenProduct :: ProductType a -> FlattenProduct a -> Product a +unFlattenProduct PTNil () = Nil +unFlattenProduct (PTCons x xs) (y, ys) = Cons (unFlattenSum x y) (unFlattenProduct xs ys) + +unFlattenSum :: SumType a -> UnionScalar (FlattenSum a) -> Sum a +unFlattenSum (STSucc x xs) (PickScalar y) = Pick y +unFlattenSum (STSucc x xs) (SkipScalar ys) = Skip $ unFlattenSum xs ys + + +flattenProduct :: Product a -> FlattenProduct a +flattenProduct Nil = () +flattenProduct (Cons x xs) = (flattenSum x, flattenProduct xs) + +flattenSum :: Sum a -> UnionScalar (FlattenSum a) +flattenSum (Pick x) = PickScalar x +flattenSum (Skip xs) = SkipScalar (flattenSum xs) + -- untag :: TypeR t -> TagR t -- untag TupRunit = TagRunit -- untag (TupRsingle t) = TagRundef t diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 41052afd5..02e2d2e67 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -123,52 +123,6 @@ type family FlattenSumType (xs :: [a]) :: Type where FlattenSumType (x ': xs) = (x, FlattenSumType xs) -flattenProduct :: Product a -> FlattenProduct a -flattenProduct Nil = () -flattenProduct (Cons x xs) = (flattenSum x, flattenProduct xs) - -flattenSum :: Sum a -> UnionScalar (FlattenSum a) -flattenSum (Pick x) = PickScalar x -flattenSum (Skip xs) = SkipScalar (flattenSum xs) - --- This might typecheck without unsafeCoerce by using cmpNat, but that requires a very new version of base -mkEltR :: forall a . (POSable a) => a -> POStoEltR (Choices a) (Fields a) -mkEltR x = case sameNat cs (Proxy :: Proxy 1) of - Just Refl -> case emptyFields @a of - -- Lots of cases because GHC does not understand inequality - -- First up: singleton type - PTCons (STSucc _ STZero) PTNil | Cons (Pick f) Nil <- fields x -> f - -- Singleton type, but we don't get an actual value out of - -- the call to `fields` (Should never occur). - PTCons (STSucc _ STZero) PTNil -> error "Value does not match representation" - -- unit type - PTNil -> fs - -- weird type with no value in the first field (thus having - -- no constructors) - PTCons STZero _ -> fs - -- weird type with a sum in the first field (weird because - -- we already have `Choices a ~ 1` in scope) - PTCons (STSucc _ (STSucc _ _)) _ -> fs - -- type with two or more fields - PTCons _ (PTCons _ _) -> fs - -- unsafeCoerce because we cannot prove to the compiler that - -- `Choices a !~ 1` (and thus the 3th branch of the POStoEltR - -- type family holds) - Nothing -> unsafeCoerce (cs, fs) - where - cs = choices x - fs = flattenProduct (fields x) - - --- TODO: this might not be correct --- This might typecheck with cmpNat, but that requires a very new version of base -fromEltR :: forall a . (POSable a) => POStoEltR (Choices a) (Fields a) -> a -fromEltR x = case sameNat (Proxy @(Choices a)) (Proxy :: Proxy 1) of - Just Refl -> case emptyFields @a of - PTCons (STSucc _ STZero) PTNil -> unsafeCoerce x - _ -> fromPOSable 0 (unsafeCoerce x) - Nothing -> uncurry fromPOSable (unsafeCoerce x) - -- Scalar types -- ------------ From 6fe8dd39bcf8b9472bf903342a66f8a0b3d6cf67 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 23 Jun 2022 10:22:03 +0200 Subject: [PATCH 58/67] define mkEltRT in terms of eltRType --- src/Data/Array/Accelerate/Sugar/Elt.hs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index 760058d71..1b09314a5 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -246,13 +246,10 @@ mkSingleType _ mkEltRT :: forall a . (POSable a) => TypeR (POStoEltR (Choices a) (Fields a)) -mkEltRT = case sameNat (Proxy @(Choices a)) (Proxy :: Proxy 1) of - -- This distinction is hard to express in a type-correct way, - -- hence the unsafeCoerce's - Just Refl -> case emptyFields @a of - PTCons (STSucc x STZero) PTNil -> TupRsingle (mkScalarType x) - x -> unsafeCoerce $ flattenProductType x - Nothing -> unsafeCoerce $ TupRpair (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeTAG)))) (flattenProductType (emptyFields @a)) +mkEltRT = case eltRType @a of + SingletonType | PTCons (STSucc x STZero) PTNil <- emptyFields @a -> TupRsingle (mkScalarType x) + TaglessType -> flattenProductType (emptyFields @a) + TaggedType -> TupRpair (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeTAG)))) (flattenProductType (emptyFields @a)) mkEltR :: forall a . (POSable a) => a -> POStoEltR (Choices a) (Fields a) From 1930f77f6913f8946477bac749f5733f4aad3c5c Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 23 Jun 2022 10:42:44 +0200 Subject: [PATCH 59/67] simplify scalarTypeTAGg --- src/Data/Array/Accelerate/Sugar/Elt.hs | 2 +- src/Data/Array/Accelerate/Type.hs | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index 1b09314a5..b882300b8 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -249,7 +249,7 @@ mkEltRT :: forall a . (POSable a) => TypeR (POStoEltR (Choices a) (Fields a)) mkEltRT = case eltRType @a of SingletonType | PTCons (STSucc x STZero) PTNil <- emptyFields @a -> TupRsingle (mkScalarType x) TaglessType -> flattenProductType (emptyFields @a) - TaggedType -> TupRpair (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeTAG)))) (flattenProductType (emptyFields @a)) + TaggedType -> TupRpair (TupRsingle scalarTypeTAG) (flattenProductType (emptyFields @a)) mkEltR :: forall a . (POSable a) => a -> POStoEltR (Choices a) (Fields a) diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 02e2d2e67..5e5ed68ba 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -381,6 +381,9 @@ scalarTypeWord8 = SingleScalarType $ NumSingleType $ IntegralNumType TypeWord8 scalarTypeWord32 :: ScalarType Word32 scalarTypeWord32 = SingleScalarType $ NumSingleType $ IntegralNumType TypeWord32 +scalarTypeTAG :: ScalarType TAG +scalarTypeTAG = SingleScalarType $ NumSingleType $ IntegralNumType TypeTAG + rnfScalarType :: ScalarType t -> () rnfScalarType (SingleScalarType t) = rnfSingleType t rnfScalarType (VectorScalarType t) = rnfVectorType t From f845bd296760074c593e0e782fbb7a224c6f7b59 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 23 Jun 2022 11:36:03 +0200 Subject: [PATCH 60/67] replace mkMin by correct mkSub --- src/Data/Array/Accelerate/Pattern/Matchable.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index 750a433a4..86697201b 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -167,7 +167,7 @@ instance (POSable (Maybe a), POSable a) => Matchable (Maybe a) where , u == tagVal @(Choices a) -- remove one from the tag as we are not in left anymore -- the `tag` function will apply the new tag if necessary - -> Just (Exp (tag @a (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant 1)) (splitRight @() @a $ prjRight x)) :* SOP.Nil) + -> Just (Exp (tag @a (unExp $ mkSub @TAG (Exp $ prjLeft x) (constant 1)) (splitRight @() @a $ prjRight x)) :* SOP.Nil) SmartExp Match {} -> Nothing _ -> error "Embedded pattern synonym used outside 'match' context." @@ -270,7 +270,7 @@ instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) SmartExp (Match (TagR l u) x) | l == 0 , u == tagVal @(Choices a) - -> Just (Exp (tag @a (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant $ tagVal @(Choices a))) (splitLeft @a @b $ prjRight x)) :* SOP.Nil) + -> Just (Exp (tag @a (unExp $ mkSub @TAG (Exp $ prjLeft x) (constant $ tagVal @(Choices a))) (splitLeft @a @b $ prjRight x)) :* SOP.Nil) SmartExp Match {} -> Nothing @@ -284,7 +284,7 @@ instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) , u == tagVal @(Choices b) -- remove one from the tag as we are not in left anymore -- the `tag` function will apply the new tag if necessary - -> Just (Exp (tag @b (unExp $ mkMin @TAG (Exp $ prjLeft x) (constant $ tagVal @(Choices a))) (splitRight @a @b $ prjRight x)) :* SOP.Nil) + -> Just (Exp (tag @b (unExp $ mkSub @TAG (Exp $ prjLeft x) (constant $ tagVal @(Choices a))) (splitRight @a @b $ prjRight x)) :* SOP.Nil) SmartExp Match {} -> Nothing _ -> error "Embedded pattern synonym used outside 'match' context." From 781ac2d231f7ff6253cf78202d032069a9e07cf6 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Fri, 24 Jun 2022 11:31:05 +0200 Subject: [PATCH 61/67] use type lists for unionscalars --- src/Data/Array/Accelerate/Smart.hs | 4 ++-- src/Data/Array/Accelerate/Sugar/Elt.hs | 6 +++--- src/Data/Array/Accelerate/Type.hs | 27 +++++++++----------------- 3 files changed, 14 insertions(+), 23 deletions(-) diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index c0f640c83..dfa798d03 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -512,12 +512,12 @@ data PreSmartExp acc exp t where -> PreSmartExp acc exp t LiftUnion :: exp t1 - -> PreSmartExp acc exp (UnionScalar (t1, ())) + -> PreSmartExp acc exp (UnionScalar '[t2]) Union :: exp (UnionScalar t1) -> PreSmartExp acc exp (UnionScalar t2) - PrjUnion :: exp (UnionScalar (t1, ())) + PrjUnion :: exp (UnionScalar '[t1]) -> PreSmartExp acc exp t1 VecPack :: KnownNat n diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index b882300b8..ea9d0eaa3 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -146,7 +146,7 @@ flattenProductType :: ProductType a -> TypeR (FlattenProduct a) flattenProductType PTNil = TupRunit flattenProductType (PTCons x xs) = TupRpair (TupRsingle (flattenSumType x)) (flattenProductType xs) -flattenSumType :: SumType a -> ScalarType (UnionScalar (FlattenSum a)) +flattenSumType :: SumType a -> ScalarType (UnionScalar a) flattenSumType STZero = UnionScalarType ZeroScalarType flattenSumType (STSucc x xs) = case flattenSumType xs of UnionScalarType xs' -> UnionScalarType (SuccScalarType (mkSingleType x) xs') @@ -271,7 +271,7 @@ unFlattenProduct :: ProductType a -> FlattenProduct a -> Product a unFlattenProduct PTNil () = Nil unFlattenProduct (PTCons x xs) (y, ys) = Cons (unFlattenSum x y) (unFlattenProduct xs ys) -unFlattenSum :: SumType a -> UnionScalar (FlattenSum a) -> Sum a +unFlattenSum :: SumType a -> UnionScalar a -> Sum a unFlattenSum (STSucc x xs) (PickScalar y) = Pick y unFlattenSum (STSucc x xs) (SkipScalar ys) = Skip $ unFlattenSum xs ys @@ -280,7 +280,7 @@ flattenProduct :: Product a -> FlattenProduct a flattenProduct Nil = () flattenProduct (Cons x xs) = (flattenSum x, flattenProduct xs) -flattenSum :: Sum a -> UnionScalar (FlattenSum a) +flattenSum :: Sum a -> UnionScalar a flattenSum (Pick x) = PickScalar x flattenSum (Skip xs) = SkipScalar (flattenSum xs) diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 5e5ed68ba..aa44caadf 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -108,20 +108,11 @@ type family POStoEltR (cs :: Nat) fs :: Type where type family FlattenProduct (xss :: f [a]) = (r :: Type) | r -> f where FlattenProduct '[] = () - FlattenProduct (x ': xs) = (UnionScalar (FlattenSum x), FlattenProduct xs) - -type family FlattenSum (xs :: [a]) :: Type where - FlattenSum '[] = () - FlattenSum (x ': xs) = (x, FlattenSum xs) + FlattenProduct (x ': xs) = (UnionScalar x, FlattenProduct xs) type family FlattenProductType (xss :: [[a]]) :: Type where FlattenProductType '[] = () - FlattenProductType (x ': xs) = (UnionScalarType (FlattenSumType x), FlattenProductType xs) - -type family FlattenSumType (xs :: [a]) :: Type where - FlattenSumType '[] = () - FlattenSumType (x ': xs) = (x, FlattenSumType xs) - + FlattenProductType (x ': xs) = (UnionScalarType x, FlattenProductType xs) -- Scalar types -- ------------ @@ -190,12 +181,12 @@ class IsUnionScalar a where unionScalarType :: UnionScalarType a data UnionScalar x where - PickScalar :: a -> UnionScalar (a, b) - SkipScalar :: UnionScalar b -> UnionScalar (a, b) + PickScalar :: x -> UnionScalar (x ': xs) + SkipScalar :: UnionScalar xs -> UnionScalar (x ': xs) data UnionScalarType a where - SuccScalarType :: SingleType a -> UnionScalarType b -> UnionScalarType (a, b) - ZeroScalarType :: UnionScalarType () + SuccScalarType :: SingleType x -> UnionScalarType xs -> UnionScalarType (x ': xs) + ZeroScalarType :: UnionScalarType '[] data SingleType a where NumSingleType :: NumType a -> SingleType a @@ -594,8 +585,8 @@ instance IsScalar Undef where instance (IsUnionScalar a) => IsScalar (UnionScalar a) where scalarType = UnionScalarType (unionScalarType @a) -instance IsUnionScalar () where +instance IsUnionScalar '[] where unionScalarType = ZeroScalarType -instance (IsSingle a, IsUnionScalar b) => IsUnionScalar (a, b) where - unionScalarType = SuccScalarType (singleType @a) (unionScalarType @b) +instance (IsSingle x, IsUnionScalar xs) => IsUnionScalar (x ': xs) where + unionScalarType = SuccScalarType (singleType @x) (unionScalarType @xs) From dc522eab9d108882b88ba49bba1bc8ae536409b7 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 30 Jun 2022 11:04:24 +0200 Subject: [PATCH 62/67] use posable from hackage --- accelerate.cabal | 2 +- stack-8.10.yaml | 3 ++- stack-8.6.yaml | 1 + stack-8.8.yaml | 1 + stack-9.0.yaml | 2 +- 5 files changed, 6 insertions(+), 3 deletions(-) diff --git a/accelerate.cabal b/accelerate.cabal index 47cae90e9..bb472a57f 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -361,7 +361,7 @@ library , unique , unordered-containers >= 0.2 , vector >= 0.10 - , posable >= 0.9.0.0 + , posable >= 1.0.0.1 , ghc-typelits-knownnat >= 0.7.6 , generics-sop >= 0.5.1.1 , finite-typelits >= 0.1.4.2 diff --git a/stack-8.10.yaml b/stack-8.10.yaml index d0823dcbd..dd0347506 100644 --- a/stack-8.10.yaml +++ b/stack-8.10.yaml @@ -7,7 +7,8 @@ resolver: lts-18.25 packages: - . -# extra-deps: +extra-deps: +- posable-1.0.0.1 # Override default flag values for local packages and extra-deps # flags: {} diff --git a/stack-8.6.yaml b/stack-8.6.yaml index 5d3724662..4e7118946 100644 --- a/stack-8.6.yaml +++ b/stack-8.6.yaml @@ -12,6 +12,7 @@ extra-deps: - prettyprinter-ansi-terminal-1.1.3 - tasty-rerun-1.1.18 - text-1.2.4.1 +- posable-1.0.0.1 # Override default flag values for local packages and extra-deps # flags: {} diff --git a/stack-8.8.yaml b/stack-8.8.yaml index f9565e8b4..d4a9f59a0 100644 --- a/stack-8.8.yaml +++ b/stack-8.8.yaml @@ -9,6 +9,7 @@ packages: extra-deps: - formatting-7.1.3 - prettyprinter-1.7.1 +- posable-1.0.0.1 # Override default flag values for local packages and extra-deps # flags: {} diff --git a/stack-9.0.yaml b/stack-9.0.yaml index b5b82f65a..9579a9e81 100644 --- a/stack-9.0.yaml +++ b/stack-9.0.yaml @@ -8,7 +8,7 @@ packages: - . extra-deps: -- ../sizeof +- posable-1.0.0.1 # Override default flag values for local packages and extra-deps From a338b9fbeb6b08c139e7f7a94cfd44126b8d829b Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 30 Jun 2022 11:08:42 +0200 Subject: [PATCH 63/67] update version ranges to match posable --- accelerate.cabal | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/accelerate.cabal b/accelerate.cabal index bb472a57f..f8eae9527 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -362,9 +362,9 @@ library , unordered-containers >= 0.2 , vector >= 0.10 , posable >= 1.0.0.1 - , ghc-typelits-knownnat >= 0.7.6 - , generics-sop >= 0.5.1.1 - , finite-typelits >= 0.1.4.2 + , ghc-typelits-knownnat >= 0.6 + , generics-sop >= 0.4.0 + , finite-typelits >= 0.1.4 exposed-modules: -- The core language and reference implementation From 52a862139e1ad6462c7d54ce3df4b80b91b75876 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 30 Jun 2022 11:38:27 +0200 Subject: [PATCH 64/67] remove unused PrimShiftFinite operator --- src/Data/Array/Accelerate/AST.hs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 33bdc7c4e..bdd550562 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -10,7 +10,6 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE DataKinds #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.AST @@ -147,7 +146,6 @@ import Data.Array.Accelerate.Representation.Stencil import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.Vec -import Data.Array.Accelerate.Representation.POS (Finite) import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type @@ -165,7 +163,6 @@ import qualified Language.Haskell.TH.Syntax as TH import GHC.TypeLits -import Data.Proxy -- Array expressions -- ----------------- @@ -672,9 +669,6 @@ data PrimFun sig where PrimAbs :: NumType a -> PrimFun (a -> a) PrimSig :: NumType a -> PrimFun (a -> a) - -- operator on Finite - PrimShiftFinite :: Proxy a -> PrimFun (Finite b -> Finite (a + b)) - -- operators from Integral PrimQuot :: IntegralType a -> PrimFun ((a, a) -> a) PrimRem :: IntegralType a -> PrimFun ((a, a) -> a) @@ -948,9 +942,9 @@ primFunType = \case floating = num . FloatingNumType tbool :: TypeR PrimBool - tbool = TupRpair (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType (TypeTAG))))) TupRunit + tbool = TupRpair (TupRsingle (scalarType @TAG)) TupRunit tint :: TypeR Int - tint = TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt))) + tint = TupRsingle (scalarType @Int) -- Normal form data From b974e6218bc7909d9de41228f8440eaccf87c708 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 30 Jun 2022 11:39:55 +0200 Subject: [PATCH 65/67] revert unchanged file --- src/Data/Array/Accelerate/AST/Idx.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Data/Array/Accelerate/AST/Idx.hs b/src/Data/Array/Accelerate/AST/Idx.hs index acb5212fe..548453e2b 100644 --- a/src/Data/Array/Accelerate/AST/Idx.hs +++ b/src/Data/Array/Accelerate/AST/Idx.hs @@ -32,7 +32,6 @@ module Data.Array.Accelerate.AST.Idx ( ) where import Language.Haskell.TH.Extra -import Data.Array.Accelerate.Type #ifndef ACCELERATE_INTERNAL_CHECKS import Data.Type.Equality ((:~:)(Refl)) @@ -112,3 +111,4 @@ pattern VoidIdx a <- (\case{} -> a) data PairIdx p a where PairIdxLeft :: PairIdx (a, b) a PairIdxRight :: PairIdx (a, b) b + From 12fa79a2f7d255245c5041a88b2d84926bedb874 Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 30 Jun 2022 11:56:59 +0200 Subject: [PATCH 66/67] bit of cleanup --- .../Array/Accelerate/Pattern/Matchable.hs | 13 +- src/Data/Array/Accelerate/Pattern/TH.hs | 437 ------------------ .../Array/Accelerate/Representation/Type.hs | 7 +- src/Data/Array/Accelerate/Smart.hs | 2 +- src/Data/Array/Accelerate/Sugar/Elt.hs | 9 +- src/Data/Array/Accelerate/Sugar/POS.hs | 1 - 6 files changed, 4 insertions(+), 465 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs index 86697201b..2c9ef60c6 100644 --- a/src/Data/Array/Accelerate/Pattern/Matchable.hs +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -16,7 +16,7 @@ {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Accelerate.Pattern.Matchable where +module Data.Array.Accelerate.Pattern.Matchable (Matchable(..)) where import Data.Array.Accelerate.Smart as Smart import GHC.TypeLits @@ -118,9 +118,6 @@ instance Matchable Bool where makeTag :: TAG -> SmartExp TAG makeTag x = SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeTAG))) x) -tagType :: TupR ScalarType TAG -tagType = TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeTAG))) - instance (POSable (Maybe a), POSable a) => Matchable (Maybe a) where build n fs = case sameNat (Proxy @(Choices a)) (Proxy @0) of -- a has 0 valid choices (which means we cannot create a Just of this type) @@ -292,14 +289,6 @@ instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) Nothing -> error "Impossible type encountered" -undefPairs :: forall xs . ProductType xs -> SmartExp (FlattenProduct (Merge '[] (xs ++ '[]))) -undefPairs PTNil = SmartExp Smart.Nil -undefPairs (PTCons x xs) = SmartExp (Pair (SmartExp (Union (SmartExp (LiftUnion (unExp $ constant POS.Undef))))) (undefPairs xs)) - -mergePairs :: forall xs . ProductType xs -> SmartExp (FlattenProduct xs) -> SmartExp (FlattenProduct (Merge '[] (xs ++ '[]))) -mergePairs PTNil _ = SmartExp Smart.Nil -mergePairs (PTCons x xs) y = SmartExp (Pair (SmartExp (Union (SmartExp (Prj PairIdxLeft y)))) (mergePairs xs (SmartExp (Prj PairIdxRight y)))) - -- like combineProducts, but lifted to the AST buildTAG :: (All POSable xs) => NP Exp xs -> Exp TAG buildTAG SOP.Nil = Exp $ makeTag 0 diff --git a/src/Data/Array/Accelerate/Pattern/TH.hs b/src/Data/Array/Accelerate/Pattern/TH.hs index f14de628d..bf26ee8bb 100644 --- a/src/Data/Array/Accelerate/Pattern/TH.hs +++ b/src/Data/Array/Accelerate/Pattern/TH.hs @@ -16,440 +16,3 @@ module Data.Array.Accelerate.Pattern.TH ( -- mkPatterns, ) where - -import Data.Array.Accelerate.AST.Idx -import Data.Array.Accelerate.Pattern -import Data.Array.Accelerate.Representation.Tag -import Data.Array.Accelerate.Smart -import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Type - -import Control.Monad -import Data.Bits -import Data.Char -import Data.List ( (\\), foldl' ) -import Language.Haskell.TH.Extra hiding ( Exp, Match, match ) -import Numeric -import Text.Printf -import qualified Language.Haskell.TH.Extra as TH - -import GHC.Stack - - --- | As 'mkPattern', but for a list of types --- --- mkPatterns :: [Name] -> DecsQ --- mkPatterns nms = concat <$> mapM mkPattern nms - --- -- | Generate pattern synonyms for the given simple (Haskell'98) sum or --- -- product data type. --- -- --- -- Constructor and record selectors are renamed to add a trailing --- -- underscore if it does not exist, or to remove it if it does. For infix --- -- constructors, the name is prepended with a colon ':'. For example: --- -- --- -- > data Point = Point { xcoord_ :: Float, ycoord_ :: Float } --- -- > deriving (Generic, Elt) --- -- --- -- Will create the pattern synonym: --- -- --- -- > Point_ :: Exp Float -> Exp Float -> Exp Point --- -- --- -- together with the selector functions --- -- --- -- > xcoord :: Exp Point -> Exp Float --- -- > ycoord :: Exp Point -> Exp Float --- -- --- mkPattern :: Name -> DecsQ --- mkPattern nm = do --- info <- reify nm --- case info of --- TyConI dec -> mkDec dec --- _ -> fail "mkPatterns: expected the name of a newtype or datatype" - --- mkDec :: Dec -> DecsQ --- mkDec dec = --- case dec of --- DataD _ nm tv _ cs _ -> mkDataD nm tv cs --- NewtypeD _ nm tv _ c _ -> mkNewtypeD nm tv c --- _ -> fail "mkPatterns: expected the name of a newtype or datatype" - --- mkNewtypeD :: Name -> [TyVarBndr ()] -> Con -> DecsQ --- mkNewtypeD tn tvs c = mkDataD tn tvs [c] - --- mkDataD :: Name -> [TyVarBndr ()] -> [Con] -> DecsQ --- mkDataD tn tvs cs = do --- (pats, decs) <- unzip <$> go cs --- comp <- pragCompleteD pats Nothing --- return $ comp : concat decs --- where --- -- For single-constructor types we create the pattern synonym for the --- -- type directly in terms of Pattern --- go [] = fail "mkPatterns: empty data declarations not supported" --- go [c] = return <$> mkConP tn tvs c --- go _ = go' [] (map fieldTys cs) ctags cs - --- -- For sum-types, when creating the pattern for an individual --- -- constructor we need to know about the types of the fields all other --- -- constructors as well --- go' prev (this:next) (tag:tags) (con:cons) = do --- r <- mkConS tn tvs prev next tag con --- rs <- go' (this:prev) next tags cons --- return (r : rs) --- go' _ [] [] [] = return [] --- go' _ _ _ _ = fail "mkPatterns: unexpected error" - --- fieldTys (NormalC _ fs) = map snd fs --- fieldTys (RecC _ fs) = map (\(_,_,t) -> t) fs --- fieldTys (InfixC a _ b) = [snd a, snd b] --- fieldTys _ = fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" - --- -- TODO: The GTags class demonstrates a way to generate the tags for --- -- a given constructor, rather than backwards-engineering the structure --- -- as we've done here. We should use that instead! --- -- --- ctags = --- let n = length cs --- m = n `quot` 2 --- l = take m (iterate (True:) [False]) --- r = take (n-m) (iterate (True:) [True]) --- -- --- bitsToTag = foldl' f 0 --- where --- f i False = i `shiftL` 1 --- f i True = setBit (i `shiftL` 1) 0 --- in --- map bitsToTag (l ++ r) - - --- mkConP :: Name -> [TyVarBndr ()] -> Con -> Q (Name, [Dec]) --- mkConP tn' tvs' con' = do --- checkExts [ PatternSynonyms ] --- case con' of --- NormalC cn fs -> mkNormalC tn' cn (map tyVarBndrName tvs') (map snd fs) --- RecC cn fs -> mkRecC tn' cn (map tyVarBndrName tvs') (map (rename . fst3) fs) (map thd3 fs) --- InfixC a cn b -> mkInfixC tn' cn (map tyVarBndrName tvs') [snd a, snd b] --- _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" --- where --- mkNormalC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec]) --- mkNormalC tn cn tvs fs = do --- xs <- replicateM (length fs) (newName "_x") --- r <- sequence [ patSynSigD pat sig --- , patSynD pat --- (prefixPatSyn xs) --- implBidir --- [p| Pattern $(tupP (map varP xs)) |] --- ] --- return (pat, r) --- where --- pat = rename cn --- sig = forallT --- (map (`plainInvisTV` specifiedSpec) tvs) --- (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) --- (foldr (\t ts -> [t| $t -> $ts |]) --- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] --- (map (\t -> [t| Exp $(return t) |]) fs)) - --- mkRecC :: Name -> Name -> [Name] -> [Name] -> [Type] -> Q (Name, [Dec]) --- mkRecC tn cn tvs xs fs = do --- r <- sequence [ patSynSigD pat sig --- , patSynD pat --- (recordPatSyn xs) --- implBidir --- [p| Pattern $(tupP (map varP xs)) |] --- ] --- return (pat, r) --- where --- pat = rename cn --- sig = forallT --- (map (`plainInvisTV` specifiedSpec) tvs) --- (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) --- (foldr (\t ts -> [t| $t -> $ts |]) --- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] --- (map (\t -> [t| Exp $(return t) |]) fs)) - --- mkInfixC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec]) --- mkInfixC tn cn tvs fs = do --- mf <- reifyFixity cn --- _a <- newName "_a" --- _b <- newName "_b" --- r <- sequence [ patSynSigD pat sig --- , patSynD pat --- (infixPatSyn _a _b) --- implBidir --- [p| Pattern $(tupP [varP _a, varP _b]) |] --- ] --- r' <- case mf of --- Nothing -> return r --- Just f -> return (InfixD f pat : r) --- return (pat, r') --- where --- pat = mkName (':' : nameBase cn) --- sig = forallT --- (map (`plainInvisTV` specifiedSpec) tvs) --- (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) --- (foldr (\t ts -> [t| $t -> $ts |]) --- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] --- (map (\t -> [t| Exp $(return t) |]) fs)) - --- mkConS :: Name -> [TyVarBndr ()] -> [[Type]] -> [[Type]] -> Word8 -> Con -> Q (Name, [Dec]) --- mkConS tn' tvs' prev' next' tag' con' = do --- checkExts [GADTs, PatternSynonyms, ScopedTypeVariables, TypeApplications, ViewPatterns] --- case con' of --- NormalC cn fs -> mkNormalC tn' cn tag' (map tyVarBndrName tvs') prev' (map snd fs) next' --- RecC cn fs -> mkRecC tn' cn tag' (map tyVarBndrName tvs') (map (rename . fst3) fs) prev' (map thd3 fs) next' --- InfixC a cn b -> mkInfixC tn' cn tag' (map tyVarBndrName tvs') prev' [snd a, snd b] next' --- _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" --- where --- mkNormalC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) --- mkNormalC tn cn tag tvs ps fs ns = do --- let pat = rename cn --- (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns --- (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns --- dec_pat <- mkNormalC_pattern tn pat tvs fs fun_build fun_match --- return $ (pat, concat [dec_pat, dec_build, dec_match]) - --- mkRecC :: Name -> Name -> Word8 -> [Name] -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) --- mkRecC tn cn tag tvs xs ps fs ns = do --- let pat = rename cn --- (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns --- (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns --- dec_pat <- mkRecC_pattern tn pat tvs xs fs fun_build fun_match --- return $ (pat, concat [dec_pat, dec_build, dec_match]) - --- mkInfixC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) --- mkInfixC tn cn tag tvs ps fs ns = do --- let pat = mkName (':' : nameBase cn) --- (fun_build, dec_build) <- mkBuild tn (zencode (nameBase cn)) tvs tag ps fs ns --- (fun_match, dec_match) <- mkMatch tn ("(" ++ nameBase pat ++ ")") (zencode (nameBase cn)) tvs tag ps fs ns --- dec_pat <- mkInfixC_pattern tn cn pat tvs fs fun_build fun_match --- return $ (pat, concat [dec_pat, dec_build, dec_match]) - --- mkNormalC_pattern :: Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec] --- mkNormalC_pattern tn pat tvs fs build match = do --- xs <- replicateM (length fs) (newName "_x") --- r <- sequence [ patSynSigD pat sig --- , patSynD pat --- (prefixPatSyn xs) --- (explBidir [clause [] (normalB (varE build)) []]) --- (parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |]) --- ] --- return r --- where --- sig = forallT --- (map (`plainInvisTV` specifiedSpec) tvs) --- (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) --- (foldr (\t ts -> [t| $t -> $ts |]) --- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] --- (map (\t -> [t| Exp $(return t) |]) fs)) - --- mkRecC_pattern :: Name -> Name -> [Name] -> [Name] -> [Type] -> Name -> Name -> Q [Dec] --- mkRecC_pattern tn pat tvs xs fs build match = do --- r <- sequence [ patSynSigD pat sig --- , patSynD pat --- (recordPatSyn xs) --- (explBidir [clause [] (normalB (varE build)) []]) --- (parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |]) --- ] --- return r --- where --- sig = forallT --- (map (`plainInvisTV` specifiedSpec) tvs) --- (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) --- (foldr (\t ts -> [t| $t -> $ts |]) --- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] --- (map (\t -> [t| Exp $(return t) |]) fs)) - --- mkInfixC_pattern :: Name -> Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec] --- mkInfixC_pattern tn cn pat tvs fs build match = do --- mf <- reifyFixity cn --- _a <- newName "_a" --- _b <- newName "_b" --- r <- sequence [ patSynSigD pat sig --- , patSynD pat --- (infixPatSyn _a _b) --- (explBidir [clause [] (normalB (varE build)) []]) --- (parensP $ viewP (varE match) [p| Just $(tupP [varP _a, varP _b]) |]) --- ] --- r' <- case mf of --- Nothing -> return r --- Just f -> return (InfixD f pat : r) --- return r' --- where --- sig = forallT --- (map (`plainInvisTV` specifiedSpec) tvs) --- (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) --- (foldr (\t ts -> [t| $t -> $ts |]) --- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] --- (map (\t -> [t| Exp $(return t) |]) fs)) - --- mkBuild :: Name -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) --- mkBuild tn cn tvs tag fs0 fs fs1 = do --- fun <- newName ("_build" ++ cn) --- xs <- replicateM (length fs) (newName "_x") --- let --- vs = foldl' (\es e -> [| SmartExp ($es `Pair` $e) |]) [| SmartExp Nil |] --- $ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat (reverse fs0)) --- ++ map varE xs --- ++ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat fs1) - --- tagged = [| Exp $ SmartExp $ Pair (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeTAG))) $(litE (IntegerL (toInteger tag))))) $vs |] --- body = clause (map (\x -> [p| (Exp $(varP x)) |]) xs) (normalB tagged) [] - --- r <- sequence [ sigD fun sig --- , funD fun [body] --- ] --- return (fun, r) --- where --- sig = forallT --- (map (`plainInvisTV` specifiedSpec) tvs) --- (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) --- (foldr (\t ts -> [t| $t -> $ts |]) --- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] --- (map (\t -> [t| Exp $(return t) |]) fs)) - - --- mkMatch :: Name -> String -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) --- mkMatch tn pn cn tvs tag fs0 fs fs1 = do --- fun <- newName ("_match" ++ cn) --- e <- newName "_e" --- x <- newName "_x" --- (ps,es) <- extract vs [| Prj PairIdxRight $(varE x) |] [] [] --- unbind <- isExtEnabled RebindableSyntax --- let --- eqE = if unbind then letE [funD (mkName "==") [clause [] (normalB (varE '(==))) []]] else id --- lhs = [p| (Exp $(varP e)) |] --- body = normalB $ eqE $ caseE (varE e) --- [ TH.match (conP 'SmartExp [(conP 'Match [matchP ps, varP x])]) (normalB [| Just $(tupE es) |]) [] --- , TH.match (conP 'SmartExp [(recP 'Match [])]) (normalB [| Nothing |]) [] --- , TH.match wildP (normalB [| error $error_msg |]) [] --- ] - --- r <- sequence [ sigD fun sig --- , funD fun [clause [lhs] body []] --- ] --- return (fun, r) --- where --- sig = forallT --- (map (`plainInvisTV` specifiedSpec) tvs) --- (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) --- [t| Exp $(foldl' appT (conT tn) (map varT tvs)) -> Maybe $(tupT (map (\t -> [t| Exp $(return t) |]) fs)) |] - --- matchP us = [p| TagRtag $(litP (IntegerL (toInteger tag))) $pat |] --- where --- pat = [p| $(foldl (\ps p -> [p| TagRpair $ps $p |]) [p| TagRunit |] us) |] - --- extract [] _ ps es = return (ps, es) --- extract (u:us) x ps es = do --- _u <- newName "_u" --- let x' = [| Prj PairIdxLeft (SmartExp $x) |] --- if not u --- then extract us x' (wildP:ps) es --- else extract us x' (varP _u:ps) ([| Exp (SmartExp (Match $(varE _u) (SmartExp (Prj PairIdxRight (SmartExp $x))))) |] : es) - --- vs = reverse --- $ [ False | _ <- concat fs0 ] ++ [ True | _ <- fs ] ++ [ False | _ <- concat fs1 ] - --- error_msg = --- let pv = unwords --- $ take (length fs + 1) --- $ concatMap (map reverse) --- $ iterate (concatMap (\xs -> [ x:xs | x <- ['a'..'z'] ])) [""] --- in stringE $ unlines --- [ "Embedded pattern synonym used outside 'match' context." --- , "" --- , "To use case statements in the embedded language the case statement must" --- , "be applied as an n-ary function to the 'match' operator. For single" --- , "argument case statements this can be done inline using LambdaCase, for" --- , "example:" --- , "" --- , "> x & match \\case" --- , printf "> %s%s -> ..." pn pv --- , printf "> _%s -> ..." (replicate (length pn + length pv - 1) ' ') --- ] - --- fst3 :: (a,b,c) -> a --- fst3 (a,_,_) = a - --- thd3 :: (a,b,c) -> c --- thd3 (_,_,c) = c - --- rename :: Name -> Name --- rename nm = --- let --- split acc [] = (reverse acc, '\0') -- shouldn't happen --- split acc [l] = (reverse acc, l) --- split acc (l:ls) = split (l:acc) ls --- -- --- nm' = nameBase nm --- (base, suffix) = split [] nm' --- in --- case suffix of --- '_' -> mkName base --- _ -> mkName (nm' ++ "_") - --- checkExts :: [Extension] -> Q () --- checkExts req = do --- enabled <- extsEnabled --- let missing = req \\ enabled --- unless (null missing) . fail . unlines --- $ printf "You must enable the following language extensions to generate pattern synonyms:" --- : map (printf " {-# LANGUAGE %s #-}" . show) missing - --- -- A simplified version of that stolen from GHC/Utils/Encoding.hs --- -- --- type EncodedString = String - --- zencode :: String -> EncodedString --- zencode [] = [] --- zencode (h:rest) = encode_digit h ++ go rest --- where --- go [] = [] --- go (c:cs) = encode_ch c ++ go cs - --- unencoded_char :: Char -> Bool --- unencoded_char 'z' = False --- unencoded_char 'Z' = False --- unencoded_char c = isAlphaNum c - --- encode_digit :: Char -> EncodedString --- encode_digit c | isDigit c = encode_as_unicode_char c --- | otherwise = encode_ch c - --- encode_ch :: Char -> EncodedString --- encode_ch c | unencoded_char c = [c] -- Common case first --- encode_ch '(' = "ZL" --- encode_ch ')' = "ZR" --- encode_ch '[' = "ZM" --- encode_ch ']' = "ZN" --- encode_ch ':' = "ZC" --- encode_ch 'Z' = "ZZ" --- encode_ch 'z' = "zz" --- encode_ch '&' = "za" --- encode_ch '|' = "zb" --- encode_ch '^' = "zc" --- encode_ch '$' = "zd" --- encode_ch '=' = "ze" --- encode_ch '>' = "zg" --- encode_ch '#' = "zh" --- encode_ch '.' = "zi" --- encode_ch '<' = "zl" --- encode_ch '-' = "zm" --- encode_ch '!' = "zn" --- encode_ch '+' = "zp" --- encode_ch '\'' = "zq" --- encode_ch '\\' = "zr" --- encode_ch '/' = "zs" --- encode_ch '*' = "zt" --- encode_ch '_' = "zu" --- encode_ch '%' = "zv" --- encode_ch c = encode_as_unicode_char c - --- encode_as_unicode_char :: Char -> EncodedString --- encode_as_unicode_char c --- = 'z' --- : if isDigit (head hex_str) then hex_str --- else '0':hex_str --- where --- hex_str = showHex (ord c) "U" - diff --git a/src/Data/Array/Accelerate/Representation/Type.hs b/src/Data/Array/Accelerate/Representation/Type.hs index 57318798c..477f09a00 100644 --- a/src/Data/Array/Accelerate/Representation/Type.hs +++ b/src/Data/Array/Accelerate/Representation/Type.hs @@ -20,7 +20,6 @@ module Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type import Data.Primitive.Vec -import Data.Array.Accelerate.Representation.POS import Formatting import Language.Haskell.TH.Extra @@ -44,14 +43,10 @@ data TupR s a where TupRsingle :: s a -> TupR s a TupRpair :: TupR s a -> TupR s b -> TupR s (a, b) --- productToTupR :: Product a -> TypeR (FlattenProduct a) --- productToTupR Nil = TupRunit --- productToTupR (Cons x xs) = TupRpair x (productToTupR xs) - instance Show (TupR ScalarType a) where show TupRunit = "()" show (TupRsingle t) = show t - show (TupRpair a b) = show a ++ " ✕ " ++ show b + show (TupRpair a b) = "(" ++ show a ++ "," ++ show b ++ ")" formatTypeR :: Format r (TypeR a -> r) formatTypeR = later $ \case diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index dfa798d03..5c6bfa87e 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -541,7 +541,7 @@ data PreSmartExp acc exp t where -> PreSmartExp acc exp sh Case :: exp a - -> [(TagR b, exp b)] + -> [(TagR a, exp b)] -> PreSmartExp acc exp b Cond :: exp PrimBool diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index ea9d0eaa3..6cf9fee8b 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -32,7 +32,6 @@ import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.POS import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Sugar.POS () -import Data.Array.Accelerate.Representation.POS import Data.Array.Accelerate.Type import Data.Char @@ -283,12 +282,6 @@ flattenProduct (Cons x xs) = (flattenSum x, flattenProduct xs) flattenSum :: Sum a -> UnionScalar a flattenSum (Pick x) = PickScalar x flattenSum (Skip xs) = SkipScalar (flattenSum xs) - --- untag :: TypeR t -> TagR t --- untag TupRunit = TagRunit --- untag (TupRsingle t) = TagRundef t --- untag (TupRpair ta tb) = TagRpair (untag ta) (untag tb) - -- Note: [Deriving Elt] -- @@ -418,7 +411,7 @@ runQ $ do ss <- mapM mkSimple (integralTypes ++ floatingTypes) -- TODO: ns <- mapM mkNewtype newtypes - -- ts <- mapM mkTuple [2..8] + -- ts <- mapM mkTuple [2..16] -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] return (concat ss ++ concat ns) diff --git a/src/Data/Array/Accelerate/Sugar/POS.hs b/src/Data/Array/Accelerate/Sugar/POS.hs index 3c943968e..f4f36a4e0 100644 --- a/src/Data/Array/Accelerate/Sugar/POS.hs +++ b/src/Data/Array/Accelerate/Sugar/POS.hs @@ -125,4 +125,3 @@ runQ $ do -- ts <- mapM mkTuple [2..16] -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] return (concat si ++ concat sf ++ concat ns ++ concat ts ++ concat nts) - \ No newline at end of file From e2dcdace825f560cf9dac54408efe58f7b664f6f Mon Sep 17 00:00:00 2001 From: Rick van Hoef Date: Thu, 30 Jun 2022 12:38:14 +0200 Subject: [PATCH 67/67] remove unused imports --- src/Data/Array/Accelerate/Type.hs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index aa44caadf..7b1f03d0d 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -76,7 +76,6 @@ import Data.Array.Accelerate.Representation.POS import Data.Primitive.Vec import Data.Bits -import Data.Proxy import Data.Int import Data.Primitive.Types import Data.Type.Equality @@ -92,7 +91,6 @@ import Text.Printf import GHC.Prim import GHC.TypeLits -import Unsafe.Coerce -- | The type of the runtime value used to distinguish constructor