From e6cd2ebb24144b31c3377facc896bb4b92e77289 Mon Sep 17 00:00:00 2001 From: Hugo <hpeters1024@gmail.com> Date: Tue, 19 Oct 2021 23:56:05 +0200 Subject: [PATCH 01/86] Vector indexing operations and empty vector constructor --- .gitignore | 2 ++ src/Data/Array/Accelerate.hs | 3 ++ src/Data/Array/Accelerate/AST.hs | 12 ++++++++ src/Data/Array/Accelerate/Classes/Enum.hs | 3 +- src/Data/Array/Accelerate/Classes/Vector.hs | 31 +++++++++++++++++++++ src/Data/Array/Accelerate/Smart.hs | 12 ++++++++ src/Data/Primitive/Vec.hs | 14 ++++++++++ 7 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 src/Data/Array/Accelerate/Classes/Vector.hs diff --git a/.gitignore b/.gitignore index 2dc9bad21..eec9590ea 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,5 @@ /docs/_build *.hi *.o + +hie.yaml diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index ff1729f27..5654cd9f9 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -310,6 +310,7 @@ module Data.Array.Accelerate ( -- ** SIMD vectors Vec, VecElt, + mkVec, -- ** Type classes -- *** Basic type classes @@ -317,6 +318,7 @@ module Data.Array.Accelerate ( Ord(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_, Enum, succ, pred, Bounded, minBound, maxBound, + Vectoring(..), -- Functor(..), (<$>), ($>), void, -- Monad(..), @@ -445,6 +447,7 @@ import Data.Array.Accelerate.Classes.Rational import Data.Array.Accelerate.Classes.RealFloat import Data.Array.Accelerate.Classes.RealFrac import Data.Array.Accelerate.Classes.ToFloating +import Data.Array.Accelerate.Classes.Vector import Data.Array.Accelerate.Data.Either import Data.Array.Accelerate.Data.Maybe import Data.Array.Accelerate.Language diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index c84f5723f..0a887802f 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -748,6 +748,9 @@ data PrimFun sig where PrimLOr :: PrimFun ((PrimBool, PrimBool) -> PrimBool) PrimLNot :: PrimFun (PrimBool -> PrimBool) + -- local array operators + PrimVectorIndex :: KnownNat n => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a) + -- general conversion between types PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b) PrimToFloating :: NumType a -> FloatingType b -> PrimFun (a -> b) @@ -924,6 +927,12 @@ primFunType = \case PrimLOr -> binary' tbool PrimLNot -> unary' tbool +-- Local Vector operations + PrimVectorIndex v'@(VectorType _ a) i' -> + let v = singleVector v' + i = integral i' + in (v `TupRpair` i, single a) + -- general conversion between types PrimFromIntegral a b -> unary (integral a) (num b) PrimToFloating a b -> unary (num a) (floating b) @@ -936,6 +945,7 @@ primFunType = \case compare' a = binary (single a) tbool single = TupRsingle . SingleScalarType + singleVector = TupRsingle . VectorScalarType num = TupRsingle . SingleScalarType . NumSingleType integral = num . IntegralNumType floating = num . FloatingNumType @@ -1165,6 +1175,7 @@ rnfPrimFun (PrimMin t) = rnfSingleType t rnfPrimFun PrimLAnd = () rnfPrimFun PrimLOr = () rnfPrimFun PrimLNot = () +rnfPrimFun (PrimVectorIndex v i) = rnfVectorType v `seq` rnfIntegralType i rnfPrimFun (PrimFromIntegral i n) = rnfIntegralType i `seq` rnfNumType n rnfPrimFun (PrimToFloating n f) = rnfNumType n `seq` rnfFloatingType f @@ -1391,6 +1402,7 @@ liftPrimFun (PrimMin t) = [|| PrimMin $$(liftSingleType t) ||] liftPrimFun PrimLAnd = [|| PrimLAnd ||] liftPrimFun PrimLOr = [|| PrimLOr ||] liftPrimFun PrimLNot = [|| PrimLNot ||] +liftPrimFun (PrimVectorIndex v i) = [||PrimVectorIndex $$(liftVectorType v) $$(liftIntegralType i) ||] liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||] liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||] diff --git a/src/Data/Array/Accelerate/Classes/Enum.hs b/src/Data/Array/Accelerate/Classes/Enum.hs index 84b344273..10e946ee5 100644 --- a/src/Data/Array/Accelerate/Classes/Enum.hs +++ b/src/Data/Array/Accelerate/Classes/Enum.hs @@ -187,8 +187,7 @@ defaultFromEnum = preludeError "fromEnum" preludeError :: String -> a preludeError x = error - $ unlines [ printf "Prelude.%s is not supported for Accelerate types" x - , "" + $ unlines [ printf "Prelude.%s is not supported for Accelerate types" x , "" , "These Prelude.Enum instances are present only to fulfil superclass" , "constraints for subsequent classes in the standard Haskell numeric hierarchy." ] diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs new file mode 100644 index 000000000..69f62e7eb --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -0,0 +1,31 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} +-- | +-- Module : Data.Array.Accelerate.Classes.Vector +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- +module Data.Array.Accelerate.Classes.Vector where + +import GHC.TypeLits +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Smart +import Data.Primitive.Vec + +class Vectoring a b c | a -> b where + indexAt :: a -> c -> b + +instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) (Exp Int) where + indexAt = mkVectorIndex + + diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 8fa577f41..14c043d1f 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -12,6 +12,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE PolyKinds #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Smart @@ -71,6 +72,9 @@ module Data.Array.Accelerate.Smart ( -- ** Smart constructors for type coercion functions mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..), + -- ** Smart constructors for vector operations + mkVectorIndex, + -- ** Auxiliary functions ($$), ($$$), ($$$$), ($$$$$), ApplyAcc(..), @@ -83,6 +87,7 @@ module Data.Array.Accelerate.Smart ( ) where +import Data.Proxy import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Error import Data.Array.Accelerate.Representation.Array @@ -95,6 +100,7 @@ import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Sugar.Array ( Arrays ) import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Sugar.Shape ( (:.)(..) ) import Data.Array.Accelerate.Type @@ -1172,6 +1178,12 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil where x = SmartExp $ Prj PairIdxLeft a +-- Operators from Vec +mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a +mkVectorIndex = let n :: Int + n = fromIntegral $ natVal $ Proxy @n + in mkPrimBinary $ PrimVectorIndex @n (VectorType n singleType) integralType + -- Numeric conversions mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 34a77635b..93b0395c0 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -10,6 +10,7 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Vec @@ -31,12 +32,16 @@ module Data.Primitive.Vec ( Vec8, pattern Vec8, Vec16, pattern Vec16, + mkVec, + listOfVec, liftVec, ) where +import Data.Proxy import Control.Monad.ST +import Control.Monad.Reader import Data.Primitive.ByteArray import Data.Primitive.Types import Data.Text.Prettyprint.Doc @@ -83,6 +88,14 @@ import GHC.Word -- data Vec (n :: Nat) a = Vec ByteArray# +mkVec :: forall n a. (KnownNat n, Prim a) => [a] -> Vec n a +mkVec vs = runST $ do + let n :: Int = fromIntegral $ natVal $ Proxy @n + mba <- newByteArray (n * sizeOf (undefined :: a)) + zipWithM_ (writeByteArray mba) [0..n] vs + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + type role Vec nominal representational instance (Show a, Prim a, KnownNat n) => Show (Vec n a) where @@ -259,6 +272,7 @@ packVec16 a b c d e f g h i j k l m n o p = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# + -- O(n) at runtime to copy from the Addr# to the ByteArray#. We should be able -- to do this without copying, but I don't think the definition of ByteArray# is -- exported (or it is deeply magical). From 0c80d44a4c150d479d64af4c4b6d727eb6aa9d72 Mon Sep 17 00:00:00 2001 From: Hugo <hpeters1024@gmail.com> Date: Wed, 20 Oct 2021 16:46:35 +0200 Subject: [PATCH 02/86] created empty vector lifted in Exp --- src/Data/Array/Accelerate.hs | 4 +-- src/Data/Array/Accelerate/AST.hs | 35 +++++++++++++-------- src/Data/Array/Accelerate/Classes/Vector.hs | 27 ++++++++++------ src/Data/Array/Accelerate/Smart.hs | 8 ++++- src/Data/Primitive/Vec.hs | 2 -- 5 files changed, 48 insertions(+), 28 deletions(-) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 5654cd9f9..8811695b8 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -310,7 +310,7 @@ module Data.Array.Accelerate ( -- ** SIMD vectors Vec, VecElt, - mkVec, + Vectoring(..), -- ** Type classes -- *** Basic type classes @@ -318,7 +318,7 @@ module Data.Array.Accelerate ( Ord(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_, Enum, succ, pred, Bounded, minBound, maxBound, - Vectoring(..), + -- Functor(..), (<$>), ($>), void, -- Monad(..), diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 0a887802f..a07920466 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -655,6 +655,9 @@ data PrimConst ty where -- constant from Floating PrimPi :: FloatingType a -> PrimConst a + -- constant for empty Vec + PrimVectorCreate :: KnownNat n => VectorType (Vec n a) -> PrimConst (Vec n a) + -- |Primitive scalar operations -- @@ -828,7 +831,7 @@ expType = \case While _ (Lam lhs _) _ -> lhsToTupR lhs While{} -> error "What's the matter, you're running in the shadows" Const tR _ -> TupRsingle tR - PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c + PrimConst c -> TupRsingle $ primConstType c PrimApp f _ -> snd $ primFunType f Index (Var repr _) _ -> arrayRtype repr LinearIndex (Var repr _) _ -> arrayRtype repr @@ -837,17 +840,21 @@ expType = \case Undef tR -> TupRsingle tR Coerce _ tR _ -> TupRsingle tR -primConstType :: PrimConst a -> SingleType a +primConstType :: PrimConst a -> ScalarType a primConstType = \case PrimMinBound t -> bounded t PrimMaxBound t -> bounded t PrimPi t -> floating t + PrimVectorCreate t -> vector t where - bounded :: BoundedType a -> SingleType a - bounded (IntegralBoundedType t) = NumSingleType $ IntegralNumType t + bounded :: BoundedType a -> ScalarType a + bounded (IntegralBoundedType t) = SingleScalarType $ NumSingleType $ IntegralNumType t + + floating :: FloatingType t -> ScalarType t + floating = SingleScalarType . NumSingleType . FloatingNumType - floating :: FloatingType t -> SingleType t - floating = NumSingleType . FloatingNumType + vector :: forall n a. (KnownNat n) => VectorType (Vec n a) -> ScalarType (Vec n a) + vector = VectorScalarType primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b) primFunType = \case @@ -1110,9 +1117,10 @@ rnfConst (TupRsingle t) !_ = rnfScalarType t -- scalars should have (nf = rnfConst (TupRpair ta tb) (a,b) = rnfConst ta a `seq` rnfConst tb b rnfPrimConst :: PrimConst c -> () -rnfPrimConst (PrimMinBound t) = rnfBoundedType t -rnfPrimConst (PrimMaxBound t) = rnfBoundedType t -rnfPrimConst (PrimPi t) = rnfFloatingType t +rnfPrimConst (PrimMinBound t) = rnfBoundedType t +rnfPrimConst (PrimMaxBound t) = rnfBoundedType t +rnfPrimConst (PrimPi t) = rnfFloatingType t +rnfPrimConst (PrimVectorCreate t) = rnfVectorType t rnfPrimFun :: PrimFun f -> () rnfPrimFun (PrimAdd t) = rnfNumType t @@ -1337,9 +1345,10 @@ liftBoundary (ArrayR _ tp) (Constant v) = [|| Constant $$(liftElt tp v) ||] liftBoundary _ (Function f) = [|| Function $$(liftOpenFun f) ||] liftPrimConst :: PrimConst c -> CodeQ (PrimConst c) -liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||] -liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||] -liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||] +liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||] +liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||] +liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||] +liftPrimConst (PrimVectorCreate t) = [|| PrimVectorCreate $$(liftVectorType t) ||] liftPrimFun :: PrimFun f -> CodeQ (PrimFun f) liftPrimFun (PrimAdd t) = [|| PrimAdd $$(liftNumType t) ||] @@ -1402,7 +1411,7 @@ liftPrimFun (PrimMin t) = [|| PrimMin $$(liftSingleType t) ||] liftPrimFun PrimLAnd = [|| PrimLAnd ||] liftPrimFun PrimLOr = [|| PrimLOr ||] liftPrimFun PrimLNot = [|| PrimLNot ||] -liftPrimFun (PrimVectorIndex v i) = [||PrimVectorIndex $$(liftVectorType v) $$(liftIntegralType i) ||] +liftPrimFun (PrimVectorIndex v i) = [|| PrimVectorIndex $$(liftVectorType v) $$(liftIntegralType i) ||] liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||] liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||] diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs index 69f62e7eb..32a618761 100644 --- a/src/Data/Array/Accelerate/Classes/Vector.hs +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -1,8 +1,10 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MonoLocalBinds #-} -{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE GADTs #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | @@ -16,16 +18,21 @@ -- module Data.Array.Accelerate.Classes.Vector where +import Data.Kind import GHC.TypeLits -import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Smart import Data.Primitive.Vec -class Vectoring a b c | a -> b where - indexAt :: a -> c -> b +class Vectoring vector a | vector -> a where + type IndexType vector :: Type + vecIndex :: vector -> IndexType vector -> a + vecEmpty :: vector -instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) (Exp Int) where - indexAt = mkVectorIndex + +instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where + type IndexType (Exp (Vec n a)) = Exp Int + vecIndex = mkVectorIndex + vecEmpty = mkVectorCreate diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 14c043d1f..ab6650300 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -73,6 +73,7 @@ module Data.Array.Accelerate.Smart ( mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..), -- ** Smart constructors for vector operations + mkVectorCreate, mkVectorIndex, -- ** Auxiliary functions @@ -865,7 +866,7 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where Case{} -> internalError "encountered empty case" Cond _ e _ -> typeR e While t _ _ _ -> t - PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c + PrimConst c -> TupRsingle $ primConstType c PrimApp f _ -> snd $ primFunType f Index tp _ _ -> tp LinearIndex tp _ _ -> tp @@ -1179,6 +1180,11 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil x = SmartExp $ Prj PairIdxLeft a -- Operators from Vec +mkVectorCreate :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) +mkVectorCreate = let n :: Int + n = fromIntegral $ natVal $ Proxy @n + in mkExp $ PrimConst $ PrimVectorCreate $ VectorType n singleType + mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a mkVectorIndex = let n :: Int n = fromIntegral $ natVal $ Proxy @n diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 93b0395c0..34b22ef13 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -32,8 +32,6 @@ module Data.Primitive.Vec ( Vec8, pattern Vec8, Vec16, pattern Vec16, - mkVec, - listOfVec, liftVec, From 2c90dd5300dc19fc23964e254b99ea6f002b56d3 Mon Sep 17 00:00:00 2001 From: Hugo <hpeters1024@gmail.com> Date: Tue, 26 Oct 2021 21:33:51 +0200 Subject: [PATCH 03/86] Add implementation of empty vector and indexing --- src/Data/Array/Accelerate.hs | 2 ++ src/Data/Array/Accelerate/AST.hs | 5 +-- src/Data/Array/Accelerate/Classes/Vector.hs | 6 ---- src/Data/Array/Accelerate/Interpreter.hs | 9 ++++++ src/Data/Array/Accelerate/Smart.hs | 2 +- src/Data/Array/Accelerate/Trafo/Algebra.hs | 3 ++ src/Data/Primitive/Vec.hs | 34 ++++++++++++++++++--- 7 files changed, 48 insertions(+), 13 deletions(-) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 8811695b8..e2543c6ae 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -311,6 +311,8 @@ module Data.Array.Accelerate ( -- ** SIMD vectors Vec, VecElt, Vectoring(..), + vecOfList, + listOfVec, -- ** Type classes -- *** Basic type classes diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index a07920466..066704093 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -149,6 +149,7 @@ import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Type import Data.Primitive.Vec +import Data.Primitive.Types import Control.DeepSeq import Data.Kind import Data.Maybe @@ -656,7 +657,7 @@ data PrimConst ty where PrimPi :: FloatingType a -> PrimConst a -- constant for empty Vec - PrimVectorCreate :: KnownNat n => VectorType (Vec n a) -> PrimConst (Vec n a) + PrimVectorCreate :: (KnownNat n, Prim a) => VectorType (Vec n a) -> PrimConst (Vec n a) -- |Primitive scalar operations @@ -752,7 +753,7 @@ data PrimFun sig where PrimLNot :: PrimFun (PrimBool -> PrimBool) -- local array operators - PrimVectorIndex :: KnownNat n => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a) + PrimVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a) -- general conversion between types PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b) diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs index 32a618761..0ab3c4942 100644 --- a/src/Data/Array/Accelerate/Classes/Vector.hs +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -24,12 +24,6 @@ import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Smart import Data.Primitive.Vec -class Vectoring vector a | vector -> a where - type IndexType vector :: Type - vecIndex :: vector -> IndexType vector -> a - vecEmpty :: vector - - instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where type IndexType (Exp (Vec n a)) = Exp Int vecIndex = mkVectorIndex diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 5b8e6401a..344a8691d 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -69,6 +69,7 @@ import qualified Data.Array.Accelerate.Sugar.Array as Sugar import qualified Data.Array.Accelerate.Sugar.Elt as Sugar import qualified Data.Array.Accelerate.Trafo.Delayed as AST +import GHC.TypeLits import Control.DeepSeq import Control.Exception import Control.Monad @@ -1082,6 +1083,7 @@ evalPrimConst :: PrimConst a -> a evalPrimConst (PrimMinBound ty) = evalMinBound ty evalPrimConst (PrimMaxBound ty) = evalMaxBound ty evalPrimConst (PrimPi ty) = evalPi ty +evalPrimConst (PrimVectorCreate ty) = evalVectorCreate ty evalPrim :: PrimFun (a -> r) -> (a -> r) evalPrim (PrimAdd ty) = evalAdd ty @@ -1144,6 +1146,7 @@ evalPrim (PrimMin ty) = evalMin ty evalPrim PrimLAnd = evalLAnd evalPrim PrimLOr = evalLOr evalPrim PrimLNot = evalLNot +evalPrim (PrimVectorIndex v i) = evalVectorIndex v i evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb evalPrim (PrimToFloating ta tb) = evalToFloating ta tb @@ -1168,6 +1171,9 @@ evalLOr (x, y) = fromBool (toBool x || toBool y) evalLNot :: PrimBool -> PrimBool evalLNot = fromBool . not . toBool +evalVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, i) -> a +evalVectorIndex (VectorType n _) ti (v, i) | IntegralDict <- integralDict ti = vecIndex v (fromIntegral i) + evalFromIntegral :: IntegralType a -> NumType b -> a -> b evalFromIntegral ta (IntegralNumType tb) | IntegralDict <- integralDict ta @@ -1213,6 +1219,9 @@ evalMaxBound (IntegralBoundedType ty) evalPi :: FloatingType a -> a evalPi ty | FloatingDict <- floatingDict ty = pi +evalVectorCreate :: (KnownNat n, Prim a) => VectorType (Vec n a) -> Vec n a +evalVectorCreate (VectorType n _) = vecEmpty + evalSin :: FloatingType a -> (a -> a) evalSin ty | FloatingDict <- floatingDict ty = sin diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index ab6650300..7693ebf45 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -1,5 +1,5 @@ {-# LANGUAGE AllowAmbiguousTypes #-} -{-# LANGUAGE CPP #-} + {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} diff --git a/src/Data/Array/Accelerate/Trafo/Algebra.hs b/src/Data/Array/Accelerate/Trafo/Algebra.hs index 9cfea36ae..1e620435b 100644 --- a/src/Data/Array/Accelerate/Trafo/Algebra.hs +++ b/src/Data/Array/Accelerate/Trafo/Algebra.hs @@ -33,12 +33,14 @@ import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Pretty.Print ( primOperator, isInfix, opName ) import Data.Array.Accelerate.Trafo.Environment import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Classes.Vector import qualified Data.Array.Accelerate.Debug.Internal.Stats as Stats import Data.Bits import Data.Monoid import Data.Text ( Text ) +import Data.Primitive.Vec import Data.Text.Prettyprint.Doc import Data.Text.Prettyprint.Doc.Render.Text import GHC.Float ( float2Double, double2Float ) @@ -142,6 +144,7 @@ evalPrimApp env f x PrimNEq ty -> evalNEq ty x env PrimMax ty -> evalMax ty x env PrimMin ty -> evalMin ty x env + PrimVectorIndex _ _ -> Nothing PrimLAnd -> evalLAnd x env PrimLOr -> evalLOr x env PrimLNot -> evalLNot x env diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 34b22ef13..10930d4e4 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -5,12 +5,16 @@ {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE FlexibleInstances #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Vec @@ -33,10 +37,13 @@ module Data.Primitive.Vec ( Vec16, pattern Vec16, listOfVec, + vecOfList, liftVec, + Vectoring(..) ) where +import Data.Kind import Data.Proxy import Control.Monad.ST import Control.Monad.Reader @@ -86,14 +93,25 @@ import GHC.Word -- data Vec (n :: Nat) a = Vec ByteArray# -mkVec :: forall n a. (KnownNat n, Prim a) => [a] -> Vec n a -mkVec vs = runST $ do +class Vectoring vector a | vector -> a where + type IndexType vector :: Data.Kind.Type + vecIndex :: vector -> IndexType vector -> a + vecEmpty :: vector + +instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where + type IndexType (Vec n a) = Int + vecIndex (Vec ba#) (I# i#) = indexByteArray# ba# i# + vecEmpty = mkVec + + +mkVec :: forall n a. (KnownNat n, Prim a) => Vec n a +mkVec = runST $ do let n :: Int = fromIntegral $ natVal $ Proxy @n mba <- newByteArray (n * sizeOf (undefined :: a)) - zipWithM_ (writeByteArray mba) [0..n] vs ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# + type role Vec nominal representational instance (Show a, Prim a, KnownNat n) => Show (Vec n a) where @@ -104,6 +122,14 @@ instance (Show a, Prim a, KnownNat n) => Show (Vec n a) where . group . encloseSep (flatAlt "< " "<") (flatAlt " >" ">") ", " . map viaShow +vecOfList :: forall n a. (KnownNat n, Prim a) => [a] -> Vec n a +vecOfList vs = runST $ do + let n :: Int = fromIntegral $ natVal $ Proxy @n + mba <- newByteArray (n * sizeOf (undefined :: a)) + zipWithM_ (writeByteArray mba) [0..n] vs + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + listOfVec :: forall a n. (Prim a, KnownNat n) => Vec n a -> [a] listOfVec (Vec ba#) = go 0# where From 74feaecf673bc615e8464f3a68f102ff721c3f8d Mon Sep 17 00:00:00 2001 From: Hugo <hpeters1024@gmail.com> Date: Wed, 27 Oct 2021 14:22:34 +0200 Subject: [PATCH 04/86] Add bounds check on vector index --- src/Data/Primitive/Vec.hs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 10930d4e4..3fa13bf09 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -100,7 +100,10 @@ class Vectoring vector a | vector -> a where instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where type IndexType (Vec n a) = Int - vecIndex (Vec ba#) (I# i#) = indexByteArray# ba# i# + vecIndex (Vec ba#) i@(I# iu#) = let + n :: Int + n = fromIntegral $ natVal $ Proxy @n + in if i >= 0 && i < n then indexByteArray# ba# iu# else error ("index " <> show i <> " out of range in Vec of size " <> show n) vecEmpty = mkVec From 977669bd9c286a446259036689bb00dc5bc28e59 Mon Sep 17 00:00:00 2001 From: Hugo <hpeters1024@gmail.com> Date: Thu, 28 Oct 2021 20:55:15 +0200 Subject: [PATCH 05/86] Fix vector creation (todo delete the prim const) --- src/Data/Array/Accelerate/Analysis/Hash.hs | 2 ++ src/Data/Array/Accelerate/Classes/Vector.hs | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 75625b9ec..8587742cc 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -389,6 +389,7 @@ encodePrimConst :: PrimConst c -> Builder encodePrimConst (PrimMinBound t) = intHost $(hashQ "PrimMinBound") <> encodeBoundedType t encodePrimConst (PrimMaxBound t) = intHost $(hashQ "PrimMaxBound") <> encodeBoundedType t encodePrimConst (PrimPi t) = intHost $(hashQ "PrimPi") <> encodeFloatingType t +encodePrimConst (PrimVectorCreate t) = intHost $(hashQ "PrimVectorCreate") <> encodeVectorType t encodePrimFun :: PrimFun f -> Builder encodePrimFun (PrimAdd a) = intHost $(hashQ "PrimAdd") <> encodeNumType a @@ -448,6 +449,7 @@ encodePrimFun (PrimEq a) = intHost $(hashQ "PrimEq") encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") <> encodeSingleType a encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeSingleType a encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeSingleType a +encodePrimFun (PrimVectorIndex (VectorType _ a) b) = intHost $(hashQ "PrimVectorIndex") <> encodeSingleType a <> encodeNumType (IntegralNumType b) encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b encodePrimFun PrimLAnd = intHost $(hashQ "PrimLAnd") diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs index 0ab3c4942..1eef95abf 100644 --- a/src/Data/Array/Accelerate/Classes/Vector.hs +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -27,6 +27,6 @@ import Data.Primitive.Vec instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where type IndexType (Exp (Vec n a)) = Exp Int vecIndex = mkVectorIndex - vecEmpty = mkVectorCreate + vecEmpty = undef From dc7d849f1256ea88cccb5437bf86d08b10f8fb79 Mon Sep 17 00:00:00 2001 From: Hugo <hpeters1024@gmail.com> Date: Tue, 2 Nov 2021 12:24:45 +0100 Subject: [PATCH 06/86] implement interpreter and fix bugs --- src/Data/Array/Accelerate/AST.hs | 1 + src/Data/Array/Accelerate/Analysis/Hash.hs | 1 + src/Data/Array/Accelerate/Classes/Vector.hs | 1 + src/Data/Array/Accelerate/Interpreter.hs | 4 ++++ src/Data/Array/Accelerate/Smart.hs | 9 +++++++++ src/Data/Array/Accelerate/Trafo/Algebra.hs | 1 + src/Data/Primitive/Vec.hs | 10 ++++++++++ 7 files changed, 27 insertions(+) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 066704093..3952c9c60 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -754,6 +754,7 @@ data PrimFun sig where -- local array operators PrimVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a) + PrimVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, (i, a)) -> Vec n a) -- general conversion between types PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b) diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 8587742cc..f7b22e47f 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -450,6 +450,7 @@ encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeSingleType a encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeSingleType a encodePrimFun (PrimVectorIndex (VectorType _ a) b) = intHost $(hashQ "PrimVectorIndex") <> encodeSingleType a <> encodeNumType (IntegralNumType b) +encodePrimFun (PrimVectorWrite (VectorType _ a) b) = intHost $(hashQ "PrimVectorWrite") <> encodeSingleType a <> encodeNumType (IntegralNumType b) encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b encodePrimFun PrimLAnd = intHost $(hashQ "PrimLAnd") diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs index 1eef95abf..87586985d 100644 --- a/src/Data/Array/Accelerate/Classes/Vector.hs +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -27,6 +27,7 @@ import Data.Primitive.Vec instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where type IndexType (Exp (Vec n a)) = Exp Int vecIndex = mkVectorIndex + vecWrite = mkVectorWrite vecEmpty = undef diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 344a8691d..06f184348 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -1147,6 +1147,7 @@ evalPrim PrimLAnd = evalLAnd evalPrim PrimLOr = evalLOr evalPrim PrimLNot = evalLNot evalPrim (PrimVectorIndex v i) = evalVectorIndex v i +evalPrim (PrimVectorWrite v i) = evalVectorWrite v i evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb evalPrim (PrimToFloating ta tb) = evalToFloating ta tb @@ -1174,6 +1175,9 @@ evalLNot = fromBool . not . toBool evalVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, i) -> a evalVectorIndex (VectorType n _) ti (v, i) | IntegralDict <- integralDict ti = vecIndex v (fromIntegral i) +evalVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, (i, a)) -> Vec n a +evalVectorWrite (VectorType n _) ti (v, (i, a)) | IntegralDict <- integralDict ti = vecWrite v (fromIntegral i) a + evalFromIntegral :: IntegralType a -> NumType b -> a -> b evalFromIntegral ta (IntegralNumType tb) | IntegralDict <- integralDict ta diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 7693ebf45..4da5568ad 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -75,6 +75,7 @@ module Data.Array.Accelerate.Smart ( -- ** Smart constructors for vector operations mkVectorCreate, mkVectorIndex, + mkVectorWrite, -- ** Auxiliary functions ($$), ($$$), ($$$$), ($$$$$), @@ -1190,6 +1191,11 @@ mkVectorIndex = let n :: Int n = fromIntegral $ natVal $ Proxy @n in mkPrimBinary $ PrimVectorIndex @n (VectorType n singleType) integralType +mkVectorWrite :: forall n a. (KnownNat n, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -> Exp (Vec n a) +mkVectorWrite = let n :: Int + n = fromIntegral $ natVal $ Proxy @n + in mkPrimTernary $ PrimVectorWrite @n (VectorType n singleType) integralType + -- Numeric conversions mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b @@ -1277,6 +1283,9 @@ mkPrimUnary prim (Exp a) = mkExp $ PrimApp prim a mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b) +mkPrimTernary :: (Elt a, Elt b, Elt c, Elt d) => PrimFun ((EltR a, (EltR b, EltR c)) -> EltR d) -> Exp a -> Exp b -> Exp c -> Exp d +mkPrimTernary prim (Exp a) (Exp b) (Exp c) = mkExp $ PrimApp prim (SmartExp $ Pair a (SmartExp (Pair b c))) + mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool mkPrimUnaryBool = mkCoerce @PrimBool $$ mkPrimUnary diff --git a/src/Data/Array/Accelerate/Trafo/Algebra.hs b/src/Data/Array/Accelerate/Trafo/Algebra.hs index 1e620435b..d8a655b06 100644 --- a/src/Data/Array/Accelerate/Trafo/Algebra.hs +++ b/src/Data/Array/Accelerate/Trafo/Algebra.hs @@ -145,6 +145,7 @@ evalPrimApp env f x PrimMax ty -> evalMax ty x env PrimMin ty -> evalMin ty x env PrimVectorIndex _ _ -> Nothing + PrimVectorWrite _ _ -> Nothing PrimLAnd -> evalLAnd x env PrimLOr -> evalLOr x env PrimLNot -> evalLNot x env diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 3fa13bf09..36c4f9570 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -15,6 +15,7 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TupleSections #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Vec @@ -96,6 +97,7 @@ data Vec (n :: Nat) a = Vec ByteArray# class Vectoring vector a | vector -> a where type IndexType vector :: Data.Kind.Type vecIndex :: vector -> IndexType vector -> a + vecWrite :: vector -> IndexType vector -> a -> vector vecEmpty :: vector instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where @@ -104,6 +106,14 @@ instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where n :: Int n = fromIntegral $ natVal $ Proxy @n in if i >= 0 && i < n then indexByteArray# ba# iu# else error ("index " <> show i <> " out of range in Vec of size " <> show n) + vecWrite vec@(Vec ba#) i@(I# iu#) v = runST $ do + let n :: Int + n = fromIntegral $ natVal $ Proxy @n + mba <- newByteArray (n * sizeOf (undefined :: a)) + let new_vs = zipWith (\i' v' -> if i' == i then v else v') [0..n] (listOfVec vec) + zipWithM_ (writeByteArray mba) [0..n] new_vs + ByteArray nba# <- unsafeFreezeByteArray mba + return $! Vec nba# vecEmpty = mkVec From 3fe1e808ebe1ae8bbe17dc5203b82d812f8c026c Mon Sep 17 00:00:00 2001 From: Hugo <hpeters1024@gmail.com> Date: Wed, 3 Nov 2021 16:01:04 +0100 Subject: [PATCH 07/86] Remove vector create constant --- src/Data/Array/Accelerate/AST.hs | 7 ------- src/Data/Array/Accelerate/Analysis/Hash.hs | 1 - src/Data/Array/Accelerate/Interpreter.hs | 1 - src/Data/Array/Accelerate/Smart.hs | 6 ------ 4 files changed, 15 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 3952c9c60..6b0f83d24 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -656,10 +656,6 @@ data PrimConst ty where -- constant from Floating PrimPi :: FloatingType a -> PrimConst a - -- constant for empty Vec - PrimVectorCreate :: (KnownNat n, Prim a) => VectorType (Vec n a) -> PrimConst (Vec n a) - - -- |Primitive scalar operations -- data PrimFun sig where @@ -847,7 +843,6 @@ primConstType = \case PrimMinBound t -> bounded t PrimMaxBound t -> bounded t PrimPi t -> floating t - PrimVectorCreate t -> vector t where bounded :: BoundedType a -> ScalarType a bounded (IntegralBoundedType t) = SingleScalarType $ NumSingleType $ IntegralNumType t @@ -1122,7 +1117,6 @@ rnfPrimConst :: PrimConst c -> () rnfPrimConst (PrimMinBound t) = rnfBoundedType t rnfPrimConst (PrimMaxBound t) = rnfBoundedType t rnfPrimConst (PrimPi t) = rnfFloatingType t -rnfPrimConst (PrimVectorCreate t) = rnfVectorType t rnfPrimFun :: PrimFun f -> () rnfPrimFun (PrimAdd t) = rnfNumType t @@ -1350,7 +1344,6 @@ liftPrimConst :: PrimConst c -> CodeQ (PrimConst c) liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||] liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||] liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||] -liftPrimConst (PrimVectorCreate t) = [|| PrimVectorCreate $$(liftVectorType t) ||] liftPrimFun :: PrimFun f -> CodeQ (PrimFun f) liftPrimFun (PrimAdd t) = [|| PrimAdd $$(liftNumType t) ||] diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index f7b22e47f..2b399aa46 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -389,7 +389,6 @@ encodePrimConst :: PrimConst c -> Builder encodePrimConst (PrimMinBound t) = intHost $(hashQ "PrimMinBound") <> encodeBoundedType t encodePrimConst (PrimMaxBound t) = intHost $(hashQ "PrimMaxBound") <> encodeBoundedType t encodePrimConst (PrimPi t) = intHost $(hashQ "PrimPi") <> encodeFloatingType t -encodePrimConst (PrimVectorCreate t) = intHost $(hashQ "PrimVectorCreate") <> encodeVectorType t encodePrimFun :: PrimFun f -> Builder encodePrimFun (PrimAdd a) = intHost $(hashQ "PrimAdd") <> encodeNumType a diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 06f184348..c304051ed 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -1083,7 +1083,6 @@ evalPrimConst :: PrimConst a -> a evalPrimConst (PrimMinBound ty) = evalMinBound ty evalPrimConst (PrimMaxBound ty) = evalMaxBound ty evalPrimConst (PrimPi ty) = evalPi ty -evalPrimConst (PrimVectorCreate ty) = evalVectorCreate ty evalPrim :: PrimFun (a -> r) -> (a -> r) evalPrim (PrimAdd ty) = evalAdd ty diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 4da5568ad..30981c660 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -73,7 +73,6 @@ module Data.Array.Accelerate.Smart ( mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..), -- ** Smart constructors for vector operations - mkVectorCreate, mkVectorIndex, mkVectorWrite, @@ -1181,11 +1180,6 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil x = SmartExp $ Prj PairIdxLeft a -- Operators from Vec -mkVectorCreate :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -mkVectorCreate = let n :: Int - n = fromIntegral $ natVal $ Proxy @n - in mkExp $ PrimConst $ PrimVectorCreate $ VectorType n singleType - mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a mkVectorIndex = let n :: Int n = fromIntegral $ natVal $ Proxy @n From faa139bcdb58fd3f457556508153c4bb2438989d Mon Sep 17 00:00:00 2001 From: Hugo <hpeters1024@gmail.com> Date: Thu, 4 Nov 2021 20:29:50 +0100 Subject: [PATCH 08/86] add missing pattern match and module in cabal file --- accelerate.cabal | 1 + src/Data/Array/Accelerate/AST.hs | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/accelerate.cabal b/accelerate.cabal index 0b95607e4..2e64e1e1f 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -402,6 +402,7 @@ library Data.Array.Accelerate.Classes.RealFloat Data.Array.Accelerate.Classes.RealFrac Data.Array.Accelerate.Classes.ToFloating + Data.Array.Accelerate.Classes.Vector Data.Array.Accelerate.Debug.Internal.Clock Data.Array.Accelerate.Debug.Internal.Flags Data.Array.Accelerate.Debug.Internal.Graph diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 6b0f83d24..242d015af 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -937,6 +937,11 @@ primFunType = \case i = integral i' in (v `TupRpair` i, single a) + PrimVectorWrite v'@(VectorType _ a) i' -> + let v = singleVector v' + i = integral i' + in (v `TupRpair` (i `TupRpair` single a), v) + -- general conversion between types PrimFromIntegral a b -> unary (integral a) (num b) PrimToFloating a b -> unary (num a) (floating b) From 0e250b8a05494a6f7aff4561add31c62d5321d38 Mon Sep 17 00:00:00 2001 From: Hugo <hpeters1024@gmail.com> Date: Thu, 2 Dec 2021 16:04:12 +0100 Subject: [PATCH 09/86] Move vec operations to correct AST --- src/Data/Array/Accelerate/AST.hs | 47 ++++----- src/Data/Array/Accelerate/Analysis/Hash.hs | 4 +- src/Data/Array/Accelerate/Classes/Vector.hs | 5 +- src/Data/Array/Accelerate/Interpreter.hs | 2 - .../Array/Accelerate/Representation/Vec.hs | 4 + src/Data/Array/Accelerate/Smart.hs | 30 ++++-- src/Data/Array/Accelerate/Trafo/Algebra.hs | 2 - src/Data/Array/Accelerate/Trafo/Sharing.hs | 68 +++++++------ src/Data/Array/Accelerate/Trafo/Shrink.hs | 6 ++ src/Data/Array/Accelerate/Trafo/Simplify.hs | 4 + .../Array/Accelerate/Trafo/Substitution.hs | 96 ++++++++++--------- 11 files changed, 155 insertions(+), 113 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 242d015af..31b2512ad 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -9,6 +9,8 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE AllowAmbiguousTypes #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.AST @@ -149,7 +151,6 @@ import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Type import Data.Primitive.Vec -import Data.Primitive.Types import Control.DeepSeq import Data.Kind import Data.Maybe @@ -560,6 +561,21 @@ data OpenExp env aenv t where -> OpenExp env aenv (Vec n s) -> OpenExp env aenv tup + VecIndex :: (KnownNat n, v ~ Vec n s) + => VectorType v + -> IntegralType i + -> OpenExp env aenv (Vec n s) + -> OpenExp env aenv i + -> OpenExp env aenv s + + VecWrite :: (KnownNat n, v ~ Vec n s) + => VectorType v + -> IntegralType i + -> OpenExp env aenv (Vec n s) + -> OpenExp env aenv i + -> OpenExp env aenv s + -> OpenExp env aenv (Vec n s) + -- Array indices & shapes IndexSlice :: SliceIndex slix sl co sh -> OpenExp env aenv slix @@ -748,10 +764,6 @@ data PrimFun sig where PrimLOr :: PrimFun ((PrimBool, PrimBool) -> PrimBool) PrimLNot :: PrimFun (PrimBool -> PrimBool) - -- local array operators - PrimVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a) - PrimVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, (i, a)) -> Vec n a) - -- general conversion between types PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b) PrimToFloating :: NumType a -> FloatingType b -> PrimFun (a -> b) @@ -818,6 +830,8 @@ expType = \case Nil -> TupRunit VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR VecUnpack vecR _ -> vecRtuple vecR + VecIndex vecT _ _ _ -> let (VectorType _ s) = vecT in TupRsingle $ SingleScalarType s + VecWrite vecT _ _ _ _ -> TupRsingle $ VectorScalarType vecT IndexSlice si _ _ -> shapeType $ sliceShapeR si IndexFull si _ _ -> shapeType $ sliceDomainR si ToIndex{} -> TupRsingle scalarTypeInt @@ -850,9 +864,6 @@ primConstType = \case floating :: FloatingType t -> ScalarType t floating = SingleScalarType . NumSingleType . FloatingNumType - vector :: forall n a. (KnownNat n) => VectorType (Vec n a) -> ScalarType (Vec n a) - vector = VectorScalarType - primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b) primFunType = \case -- Num @@ -931,17 +942,6 @@ primFunType = \case PrimLOr -> binary' tbool PrimLNot -> unary' tbool --- Local Vector operations - PrimVectorIndex v'@(VectorType _ a) i' -> - let v = singleVector v' - i = integral i' - in (v `TupRpair` i, single a) - - PrimVectorWrite v'@(VectorType _ a) i' -> - let v = singleVector v' - i = integral i' - in (v `TupRpair` (i `TupRpair` single a), v) - -- general conversion between types PrimFromIntegral a b -> unary (integral a) (num b) PrimToFloating a b -> unary (num a) (floating b) @@ -954,7 +954,6 @@ primFunType = \case compare' a = binary (single a) tbool single = TupRsingle . SingleScalarType - singleVector = TupRsingle . VectorScalarType num = TupRsingle . SingleScalarType . NumSingleType integral = num . IntegralNumType floating = num . FloatingNumType @@ -1092,6 +1091,8 @@ rnfOpenExp topExp = Nil -> () VecPack vecr e -> rnfVecR vecr `seq` rnfE e VecUnpack vecr e -> rnfVecR vecr `seq` rnfE e + VecIndex vt it v i -> rnfVectorType vt `seq` rnfIntegralType it `seq` rnfE v `seq` rnfE i + VecWrite vt it v i e -> rnfVectorType vt `seq` rnfIntegralType it `seq` rnfE v `seq` rnfE i `seq` rnfE e IndexSlice slice slix sh -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sh IndexFull slice slix sl -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sl ToIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix @@ -1184,7 +1185,6 @@ rnfPrimFun (PrimMin t) = rnfSingleType t rnfPrimFun PrimLAnd = () rnfPrimFun PrimLOr = () rnfPrimFun PrimLNot = () -rnfPrimFun (PrimVectorIndex v i) = rnfVectorType v `seq` rnfIntegralType i rnfPrimFun (PrimFromIntegral i n) = rnfIntegralType i `seq` rnfNumType n rnfPrimFun (PrimToFloating n f) = rnfNumType n `seq` rnfFloatingType f @@ -1313,6 +1313,8 @@ liftOpenExp pexp = Nil -> [|| Nil ||] VecPack vecr e -> [|| VecPack $$(liftVecR vecr) $$(liftE e) ||] VecUnpack vecr e -> [|| VecUnpack $$(liftVecR vecr) $$(liftE e) ||] + VecIndex vt it v i -> [|| VecIndex $$(liftVectorType vt) $$(liftIntegralType it) $$(liftE v) $$(liftE i) ||] + VecWrite vt it v i e -> [|| VecWrite $$(liftVectorType vt) $$(liftIntegralType it) $$(liftE v) $$(liftE i) $$(liftE e) ||] IndexSlice slice slix sh -> [|| IndexSlice $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sh) ||] IndexFull slice slix sl -> [|| IndexFull $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sl) ||] ToIndex shr sh ix -> [|| ToIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||] @@ -1411,7 +1413,6 @@ liftPrimFun (PrimMin t) = [|| PrimMin $$(liftSingleType t) ||] liftPrimFun PrimLAnd = [|| PrimLAnd ||] liftPrimFun PrimLOr = [|| PrimLOr ||] liftPrimFun PrimLNot = [|| PrimLNot ||] -liftPrimFun (PrimVectorIndex v i) = [|| PrimVectorIndex $$(liftVectorType v) $$(liftIntegralType i) ||] liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||] liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||] @@ -1461,6 +1462,8 @@ formatExpOp = later $ \case Nil{} -> "Nil" VecPack{} -> "VecPack" VecUnpack{} -> "VecUnpack" + VecIndex{} -> "VecIndex" + VecWrite{} -> "VecWrite" IndexSlice{} -> "IndexSlice" IndexFull{} -> "IndexFull" ToIndex{} -> "ToIndex" diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 2b399aa46..964a5f11a 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -320,6 +320,8 @@ encodeOpenExp exp = Pair e1 e2 -> intHost $(hashQ "Pair") <> travE e1 <> travE e2 VecPack _ e -> intHost $(hashQ "VecPack") <> travE e VecUnpack _ e -> intHost $(hashQ "VecUnpack") <> travE e + VecIndex _ _ v i -> intHost $(hashQ "VecIndex") <> travE v <> travE i + VecWrite _ _ v i e -> intHost $(hashQ "VecWrite") <> travE v <> travE i <> travE e Const tp c -> intHost $(hashQ "Const") <> encodeScalarConst tp c Undef tp -> intHost $(hashQ "Undef") <> encodeScalarType tp IndexSlice spec ix sh -> intHost $(hashQ "IndexSlice") <> travE ix <> travE sh <> encodeSliceIndex spec @@ -448,8 +450,6 @@ encodePrimFun (PrimEq a) = intHost $(hashQ "PrimEq") encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") <> encodeSingleType a encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeSingleType a encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeSingleType a -encodePrimFun (PrimVectorIndex (VectorType _ a) b) = intHost $(hashQ "PrimVectorIndex") <> encodeSingleType a <> encodeNumType (IntegralNumType b) -encodePrimFun (PrimVectorWrite (VectorType _ a) b) = intHost $(hashQ "PrimVectorWrite") <> encodeSingleType a <> encodeNumType (IntegralNumType b) encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b encodePrimFun PrimLAnd = intHost $(hashQ "PrimLAnd") diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs index 87586985d..21c7a7be2 100644 --- a/src/Data/Array/Accelerate/Classes/Vector.hs +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -5,6 +5,8 @@ {-# LANGUAGE MonoLocalBinds #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE GADTs #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | @@ -18,12 +20,13 @@ -- module Data.Array.Accelerate.Classes.Vector where -import Data.Kind import GHC.TypeLits import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Smart import Data.Primitive.Vec + + instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where type IndexType (Exp (Vec n a)) = Exp Int vecIndex = mkVectorIndex diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index c304051ed..aee68443f 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -1145,8 +1145,6 @@ evalPrim (PrimMin ty) = evalMin ty evalPrim PrimLAnd = evalLAnd evalPrim PrimLOr = evalLOr evalPrim PrimLNot = evalLNot -evalPrim (PrimVectorIndex v i) = evalVectorIndex v i -evalPrim (PrimVectorWrite v i) = evalVectorWrite v i evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb evalPrim (PrimToFloating ta tb) = evalToFloating ta tb diff --git a/src/Data/Array/Accelerate/Representation/Vec.hs b/src/Data/Array/Accelerate/Representation/Vec.hs index 35eac3b6c..bd37c7f18 100644 --- a/src/Data/Array/Accelerate/Representation/Vec.hs +++ b/src/Data/Array/Accelerate/Representation/Vec.hs @@ -41,6 +41,7 @@ data VecR (n :: Nat) single tuple where VecRnil :: SingleType s -> VecR 0 s () VecRsucc :: VecR n s t -> VecR (n + 1) s (t, s) + vecRvector :: KnownNat n => VecR n s tuple -> VectorType (Vec n s) vecRvector = uncurry VectorType . go where @@ -48,6 +49,9 @@ vecRvector = uncurry VectorType . go go (VecRnil tp) = (0, tp) go (VecRsucc vec) | (n, tp) <- go vec = (n + 1, tp) +vecRSingle :: KnownNat n => VecR n s tuple -> SingleType s +vecRSingle vecr = let (VectorType _ s) = vecRvector vecr in s + vecRtuple :: VecR n s tuple -> TypeR tuple vecRtuple = snd . go where diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 30981c660..ccb38e7ab 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -527,6 +527,21 @@ data PreSmartExp acc exp t where -> exp (Vec n s) -> PreSmartExp acc exp tup + VecIndex :: (KnownNat n, v ~ Vec n s) + => VectorType v + -> IntegralType i + -> exp (Vec n s) + -> exp i + -> PreSmartExp acc exp s + + VecWrite :: (KnownNat n, v ~ Vec n s) + => VectorType v + -> IntegralType i + -> exp (Vec n s) + -> exp i + -> exp s + -> PreSmartExp acc exp (Vec n s) + ToIndex :: ShapeR sh -> exp sh -> exp sh @@ -860,6 +875,8 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where Prj _ _ -> error "I never joke about my work" VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR VecUnpack vecR _ -> vecRtuple vecR + VecIndex vecT _ _ _ -> let (VectorType _ s) = vecT in TupRsingle $ SingleScalarType s + VecWrite vecT _ _ _ _ -> TupRsingle $ VectorScalarType vecT ToIndex _ _ _ -> TupRsingle scalarTypeInt FromIndex shr _ _ -> shapeType shr Case _ ((_,c):_) -> typeR c @@ -1179,16 +1196,15 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil where x = SmartExp $ Prj PairIdxLeft a --- Operators from Vec + +inferNat :: forall n. KnownNat n => Int +inferNat = fromInteger $ natVal (Proxy @n) + mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -mkVectorIndex = let n :: Int - n = fromIntegral $ natVal $ Proxy @n - in mkPrimBinary $ PrimVectorIndex @n (VectorType n singleType) integralType +mkVectorIndex (Exp v) (Exp i) = mkExp $ VecIndex (VectorType (inferNat @n) singleType) integralType v i mkVectorWrite :: forall n a. (KnownNat n, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -> Exp (Vec n a) -mkVectorWrite = let n :: Int - n = fromIntegral $ natVal $ Proxy @n - in mkPrimTernary $ PrimVectorWrite @n (VectorType n singleType) integralType +mkVectorWrite (Exp v) (Exp i) (Exp el) = mkExp $ VecWrite (VectorType (inferNat @n) singleType) integralType v i el -- Numeric conversions diff --git a/src/Data/Array/Accelerate/Trafo/Algebra.hs b/src/Data/Array/Accelerate/Trafo/Algebra.hs index d8a655b06..807ffe474 100644 --- a/src/Data/Array/Accelerate/Trafo/Algebra.hs +++ b/src/Data/Array/Accelerate/Trafo/Algebra.hs @@ -144,8 +144,6 @@ evalPrimApp env f x PrimNEq ty -> evalNEq ty x env PrimMax ty -> evalMax ty x env PrimMin ty -> evalMin ty x env - PrimVectorIndex _ _ -> Nothing - PrimVectorWrite _ _ -> Nothing PrimLAnd -> evalLAnd x env PrimLOr -> evalLOr x env PrimLNot -> evalLNot x env diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index 67ead04f0..9a740cb06 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -764,6 +764,8 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp Pair e1 e2 -> AST.Pair (cvt e1) (cvt e2) VecPack vec e -> AST.VecPack vec (cvt e) VecUnpack vec e -> AST.VecUnpack vec (cvt e) + VecIndex vt it v i -> AST.VecIndex vt it (cvt v) (cvt i) + VecWrite vt it v i e -> AST.VecWrite vt it (cvt v) (cvt i) (cvt e) ToIndex shr sh ix -> AST.ToIndex shr (cvt sh) (cvt ix) FromIndex shr sh e -> AST.FromIndex shr (cvt sh) (cvt e) Case e rhs -> cvtCase (cvt e) (over (mapped . _2) cvt rhs) @@ -1841,37 +1843,39 @@ makeOccMapSharingExp config accOccMap expOccMap = travE return (UnscopedExp [] (ExpSharing (StableNameHeight sn height) exp), height) reconstruct $ case pexp of - Tag tp i -> return (Tag tp i, 0) -- height is 0! - Const tp c -> return (Const tp c, 1) - Undef tp -> return (Undef tp, 1) - Nil -> return (Nil, 1) - Pair e1 e2 -> travE2 Pair e1 e2 - Prj i e -> travE1 (Prj i) e - VecPack vec e -> travE1 (VecPack vec) e - VecUnpack vec e -> travE1 (VecUnpack vec) e - ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix - FromIndex shr sh e -> travE2 (FromIndex shr) sh e - Match t e -> travE1 (Match t) e - Case e rhs -> do - (e', h1) <- travE lvl e - (rhs', h2) <- unzip <$> sequence [ travE1 (t,) c | (t,c) <- rhs ] - return (Case e' rhs', h1 `max` maximum h2 + 1) - Cond e1 e2 e3 -> travE3 Cond e1 e2 e3 - While t p iter init -> do - (p' , h1) <- traverseFun1 lvl t p - (iter', h2) <- traverseFun1 lvl t iter - (init', h3) <- travE lvl init - return (While t p' iter' init', h1 `max` h2 `max` h3 + 1) - PrimConst c -> return (PrimConst c, 1) - PrimApp p e -> travE1 (PrimApp p) e - Index tp a e -> travAE (Index tp) a e - LinearIndex tp a i -> travAE (LinearIndex tp) a i - Shape shr a -> travA (Shape shr) a - ShapeSize shr e -> travE1 (ShapeSize shr) e - Foreign tp ff f e -> do - (e', h) <- travE lvl e - return (Foreign tp ff f e', h+1) - Coerce t1 t2 e -> travE1 (Coerce t1 t2) e + Tag tp i -> return (Tag tp i, 0) -- height is 0! + Const tp c -> return (Const tp c, 1) + Undef tp -> return (Undef tp, 1) + Nil -> return (Nil, 1) + Pair e1 e2 -> travE2 Pair e1 e2 + Prj i e -> travE1 (Prj i) e + VecPack vec e -> travE1 (VecPack vec) e + VecUnpack vec e -> travE1 (VecUnpack vec) e + VecIndex vt ti v i -> travE2 (VecIndex vt ti) v i + VecWrite vt ti v i e -> travE3 (VecWrite vt ti) v i e + ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix + FromIndex shr sh e -> travE2 (FromIndex shr) sh e + Match t e -> travE1 (Match t) e + Case e rhs -> do + (e', h1) <- travE lvl e + (rhs', h2) <- unzip <$> sequence [ travE1 (t,) c | (t,c) <- rhs ] + return (Case e' rhs', h1 `max` maximum h2 + 1) + Cond e1 e2 e3 -> travE3 Cond e1 e2 e3 + While t p iter init -> do + (p' , h1) <- traverseFun1 lvl t p + (iter', h2) <- traverseFun1 lvl t iter + (init', h3) <- travE lvl init + return (While t p' iter' init', h1 `max` h2 `max` h3 + 1) + PrimConst c -> return (PrimConst c, 1) + PrimApp p e -> travE1 (PrimApp p) e + Index tp a e -> travAE (Index tp) a e + LinearIndex tp a i -> travAE (LinearIndex tp) a i + Shape shr a -> travA (Shape shr) a + ShapeSize shr e -> travE1 (ShapeSize shr) e + Foreign tp ff f e -> do + (e', h) <- travE lvl e + return (Foreign tp ff f e', h+1) + Coerce t1 t2 e -> travE1 (Coerce t1 t2) e where traverseAcc :: HasCallStack => Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int) @@ -2755,6 +2759,8 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp Prj i e -> travE1 (Prj i) e VecPack vec e -> travE1 (VecPack vec) e VecUnpack vec e -> travE1 (VecUnpack vec) e + VecIndex vt it v i -> travE2 (VecIndex vt it) v i + VecWrite vt it v i e -> travE3 (VecWrite vt it) v i e ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix FromIndex shr sh e -> travE2 (FromIndex shr) sh e Match t e -> travE1 (Match t) e diff --git a/src/Data/Array/Accelerate/Trafo/Shrink.hs b/src/Data/Array/Accelerate/Trafo/Shrink.hs index 574747865..636043113 100644 --- a/src/Data/Array/Accelerate/Trafo/Shrink.hs +++ b/src/Data/Array/Accelerate/Trafo/Shrink.hs @@ -293,6 +293,8 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE Pair x y -> Pair <$> shrinkE x <*> shrinkE y VecPack vec e -> VecPack vec <$> shrinkE e VecUnpack vec e -> VecUnpack vec <$> shrinkE e + VecIndex vt it v i -> VecIndex vt it <$> shrinkE v <*> shrinkE i + VecWrite vt it v i e -> VecWrite vt it <$> shrinkE v <*> shrinkE i <*> shrinkE e IndexSlice x ix sh -> IndexSlice x <$> shrinkE ix <*> shrinkE sh IndexFull x ix sl -> IndexFull x <$> shrinkE ix <*> shrinkE sl ToIndex shr sh ix -> ToIndex shr <$> shrinkE sh <*> shrinkE ix @@ -494,6 +496,8 @@ usesOfExp range = countE Pair e1 e2 -> countE e1 <> countE e2 VecPack _ e -> countE e VecUnpack _ e -> countE e + VecIndex _ _ v i -> countE v <> countE i + VecWrite _ _ v i e -> countE v <> countE i <> countE e IndexSlice _ ix sh -> countE ix <> countE sh IndexFull _ ix sl -> countE ix <> countE sl FromIndex _ sh i -> countE sh <> countE i @@ -581,6 +585,8 @@ usesOfPreAcc withShape countAcc idx = count Pair x y -> countE x + countE y VecPack _ e -> countE e VecUnpack _ e -> countE e + VecIndex _ _ v i -> countE v + countE i + VecWrite _ _ v i e -> countE v + countE i + countE e IndexSlice _ ix sh -> countE ix + countE sh IndexFull _ ix sl -> countE ix + countE sl ToIndex _ sh ix -> countE sh + countE ix diff --git a/src/Data/Array/Accelerate/Trafo/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Simplify.hs index 71be5aad3..6fe611f7a 100644 --- a/src/Data/Array/Accelerate/Trafo/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Simplify.hs @@ -226,6 +226,8 @@ simplifyOpenExp env = first getAny . cvtE Pair e1 e2 -> Pair <$> cvtE e1 <*> cvtE e2 VecPack vec e -> VecPack vec <$> cvtE e VecUnpack vec e -> VecUnpack vec <$> cvtE e + VecIndex vt it v i -> VecIndex vt it <$> cvtE v <*> cvtE i + VecWrite vt it v i e -> VecWrite vt it <$> cvtE v <*> cvtE i <*> cvtE e IndexSlice x ix sh -> IndexSlice x <$> cvtE ix <*> cvtE sh IndexFull x ix sl -> IndexFull x <$> cvtE ix <*> cvtE sl ToIndex shr sh ix -> toIndex shr (cvtE sh) (cvtE ix) @@ -548,6 +550,8 @@ summariseOpenExp = (terms +~ 1) . goE Pair e1 e2 -> travE e1 +++ travE e2 & terms +~ 1 VecPack _ e -> travE e VecUnpack _ e -> travE e + VecIndex _ _ v i -> travE v +++ travE i + VecWrite _ _ v i e -> travE v +++ travE i +++ travE e IndexSlice _ slix sh -> travE slix +++ travE sh & terms +~ 1 -- +1 for sliceIndex IndexFull _ slix sl -> travE slix +++ travE sl & terms +~ 1 -- +1 for sliceIndex ToIndex _ sh ix -> travE sh +++ travE ix diff --git a/src/Data/Array/Accelerate/Trafo/Substitution.hs b/src/Data/Array/Accelerate/Trafo/Substitution.hs index e1aa1176b..7debd6d07 100644 --- a/src/Data/Array/Accelerate/Trafo/Substitution.hs +++ b/src/Data/Array/Accelerate/Trafo/Substitution.hs @@ -149,29 +149,31 @@ inlineVars lhsBound expr bound substitute k1 k2 vars topExp = case topExp of Let lhs e1 e2 | Exists lhs' <- rebuildLHS lhs - -> Let lhs' <$> travE e1 <*> substitute (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weakenVars` vars) e2 - Evar (Var t ix) -> Evar . Var t <$> k1 ix - Foreign tp asm f e1 -> Foreign tp asm f <$> travE e1 - Pair e1 e2 -> Pair <$> travE e1 <*> travE e2 - Nil -> Just Nil - VecPack vec e1 -> VecPack vec <$> travE e1 - VecUnpack vec e1 -> VecUnpack vec <$> travE e1 - IndexSlice si e1 e2 -> IndexSlice si <$> travE e1 <*> travE e2 - IndexFull si e1 e2 -> IndexFull si <$> travE e1 <*> travE e2 - ToIndex shr e1 e2 -> ToIndex shr <$> travE e1 <*> travE e2 - FromIndex shr e1 e2 -> FromIndex shr <$> travE e1 <*> travE e2 - Case e1 rhs def -> Case <$> travE e1 <*> mapM (\(t,c) -> (t,) <$> travE c) rhs <*> travMaybeE def - Cond e1 e2 e3 -> Cond <$> travE e1 <*> travE e2 <*> travE e3 - While f1 f2 e1 -> While <$> travF f1 <*> travF f2 <*> travE e1 - Const t c -> Just $ Const t c - PrimConst c -> Just $ PrimConst c - PrimApp p e1 -> PrimApp p <$> travE e1 - Index a e1 -> Index a <$> travE e1 - LinearIndex a e1 -> LinearIndex a <$> travE e1 - Shape a -> Just $ Shape a - ShapeSize shr e1 -> ShapeSize shr <$> travE e1 - Undef t -> Just $ Undef t - Coerce t1 t2 e1 -> Coerce t1 t2 <$> travE e1 + -> Let lhs' <$> travE e1 <*> substitute (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weakenVars` vars) e2 + Evar (Var t ix) -> Evar . Var t <$> k1 ix + Foreign tp asm f e1 -> Foreign tp asm f <$> travE e1 + Pair e1 e2 -> Pair <$> travE e1 <*> travE e2 + Nil -> Just Nil + VecPack vec e1 -> VecPack vec <$> travE e1 + VecUnpack vec e1 -> VecUnpack vec <$> travE e1 + VecIndex vt it v i -> VecIndex vt it <$> travE v <*> travE i + VecWrite vt it v i e -> VecWrite vt it <$> travE v <*> travE i <*> travE e + IndexSlice si e1 e2 -> IndexSlice si <$> travE e1 <*> travE e2 + IndexFull si e1 e2 -> IndexFull si <$> travE e1 <*> travE e2 + ToIndex shr e1 e2 -> ToIndex shr <$> travE e1 <*> travE e2 + FromIndex shr e1 e2 -> FromIndex shr <$> travE e1 <*> travE e2 + Case e1 rhs def -> Case <$> travE e1 <*> mapM (\(t,c) -> (t,) <$> travE c) rhs <*> travMaybeE def + Cond e1 e2 e3 -> Cond <$> travE e1 <*> travE e2 <*> travE e3 + While f1 f2 e1 -> While <$> travF f1 <*> travF f2 <*> travE e1 + Const t c -> Just $ Const t c + PrimConst c -> Just $ PrimConst c + PrimApp p e1 -> PrimApp p <$> travE e1 + Index a e1 -> Index a <$> travE e1 + LinearIndex a e1 -> LinearIndex a <$> travE e1 + Shape a -> Just $ Shape a + ShapeSize shr e1 -> ShapeSize shr <$> travE e1 + Undef t -> Just $ Undef t + Coerce t1 t2 e1 -> Coerce t1 t2 <$> travE e1 where travE :: OpenExp env1 aenv s -> Maybe (OpenExp env2 aenv s) @@ -546,31 +548,33 @@ rebuildOpenExp -> f (OpenExp env' aenv' t) rebuildOpenExp v av@(ReindexAvar reindex) exp = case exp of - Const t c -> pure $ Const t c - PrimConst c -> pure $ PrimConst c - Undef t -> pure $ Undef t - Evar var -> expOut <$> v var + Const t c -> pure $ Const t c + PrimConst c -> pure $ PrimConst c + Undef t -> pure $ Undef t + Evar var -> expOut <$> v var Let lhs a b | Exists lhs' <- rebuildLHS lhs - -> Let lhs' <$> rebuildOpenExp v av a <*> rebuildOpenExp (shiftE' lhs lhs' v) av b - Pair e1 e2 -> Pair <$> rebuildOpenExp v av e1 <*> rebuildOpenExp v av e2 - Nil -> pure Nil - VecPack vec e -> VecPack vec <$> rebuildOpenExp v av e - VecUnpack vec e -> VecUnpack vec <$> rebuildOpenExp v av e - IndexSlice x ix sh -> IndexSlice x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh - IndexFull x ix sl -> IndexFull x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl - ToIndex shr sh ix -> ToIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix - FromIndex shr sh ix -> FromIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix - Case e rhs def -> Case <$> rebuildOpenExp v av e <*> sequenceA [ (t,) <$> rebuildOpenExp v av c | (t,c) <- rhs ] <*> rebuildMaybeExp v av def - Cond p t e -> Cond <$> rebuildOpenExp v av p <*> rebuildOpenExp v av t <*> rebuildOpenExp v av e - While p f x -> While <$> rebuildFun v av p <*> rebuildFun v av f <*> rebuildOpenExp v av x - PrimApp f x -> PrimApp f <$> rebuildOpenExp v av x - Index a sh -> Index <$> reindex a <*> rebuildOpenExp v av sh - LinearIndex a i -> LinearIndex <$> reindex a <*> rebuildOpenExp v av i - Shape a -> Shape <$> reindex a - ShapeSize shr sh -> ShapeSize shr <$> rebuildOpenExp v av sh - Foreign tp ff f e -> Foreign tp ff f <$> rebuildOpenExp v av e - Coerce t1 t2 e -> Coerce t1 t2 <$> rebuildOpenExp v av e + -> Let lhs' <$> rebuildOpenExp v av a <*> rebuildOpenExp (shiftE' lhs lhs' v) av b + Pair e1 e2 -> Pair <$> rebuildOpenExp v av e1 <*> rebuildOpenExp v av e2 + Nil -> pure Nil + VecPack vec e -> VecPack vec <$> rebuildOpenExp v av e + VecUnpack vec e -> VecUnpack vec <$> rebuildOpenExp v av e + VecIndex vt it v' i -> VecIndex vt it <$> rebuildOpenExp v av v' <*> rebuildOpenExp v av i + VecWrite vt it v' i e -> VecWrite vt it <$> rebuildOpenExp v av v' <*> rebuildOpenExp v av i <*> rebuildOpenExp v av e + IndexSlice x ix sh -> IndexSlice x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh + IndexFull x ix sl -> IndexFull x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl + ToIndex shr sh ix -> ToIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix + FromIndex shr sh ix -> FromIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix + Case e rhs def -> Case <$> rebuildOpenExp v av e <*> sequenceA [ (t,) <$> rebuildOpenExp v av c | (t,c) <- rhs ] <*> rebuildMaybeExp v av def + Cond p t e -> Cond <$> rebuildOpenExp v av p <*> rebuildOpenExp v av t <*> rebuildOpenExp v av e + While p f x -> While <$> rebuildFun v av p <*> rebuildFun v av f <*> rebuildOpenExp v av x + PrimApp f x -> PrimApp f <$> rebuildOpenExp v av x + Index a sh -> Index <$> reindex a <*> rebuildOpenExp v av sh + LinearIndex a i -> LinearIndex <$> reindex a <*> rebuildOpenExp v av i + Shape a -> Shape <$> reindex a + ShapeSize shr sh -> ShapeSize shr <$> rebuildOpenExp v av sh + Foreign tp ff f e -> Foreign tp ff f <$> rebuildOpenExp v av e + Coerce t1 t2 e -> Coerce t1 t2 <$> rebuildOpenExp v av e {-# INLINEABLE rebuildFun #-} rebuildFun From 21d6dab8fad5890675c2184d137a261b8e6ace21 Mon Sep 17 00:00:00 2001 From: Hugo <hpeters1024@gmail.com> Date: Wed, 8 Dec 2021 10:58:44 +0100 Subject: [PATCH 10/86] fix off by one errors --- src/Data/Primitive/Vec.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 36c4f9570..ff60d7d2e 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -110,8 +110,8 @@ instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where let n :: Int n = fromIntegral $ natVal $ Proxy @n mba <- newByteArray (n * sizeOf (undefined :: a)) - let new_vs = zipWith (\i' v' -> if i' == i then v else v') [0..n] (listOfVec vec) - zipWithM_ (writeByteArray mba) [0..n] new_vs + let new_vs = zipWith (\i' v' -> if i' == i then v else v') [0..n-1] (listOfVec vec) + zipWithM_ (writeByteArray mba) [0..n-1] new_vs ByteArray nba# <- unsafeFreezeByteArray mba return $! Vec nba# vecEmpty = mkVec @@ -139,7 +139,7 @@ vecOfList :: forall n a. (KnownNat n, Prim a) => [a] -> Vec n a vecOfList vs = runST $ do let n :: Int = fromIntegral $ natVal $ Proxy @n mba <- newByteArray (n * sizeOf (undefined :: a)) - zipWithM_ (writeByteArray mba) [0..n] vs + zipWithM_ (writeByteArray mba) [0..n-1] vs ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# From ad1f995dfa11b8b78f4dde01a95e4e05fb1ea4d3 Mon Sep 17 00:00:00 2001 From: Hugo <hpeters1024@gmail.com> Date: Mon, 13 Dec 2021 12:32:59 +0100 Subject: [PATCH 11/86] style changes --- src/Data/Array/Accelerate/AST.hs | 4 ++-- src/Data/Primitive/Vec.hs | 36 ++++++++++++++++---------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 31b2512ad..d3a26353e 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -7,10 +8,9 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE AllowAmbiguousTypes #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.AST diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index ff60d7d2e..52e5ccc39 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -1,21 +1,21 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE UnboxedTuples #-} -{-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE FunctionalDependencies #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TupleSections #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TupleSections #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Vec From f9556e3c6c1dfbfe946563d274670ae849ab122c Mon Sep 17 00:00:00 2001 From: Hugo <hpeters1024@gmail.com> Date: Wed, 19 Jan 2022 16:08:39 +0100 Subject: [PATCH 12/86] prevent memcpy using unsafe mutable coercion --- src/Data/Primitive/Vec.hs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 52e5ccc39..a50f643c2 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -109,9 +109,8 @@ instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where vecWrite vec@(Vec ba#) i@(I# iu#) v = runST $ do let n :: Int n = fromIntegral $ natVal $ Proxy @n - mba <- newByteArray (n * sizeOf (undefined :: a)) - let new_vs = zipWith (\i' v' -> if i' == i then v else v') [0..n-1] (listOfVec vec) - zipWithM_ (writeByteArray mba) [0..n-1] new_vs + mba <- unsafeThawByteArray (ByteArray ba#) + writeByteArray mba i v ByteArray nba# <- unsafeFreezeByteArray mba return $! Vec nba# vecEmpty = mkVec From 1eaa378b7018959f91a9edc4f58400439a74cc7a Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Thu, 10 Mar 2022 17:00:42 +0100 Subject: [PATCH 13/86] stack: update resolver --- stack-8.10.yaml | 2 +- stack-9.0.yaml | 2 +- stack-9.2.yaml | 8 +++----- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/stack-8.10.yaml b/stack-8.10.yaml index d0823dcbd..a7a50be2d 100644 --- a/stack-8.10.yaml +++ b/stack-8.10.yaml @@ -2,7 +2,7 @@ # For advanced use and comprehensive documentation of the format, please see: # https://docs.haskellstack.org/en/stable/yaml_configuration/ -resolver: lts-18.25 +resolver: lts-18.27 packages: - . diff --git a/stack-9.0.yaml b/stack-9.0.yaml index 1349abd27..19296b890 100644 --- a/stack-9.0.yaml +++ b/stack-9.0.yaml @@ -2,7 +2,7 @@ # For advanced use and comprehensive documentation of the format, please see: # https://docs.haskellstack.org/en/stable/yaml_configuration/ -resolver: nightly-2022-02-16 +resolver: nightly-2022-03-10 packages: - . diff --git a/stack-9.2.yaml b/stack-9.2.yaml index 69365d734..422f6750f 100644 --- a/stack-9.2.yaml +++ b/stack-9.2.yaml @@ -2,15 +2,13 @@ # For advanced use and comprehensive documentation of the format, please see: # https://docs.haskellstack.org/en/stable/yaml_configuration/ -compiler: ghc-9.2.1 -resolver: nightly-2022-02-19 +compiler: ghc-9.2.2 +resolver: nightly-2022-03-10 packages: - . -extra-deps: -- base-compat-0.12.1 -- doctest-0.20.0 +# extra-deps: [] # Override default flag values for local packages and extra-deps # flags: {} From 837c20c1d4492c3bf53ad3554bfd25c8617aefba Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 7 Jun 2022 15:40:28 +0200 Subject: [PATCH 14/86] add operations for 128-bit floating point numbers --- accelerate.cabal | 22 +- cbits/float128.c | 119 +++++++ src/Data/Numeric/Float128.hs | 591 +++++++++++++++++++++++++++++++++++ 3 files changed, 731 insertions(+), 1 deletion(-) create mode 100644 cbits/float128.c create mode 100644 src/Data/Numeric/Float128.hs diff --git a/accelerate.cabal b/accelerate.cabal index 24f4b65f9..e9bdc317d 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -204,6 +204,15 @@ custom-setup , directory >= 1.0 , filepath >= 1.0 +flag float128 + manual: True + default: False + description: + Enable support for 128-bit floating point numbers + . + This requires the library 'quadmath' to be installed. Note that not all + targets support 128-bit floating-point numbers. + flag debug manual: True default: False @@ -430,9 +439,10 @@ library Data.Array.Accelerate.Test.Similar -- Other + Crypto.Hash.XKCP Data.BitSet Data.Primitive.Vec - Crypto.Hash.XKCP + Data.Numeric.Float128 other-modules: Data.Array.Accelerate.Analysis.Hash.TH @@ -587,6 +597,16 @@ library -caf-all -auto-all + if flag(float128) + cc-options: + -DFLOAT128_ENABLE + + cpp-options: + -DFLOAT128_ENABLE + + extra-libraries: + quadmath + if flag(debug) cc-options: -DACCELERATE_DEBUG diff --git a/cbits/float128.c b/cbits/float128.c new file mode 100644 index 000000000..f16aad747 --- /dev/null +++ b/cbits/float128.c @@ -0,0 +1,119 @@ + +#include <quadmath.h> +#include <stdio.h> + +typedef _Float128 f128; + +union ieee754_quad { + f128 as_float128; + struct { +#if WORDS_BIGENDIAN + uint64_t negative:1; + uint64_t exponent:15; + uint64_t mantissa0:48; + uint64_t mantissa1; +#else + uint64_t mantissa1; + uint64_t mantissa0:48; + uint64_t exponent:15; + uint64_t negative:1; +#endif + } as_uint128; +}; + +/* Operations from Read and Show + */ +void _readq(f128* r, const char* str) { *r = strtoflt128(str, NULL); } +void _showq(char* buf, size_t n, f128 *a) { quadmath_snprintf(buf, n, "%Qf", *a); } + +/* Operations from Num + */ +void _addq(f128* r, const f128* a, const f128* b) { *r = *a + *b; } +void _subq(f128* r, const f128* a, const f128* b) { *r = *a - *b; } +void _mulq(f128* r, const f128* a, const f128* b) { *r = *a * *b; } +void _negateq(f128* r, const f128* a) { *r = - *a; } +void _absq(f128* r, const f128* a) { *r = fabsq(*a); } +void _signumq(f128* r, const f128* a) { *r = (*a > 0.0q) - (*a < 0.0q); } + +/* Operations from Fractional + */ +void _divq(f128* r, const f128* a, const f128* b) { *r = *a / *b; } +void _recipq(f128* r, const f128* a) { *r = 1.0q / *a; } + +/* Operations from Floating + */ +void _piq(f128* r) { *r = M_PIq; } +void _expq(f128* r, const f128* a) { *r = expq(*a); } +void _logq(f128* r, const f128* a) { *r = logq(*a); } +void _sqrtq(f128* r, const f128* a) { *r = sqrtq(*a); } +void _powq(f128* r, const f128* a, const f128* b) { *r = powq(*a, *b); } +void _sinq(f128* r, const f128* a) { *r = sinq(*a); } +void _cosq(f128* r, const f128* a) { *r = cosq(*a); } +void _tanq(f128* r, const f128* a) { *r = tanq(*a); } +void _asinq(f128* r, const f128* a) { *r = asinq(*a); } +void _acosq(f128* r, const f128* a) { *r = acosq(*a); } +void _atanq(f128* r, const f128* a) { *r = atanq(*a); } +void _sinhq(f128* r, const f128* a) { *r = sinhq(*a); } +void _coshq(f128* r, const f128* a) { *r = coshq(*a); } +void _tanhq(f128* r, const f128* a) { *r = tanhq(*a); } +void _asinhq(f128* r, const f128* a) { *r = asinhq(*a); } +void _acoshq(f128* r, const f128* a) { *r = acoshq(*a); } +void _atanhq(f128* r, const f128* a) { *r = atanhq(*a); } +void _log1pq(f128* r, const f128* a) { *r = log1pq(*a); } +void _expm1q(f128* r, const f128* a) { *r = expm1q(*a); } + +/* Operations from RealFrac + */ +void _roundq(f128* r, const f128* a) { *r = roundq(*a); } +void _truncq(f128* r, const f128* a) { *r = truncq(*a); } +void _floorq(f128* r, const f128* a) { *r = floorq(*a); } +void _ceilq(f128* r, const f128* a) { *r = ceilq(*a); } + +/* Operations from RealFloat + */ +uint32_t _isnanq(const f128* a) { return isnanq(*a); } +uint32_t _isinfq(const f128* a) { return isinfq(*a); } +void _frexpq(f128* r, const f128* a, int32_t* b) { *r = frexpq(*a, b); } +void _ldexpq(f128* r, const f128* a, int32_t b) { *r = ldexpq(*a, b); } +void _atan2q(f128* r, const f128* a, const f128* b) { *r = atan2q(*a, *b); } + +/* A (single/double/quad) precision floating point number is denormalized iff: + * - exponent is zero + * - mantissa is non-zero + * - (don't care about the sign bit) + */ +uint32_t _isdenormq(const f128* a) +{ + union ieee754_quad u; + u.as_float128 = *a; + + return (u.as_uint128.exponent == 0 + && (u.as_uint128.mantissa0 != 0 || u.as_uint128.mantissa1 != 0)); +} + +/* A (single/double/quad) precision floating point number is negative zero iff: + * - sign bit is set + * - all other bits are zero + */ +uint32_t _isnegzeroq(const f128* a) +{ + union ieee754_quad u; + u.as_float128 = *a; + + return ( + u.as_uint128.negative && + u.as_uint128.exponent == 0 && + u.as_uint128.mantissa0 == 0 && + u.as_uint128.mantissa1 == 0 + ); +} + +/* Operations from Ord + */ +uint32_t _ltq(const f128* a, const f128* b) { return *a < *b; } +uint32_t _leq(const f128* a, const f128* b) { return *a <= *b; } +uint32_t _gtq(const f128* a, const f128* b) { return *a > *b; } +uint32_t _geq(const f128* a, const f128* b) { return *a <= *b; } +void _fminq(f128* r, const f128* a, const f128* b) { *r = fminq(*a, *b); } +void _fmaxq(f128* r, const f128* a, const f128* b) { *r = fmaxq(*a, *b); } + diff --git a/src/Data/Numeric/Float128.hs b/src/Data/Numeric/Float128.hs new file mode 100644 index 000000000..fba3757d9 --- /dev/null +++ b/src/Data/Numeric/Float128.hs @@ -0,0 +1,591 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE UnboxedTuples #-} +{-# OPTIONS_GHC -fobject-code #-} +{-# OPTIONS_HADDOCK hide #-} +-- | +-- Module : Data.Numeric.Float128 +-- Copyright : [2008..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- +-- IEEE 128-bit floating point type (quadruple precision), consisting of +-- a 15-bit signed exponent and 113-bit mantissa (compared to 11-bit and +-- 53-bit respectively for IEEE double precision). +-- +-- Partly stolen from the (unmaintained) float128 package (BSD3 license). +-- We required the Float128 data type even if operations on that type are +-- not implemented, and link against the methods from the quadmath library +-- (as this is what LLVM will generate). +-- + +module Data.Numeric.Float128 ( + + Float128(..), + +) where + +import Numeric +import Data.Bits +import Data.Primitive.Types +import Data.Ratio +import Foreign.Marshal.Alloc +import Foreign.Marshal.Utils +import Foreign.Ptr +import Foreign.Storable +import System.IO.Unsafe + +import GHC.Base +import GHC.Int +import GHC.Integer.Logarithms +import GHC.Word + +#if defined(FLOAT128_ENABLE) && !defined(__GHCIDE__) +import Language.Haskell.TH.Syntax +#endif + + +-- | A 128-bit floating point number +-- +data Float128 = Float128 !Word64 !Word64 + deriving Eq + +instance Show Float128 where + showsPrec p x = showParen (x < 0 && p > 6) (showFloat x) + +instance Read Float128 where + readsPrec _ = readSigned readFloat + +instance Ord Float128 where + (<) = cmp c_ltq + (>) = cmp c_gtq + (<=) = cmp c_leq + (>=) = cmp c_geq + min = call2 c_fminq + max = call2 c_fmaxq + +instance Num Float128 where + (+) = call2 c_addq + (-) = call2 c_subq + (*) = call2 c_mulq + negate = call1 c_negateq + abs = call1 c_absq + signum = call1 c_signumq + fromInteger z = encodeFloat z 0 + +instance Fractional Float128 where + (/) = call2 c_divq + recip = call1 c_recipq + fromRational q = -- FIXME accuracy? + let a = fromInteger (numerator q) / fromInteger (denominator q) + b = fromInteger (numerator r) / fromInteger (denominator r) + r = q - toRational a + in a + b + +instance Floating Float128 where + pi = unsafePerformIO $ alloca $ \pr -> do + c_piq pr + peek pr + exp = call1 c_expq + log = call1 c_logq + sqrt = call1 c_sqrtq + (**) = call2 c_powq + sin = call1 c_sinq + cos = call1 c_cosq + tan = call1 c_tanq + asin = call1 c_asinq + acos = call1 c_acosq + atan = call1 c_atanq + sinh = call1 c_sinhq + cosh = call1 c_coshq + tanh = call1 c_tanhq + asinh = call1 c_asinhq + acosh = call1 c_acoshq + atanh = call1 c_atanhq + log1p = call1 c_log1pq + expm1 = call1 c_expm1q + +instance Real Float128 where + toRational l = + case decodeFloat l of + (m, e) + | e >= 0 -> m `shiftL` e % 1 + | otherwise -> m % bit (negate e) + +instance RealFrac Float128 where + properFraction l + | l >= 0 = let n' = floor' l + n = fromInteger (toInteger' n') + f = l - n' + in (n, f) + | l < 0 = let n' = ceil' l + n = fromInteger (toInteger' n') + f = l - n' + in (n, f) + | otherwise = (0, l) -- NaN + + truncate = fromInteger . toInteger' . trunc' + round = fromInteger . toInteger' . round' + floor = fromInteger . toInteger' . floor' + ceiling = fromInteger . toInteger' . ceil' + +toInteger' :: Float128 -> Integer +toInteger' l = + case decodeFloat l of + (m, e) + | e >= 0 -> m `shiftL` e + | otherwise -> m `shiftR` negate e + +round', trunc', floor', ceil' :: Float128 -> Float128 +round' = call1 c_roundq +trunc' = call1 c_truncq +floor' = call1 c_floorq +ceil' = call1 c_ceilq + +instance RealFloat Float128 where + isIEEE _ = True + floatRadix _ = 2 + floatDigits _ = 113 -- quadmath.h:FLT128_MANT_DIG + floatRange _ = (-16381,16384) -- quadmath.h:FLT128_MIN_EXP, FLT128_MAX_EXP + isNaN = tst c_isnanq + isInfinite = tst c_isinfq + isNegativeZero = tst c_isnegzeroq + isDenormalized = tst c_isdenormq + atan2 = call2 c_atan2q + + decodeFloat l@(Float128 msw lsw) + | isNaN l = (0, 0) + | isInfinite l = (0, 0) + | l == 0 = (0, 0) + | isDenormalized l = + case decodeFloat (scaleFloat 128 l) of + (m, e) -> (m, e - 128) + | otherwise = + let s = shiftR msw 48 `testBit` 15 + m0 = shiftL (0x1000000000000 .|. toInteger msw .&. 0xFFFFffffFFFF) 64 .|. toInteger lsw + e0 = shiftR msw 48 .&. (bit 15 - 1) + m = if s then negate m0 else m0 + e = fromIntegral e0 - 16383 - 112 -- FIXME verify + in (m, e) + + encodeFloat m e + | m == 0 = Float128 0 0 + | m < 0 = negate (encodeFloat (negate m) e) + | b >= bit 15 - 1 = Float128 (shiftL (bit 15 - 1) 48) 0 -- infinity + | b <= 0 = scaleFloat (b - 128) (encodeFloat m (e - b + 128)) -- denormal + | otherwise = Float128 msw lsw -- normal + where + l = I# (integerLog2# m) + t = l - 112 -- FIXME verify + m' | t >= 0 = m `shiftR` t + | otherwise = m `shiftL` negate t + -- FIXME: verify that m' `testBit` 112 == True + lsw = fromInteger (m' .&. 0xFFFFffffFFFFffff) + msw = fromInteger (shiftR m' 64 .&. 0xFFFFffffFFFF) .|. shiftL (fromIntegral b) 48 + b = e + t + 16383 + 112 -- FIXME verify + + exponent l@(Float128 msw _) + | isNaN l = 0 + | isInfinite l = 0 + | l == 0 = 0 + | isDenormalized l = snd (decodeFloat l) + 113 + | otherwise = let e0 = shiftR msw 48 .&. (bit 15 - 1) + in fromIntegral e0 - 16383 - 112 + 113 + + significand l = unsafePerformIO $ + with l $ \lp -> + alloca $ \ep -> do + c_frexpq lp lp ep + peek lp + + scaleFloat e l = unsafePerformIO $ + with l $ \lp -> do + c_ldexpq lp lp (fromIntegral e) + peek lp + +instance Storable Float128 where + {-# INLINE sizeOf #-} + {-# INLINE alignment #-} + {-# INLINE peek #-} + {-# INLINE poke #-} + sizeOf _ = 16 + alignment _ = 16 + peek = peek128 + peekElemOff = peekElemOff128 + poke = poke128 + pokeElemOff = pokeElemOff128 + +instance Prim Float128 where + {-# INLINE sizeOf# #-} + {-# INLINE alignment# #-} + {-# INLINE indexByteArray# #-} + {-# INLINE readByteArray# #-} + {-# INLINE writeByteArray# #-} + {-# INLINE setByteArray# #-} + {-# INLINE indexOffAddr# #-} + {-# INLINE readOffAddr# #-} + {-# INLINE writeOffAddr# #-} + {-# INLINE setOffAddr# #-} + sizeOf# _ = 16# + alignment# _ = 16# + indexByteArray# = indexByteArray128# + readByteArray# = readByteArray128# + writeByteArray# = writeByteArray128# + setByteArray# = setByteArray128# + indexOffAddr# = indexOffAddr128# + readOffAddr# = readOffAddr128# + writeOffAddr# = writeOffAddr128# + setOffAddr# = setOffAddr128# + +{-# INLINE peek128 #-} +peek128 :: Ptr Float128 -> IO Float128 +peek128 ptr = Float128 <$> peekElemOff (castPtr ptr) index1 + <*> peekElemOff (castPtr ptr) index0 + +{-# INLINE peekElemOff128 #-} +peekElemOff128 :: Ptr Float128 -> Int -> IO Float128 +peekElemOff128 ptr i = + let i2 = 2 * i + in Float128 <$> peekElemOff (castPtr ptr) (i2 + index1) + <*> peekElemOff (castPtr ptr) (i2 + index0) + +{-# INLINE poke128 #-} +poke128 :: Ptr Float128 -> Float128 -> IO () +poke128 ptr (Float128 a b) = do + pokeElemOff (castPtr ptr) index1 a + pokeElemOff (castPtr ptr) index0 b + +{-# INLINE pokeElemOff128 #-} +pokeElemOff128 :: Ptr Float128 -> Int -> Float128 -> IO () +pokeElemOff128 ptr i (Float128 a1 a0) = + let i2 = 2 * i + in do pokeElemOff (castPtr ptr) (i2 + index0) a0 + pokeElemOff (castPtr ptr) (i2 + index1) a1 + +{-# INLINE indexByteArray128# #-} +indexByteArray128# :: ByteArray# -> Int# -> Float128 +indexByteArray128# arr# i# = + let i2# = 2# *# i# + x = indexByteArray# arr# (i2# +# unInt index1) + y = indexByteArray# arr# (i2# +# unInt index0) + in Float128 x y + +{-# INLINE readByteArray128# #-} +readByteArray128# :: MutableByteArray# s -> Int# -> State# s -> (# State# s, Float128 #) +readByteArray128# arr# i# s0 = + let i2# = 2# *# i# + in case readByteArray# arr# (i2# +# unInt index1) s0 of { (# s1, x #) -> + case readByteArray# arr# (i2# +# unInt index0) s1 of { (# s2, y #) -> + (# s2, Float128 x y #) + }} + +{-# INLINE writeByteArray128# #-} +writeByteArray128# :: MutableByteArray# s -> Int# -> Float128 -> State# s -> State# s +writeByteArray128# arr# i# (Float128 a b) s0 = + let i2# = 2# *# i# + in case writeByteArray# arr# (i2# +# unInt index1) a s0 of { s1 -> + case writeByteArray# arr# (i2# +# unInt index0) b s1 of { s2 -> + s2 + }} + +{-# INLINE setByteArray128# #-} +setByteArray128# :: MutableByteArray# s -> Int# -> Int# -> Float128 -> State# s -> State# s +setByteArray128# = defaultSetByteArray# + +{-# INLINE indexOffAddr128# #-} +indexOffAddr128# :: Addr# -> Int# -> Float128 +indexOffAddr128# addr# i# = + let i2# = 2# *# i# + x = indexOffAddr# addr# (i2# +# unInt index1) + y = indexOffAddr# addr# (i2# +# unInt index0) + in Float128 x y + +{-# INLINE readOffAddr128# #-} +readOffAddr128# :: Addr# -> Int# -> State# s -> (# State# s, Float128 #) +readOffAddr128# addr# i# s0 = + let i2# = 2# *# i# + in case readOffAddr# addr# (i2# +# unInt index1) s0 of { (# s1, x #) -> + case readOffAddr# addr# (i2# +# unInt index0) s1 of { (# s2, y #) -> + (# s2, Float128 x y #) + }} + +{-# INLINE writeOffAddr128# #-} +writeOffAddr128# :: Addr# -> Int# -> Float128 -> State# s -> State# s +writeOffAddr128# addr# i# (Float128 a b) s0 = + let i2# = 2# *# i# + in case writeOffAddr# addr# (i2# +# unInt index1) a s0 of { s1 -> + case writeOffAddr# addr# (i2# +# unInt index0) b s1 of { s2 -> + s2 + }} + +{-# INLINE setOffAddr128# #-} +setOffAddr128# :: Addr# -> Int# -> Int# -> Float128 -> State# s -> State# s +setOffAddr128# = defaultSetOffAddr# + +{-# INLINE unInt #-} +unInt :: Int -> Int# +unInt (I# i#) = i# + +-- Use these indices to get the peek/poke ordering endian correct. +{-# INLINE index0 #-} +{-# INLINE index1 #-} +index0, index1 :: Int +#if WORDS_BIGENDIAN +index0 = 1 +index1 = 0 +#else +index0 = 0 +index1 = 1 +#endif + +type F1 = Ptr Float128 -> Ptr Float128 -> IO () +type F2 = Ptr Float128 -> Ptr Float128 -> Ptr Float128 -> IO () +type CMP = Ptr Float128 -> Ptr Float128 -> IO Int32 +type TST = Ptr Float128 -> IO Int32 + +{-# INLINE call1 #-} +call1 :: F1 -> Float128 -> Float128 +call1 f x = unsafePerformIO $ + alloca $ \pr -> + with x $ \px -> do + f pr px + peek pr + +{-# INLINE call2 #-} +call2 :: F2 -> Float128 -> Float128 -> Float128 +call2 f x y = unsafePerformIO $ + alloca $ \pr -> + with x $ \px -> do + with y $ \py -> do + f pr px py + peek pr + +{-# INLINE cmp #-} +cmp :: CMP -> Float128 -> Float128 -> Bool +cmp f x y = unsafePerformIO $ + with x $ \px -> + with y $ \py -> + toBool <$> f px py + +{-# INLINE tst #-} +tst :: TST -> Float128 -> Bool +tst f x = unsafePerformIO $ + with x $ \px -> + toBool <$> f px + + +-- SEE: [HLS and GHC IDE] +-- +#if defined(FLOAT128_ENABLE) && !defined(__GHCIDE__) + +foreign import ccall unsafe "_addq" c_addq :: F2 +foreign import ccall unsafe "_subq" c_subq :: F2 +foreign import ccall unsafe "_mulq" c_mulq :: F2 +foreign import ccall unsafe "_absq" c_absq :: F1 +foreign import ccall unsafe "_negateq" c_negateq :: F1 +foreign import ccall unsafe "_signumq" c_signumq :: F1 + +foreign import ccall unsafe "_divq" c_divq :: F2 +foreign import ccall unsafe "_recipq" c_recipq :: F1 + +foreign import ccall unsafe "_piq" c_piq :: Ptr Float128 -> IO () +foreign import ccall unsafe "_expq" c_expq :: F1 +foreign import ccall unsafe "_logq" c_logq :: F1 +foreign import ccall unsafe "_sqrtq" c_sqrtq :: F1 +foreign import ccall unsafe "_powq" c_powq :: F2 +foreign import ccall unsafe "_sinq" c_sinq :: F1 +foreign import ccall unsafe "_cosq" c_cosq :: F1 +foreign import ccall unsafe "_tanq" c_tanq :: F1 +foreign import ccall unsafe "_asinq" c_asinq :: F1 +foreign import ccall unsafe "_acosq" c_acosq :: F1 +foreign import ccall unsafe "_atanq" c_atanq :: F1 +foreign import ccall unsafe "_sinhq" c_sinhq :: F1 +foreign import ccall unsafe "_coshq" c_coshq :: F1 +foreign import ccall unsafe "_tanhq" c_tanhq :: F1 +foreign import ccall unsafe "_asinhq" c_asinhq :: F1 +foreign import ccall unsafe "_acoshq" c_acoshq :: F1 +foreign import ccall unsafe "_atanhq" c_atanhq :: F1 +foreign import ccall unsafe "_log1pq" c_log1pq :: F1 +foreign import ccall unsafe "_expm1q" c_expm1q :: F1 + +foreign import ccall unsafe "_roundq" c_roundq :: F1 +foreign import ccall unsafe "_truncq" c_truncq :: F1 +foreign import ccall unsafe "_floorq" c_floorq :: F1 +foreign import ccall unsafe "_ceilq" c_ceilq :: F1 + +foreign import ccall unsafe "_isnanq" c_isnanq :: TST +foreign import ccall unsafe "_isinfq" c_isinfq :: TST +foreign import ccall unsafe "_isnegzeroq" c_isnegzeroq :: TST +foreign import ccall unsafe "_isdenormq" c_isdenormq :: TST +foreign import ccall unsafe "_frexpq" c_frexpq :: Ptr Float128 -> Ptr Float128 -> Ptr Int32 -> IO () +foreign import ccall unsafe "_ldexpq" c_ldexpq :: Ptr Float128 -> Ptr Float128 -> Int32 -> IO () +foreign import ccall unsafe "_atan2q" c_atan2q :: F2 + +foreign import ccall unsafe "_ltq" c_ltq :: CMP +foreign import ccall unsafe "_ltq" c_gtq :: CMP +foreign import ccall unsafe "_ltq" c_leq :: CMP +foreign import ccall unsafe "_ltq" c_geq :: CMP +foreign import ccall unsafe "_fminq" c_fminq :: F2 +foreign import ccall unsafe "_fmaxq" c_fmaxq :: F2 + +-- foreign import ccall unsafe "_readq" c_readq :: Ptr Float128 -> Ptr CChar -> IO () +-- foreign import ccall unsafe "_showq" c_showq :: Ptr CChar -> CSize -> Ptr Float128 -> IO () + +-- SEE: [linking to .c files] +-- +runQ $ do + addForeignFilePath LangC "cbits/float128.c" + return [] + +#else + +c_addq :: F2 +c_addq = not_enabled + +c_subq :: F2 +c_subq = not_enabled + +c_mulq :: F2 +c_mulq = not_enabled + +c_absq :: F1 +c_absq = not_enabled + +c_negateq :: F1 +c_negateq = not_enabled + +c_signumq :: F1 +c_signumq = not_enabled + +c_divq :: F2 +c_divq = not_enabled + +c_recipq :: F1 +c_recipq = not_enabled + +c_piq :: Ptr Float128 -> IO () +c_piq = not_enabled + +c_expq :: F1 +c_expq = not_enabled + +c_logq :: F1 +c_logq = not_enabled + +c_sqrtq :: F1 +c_sqrtq = not_enabled + +c_powq :: F2 +c_powq = not_enabled + +c_sinq :: F1 +c_sinq = not_enabled + +c_cosq :: F1 +c_cosq = not_enabled + +c_tanq :: F1 +c_tanq = not_enabled + +c_asinq :: F1 +c_asinq = not_enabled + +c_acosq :: F1 +c_acosq = not_enabled + +c_atanq :: F1 +c_atanq = not_enabled + +c_sinhq :: F1 +c_sinhq = not_enabled + +c_coshq :: F1 +c_coshq = not_enabled + +c_tanhq :: F1 +c_tanhq = not_enabled + +c_asinhq :: F1 +c_asinhq = not_enabled + +c_acoshq :: F1 +c_acoshq = not_enabled + +c_atanhq :: F1 +c_atanhq = not_enabled + +c_log1pq :: F1 +c_log1pq = not_enabled + +c_expm1q :: F1 +c_expm1q = not_enabled + +c_roundq :: F1 +c_roundq = not_enabled + +c_truncq :: F1 +c_truncq = not_enabled + +c_floorq :: F1 +c_floorq = not_enabled + +c_ceilq :: F1 +c_ceilq = not_enabled + +c_isnanq :: TST +c_isnanq = not_enabled + +c_isinfq :: TST +c_isinfq = not_enabled + +c_isnegzeroq :: TST +c_isnegzeroq = not_enabled + +c_isdenormq :: TST +c_isdenormq = not_enabled + +c_frexpq :: Ptr Float128 -> Ptr Float128 -> Ptr Int32 -> IO () +c_frexpq = not_enabled + +c_ldexpq :: Ptr Float128 -> Ptr Float128 -> Int32 -> IO () +c_ldexpq = not_enabled + +c_atan2q :: F2 +c_atan2q = not_enabled + +c_ltq :: CMP +c_ltq = not_enabled + +c_gtq :: CMP +c_gtq = not_enabled + +c_leq :: CMP +c_leq = not_enabled + +c_geq :: CMP +c_geq = not_enabled + +c_fminq :: F2 +c_fminq = not_enabled + +c_fmaxq :: F2 +c_fmaxq = not_enabled + +-- c_readq :: Ptr Float128 -> Ptr CChar -> IO () +-- c_readq = not_enabled +-- +-- c_showq :: Ptr CChar -> CSize -> Ptr Float128 -> IO () +-- c_showq = not_enabled + +not_enabled :: a +not_enabled = error $ + unlines [ "128-bit floating point numbers are not enabled." + , "Reinstall package 'accelerate' with '-ffloat128' to enable them." + ] +#endif + From dff0ba70495c8e7d9e3fd24632dcb30baf849be6 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 7 Jun 2022 15:43:27 +0200 Subject: [PATCH 15/86] add basic types for single bits, bit vectors --- accelerate.cabal | 1 + src/Data/Primitive/Bit.hs | 244 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 245 insertions(+) create mode 100644 src/Data/Primitive/Bit.hs diff --git a/accelerate.cabal b/accelerate.cabal index e9bdc317d..f51f51829 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -441,6 +441,7 @@ library -- Other Crypto.Hash.XKCP Data.BitSet + Data.Primitive.Bit Data.Primitive.Vec Data.Numeric.Float128 diff --git a/src/Data/Primitive/Bit.hs b/src/Data/Primitive/Bit.hs new file mode 100644 index 000000000..f21b7edc0 --- /dev/null +++ b/src/Data/Primitive/Bit.hs @@ -0,0 +1,244 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UnboxedTuples #-} +-- | +-- Module : Data.Primitive.Bit +-- Copyright : [2008..2022] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Primitive.Bit ( + + Bit(..), + BitMask(..), extract, insert, zeros, ones, + +) where + +import Data.Array.Accelerate.Error + +import Data.Bits +import Data.Typeable +import Control.Monad.ST +import Control.Exception +import qualified Foreign.Storable as Foreign + +import Data.Primitive.ByteArray +import Data.Primitive.Types +import Data.Primitive.Vec ( Vec(..) ) + +import GHC.Base ( isTrue# ) +import GHC.Generics +import GHC.Int +import GHC.Prim +import GHC.TypeLits +import GHC.Types ( IO(..) ) +import GHC.Word +import qualified GHC.Exts as GHC + + +-- | A newtype wrapper over 'Bool' whose instances pack bits as efficiently +-- as possible (8 values per byte). Arrays of 'Bit' use 8x less memory than +-- arrays of 'Bool' (which stores one value per byte). However, (parallel) +-- random writes are slower. +-- +newtype Bit = Bit { unBit :: Bool } + deriving (Eq, Ord, Bounded, Enum, FiniteBits, Bits, Typeable, Generic) + +instance Show Bit where + showsPrec _ (Bit False) = showString "0" + showsPrec _ (Bit True) = showString "1" + +instance Read Bit where + readsPrec p = \case + ' ':rest -> readsPrec p rest + '0':rest -> [(Bit False, rest)] + '1':rest -> [(Bit True, rest)] + _ -> [] + +instance Num Bit where + Bit a * Bit b = Bit (a && b) + Bit a + Bit b = Bit (a /= b) + Bit a - Bit b = Bit (a /= b) + negate = id + abs = id + signum = id + fromInteger = Bit . odd + +instance Real Bit where + toRational = fromIntegral + +instance Integral Bit where + quotRem _ (Bit False) = throw DivideByZero + quotRem x (Bit True) = (x, Bit False) + toInteger (Bit False) = 0 + toInteger (Bit True) = 1 + + +-- | A SIMD vector of 'Bit's +-- +newtype BitMask n = BitMask { unMask :: Vec n Bit } + deriving Eq + +instance KnownNat n => Show (BitMask n) where + show = bin . toList + where + bin :: [Bit] -> String + bin bs = '0':'b': go bs + -- + go [] = [] + go (Bit True :rest) = '1' : go rest + go (Bit False:rest) = '0' : go rest + + +instance KnownNat n => GHC.IsList (BitMask n) where + type Item (BitMask n) = Bit + {-# INLINE toList #-} + {-# INLINE fromList #-} + toList = toList + fromList = fromList + +instance KnownNat n => Foreign.Storable (BitMask n) where + {-# INLINE sizeOf #-} + {-# INLINE alignment #-} + {-# INLINE peek #-} + {-# INLINE poke #-} + + alignment _ = 1 + sizeOf _ = + let k = fromIntegral (natVal' (proxy# :: Proxy# n)) + (q,r) = quotRem k 8 + in if r == 0 + then q + else q+1 + + peek (Ptr addr#) = + IO $ \s0 -> + case Foreign.sizeOf (undefined :: BitMask n) of { I# bytes# -> + case newByteArray# bytes# s0 of { (# s1, mba# #) -> + case copyAddrToByteArray# addr# mba# 0# bytes# s1 of { s2 -> + case unsafeFreezeByteArray# mba# s2 of { (# s3, ba# #) -> + (# s3, BitMask (Vec ba#) #) + }}}} + + poke (Ptr addr#) (BitMask (Vec ba#)) = + IO $ \s0 -> + case Foreign.sizeOf (undefined :: BitMask n) of { I# bytes# -> + case copyByteArrayToAddr# ba# 0# addr# bytes# s0 of { + s1 -> (# s1, () #) + }} + +{-# INLINE toList #-} +toList :: forall n. KnownNat n => BitMask n -> [Bit] +toList (BitMask (Vec ba#)) = go 0# + where + !(I# n#) = fromInteger (natVal' (proxy# :: Proxy# n)) + + go :: Int# -> [Bit] + go i# + | isTrue# (i# <# n#) = + let !(# q#, r# #) = quotRemInt# i# 8# + w# = indexWord8Array# ba# q# + b# = testBitWord8# w# r# + in + Bit (isTrue# b#) : go (i# +# 1#) + | otherwise = [] + +{-# INLINE fromList #-} +fromList :: forall n. KnownNat n => [Bit] -> BitMask n +fromList bits = case byteArrayFromListN bytes (pack bits) of + ByteArray ba# -> BitMask (Vec ba#) + where + bytes = Foreign.sizeOf (undefined :: BitMask n) + + pack :: [Bit] -> [Word8] + pack xs = + let (h,t) = splitAt 8 xs + w = w8 0 0 h + in if null t + then [w] + else w : pack t + + w8 :: Int -> Word8 -> [Bit] -> Word8 + w8 !_ !w [] = w + w8 !i !w (Bit True :bs) = w8 (i+1) (setBit w i) bs + w8 !i !w (Bit False:bs) = w8 (i+1) w bs + +{-# INLINE extract #-} +extract :: forall n. KnownNat n => BitMask n -> Int -> Bit +extract (BitMask (Vec ba#)) i@(I# i#) = + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + !(# q#, r# #) = quotRemInt# i# 8# + w# = indexWord8Array# ba# q# + b# = testBitWord8# w# r# + in + boundsCheck "out of range" (i >= 0 && i < n) (Bit (isTrue# b#)) + +{-# INLINE insert #-} +insert :: forall n. KnownNat n => BitMask n -> Int -> Bit -> BitMask n +insert (BitMask (Vec ba#)) i (Bit b) = runST $ do + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + (u,v) = quotRem i 8 + (q,r) = quotRem n 8 + bytes = if r == 0 + then q + else q + 1 + -- + mba <- newByteArray n + copyByteArray mba 0 (ByteArray ba#) 0 bytes + x :: Word8 <- readByteArray mba u + writeByteArray mba u $ if b then setBit x v + else clearBit x v + ByteArray ba'# <- unsafeFreezeByteArray mba + return (BitMask (Vec ba'#)) + +{-# INLINE zeros #-} +zeros :: forall n. KnownNat n => BitMask n +zeros = + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + (q,r) = quotRem n 8 + l = if r == 0 + then q + else q + 1 + in + case byteArrayFromListN l (replicate l (0 :: Word8)) of + ByteArray ba# -> BitMask (Vec ba#) + +{-# INLINE ones #-} +ones :: forall n. KnownNat n => BitMask n +ones = + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + (q,r) = quotRem n 8 + l = if r == 0 + then q + else q + 1 + in + case byteArrayFromListN l (replicate l (0xff :: Word8)) of + ByteArray ba# -> BitMask (Vec ba#) + + +#if __GLASGOW_HASKELL__ < 902 +testBitWord8# :: Word# -> Int# -> Int# +testBitWord8# x# i# = (x# `and#` bitWord8# i#) `neWord#` 0## + +bitWord8# :: Int# -> Word# +bitWord8# i# = narrow8Word# (1## `uncheckedShiftL#` i#) + +#else +testBitWord8# :: Word8# -> Int# -> Int# +testBitWord8# x# i# = (x# `andWord8#` bitWord8# i#) `neWord8#` (wordToWord8# 0##) + +bitWord8# :: Int# -> Word8# +bitWord8# i# = (wordToWord8# 1##) `uncheckedShiftLWord8#` i# +#endif + From 69c36a07ec25b1b1c8ba0daac82a968661a860fb Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Fri, 10 Jun 2022 23:54:37 +0200 Subject: [PATCH 16/86] fix to/fromList for BitMask --- src/Data/Primitive/Bit.hs | 41 ++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/src/Data/Primitive/Bit.hs b/src/Data/Primitive/Bit.hs index f21b7edc0..6c71e5d6b 100644 --- a/src/Data/Primitive/Bit.hs +++ b/src/Data/Primitive/Bit.hs @@ -21,7 +21,9 @@ module Data.Primitive.Bit ( Bit(..), - BitMask(..), extract, insert, zeros, ones, + BitMask(..), + toList, fromList, + extract, insert, zeros, ones, ) where @@ -89,6 +91,8 @@ instance Integral Bit where -- newtype BitMask n = BitMask { unMask :: Vec n Bit } deriving Eq + -- XXX: We should mask off the unused bits before testing for equality, + -- otherwise we are including junk in the test. TLM 2022-06-07 instance KnownNat n => Show (BitMask n) where show = bin . toList @@ -140,39 +144,44 @@ instance KnownNat n => Foreign.Storable (BitMask n) where {-# INLINE toList #-} toList :: forall n. KnownNat n => BitMask n -> [Bit] -toList (BitMask (Vec ba#)) = go 0# +toList (BitMask (Vec ba#)) = concat (unpack 0# []) where !(I# n#) = fromInteger (natVal' (proxy# :: Proxy# n)) - go :: Int# -> [Bit] - go i# + unpack :: Int# -> [[Bit]] -> [[Bit]] + unpack i# acc | isTrue# (i# <# n#) = - let !(# q#, r# #) = quotRemInt# i# 8# - w# = indexWord8Array# ba# q# - b# = testBitWord8# w# r# + let q# = quotInt# i# 8# + w# = indexWord8Array# ba# q# + lim# = minInt# 8# (n# -# i#) + w8 j# = if isTrue# (j# <# lim#) + then let b# = testBitWord8# w# (7# -# j#) + in Bit (isTrue# b#) : w8 (j# +# 1#) + else [] in - Bit (isTrue# b#) : go (i# +# 1#) - | otherwise = [] + unpack (i# +# 8#) (w8 0# : acc) + | otherwise = acc {-# INLINE fromList #-} fromList :: forall n. KnownNat n => [Bit] -> BitMask n -fromList bits = case byteArrayFromListN bytes (pack bits) of +fromList bits = case byteArrayFromListN bytes (pack bits') of ByteArray ba# -> BitMask (Vec ba#) where + bits' = take (fromInteger (natVal' (proxy# :: Proxy# n))) bits bytes = Foreign.sizeOf (undefined :: BitMask n) pack :: [Bit] -> [Word8] pack xs = let (h,t) = splitAt 8 xs - w = w8 0 0 h + w = w8 7 0 h in if null t then [w] else w : pack t w8 :: Int -> Word8 -> [Bit] -> Word8 w8 !_ !w [] = w - w8 !i !w (Bit True :bs) = w8 (i+1) (setBit w i) bs - w8 !i !w (Bit False:bs) = w8 (i+1) w bs + w8 !i !w (Bit True :bs) = w8 (i-1) (setBit w i) bs + w8 !i !w (Bit False:bs) = w8 (i-1) w bs {-# INLINE extract #-} extract :: forall n. KnownNat n => BitMask n -> Int -> Bit @@ -227,6 +236,12 @@ ones = ByteArray ba# -> BitMask (Vec ba#) +minInt# :: Int# -> Int# -> Int# +minInt# a# b# = + case a# <# b# of + 0# -> b# + _ -> a# + #if __GLASGOW_HASKELL__ < 902 testBitWord8# :: Word# -> Int# -> Int# testBitWord8# x# i# = (x# `and#` bitWord8# i#) `neWord#` 0## From ac1aa94c41038aaf5fcb533149908528ca127f92 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Sat, 11 Jun 2022 02:31:38 +0200 Subject: [PATCH 17/86] add support for computation on SIMD types Previously we could represent SIMD types (e.g. for complex numbers) but calculations directly in this representation were not possible. User defined (sum and product) types can also be stored in SIMD format through the (generic derivable) SIMD class. Nested SIMD types are not supported. New type classes VOrd and VEq which provide operations from Eq and Ord that compute their result per-lane. Types of member functions of several (otherwise standard H98) classes are changed in order to support SIMD types, in particular Bits. --- accelerate.cabal | 20 +- icebox/Interpreter.hs | 480 ++++++ icebox/Vec.hs | 99 ++ src/Data/Array/Accelerate.hs | 27 +- src/Data/Array/Accelerate/AST.hs | 372 +++-- src/Data/Array/Accelerate/AST/Idx.hs | 11 +- src/Data/Array/Accelerate/Analysis/Hash.hs | 150 +- src/Data/Array/Accelerate/Analysis/Match.hs | 280 ++-- src/Data/Array/Accelerate/Array/Data.hs | 647 +++++--- .../Array/Accelerate/Array/Remote/Class.hs | 4 +- src/Data/Array/Accelerate/Classes/Bounded.hs | 157 +- src/Data/Array/Accelerate/Classes/Enum.hs | 194 +-- src/Data/Array/Accelerate/Classes/Eq.hs | 96 +- src/Data/Array/Accelerate/Classes/Floating.hs | 169 +- .../Array/Accelerate/Classes/Fractional.hs | 68 +- .../Array/Accelerate/Classes/FromIntegral.hs | 94 +- src/Data/Array/Accelerate/Classes/Integral.hs | 198 +-- src/Data/Array/Accelerate/Classes/Num.hs | 282 +--- src/Data/Array/Accelerate/Classes/Ord.hs | 147 +- src/Data/Array/Accelerate/Classes/Rational.hs | 74 +- src/Data/Array/Accelerate/Classes/Real.hs | 4 +- .../Array/Accelerate/Classes/RealFloat.hs | 168 +- .../Accelerate/Classes/RealFloat.hs-boot | 67 - src/Data/Array/Accelerate/Classes/RealFrac.hs | 249 ++- .../Array/Accelerate/Classes/RealFrac.hs-boot | 24 - .../Array/Accelerate/Classes/ToFloating.hs | 74 +- src/Data/Array/Accelerate/Classes/VEq.hs | 211 +++ src/Data/Array/Accelerate/Classes/VEq.hs-boot | 26 + src/Data/Array/Accelerate/Classes/VOrd.hs | 166 ++ .../Array/Accelerate/Classes/VOrd.hs-boot | 41 + src/Data/Array/Accelerate/Classes/Vector.hs | 7 +- src/Data/Array/Accelerate/Data/Bits.hs | 723 ++------- src/Data/Array/Accelerate/Data/Complex.hs | 204 ++- src/Data/Array/Accelerate/Data/Either.hs | 7 +- src/Data/Array/Accelerate/Data/Maybe.hs | 11 +- src/Data/Array/Accelerate/Data/Ratio.hs | 11 +- src/Data/Array/Accelerate/Error.hs | 2 +- src/Data/Array/Accelerate/Interpreter.hs | 1429 +++++++---------- .../Accelerate/Interpreter/Arithmetic.hs | 591 +++++++ src/Data/Array/Accelerate/Language.hs | 53 +- src/Data/Array/Accelerate/Lift.hs | 6 +- src/Data/Array/Accelerate/Orphans.hs | 2 - src/Data/Array/Accelerate/Pattern.hs | 114 +- src/Data/Array/Accelerate/Pattern/Bool.hs | 59 +- src/Data/Array/Accelerate/Pattern/TH.hs | 14 +- src/Data/Array/Accelerate/Prelude.hs | 48 +- src/Data/Array/Accelerate/Pretty/Graphviz.hs | 11 +- src/Data/Array/Accelerate/Pretty/Print.hs | 78 +- .../Array/Accelerate/Representation/Array.hs | 55 +- .../Array/Accelerate/Representation/Elt.hs | 204 ++- .../Array/Accelerate/Representation/Shape.hs | 39 +- .../Array/Accelerate/Representation/Slice.hs | 17 +- .../Accelerate/Representation/Stencil.hs | 11 +- .../Array/Accelerate/Representation/Tag.hs | 40 +- .../Array/Accelerate/Representation/Type.hs | 52 +- .../Array/Accelerate/Representation/Vec.hs | 162 +- src/Data/Array/Accelerate/Smart.hs | 562 +++++-- src/Data/Array/Accelerate/Sugar/Array.hs | 6 +- src/Data/Array/Accelerate/Sugar/Elt.hs | 49 +- src/Data/Array/Accelerate/Sugar/Shape.hs | 14 +- src/Data/Array/Accelerate/Sugar/Vec.hs | 396 ++++- src/Data/Array/Accelerate/Test/NoFib/Base.hs | 31 +- .../Accelerate/Test/NoFib/Prelude/Map.hs | 24 +- .../Accelerate/Test/NoFib/Prelude/SIMD.hs | 61 +- .../Accelerate/Test/NoFib/Prelude/Stencil.hs | 2 +- .../Accelerate/Test/NoFib/Prelude/ZipWith.hs | 80 +- .../Test/NoFib/Spectral/RadixSort.hs | 8 +- src/Data/Array/Accelerate/Test/Similar.hs | 22 +- src/Data/Array/Accelerate/Trafo/Algebra.hs | 37 +- src/Data/Array/Accelerate/Trafo/Delayed.hs | 3 +- src/Data/Array/Accelerate/Trafo/Fusion.hs | 34 +- src/Data/Array/Accelerate/Trafo/Sharing.hs | 51 +- src/Data/Array/Accelerate/Trafo/Shrink.hs | 35 +- src/Data/Array/Accelerate/Trafo/Simplify.hs | 194 ++- .../Array/Accelerate/Trafo/Substitution.hs | 22 +- src/Data/Array/Accelerate/Type.hs | 641 ++++---- src/Data/Primitive/Vec.hs | 171 +- src/GHC/TypeLits/Extra.hs | 31 + stack-9.2.yaml | 9 +- 79 files changed, 6400 insertions(+), 4632 deletions(-) create mode 100644 icebox/Interpreter.hs create mode 100644 icebox/Vec.hs delete mode 100644 src/Data/Array/Accelerate/Classes/RealFloat.hs-boot delete mode 100644 src/Data/Array/Accelerate/Classes/RealFrac.hs-boot create mode 100644 src/Data/Array/Accelerate/Classes/VEq.hs create mode 100644 src/Data/Array/Accelerate/Classes/VEq.hs-boot create mode 100644 src/Data/Array/Accelerate/Classes/VOrd.hs create mode 100644 src/Data/Array/Accelerate/Classes/VOrd.hs-boot create mode 100644 src/Data/Array/Accelerate/Interpreter/Arithmetic.hs create mode 100644 src/GHC/TypeLits/Extra.hs diff --git a/accelerate.cabal b/accelerate.cabal index f51f51829..f120c4a64 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -370,6 +370,11 @@ library , unique , unordered-containers >= 0.2 , vector >= 0.10 + , wide-word >= 0.1 + + if impl(ghc < 9.0) + build-depends: + integer-gmp exposed-modules: -- The core language and reference implementation @@ -398,14 +403,15 @@ library Data.Array.Accelerate.Analysis.Hash Data.Array.Accelerate.Analysis.Match Data.Array.Accelerate.Array.Data - Data.Array.Accelerate.Array.Remote - Data.Array.Accelerate.Array.Remote.Class - Data.Array.Accelerate.Array.Remote.LRU - Data.Array.Accelerate.Array.Remote.Table + -- Data.Array.Accelerate.Array.Remote + -- Data.Array.Accelerate.Array.Remote.Class + -- Data.Array.Accelerate.Array.Remote.LRU + -- Data.Array.Accelerate.Array.Remote.Table Data.Array.Accelerate.Array.Unique Data.Array.Accelerate.Async - Data.Array.Accelerate.Error Data.Array.Accelerate.Debug.Internal + Data.Array.Accelerate.Error + Data.Array.Accelerate.Interpreter.Arithmetic Data.Array.Accelerate.Lifetime Data.Array.Accelerate.Pretty Data.Array.Accelerate.Representation.Array @@ -462,7 +468,8 @@ library Data.Array.Accelerate.Classes.RealFloat Data.Array.Accelerate.Classes.RealFrac Data.Array.Accelerate.Classes.ToFloating - Data.Array.Accelerate.Classes.Vector + Data.Array.Accelerate.Classes.VEq + Data.Array.Accelerate.Classes.VOrd Data.Array.Accelerate.Debug.Internal.Clock Data.Array.Accelerate.Debug.Internal.Flags Data.Array.Accelerate.Debug.Internal.Graph @@ -498,6 +505,7 @@ library Data.Array.Accelerate.Test.NoFib.Config Language.Haskell.TH.Extra + GHC.TypeLits.Extra if flag(nofib) build-depends: diff --git a/icebox/Interpreter.hs b/icebox/Interpreter.hs new file mode 100644 index 000000000..b3733425d --- /dev/null +++ b/icebox/Interpreter.hs @@ -0,0 +1,480 @@ +-- | +-- Module : Data.Array.Accelerate.Interpreter +-- Description : Reference backend (interpreted) +-- Copyright : [2008..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + + +-- | Stream a lazily read list of input arrays through the given program, +-- collecting results as we go +-- +streamOut :: Arrays a => Sugar.Seq [a] -> [a] +streamOut seq = let seq' = convertSeqWith config seq + in evalDelayedSeq defaultSeqConfig seq' + + +toSeqOp :: forall slix sl dim co e proxy. (Elt slix, Shape sl, Shape dim, Elt e) + => SliceIndex (EltRepr slix) + (EltRepr sl) + co + (EltRepr dim) + -> proxy slix + -> Array dim e + -> [Array sl e] +toSeqOp sliceIndex _ arr = map (sliceOp sliceIndex arr :: slix -> Array sl e) + (enumSlices sliceIndex (shape arr)) + +-- Sequence evaluation +-- --------------- + +-- Position in sequence. +-- +type SeqPos = Int + +-- Configuration for sequence evaluation. +-- +data SeqConfig = SeqConfig + { chunkSize :: Int -- Allocation limit for a sequence in + -- words. Actual runtime allocation should be the + -- maximum of this size and the size of the + -- largest element in the sequence. + } + +-- Default sequence evaluation configuration for testing purposes. +-- +defaultSeqConfig :: SeqConfig +defaultSeqConfig = SeqConfig { chunkSize = 2 } + +type Chunk a = Vector' a + +-- The empty chunk. O(1). +emptyChunk :: Arrays a => Chunk a +emptyChunk = empty' + +-- Number of arrays in chunk. O(1). +-- +clen :: Arrays a => Chunk a -> Int +clen = length' + +elemsPerChunk :: SeqConfig -> Int -> Int +elemsPerChunk conf n + | n < 1 = chunkSize conf + | otherwise = + let (a,b) = chunkSize conf `quotRem` n + in a + signum b + +-- Drop a number of arrays from a chunk. O(1). Note: Require keeping a +-- scan of element sizes. +-- +cdrop :: Arrays a => Int -> Chunk a -> Chunk a +cdrop = drop' dropOp (fst . offsetsOp) + +-- Get all the shapes of a chunk of arrays. O(1). +-- +chunkShapes :: Chunk (Array sh a) -> Vector sh +chunkShapes = shapes' + +-- Get all the elements of a chunk of arrays. O(1). +-- +chunkElems :: Chunk (Array sh a) -> Vector a +chunkElems = elements' + +-- Convert a vector to a chunk of scalars. +-- +vec2Chunk :: Elt e => Vector e -> Chunk (Scalar e) +vec2Chunk = vec2Vec' + +-- Convert a list of arrays to a chunk. +-- +fromListChunk :: Arrays a => [a] -> Vector' a +fromListChunk = fromList' concatOp + +-- Convert a chunk to a list of arrays. +-- +toListChunk :: Arrays a => Vector' a -> [a] +toListChunk = toList' fetchAllOp + +-- fmap for Chunk. O(n). +-- TODO: Use vectorised function. +mapChunk :: (Arrays a, Arrays b) + => (a -> b) + -> Chunk a -> Chunk b +mapChunk f c = fromListChunk $ map f (toListChunk c) + +-- zipWith for Chunk. O(n). +-- TODO: Use vectorised function. +zipWithChunk :: (Arrays a, Arrays b, Arrays c) + => (a -> b -> c) + -> Chunk a -> Chunk b -> Chunk c +zipWithChunk f c1 c2 = fromListChunk $ zipWith f (toListChunk c1) (toListChunk c2) + +-- A window on a sequence. +-- +data Window a = Window + { chunk :: Chunk a -- Current allocated chunk. + , wpos :: SeqPos -- Position of the window on the sequence, given + -- in number of elements. + } + +-- The initial empty window. +-- +window0 :: Arrays a => Window a +window0 = Window { chunk = emptyChunk, wpos = 0 } + +-- Index the given window by the given index on the sequence. +-- +(!#) :: Arrays a => Window a -> SeqPos -> Chunk a +w !# i + | j <- i - wpos w + , j >= 0 + = cdrop j (chunk w) + -- + | otherwise + = error $ "Window indexed before position. wpos = " ++ show (wpos w) ++ " i = " ++ show i + +-- Move the give window by supplying the next chunk. +-- +moveWin :: Arrays a => Window a -> Chunk a -> Window a +moveWin w c = w { chunk = c + , wpos = wpos w + clen (chunk w) + } + +-- A cursor on a sequence. +-- +data Cursor senv a = Cursor + { ref :: Idx senv a -- Reference to the sequence. + , cpos :: SeqPos -- Position of the cursor on the sequence, + -- given in number of elements. + } + +-- Initial cursor. +-- +cursor0 :: Idx senv a -> Cursor senv a +cursor0 x = Cursor { ref = x, cpos = 0 } + +-- Advance cursor by a relative amount. +-- +moveCursor :: Int -> Cursor senv a -> Cursor senv a +moveCursor k c = c { cpos = cpos c + k } + +-- Valuation for an environment of sequence windows. +-- +data Val' senv where + Empty' :: Val' () + Push' :: Val' senv -> Window t -> Val' (senv, t) + +-- Projection of a window from a window valuation using a de Bruijn +-- index. +-- +prj' :: Idx senv t -> Val' senv -> Window t +prj' ZeroIdx (Push' _ v) = v +prj' (SuccIdx idx) (Push' val _) = prj' idx val + +-- Projection of a chunk from a window valuation using a sequence +-- cursor. +-- +prjChunk :: Arrays a => Cursor senv a -> Val' senv -> Chunk a +prjChunk c senv = prj' (ref c) senv !# cpos c + +-- An executable sequence. +-- +data ExecSeq senv arrs where + ExecP :: Arrays a => Window a -> ExecP senv a -> ExecSeq (senv, a) arrs -> ExecSeq senv arrs + ExecC :: Arrays a => ExecC senv a -> ExecSeq senv a + ExecR :: Arrays a => Cursor senv a -> ExecSeq senv [a] + +-- An executable producer. +-- +data ExecP senv a where + ExecStreamIn :: Int + -> [a] + -> ExecP senv a + + ExecMap :: Arrays a + => (Chunk a -> Chunk b) + -> Cursor senv a + -> ExecP senv b + + ExecZipWith :: (Arrays a, Arrays b) + => (Chunk a -> Chunk b -> Chunk c) + -> Cursor senv a + -> Cursor senv b + -> ExecP senv c + + -- Stream scan skeleton. + ExecScan :: Arrays a + => (s -> Chunk a -> (Chunk r, s)) -- Chunk scanner. + -> s -- Accumulator (internal state). + -> Cursor senv a -- Input stream. + -> ExecP senv r + +-- An executable consumer. +-- +data ExecC senv a where + + -- Stream reduction skeleton. + ExecFold :: Arrays a + => (s -> Chunk a -> s) -- Chunk consumer function. + -> (s -> r) -- Finalizer function. + -> s -- Accumulator (internal state). + -> Cursor senv a -- Input stream. + -> ExecC senv r + + ExecStuple :: IsAtuple a + => Atuple (ExecC senv) (TupleRepr a) + -> ExecC senv a + +minCursor :: ExecSeq senv a -> SeqPos +minCursor s = travS s 0 + where + travS :: ExecSeq senv a -> Int -> SeqPos + travS s i = + case s of + ExecP _ p s' -> travP p i `min` travS s' (i+1) + ExecC c -> travC c i + ExecR _ -> maxBound + + k :: Cursor senv a -> Int -> SeqPos + k c i + | i == idxToInt (ref c) = cpos c + | otherwise = maxBound + + travP :: ExecP senv a -> Int -> SeqPos + travP p i = + case p of + ExecStreamIn _ _ -> maxBound + ExecMap _ c -> k c i + ExecZipWith _ c1 c2 -> k c1 i `min` k c2 i + ExecScan _ _ c -> k c i + + travT :: Atuple (ExecC senv) t -> Int -> SeqPos + travT NilAtup _ = maxBound + travT (SnocAtup t c) i = travT t i `min` travC c i + + travC :: ExecC senv a -> Int -> SeqPos + travC c i = + case c of + ExecFold _ _ _ cu -> k cu i + ExecStuple t -> travT t i + + +evalDelayedSeq + :: SeqConfig + -> DelayedSeq arrs + -> arrs +evalDelayedSeq cfg (DelayedSeq aenv s) | aenv' <- evalExtend aenv Empty + = evalSeq cfg s aenv' + +evalSeq :: forall aenv arrs. + SeqConfig + -> PreOpenSeq DelayedOpenAcc aenv () arrs + -> Val aenv -> arrs +evalSeq conf s aenv = evalSeq' s + where + evalSeq' :: PreOpenSeq DelayedOpenAcc aenv senv arrs -> arrs + evalSeq' (Producer _ s) = evalSeq' s + evalSeq' (Consumer _) = loop (initSeq aenv s) + evalSeq' (Reify _) = reify (initSeq aenv s) + + -- Initialize the producers and the accumulators of the consumers + -- with the given array enviroment. + initSeq :: forall senv arrs'. + Val aenv + -> PreOpenSeq DelayedOpenAcc aenv senv arrs' + -> ExecSeq senv arrs' + initSeq aenv s = + case s of + Producer p s' -> ExecP window0 (initProducer p) (initSeq aenv s') + Consumer c -> ExecC (initConsumer c) + Reify ix -> ExecR (cursor0 ix) + + -- Generate a list from the sequence. + reify :: forall arrs. ExecSeq () [arrs] + -> [arrs] + reify s = case step s Empty' of + (Just s', a) -> a ++ reify s' + (Nothing, a) -> a + + -- Iterate the given sequence until it terminates. + -- A sequence only terminates when one of the producers are exhausted. + loop :: Arrays arrs + => ExecSeq () arrs + -> arrs + loop s = + case step' s of + (Nothing, arrs) -> arrs + (Just s', _) -> loop s' + + where + step' :: ExecSeq () arrs -> (Maybe (ExecSeq () arrs), arrs) + step' s = step s Empty' + + -- One iteration of a sequence. + step :: forall senv arrs'. + ExecSeq senv arrs' + -> Val' senv + -> (Maybe (ExecSeq senv arrs'), arrs') + step s senv = + case s of + ExecP w p s' -> + let (c, mp') = produce p senv + finished = 0 == clen (w !# minCursor s') + w' = if finished then moveWin w c else w + (ms'', a) = step s' (senv `Push'` w') + in case ms'' of + Nothing -> (Nothing, a) + Just s'' | finished + , Just p' <- mp' + -> (Just (ExecP w' p' s''), a) + | not finished + -> (Just (ExecP w' p s''), a) + | otherwise + -> (Nothing, a) + ExecC c -> let (c', acc) = consume c senv + in (Just (ExecC c'), acc) + ExecR ix -> let c = prjChunk ix senv in (Just (ExecR (moveCursor (clen c) ix)), toListChunk c) + + evalA :: DelayedOpenAcc aenv a -> a + evalA acc = evalOpenAcc acc aenv + + evalAF :: DelayedOpenAfun aenv f -> f + evalAF f = evalOpenAfun f aenv + + evalE :: DelayedExp aenv t -> t + evalE exp = evalExp exp aenv + + evalF :: DelayedFun aenv f -> f + evalF fun = evalFun fun aenv + + initProducer :: forall a senv. + Producer DelayedOpenAcc aenv senv a + -> ExecP senv a + initProducer p = + case p of + StreamIn arrs -> ExecStreamIn 1 arrs + ToSeq sliceIndex slix (delayed -> Delayed sh ix _) -> + let n = R.size (R.sliceShape sliceIndex (fromElt sh)) + k = elemsPerChunk conf n + in ExecStreamIn k (toSeqOp sliceIndex slix (fromFunction sh ix)) + MapSeq f x -> ExecMap (mapChunk (evalAF f)) (cursor0 x) + ChunkedMapSeq f x -> ExecMap (evalAF f) (cursor0 x) + ZipWithSeq f x y -> ExecZipWith (zipWithChunk (evalAF f)) (cursor0 x) (cursor0 y) + ScanSeq f e x -> ExecScan scanner (evalE e) (cursor0 x) + where + scanner a c = + let v0 = chunkElems c + (v1, a') = scanl'Op (evalF f) a (delayArray v0) + in (vec2Chunk v1, fromScalar a') + + initConsumer :: forall a senv. + Consumer DelayedOpenAcc aenv senv a + -> ExecC senv a + initConsumer c = + case c of + FoldSeq f e x -> + let f' = evalF f + a0 = fromFunction (Z :. chunkSize conf) (const (evalE e)) + consumer v c = zipWith'Op f' (delayArray v) (delayArray (chunkElems c)) + finalizer = fold1Op f' . delayArray + in ExecFold consumer finalizer a0 (cursor0 x) + FoldSeqFlatten f acc x -> + let f' = evalAF f + a0 = evalA acc + consumer a c = f' a (chunkShapes c) (chunkElems c) + in ExecFold consumer id a0 (cursor0 x) + Stuple t -> + let initTup :: Atuple (Consumer DelayedOpenAcc aenv senv) t -> Atuple (ExecC senv) t + initTup NilAtup = NilAtup + initTup (SnocAtup t c) = SnocAtup (initTup t) (initConsumer c) + in ExecStuple (initTup t) + + delayed :: DelayedOpenAcc aenv (Array sh e) -> Delayed (Array sh e) + delayed AST.Manifest{} = $internalError "evalOpenAcc" "expected delayed array" + delayed AST.Delayed{..} = Delayed (evalExp extentD aenv) + (evalFun indexD aenv) + (evalFun linearIndexD aenv) + +produce :: Arrays a => ExecP senv a -> Val' senv -> (Chunk a, Maybe (ExecP senv a)) +produce p senv = + case p of + ExecStreamIn k xs -> + let (xs', xs'') = (take k xs, drop k xs) + c = fromListChunk xs' + mp = if null xs'' + then Nothing + else Just (ExecStreamIn k xs'') + in (c, mp) + ExecMap f x -> + let c = prjChunk x senv + in (f c, Just $ ExecMap f (moveCursor (clen c) x)) + ExecZipWith f x y -> + let c1 = prjChunk x senv + c2 = prjChunk y senv + k = clen c1 `min` clen c2 + in (f c1 c2, Just $ ExecZipWith f (moveCursor k x) (moveCursor k y)) + ExecScan scanner a x -> + let c = prjChunk x senv + (c', a') = scanner a c + k = clen c + in (c', Just $ ExecScan scanner a' (moveCursor k x)) + +consume :: forall senv a. ExecC senv a -> Val' senv -> (ExecC senv a, a) +consume c senv = + case c of + ExecFold f g acc x -> + let c = prjChunk x senv + acc' = f acc c + -- Even though we call g here, lazy evaluation should guarantee it is + -- only ever called once. + in (ExecFold f g acc' (moveCursor (clen c) x), g acc') + ExecStuple t -> + let consT :: Atuple (ExecC senv) t -> (Atuple (ExecC senv) t, t) + consT NilAtup = (NilAtup, ()) + consT (SnocAtup t c) | (c', acc) <- consume c senv + , (t', acc') <- consT t + = (SnocAtup t' c', (acc', acc)) + (t', acc) = consT t + in (ExecStuple t', toAtuple acc) + +evalExtend :: Extend DelayedOpenAcc aenv aenv' -> Val aenv -> Val aenv' +evalExtend BaseEnv aenv = aenv +evalExtend (PushEnv ext1 ext2) aenv | aenv' <- evalExtend ext1 aenv + = Push aenv' (evalOpenAcc ext2 aenv') + +delayArray :: Array sh e -> Delayed (Array sh e) +delayArray arr@(Array _ adata) = Delayed (shape arr) (arr!) (toElt . unsafeIndexArrayData adata) + +fromScalar :: Scalar a -> a +fromScalar = (!Z) + +concatOp :: forall e. Elt e => [Vector e] -> Vector e +concatOp = concatVectors + +fetchAllOp :: (Shape sh, Elt e) => Segments sh -> Vector e -> [Array sh e] +fetchAllOp segs elts + | (offsets, n) <- offsetsOp segs + , (n ! Z) <= size (shape elts) + = [fetch (segs ! (Z :. i)) (offsets ! (Z :. i)) | i <- [0 .. size (shape segs) - 1]] + | otherwise = error $ "illegal argument to fetchAllOp" + where + fetch sh offset = fromFunction sh (\ ix -> elts ! (Z :. ((toIndex sh ix) + offset))) + +dropOp :: Elt e => Int -> Vector e -> Vector e +dropOp i v -- TODO + -- * Implement using C-style pointer-plus. + -- ; dropOp is used often (from prjChunk), + -- so it ought to be efficient O(1). + | n <- size (shape v) + , i <= n + , i >= 0 + = fromFunction (Z :. n - i) (\ (Z :. j) -> v ! (Z :. i + j)) + | otherwise = error $ "illegal argument to drop" + +offsetsOp :: Shape sh => Segments sh -> (Vector Int, Scalar Int) +offsetsOp segs = scanl'Op (+) 0 $ delayArray (mapOp size (delayArray segs)) + diff --git a/icebox/Vec.hs b/icebox/Vec.hs new file mode 100644 index 000000000..5a324ed8e --- /dev/null +++ b/icebox/Vec.hs @@ -0,0 +1,99 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_HADDOCK hide #-} +-- | +-- Module : Data.Array.Accelerate.Representation.Vec +-- Copyright : [2008..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Representation.Vec + where + +import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Representation.Type +import qualified Data.Primitive.Vec as Prim + +import Control.Monad.ST +import Data.Primitive.ByteArray +import Data.Primitive.Types +import Language.Haskell.TH.Extra + +import GHC.Base ( Int(..), Int#, (-#) ) +import GHC.TypeNats + + +-- | Declares the size of a SIMD vector and the type of its elements. This +-- data type is used to denote the relation between a vector type (Vec +-- n single) with its tuple representation (tuple). Conversions between +-- those types are exposed through 'pack' and 'unpack'. +-- +data VecR (n :: Nat) single tuple where + VecRnil :: SingleType s -> VecR 0 s () + VecRsucc :: VecR n s t -> VecR (n + 1) s (t, s) + + +vecRvector :: KnownNat n => VecR n s tuple -> VectorType (Vec n s) +vecRvector = uncurry VectorType . go + where + go :: VecR n s tuple -> (Int, SingleType s) + go (VecRnil tp) = (0, tp) + go (VecRsucc vec) | (n, tp) <- go vec = (n + 1, tp) + +vecRSingle :: KnownNat n => VecR n s tuple -> SingleType s +vecRSingle vecr = let (VectorType _ s) = vecRvector vecr in s + +vecRtuple :: VecR n s tuple -> TypeR tuple +vecRtuple = snd . go + where + go :: VecR n s tuple -> (SingleType s, TypeR tuple) + go (VecRnil tp) = (tp, TupRunit) + go (VecRsucc vec) | (tp, tuple) <- go vec = (tp, TupRpair tuple (TupRsingle (SingleScalarType tp))) + +pack :: forall n single tuple. KnownNat n => VecR n single tuple -> tuple -> Vec n single +pack vecR tuple + | VectorType n single <- vecRvector vecR + , SingleDict <- singleDict single + = runST $ do + mba <- newByteArray (n * sizeOf (undefined :: single)) + go (n - 1) vecR tuple mba + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + where + go :: Prim single => Int -> VecR n' single tuple' -> tuple' -> MutableByteArray s -> ST s () + go _ (VecRnil _) () _ = return () + go i (VecRsucc r) (xs, x) mba = do + writeByteArray mba i x + go (i - 1) r xs mba + +unpack :: forall n single tuple. KnownNat n => VecR n single tuple -> Vec n single -> tuple +unpack vecR (Vec ba#) + | VectorType n single <- vecRvector vecR + , (I# n#) <- n + , SingleDict <- singleDict single + = go (n# -# 1#) vecR + where + go :: Prim single => Int# -> VecR n' single tuple' -> tuple' + go _ (VecRnil _) = () + go i# (VecRsucc r) = x `seq` xs `seq` (xs, x) + where + xs = go (i# -# 1#) r + x = indexByteArray# ba# i# + +rnfVecR :: VecR n single tuple -> () +rnfVecR (VecRnil tp) = rnfSingleType tp +rnfVecR (VecRsucc vec) = rnfVecR vec + +liftVecR :: VecR n single tuple -> CodeQ (VecR n single tuple) +liftVecR (VecRnil tp) = [|| VecRnil $$(liftSingleType tp) ||] +liftVecR (VecRsucc vec) = [|| VecRsucc $$(liftVecR vec) ||] + diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index e2543c6ae..22afa2a6b 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -309,15 +309,12 @@ module Data.Array.Accelerate ( Exp, -- ** SIMD vectors - Vec, VecElt, - Vectoring(..), - vecOfList, - listOfVec, + Vec, SIMD, -- ** Type classes -- *** Basic type classes - Eq(..), - Ord(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_, + Eq(..), VEq(..), + Ord(..), VOrd(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_, Enum, succ, pred, Bounded, minBound, maxBound, @@ -356,11 +353,11 @@ module Data.Array.Accelerate ( pattern I0, pattern I1, pattern I2, pattern I3, pattern I4, pattern I5, pattern I6, pattern I7, pattern I8, pattern I9, - pattern Vec2, pattern V2, - pattern Vec3, pattern V3, - pattern Vec4, pattern V4, - pattern Vec8, pattern V8, - pattern Vec16, pattern V16, + Vec2, pattern V2, + Vec3, pattern V3, + Vec4, pattern V4, + Vec8, pattern V8, + Vec16, pattern V16, mkPattern, mkPatterns, @@ -430,10 +427,6 @@ module Data.Array.Accelerate ( Either(..), pattern Left_, pattern Right_, Char, - CFloat, CDouble, - CShort, CUShort, CInt, CUInt, CLong, CULong, CLLong, CULLong, - CChar, CSChar, CUChar, - ) where import Data.Array.Accelerate.Classes.Bounded @@ -449,7 +442,8 @@ import Data.Array.Accelerate.Classes.Rational import Data.Array.Accelerate.Classes.RealFloat import Data.Array.Accelerate.Classes.RealFrac import Data.Array.Accelerate.Classes.ToFloating -import Data.Array.Accelerate.Classes.Vector +import Data.Array.Accelerate.Classes.VEq +import Data.Array.Accelerate.Classes.VOrd import Data.Array.Accelerate.Data.Either import Data.Array.Accelerate.Data.Maybe import Data.Array.Accelerate.Language @@ -463,7 +457,6 @@ import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Shape hiding ( size, toIndex, fromIndex, intersect ) import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type -import Data.Primitive.Vec import qualified Data.Array.Accelerate.Sugar.Array as S import qualified Data.Array.Accelerate.Sugar.Shape as S diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 57d1a53c6..c307275ef 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -1,5 +1,6 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} @@ -88,15 +89,14 @@ module Data.Array.Accelerate.AST ( Fun, OpenFun(..), Exp, OpenExp(..), Boundary(..), - PrimConst(..), PrimFun(..), PrimBool, PrimMaybe, + BitOrMask, -- ** Extracting type information HasArraysR(..), arrayR, expType, - primConstType, primFunType, -- ** Normal-form @@ -111,7 +111,6 @@ module Data.Array.Accelerate.AST ( rnfExpVar, rnfBoundary, rnfConst, - rnfPrimConst, rnfPrimFun, -- ** Template Haskell @@ -125,7 +124,6 @@ module Data.Array.Accelerate.AST ( liftELeftHandSide, liftExpVar, liftBoundary, - liftPrimConst, liftPrimFun, liftMessage, @@ -146,7 +144,6 @@ import Data.Array.Accelerate.Representation.Slice import Data.Array.Accelerate.Representation.Stencil import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type -import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Type import Data.Primitive.Vec @@ -161,8 +158,6 @@ import Language.Haskell.TH.Extra ( CodeQ ) import qualified Language.Haskell.TH.Extra as TH import qualified Language.Haskell.TH.Syntax as TH -import GHC.TypeLits - -- Array expressions -- ----------------- @@ -199,8 +194,8 @@ type ALeftHandSide = LeftHandSide ArrayR type ArrayVar = Var ArrayR type ArrayVars aenv = Vars ArrayR aenv --- Bool is not a primitive type -type PrimBool = TAG +type PrimBool = Bit +type PrimMask n = Vec n Bit type PrimMaybe a = (TAG, ((), a)) -- Trace messages @@ -370,18 +365,18 @@ data PreOpenAcc (acc :: Type -> Type -> Type) aenv a where -- Fold :: Fun aenv (e -> e -> e) -- combination function -> Maybe (Exp aenv e) -- default value - -> acc aenv (Array (sh, Int) e) -- folded array + -> acc aenv (Array (sh, INT) e) -- folded array -> PreOpenAcc acc aenv (Array sh e) -- Segmented fold along the innermost dimension of an array with a given -- /associative/ function -- - FoldSeg :: IntegralType i + FoldSeg :: SingleIntegralType i -> Fun aenv (e -> e -> e) -- combination function -> Maybe (Exp aenv e) -- default value - -> acc aenv (Array (sh, Int) e) -- folded array + -> acc aenv (Array (sh, INT) e) -- folded array -> acc aenv (Segments i) -- segment descriptor - -> PreOpenAcc acc aenv (Array (sh, Int) e) + -> PreOpenAcc acc aenv (Array (sh, INT) e) -- Haskell-style scan of a linear array with a given -- /associative/ function and optionally an initial element @@ -391,8 +386,8 @@ data PreOpenAcc (acc :: Type -> Type -> Type) aenv a where Scan :: Direction -> Fun aenv (e -> e -> e) -- combination function -> Maybe (Exp aenv e) -- initial value - -> acc aenv (Array (sh, Int) e) - -> PreOpenAcc acc aenv (Array (sh, Int) e) + -> acc aenv (Array (sh, INT) e) + -> PreOpenAcc acc aenv (Array (sh, INT) e) -- Like 'Scan', but produces a rightmost (in case of a left-to-right scan) -- fold value and an array with the same length as the input array (the @@ -401,8 +396,8 @@ data PreOpenAcc (acc :: Type -> Type -> Type) aenv a where Scan' :: Direction -> Fun aenv (e -> e -> e) -- combination function -> Exp aenv e -- initial value - -> acc aenv (Array (sh, Int) e) - -> PreOpenAcc acc aenv (Array (sh, Int) e, Array sh e) + -> acc aenv (Array (sh, INT) e) + -> PreOpenAcc acc aenv (Array (sh, INT) e, Array sh e) -- Generalised forward permutation is characterised by a permutation function -- that determines for each element of the source array where it should go in @@ -551,30 +546,30 @@ data OpenExp env aenv t where Nil :: OpenExp env aenv () -- SIMD vectors - VecPack :: KnownNat n - => VecR n s tup - -> OpenExp env aenv tup - -> OpenExp env aenv (Vec n s) - - VecUnpack :: KnownNat n - => VecR n s tup - -> OpenExp env aenv (Vec n s) - -> OpenExp env aenv tup - - VecIndex :: (KnownNat n, v ~ Vec n s) - => VectorType v - -> IntegralType i - -> OpenExp env aenv (Vec n s) + Extract :: ScalarType (Vec n a) + -> SingleIntegralType i + -> OpenExp env aenv (Vec n a) -> OpenExp env aenv i - -> OpenExp env aenv s + -> OpenExp env aenv a - VecWrite :: (KnownNat n, v ~ Vec n s) - => VectorType v - -> IntegralType i - -> OpenExp env aenv (Vec n s) + Insert :: ScalarType (Vec n a) + -> SingleIntegralType i + -> OpenExp env aenv (Vec n a) -> OpenExp env aenv i - -> OpenExp env aenv s - -> OpenExp env aenv (Vec n s) + -> OpenExp env aenv a + -> OpenExp env aenv (Vec n a) + + Shuffle :: ScalarType (Vec m a) + -> SingleIntegralType i + -> OpenExp env aenv (Vec n a) + -> OpenExp env aenv (Vec n a) + -> OpenExp env aenv (Vec m i) + -> OpenExp env aenv (Vec m a) + + Select :: OpenExp env aenv (PrimMask n) + -> OpenExp env aenv (Vec n a) + -> OpenExp env aenv (Vec n a) + -> OpenExp env aenv (Vec n a) -- Array indices & shapes IndexSlice :: SliceIndex slix sl co sh @@ -591,15 +586,16 @@ data OpenExp env aenv t where ToIndex :: ShapeR sh -> OpenExp env aenv sh -- shape of the array -> OpenExp env aenv sh -- index into the array - -> OpenExp env aenv Int + -> OpenExp env aenv INT FromIndex :: ShapeR sh -> OpenExp env aenv sh -- shape of the array - -> OpenExp env aenv Int -- index into linear representation + -> OpenExp env aenv INT -- index into linear representation -> OpenExp env aenv sh -- Case statement - Case :: OpenExp env aenv TAG + Case :: ScalarType TAG + -> OpenExp env aenv TAG -> [(TAG, OpenExp env aenv b)] -- list of equations -> Maybe (OpenExp env aenv b) -- default case -> OpenExp env aenv b @@ -621,9 +617,6 @@ data OpenExp env aenv t where -> t -> OpenExp env aenv t - PrimConst :: PrimConst t - -> OpenExp env aenv t - -- Primitive scalar operations PrimApp :: PrimFun (a -> r) -> OpenExp env aenv a @@ -636,7 +629,7 @@ data OpenExp env aenv t where -> OpenExp env aenv t LinearIndex :: ArrayVar aenv (Array dim t) - -> OpenExp env aenv Int + -> OpenExp env aenv INT -> OpenExp env aenv t -- Array shape. @@ -647,7 +640,7 @@ data OpenExp env aenv t where -- Number of elements of an array given its shape ShapeSize :: ShapeR dim -> OpenExp env aenv dim - -> OpenExp env aenv Int + -> OpenExp env aenv INT -- Unsafe operations (may fail or result in undefined behaviour) -- An unspecified bit pattern @@ -661,18 +654,14 @@ data OpenExp env aenv t where -> OpenExp env aenv a -> OpenExp env aenv b --- |Primitive constant values --- -data PrimConst ty where - - -- constants from Bounded - PrimMinBound :: BoundedType a -> PrimConst a - PrimMaxBound :: BoundedType a -> PrimConst a - -- constant from Floating - PrimPi :: FloatingType a -> PrimConst a +-- | A bit mask at the width of the corresponding type +-- +type family BitOrMask a where + BitOrMask (Vec n _) = PrimMask n + BitOrMask _ = PrimBool --- |Primitive scalar operations +-- | Primitive scalar operations -- data PrimFun sig where @@ -693,17 +682,17 @@ data PrimFun sig where PrimDivMod :: IntegralType a -> PrimFun ((a, a) -> (a, a)) -- operators from Bits & FiniteBits - PrimBAnd :: IntegralType a -> PrimFun ((a, a) -> a) - PrimBOr :: IntegralType a -> PrimFun ((a, a) -> a) - PrimBXor :: IntegralType a -> PrimFun ((a, a) -> a) - PrimBNot :: IntegralType a -> PrimFun (a -> a) - PrimBShiftL :: IntegralType a -> PrimFun ((a, Int) -> a) - PrimBShiftR :: IntegralType a -> PrimFun ((a, Int) -> a) - PrimBRotateL :: IntegralType a -> PrimFun ((a, Int) -> a) - PrimBRotateR :: IntegralType a -> PrimFun ((a, Int) -> a) - PrimPopCount :: IntegralType a -> PrimFun (a -> Int) - PrimCountLeadingZeros :: IntegralType a -> PrimFun (a -> Int) - PrimCountTrailingZeros :: IntegralType a -> PrimFun (a -> Int) + PrimBAnd :: IntegralType a -> PrimFun ((a, a) -> a) + PrimBOr :: IntegralType a -> PrimFun ((a, a) -> a) + PrimBXor :: IntegralType a -> PrimFun ((a, a) -> a) + PrimBNot :: IntegralType a -> PrimFun (a -> a) + PrimBShiftL :: IntegralType a -> PrimFun ((a, a) -> a) + PrimBShiftR :: IntegralType a -> PrimFun ((a, a) -> a) + PrimBRotateL :: IntegralType a -> PrimFun ((a, a) -> a) + PrimBRotateR :: IntegralType a -> PrimFun ((a, a) -> a) + PrimPopCount :: IntegralType a -> PrimFun (a -> a) + PrimCountLeadingZeros :: IntegralType a -> PrimFun (a -> a) + PrimCountTrailingZeros :: IntegralType a -> PrimFun (a -> a) -- operators from Fractional and Floating PrimFDiv :: FloatingType a -> PrimFun ((a, a) -> a) @@ -726,8 +715,6 @@ data PrimFun sig where PrimFPow :: FloatingType a -> PrimFun ((a, a) -> a) PrimLogBase :: FloatingType a -> PrimFun ((a, a) -> a) - -- FIXME: add missing operations from RealFrac & RealFloat - -- operators from RealFrac PrimTruncate :: FloatingType a -> IntegralType b -> PrimFun (a -> b) PrimRound :: FloatingType a -> IntegralType b -> PrimFun (a -> b) @@ -737,18 +724,18 @@ data PrimFun sig where -- operators from RealFloat PrimAtan2 :: FloatingType a -> PrimFun ((a, a) -> a) - PrimIsNaN :: FloatingType a -> PrimFun (a -> PrimBool) - PrimIsInfinite :: FloatingType a -> PrimFun (a -> PrimBool) + PrimIsNaN :: FloatingType a -> PrimFun (a -> BitOrMask a) + PrimIsInfinite :: FloatingType a -> PrimFun (a -> BitOrMask a) -- relational and equality operators - PrimLt :: SingleType a -> PrimFun ((a, a) -> PrimBool) - PrimGt :: SingleType a -> PrimFun ((a, a) -> PrimBool) - PrimLtEq :: SingleType a -> PrimFun ((a, a) -> PrimBool) - PrimGtEq :: SingleType a -> PrimFun ((a, a) -> PrimBool) - PrimEq :: SingleType a -> PrimFun ((a, a) -> PrimBool) - PrimNEq :: SingleType a -> PrimFun ((a, a) -> PrimBool) - PrimMax :: SingleType a -> PrimFun ((a, a) -> a) - PrimMin :: SingleType a -> PrimFun ((a, a) -> a) + PrimLt :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimGt :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimLtEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimGtEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimNEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimMax :: ScalarType a -> PrimFun ((a, a) -> a) + PrimMin :: ScalarType a -> PrimFun ((a, a) -> a) -- logical operators -- @@ -760,13 +747,15 @@ data PrimFun sig where -- short-circuiting, while (&&!) and (||!) are strict versions of these -- operators, which are defined using PrimLAnd and PrimLOr. -- - PrimLAnd :: PrimFun ((PrimBool, PrimBool) -> PrimBool) - PrimLOr :: PrimFun ((PrimBool, PrimBool) -> PrimBool) - PrimLNot :: PrimFun (PrimBool -> PrimBool) + PrimLAnd :: BitType a -> PrimFun ((a, a) -> a) + PrimLOr :: BitType a -> PrimFun ((a, a) -> a) + PrimLNot :: BitType a -> PrimFun (a -> a) -- general conversion between types PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b) PrimToFloating :: NumType a -> FloatingType b -> PrimFun (a -> b) + PrimToBool :: IntegralType a -> BitType b -> PrimFun (a -> b) + PrimFromBool :: BitType a -> IntegralType b -> PrimFun (a -> b) -- Type utilities @@ -828,41 +817,49 @@ expType = \case Foreign tR _ _ _ -> tR Pair e1 e2 -> TupRpair (expType e1) (expType e2) Nil -> TupRunit - VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR - VecUnpack vecR _ -> vecRtuple vecR - VecIndex vecT _ _ _ -> let (VectorType _ s) = vecT in TupRsingle $ SingleScalarType s - VecWrite vecT _ _ _ _ -> TupRsingle $ VectorScalarType vecT + Extract tR _ _ _ -> TupRsingle (scalar tR) + where + scalar :: ScalarType (Vec n a) -> ScalarType a + scalar (NumScalarType t) = NumScalarType (num t) + scalar (BitScalarType t) = BitScalarType (bit t) + + bit :: BitType (Vec n a) -> BitType a + bit TypeMask{} = TypeBit + + num :: NumType (Vec n a) -> NumType a + num (IntegralNumType t) = IntegralNumType (integral t) + num (FloatingNumType t) = FloatingNumType (floating t) + + integral :: IntegralType (Vec n a) -> IntegralType a + integral (SingleIntegralType t) = case t of + integral (VectorIntegralType _ t) = SingleIntegralType t + + floating :: FloatingType (Vec n a) -> FloatingType a + floating (SingleFloatingType t) = case t of + floating (VectorFloatingType _ t) = SingleFloatingType t + -- + Insert t _ _ _ _ -> TupRsingle t + Shuffle t _ _ _ _ -> TupRsingle t + Select _ x _ -> expType x IndexSlice si _ _ -> shapeType $ sliceShapeR si IndexFull si _ _ -> shapeType $ sliceDomainR si - ToIndex{} -> TupRsingle scalarTypeInt - FromIndex shr _ _ -> shapeType shr - Case _ ((_,e):_) _ -> expType e - Case _ [] (Just e) -> expType e + ToIndex{} -> TupRsingle scalarType + FromIndex shR _ _ -> shapeType shR + Case _ _ ((_,e):_) _ -> expType e + Case _ _ [] (Just e) -> expType e Case{} -> internalError "empty case encountered" Cond _ e _ -> expType e While _ (Lam lhs _) _ -> lhsToTupR lhs While{} -> error "What's the matter, you're running in the shadows" Const tR _ -> TupRsingle tR - PrimConst c -> TupRsingle $ primConstType c PrimApp f _ -> snd $ primFunType f Index (Var repr _) _ -> arrayRtype repr LinearIndex (Var repr _) _ -> arrayRtype repr Shape (Var repr _) -> shapeType $ arrayRshape repr - ShapeSize{} -> TupRsingle scalarTypeInt + ShapeSize{} -> TupRsingle (scalarType @INT) Undef tR -> TupRsingle tR Coerce _ tR _ -> TupRsingle tR -primConstType :: PrimConst a -> ScalarType a -primConstType = \case - PrimMinBound t -> bounded t - PrimMaxBound t -> bounded t - PrimPi t -> floating t - where - bounded :: BoundedType a -> ScalarType a - bounded (IntegralBoundedType t) = SingleScalarType $ NumSingleType $ IntegralNumType t - - floating :: FloatingType t -> ScalarType t - floating = SingleScalarType . NumSingleType . FloatingNumType primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b) primFunType = \case @@ -887,13 +884,13 @@ primFunType = \case PrimBOr t -> binary' $ integral t PrimBXor t -> binary' $ integral t PrimBNot t -> unary' $ integral t - PrimBShiftL t -> (integral t `TupRpair` tint, integral t) - PrimBShiftR t -> (integral t `TupRpair` tint, integral t) - PrimBRotateL t -> (integral t `TupRpair` tint, integral t) - PrimBRotateR t -> (integral t `TupRpair` tint, integral t) - PrimPopCount t -> unary (integral t) tint - PrimCountLeadingZeros t -> unary (integral t) tint - PrimCountTrailingZeros t -> unary (integral t) tint + PrimBShiftL t -> (integral t `TupRpair` integral t, integral t) + PrimBShiftR t -> (integral t `TupRpair` integral t, integral t) + PrimBRotateL t -> (integral t `TupRpair` integral t, integral t) + PrimBRotateR t -> (integral t `TupRpair` integral t, integral t) + PrimPopCount t -> unary (integral t) (integral t) + PrimCountLeadingZeros t -> unary (integral t) (integral t) + PrimCountTrailingZeros t -> unary (integral t) (integral t) -- Fractional, Floating PrimFDiv t -> binary' $ floating t @@ -924,8 +921,8 @@ primFunType = \case -- RealFloat PrimAtan2 t -> binary' $ floating t - PrimIsNaN t -> unary (floating t) tbool - PrimIsInfinite t -> unary (floating t) tbool + PrimIsNaN t -> unary (floating t) (floating_mask t) + PrimIsInfinite t -> unary (floating t) (floating_mask t) -- Relational and equality PrimLt t -> compare' t @@ -938,28 +935,62 @@ primFunType = \case PrimMin t -> binary' $ single t -- Logical - PrimLAnd -> binary' tbool - PrimLOr -> binary' tbool - PrimLNot -> unary' tbool + PrimLAnd t -> binary' (bit t) + PrimLOr t -> binary' (bit t) + PrimLNot t -> unary' (bit t) -- general conversion between types PrimFromIntegral a b -> unary (integral a) (num b) PrimToFloating a b -> unary (num a) (floating b) + PrimToBool a b -> unary (integral a) (bit b) + PrimFromBool a b -> unary (bit a) (integral b) where unary a b = (a, b) unary' a = unary a a binary a b = (a `TupRpair` a, b) binary' a = binary a a - compare' a = binary (single a) tbool - - single = TupRsingle . SingleScalarType - num = TupRsingle . SingleScalarType . NumSingleType - integral = num . IntegralNumType - floating = num . FloatingNumType - - tbool = TupRsingle scalarTypeWord8 - tint = TupRsingle scalarTypeInt + compare' a = binary (single a) (scalar_mask a) + + single = TupRsingle + num = single . NumScalarType + bit = single . BitScalarType + integral = num . IntegralNumType + floating = num . FloatingNumType + + scalar_mask :: ScalarType t -> TypeR (BitOrMask t) + scalar_mask (NumScalarType t) = num_mask t + scalar_mask (BitScalarType t) = bit_mask t + + bit_mask :: BitType t -> TypeR (BitOrMask t) + bit_mask TypeBit = bit TypeBit + bit_mask (TypeMask n) = bit (TypeMask n) + + num_mask :: NumType t -> TypeR (BitOrMask t) + num_mask (IntegralNumType t) = integral_mask t + num_mask (FloatingNumType t) = floating_mask t + + integral_mask :: IntegralType t -> TypeR (BitOrMask t) + integral_mask (VectorIntegralType n _) = single (BitScalarType (TypeMask n)) + integral_mask (SingleIntegralType t) = case t of + TypeInt8 -> single (scalarType @Bit) + TypeInt16 -> single (scalarType @Bit) + TypeInt32 -> single (scalarType @Bit) + TypeInt64 -> single (scalarType @Bit) + TypeInt128 -> single (scalarType @Bit) + TypeWord8 -> single (scalarType @Bit) + TypeWord16 -> single (scalarType @Bit) + TypeWord32 -> single (scalarType @Bit) + TypeWord64 -> single (scalarType @Bit) + TypeWord128 -> single (scalarType @Bit) + + floating_mask :: FloatingType t -> TypeR (BitOrMask t) + floating_mask (VectorFloatingType n _) = single (BitScalarType (TypeMask n)) + floating_mask (SingleFloatingType t) = case t of + TypeFloat16 -> single (scalarType @Bit) + TypeFloat32 -> single (scalarType @Bit) + TypeFloat64 -> single (scalarType @Bit) + TypeFloat128 -> single (scalarType @Bit) -- Normal form data @@ -1028,7 +1059,7 @@ rnfPreOpenAcc rnfA pacc = Map tp f a -> rnfTypeR tp `seq` rnfF f `seq` rnfA a ZipWith tp f a1 a2 -> rnfTypeR tp `seq` rnfF f `seq` rnfA a1 `seq` rnfA a2 Fold f z a -> rnfF f `seq` rnfMaybe rnfE z `seq` rnfA a - FoldSeg i f z a s -> rnfIntegralType i `seq` rnfF f `seq` rnfMaybe rnfE z `seq` rnfA a `seq` rnfA s + FoldSeg i f z a s -> rnfSingleIntegralType i `seq` rnfF f `seq` rnfMaybe rnfE z `seq` rnfA a `seq` rnfA s Scan d f z a -> d `seq` rnfF f `seq` rnfMaybe rnfE z `seq` rnfA a Scan' d f z a -> d `seq` rnfF f `seq` rnfE z `seq` rnfA a Permute f d p a -> rnfF f `seq` rnfA d `seq` rnfF p `seq` rnfA a @@ -1089,18 +1120,17 @@ rnfOpenExp topExp = Undef tp -> rnfScalarType tp Pair a b -> rnfE a `seq` rnfE b Nil -> () - VecPack vecr e -> rnfVecR vecr `seq` rnfE e - VecUnpack vecr e -> rnfVecR vecr `seq` rnfE e - VecIndex vt it v i -> rnfVectorType vt `seq` rnfIntegralType it `seq` rnfE v `seq` rnfE i - VecWrite vt it v i e -> rnfVectorType vt `seq` rnfIntegralType it `seq` rnfE v `seq` rnfE i `seq` rnfE e + Extract vR iR v i -> rnfScalarType vR `seq` rnfSingleIntegralType iR `seq` rnfE v `seq` rnfE i + Insert vR iR v i x -> rnfScalarType vR `seq` rnfSingleIntegralType iR `seq` rnfE v `seq` rnfE i `seq` rnfE x + Shuffle eR iR x y i -> rnfScalarType eR `seq` rnfSingleIntegralType iR `seq` rnfE x `seq` rnfE y `seq` rnfE i + Select m x y -> rnfE m `seq` rnfE x `seq` rnfE y IndexSlice slice slix sh -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sh IndexFull slice slix sl -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sl ToIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix FromIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix - Case e rhs def -> rnfE e `seq` rnfList (\(t,c) -> t `seq` rnfE c) rhs `seq` rnfMaybe rnfE def + Case pR p rhs def -> rnfScalarType pR `seq` rnfE p `seq` rnfList (\(t,c) -> t `seq` rnfE c) rhs `seq` rnfMaybe rnfE def Cond p e1 e2 -> rnfE p `seq` rnfE e1 `seq` rnfE e2 While p f x -> rnfF p `seq` rnfF f `seq` rnfE x - PrimConst c -> rnfPrimConst c PrimApp f x -> rnfPrimFun f `seq` rnfE x Index a ix -> rnfArrayVar a `seq` rnfE ix LinearIndex a ix -> rnfArrayVar a `seq` rnfE ix @@ -1119,11 +1149,6 @@ rnfConst TupRunit () = () rnfConst (TupRsingle t) !_ = rnfScalarType t -- scalars should have (nf == whnf) rnfConst (TupRpair ta tb) (a,b) = rnfConst ta a `seq` rnfConst tb b -rnfPrimConst :: PrimConst c -> () -rnfPrimConst (PrimMinBound t) = rnfBoundedType t -rnfPrimConst (PrimMaxBound t) = rnfBoundedType t -rnfPrimConst (PrimPi t) = rnfFloatingType t - rnfPrimFun :: PrimFun f -> () rnfPrimFun (PrimAdd t) = rnfNumType t rnfPrimFun (PrimSub t) = rnfNumType t @@ -1174,19 +1199,21 @@ rnfPrimFun (PrimCeiling f i) = rnfFloatingType f `seq` rnfIntegralType rnfPrimFun (PrimIsNaN t) = rnfFloatingType t rnfPrimFun (PrimIsInfinite t) = rnfFloatingType t rnfPrimFun (PrimAtan2 t) = rnfFloatingType t -rnfPrimFun (PrimLt t) = rnfSingleType t -rnfPrimFun (PrimGt t) = rnfSingleType t -rnfPrimFun (PrimLtEq t) = rnfSingleType t -rnfPrimFun (PrimGtEq t) = rnfSingleType t -rnfPrimFun (PrimEq t) = rnfSingleType t -rnfPrimFun (PrimNEq t) = rnfSingleType t -rnfPrimFun (PrimMax t) = rnfSingleType t -rnfPrimFun (PrimMin t) = rnfSingleType t -rnfPrimFun PrimLAnd = () -rnfPrimFun PrimLOr = () -rnfPrimFun PrimLNot = () +rnfPrimFun (PrimLt t) = rnfScalarType t +rnfPrimFun (PrimGt t) = rnfScalarType t +rnfPrimFun (PrimLtEq t) = rnfScalarType t +rnfPrimFun (PrimGtEq t) = rnfScalarType t +rnfPrimFun (PrimEq t) = rnfScalarType t +rnfPrimFun (PrimNEq t) = rnfScalarType t +rnfPrimFun (PrimMax t) = rnfScalarType t +rnfPrimFun (PrimMin t) = rnfScalarType t +rnfPrimFun (PrimLAnd t) = rnfBitType t +rnfPrimFun (PrimLOr t) = rnfBitType t +rnfPrimFun (PrimLNot t) = rnfBitType t rnfPrimFun (PrimFromIntegral i n) = rnfIntegralType i `seq` rnfNumType n rnfPrimFun (PrimToFloating n f) = rnfNumType n `seq` rnfFloatingType f +rnfPrimFun (PrimToBool i b) = rnfIntegralType i `seq` rnfBitType b +rnfPrimFun (PrimFromBool b i) = rnfBitType b `seq` rnfIntegralType i -- Template Haskell @@ -1238,7 +1265,7 @@ liftPreOpenAcc liftA pacc = Map tp f a -> [|| Map $$(liftTypeR tp) $$(liftF f) $$(liftA a) ||] ZipWith tp f a b -> [|| ZipWith $$(liftTypeR tp) $$(liftF f) $$(liftA a) $$(liftA b) ||] Fold f z a -> [|| Fold $$(liftF f) $$(liftMaybe liftE z) $$(liftA a) ||] - FoldSeg i f z a s -> [|| FoldSeg $$(liftIntegralType i) $$(liftF f) $$(liftMaybe liftE z) $$(liftA a) $$(liftA s) ||] + FoldSeg i f z a s -> [|| FoldSeg $$(liftSingleIntegralType i) $$(liftF f) $$(liftMaybe liftE z) $$(liftA a) $$(liftA s) ||] Scan d f z a -> [|| Scan $$(liftDirection d) $$(liftF f) $$(liftMaybe liftE z) $$(liftA a) ||] Scan' d f z a -> [|| Scan' $$(liftDirection d) $$(liftF f) $$(liftE z) $$(liftA a) ||] Permute f d p a -> [|| Permute $$(liftF f) $$(liftA d) $$(liftF p) $$(liftA a) ||] @@ -1311,18 +1338,17 @@ liftOpenExp pexp = Undef tp -> [|| Undef $$(liftScalarType tp) ||] Pair a b -> [|| Pair $$(liftE a) $$(liftE b) ||] Nil -> [|| Nil ||] - VecPack vecr e -> [|| VecPack $$(liftVecR vecr) $$(liftE e) ||] - VecUnpack vecr e -> [|| VecUnpack $$(liftVecR vecr) $$(liftE e) ||] - VecIndex vt it v i -> [|| VecIndex $$(liftVectorType vt) $$(liftIntegralType it) $$(liftE v) $$(liftE i) ||] - VecWrite vt it v i e -> [|| VecWrite $$(liftVectorType vt) $$(liftIntegralType it) $$(liftE v) $$(liftE i) $$(liftE e) ||] + Extract vR iR v i -> [|| Extract $$(liftScalarType vR) $$(liftSingleIntegralType iR) $$(liftE v) $$(liftE i) ||] + Insert vR iR v i x -> [|| Insert $$(liftScalarType vR) $$(liftSingleIntegralType iR) $$(liftE v) $$(liftE i) $$(liftE x) ||] + Shuffle eR iR x y i -> [|| Shuffle $$(liftScalarType eR) $$(liftSingleIntegralType iR) $$(liftE x) $$(liftE y) $$(liftE i) ||] + Select m x y -> [|| Select $$(liftE m) $$(liftE x) $$(liftE y) ||] IndexSlice slice slix sh -> [|| IndexSlice $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sh) ||] IndexFull slice slix sl -> [|| IndexFull $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sl) ||] ToIndex shr sh ix -> [|| ToIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||] FromIndex shr sh ix -> [|| FromIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||] - Case p rhs def -> [|| Case $$(liftE p) $$(liftList (\(t,c) -> [|| (t, $$(liftE c)) ||]) rhs) $$(liftMaybe liftE def) ||] + Case pR p rhs def -> [|| Case $$(liftScalarType pR) $$(liftE p) $$(liftList (\(t,c) -> [|| ($$(liftScalar pR t), $$(liftE c)) ||]) rhs) $$(liftMaybe liftE def) ||] Cond p t e -> [|| Cond $$(liftE p) $$(liftE t) $$(liftE e) ||] While p f x -> [|| While $$(liftF p) $$(liftF f) $$(liftE x) ||] - PrimConst t -> [|| PrimConst $$(liftPrimConst t) ||] PrimApp f x -> [|| PrimApp $$(liftPrimFun f) $$(liftE x) ||] Index a ix -> [|| Index $$(liftArrayVar a) $$(liftE ix) ||] LinearIndex a ix -> [|| LinearIndex $$(liftArrayVar a) $$(liftE ix) ||] @@ -1347,11 +1373,6 @@ liftBoundary _ Wrap = [|| Wrap ||] liftBoundary (ArrayR _ tp) (Constant v) = [|| Constant $$(liftElt tp v) ||] liftBoundary _ (Function f) = [|| Function $$(liftOpenFun f) ||] -liftPrimConst :: PrimConst c -> CodeQ (PrimConst c) -liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||] -liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||] -liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||] - liftPrimFun :: PrimFun f -> CodeQ (PrimFun f) liftPrimFun (PrimAdd t) = [|| PrimAdd $$(liftNumType t) ||] liftPrimFun (PrimSub t) = [|| PrimSub $$(liftNumType t) ||] @@ -1402,19 +1423,21 @@ liftPrimFun (PrimCeiling ta tb) = [|| PrimCeiling $$(liftFloatingType ta) liftPrimFun (PrimIsNaN t) = [|| PrimIsNaN $$(liftFloatingType t) ||] liftPrimFun (PrimIsInfinite t) = [|| PrimIsInfinite $$(liftFloatingType t) ||] liftPrimFun (PrimAtan2 t) = [|| PrimAtan2 $$(liftFloatingType t) ||] -liftPrimFun (PrimLt t) = [|| PrimLt $$(liftSingleType t) ||] -liftPrimFun (PrimGt t) = [|| PrimGt $$(liftSingleType t) ||] -liftPrimFun (PrimLtEq t) = [|| PrimLtEq $$(liftSingleType t) ||] -liftPrimFun (PrimGtEq t) = [|| PrimGtEq $$(liftSingleType t) ||] -liftPrimFun (PrimEq t) = [|| PrimEq $$(liftSingleType t) ||] -liftPrimFun (PrimNEq t) = [|| PrimNEq $$(liftSingleType t) ||] -liftPrimFun (PrimMax t) = [|| PrimMax $$(liftSingleType t) ||] -liftPrimFun (PrimMin t) = [|| PrimMin $$(liftSingleType t) ||] -liftPrimFun PrimLAnd = [|| PrimLAnd ||] -liftPrimFun PrimLOr = [|| PrimLOr ||] -liftPrimFun PrimLNot = [|| PrimLNot ||] +liftPrimFun (PrimLt t) = [|| PrimLt $$(liftScalarType t) ||] +liftPrimFun (PrimGt t) = [|| PrimGt $$(liftScalarType t) ||] +liftPrimFun (PrimLtEq t) = [|| PrimLtEq $$(liftScalarType t) ||] +liftPrimFun (PrimGtEq t) = [|| PrimGtEq $$(liftScalarType t) ||] +liftPrimFun (PrimEq t) = [|| PrimEq $$(liftScalarType t) ||] +liftPrimFun (PrimNEq t) = [|| PrimNEq $$(liftScalarType t) ||] +liftPrimFun (PrimMax t) = [|| PrimMax $$(liftScalarType t) ||] +liftPrimFun (PrimMin t) = [|| PrimMin $$(liftScalarType t) ||] +liftPrimFun (PrimLAnd t) = [|| PrimLAnd $$(liftBitType t) ||] +liftPrimFun (PrimLOr t) = [|| PrimLOr $$(liftBitType t) ||] +liftPrimFun (PrimLNot t) = [|| PrimLNot $$(liftBitType t) ||] liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||] liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||] +liftPrimFun (PrimToBool ta tb) = [|| PrimToBool $$(liftIntegralType ta) $$(liftBitType tb) ||] +liftPrimFun (PrimFromBool ta tb) = [|| PrimFromBool $$(liftBitType ta) $$(liftIntegralType tb) ||] formatDirection :: Format r (Direction -> r) @@ -1460,10 +1483,10 @@ formatExpOp = later $ \case Foreign{} -> "Foreign" Pair{} -> "Pair" Nil{} -> "Nil" - VecPack{} -> "VecPack" - VecUnpack{} -> "VecUnpack" - VecIndex{} -> "VecIndex" - VecWrite{} -> "VecWrite" + Insert{} -> "Insert" + Extract{} -> "Extract" + Shuffle{} -> "Shuffle" + Select{} -> "Select" IndexSlice{} -> "IndexSlice" IndexFull{} -> "IndexFull" ToIndex{} -> "ToIndex" @@ -1471,7 +1494,6 @@ formatExpOp = later $ \case Case{} -> "Case" Cond{} -> "Cond" While{} -> "While" - PrimConst{} -> "PrimConst" PrimApp{} -> "PrimApp" Index{} -> "Index" LinearIndex{} -> "LinearIndex" diff --git a/src/Data/Array/Accelerate/AST/Idx.hs b/src/Data/Array/Accelerate/AST/Idx.hs index 548453e2b..0f76c7a4c 100644 --- a/src/Data/Array/Accelerate/AST/Idx.hs +++ b/src/Data/Array/Accelerate/AST/Idx.hs @@ -2,6 +2,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -31,11 +32,12 @@ module Data.Array.Accelerate.AST.Idx ( ) where -import Language.Haskell.TH.Extra +import Data.Kind +import Language.Haskell.TH.Extra hiding ( Type ) #ifndef ACCELERATE_INTERNAL_CHECKS -import Data.Type.Equality ((:~:)(Refl)) -import Unsafe.Coerce (unsafeCoerce) +import Data.Type.Equality ( (:~:)(..) ) +import Unsafe.Coerce ( unsafeCoerce ) #endif @@ -72,7 +74,8 @@ liftIdx (SuccIdx ix) = [|| SuccIdx $$(liftIdx ix) ||] -- -- For performance, it uses an Int under the hood. -- -newtype Idx env t = UnsafeIdxConstructor { unsafeRunIdx :: Int } +newtype Idx :: Type -> Type -> Type where + UnsafeIdxConstructor :: { unsafeRunIdx :: Int } -> Idx env t {-# COMPLETE ZeroIdx, SuccIdx #-} diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 9be163e5c..c97abcb77 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -49,6 +49,7 @@ import Data.Array.Accelerate.Type import Data.Primitive.Vec import Crypto.Hash.XKCP +import Foreign.C.Types import Data.ByteString.Builder import Data.ByteString.Builder.Extra import Data.ByteString.Short.Internal ( ShortByteString(..) ) @@ -58,6 +59,8 @@ import System.Mem.StableName ( hashStable import Prelude hiding ( exp ) import qualified Data.Hashable as Hashable +import GHC.TypeLits + -- Hashing -- ------- @@ -263,7 +266,7 @@ encodeArraysType :: ArraysR arrs -> Builder encodeArraysType = encodeTupR encodeArrayType encodeShapeR :: ShapeR sh -> Builder -encodeShapeR = intHost . rank +encodeShapeR = int64Host . rank encodePreOpenAfun :: forall acc aenv f. @@ -288,13 +291,13 @@ encodeBoundary encodeBoundary _ Wrap = intHost $(hashQ "Wrap") encodeBoundary _ Clamp = intHost $(hashQ "Clamp") encodeBoundary _ Mirror = intHost $(hashQ "Mirror") -encodeBoundary tp (Constant c) = intHost $(hashQ "Constant") <> encodeConst tp c +encodeBoundary tR (Constant c) = intHost $(hashQ "Constant") <> encodeConst tR c encodeBoundary _ (Function f) = intHost $(hashQ "Function") <> encodeOpenFun f encodeSliceIndex :: SliceIndex slix sl co sh -> Builder -encodeSliceIndex SliceNil = intHost $(hashQ "SliceNil") -encodeSliceIndex (SliceAll r) = intHost $(hashQ "SliceAll") <> encodeSliceIndex r -encodeSliceIndex (SliceFixed r) = intHost $(hashQ "sliceFixed") <> encodeSliceIndex r +encodeSliceIndex SliceNil = intHost $(hashQ "SliceNil") +encodeSliceIndex (SliceAll r) = intHost $(hashQ "SliceAll") <> encodeSliceIndex r +encodeSliceIndex (SliceFixed r) = intHost $(hashQ "sliceFixed") <> encodeSliceIndex r -- Scalar expressions @@ -318,21 +321,20 @@ encodeOpenExp exp = Evar (Var tp ix) -> intHost $(hashQ "Evar") <> encodeScalarType tp <> encodeIdx ix Nil -> intHost $(hashQ "Nil") Pair e1 e2 -> intHost $(hashQ "Pair") <> travE e1 <> travE e2 - VecPack _ e -> intHost $(hashQ "VecPack") <> travE e - VecUnpack _ e -> intHost $(hashQ "VecUnpack") <> travE e - VecIndex _ _ v i -> intHost $(hashQ "VecIndex") <> travE v <> travE i - VecWrite _ _ v i e -> intHost $(hashQ "VecWrite") <> travE v <> travE i <> travE e + Extract vR iR v i -> intHost $(hashQ "Extract") <> encodeScalarType vR <> encodeSingleIntegralType iR <> travE v <> travE i + Insert vR iR v i x -> intHost $(hashQ "Insert") <> encodeScalarType vR <> encodeSingleIntegralType iR <> travE v <> travE i <> travE x + Shuffle eR iR x y i -> intHost $(hashQ "Shuffle") <> encodeScalarType eR <> encodeSingleIntegralType iR <> travE x <> travE y <> travE i + Select m x y -> intHost $(hashQ "Select") <> travE m <>travE x <> travE y Const tp c -> intHost $(hashQ "Const") <> encodeScalarConst tp c Undef tp -> intHost $(hashQ "Undef") <> encodeScalarType tp IndexSlice spec ix sh -> intHost $(hashQ "IndexSlice") <> travE ix <> travE sh <> encodeSliceIndex spec IndexFull spec ix sl -> intHost $(hashQ "IndexFull") <> travE ix <> travE sl <> encodeSliceIndex spec ToIndex _ sh i -> intHost $(hashQ "ToIndex") <> travE sh <> travE i FromIndex _ sh i -> intHost $(hashQ "FromIndex") <> travE sh <> travE i - Case e rhs def -> intHost $(hashQ "Case") <> travE e <> mconcat [ word8 t <> travE c | (t,c) <- rhs ] <> encodeMaybe travE def + Case eR e rhs def -> intHost $(hashQ "Case") <> encodeScalarType eR <> travE e <> mconcat [ encodeScalarConst eR t <> travE c | (t,c) <- rhs ] <> encodeMaybe travE def Cond c t e -> intHost $(hashQ "Cond") <> travE c <> travE t <> travE e While p f x -> intHost $(hashQ "While") <> travF p <> travF f <> travE x PrimApp f x -> intHost $(hashQ "PrimApp") <> encodePrimFun f <> travE x - PrimConst c -> intHost $(hashQ "PrimConst") <> encodePrimConst c Index a ix -> intHost $(hashQ "Index") <> encodeArrayVar a <> travE ix LinearIndex a ix -> intHost $(hashQ "LinearIndex") <> encodeArrayVar a <> travE ix Shape a -> intHost $(hashQ "Shape") <> encodeArrayVar a @@ -357,40 +359,43 @@ encodeConst (TupRsingle t) c = encodeScalarConst t c encodeConst (TupRpair ta tb) (a,b) = intHost $(hashQ "pair") <> encodeConst ta a <> encodeConst tb b encodeScalarConst :: ScalarType t -> t -> Builder -encodeScalarConst (SingleScalarType t) = encodeSingleConst t -encodeScalarConst (VectorScalarType t) = encodeVectorConst t - -encodeSingleConst :: SingleType t -> t -> Builder -encodeSingleConst (NumSingleType t) = encodeNumConst t +encodeScalarConst (NumScalarType t) = encodeNumConst t +encodeScalarConst (BitScalarType t) = encodeBitConst t -encodeVectorConst :: VectorType (Vec n t) -> Vec n t -> Builder -encodeVectorConst (VectorType n t) (Vec ba#) = intHost $(hashQ "Vec") <> intHost n <> encodeSingleType t <> shortByteString (SBS ba#) +encodeBitConst :: BitType t -> t -> Builder +encodeBitConst TypeBit (Bit False) = intHost $(hashQ "Bit") <> int8 0 +encodeBitConst TypeBit (Bit True) = intHost $(hashQ "Bit") <> int8 1 +encodeBitConst (TypeMask n) (Vec ba#) = intHost $(hashQ "BitMask") <> int8 (fromIntegral (natVal' n)) <> shortByteString (SBS ba#) encodeNumConst :: NumType t -> t -> Builder encodeNumConst (IntegralNumType t) = encodeIntegralConst t encodeNumConst (FloatingNumType t) = encodeFloatingConst t encodeIntegralConst :: IntegralType t -> t -> Builder -encodeIntegralConst TypeInt{} x = intHost $(hashQ "Int") <> intHost x -encodeIntegralConst TypeInt8{} x = intHost $(hashQ "Int8") <> int8 x -encodeIntegralConst TypeInt16{} x = intHost $(hashQ "Int16") <> int16Host x -encodeIntegralConst TypeInt32{} x = intHost $(hashQ "Int32") <> int32Host x -encodeIntegralConst TypeInt64{} x = intHost $(hashQ "Int64") <> int64Host x -encodeIntegralConst TypeWord{} x = intHost $(hashQ "Word") <> wordHost x -encodeIntegralConst TypeWord8{} x = intHost $(hashQ "Word8") <> word8 x -encodeIntegralConst TypeWord16{} x = intHost $(hashQ "Word16") <> word16Host x -encodeIntegralConst TypeWord32{} x = intHost $(hashQ "Word32") <> word32Host x -encodeIntegralConst TypeWord64{} x = intHost $(hashQ "Word64") <> word64Host x +encodeIntegralConst (SingleIntegralType t) x = encodeSingleIntegralConst t x +encodeIntegralConst (VectorIntegralType n t) (Vec ba#) = intHost $(hashQ "Vec") <> int8 (fromIntegral (natVal' n)) <> encodeSingleIntegralType t <> shortByteString (SBS ba#) + +encodeSingleIntegralConst :: SingleIntegralType t -> t -> Builder +encodeSingleIntegralConst TypeInt8 x = intHost $(hashQ "Int8") <> int8 x +encodeSingleIntegralConst TypeInt16 x = intHost $(hashQ "Int16") <> int16Host x +encodeSingleIntegralConst TypeInt32 x = intHost $(hashQ "Int32") <> int32Host x +encodeSingleIntegralConst TypeInt64 x = intHost $(hashQ "Int64") <> int64Host x +encodeSingleIntegralConst TypeInt128 (Int128 x y) = intHost $(hashQ "Int128") <> word64Host x <> word64Host y +encodeSingleIntegralConst TypeWord8 x = intHost $(hashQ "Word8") <> word8 x +encodeSingleIntegralConst TypeWord16 x = intHost $(hashQ "Word16") <> word16Host x +encodeSingleIntegralConst TypeWord32 x = intHost $(hashQ "Word32") <> word32Host x +encodeSingleIntegralConst TypeWord64 x = intHost $(hashQ "Word64") <> word64Host x +encodeSingleIntegralConst TypeWord128 (Word128 x y) = intHost $(hashQ "Word128") <> word64Host x <> word64Host y encodeFloatingConst :: FloatingType t -> t -> Builder -encodeFloatingConst TypeHalf{} (Half (CUShort x)) = intHost $(hashQ "Half") <> word16Host x -encodeFloatingConst TypeFloat{} x = intHost $(hashQ "Float") <> floatHost x -encodeFloatingConst TypeDouble{} x = intHost $(hashQ "Double") <> doubleHost x +encodeFloatingConst (SingleFloatingType t) x = encodeSingleFloatingConst t x +encodeFloatingConst (VectorFloatingType n t) (Vec ba#) = intHost $(hashQ "Vec") <> int8 (fromIntegral (natVal' n)) <> encodeSingleFloatingType t <> shortByteString (SBS ba#) -encodePrimConst :: PrimConst c -> Builder -encodePrimConst (PrimMinBound t) = intHost $(hashQ "PrimMinBound") <> encodeBoundedType t -encodePrimConst (PrimMaxBound t) = intHost $(hashQ "PrimMaxBound") <> encodeBoundedType t -encodePrimConst (PrimPi t) = intHost $(hashQ "PrimPi") <> encodeFloatingType t +encodeSingleFloatingConst :: SingleFloatingType t -> t -> Builder +encodeSingleFloatingConst TypeFloat16 (Half (CUShort x)) = intHost $(hashQ "Half") <> word16Host x +encodeSingleFloatingConst TypeFloat32 x = intHost $(hashQ "Float") <> floatHost x +encodeSingleFloatingConst TypeFloat64 x = intHost $(hashQ "Double") <> doubleHost x +encodeSingleFloatingConst TypeFloat128 (Float128 x y) = intHost $(hashQ "Float128") <> word64Host x <> word64Host y encodePrimFun :: PrimFun f -> Builder encodePrimFun (PrimAdd a) = intHost $(hashQ "PrimAdd") <> encodeNumType a @@ -442,20 +447,21 @@ encodePrimFun (PrimFloor a b) = intHost $(hashQ "PrimFloor") encodePrimFun (PrimCeiling a b) = intHost $(hashQ "PrimCeiling") <> encodeFloatingType a <> encodeIntegralType b encodePrimFun (PrimIsNaN a) = intHost $(hashQ "PrimIsNaN") <> encodeFloatingType a encodePrimFun (PrimIsInfinite a) = intHost $(hashQ "PrimIsInfinite") <> encodeFloatingType a -encodePrimFun (PrimLt a) = intHost $(hashQ "PrimLt") <> encodeSingleType a -encodePrimFun (PrimGt a) = intHost $(hashQ "PrimGt") <> encodeSingleType a -encodePrimFun (PrimLtEq a) = intHost $(hashQ "PrimLtEq") <> encodeSingleType a -encodePrimFun (PrimGtEq a) = intHost $(hashQ "PrimGtEq") <> encodeSingleType a -encodePrimFun (PrimEq a) = intHost $(hashQ "PrimEq") <> encodeSingleType a -encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") <> encodeSingleType a -encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeSingleType a -encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeSingleType a +encodePrimFun (PrimLt a) = intHost $(hashQ "PrimLt") <> encodeScalarType a +encodePrimFun (PrimGt a) = intHost $(hashQ "PrimGt") <> encodeScalarType a +encodePrimFun (PrimLtEq a) = intHost $(hashQ "PrimLtEq") <> encodeScalarType a +encodePrimFun (PrimGtEq a) = intHost $(hashQ "PrimGtEq") <> encodeScalarType a +encodePrimFun (PrimEq a) = intHost $(hashQ "PrimEq") <> encodeScalarType a +encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") <> encodeScalarType a +encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeScalarType a +encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeScalarType a +encodePrimFun (PrimLAnd a) = intHost $(hashQ "PrimLAnd") <> encodeBitType a +encodePrimFun (PrimLOr a) = intHost $(hashQ "PrimLOr") <> encodeBitType a +encodePrimFun (PrimLNot a) = intHost $(hashQ "PrimLNot") <> encodeBitType a encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b -encodePrimFun PrimLAnd = intHost $(hashQ "PrimLAnd") -encodePrimFun PrimLOr = intHost $(hashQ "PrimLOr") -encodePrimFun PrimLNot = intHost $(hashQ "PrimLNot") - +encodePrimFun (PrimToBool a b) = intHost $(hashQ "PrimToBool") <> encodeIntegralType a <> encodeBitType b +encodePrimFun (PrimFromBool a b) = intHost $(hashQ "PrimFromBool") <> encodeBitType a <> encodeIntegralType b encodeTypeR :: TypeR t -> Builder encodeTypeR TupRunit = intHost $(hashQ "TupRunit") @@ -469,38 +475,42 @@ depthTypeR TupRsingle{} = 1 depthTypeR (TupRpair a b) = depthTypeR a + depthTypeR b encodeScalarType :: ScalarType t -> Builder -encodeScalarType (SingleScalarType t) = intHost $(hashQ "SingleScalarType") <> encodeSingleType t -encodeScalarType (VectorScalarType t) = intHost $(hashQ "VectorScalarType") <> encodeVectorType t - -encodeSingleType :: SingleType t -> Builder -encodeSingleType (NumSingleType t) = intHost $(hashQ "NumSingleType") <> encodeNumType t - -encodeVectorType :: VectorType (Vec n t) -> Builder -encodeVectorType (VectorType n t) = intHost $(hashQ "VectorType") <> intHost n <> encodeSingleType t +encodeScalarType (NumScalarType t) = encodeNumType t +encodeScalarType (BitScalarType t) = encodeBitType t -encodeBoundedType :: BoundedType t -> Builder -encodeBoundedType (IntegralBoundedType t) = intHost $(hashQ "IntegralBoundedType") <> encodeIntegralType t +encodeBitType :: BitType t -> Builder +encodeBitType TypeBit = intHost $(hashQ "Bit") +encodeBitType (TypeMask n) = intHost $(hashQ "BitMask") <> int8 (fromIntegral (natVal' n)) encodeNumType :: NumType t -> Builder encodeNumType (IntegralNumType t) = intHost $(hashQ "IntegralNumType") <> encodeIntegralType t encodeNumType (FloatingNumType t) = intHost $(hashQ "FloatingNumType") <> encodeFloatingType t encodeIntegralType :: IntegralType t -> Builder -encodeIntegralType TypeInt{} = intHost $(hashQ "Int") -encodeIntegralType TypeInt8{} = intHost $(hashQ "Int8") -encodeIntegralType TypeInt16{} = intHost $(hashQ "Int16") -encodeIntegralType TypeInt32{} = intHost $(hashQ "Int32") -encodeIntegralType TypeInt64{} = intHost $(hashQ "Int64") -encodeIntegralType TypeWord{} = intHost $(hashQ "Word") -encodeIntegralType TypeWord8{} = intHost $(hashQ "Word8") -encodeIntegralType TypeWord16{} = intHost $(hashQ "Word16") -encodeIntegralType TypeWord32{} = intHost $(hashQ "Word32") -encodeIntegralType TypeWord64{} = intHost $(hashQ "Word64") +encodeIntegralType (SingleIntegralType t) = encodeSingleIntegralType t +encodeIntegralType (VectorIntegralType n t) = intHost $(hashQ "Vec") <> int8 (fromIntegral (natVal' n)) <> encodeSingleIntegralType t + +encodeSingleIntegralType :: SingleIntegralType t -> Builder +encodeSingleIntegralType TypeInt8 = intHost $(hashQ "Int8") +encodeSingleIntegralType TypeInt16 = intHost $(hashQ "Int16") +encodeSingleIntegralType TypeInt32 = intHost $(hashQ "Int32") +encodeSingleIntegralType TypeInt64 = intHost $(hashQ "Int64") +encodeSingleIntegralType TypeInt128 = intHost $(hashQ "Int128") +encodeSingleIntegralType TypeWord8 = intHost $(hashQ "Word8") +encodeSingleIntegralType TypeWord16 = intHost $(hashQ "Word16") +encodeSingleIntegralType TypeWord32 = intHost $(hashQ "Word32") +encodeSingleIntegralType TypeWord64 = intHost $(hashQ "Word64") +encodeSingleIntegralType TypeWord128 = intHost $(hashQ "Word128") encodeFloatingType :: FloatingType t -> Builder -encodeFloatingType TypeHalf{} = intHost $(hashQ "Half") -encodeFloatingType TypeFloat{} = intHost $(hashQ "Float") -encodeFloatingType TypeDouble{} = intHost $(hashQ "Double") +encodeFloatingType (SingleFloatingType t) = encodeSingleFloatingType t +encodeFloatingType (VectorFloatingType n t) = intHost $(hashQ "Vec") <> int8 (fromIntegral (natVal' n)) <> encodeSingleFloatingType t + +encodeSingleFloatingType :: SingleFloatingType t -> Builder +encodeSingleFloatingType TypeFloat16 = intHost $(hashQ "Half") +encodeSingleFloatingType TypeFloat32 = intHost $(hashQ "Float") +encodeSingleFloatingType TypeFloat64 = intHost $(hashQ "Double") +encodeSingleFloatingType TypeFloat128 = intHost $(hashQ "Float128") encodeMaybe :: (a -> Builder) -> Maybe a -> Builder encodeMaybe _ Nothing = intHost $(hashQ "Nothing") diff --git a/src/Data/Array/Accelerate/Analysis/Match.hs b/src/Data/Array/Accelerate/Analysis/Match.hs index 98a0818a4..c25a2620d 100644 --- a/src/Data/Array/Accelerate/Analysis/Match.hs +++ b/src/Data/Array/Accelerate/Analysis/Match.hs @@ -1,6 +1,8 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE CPP #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -30,8 +32,12 @@ module Data.Array.Accelerate.Analysis.Match ( -- auxiliary matchIdx, matchVar, matchVars, matchArrayR, matchArraysR, matchTypeR, matchShapeR, - matchShapeType, matchIntegralType, matchFloatingType, matchNumType, matchScalarType, - matchLeftHandSide, matchALeftHandSide, matchELeftHandSide, matchSingleType, matchTupR + matchShapeType, + matchIntegralType, matchSingleIntegralType, + matchFloatingType, matchSingleFloatingType, + matchNumType, + matchScalarType, + matchLeftHandSide, matchALeftHandSide, matchELeftHandSide, matchTupR, ) where @@ -40,21 +46,23 @@ import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Analysis.Hash -import Data.Array.Accelerate.Representation.Array -import Data.Array.Accelerate.Representation.Shape -import Data.Array.Accelerate.Representation.Slice +import Data.Array.Accelerate.Interpreter.Arithmetic +import Data.Array.Accelerate.Representation.Array ( Array(..), ArraysR, ArrayR(..) ) +import Data.Array.Accelerate.Representation.Shape ( ShapeR(..) ) +import Data.Array.Accelerate.Representation.Slice ( SliceIndex(..) ) import Data.Array.Accelerate.Representation.Stencil import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type -import Data.Primitive.Vec -import qualified Data.Array.Accelerate.Sugar.Shape as Sugar +import qualified Data.Array.Accelerate.Sugar.Shape as Sugar import Data.Maybe import Data.Typeable -import Unsafe.Coerce ( unsafeCoerce ) -import System.IO.Unsafe ( unsafePerformIO ) +import Unsafe.Coerce ( unsafeCoerce ) +import System.IO.Unsafe ( unsafePerformIO ) import System.Mem.StableName -import Prelude hiding ( exp ) +import Prelude hiding ( exp ) + +import GHC.TypeLits.Extra -- The type of matching array computations @@ -455,24 +463,41 @@ matchOpenExp (Foreign _ ff1 f1 e1) (Foreign _ ff2 f2 e2) , Just Refl <- matchOpenFun f1 f2 = Just Refl -matchOpenExp (Const t1 c1) (Const t2 c2) - | Just Refl <- matchScalarType t1 t2 - , matchConst (TupRsingle t1) c1 c2 +matchOpenExp (Pair a1 b1) (Pair a2 b2) + | Just Refl <- matchOpenExp a1 a2 + , Just Refl <- matchOpenExp b1 b2 = Just Refl -matchOpenExp (Undef t1) (Undef t2) = matchScalarType t1 t2 +matchOpenExp Nil Nil + = Just Refl -matchOpenExp (Coerce _ t1 e1) (Coerce _ t2 e2) - | Just Refl <- matchScalarType t1 t2 - , Just Refl <- matchOpenExp e1 e2 +matchOpenExp (Extract vR1 iR1 v1 i1) (Extract vR2 iR2 v2 i2) + | Just Refl <- matchScalarType vR1 vR2 + , Just Refl <- matchSingleIntegralType iR1 iR2 + , Just Refl <- matchOpenExp v1 v2 + , Just Refl <- matchOpenExp i1 i2 = Just Refl -matchOpenExp (Pair a1 b1) (Pair a2 b2) - | Just Refl <- matchOpenExp a1 a2 - , Just Refl <- matchOpenExp b1 b2 +matchOpenExp (Insert vR1 iR1 v1 i1 x1) (Insert vR2 iR2 v2 i2 x2) + | Just Refl <- matchScalarType vR1 vR2 + , Just Refl <- matchSingleIntegralType iR1 iR2 + , Just Refl <- matchOpenExp v1 v2 + , Just Refl <- matchOpenExp i1 i2 + , Just Refl <- matchOpenExp x1 x2 = Just Refl -matchOpenExp Nil Nil +matchOpenExp (Shuffle eR1 iR1 x1 y1 i1) (Shuffle eR2 iR2 x2 y2 i2) + | Just Refl <- matchScalarType eR1 eR2 + , Just Refl <- matchSingleIntegralType iR1 iR2 + , Just Refl <- matchOpenExp x1 x2 + , Just Refl <- matchOpenExp y1 y2 + , Just Refl <- matchOpenExp i1 i2 + = Just Refl + +matchOpenExp (Select p1 x1 y1) (Select p2 x2 y2) + | Just Refl <- matchOpenExp p1 p2 + , Just Refl <- matchOpenExp x1 x2 + , Just Refl <- matchOpenExp y1 y2 = Just Refl matchOpenExp (IndexSlice sliceIndex1 ix1 sh1) (IndexSlice sliceIndex2 ix2 sh2) @@ -497,6 +522,29 @@ matchOpenExp (FromIndex _ sh1 i1) (FromIndex _ sh2 i2) , Just Refl <- matchOpenExp sh1 sh2 = Just Refl +matchOpenExp (Case eR1 e1 rhs1 def1) (Case eR2 e2 rhs2 def2) + | Just Refl <- matchScalarType eR1 eR2 + , Just Refl <- matchOpenExp e1 e2 + , Just Refl <- matchCaseEqs eR1 rhs1 rhs2 + , Just Refl <- matchCaseDef def1 def2 + = Just Refl + where + matchCaseEqs :: ScalarType tag -> [(tag, OpenExp env aenv a)] -> [(tag, OpenExp env aenv b)] -> Maybe (a :~: b) + matchCaseEqs _ [] [] + = unsafeCoerce Refl + matchCaseEqs tR ((s,x):xs) ((t,y):ys) + | evalEq tR (s,t) + , Just Refl <- matchOpenExp x y + , Just Refl <- matchCaseEqs tR xs ys + = Just Refl + matchCaseEqs _ _ _ + = Nothing + + matchCaseDef :: Maybe (OpenExp env aenv a) -> Maybe (OpenExp env aenv b) -> Maybe (a :~: b) + matchCaseDef Nothing Nothing = unsafeCoerce Refl + matchCaseDef (Just x) (Just y) = matchOpenExp x y + matchCaseDef _ _ = Nothing + matchOpenExp (Cond p1 t1 e1) (Cond p2 t2 e2) | Just Refl <- matchOpenExp p1 p2 , Just Refl <- matchOpenExp t1 t2 @@ -509,8 +557,10 @@ matchOpenExp (While p1 f1 x1) (While p2 f2 x2) , Just Refl <- matchOpenFun f1 f2 = Just Refl -matchOpenExp (PrimConst c1) (PrimConst c2) - = matchPrimConst c1 c2 +matchOpenExp (Const t1 c1) (Const t2 c2) + | Just Refl <- matchScalarType t1 t2 + , matchConst (TupRsingle t1) c1 c2 + = Just Refl matchOpenExp (PrimApp f1 x1) (PrimApp f2 x2) | Just x1' <- commutes f1 x1 @@ -541,6 +591,13 @@ matchOpenExp (ShapeSize _ sh1) (ShapeSize _ sh2) | Just Refl <- matchOpenExp sh1 sh2 = Just Refl +matchOpenExp (Undef t1) (Undef t2) = matchScalarType t1 t2 + +matchOpenExp (Coerce _ t1 e1) (Coerce _ t2 e2) + | Just Refl <- matchScalarType t1 t2 + , Just Refl <- matchOpenExp e1 e2 + = Just Refl + matchOpenExp _ _ = Nothing @@ -568,18 +625,41 @@ matchConst (TupRsingle ty) a b = evalEq ty (a,b) matchConst (TupRpair ta tb) (a1,b1) (a2,b2) = matchConst ta a1 a2 && matchConst tb b1 b2 evalEq :: ScalarType a -> (a, a) -> Bool -evalEq (SingleScalarType t) = evalEqSingle t -evalEq (VectorScalarType t) = evalEqVector t - -evalEqSingle :: SingleType a -> (a, a) -> Bool -evalEqSingle (NumSingleType t) = evalEqNum t - -evalEqVector :: VectorType a -> (a, a) -> Bool -evalEqVector VectorType{} = uncurry (==) - -evalEqNum :: NumType a -> (a, a) -> Bool -evalEqNum (IntegralNumType t) | IntegralDict <- integralDict t = uncurry (==) -evalEqNum (FloatingNumType t) | FloatingDict <- floatingDict t = uncurry (==) +evalEq t x = lall (scalar t) (eq t x) + where + scalar :: ScalarType s -> BitType (BitOrMask s) + scalar (NumScalarType s) = num s + scalar (BitScalarType s) = bit s + + bit :: BitType s -> BitType (BitOrMask s) + bit TypeBit = TypeBit + bit (TypeMask n) = TypeMask n + + num :: NumType s -> BitType (BitOrMask s) + num (IntegralNumType s) = integral s + num (FloatingNumType s) = floating s + + integral :: IntegralType s -> BitType (BitOrMask s) + integral (VectorIntegralType n _) = TypeMask n + integral (SingleIntegralType s) = case s of + TypeInt8 -> TypeBit + TypeInt16 -> TypeBit + TypeInt32 -> TypeBit + TypeInt64 -> TypeBit + TypeInt128 -> TypeBit + TypeWord8 -> TypeBit + TypeWord16 -> TypeBit + TypeWord32 -> TypeBit + TypeWord64 -> TypeBit + TypeWord128 -> TypeBit + + floating :: FloatingType s -> BitType (BitOrMask s) + floating (VectorFloatingType n _) = TypeMask n + floating (SingleFloatingType s) = case s of + TypeFloat16 -> TypeBit + TypeFloat32 -> TypeBit + TypeFloat64 -> TypeBit + TypeFloat128 -> TypeBit -- Environment projection indices @@ -622,14 +702,8 @@ matchSliceIndex (SliceFixed sl1) (SliceFixed sl2) matchSliceIndex _ _ = Nothing --- Primitive constants and functions --- -matchPrimConst :: PrimConst s -> PrimConst t -> Maybe (s :~: t) -matchPrimConst (PrimMinBound s) (PrimMinBound t) = matchBoundedType s t -matchPrimConst (PrimMaxBound s) (PrimMaxBound t) = matchBoundedType s t -matchPrimConst (PrimPi s) (PrimPi t) = matchFloatingType s t -matchPrimConst _ _ = Nothing - +-- Primitive functions +-- ------------------- -- Covariant function matching -- @@ -692,11 +766,13 @@ matchPrimFun (PrimEq _) (PrimEq _) = Just Refl matchPrimFun (PrimNEq _) (PrimNEq _) = Just Refl matchPrimFun (PrimMax _) (PrimMax _) = Just Refl matchPrimFun (PrimMin _) (PrimMin _) = Just Refl +matchPrimFun (PrimLAnd s) (PrimLAnd t) = matchBitType s t +matchPrimFun (PrimLOr s) (PrimLOr t) = matchBitType s t +matchPrimFun (PrimLNot s) (PrimLNot t) = matchBitType s t matchPrimFun (PrimFromIntegral _ s) (PrimFromIntegral _ t) = matchNumType s t matchPrimFun (PrimToFloating _ s) (PrimToFloating _ t) = matchFloatingType s t -matchPrimFun PrimLAnd PrimLAnd = Just Refl -matchPrimFun PrimLOr PrimLOr = Just Refl -matchPrimFun PrimLNot PrimLNot = Just Refl +matchPrimFun (PrimToBool _ s) (PrimToBool _ t) = matchBitType s t +matchPrimFun (PrimFromBool _ s) (PrimFromBool _ t) = matchIntegralType s t matchPrimFun _ _ = Nothing @@ -757,34 +833,36 @@ matchPrimFun' (PrimIsNaN s) (PrimIsNaN t) = matchFloat matchPrimFun' (PrimIsInfinite s) (PrimIsInfinite t) = matchFloatingType s t matchPrimFun' (PrimMax _) (PrimMax _) = Just Refl matchPrimFun' (PrimMin _) (PrimMin _) = Just Refl +matchPrimFun' (PrimLAnd _) (PrimLAnd _) = Just Refl +matchPrimFun' (PrimLOr _) (PrimLOr _) = Just Refl +matchPrimFun' (PrimLNot _) (PrimLNot _) = Just Refl matchPrimFun' (PrimFromIntegral s _) (PrimFromIntegral t _) = matchIntegralType s t matchPrimFun' (PrimToFloating s _) (PrimToFloating t _) = matchNumType s t -matchPrimFun' PrimLAnd PrimLAnd = Just Refl -matchPrimFun' PrimLOr PrimLOr = Just Refl -matchPrimFun' PrimLNot PrimLNot = Just Refl +matchPrimFun' (PrimToBool s _) (PrimToBool t _) = matchIntegralType s t +matchPrimFun' (PrimFromBool s _) (PrimFromBool t _) = matchBitType s t matchPrimFun' (PrimLt s) (PrimLt t) - | Just Refl <- matchSingleType s t + | Just Refl <- matchScalarType s t = Just Refl matchPrimFun' (PrimGt s) (PrimGt t) - | Just Refl <- matchSingleType s t + | Just Refl <- matchScalarType s t = Just Refl matchPrimFun' (PrimLtEq s) (PrimLtEq t) - | Just Refl <- matchSingleType s t + | Just Refl <- matchScalarType s t = Just Refl matchPrimFun' (PrimGtEq s) (PrimGtEq t) - | Just Refl <- matchSingleType s t + | Just Refl <- matchScalarType s t = Just Refl matchPrimFun' (PrimEq s) (PrimEq t) - | Just Refl <- matchSingleType s t + | Just Refl <- matchScalarType s t = Just Refl matchPrimFun' (PrimNEq s) (PrimNEq t) - | Just Refl <- matchSingleType s t + | Just Refl <- matchScalarType s t = Just Refl matchPrimFun' _ _ @@ -831,24 +909,17 @@ matchShapeR _ _ = Nothing -- {-# INLINEABLE matchScalarType #-} matchScalarType :: ScalarType s -> ScalarType t -> Maybe (s :~: t) -matchScalarType (SingleScalarType s) (SingleScalarType t) = matchSingleType s t -matchScalarType (VectorScalarType s) (VectorScalarType t) = matchVectorType s t -matchScalarType _ _ = Nothing - -{-# INLINEABLE matchSingleType #-} -matchSingleType :: SingleType s -> SingleType t -> Maybe (s :~: t) -matchSingleType (NumSingleType s) (NumSingleType t) = matchNumType s t - -{-# INLINEABLE matchVectorType #-} -matchVectorType :: forall m n s t. VectorType (Vec n s) -> VectorType (Vec m t) -> Maybe (Vec n s :~: Vec m t) -matchVectorType (VectorType n s) (VectorType m t) - | Just Refl <- if n == m - then Just (unsafeCoerce Refl :: n :~: m) -- XXX: we don't have an embedded KnownNat constraint, but - else Nothing -- this implementation is the same as 'GHC.TypeLits.sameNat' - , Just Refl <- matchSingleType s t - = Just Refl -matchVectorType _ _ - = Nothing +matchScalarType (NumScalarType s) (NumScalarType t) = matchNumType s t +matchScalarType (BitScalarType s) (BitScalarType t) = matchBitType s t +matchScalarType _ _ = Nothing + +{-# INLINEABLE matchBitType #-} +matchBitType :: BitType s -> BitType t -> Maybe (s :~: t) +matchBitType TypeBit TypeBit = Just Refl +matchBitType (TypeMask n) (TypeMask m) + | Just Refl <- sameNat' n m + = Just Refl +matchBitType _ _ = Nothing {-# INLINEABLE matchNumType #-} matchNumType :: NumType s -> NumType t -> Maybe (s :~: t) @@ -856,30 +927,45 @@ matchNumType (IntegralNumType s) (IntegralNumType t) = matchIntegralType s t matchNumType (FloatingNumType s) (FloatingNumType t) = matchFloatingType s t matchNumType _ _ = Nothing -{-# INLINEABLE matchBoundedType #-} -matchBoundedType :: BoundedType s -> BoundedType t -> Maybe (s :~: t) -matchBoundedType (IntegralBoundedType s) (IntegralBoundedType t) = matchIntegralType s t - {-# INLINEABLE matchIntegralType #-} matchIntegralType :: IntegralType s -> IntegralType t -> Maybe (s :~: t) -matchIntegralType TypeInt TypeInt = Just Refl -matchIntegralType TypeInt8 TypeInt8 = Just Refl -matchIntegralType TypeInt16 TypeInt16 = Just Refl -matchIntegralType TypeInt32 TypeInt32 = Just Refl -matchIntegralType TypeInt64 TypeInt64 = Just Refl -matchIntegralType TypeWord TypeWord = Just Refl -matchIntegralType TypeWord8 TypeWord8 = Just Refl -matchIntegralType TypeWord16 TypeWord16 = Just Refl -matchIntegralType TypeWord32 TypeWord32 = Just Refl -matchIntegralType TypeWord64 TypeWord64 = Just Refl -matchIntegralType _ _ = Nothing +matchIntegralType (SingleIntegralType s) (SingleIntegralType t) = matchSingleIntegralType s t +matchIntegralType (VectorIntegralType n s) (VectorIntegralType m t) + | Just Refl <- sameNat' n m + , Just Refl <- matchSingleIntegralType s t + = Just Refl +matchIntegralType _ _ = Nothing + +{-# INLINEABLE matchSingleIntegralType #-} +matchSingleIntegralType :: SingleIntegralType s -> SingleIntegralType t -> Maybe (s :~: t) +matchSingleIntegralType TypeInt8 TypeInt8 = Just Refl +matchSingleIntegralType TypeInt64 TypeInt64 = Just Refl +matchSingleIntegralType TypeInt32 TypeInt32 = Just Refl +matchSingleIntegralType TypeInt16 TypeInt16 = Just Refl +matchSingleIntegralType TypeInt128 TypeInt128 = Just Refl +matchSingleIntegralType TypeWord8 TypeWord8 = Just Refl +matchSingleIntegralType TypeWord64 TypeWord64 = Just Refl +matchSingleIntegralType TypeWord32 TypeWord32 = Just Refl +matchSingleIntegralType TypeWord16 TypeWord16 = Just Refl +matchSingleIntegralType TypeWord128 TypeWord128 = Just Refl +matchSingleIntegralType _ _ = Nothing {-# INLINEABLE matchFloatingType #-} matchFloatingType :: FloatingType s -> FloatingType t -> Maybe (s :~: t) -matchFloatingType TypeHalf TypeHalf = Just Refl -matchFloatingType TypeFloat TypeFloat = Just Refl -matchFloatingType TypeDouble TypeDouble = Just Refl -matchFloatingType _ _ = Nothing +matchFloatingType (SingleFloatingType s) (SingleFloatingType t) = matchSingleFloatingType s t +matchFloatingType (VectorFloatingType n s) (VectorFloatingType m t) + | Just Refl <- sameNat' n m + , Just Refl <- matchSingleFloatingType s t + = Just Refl +matchFloatingType _ _ = Nothing + +{-# INLINEABLE matchSingleFloatingType #-} +matchSingleFloatingType :: SingleFloatingType s -> SingleFloatingType t -> Maybe (s :~: t) +matchSingleFloatingType TypeFloat16 TypeFloat16 = Just Refl +matchSingleFloatingType TypeFloat32 TypeFloat32 = Just Refl +matchSingleFloatingType TypeFloat64 TypeFloat64 = Just Refl +matchSingleFloatingType TypeFloat128 TypeFloat128 = Just Refl +matchSingleFloatingType _ _ = Nothing -- Auxiliary @@ -904,14 +990,14 @@ commutes f x = case f of PrimNEq{} -> Just (swizzle x) PrimMax{} -> Just (swizzle x) PrimMin{} -> Just (swizzle x) - PrimLAnd -> Just (swizzle x) - PrimLOr -> Just (swizzle x) + PrimLAnd{} -> Just (swizzle x) + PrimLOr{} -> Just (swizzle x) _ -> Nothing where swizzle :: OpenExp env aenv (a',a') -> OpenExp env aenv (a',a') - swizzle exp - | (a `Pair` b) <- exp + swizzle e + | (a `Pair` b) <- e , hashOpenExp a > hashOpenExp b = b `Pair` a -- - | otherwise = exp + | otherwise = e diff --git a/src/Data/Array/Accelerate/Array/Data.hs b/src/Data/Array/Accelerate/Array/Data.hs index a22475bce..24227a75d 100644 --- a/src/Data/Array/Accelerate/Array/Data.hs +++ b/src/Data/Array/Accelerate/Array/Data.hs @@ -1,6 +1,8 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -35,16 +37,9 @@ module Data.Array.Accelerate.Array.Data ( touchArrayData, rnfArrayData, - -- * Type macros - HTYPE_INT, HTYPE_WORD, HTYPE_CLONG, HTYPE_CULONG, HTYPE_CCHAR, - -- * Allocator internals registerForeignPtrAllocator, - -- * Utilities for type classes - ScalarArrayDict(..), scalarArrayDict, - SingleArrayDict(..), singleArrayDict, - -- * TemplateHaskell liftArrayData, @@ -54,6 +49,7 @@ import Data.Array.Accelerate.Array.Unique import Data.Array.Accelerate.Error import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type +import Data.Primitive.Bit import Data.Primitive.Vec #ifdef ACCELERATE_DEBUG import Data.Array.Accelerate.Lifetime @@ -66,18 +62,17 @@ import Data.Array.Accelerate.Debug.Internal.Trace import Control.Applicative import Control.DeepSeq import Control.Monad ( (<=<) ) -import Data.Bits +import Data.Bits ( testBit, setBit, clearBit ) import Data.IORef -import Data.Primitive ( sizeOf# ) import Foreign.ForeignPtr -import Foreign.Storable import Formatting hiding ( bytes ) -import Language.Haskell.TH.Extra hiding ( Type ) -import Prelude hiding ( mapM ) +import Language.Haskell.TH.Extra import System.IO.Unsafe +import Prelude hiding ( mapM ) import GHC.Exts hiding ( build ) import GHC.ForeignPtr +import GHC.TypeLits import GHC.Types @@ -100,168 +95,405 @@ type family GArrayDataR ba a where type ScalarArrayData a = UniqueArray (ScalarArrayDataR a) --- | Mapping from scalar type to the type as represented in memory in an --- array. +-- | Mapping from scalar type to the type as represented in memory in an array -- type family ScalarArrayDataR t where - ScalarArrayDataR Int = Int - ScalarArrayDataR Int8 = Int8 - ScalarArrayDataR Int16 = Int16 - ScalarArrayDataR Int32 = Int32 - ScalarArrayDataR Int64 = Int64 - ScalarArrayDataR Word = Word - ScalarArrayDataR Word8 = Word8 - ScalarArrayDataR Word16 = Word16 - ScalarArrayDataR Word32 = Word32 - ScalarArrayDataR Word64 = Word64 - ScalarArrayDataR Half = Half - ScalarArrayDataR Float = Float - ScalarArrayDataR Double = Double - ScalarArrayDataR (Vec n t) = ScalarArrayDataR t - - -data ScalarArrayDict a where - ScalarArrayDict :: ( ArrayData a ~ ScalarArrayData a, ScalarArrayDataR a ~ ScalarArrayDataR b ) - => {-# UNPACK #-} !Int -- vector width - -> SingleType b -- base type - -> ScalarArrayDict a - -data SingleArrayDict a where - SingleArrayDict :: ( ArrayData a ~ ScalarArrayData a, ScalarArrayDataR a ~ a ) - => SingleArrayDict a - -scalarArrayDict :: ScalarType a -> ScalarArrayDict a -scalarArrayDict = scalar - where - scalar :: ScalarType a -> ScalarArrayDict a - scalar (VectorScalarType t) = vector t - scalar (SingleScalarType t) - | SingleArrayDict <- singleArrayDict t - = ScalarArrayDict 1 t - - vector :: VectorType a -> ScalarArrayDict a - vector (VectorType w s) - | SingleArrayDict <- singleArrayDict s - = ScalarArrayDict w s - -singleArrayDict :: SingleType a -> SingleArrayDict a -singleArrayDict = single - where - single :: SingleType a -> SingleArrayDict a - single (NumSingleType t) = num t - - num :: NumType a -> SingleArrayDict a - num (IntegralNumType t) = integral t - num (FloatingNumType t) = floating t - - integral :: IntegralType a -> SingleArrayDict a - integral TypeInt = SingleArrayDict - integral TypeInt8 = SingleArrayDict - integral TypeInt16 = SingleArrayDict - integral TypeInt32 = SingleArrayDict - integral TypeInt64 = SingleArrayDict - integral TypeWord = SingleArrayDict - integral TypeWord8 = SingleArrayDict - integral TypeWord16 = SingleArrayDict - integral TypeWord32 = SingleArrayDict - integral TypeWord64 = SingleArrayDict - - floating :: FloatingType a -> SingleArrayDict a - floating TypeHalf = SingleArrayDict - floating TypeFloat = SingleArrayDict - floating TypeDouble = SingleArrayDict + ScalarArrayDataR Bit = Word8 + ScalarArrayDataR Int8 = Int8 + ScalarArrayDataR Int16 = Int16 + ScalarArrayDataR Int32 = Int32 + ScalarArrayDataR Int64 = Int64 + ScalarArrayDataR Int128 = Int128 + ScalarArrayDataR Word8 = Word8 + ScalarArrayDataR Word16 = Word16 + ScalarArrayDataR Word32 = Word32 + ScalarArrayDataR Word64 = Word64 + ScalarArrayDataR Word128 = Word128 + ScalarArrayDataR Half = Half + ScalarArrayDataR Float = Float + ScalarArrayDataR Double = Double + ScalarArrayDataR Float128 = Float128 + -- + ScalarArrayDataR (Vec n Bit) = BitMask n + ScalarArrayDataR (Vec n Int8) = Vec n Int8 + ScalarArrayDataR (Vec n Int16) = Vec n Int16 + ScalarArrayDataR (Vec n Int32) = Vec n Int32 + ScalarArrayDataR (Vec n Int64) = Vec n Int64 + ScalarArrayDataR (Vec n Int128) = Vec n Int128 + ScalarArrayDataR (Vec n Word8) = Vec n Word8 + ScalarArrayDataR (Vec n Word16) = Vec n Word16 + ScalarArrayDataR (Vec n Word32) = Vec n Word32 + ScalarArrayDataR (Vec n Word64) = Vec n Word64 + ScalarArrayDataR (Vec n Word128) = Vec n Word128 + ScalarArrayDataR (Vec n Half) = Vec n Half + ScalarArrayDataR (Vec n Float) = Vec n Float + ScalarArrayDataR (Vec n Double) = Vec n Double + ScalarArrayDataR (Vec n Float128) = Vec n Float128 -- Array operations -- ---------------- -newArrayData :: HasCallStack => TupR ScalarType e -> Int -> IO (MutableArrayData e) +newArrayData :: HasCallStack => TypeR e -> Int -> IO (MutableArrayData e) newArrayData TupRunit !_ = return () newArrayData (TupRpair t1 t2) !size = (,) <$> newArrayData t1 size <*> newArrayData t2 size -newArrayData (TupRsingle t) !size - | SingleScalarType s <- t - , SingleDict <- singleDict s - , SingleArrayDict <- singleArrayDict s - = allocateArray size - -- - | VectorScalarType v <- t - , VectorType w s <- v - , SingleDict <- singleDict s - , SingleArrayDict <- singleArrayDict s - = allocateArray (w * size) +newArrayData (TupRsingle _t) !size = scalar _t + where + scalar :: ScalarType t -> IO (MutableArrayData t) + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> IO (MutableArrayData t) + bit TypeBit = let (q,r) = quotRem size 8 + in if r == 0 + then allocateArray q + else allocateArray (q+1) + bit (TypeMask n) = let k = fromInteger (natVal' n) + (q,r) = quotRem k 8 + in if r == 0 + then allocateArray (size * q) + else allocateArray (size * (q + 1)) + + num :: NumType t -> IO (MutableArrayData t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t -indexArrayData :: TupR ScalarType e -> ArrayData e -> Int -> e + integral :: IntegralType t -> IO (MutableArrayData t) + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n (fromInteger $ natVal' n) t + where + single :: SingleIntegralType t -> IO (MutableArrayData t) + single = \case + TypeInt8 -> allocateArray size + TypeInt16 -> allocateArray (2 * size) + TypeInt32 -> allocateArray (4 * size) + TypeInt64 -> allocateArray (8 * size) + TypeInt128 -> allocateArray (16 * size) + TypeWord8 -> allocateArray size + TypeWord16 -> allocateArray (2 * size) + TypeWord32 -> allocateArray (4 * size) + TypeWord64 -> allocateArray (8 * size) + TypeWord128 -> allocateArray (16 * size) + + vector :: Proxy# n -> Int -> SingleIntegralType t -> IO (MutableArrayData (Vec n t)) + vector _ !k = \case + TypeInt8 -> allocateArray (k * size) + TypeInt16 -> allocateArray (2 * k * size) + TypeInt32 -> allocateArray (4 * k * size) + TypeInt64 -> allocateArray (8 * k * size) + TypeInt128 -> allocateArray (16 * k * size) + TypeWord8 -> allocateArray (k * size) + TypeWord16 -> allocateArray (2 * k * size) + TypeWord32 -> allocateArray (4 * k * size) + TypeWord64 -> allocateArray (8 * k * size) + TypeWord128 -> allocateArray (16 * k * size) + + floating :: FloatingType t -> IO (MutableArrayData t) + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n (fromInteger $ natVal' n) t + where + single :: SingleFloatingType t -> IO (MutableArrayData t) + single = \case + TypeFloat16 -> allocateArray (2 * size) + TypeFloat32 -> allocateArray (4 * size) + TypeFloat64 -> allocateArray (8 * size) + TypeFloat128 -> allocateArray (16 * size) + + vector :: Proxy# n -> Int -> SingleFloatingType t -> IO (MutableArrayData (Vec n t)) + vector _ !k = \case + TypeFloat16 -> allocateArray (2 * k * size) + TypeFloat32 -> allocateArray (4 * k * size) + TypeFloat64 -> allocateArray (8 * k * size) + TypeFloat128 -> allocateArray (16 * k * size) + + +indexArrayData :: TypeR e -> ArrayData e -> Int -> e indexArrayData tR arr ix = unsafePerformIO $ readArrayData tR arr ix -readArrayData :: forall e. TupR ScalarType e -> MutableArrayData e -> Int -> IO e +readArrayData :: TypeR e -> MutableArrayData e -> Int -> IO e readArrayData TupRunit () !_ = return () readArrayData (TupRpair t1 t2) (a1, a2) !ix = (,) <$> readArrayData t1 a1 ix <*> readArrayData t2 a2 ix -readArrayData (TupRsingle t) arr !ix - | SingleScalarType s <- t - , SingleDict <- singleDict s - , SingleArrayDict <- singleArrayDict s - = unsafeReadArray arr ix - -- - | VectorScalarType v <- t - , VectorType w s <- v - , I# w# <- w - , I# ix# <- ix - , SingleDict <- singleDict s - , SingleArrayDict <- singleArrayDict s - = let - !bytes# = w# *# sizeOf# (undefined :: ScalarArrayDataR e) - !addr# = unPtr# (unsafeUniqueArrayPtr arr) `plusAddr#` (ix# *# bytes#) - in - IO $ \s0 -> - case newAlignedPinnedByteArray# bytes# 16# s0 of { (# s1, mba# #) -> - case copyAddrToByteArray# addr# mba# 0# bytes# s1 of { s2 -> - case unsafeFreezeByteArray# mba# s2 of { (# s3, ba# #) -> - (# s3, Vec ba# #) - }}} - -writeArrayData :: forall e. TupR ScalarType e -> MutableArrayData e -> Int -> e -> IO () -writeArrayData TupRunit () !_ () = return () +readArrayData (TupRsingle _t) arr !ix = scalar _t arr ix + where + scalar :: ScalarType t -> MutableArrayData t -> Int -> IO t + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> MutableArrayData t -> Int -> IO t + bit TypeMask{} ua i = unMask <$> unsafeReadArray ua i + bit TypeBit ua i = + let (q,r) = quotRem i 8 + in do + w <- unsafeReadArray ua q + return $ Bit (testBit w r) + + num :: NumType t -> MutableArrayData t -> Int -> IO t + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> MutableArrayData t -> Int -> IO t + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n t + where + single :: SingleIntegralType t -> MutableArrayData t -> Int -> IO t + single = \case + TypeInt8 -> unsafeReadArray + TypeInt16 -> unsafeReadArray + TypeInt32 -> unsafeReadArray + TypeInt64 -> unsafeReadArray + TypeInt128 -> unsafeReadArray + TypeWord8 -> unsafeReadArray + TypeWord16 -> unsafeReadArray + TypeWord32 -> unsafeReadArray + TypeWord64 -> unsafeReadArray + TypeWord128 -> unsafeReadArray + + vector :: KnownNat n => Proxy# n -> SingleIntegralType t -> MutableArrayData (Vec n t) -> Int -> IO (Vec n t) + vector _ = \case + TypeInt8 -> unsafeReadArray + TypeInt16 -> unsafeReadArray + TypeInt32 -> unsafeReadArray + TypeInt64 -> unsafeReadArray + TypeInt128 -> unsafeReadArray + TypeWord8 -> unsafeReadArray + TypeWord16 -> unsafeReadArray + TypeWord32 -> unsafeReadArray + TypeWord64 -> unsafeReadArray + TypeWord128 -> unsafeReadArray + + floating :: FloatingType t -> MutableArrayData t -> Int -> IO t + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n t + where + single :: SingleFloatingType t -> MutableArrayData t -> Int -> IO t + single = \case + TypeFloat16 -> unsafeReadArray + TypeFloat32 -> unsafeReadArray + TypeFloat64 -> unsafeReadArray + TypeFloat128 -> unsafeReadArray + + vector :: KnownNat n => Proxy# n -> SingleFloatingType t -> MutableArrayData (Vec n t) -> Int -> IO (Vec n t) + vector _ = \case + TypeFloat16 -> unsafeReadArray + TypeFloat32 -> unsafeReadArray + TypeFloat64 -> unsafeReadArray + TypeFloat128 -> unsafeReadArray + + +writeArrayData :: TypeR e -> MutableArrayData e -> Int -> e -> IO () +writeArrayData TupRunit () !_ !() = return () writeArrayData (TupRpair t1 t2) (a1, a2) !ix (v1, v2) = writeArrayData t1 a1 ix v1 >> writeArrayData t2 a2 ix v2 -writeArrayData (TupRsingle t) arr !ix !val - | SingleScalarType s <- t - , SingleDict <- singleDict s - , SingleArrayDict <- singleArrayDict s - = unsafeWriteArray arr ix val - -- - | VectorScalarType v <- t - , VectorType w s <- v - , Vec ba# <- val - , I# w# <- w - , I# ix# <- ix - , SingleDict <- singleDict s - , SingleArrayDict <- singleArrayDict s - = let - !bytes# = w# *# sizeOf# (undefined :: ScalarArrayDataR e) - !addr# = unPtr# (unsafeUniqueArrayPtr arr) `plusAddr#` (ix# *# bytes#) - in - IO $ \s0 -> case copyByteArrayToAddr# ba# 0# addr# bytes# s0 of - s1 -> (# s1, () #) +writeArrayData (TupRsingle _t) arr !ix !val = scalar _t arr ix val + where + scalar :: ScalarType t -> MutableArrayData t -> Int -> t -> IO () + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> MutableArrayData t -> Int -> t -> IO () + bit TypeMask{} ua i m = unsafeWriteArray ua i (BitMask m) + bit TypeBit ua i (Bit b) = + let (q,r) = quotRem i 8 + update x = if b then setBit x r else clearBit x r + in do + w <- unsafeReadArray ua q + unsafeWriteArray ua q (update w) + + num :: NumType t -> MutableArrayData t -> Int -> t -> IO () + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> MutableArrayData t -> Int -> t -> IO () + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n t + where + single :: SingleIntegralType t -> MutableArrayData t -> Int -> t -> IO () + single = \case + TypeInt8 -> unsafeWriteArray + TypeInt16 -> unsafeWriteArray + TypeInt32 -> unsafeWriteArray + TypeInt64 -> unsafeWriteArray + TypeInt128 -> unsafeWriteArray + TypeWord8 -> unsafeWriteArray + TypeWord16 -> unsafeWriteArray + TypeWord32 -> unsafeWriteArray + TypeWord64 -> unsafeWriteArray + TypeWord128 -> unsafeWriteArray + + vector :: KnownNat n => Proxy# n -> SingleIntegralType t -> MutableArrayData (Vec n t) -> Int -> Vec n t -> IO () + vector _ = \case + TypeInt8 -> unsafeWriteArray + TypeInt16 -> unsafeWriteArray + TypeInt32 -> unsafeWriteArray + TypeInt64 -> unsafeWriteArray + TypeInt128 -> unsafeWriteArray + TypeWord8 -> unsafeWriteArray + TypeWord16 -> unsafeWriteArray + TypeWord32 -> unsafeWriteArray + TypeWord64 -> unsafeWriteArray + TypeWord128 -> unsafeWriteArray + + floating :: FloatingType t -> MutableArrayData t -> Int -> t -> IO () + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n t + where + single :: SingleFloatingType t -> MutableArrayData t -> Int -> t -> IO () + single = \case + TypeFloat16 -> unsafeWriteArray + TypeFloat32 -> unsafeWriteArray + TypeFloat64 -> unsafeWriteArray + TypeFloat128 -> unsafeWriteArray + + vector :: KnownNat n => Proxy# n -> SingleFloatingType t -> MutableArrayData (Vec n t) -> Int -> Vec n t -> IO () + vector _ = \case + TypeFloat16 -> unsafeWriteArray + TypeFloat32 -> unsafeWriteArray + TypeFloat64 -> unsafeWriteArray + TypeFloat128 -> unsafeWriteArray unsafeArrayDataPtr :: ScalarType e -> ArrayData e -> Ptr (ScalarArrayDataR e) -unsafeArrayDataPtr t arr - | ScalarArrayDict{} <- scalarArrayDict t - = unsafeUniqueArrayPtr arr +unsafeArrayDataPtr = scalar + where + scalar :: ScalarType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + bit TypeBit = unsafeUniqueArrayPtr + bit TypeMask{} = unsafeUniqueArrayPtr + + num :: NumType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n t + where + single :: SingleIntegralType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + single = \case + TypeInt8 -> unsafeUniqueArrayPtr + TypeInt16 -> unsafeUniqueArrayPtr + TypeInt32 -> unsafeUniqueArrayPtr + TypeInt64 -> unsafeUniqueArrayPtr + TypeInt128 -> unsafeUniqueArrayPtr + TypeWord8 -> unsafeUniqueArrayPtr + TypeWord16 -> unsafeUniqueArrayPtr + TypeWord32 -> unsafeUniqueArrayPtr + TypeWord64 -> unsafeUniqueArrayPtr + TypeWord128 -> unsafeUniqueArrayPtr + + vector :: KnownNat n => Proxy# n -> SingleIntegralType t -> ArrayData (Vec n t) -> Ptr (ScalarArrayDataR (Vec n t)) + vector _ = \case + TypeInt8 -> unsafeUniqueArrayPtr + TypeInt16 -> unsafeUniqueArrayPtr + TypeInt32 -> unsafeUniqueArrayPtr + TypeInt64 -> unsafeUniqueArrayPtr + TypeInt128 -> unsafeUniqueArrayPtr + TypeWord8 -> unsafeUniqueArrayPtr + TypeWord16 -> unsafeUniqueArrayPtr + TypeWord32 -> unsafeUniqueArrayPtr + TypeWord64 -> unsafeUniqueArrayPtr + TypeWord128 -> unsafeUniqueArrayPtr + + floating :: FloatingType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n t + where + single :: SingleFloatingType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + single = \case + TypeFloat16 -> unsafeUniqueArrayPtr + TypeFloat32 -> unsafeUniqueArrayPtr + TypeFloat64 -> unsafeUniqueArrayPtr + TypeFloat128 -> unsafeUniqueArrayPtr + + vector :: KnownNat n => Proxy# n -> SingleFloatingType t -> ArrayData (Vec n t) -> Ptr (ScalarArrayDataR (Vec n t)) + vector _ = \case + TypeFloat16 -> unsafeUniqueArrayPtr + TypeFloat32 -> unsafeUniqueArrayPtr + TypeFloat64 -> unsafeUniqueArrayPtr + TypeFloat128 -> unsafeUniqueArrayPtr touchArrayData :: TupR ScalarType e -> ArrayData e -> IO () touchArrayData TupRunit () = return () touchArrayData (TupRpair t1 t2) (a1, a2) = touchArrayData t1 a1 >> touchArrayData t2 a2 -touchArrayData (TupRsingle t) arr - | ScalarArrayDict{} <- scalarArrayDict t - = touchUniqueArray arr +touchArrayData (TupRsingle ta) arr = scalar ta arr + where + scalar :: ScalarType t -> ArrayData t -> IO () + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> ArrayData t -> IO () + bit TypeBit = touchUniqueArray + bit TypeMask{} = touchUniqueArray + + num :: NumType t -> ArrayData t -> IO () + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> ArrayData t -> IO () + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n t + where + single :: SingleIntegralType t -> ArrayData t -> IO () + single = \case + TypeInt8 -> touchUniqueArray + TypeInt16 -> touchUniqueArray + TypeInt32 -> touchUniqueArray + TypeInt64 -> touchUniqueArray + TypeInt128 -> touchUniqueArray + TypeWord8 -> touchUniqueArray + TypeWord16 -> touchUniqueArray + TypeWord32 -> touchUniqueArray + TypeWord64 -> touchUniqueArray + TypeWord128 -> touchUniqueArray + + vector :: KnownNat n => Proxy# n -> SingleIntegralType t -> ArrayData (Vec n t) -> IO () + vector _ = \case + TypeInt8 -> touchUniqueArray + TypeInt16 -> touchUniqueArray + TypeInt32 -> touchUniqueArray + TypeInt64 -> touchUniqueArray + TypeInt128 -> touchUniqueArray + TypeWord8 -> touchUniqueArray + TypeWord16 -> touchUniqueArray + TypeWord32 -> touchUniqueArray + TypeWord64 -> touchUniqueArray + TypeWord128 -> touchUniqueArray + + floating :: FloatingType t -> ArrayData t -> IO () + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n t + where + single :: SingleFloatingType t -> ArrayData t -> IO () + single = \case + TypeFloat16 -> touchUniqueArray + TypeFloat32 -> touchUniqueArray + TypeFloat64 -> touchUniqueArray + TypeFloat128 -> touchUniqueArray + + vector :: KnownNat n => Proxy# n -> SingleFloatingType t -> ArrayData (Vec n t) -> IO () + vector _ = \case + TypeFloat16 -> touchUniqueArray + TypeFloat32 -> touchUniqueArray + TypeFloat64 -> touchUniqueArray + TypeFloat128 -> touchUniqueArray rnfArrayData :: TupR ScalarType e -> ArrayData e -> () rnfArrayData TupRunit () = () rnfArrayData (TupRpair t1 t2) (a1, a2) = rnfArrayData t1 a1 `seq` rnfArrayData t2 a2 `seq` () rnfArrayData (TupRsingle t) arr = rnf (unsafeArrayDataPtr t arr) -unPtr# :: Ptr a -> Addr# -unPtr# (Ptr addr#) = addr# -- | Safe combination of creating and fast freezing of array data. -- @@ -272,22 +504,20 @@ runArrayData st = unsafePerformIO $ do (mad, r) <- st return (mad, r) --- Allocate a new array with enough storage to hold the given number of --- elements. +-- Allocate a new array of the given number of bytes. -- -- The array is uninitialised and, in particular, allocated lazily. The latter -- is important because it means that for backends that have discrete memory -- spaces (e.g. GPUs), we will not increase host memory pressure simply to track -- intermediate arrays that contain meaningful data only on the device. -- -allocateArray :: forall e. (HasCallStack, Storable e) => Int -> IO (UniqueArray e) +allocateArray :: forall e. HasCallStack => Int -> IO (UniqueArray e) allocateArray !size = internalCheck "size must be >= 0" (size >= 0) $ do arr <- newUniqueArray <=< unsafeInterleaveIO $ do - let bytes = size * sizeOf (undefined :: e) new <- readIORef __mallocForeignPtrBytes - ptr <- new bytes - traceM dump_gc ("gc: allocated new host array (size=" % int % ", ptr=" % build % ")") bytes (unsafeForeignPtrToPtr ptr) - local_memory_alloc (unsafeForeignPtrToPtr ptr) bytes + ptr <- new size + traceM dump_gc ("gc: allocated new host array (size=" % int % ", ptr=" % build % ")") size (unsafeForeignPtrToPtr ptr) + local_memory_alloc (unsafeForeignPtrToPtr ptr) size return (castForeignPtr ptr) #ifdef ACCELERATE_DEBUG addFinalizer (uniqueArrayData arr) (local_memory_free (unsafeUniqueArrayPtr arr)) @@ -333,66 +563,69 @@ liftArrayData n = tuple tuple (TupRsingle s) adata = scalar s adata scalar :: ScalarType e -> ArrayData e -> CodeQ (ArrayData e) - scalar (SingleScalarType t) = single t - scalar (VectorScalarType t) = vector t - - vector :: forall n e. VectorType (Vec n e) -> ArrayData (Vec n e) -> CodeQ (ArrayData (Vec n e)) - vector (VectorType w t) - | SingleArrayDict <- singleArrayDict t - = liftArrayData (w * n) (TupRsingle (SingleScalarType t)) - - single :: SingleType e -> ArrayData e -> CodeQ (ArrayData e) - single (NumSingleType t) = num t + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType e -> ArrayData e -> CodeQ (ArrayData e) + bit TypeBit ua = + let (q,r) = quotRem n 8 + in if r == 0 + then liftUniqueArray q ua + else liftUniqueArray (q+1) ua + bit (TypeMask n') ua = + let k = fromInteger (natVal' n') + (q,r) = quotRem k 8 + in if r == 0 + then liftUniqueArray (n * q) ua + else liftUniqueArray (n * (q+1)) ua num :: NumType e -> ArrayData e -> CodeQ (ArrayData e) num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType e -> ArrayData e -> CodeQ (ArrayData e) - integral TypeInt = liftUniqueArray n - integral TypeInt8 = liftUniqueArray n - integral TypeInt16 = liftUniqueArray n - integral TypeInt32 = liftUniqueArray n - integral TypeInt64 = liftUniqueArray n - integral TypeWord = liftUniqueArray n - integral TypeWord8 = liftUniqueArray n - integral TypeWord16 = liftUniqueArray n - integral TypeWord32 = liftUniqueArray n - integral TypeWord64 = liftUniqueArray n + integral = \case + SingleIntegralType t -> single t n + VectorIntegralType n' t -> vector n' t (n * fromInteger (natVal' n')) + where + single :: SingleIntegralType e -> Int -> ArrayData e -> CodeQ (ArrayData e) + single TypeInt8 = liftUniqueArray + single TypeInt16 = liftUniqueArray + single TypeInt32 = liftUniqueArray + single TypeInt64 = liftUniqueArray + single TypeInt128 = liftUniqueArray + single TypeWord8 = liftUniqueArray + single TypeWord16 = liftUniqueArray + single TypeWord32 = liftUniqueArray + single TypeWord64 = liftUniqueArray + single TypeWord128 = liftUniqueArray + + vector :: KnownNat n => Proxy# n -> SingleIntegralType e -> Int -> ArrayData (Vec n e) -> CodeQ (ArrayData (Vec n e)) + vector _ TypeInt8 = liftUniqueArray + vector _ TypeInt16 = liftUniqueArray + vector _ TypeInt32 = liftUniqueArray + vector _ TypeInt64 = liftUniqueArray + vector _ TypeInt128 = liftUniqueArray + vector _ TypeWord8 = liftUniqueArray + vector _ TypeWord16 = liftUniqueArray + vector _ TypeWord32 = liftUniqueArray + vector _ TypeWord64 = liftUniqueArray + vector _ TypeWord128 = liftUniqueArray floating :: FloatingType e -> ArrayData e -> CodeQ (ArrayData e) - floating TypeHalf = liftUniqueArray n - floating TypeFloat = liftUniqueArray n - floating TypeDouble = liftUniqueArray n - --- Determine the underlying type of a Haskell CLong or CULong. --- -runQ [d| type HTYPE_INT = $( - case finiteBitSize (undefined::Int) of - 32 -> [t| Int32 |] - 64 -> [t| Int64 |] - _ -> error "I don't know what architecture I am" ) |] - -runQ [d| type HTYPE_WORD = $( - case finiteBitSize (undefined::Word) of - 32 -> [t| Word32 |] - 64 -> [t| Word64 |] - _ -> error "I don't know what architecture I am" ) |] - -runQ [d| type HTYPE_CLONG = $( - case finiteBitSize (undefined::CLong) of - 32 -> [t| Int32 |] - 64 -> [t| Int64 |] - _ -> error "I don't know what architecture I am" ) |] - -runQ [d| type HTYPE_CULONG = $( - case finiteBitSize (undefined::CULong) of - 32 -> [t| Word32 |] - 64 -> [t| Word64 |] - _ -> error "I don't know what architecture I am" ) |] - -runQ [d| type HTYPE_CCHAR = $( - if isSigned (undefined::CChar) - then [t| Int8 |] - else [t| Word8 |] ) |] + floating = \case + SingleFloatingType t -> single t n + VectorFloatingType n' t -> vector n' t (n * fromInteger (natVal' n')) + where + single :: SingleFloatingType e -> Int -> ArrayData e -> CodeQ (ArrayData e) + single TypeFloat16 = liftUniqueArray + single TypeFloat32 = liftUniqueArray + single TypeFloat64 = liftUniqueArray + single TypeFloat128 = liftUniqueArray + + vector :: KnownNat n => Proxy# n -> SingleFloatingType e -> Int -> ArrayData (Vec n e) -> CodeQ (ArrayData (Vec n e)) + vector _ TypeFloat16 = liftUniqueArray + vector _ TypeFloat32 = liftUniqueArray + vector _ TypeFloat64 = liftUniqueArray + vector _ TypeFloat128 = liftUniqueArray diff --git a/src/Data/Array/Accelerate/Array/Remote/Class.hs b/src/Data/Array/Accelerate/Array/Remote/Class.hs index 7a871bd1a..067161121 100644 --- a/src/Data/Array/Accelerate/Array/Remote/Class.hs +++ b/src/Data/Array/Accelerate/Array/Remote/Class.hs @@ -54,10 +54,10 @@ class (Applicative m, Monad m, MonadCatch m, MonadMask m) => RemoteMemory m wher mallocRemote :: Int -> m (Maybe (RemotePtr m Word8)) -- | Copy the given number of elements from the host array into remote memory. - pokeRemote :: SingleType e -> Int -> RemotePtr m (ScalarArrayDataR e) -> ArrayData e -> m () + pokeRemote :: ScalarType e -> Int -> RemotePtr m (ScalarArrayDataR e) -> ArrayData e -> m () -- | Copy the given number of elements from remote memory to the host array. - peekRemote :: SingleType e -> Int -> RemotePtr m (ScalarArrayDataR e) -> MutableArrayData e -> m () + peekRemote :: ScalarType e -> Int -> RemotePtr m (ScalarArrayDataR e) -> MutableArrayData e -> m () -- | Cast a remote pointer. castRemotePtr :: RemotePtr m a -> RemotePtr m b diff --git a/src/Data/Array/Accelerate/Classes/Bounded.hs b/src/Data/Array/Accelerate/Classes/Bounded.hs index c567fb14c..4e495e7d1 100644 --- a/src/Data/Array/Accelerate/Classes/Bounded.hs +++ b/src/Data/Array/Accelerate/Classes/Bounded.hs @@ -21,14 +21,15 @@ module Data.Array.Accelerate.Classes.Bounded ( ) where -import Data.Array.Accelerate.Array.Data import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type +import qualified Data.Primitive.Vec as Prim -import Prelude ( ($), (<$>), Num(..), Char, Bool, show, concat, map, mapM ) import Language.Haskell.TH.Extra hiding ( Exp ) +import Prelude hiding ( Bounded ) import qualified Prelude as P @@ -42,113 +43,57 @@ instance P.Bounded (Exp ()) where minBound = constant () maxBound = constant () -instance P.Bounded (Exp Int) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Int8) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Int16) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Int32) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Int64) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Word) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Word8) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Word16) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Word32) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Word64) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp CShort) where - minBound = mkBitcast (mkMinBound @Int16) - maxBound = mkBitcast (mkMaxBound @Int16) - -instance P.Bounded (Exp CUShort) where - minBound = mkBitcast (mkMinBound @Word16) - maxBound = mkBitcast (mkMaxBound @Word16) - -instance P.Bounded (Exp CInt) where - minBound = mkBitcast (mkMinBound @Int32) - maxBound = mkBitcast (mkMaxBound @Int32) - -instance P.Bounded (Exp CUInt) where - minBound = mkBitcast (mkMinBound @Word32) - maxBound = mkBitcast (mkMaxBound @Word32) - -instance P.Bounded (Exp CLong) where - minBound = mkBitcast (mkMinBound @HTYPE_CLONG) - maxBound = mkBitcast (mkMaxBound @HTYPE_CLONG) - -instance P.Bounded (Exp CULong) where - minBound = mkBitcast (mkMinBound @HTYPE_CULONG) - maxBound = mkBitcast (mkMaxBound @HTYPE_CULONG) - -instance P.Bounded (Exp CLLong) where - minBound = mkBitcast (mkMinBound @Int64) - maxBound = mkBitcast (mkMaxBound @Int64) - -instance P.Bounded (Exp CULLong) where - minBound = mkBitcast (mkMinBound @Word64) - maxBound = mkBitcast (mkMaxBound @Word64) - instance P.Bounded (Exp Bool) where minBound = constant P.minBound maxBound = constant P.maxBound instance P.Bounded (Exp Char) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp CChar) where - minBound = mkBitcast (mkMinBound @HTYPE_CCHAR) - maxBound = mkBitcast (mkMaxBound @HTYPE_CCHAR) - -instance P.Bounded (Exp CSChar) where - minBound = mkBitcast (mkMinBound @Int8) - maxBound = mkBitcast (mkMaxBound @Int8) - -instance P.Bounded (Exp CUChar) where - minBound = mkBitcast (mkMinBound @Word8) - maxBound = mkBitcast (mkMaxBound @Word8) - -$(runQ $ do - let - mkInstance :: Int -> Q [Dec] - mkInstance n = - let - xs = [ mkName ('x':show i) | i <- [0 .. n-1] ] - cst = tupT (map (\x -> [t| Bounded $(varT x) |]) xs) - res = tupT (map varT xs) - app x = appsE (conE (mkName ('T':show n)) : P.replicate n x) - in - [d| instance $cst => P.Bounded (Exp $res) where - minBound = $(app [| P.minBound |]) - maxBound = $(app [| P.maxBound |]) - |] - -- - concat <$> mapM mkInstance [2..16] - ) + minBound = constant P.minBound + maxBound = constant P.maxBound + +runQ $ do + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + mkBounded :: Name -> Q [Dec] + mkBounded a = + [d| instance P.Bounded (Exp $(conT a)) where + minBound = constant P.minBound + maxBound = constant P.maxBound + + instance KnownNat n => P.Bounded (Exp (Vec n $(conT a))) where + minBound = constant (Vec (Prim.splat minBound)) + maxBound = constant (Vec (Prim.splat maxBound)) + |] + + mkTuple :: Int -> Q [Dec] + mkTuple n = + let + xs = [ mkName ('x':show i) | i <- [0 .. n-1] ] + cst = tupT (map (\x -> [t| Bounded $(varT x) |]) xs) + res = tupT (map varT xs) + app x = appsE (conE (mkName ('T':show n)) : P.replicate n x) + in + [d| instance $cst => P.Bounded (Exp $res) where + minBound = $(app [| P.minBound |]) + maxBound = $(app [| P.maxBound |]) + |] + -- + as <- mapM mkBounded integralTypes + ts <- mapM mkTuple [2..16] + return $ concat (as ++ ts) diff --git a/src/Data/Array/Accelerate/Classes/Enum.hs b/src/Data/Array/Accelerate/Classes/Enum.hs index 10e946ee5..cd2f3e2c6 100644 --- a/src/Data/Array/Accelerate/Classes/Enum.hs +++ b/src/Data/Array/Accelerate/Classes/Enum.hs @@ -1,7 +1,7 @@ {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE TemplateHaskell #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Classes.Enum @@ -22,10 +22,13 @@ module Data.Array.Accelerate.Classes.Enum ( import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type -import Text.Printf -import Prelude ( ($), String, error, unlines, succ, pred ) +import Control.Monad +import Language.Haskell.TH hiding ( Exp ) +import Text.Printf +import Prelude hiding ( Num, Enum ) import qualified Prelude as P @@ -33,145 +36,6 @@ import qualified Prelude as P -- type Enum a = P.Enum (Exp a) - -instance P.Enum (Exp Int) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Int8) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Int16) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Int32) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Int64) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Word) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Word8) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Word16) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Word32) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Word64) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CInt) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CUInt) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CLong) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CULong) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CLLong) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CULLong) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CShort) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CUShort) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Half) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Float) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Double) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CFloat) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CDouble) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - defaultSucc :: Num a => Exp a -> Exp a defaultSucc x = x + 1 @@ -192,3 +56,49 @@ preludeError x , "constraints for subsequent classes in the standard Haskell numeric hierarchy." ] +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + mkEnum :: Name -> Q [Dec] + mkEnum a = + [d| instance P.Enum (Exp $(conT a)) where + succ = defaultSucc + pred = defaultPred + toEnum = defaultToEnum + fromEnum = defaultFromEnum + + instance KnownNat n => P.Enum (Exp (Vec n $(conT a))) where + succ = defaultSucc + pred = defaultPred + toEnum = defaultToEnum + fromEnum = defaultFromEnum + |] + in + concat <$> mapM mkEnum numTypes + diff --git a/src/Data/Array/Accelerate/Classes/Eq.hs b/src/Data/Array/Accelerate/Classes/Eq.hs index 6985facdd..4f770a340 100644 --- a/src/Data/Array/Accelerate/Classes/Eq.hs +++ b/src/Data/Array/Accelerate/Classes/Eq.hs @@ -1,7 +1,10 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} @@ -30,22 +33,28 @@ module Data.Array.Accelerate.Classes.Eq ( ) where -import Data.Array.Accelerate.AST.Idx +import Data.Array.Accelerate.AST ( PrimFun(..), BitOrMask ) +import Data.Array.Accelerate.Error import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Pattern.Bool import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Shape +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type +import {-# SOURCE #-} Data.Array.Accelerate.Classes.VEq import Data.Bool ( Bool(..) ) -import Data.Char ( Char ) +import Data.Bits import Text.Printf import Prelude ( ($), String, Num(..), Ordering(..), show, error, return, concat, map, zipWith, foldr1, mapM ) import Language.Haskell.TH.Extra hiding ( Exp ) import qualified Prelude as P +import GHC.Exts +import GHC.TypeLits + infix 4 == infix 4 /= @@ -56,10 +65,7 @@ infix 4 /= infixr 3 && (&&) :: Exp Bool -> Exp Bool -> Exp Bool (&&) (Exp x) (Exp y) = - mkExp $ SmartExp (Cond (SmartExp $ Prj PairIdxLeft x) - (SmartExp $ Prj PairIdxLeft y) - (SmartExp $ Const scalarTypeWord8 0)) - `Pair` SmartExp Nil + mkExp $ Cond x y (SmartExp $ Const (scalarType @PrimBool) 0) -- | Conjunction: True if both arguments are true. This is a strict version of -- '(&&)': it will always evaluate both arguments, even when the first is false. @@ -77,10 +83,7 @@ infixr 3 &&! infixr 2 || (||) :: Exp Bool -> Exp Bool -> Exp Bool (||) (Exp x) (Exp y) = - mkExp $ SmartExp (Cond (SmartExp $ Prj PairIdxLeft x) - (SmartExp $ Const scalarTypeWord8 1) - (SmartExp $ Prj PairIdxLeft y)) - `Pair` SmartExp Nil + mkExp $ Cond x (SmartExp $ Const (scalarType @PrimBool) 1) y -- | Disjunction: True if either argument is true. This is a strict version of @@ -98,17 +101,19 @@ not :: Exp Bool -> Exp Bool not = mkLNot --- | The 'Eq' class defines equality '==' and inequality '/=' for scalar +-- | The 'Eq' class defines equality '(==)' and inequality '(/=)' for -- Accelerate expressions. -- --- For convenience, we include 'Elt' as a superclass. +-- Vector types behave analogously to tuple types. For testing equality +-- lane-wise on each element of a vector, see the class +-- 'Data.Array.Accelerate.VEq'. -- class Elt a => Eq a where (==) :: Exp a -> Exp a -> Exp Bool (/=) :: Exp a -> Exp a -> Exp Bool {-# MINIMAL (==) | (/=) #-} - x == y = mkLNot (x /= y) - x /= y = mkLNot (x == y) + x == y = not (x /= y) + x /= y = not (x == y) instance Eq () where @@ -127,8 +132,15 @@ instance P.Eq (Exp a) where (==) = preludeError "Eq.(==)" "(==)" (/=) = preludeError "Eq.(/=)" "(/=)" -preludeError :: String -> String -> a -preludeError x y = error (printf "Prelude.%s applied to EDSL types: use Data.Array.Accelerate.%s instead" x y) +preludeError :: HasCallStack => String -> String -> a +preludeError x y + = error + $ P.unlines [ printf "Prelude.%s applied to EDSL types: use Data.Array.Accelerate.%s instead" x y + , "" + , "These Prelude.Eq instances are present only to fulfil superclass" + , "constraints for subsequent classes in the standard Haskell numeric" + , "hierarchy." + ] runQ $ do let @@ -139,11 +151,13 @@ runQ $ do , ''Int16 , ''Int32 , ''Int64 + , ''Int128 , ''Word , ''Word8 , ''Word16 , ''Word32 , ''Word64 + , ''Word128 ] floatingTypes :: [Name] @@ -151,6 +165,7 @@ runQ $ do [ ''Half , ''Float , ''Double + , ''Float128 ] nonNumTypes :: [Name] @@ -158,23 +173,6 @@ runQ $ do [ ''Char ] - cTypes :: [Name] - cTypes = - [ ''CInt - , ''CUInt - , ''CLong - , ''CULong - , ''CLLong - , ''CULLong - , ''CShort - , ''CUShort - , ''CChar - , ''CUChar - , ''CSChar - , ''CFloat - , ''CDouble - ] - mkPrim :: Name -> Q [Dec] mkPrim t = [d| instance Eq $(conT t) where @@ -199,19 +197,43 @@ runQ $ do is <- mapM mkPrim integralTypes fs <- mapM mkPrim floatingTypes ns <- mapM mkPrim nonNumTypes - cs <- mapM mkPrim cTypes ts <- mapM mkTup [2..16] - return $ concat (concat [is,fs,ns,cs,ts]) + return $ concat (concat [is,fs,ns,ts]) instance Eq sh => Eq (sh :. Int) where x == y = indexHead x == indexHead y && indexTail x == indexTail y x /= y = indexHead x /= indexHead y || indexTail x /= indexTail y instance Eq Bool where - x == y = mkCoerce x == (mkCoerce y :: Exp PrimBool) - x /= y = mkCoerce x /= (mkCoerce y :: Exp PrimBool) + (==) = mkEq + (/=) = mkNEq instance Eq Ordering where x == y = mkCoerce x == (mkCoerce y :: Exp TAG) x /= y = mkCoerce x /= (mkCoerce y :: Exp TAG) +instance VEq n a => Eq (Vec n a) where + (==) = vcmp (==*) + (/=) = vcmp (/=*) + +vcmp :: forall n a. KnownNat n + => (Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool)) + -> (Exp (Vec n a) -> Exp (Vec n a) -> Exp Bool) +vcmp op x y = + let n = fromInteger $ natVal' (proxy# :: Proxy# n) + v = op x y + -- + cmp :: forall t. (Elt t, Num t, Bits t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) + => Exp (Vec n Bool) + -> Exp Bool + cmp u = + let u' = mkPrimUnary (PrimFromBool bitType integralType) u :: Exp t + in mkEq (constant ((1 `unsafeShiftL` n) - 1)) u' + in + if n P.<= 8 then cmp @Word8 v else + if n P.<= 16 then cmp @Word16 v else + if n P.<= 32 then cmp @Word32 v else + if n P.<= 64 then cmp @Word64 v else + if n P.<= 128 then cmp @Word128 v else + internalError "Can not handle Vec types with more than 128 lanes" + diff --git a/src/Data/Array/Accelerate/Classes/Floating.hs b/src/Data/Array/Accelerate/Classes/Floating.hs index b2f067af9..01821e208 100644 --- a/src/Data/Array/Accelerate/Classes/Floating.hs +++ b/src/Data/Array/Accelerate/Classes/Floating.hs @@ -1,8 +1,13 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UnboxedTuples #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Classes.Floating @@ -31,10 +36,14 @@ module Data.Array.Accelerate.Classes.Floating ( ) where import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type +import qualified Data.Primitive.Vec as Prim import Data.Array.Accelerate.Classes.Fractional +import Language.Haskell.TH hiding ( Exp ) +import Prelude hiding ( Fractional, Floating ) import qualified Prelude as P @@ -42,104 +51,58 @@ import qualified Prelude as P -- type Floating a = (Fractional a, P.Floating (Exp a)) +runQ $ + let + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] -instance P.Floating (Exp Half) where - pi = mkPi - sin = mkSin - cos = mkCos - tan = mkTan - asin = mkAsin - acos = mkAcos - atan = mkAtan - sinh = mkSinh - cosh = mkCosh - tanh = mkTanh - asinh = mkAsinh - acosh = mkAcosh - atanh = mkAtanh - exp = mkExpFloating - sqrt = mkSqrt - log = mkLog - (**) = mkFPow - logBase = mkLogBase + thFloating :: Name -> Q [Dec] + thFloating a = + [d| instance P.Floating (Exp $(conT a)) where + pi = constant pi + sin = mkSin + cos = mkCos + tan = mkTan + asin = mkAsin + acos = mkAcos + atan = mkAtan + sinh = mkSinh + cosh = mkCosh + tanh = mkTanh + asinh = mkAsinh + acosh = mkAcosh + atanh = mkAtanh + exp = mkExpFloating + sqrt = mkSqrt + log = mkLog + (**) = mkFPow + logBase = mkLogBase -instance P.Floating (Exp Float) where - pi = mkPi - sin = mkSin - cos = mkCos - tan = mkTan - asin = mkAsin - acos = mkAcos - atan = mkAtan - sinh = mkSinh - cosh = mkCosh - tanh = mkTanh - asinh = mkAsinh - acosh = mkAcosh - atanh = mkAtanh - exp = mkExpFloating - sqrt = mkSqrt - log = mkLog - (**) = mkFPow - logBase = mkLogBase - -instance P.Floating (Exp Double) where - pi = mkPi - sin = mkSin - cos = mkCos - tan = mkTan - asin = mkAsin - acos = mkAcos - atan = mkAtan - sinh = mkSinh - cosh = mkCosh - tanh = mkTanh - asinh = mkAsinh - acosh = mkAcosh - atanh = mkAtanh - exp = mkExpFloating - sqrt = mkSqrt - log = mkLog - (**) = mkFPow - logBase = mkLogBase - -instance P.Floating (Exp CFloat) where - pi = mkBitcast (mkPi @Float) - sin = mkSin - cos = mkCos - tan = mkTan - asin = mkAsin - acos = mkAcos - atan = mkAtan - sinh = mkSinh - cosh = mkCosh - tanh = mkTanh - asinh = mkAsinh - acosh = mkAcosh - atanh = mkAtanh - exp = mkExpFloating - sqrt = mkSqrt - log = mkLog - (**) = mkFPow - logBase = mkLogBase - -instance P.Floating (Exp CDouble) where - pi = mkBitcast (mkPi @Double) - sin = mkSin - cos = mkCos - tan = mkTan - asin = mkAsin - acos = mkAcos - atan = mkAtan - sinh = mkSinh - cosh = mkCosh - tanh = mkTanh - asinh = mkAsinh - acosh = mkAcosh - atanh = mkAtanh - exp = mkExpFloating - sqrt = mkSqrt - log = mkLog - (**) = mkFPow - logBase = mkLogBase + instance KnownNat n => P.Floating (Exp (Vec n $(conT a))) where + pi = constant (Vec (Prim.splat pi)) + sin = mkSin + cos = mkCos + tan = mkTan + asin = mkAsin + acos = mkAcos + atan = mkAtan + sinh = mkSinh + cosh = mkCosh + tanh = mkTanh + asinh = mkAsinh + acosh = mkAcosh + atanh = mkAtanh + exp = mkExpFloating + sqrt = mkSqrt + log = mkLog + (**) = mkFPow + logBase = mkLogBase + |] + in + concat <$> mapM thFloating floatingTypes diff --git a/src/Data/Array/Accelerate/Classes/Fractional.hs b/src/Data/Array/Accelerate/Classes/Fractional.hs index 52bdc61e9..dd554f0f6 100644 --- a/src/Data/Array/Accelerate/Classes/Fractional.hs +++ b/src/Data/Array/Accelerate/Classes/Fractional.hs @@ -1,7 +1,13 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UnboxedTuples #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Classes.Fractional @@ -21,11 +27,14 @@ module Data.Array.Accelerate.Classes.Fractional ( ) where import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Num -import Prelude ( (.) ) +import Language.Haskell.TH hiding ( Exp ) +import Prelude hiding ( Num, Fractional ) +import qualified Data.Primitive.Vec as Prim import qualified Prelude as P @@ -44,29 +53,28 @@ import qualified Prelude as P -- type Fractional a = (Num a, P.Fractional (Exp a)) - -instance P.Fractional (Exp Half) where - (/) = mkFDiv - recip = mkRecip - fromRational = constant . P.fromRational - -instance P.Fractional (Exp Float) where - (/) = mkFDiv - recip = mkRecip - fromRational = constant . P.fromRational - -instance P.Fractional (Exp Double) where - (/) = mkFDiv - recip = mkRecip - fromRational = constant . P.fromRational - -instance P.Fractional (Exp CFloat) where - (/) = mkFDiv - recip = mkRecip - fromRational = constant . P.fromRational - -instance P.Fractional (Exp CDouble) where - (/) = mkFDiv - recip = mkRecip - fromRational = constant . P.fromRational +runQ $ + let + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + thFractional :: Name -> Q [Dec] + thFractional a = + [d| instance P.Fractional (Exp $(conT a)) where + (/) = mkFDiv + recip = mkRecip + fromRational = constant . P.fromRational + + instance KnownNat n => P.Fractional (Exp (Vec n $(conT a))) where + (/) = mkFDiv + recip = mkRecip + fromRational = constant . Vec . Prim.splat . P.fromRational + |] + in + concat <$> mapM thFractional floatingTypes diff --git a/src/Data/Array/Accelerate/Classes/FromIntegral.hs b/src/Data/Array/Accelerate/Classes/FromIntegral.hs index d8678ee9d..f59fd6f54 100644 --- a/src/Data/Array/Accelerate/Classes/FromIntegral.hs +++ b/src/Data/Array/Accelerate/Classes/FromIntegral.hs @@ -1,6 +1,5 @@ -{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TemplateHaskell #-} -- | @@ -20,6 +19,7 @@ module Data.Array.Accelerate.Classes.FromIntegral ( ) where import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Integral @@ -40,7 +40,6 @@ class FromIntegral a b where -- instance {-# OVERLAPPABLE #-} (Elt a, Elt b, IsIntegral a, IsNum b) => FromIntegral a b where -- fromIntegral = mkFromIntegral - -- Reify in ghci: -- -- $( stringE . show =<< reify ''Thing ) @@ -49,37 +48,60 @@ class FromIntegral a b where -- messages when we don't have an instance available, rather than a "can not -- deduce IsNum..." style error (which the user can do nothing about). -- -$(runQ $ do - let - -- Get all the types that our dictionaries reify - digItOut :: Name -> Q [Name] - digItOut name = do - TyConI (DataD _ _ _ _ cons _) <- reify name - let - -- This is what a constructor such as IntegralNumType will be reified - -- as prior to GHC 8.4... - dig (NormalC _ [(_, AppT (ConT n) (VarT _))]) = digItOut n - -- ...but this is what IntegralNumType will be reified as on GHC 8.4 - -- and later, after the changes described in - -- https://ghc.haskell.org/trac/ghc/wiki/Migration/8.4#TemplateHaskellreificationchangesforGADTs - dig (ForallC _ _ (GadtC _ [(_, AppT (ConT n) (VarT _))] _)) = digItOut n - dig (GadtC _ _ (AppT (ConT _) (ConT n))) = return [n] - dig _ = error "Unexpected case generating FromIntegral instances" - -- - concat `fmap` mapM dig cons - - thFromIntegral :: Name -> Name -> Q Dec - thFromIntegral a b = - let - ty = AppT (AppT (ConT (mkName "FromIntegral")) (ConT a)) (ConT b) - dec = ValD (VarP (mkName "fromIntegral")) (NormalB (VarE (mkName f))) [] - f | a == b = "id" - | otherwise = "mkFromIntegral" - in - instanceD (return []) (return ty) [return dec] - -- - as <- digItOut ''IntegralType - bs <- digItOut ''NumType - sequence [ thFromIntegral a b | a <- as, b <- bs ] - ) +-- > -- Get all the types that our dictionaries reify +-- > digItOut :: Name -> Q [Name] +-- > digItOut name = do +-- > TyConI (DataD _ _ _ _ cons _) <- reify name +-- > let +-- > -- This is what a constructor such as IntegralNumType will be reified +-- > -- as prior to GHC 8.4... +-- > dig (NormalC _ [(_, AppT (ConT n) (VarT _))]) = digItOut n +-- > -- ...but this is what IntegralNumType will be reified as on GHC 8.4 +-- > -- and later, after the changes described in +-- > -- https://ghc.haskell.org/trac/ghc/wiki/Migration/8.4#TemplateHaskellreificationchangesforGADTs +-- > dig (ForallC _ _ (GadtC _ [(_, AppT (ConT n) (VarT _))] _)) = digItOut n +-- > dig (GadtC _ _ (AppT (ConT _) (ConT n))) = return [n] +-- > dig _ = error "Unexpected case generating FromIntegral instances" +-- > -- +-- > concat `fmap` mapM dig cons +-- +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + thFromIntegral :: Name -> Name -> Q [Dec] + thFromIntegral a b = + [d| instance FromIntegral $(conT a) $(conT b) where + fromIntegral = $(varE $ if a == b then 'id else 'mkFromIntegral ) + + instance KnownNat n => FromIntegral (Vec n $(conT a)) (Vec n $(conT b)) where + fromIntegral = $(varE $ if a == b then 'id else 'mkFromIntegral ) + |] + in + concat <$> sequence [ thFromIntegral from to | from <- integralTypes, to <- numTypes ] diff --git a/src/Data/Array/Accelerate/Classes/Integral.hs b/src/Data/Array/Accelerate/Classes/Integral.hs index c6cadf7cc..110bd721d 100644 --- a/src/Data/Array/Accelerate/Classes/Integral.hs +++ b/src/Data/Array/Accelerate/Classes/Integral.hs @@ -3,6 +3,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MonoLocalBinds #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TemplateHaskell #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Classes.Integral @@ -34,7 +35,9 @@ import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.Real () -import Prelude ( error ) +import Control.Monad +import Language.Haskell.TH hiding ( Exp ) +import Prelude hiding ( Enum, Ord, Num, Integral ) import qualified Prelude as P @@ -42,166 +45,35 @@ import qualified Prelude as P -- type Integral a = (Enum a, Ord a, Num a, P.Integral (Exp a)) - -instance P.Integral (Exp Int) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Int8) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Int16) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Int32) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Int64) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Word) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Word8) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Word16) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Word32) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Word64) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CInt) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CUInt) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CLong) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CULong) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CLLong) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CULLong) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CShort) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CUShort) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + mkIntegral :: Name -> Q [Dec] + mkIntegral a = + [d| instance P.Integral (Exp $(conT a)) where + quot = mkQuot + rem = mkRem + div = mkIDiv + mod = mkMod + quotRem = mkQuotRem + divMod = mkDivMod + toInteger = P.error "Prelude.toInteger not supported for Accelerate types" + |] + in + concat <$> mapM mkIntegral integralTypes diff --git a/src/Data/Array/Accelerate/Classes/Num.hs b/src/Data/Array/Accelerate/Classes/Num.hs index 942b38673..a236074ff 100644 --- a/src/Data/Array/Accelerate/Classes/Num.hs +++ b/src/Data/Array/Accelerate/Classes/Num.hs @@ -1,7 +1,13 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UnboxedTuples #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Classes.Num @@ -15,16 +21,19 @@ module Data.Array.Accelerate.Classes.Num ( - Num, + Num, Integer, (P.+), (P.-), (P.*), P.negate, P.abs, P.signum, P.fromInteger, ) where -import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type -import Prelude ( (.) ) +import Language.Haskell.TH hiding ( Exp ) +import Prelude hiding ( Num ) +import qualified Data.Primitive.Vec as Prim import qualified Prelude as P @@ -50,7 +59,6 @@ import qualified Prelude as P -- much, _much_ better. -- - -- | Conversion from an 'Integer'. -- -- An integer literal represents the application of the function 'fromInteger' @@ -61,215 +69,59 @@ import qualified Prelude as P -- fromInteger :: Num a => Integer -> Exp a -- fromInteger = P.fromInteger - -- | Basic numeric class -- type Num a = (Elt a, P.Num (Exp a)) +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + thNum :: Name -> Q [Dec] + thNum a = + [d| instance P.Num (Exp $(conT a)) where + (+) = mkAdd + (-) = mkSub + (*) = mkMul + negate = mkNeg + abs = mkAbs + signum = mkSig + fromInteger = constant . P.fromInteger + + instance KnownNat n => P.Num (Exp (Vec n $(conT a))) where + (+) = mkAdd + (-) = mkSub + (*) = mkMul + negate = mkNeg + abs = mkAbs + signum = mkSig + fromInteger = constant . Vec . Prim.splat . P.fromInteger + |] + in + concat <$> mapM thNum numTypes -instance P.Num (Exp Int) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Int8) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Int16) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Int32) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Int64) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Word) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Word8) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Word16) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Word32) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Word64) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CInt) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CUInt) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CLong) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CULong) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CLLong) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CULLong) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CShort) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CUShort) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Half) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Float) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Double) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CFloat) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CDouble) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger diff --git a/src/Data/Array/Accelerate/Classes/Ord.hs b/src/Data/Array/Accelerate/Classes/Ord.hs index fa80dd9d2..477e1adde 100644 --- a/src/Data/Array/Accelerate/Classes/Ord.hs +++ b/src/Data/Array/Accelerate/Classes/Ord.hs @@ -1,9 +1,11 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} @@ -28,26 +30,30 @@ module Data.Array.Accelerate.Classes.Ord ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..), BitOrMask ) import Data.Array.Accelerate.Analysis.Match +import Data.Array.Accelerate.Classes.Eq +import Data.Array.Accelerate.Error import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Pattern.Ordering import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Shape +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type +import {-# SOURCE #-} Data.Array.Accelerate.Classes.VOrd --- We must hide (==), as that operator is used for the literals 0, 1 and 2 in the pattern synonyms for Ordering. --- As RebindableSyntax is enabled, a literal pattern is compiled to a call to (==), meaning that the Prelude.(==) should be in scope as (==). -import Data.Array.Accelerate.Classes.Eq hiding ( (==) ) -import qualified Data.Array.Accelerate.Classes.Eq as A - +import Data.Bits import Data.Char import Language.Haskell.TH.Extra hiding ( Exp ) -import Prelude ( ($), (>>=), Ordering(..), Num(..), Maybe(..), String, show, error, unlines, return, concat, map, mapM ) +import Prelude ( ($), Num(..), Ordering(..), Maybe(..), String, show, error, unlines, return, concat, map, mapM ) import Text.Printf import qualified Prelude as P +import GHC.Exts +import GHC.TypeLits + infix 4 < infix 4 > @@ -66,41 +72,23 @@ class Eq a => Ord a where max :: Exp a -> Exp a -> Exp a compare :: Exp a -> Exp a -> Exp Ordering - x < y = if compare x y A.== constant LT then constant True else constant False - x <= y = if compare x y A.== constant GT then constant False else constant True - x > y = if compare x y A.== constant GT then constant True else constant False - x >= y = if compare x y A.== constant LT then constant False else constant True + x < y = cond (compare x y == LT_) True_ False_ + x <= y = cond (compare x y == GT_) False_ True_ + x > y = cond (compare x y == GT_) True_ False_ + x >= y = cond (compare x y == LT_) False_ True_ - min x y = if x <= y then x else y - max x y = if x <= y then y else x + min x y = cond (x <= y) x y + max x y = cond (x <= y) y x - compare x y = - if x A.== y then constant EQ else - if x <= y then constant LT - else constant GT + compare x y + = cond (x == y) EQ_ + $ cond (x <= y) LT_ + {- else -} GT_ --- Local redefinition for use with RebindableSyntax (pulled forward from Prelude.hs) +-- Local redefinition to prevent cyclic imports -- -ifThenElse :: Elt a => Exp Bool -> Exp a -> Exp a -> Exp a -ifThenElse (Exp c) (Exp x) (Exp y) = Exp $ SmartExp $ Cond (mkCoerce' c) x y - -instance Ord () where - (<) _ _ = constant False - (>) _ _ = constant False - (>=) _ _ = constant True - (<=) _ _ = constant True - min _ _ = constant () - max _ _ = constant () - compare _ _ = constant EQ - -instance Ord Z where - (<) _ _ = constant False - (>) _ _ = constant False - (<=) _ _ = constant True - (>=) _ _ = constant True - min _ _ = constant Z - max _ _ = constant Z - +cond :: Elt a => Exp Bool -> Exp a -> Exp a -> Exp a +cond (Exp c) (Exp x) (Exp y) = Exp $ SmartExp $ Cond (mkCoerce' c) x y -- Instances of 'Prelude.Ord' (mostly) don't make sense with the standard -- signatures as the return type is fixed to 'Bool'. This instance is provided @@ -118,7 +106,7 @@ instance Ord a => P.Ord (Exp a) where min = min max = max -preludeError :: String -> String -> a +preludeError :: HasCallStack => String -> String -> a preludeError x y = error $ unlines [ printf "Prelude.%s applied to EDSL types: use Data.Array.Accelerate.%s instead" x y @@ -137,11 +125,13 @@ runQ $ do , ''Int16 , ''Int32 , ''Int64 + , ''Int128 , ''Word , ''Word8 , ''Word16 , ''Word32 , ''Word64 + , ''Word128 ] floatingTypes :: [Name] @@ -149,6 +139,7 @@ runQ $ do [ ''Half , ''Float , ''Double + , ''Float128 ] nonNumTypes :: [Name] @@ -156,23 +147,6 @@ runQ $ do [ ''Char ] - cTypes :: [Name] - cTypes = - [ ''CInt - , ''CUInt - , ''CLong - , ''CULong - , ''CLLong - , ''CULLong - , ''CShort - , ''CUShort - , ''CChar - , ''CUChar - , ''CSChar - , ''CFloat - , ''CDouble - ] - mkPrim :: Name -> Q [Dec] mkPrim t = [d| instance Ord $(conT t) where @@ -186,22 +160,22 @@ runQ $ do mkLt' :: [ExpQ] -> [ExpQ] -> ExpQ mkLt' [x] [y] = [| $x < $y |] - mkLt' (x:xs) (y:ys) = [| $x < $y || ( $x A.== $y && $(mkLt' xs ys) ) |] + mkLt' (x:xs) (y:ys) = [| $x < $y || ( $x == $y && $(mkLt' xs ys) ) |] mkLt' _ _ = error "mkLt'" mkGt' :: [ExpQ] -> [ExpQ] -> ExpQ mkGt' [x] [y] = [| $x > $y |] - mkGt' (x:xs) (y:ys) = [| $x > $y || ( $x A.== $y && $(mkGt' xs ys) ) |] + mkGt' (x:xs) (y:ys) = [| $x > $y || ( $x == $y && $(mkGt' xs ys) ) |] mkGt' _ _ = error "mkGt'" mkLtEq' :: [ExpQ] -> [ExpQ] -> ExpQ - mkLtEq' [x] [y] = [| $x < $y |] - mkLtEq' (x:xs) (y:ys) = [| $x < $y || ( $x A.== $y && $(mkLtEq' xs ys) ) |] + mkLtEq' [x] [y] = [| $x <= $y |] + mkLtEq' (x:xs) (y:ys) = [| $x < $y || ( $x == $y && $(mkLtEq' xs ys) ) |] mkLtEq' _ _ = error "mkLtEq'" mkGtEq' :: [ExpQ] -> [ExpQ] -> ExpQ - mkGtEq' [x] [y] = [| $x > $y |] - mkGtEq' (x:xs) (y:ys) = [| $x > $y || ( $x A.== $y && $(mkGtEq' xs ys) ) |] + mkGtEq' [x] [y] = [| $x >= $y |] + mkGtEq' (x:xs) (y:ys) = [| $x > $y || ( $x == $y && $(mkGtEq' xs ys) ) |] mkGtEq' _ _ = error "mkGtEq'" mkTup :: Int -> Q [Dec] @@ -223,9 +197,27 @@ runQ $ do is <- mapM mkPrim integralTypes fs <- mapM mkPrim floatingTypes ns <- mapM mkPrim nonNumTypes - cs <- mapM mkPrim cTypes ts <- mapM mkTup [2..16] - return $ concat (concat [is,fs,ns,cs,ts]) + return $ concat (concat [is,fs,ns,ts]) + + +instance Ord () where + (<) _ _ = constant False + (>) _ _ = constant False + (>=) _ _ = constant True + (<=) _ _ = constant True + min _ _ = constant () + max _ _ = constant () + compare _ _ = constant EQ + +instance Ord Z where + (<) _ _ = constant False + (>) _ _ = constant False + (<=) _ _ = constant True + (>=) _ _ = constant True + min _ _ = constant Z + max _ _ = constant Z + compare _ _ = constant EQ instance Ord sh => Ord (sh :. Int) where x <= y = indexHead x <= indexHead y && indexTail x <= indexTail y @@ -247,3 +239,30 @@ instance Ord Ordering where min x y = mkCoerce $ min (mkCoerce x) (mkCoerce y :: Exp TAG) max x y = mkCoerce $ max (mkCoerce x) (mkCoerce y :: Exp TAG) +instance VOrd n a => Ord (Vec n a) where + (<) = vcmp (<*) + (>) = vcmp (>*) + (<=) = vcmp (<=*) + (>=) = vcmp (>=*) + +vcmp :: forall n a. KnownNat n + => (Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool)) + -> (Exp (Vec n a) -> Exp (Vec n a) -> Exp Bool) +vcmp op x y = + let n = fromInteger $ natVal' (proxy# :: Proxy# n) + v = op x y + -- + cmp :: forall t. (Elt t, Num t, Bits t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) + => Exp (Vec n Bool) + -> Exp Bool + cmp u = + let u' = mkPrimUnary (PrimFromBool bitType integralType) u :: Exp t + in mkEq (constant ((1 `unsafeShiftL` n) - 1)) u' + in + if n P.<= 8 then cmp @Word8 v else + if n P.<= 16 then cmp @Word16 v else + if n P.<= 32 then cmp @Word32 v else + if n P.<= 64 then cmp @Word64 v else + if n P.<= 128 then cmp @Word128 v else + internalError "Can not handle Vec types with more than 128 lanes" + diff --git a/src/Data/Array/Accelerate/Classes/Rational.hs b/src/Data/Array/Accelerate/Classes/Rational.hs index 7ec238459..0a65805c8 100644 --- a/src/Data/Array/Accelerate/Classes/Rational.hs +++ b/src/Data/Array/Accelerate/Classes/Rational.hs @@ -1,4 +1,6 @@ -{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} -- | -- Module : Data.Array.Accelerate.Classes.Rational -- Copyright : [2016..2020] The Accelerate Team @@ -29,7 +31,9 @@ import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.RealFloat +import Data.Array.Accelerate.Classes.RealFrac +import Data.Kind import Prelude ( ($) ) @@ -40,57 +44,67 @@ import Prelude ( ($) ) -- package. -- class (Num a, Ord a) => Rational a where + type Embedding a :: Type + -- | Convert a number to the quotient of two integers -- - toRational :: (FromIntegral Int64 b, Integral b) => Exp a -> Exp (Ratio b) - -instance Rational Int where toRational = integralToRational -instance Rational Int8 where toRational = integralToRational -instance Rational Int16 where toRational = integralToRational -instance Rational Int32 where toRational = integralToRational -instance Rational Int64 where toRational = integralToRational -instance Rational Word where toRational = integralToRational -instance Rational Word8 where toRational = integralToRational -instance Rational Word16 where toRational = integralToRational -instance Rational Word32 where toRational = integralToRational -instance Rational Word64 where toRational = integralToRational - -instance Rational Half where toRational = floatingToRational -instance Rational Float where toRational = floatingToRational -instance Rational Double where toRational = floatingToRational - + toRational :: (FromIntegral (Embedding a) b, Integral b) => Exp a -> Exp (Ratio b) + +instance Rational Int where type Embedding Int = Int; toRational = integralToRational +instance Rational Int8 where type Embedding Int8 = Int8; toRational = integralToRational +instance Rational Int16 where type Embedding Int16 = Int16; toRational = integralToRational +instance Rational Int32 where type Embedding Int32 = Int32; toRational = integralToRational +instance Rational Int64 where type Embedding Int64 = Int64; toRational = integralToRational +instance Rational Int128 where type Embedding Int128 = Int128; toRational = integralToRational +instance Rational Word where type Embedding Word = Word; toRational = integralToRational +instance Rational Word8 where type Embedding Word8 = Word8; toRational = integralToRational +instance Rational Word16 where type Embedding Word16 = Word16; toRational = integralToRational +instance Rational Word32 where type Embedding Word32 = Word32; toRational = integralToRational +instance Rational Word64 where type Embedding Word64 = Word64; toRational = integralToRational +instance Rational Word128 where type Embedding Word128 = Word128; toRational = integralToRational + +instance Rational Half where type Embedding Half = Int16; toRational = floatingToRational +instance Rational Float where type Embedding Float = Int32; toRational = floatingToRational +instance Rational Double where type Embedding Double = Int64; toRational = floatingToRational +instance Rational Float128 where type Embedding Float128 = Int128; toRational = floatingToRational integralToRational - :: (Integral a, Integral b, FromIntegral a Int64, FromIntegral Int64 b) + :: (Integral a, Integral b, FromIntegral a b) => Exp a -> Exp (Ratio b) -integralToRational x = fromIntegral (fromIntegral x :: Exp Int64) :% 1 +integralToRational x = fromIntegral x :% 1 floatingToRational - :: (RealFloat a, Integral b, FromIntegral Int64 b) + :: (RealFloat a, Integral b, FromIntegral (Significand a) b, FromIntegral Int (Significand a), FiniteBits (Significand a)) => Exp a -> Exp (Ratio b) floatingToRational x = fromIntegral u :% fromIntegral v where - (m, e) = decodeFloat x - (n, d) = elimZeros m (negate e) - u :% v = cond (e >= 0) ((m `shiftL` e) :% 1) $ - cond (m .&. 1 == 0) (n :% shiftL 1 d) $ - (m :% shiftL 1 (negate e)) + T2 m e = decodeFloat x + T2 n d = elimZeros m ne' + ne' = negate e' + e' = fromIntegral e + u :% v = cond (e' >= 0) ((m `shiftL` e') :% 1) $ + cond (m .&. 1 == 0) (n :% shiftL 1 d) $ + (m :% shiftL 1 ne') -- Stolen from GHC.Float.ConversionUtils -- Double mantissa have 53 bits, which fits in an Int64 -- -elimZeros :: Exp Int64 -> Exp Int -> (Exp Int64, Exp Int) -elimZeros x y = (u, v) +elimZeros + :: forall e. (Num e, Ord e, FiniteBits e) + => Exp e + -> Exp e + -> Exp (e, e) +elimZeros x y = T2 u v where T3 _ u v = while (\(T3 p _ _) -> p) elim (T3 moar x y) kthxbai = constant False moar = constant True - elim :: Exp (Bool, Int64, Int) -> Exp (Bool, Int64, Int) + elim :: Exp (Bool, e, e) -> Exp (Bool, e, e) elim (T3 _ n e) = - let t = countTrailingZeros (fromIntegral n :: Exp Word8) + let t = countTrailingZeros n in cond (e <= t) (T3 kthxbai (shiftR n e) 0) $ cond (t < 8) (T3 kthxbai (shiftR n t) (e-t)) $ diff --git a/src/Data/Array/Accelerate/Classes/Real.hs b/src/Data/Array/Accelerate/Classes/Real.hs index a6bd1b185..aa38158cf 100644 --- a/src/Data/Array/Accelerate/Classes/Real.hs +++ b/src/Data/Array/Accelerate/Classes/Real.hs @@ -34,8 +34,8 @@ type Real a = (Num a, Ord a, P.Real (Exp a)) -- Instances of 'Real' don't make sense in Accelerate at the moment. These are -- only provided to fulfil superclass constraints; e.g. Integral. -- --- We won't need `toRational' until we support rational numbers in AP --- computations. +-- We won't need `toRational' until we support rational numbers in scalar +-- expressions. -- instance (Num a, Ord a) => P.Real (Exp a) where toRational diff --git a/src/Data/Array/Accelerate/Classes/RealFloat.hs b/src/Data/Array/Accelerate/Classes/RealFloat.hs index 0b3366ec2..4d11059c0 100644 --- a/src/Data/Array/Accelerate/Classes/RealFloat.hs +++ b/src/Data/Array/Accelerate/Classes/RealFloat.hs @@ -9,6 +9,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} +{-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Classes.RealFloat -- Copyright : [2016..2020] The Accelerate Team @@ -22,11 +23,12 @@ module Data.Array.Accelerate.Classes.RealFloat ( RealFloat(..), + defaultProperFraction, ) where import Data.Array.Accelerate.Error -import Data.Array.Accelerate.Language ( cond, while ) +import Data.Array.Accelerate.Language ( (^), cond, while ) import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type @@ -36,14 +38,15 @@ import Data.Array.Accelerate.Data.Bits import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.Floating import Data.Array.Accelerate.Classes.FromIntegral +import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.RealFrac import Data.Text.Lazy.Builder import Formatting -import Text.Printf import Prelude ( (.), ($), String, error, undefined, unlines, otherwise ) +import Text.Printf import qualified Prelude as P @@ -52,8 +55,8 @@ import qualified Prelude as P -- class (RealFrac a, Floating a) => RealFloat a where -- | The radix of the representation (often 2) (constant) - floatRadix :: Exp a -> Exp Int64 -- Integer - default floatRadix :: P.RealFloat a => Exp a -> Exp Int64 + floatRadix :: Exp a -> Exp Int -- Integer + default floatRadix :: P.RealFloat a => Exp a -> Exp Int floatRadix _ = P.fromInteger (P.floatRadix (undefined::a)) -- | The number of digits of 'floatRadix' in the significand (constant) @@ -62,45 +65,40 @@ class (RealFrac a, Floating a) => RealFloat a where floatDigits _ = constant (P.floatDigits (undefined::a)) -- | The lowest and highest values the exponent may assume (constant) - floatRange :: Exp a -> (Exp Int, Exp Int) - default floatRange :: P.RealFloat a => Exp a -> (Exp Int, Exp Int) - floatRange _ = let (m,n) = P.floatRange (undefined::a) - in (constant m, constant n) + floatRange :: Exp a -> Exp (Int, Int) + default floatRange :: P.RealFloat a => Exp a -> Exp (Int, Int) + floatRange _ = constant $ P.floatRange (undefined::a) -- | Return the significand and an appropriately scaled exponent. If -- @(m,n) = 'decodeFloat' x@ then @x = m*b^^n@, where @b@ is the -- floating-point radix ('floatRadix'). Furthermore, either @m@ and @n@ are -- both zero, or @b^(d-1) <= 'abs' m < b^d@, where @d = 'floatDigits' x@. - decodeFloat :: Exp a -> (Exp Int64, Exp Int) -- Integer + decodeFloat :: Exp a -> Exp (Significand a, Int) -- | Inverse of 'decodeFloat' - encodeFloat :: Exp Int64 -> Exp Int -> Exp a -- Integer - default encodeFloat :: (FromIntegral Int a, FromIntegral Int64 a) => Exp Int64 -> Exp Int -> Exp a + encodeFloat :: Exp (Significand a) -> Exp Int -> Exp a + default encodeFloat :: (FromIntegral Int a, FromIntegral (Significand a) a) => Exp (Significand a) -> Exp Int -> Exp a encodeFloat x e = fromIntegral x * (fromIntegral (floatRadix (undefined :: Exp a)) ** fromIntegral e) -- | Corresponds to the second component of 'decodeFloat' exponent :: Exp a -> Exp Int - exponent x = let (m,n) = decodeFloat x - in cond (m == 0) - 0 - (n + floatDigits x) + exponent x = let T2 m n = decodeFloat x + in cond (m == 0) 0 (n + floatDigits x) -- | Corresponds to the first component of 'decodeFloat' significand :: Exp a -> Exp a - significand x = let (m,_) = decodeFloat x + significand x = let T2 m _ = decodeFloat x in encodeFloat m (negate (floatDigits x)) -- | Multiply a floating point number by an integer power of the radix scaleFloat :: Exp Int -> Exp a -> Exp a - scaleFloat k x = - cond (k == 0 || isFix) x - $ encodeFloat m (n + clamp b) + scaleFloat k x = cond (k == 0 || isFix) x (encodeFloat m (n + clamp b)) where - isFix = x == 0 || isNaN x || isInfinite x - (m,n) = decodeFloat x - (l,h) = floatRange x - d = floatDigits x - b = h - l + 4*d + isFix = x == 0 || isNaN x || isInfinite x + T2 m n = decodeFloat x + T2 l h = floatRange x + d = floatDigits x + b = h - l + 4*d -- n+k may overflow, which would lead to incorrect results, hence we clamp -- the scaling parameter. If (n+k) would be larger than h, (n + clamp b k) -- must be too, similar for smaller than (l-d). @@ -131,14 +129,29 @@ class (RealFrac a, Floating a) => RealFloat a where atan2 :: Exp a -> Exp a -> Exp a +instance RealFrac Half where + type Significand Half = Int16 + properFraction = defaultProperFraction + +instance RealFrac Float where + type Significand Float = Int32 + properFraction = defaultProperFraction + +instance RealFrac Double where + type Significand Double = Int64 + properFraction = defaultProperFraction + +instance RealFrac Float128 where + type Significand Float128 = Int128 + properFraction = defaultProperFraction + instance RealFloat Half where atan2 = mkAtan2 isNaN = mkIsNaN isInfinite = mkIsInfinite isDenormalized = ieee754 "isDenormalized" (ieee754_f16_is_denormalized . mkBitcast) isNegativeZero = ieee754 "isNegativeZero" (ieee754_f16_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f16_decode (mkBitcast x) - in (fromIntegral m, n)) + decodeFloat = ieee754 "decodeFloat" (ieee754_f16_decode . mkBitcast) instance RealFloat Float where atan2 = mkAtan2 @@ -146,8 +159,7 @@ instance RealFloat Float where isInfinite = mkIsInfinite isDenormalized = ieee754 "isDenormalized" (ieee754_f32_is_denormalized . mkBitcast) isNegativeZero = ieee754 "isNegativeZero" (ieee754_f32_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f32_decode (mkBitcast x) - in (fromIntegral m, n)) + decodeFloat = ieee754 "decodeFloat" (ieee754_f32_decode . mkBitcast) instance RealFloat Double where atan2 = mkAtan2 @@ -155,28 +167,15 @@ instance RealFloat Double where isInfinite = mkIsInfinite isDenormalized = ieee754 "isDenormalized" (ieee754_f64_is_denormalized . mkBitcast) isNegativeZero = ieee754 "isNegativeZero" (ieee754_f64_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f64_decode (mkBitcast x) - in (m, n)) - -instance RealFloat CFloat where - atan2 = mkAtan2 - isNaN = mkIsNaN . mkBitcast @Float - isInfinite = mkIsInfinite . mkBitcast @Float - isDenormalized = ieee754 "isDenormalized" (ieee754_f32_is_denormalized . mkBitcast) - isNegativeZero = ieee754 "isNegativeZero" (ieee754_f32_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f32_decode (mkBitcast x) - in (fromIntegral m, n)) - encodeFloat x e = mkBitcast (encodeFloat @Float x e) + decodeFloat = ieee754 "decodeFloat" (ieee754_f64_decode . mkBitcast) -instance RealFloat CDouble where +instance RealFloat Float128 where atan2 = mkAtan2 - isNaN = mkIsNaN . mkBitcast @Double - isInfinite = mkIsInfinite . mkBitcast @Double - isDenormalized = ieee754 "isDenormalized" (ieee754_f64_is_denormalized . mkBitcast) - isNegativeZero = ieee754 "isNegativeZero" (ieee754_f64_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f64_decode (mkBitcast x) - in (m, n)) - encodeFloat x e = mkBitcast (encodeFloat @Double x e) + isNaN = mkIsNaN + isInfinite = mkIsInfinite + isDenormalized = ieee754 "isDenormalized" (ieee754_f128_is_denormalized . mkBitcast) + isNegativeZero = ieee754 "isNegativeZero" (ieee754_f128_is_negative_zero . mkBitcast) + decodeFloat = ieee754 "decodeFloat" (ieee754_f128_decode . mkBitcast) -- To satisfy superclass constraints @@ -202,12 +201,42 @@ preludeError x , "constraints for subsequent classes in the standard Haskell numeric hierarchy." ] - ieee754 :: forall a b. HasCallStack => P.RealFloat a => Builder -> (Exp a -> b) -> Exp a -> b ieee754 name f x | P.isIEEE (undefined::a) = f x | otherwise = internalError (builder % ": Not implemented for non-IEEE floating point") name + +-- Must test for ±0.0 to avoid returning -0.0 in the second component of the +-- pair. Unfortunately the branching costs a lot of performance. +-- +-- Orphaned from RealFrac module +-- +-- defaultProperFraction +-- :: (ToFloating b a, RealFrac a, IsIntegral b, Num b, Floating a) +-- => Exp a +-- -> (Exp b, Exp a) +-- defaultProperFraction x = +-- unlift $ Exp +-- $ Cond (x == 0) (tup2 (0, 0)) +-- (tup2 (n, f)) +-- where +-- n = truncate x +-- f = x - toFloating n + +defaultProperFraction + :: (RealFloat a, FromIntegral (Significand a) b, Integral b) + => Exp a + -> Exp (b, a) +defaultProperFraction x = + cond (n >= 0) + (T2 (fromIntegral m * (2 ^ n)) 0.0) + (T2 (fromIntegral q) (encodeFloat r n)) + where + T2 m n = decodeFloat x + (q, r) = quotRem m (2 ^ (negate n)) + + -- From: ghc/libraries/base/cbits/primFloat.c -- ------------------------------------------ @@ -216,6 +245,11 @@ ieee754 name f x -- * mantissa is non-zero. -- * (don't care about setting of sign bit.) -- +ieee754_f128_is_denormalized :: Exp Word128 -> Exp Bool +ieee754_f128_is_denormalized x = + ieee754_f128_mantissa x == 0 && + ieee754_f128_exponent x /= 0 + ieee754_f64_is_denormalized :: Exp Word64 -> Exp Bool ieee754_f64_is_denormalized x = ieee754_f64_mantissa x == 0 && @@ -233,6 +267,12 @@ ieee754_f16_is_denormalized x = -- Negative zero if only the sign bit is set -- +ieee754_f128_is_negative_zero :: Exp Word128 -> Exp Bool +ieee754_f128_is_negative_zero x = + ieee754_f128_negative x && + ieee754_f128_exponent x == 0 && + ieee754_f128_mantissa x == 0 + ieee754_f64_is_negative_zero :: Exp Word64 -> Exp Bool ieee754_f64_is_negative_zero x = ieee754_f64_negative x && @@ -255,9 +295,24 @@ ieee754_f16_is_negative_zero x = -- Assume the host processor stores integers and floating point numbers in the -- same endianness (true for modern processors). -- --- To recap, here's the representation of a double precision +-- To recap, here's the representation of a quadruple precision -- IEEE floating point number: -- +-- sign 127 sign bit (0==positive, 1==negative) +-- exponent 126-112 exponent (biased by 16383) +-- fraction 111-0 fraction (bits to right of binary part) +-- +ieee754_f128_mantissa :: Exp Word128 -> Exp Word128 +ieee754_f128_mantissa x = x .&. 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFF + +ieee754_f128_exponent :: Exp Word128 -> Exp Word16 +ieee754_f128_exponent x = fromIntegral (x `unsafeShiftR` 112) .&. 0x7FFF + +ieee754_f128_negative :: Exp Word128 -> Exp Bool +ieee754_f128_negative x = testBit x 127 + +-- Representation of a double precision IEEE floating point number: +-- -- sign 63 sign bit (0==positive, 1==negative) -- exponent 62-52 exponent (biased by 1023) -- fraction 51-0 fraction (bits to right of binary point) @@ -271,7 +326,7 @@ ieee754_f64_exponent x = fromIntegral (x `unsafeShiftR` 52) .&. 0x7FF ieee754_f64_negative :: Exp Word64 -> Exp Bool ieee754_f64_negative x = testBit x 63 --- Representation of single precision IEEE floating point number: +-- Representation of a single precision IEEE floating point number: -- -- sign 31 sign bit (0==positive, 1==negative) -- exponent 30-23 exponent (biased by 127) @@ -286,7 +341,7 @@ ieee754_f32_exponent x = fromIntegral (x `unsafeShiftR` 23) ieee754_f32_negative :: Exp Word32 -> Exp Bool ieee754_f32_negative x = testBit x 31 --- Representation of half precision IEEE floating point number: +-- Representation of a half precision IEEE floating point number: -- -- sign 15 sign bit (0==positive, 1==negative) -- exponent 14-10 exponent (biased by 15) @@ -311,7 +366,7 @@ ieee754_f16_decode i = _HMSBIT = 0x8000 _HMINEXP = ((_HALF_MIN_EXP) - (_HALF_MANT_DIG) - 1) _HALF_MANT_DIG = floatDigits (undefined::Exp Half) - (_HALF_MIN_EXP, _HALF_MAX_EXP) = floatRange (undefined::Exp Half) + T2 _HALF_MIN_EXP _HALF_MAX_EXP = floatRange (undefined::Exp Half) high1 = fromIntegral i high2 = high1 .&. (_HHIGHBIT - 1) @@ -345,7 +400,7 @@ ieee754_f32_decode i = _FMSBIT = 0x80000000 _FMINEXP = ((_FLT_MIN_EXP) - (_FLT_MANT_DIG) - 1) _FLT_MANT_DIG = floatDigits (undefined::Exp Float) - (_FLT_MIN_EXP, _FLT_MAX_EXP) = floatRange (undefined::Exp Float) + T2 _FLT_MIN_EXP _FLT_MAX_EXP = floatRange (undefined::Exp Float) high1 = fromIntegral i high2 = high1 .&. (_FHIGHBIT - 1) @@ -381,7 +436,7 @@ ieee754_f64_decode2 i = _DMSBIT = 0x80000000 _DMINEXP = ((_DBL_MIN_EXP) - (_DBL_MANT_DIG) - 1) _DBL_MANT_DIG = floatDigits (undefined::Exp Double) - (_DBL_MIN_EXP, _DBL_MAX_EXP) = floatRange (undefined::Exp Double) + T2 _DBL_MIN_EXP _DBL_MAX_EXP = floatRange (undefined::Exp Double) low = fromIntegral i high = fromIntegral (i `unsafeShiftR` 32) @@ -409,3 +464,6 @@ ieee754_f64_decode2 i = (T4 1 0 0 0) (T4 sign hi lo ie) +ieee754_f128_decode :: Exp Word128 -> Exp (Int128, Int) +ieee754_f128_decode = undefined + diff --git a/src/Data/Array/Accelerate/Classes/RealFloat.hs-boot b/src/Data/Array/Accelerate/Classes/RealFloat.hs-boot deleted file mode 100644 index a4f2878bd..000000000 --- a/src/Data/Array/Accelerate/Classes/RealFloat.hs-boot +++ /dev/null @@ -1,67 +0,0 @@ -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE FlexibleContexts #-} --- | --- Module : Data.Array.Accelerate.Classes.RealFloat --- Copyright : [2019..2020] The Accelerate Team --- License : BSD3 --- --- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> --- Stability : experimental --- Portability : non-portable (GHC extensions) --- - -module Data.Array.Accelerate.Classes.RealFloat - where - -import Data.Array.Accelerate.Smart -import Data.Array.Accelerate.Type - -import Data.Array.Accelerate.Classes.Floating -import Data.Array.Accelerate.Classes.FromIntegral -import {-# SOURCE #-} Data.Array.Accelerate.Classes.RealFrac - -import Prelude ( Bool ) -import qualified Prelude as P - - -class (RealFrac a, Floating a) => RealFloat a where - floatRadix :: Exp a -> Exp Int64 -- Integer - floatDigits :: Exp a -> Exp Int - floatRange :: Exp a -> (Exp Int, Exp Int) - decodeFloat :: Exp a -> (Exp Int64, Exp Int) -- Integer - encodeFloat :: Exp Int64 -> Exp Int -> Exp a -- Integer - exponent :: Exp a -> Exp Int - significand :: Exp a -> Exp a - scaleFloat :: Exp Int -> Exp a -> Exp a - isNaN :: Exp a -> Exp Bool - isInfinite :: Exp a -> Exp Bool - isDenormalized :: Exp a -> Exp Bool - isNegativeZero :: Exp a -> Exp Bool - isIEEE :: Exp a -> Exp Bool - atan2 :: Exp a -> Exp a -> Exp a - - exponent = P.undefined - significand = P.undefined - scaleFloat = P.undefined - - default floatRadix :: P.RealFloat a => Exp a -> Exp Int64 - floatRadix _ = P.undefined - - default floatDigits :: P.RealFloat a => Exp a -> Exp Int - floatDigits _ = P.undefined - - default floatRange :: P.RealFloat a => Exp a -> (Exp Int, Exp Int) - floatRange _ = P.undefined - - default encodeFloat :: (FromIntegral Int a, FromIntegral Int64 a) => Exp Int64 -> Exp Int -> Exp a - encodeFloat _ _ = P.undefined - - default isIEEE :: P.RealFloat a => Exp a -> Exp Bool - isIEEE _ = P.undefined - -instance RealFloat Half -instance RealFloat Float -instance RealFloat Double -instance RealFloat CFloat -instance RealFloat CDouble - diff --git a/src/Data/Array/Accelerate/Classes/RealFrac.hs b/src/Data/Array/Accelerate/Classes/RealFrac.hs index 9a12e5029..b3701a5af 100644 --- a/src/Data/Array/Accelerate/Classes/RealFrac.hs +++ b/src/Data/Array/Accelerate/Classes/RealFrac.hs @@ -3,6 +3,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -fno-warn-orphans #-} @@ -23,8 +24,7 @@ module Data.Array.Accelerate.Classes.RealFrac ( ) where -import Data.Array.Accelerate.Language ( (^), cond, even ) -import Data.Array.Accelerate.Lift ( unlift ) +import Data.Array.Accelerate.Language ( cond, even ) import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Smart @@ -32,38 +32,41 @@ import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Eq -import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.Floating import Data.Array.Accelerate.Classes.Fractional import Data.Array.Accelerate.Classes.FromIntegral import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num +import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.ToFloating -import {-# SOURCE #-} Data.Array.Accelerate.Classes.RealFloat -- defaultProperFraction import Data.Maybe +import Data.Kind +import Prelude ( ($), String, error, otherwise, unlines ) import Text.Printf -import Prelude ( ($), String, error, unlines, otherwise ) import qualified Prelude as P -- | Generalisation of 'P.div' to any instance of 'RealFrac' -- -div' :: (RealFrac a, FromIntegral Int64 b, Integral b) => Exp a -> Exp a -> Exp b +div' :: (RealFrac a, FromIntegral (Significand a) b, Integral b) => Exp a -> Exp a -> Exp b div' n d = floor (n / d) -- | Generalisation of 'P.mod' to any instance of 'RealFrac' -- -mod' :: (Floating a, RealFrac a, ToFloating Int64 a) => Exp a -> Exp a -> Exp a +mod' :: forall a. (Floating a, RealFrac a, Integral (Significand a), ToFloating (Significand a) a, FromIntegral (Significand a) (Significand a)) + => Exp a + -> Exp a + -> Exp a mod' n d = n - (toFloating f) * d where - f :: Exp Int64 + f :: Exp (Significand a) f = div' n d -- | Generalisation of 'P.divMod' to any instance of 'RealFrac' -- divMod' - :: (Floating a, RealFrac a, Integral b, FromIntegral Int64 b, ToFloating b a) + :: (Floating a, RealFrac a, Integral b, FromIntegral (Significand a) b, ToFloating b a) => Exp a -> Exp a -> (Exp b, Exp a) @@ -74,7 +77,15 @@ divMod' n d = (f, n - (toFloating f) * d) -- | Extracting components of fractions. -- -class (Ord a, Fractional a) => RealFrac a where +class (Ord a, Fractional a, Integral (Significand a)) => RealFrac a where + -- | The significand (also known as the mantissa) is the part of a number in + -- floating point representation consisting of the significant digits. + -- Generally speaking, this is the integral part of a fractional number. + -- + type Significand a :: Type + + {-# MINIMAL properFraction #-} + -- | The function 'properFraction' takes a real fractional number @x@ and -- returns a pair @(n,f)@ such that @x = n+f@, and: -- @@ -85,7 +96,7 @@ class (Ord a, Fractional a) => RealFrac a where -- -- The default definitions of the 'ceiling', 'floor', 'truncate' -- and 'round' functions are in terms of 'properFraction'. - properFraction :: (Integral b, FromIntegral Int64 b) => Exp a -> (Exp b, Exp a) + properFraction :: (FromIntegral (Significand a) b, Integral b) => Exp a -> Exp (b, a) -- The function 'splitFraction' takes a real fractional number @x@ and -- returns a pair @(n,f)@ such that @x = n+f@, and: @@ -104,109 +115,58 @@ class (Ord a, Fractional a) => RealFrac a where -- splitFraction / fraction are from numeric-prelude Algebra.RealRing -- | @truncate x@ returns the integer nearest @x@ between zero and @x@ - truncate :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b + truncate :: (Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b truncate = defaultTruncate -- | @'round' x@ returns the nearest integer to @x@; the even integer if @x@ -- is equidistant between two integers - round :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b - round = defaultRound + round :: (Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b + round = defaultRound -- | @'ceiling' x@ returns the least integer not less than @x@ - ceiling :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b - ceiling = defaultCeiling + ceiling :: (Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b + ceiling = defaultCeiling -- | @'floor' x@ returns the greatest integer not greater than @x@ - floor :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b - floor = defaultFloor - -instance RealFrac Half where - properFraction = defaultProperFraction - -instance RealFrac Float where - properFraction = defaultProperFraction + floor :: (Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b + floor = defaultFloor -instance RealFrac Double where - properFraction = defaultProperFraction -instance RealFrac CFloat where - properFraction = defaultProperFraction - truncate = defaultTruncate - round = defaultRound - ceiling = defaultCeiling - floor = defaultFloor - -instance RealFrac CDouble where - properFraction = defaultProperFraction - truncate = defaultTruncate - round = defaultRound - ceiling = defaultCeiling - floor = defaultFloor - - --- Must test for ±0.0 to avoid returning -0.0 in the second component of the --- pair. Unfortunately the branching costs a lot of performance. --- --- defaultProperFraction --- :: (ToFloating b a, RealFrac a, IsIntegral b, Num b, Floating a) --- => Exp a --- -> (Exp b, Exp a) --- defaultProperFraction x = --- unlift $ Exp --- $ Cond (x == 0) (tup2 (0, 0)) --- (tup2 (n, f)) --- where --- n = truncate x --- f = x - toFloating n - -defaultProperFraction - :: (RealFloat a, FromIntegral Int64 b, Integral b) - => Exp a - -> (Exp b, Exp a) -defaultProperFraction x - = unlift - $ cond (n >= 0) - (T2 (fromIntegral m * (2 ^ n)) 0.0) - (T2 (fromIntegral q) (encodeFloat r n)) - where - (m, n) = decodeFloat x - (q, r) = quotRem m (2 ^ (negate n)) - -defaultTruncate :: forall a b. (RealFrac a, Integral b, FromIntegral Int64 b) => Exp a -> Exp b +defaultTruncate :: forall a b. (RealFrac a, Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b defaultTruncate x - | Just IsFloatingDict <- isFloating @a - , Just IsIntegralDict <- isIntegral @b + | Just FloatingDict <- floatingDict @a + , Just IntegralDict <- integralDict @b = mkTruncate x -- | otherwise - = let (n, _) = properFraction x in n + = let T2 n _ = properFraction x in n -defaultCeiling :: forall a b. (RealFrac a, Integral b, FromIntegral Int64 b) => Exp a -> Exp b +defaultCeiling :: forall a b. (RealFrac a, Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b defaultCeiling x - | Just IsFloatingDict <- isFloating @a - , Just IsIntegralDict <- isIntegral @b + | Just FloatingDict <- floatingDict @a + , Just IntegralDict <- integralDict @b = mkCeiling x -- | otherwise - = let (n, r) = properFraction x in cond (r > 0) (n+1) n + = let T2 n r = properFraction x in cond (r > 0) (n+1) n -defaultFloor :: forall a b. (RealFrac a, Integral b, FromIntegral Int64 b) => Exp a -> Exp b +defaultFloor :: forall a b. (RealFrac a, Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b defaultFloor x - | Just IsFloatingDict <- isFloating @a - , Just IsIntegralDict <- isIntegral @b - = mkFloor x + | Just FloatingDict <- floatingDict @a + , Just IntegralDict <- integralDict @b + = mkCeiling x -- | otherwise - = let (n, r) = properFraction x in cond (r < 0) (n-1) n + = let T2 n r = properFraction x in cond (r < 0) (n-1) n -defaultRound :: forall a b. (RealFrac a, Integral b, FromIntegral Int64 b) => Exp a -> Exp b +defaultRound :: forall a b. (RealFrac a, Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b defaultRound x - | Just IsFloatingDict <- isFloating @a - , Just IsIntegralDict <- isIntegral @b - = mkRound x + | Just FloatingDict <- floatingDict @a + , Just IntegralDict <- integralDict @b + = mkCeiling x -- | otherwise - = let (n, r) = properFraction x + = let T2 n r = properFraction x m = cond (r < 0.0) (n-1) (n+1) half_down = abs r - 0.5 p = compare half_down 0.0 @@ -216,46 +176,81 @@ defaultRound x {- otherwise -} m -data IsFloatingDict a where - IsFloatingDict :: IsFloating a => IsFloatingDict a +data FloatingDict a where + FloatingDict :: IsFloating a => FloatingDict a -data IsIntegralDict a where - IsIntegralDict :: IsIntegral a => IsIntegralDict a +data IntegralDict a where + IntegralDict :: IsIntegral a => IntegralDict a -isFloating :: forall a. Elt a => Maybe (IsFloatingDict (EltR a)) -isFloating - | TupRsingle t <- eltR @a - , SingleScalarType s <- t - , NumSingleType n <- s - , FloatingNumType f <- n - = case f of - TypeHalf{} -> Just IsFloatingDict - TypeFloat{} -> Just IsFloatingDict - TypeDouble{} -> Just IsFloatingDict - -- - | otherwise - = Nothing - -isIntegral :: forall a. Elt a => Maybe (IsIntegralDict (EltR a)) -isIntegral - | TupRsingle t <- eltR @a - , SingleScalarType s <- t - , NumSingleType n <- s - , IntegralNumType i <- n - = case i of - TypeInt{} -> Just IsIntegralDict - TypeInt8{} -> Just IsIntegralDict - TypeInt16{} -> Just IsIntegralDict - TypeInt32{} -> Just IsIntegralDict - TypeInt64{} -> Just IsIntegralDict - TypeWord{} -> Just IsIntegralDict - TypeWord8{} -> Just IsIntegralDict - TypeWord16{} -> Just IsIntegralDict - TypeWord32{} -> Just IsIntegralDict - TypeWord64{} -> Just IsIntegralDict - -- - | otherwise - = Nothing +floatingDict :: forall a. Elt a => Maybe (FloatingDict (EltR a)) +floatingDict = go (eltR @a) + where + go :: TypeR t -> Maybe (FloatingDict t) + go (TupRsingle t) = scalar t + go _ = Nothing + + scalar :: ScalarType t -> Maybe (FloatingDict t) + scalar (NumScalarType t) = num t + scalar _ = Nothing + + num :: NumType t -> Maybe (FloatingDict t) + num (FloatingNumType t) = floating t + num _ = Nothing + + floating :: forall t. FloatingType t -> Maybe (FloatingDict t) + floating (SingleFloatingType t) = + case t of + TypeFloat16 -> Just FloatingDict + TypeFloat32 -> Just FloatingDict + TypeFloat64 -> Just FloatingDict + TypeFloat128 -> Just FloatingDict + floating (VectorFloatingType _ t) = + case t of + TypeFloat16 -> Just FloatingDict + TypeFloat32 -> Just FloatingDict + TypeFloat64 -> Just FloatingDict + TypeFloat128 -> Just FloatingDict + +integralDict :: forall a. Elt a => Maybe (IntegralDict (EltR a)) +integralDict = go (eltR @a) + where + go :: TypeR t -> Maybe (IntegralDict t) + go (TupRsingle t) = scalar t + go _ = Nothing + + scalar :: ScalarType t -> Maybe (IntegralDict t) + scalar (NumScalarType t) = num t + scalar _ = Nothing + + num :: NumType t -> Maybe (IntegralDict t) + num (IntegralNumType t) = integral t + num _ = Nothing + + integral :: forall t. IntegralType t -> Maybe (IntegralDict t) + integral (SingleIntegralType t) = + case t of + TypeInt8 -> Just IntegralDict + TypeInt16 -> Just IntegralDict + TypeInt32 -> Just IntegralDict + TypeInt64 -> Just IntegralDict + TypeInt128 -> Just IntegralDict + TypeWord8 -> Just IntegralDict + TypeWord16 -> Just IntegralDict + TypeWord32 -> Just IntegralDict + TypeWord64 -> Just IntegralDict + TypeWord128 -> Just IntegralDict + integral (VectorIntegralType _ t) = + case t of + TypeInt8 -> Just IntegralDict + TypeInt16 -> Just IntegralDict + TypeInt32 -> Just IntegralDict + TypeInt64 -> Just IntegralDict + TypeInt128 -> Just IntegralDict + TypeWord8 -> Just IntegralDict + TypeWord16 -> Just IntegralDict + TypeWord32 -> Just IntegralDict + TypeWord64 -> Just IntegralDict + TypeWord128 -> Just IntegralDict -- To satisfy superclass constraints @@ -275,3 +270,7 @@ preludeError x , "These Prelude.RealFrac instances are present only to fulfil superclass" , "constraints for subsequent classes in the standard Haskell numeric hierarchy." ] + +-- Instances declared in Data.Array.Accelerate.Classes.RealFloat to avoid +-- recursive modules + diff --git a/src/Data/Array/Accelerate/Classes/RealFrac.hs-boot b/src/Data/Array/Accelerate/Classes/RealFrac.hs-boot deleted file mode 100644 index 0c2fa7307..000000000 --- a/src/Data/Array/Accelerate/Classes/RealFrac.hs-boot +++ /dev/null @@ -1,24 +0,0 @@ -{-# LANGUAGE NoImplicitPrelude #-} --- | --- Module : Data.Array.Accelerate.Classes.RealFrac --- Copyright : [2019..2020] The Accelerate Team --- License : BSD3 --- --- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> --- Stability : experimental --- Portability : non-portable (GHC extensions) --- - -module Data.Array.Accelerate.Classes.RealFrac - where - -import Data.Array.Accelerate.Type - -class RealFrac a - -instance RealFrac Half -instance RealFrac Float -instance RealFrac Double -instance RealFrac CFloat -instance RealFrac CDouble - diff --git a/src/Data/Array/Accelerate/Classes/ToFloating.hs b/src/Data/Array/Accelerate/Classes/ToFloating.hs index c3f4545c6..1d45bb525 100644 --- a/src/Data/Array/Accelerate/Classes/ToFloating.hs +++ b/src/Data/Array/Accelerate/Classes/ToFloating.hs @@ -20,6 +20,7 @@ module Data.Array.Accelerate.Classes.ToFloating ( ) where import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Floating @@ -42,39 +43,46 @@ class ToFloating a b where -- instance (Elt a, Elt b, IsNum a, IsFloating b) => ToFloating a b where -- toFloating = mkToFloating + -- Generate standard instances explicitly. See also: 'FromIntegral'. -- -$(runQ $ do - let - -- Get all the types that our dictionaries reify - digItOut :: Name -> Q [Name] - digItOut name = do - TyConI (DataD _ _ _ _ cons _) <- reify name - let - -- This is what a constructor such as IntegralNumType will be reified - -- as prior to GHC 8.4... - dig (NormalC _ [(_, AppT (ConT n) (VarT _))]) = digItOut n - -- ...but this is what IntegralNumType will be reified as on GHC 8.4 - -- and later, after the changes described in - -- https://ghc.haskell.org/trac/ghc/wiki/Migration/8.4#TemplateHaskellreificationchangesforGADTs - dig (ForallC _ _ (GadtC _ [(_, AppT (ConT n) (VarT _))] _)) = digItOut n - dig (GadtC _ _ (AppT (ConT _) (ConT n))) = return [n] - dig _ = error "Unexpected case generating ToFloating instances" - -- - concat `fmap` mapM dig cons - - thToFloating :: Name -> Name -> Q Dec - thToFloating a b = - let - ty = AppT (AppT (ConT (mkName "ToFloating")) (ConT a)) (ConT b) - dec = ValD (VarP (mkName "toFloating")) (NormalB (VarE (mkName f))) [] - f | a == b = "id" - | otherwise = "mkToFloating" - in - instanceD (return []) (return ty) [return dec] - -- - as <- digItOut ''NumType - bs <- digItOut ''FloatingType - sequence [ thToFloating a b | a <- as, b <- bs ] - ) +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + thToFloating :: Name -> Name -> Q [Dec] + thToFloating a b = + [d| instance ToFloating $(conT a) $(conT b) where + toFloating = $(varE $ if a == b then 'id else 'mkToFloating) + + instance KnownNat n => ToFloating (Vec n $(conT a)) (Vec n $(conT b)) where + toFloating = $(varE $ if a == b then 'id else 'mkToFloating) + |] + in + concat <$> sequence [ thToFloating from to | from <- numTypes, to <- floatingTypes ] diff --git a/src/Data/Array/Accelerate/Classes/VEq.hs b/src/Data/Array/Accelerate/Classes/VEq.hs new file mode 100644 index 000000000..a24ebf433 --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/VEq.hs @@ -0,0 +1,211 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +-- | +-- Module : Data.Array.Accelerate.Classes.VEq +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.VEq ( + + VEq(..), + (&&*), + (||*), + vnot, + +) where + +import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.Classes.Eq +import Data.Array.Accelerate.Classes.Num +import Data.Array.Accelerate.Representation.Tag +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Shape +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Type + +import Data.Array.Accelerate.Error + +import qualified Data.Primitive.Bit as Prim + +import Language.Haskell.TH.Extra hiding ( Type, Exp ) + +import Prelude hiding ( Eq(..) ) + +import GHC.Exts +import GHC.TypeLits + + +-- | Vectorised conjunction: Element-wise returns true if both arguments in +-- the corresponding lane are True. This is a strict vectorised version of +-- '(Data.Array.Accelerate.&&)' that always evaluates both arguments. +-- +infixr 3 &&* +(&&*) :: KnownNat n => Exp (Vec n Bool) -> Exp (Vec n Bool) -> Exp (Vec n Bool) +(&&*) = mkLAnd + +-- | Vectorised disjunction: Element-wise returns true if either argument +-- in the corresponding lane is true. This is a strict vectorised version +-- of '(Data.Array.Accelerate.||)' that always evaluates both arguments. +-- +infixr 2 ||* +(||*) :: KnownNat n => Exp (Vec n Bool) -> Exp (Vec n Bool) -> Exp (Vec n Bool) +(||*) = mkLOr + +-- | Vectorised logical negation +-- +vnot :: KnownNat n => Exp (Vec n Bool) -> Exp (Vec n Bool) +vnot = mkLNot + + +infix 4 ==* +infix 4 /=* + +-- | The 'VEq' class defines lane-wise equality '(==*)' and inequality +-- '(/=*)' for Accelerate vector expressions. +-- +class SIMD n a => VEq n a where + (==*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (/=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + {-# MINIMAL (==*) | (/=*) #-} + x ==* y = vnot (x /=* y) + x /=* y = vnot (x ==* y) + +runQ $ do + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + nonNumTypes :: [Name] + nonNumTypes = + [ ''Char + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + mkPrim :: Name -> Q [Dec] + mkPrim name = + [d| instance KnownNat n => VEq n $(conT name) where + (==*) = mkPrimBinary $ PrimEq scalarType + (/=*) = mkPrimBinary $ PrimNEq scalarType + |] + + mkTup :: Word8 -> Q Dec + mkTup n = do + w <- newName "w" + x <- newName "x" + y <- newName "y" + let + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + ts = map varT xs + res = tupT ts + ctx = (++) <$> mapM (appT [t| Eq |]) ts + <*> mapM (appT [t| SIMD $(varT w) |]) ts + + cmp f = [| mkPack (zipWith $f (mkUnpack $(varE x)) (mkUnpack $(varE y))) |] + -- + instanceD ctx [t| VEq $(varT w) $res |] + [ funD (mkName "==*") [ clause [varP x, varP y] (normalB (cmp [| (==) |])) [] ] + , funD (mkName "/=*") [ clause [varP x, varP y] (normalB (cmp [| (/=) |])) [] ] + ] + -- + ps <- concat <$> mapM mkPrim (numTypes ++ nonNumTypes) + ts <- mapM mkTup [2..16] + return (ps ++ ts) + +vtrue, vfalse :: KnownNat n => Exp (Vec n Bool) +vtrue = constant (Vec (Prim.unMask Prim.ones)) +vfalse = constant (Vec (Prim.unMask Prim.zeros)) + +instance KnownNat n => VEq n () where + _ ==* _ = vtrue + _ /=* _ = vfalse + +instance KnownNat n => VEq n Z where + _ ==* _ = vtrue + _ /=* _ = vfalse + +instance KnownNat n => VEq n Bool where + (==*) = + let n = natVal' (proxy# :: Proxy# n) + -- + cmp :: forall t. (Elt t, IsIntegral (EltR t)) + => Exp (Vec n Bool) + -> Exp (Vec n Bool) + -> Exp (Vec n Bool) + cmp x y = + let x' = mkPrimUnary (PrimFromBool bitType integralType) x :: Exp t + y' = mkPrimUnary (PrimFromBool bitType integralType) y + in + mkPrimUnary (PrimToBool integralType bitType) (mkBAnd x' y') + in + if n <= 8 then cmp @Word8 else + if n <= 16 then cmp @Word16 else + if n <= 32 then cmp @Word32 else + if n <= 64 then cmp @Word64 else + if n <= 128 then cmp @Word128 else + internalError "Can not handle Vec types with more than 128 lanes" + + (/=*) = + let n = natVal' (proxy# :: Proxy# n) + -- + cmp :: forall t. (Elt t, IsIntegral (EltR t)) + => Exp (Vec n Bool) + -> Exp (Vec n Bool) + -> Exp (Vec n Bool) + cmp x y = + let x' = mkPrimUnary (PrimFromBool bitType integralType) x :: Exp t + y' = mkPrimUnary (PrimFromBool bitType integralType) y + in + mkPrimUnary (PrimToBool integralType bitType) (mkBXor x' y') + in + if n <= 8 then cmp @Word8 else + if n <= 16 then cmp @Word16 else + if n <= 32 then cmp @Word32 else + if n <= 64 then cmp @Word64 else + if n <= 128 then cmp @Word128 else + internalError "Can not handle SIMD vector types with more than 128 lanes" + +instance (Eq sh, SIMD n sh) => VEq n (sh :. Int) where + x ==* y = mkPack (zipWith (==) (mkUnpack x) (mkUnpack y)) + x /=* y = mkPack (zipWith (/=) (mkUnpack x) (mkUnpack y)) + +instance KnownNat n => VEq n Ordering where + x ==* y = mkCoerce x ==* (mkCoerce y :: Exp (Vec n TAG)) + x /=* y = mkCoerce x /=* (mkCoerce y :: Exp (Vec n TAG)) + diff --git a/src/Data/Array/Accelerate/Classes/VEq.hs-boot b/src/Data/Array/Accelerate/Classes/VEq.hs-boot new file mode 100644 index 000000000..082de6b18 --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/VEq.hs-boot @@ -0,0 +1,26 @@ +{-# LANGUAGE MultiParamTypeClasses #-} +-- | +-- Module : Data.Array.Accelerate.Classes.VEq +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.VEq + where + +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec + +class SIMD n a => VEq n a where + (==*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (/=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + {-# MINIMAL (==*) | (/=*) #-} + x ==* y = vnot (x /=* y) + x /=* y = vnot (x ==* y) + +vnot :: KnownNat n => Exp (Vec n Bool) -> Exp (Vec n Bool) + diff --git a/src/Data/Array/Accelerate/Classes/VOrd.hs b/src/Data/Array/Accelerate/Classes/VOrd.hs new file mode 100644 index 000000000..30a260edc --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/VOrd.hs @@ -0,0 +1,166 @@ +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeOperators #-} +-- | +-- Module : Data.Array.Accelerate.Classes.VOrd +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.VOrd ( + + VOrd(..), + +) where + +import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.Classes.Ord +import Data.Array.Accelerate.Classes.VEq +import Data.Array.Accelerate.Representation.Tag +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Shape +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Type + +import qualified Data.Primitive.Vec as Prim +import qualified Data.Primitive.Bit as Prim + +import Language.Haskell.TH.Extra hiding ( Type, Exp ) + +import Prelude hiding ( Ord(..), (<*) ) + + +infix 4 <* +infix 4 >* +infix 4 <=* +infix 4 >=* + +-- | The 'VOrd' class defines lane-wise comparisons for totally ordered +-- datatypes. +-- +class VEq n a => VOrd n a where + {-# MINIMAL (<=*) | vcompare #-} + (<*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (>*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (<=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (>=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + vcompare :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Ordering) + + x <* y = select (vcompare x y ==* vlt) vtrue vfalse + x <=* y = select (vcompare x y ==* vgt) vfalse vtrue + x >* y = select (vcompare x y ==* vgt) vtrue vfalse + x >=* y = select (vcompare x y ==* vlt) vfalse vtrue + + vcompare x y + = select (x ==* y) veq + $ select (x <=* y) vlt vgt + +vlt, veq, vgt :: KnownNat n => Exp (Vec n Ordering) +vlt = constant (Vec (let (tag,()) = fromElt LT in Prim.splat tag, ())) +veq = constant (Vec (let (tag,()) = fromElt EQ in Prim.splat tag, ())) +vgt = constant (Vec (let (tag,()) = fromElt EQ in Prim.splat tag, ())) + +vtrue, vfalse :: KnownNat n => Exp (Vec n Bool) +vtrue = constant (Vec (Prim.unMask Prim.ones)) +vfalse = constant (Vec (Prim.unMask Prim.zeros)) + +runQ $ do + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + nonNumTypes :: [Name] + nonNumTypes = + [ ''Char + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + mkPrim :: Name -> Q [Dec] + mkPrim name = + [d| instance KnownNat n => VOrd n $(conT name) where + (<*) = mkPrimBinary $ PrimLt scalarType + (>*) = mkPrimBinary $ PrimGt scalarType + (<=*) = mkPrimBinary $ PrimLtEq scalarType + (>=*) = mkPrimBinary $ PrimGtEq scalarType + |] + + mkTup :: Word8 -> Q Dec + mkTup n = do + w <- newName "w" + x <- newName "x" + y <- newName "y" + let + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + ts = map varT xs + res = tupT ts + ctx = (++) <$> mapM (appT [t| Ord |]) ts + <*> mapM (appT [t| SIMD $(varT w) |]) ts + cmp f = [| mkPack (zipWith $f (mkUnpack $(varE x)) (mkUnpack $(varE y))) |] + -- + instanceD ctx [t| VOrd $(varT w) $res |] + [ funD (mkName "<*") [ clause [varP x, varP y] (normalB (cmp [| (<) |])) [] ] + , funD (mkName ">*") [ clause [varP x, varP y] (normalB (cmp [| (>) |])) [] ] + , funD (mkName "<=*") [ clause [varP x, varP y] (normalB (cmp [| (<=) |])) [] ] + , funD (mkName ">=*") [ clause [varP x, varP y] (normalB (cmp [| (>=) |])) [] ] + ] + -- + ps <- concat <$> mapM mkPrim (numTypes ++ nonNumTypes) + ts <- mapM mkTup [2..16] + return (ps ++ ts) + +instance KnownNat n => VOrd n () where + (<*) _ _ = vfalse + (>*) _ _ = vfalse + (<=*) _ _ = vtrue + (>=*) _ _ = vtrue + vcompare _ _ = veq + +instance KnownNat n => VOrd n Z where + (<*) _ _ = vfalse + (>*) _ _ = vfalse + (<=*) _ _ = vtrue + (>=*) _ _ = vtrue + vcompare _ _ = veq + +instance KnownNat n => VOrd n Ordering where + x <* y = mkCoerce x <* (mkCoerce y :: Exp (Vec n TAG)) + x >* y = mkCoerce x >* (mkCoerce y :: Exp (Vec n TAG)) + x <=* y = mkCoerce x <=* (mkCoerce y :: Exp (Vec n TAG)) + x >=* y = mkCoerce x >=* (mkCoerce y :: Exp (Vec n TAG)) + +instance (Ord sh, VOrd n sh) => VOrd n (sh :. Int) where + x <* y = mkPack (zipWith (<) (mkUnpack x) (mkUnpack y)) + x >* y = mkPack (zipWith (>) (mkUnpack x) (mkUnpack y)) + x <=* y = mkPack (zipWith (<=) (mkUnpack x) (mkUnpack y)) + x >=* y = mkPack (zipWith (>=) (mkUnpack x) (mkUnpack y)) + diff --git a/src/Data/Array/Accelerate/Classes/VOrd.hs-boot b/src/Data/Array/Accelerate/Classes/VOrd.hs-boot new file mode 100644 index 000000000..b1b83feb5 --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/VOrd.hs-boot @@ -0,0 +1,41 @@ +{-# LANGUAGE MultiParamTypeClasses #-} +-- | +-- Module : Data.Array.Accelerate.Classes.VOrd +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.VOrd ( + + VOrd(..), + +) where + +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Classes.VEq + + +class VEq n a => VOrd n a where + {-# MINIMAL (<=*) | vcompare #-} + (<*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (>*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (<=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (>=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + vcompare :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Ordering) + + x <* y = select (vcompare x y ==* vlt) vtrue vfalse + x <=* y = select (vcompare x y ==* vgt) vfalse vtrue + x >* y = select (vcompare x y ==* vgt) vtrue vfalse + x >=* y = select (vcompare x y ==* vlt) vfalse vtrue + + vcompare x y + = select (x ==* y) veq + $ select (x <=* y) vlt vgt + +vtrue, vfalse :: KnownNat n => Exp (Vec n Bool) + diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs index 21c7a7be2..e624261bb 100644 --- a/src/Data/Array/Accelerate/Classes/Vector.hs +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -18,19 +18,18 @@ -- Stability : experimental -- Portability : non-portable (GHC extensions) -- -module Data.Array.Accelerate.Classes.Vector where + +module Data.Array.Accelerate.Classes.Vector + where import GHC.TypeLits import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Smart import Data.Primitive.Vec - - instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where type IndexType (Exp (Vec n a)) = Exp Int vecIndex = mkVectorIndex vecWrite = mkVectorWrite vecEmpty = undef - diff --git a/src/Data/Array/Accelerate/Data/Bits.hs b/src/Data/Array/Accelerate/Data/Bits.hs index 9696e3c3e..c93332ca0 100644 --- a/src/Data/Array/Accelerate/Data/Bits.hs +++ b/src/Data/Array/Accelerate/Data/Bits.hs @@ -1,9 +1,14 @@ {-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Data.Bits @@ -24,21 +29,31 @@ module Data.Array.Accelerate.Data.Bits ( ) where -import Data.Array.Accelerate.Array.Data +import Data.Array.Accelerate.AST ( BitOrMask ) import Data.Array.Accelerate.Language import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Eq +import Data.Array.Accelerate.Classes.FromIntegral +import Data.Array.Accelerate.Classes.Integral ( ) +import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord -import Data.Array.Accelerate.Classes.Integral () +import Data.Array.Accelerate.Classes.VEq +import Data.Array.Accelerate.Classes.VOrd -import Prelude ( (.), ($), undefined, otherwise ) +import Data.Kind +import Language.Haskell.TH hiding ( Exp, Type ) +import Prelude ( ($), (<$>), undefined, otherwise, concat, mapM, toInteger ) +import qualified Prelude as P import qualified Data.Bits as B +import GHC.Exts +import GHC.TypeLits -infixl 8 `shift`, `rotate`, `shiftL`, `shiftR`, `rotateL`, `rotateR` +infixl 8 `shiftL`, `shiftR`, `rotateL`, `rotateR` infixl 7 .&. infixl 6 `xor` infixl 5 .|. @@ -49,10 +64,8 @@ infixl 5 .|. -- significant bit. -- class Eq a => Bits a where - {-# MINIMAL (.&.), (.|.), xor, complement, - (shift | (shiftL, shiftR)), - (rotate | (rotateL, rotateR)), - isSigned, testBit, bit, popCount #-} + type Bools a :: Type + {-# MINIMAL (.&.), (.|.), xor, complement, shiftL, shiftR, rotateL, rotateR, isSigned, bit, testBit, popCount #-} -- | Bitwise "and" (.&.) :: Exp a -> Exp a -> Exp a @@ -66,58 +79,40 @@ class Eq a => Bits a where -- | Reverse all bits in the argument complement :: Exp a -> Exp a - -- | @'shift' x i@ shifts @x@ left by @i@ bits if @i@ is positive, or right by - -- @-i@ bits otherwise. Right shifts perform sign extension on signed number - -- types; i.e. they fill the top bits with 1 if the @x@ is negative and with - -- 0 otherwise. - shift :: Exp a -> Exp Int -> Exp a - shift x i - = cond (i < 0) (x `shiftR` (-i)) - $ cond (i > 0) (x `shiftL` i) - $ x - - -- | @'rotate' x i@ rotates @x@ left by @i@ bits if @i@ is positive, or right - -- by @-i@ bits otherwise. - rotate :: Exp a -> Exp Int -> Exp a - rotate x i - = cond (i < 0) (x `rotateR` (-i)) - $ cond (i > 0) (x `rotateL` i) - $ x - -- | The value with all bits unset zeroBits :: Exp a + default zeroBits :: Num a => Exp a zeroBits = clearBit (bit 0) 0 -- | @bit /i/@ is a value with the @/i/@th bit set and all other bits clear. - bit :: Exp Int -> Exp a + bit :: Exp a -> Exp a -- | @x \`setBit\` i@ is the same as @x .|. bit i@ - setBit :: Exp a -> Exp Int -> Exp a + setBit :: Exp a -> Exp a -> Exp a setBit x i = x .|. bit i -- | @x \`clearBit\` i@ is the same as @x .&. complement (bit i)@ - clearBit :: Exp a -> Exp Int -> Exp a + clearBit :: Exp a -> Exp a -> Exp a clearBit x i = x .&. complement (bit i) -- | @x \`complementBit\` i@ is the same as @x \`xor\` bit i@ - complementBit :: Exp a -> Exp Int -> Exp a + complementBit :: Exp a -> Exp a -> Exp a complementBit x i = x `xor` bit i -- | Return 'True' if the @n@th bit of the argument is 1 - testBit :: Exp a -> Exp Int -> Exp Bool + testBit :: Exp a -> Exp a -> Exp (Bools a) -- | Return 'True' if the argument is a signed type. isSigned :: Exp a -> Exp Bool -- | Shift the argument left by the specified number of bits (which must be - -- non-negative). - shiftL :: Exp a -> Exp Int -> Exp a - shiftL x i = x `shift` i + -- non-negative) + shiftL :: Exp a -> Exp a -> Exp a -- | Shift the argument left by the specified number of bits. The result is -- undefined for negative shift amounts and shift amounts greater or equal to -- the 'finiteBitSize'. - unsafeShiftL :: Exp a -> Exp Int -> Exp a + unsafeShiftL :: Exp a -> Exp a -> Exp a unsafeShiftL = shiftL -- | Shift the first argument right by the specified number of bits (which @@ -125,39 +120,36 @@ class Eq a => Bits a where -- -- Right shifts perform sign extension on signed number types; i.e. they fill -- the top bits with 1 if @x@ is negative and with 0 otherwise. - shiftR :: Exp a -> Exp Int -> Exp a - shiftR x i = x `shift` (-i) + shiftR :: Exp a -> Exp a -> Exp a -- | Shift the first argument right by the specified number of bits. The -- result is undefined for negative shift amounts and shift amounts greater or -- equal to the 'finiteBitSize'. - unsafeShiftR :: Exp a -> Exp Int -> Exp a + unsafeShiftR :: Exp a -> Exp a -> Exp a unsafeShiftR = shiftR -- | Rotate the argument left by the specified number of bits (which must be -- non-negative). - rotateL :: Exp a -> Exp Int -> Exp a - rotateL x i = x `rotate` i + rotateL :: Exp a -> Exp a -> Exp a -- | Rotate the argument right by the specified number of bits (which must be non-negative). - rotateR :: Exp a -> Exp Int -> Exp a - rotateR x i = x `rotate` (-i) + rotateR :: Exp a -> Exp a -> Exp a -- | Return the number of set bits in the argument. This number is known as -- the population count or the Hamming weight. - popCount :: Exp a -> Exp Int + popCount :: Exp a -> Exp a -class Bits b => FiniteBits b where +class Bits a => FiniteBits a where -- | Return the number of bits in the type of the argument. - finiteBitSize :: Exp b -> Exp Int + finiteBitSize :: Exp a -> Exp Int -- | Count the number of zero bits preceding the most significant set bit. -- This can be used to compute a base-2 logarithm via: -- -- > logBase2 x = finiteBitSize x - 1 - countLeadingZeros x -- - countLeadingZeros :: Exp b -> Exp Int + countLeadingZeros :: Exp a -> Exp a -- | Count the number of zero bits following the least significant set bit. -- The related @@ -166,558 +158,75 @@ class Bits b => FiniteBits b where -- -- > findFirstSet x = 1 + countTrailingZeros x -- - countTrailingZeros :: Exp b -> Exp Int + countTrailingZeros :: Exp a -> Exp a -- Instances for Bits -- ------------------ instance Bits Bool where + type Bools Bool = Bool (.&.) = (&&) (.|.) = (||) xor = (/=) complement = not - shift x i = cond (i == 0) x (constant False) - testBit x i = cond (i == 0) x (constant False) - rotate x _ = x - bit i = i == 0 - isSigned = isSignedDefault - popCount = boolToInt - -instance Bits Int where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Int8 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Int16 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Int32 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Int64 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Word where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Word8 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Word16 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Word32 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Word64 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits CInt where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Int32 - testBit b = testBitDefault (mkBitcast @Int32 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Int32 - -instance Bits CUInt where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Word32 - testBit b = testBitDefault (mkBitcast @Word32 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Word32 - -instance Bits CLong where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @HTYPE_CLONG - testBit b = testBitDefault (mkBitcast @HTYPE_CLONG b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault + zeroBits = False_ + testBit = (&&) + bit x = x isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @HTYPE_CLONG - -instance Bits CULong where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @HTYPE_CULONG - testBit b = testBitDefault (mkBitcast @HTYPE_CULONG b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @HTYPE_CULONG - -instance Bits CLLong where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Int64 - testBit b = testBitDefault (mkBitcast @Int64 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Int64 - -instance Bits CULLong where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Word64 - testBit b = testBitDefault (mkBitcast @Word64 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Word64 - -instance Bits CShort where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Int16 - testBit b = testBitDefault (mkBitcast @Int16 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Int16 - -instance Bits CUShort where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Word16 - testBit b = testBitDefault (mkBitcast @Word16 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Word16 - -instance Bits CChar where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @HTYPE_CCHAR - testBit b = testBitDefault (mkBitcast @HTYPE_CCHAR b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @HTYPE_CCHAR - -instance Bits CSChar where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Int8 - testBit b = testBitDefault (mkBitcast @Int8 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Int8 - -instance Bits CUChar where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Word8 - testBit b = testBitDefault (mkBitcast @Word8 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Word8 - - - --- Instances for FiniteBits --- ------------------------ + shiftL x i = cond i False_ x + shiftR x i = cond i False_ x + rotateL x _ = x + rotateR x _ = x + popCount x = x instance FiniteBits Bool where - finiteBitSize _ = constInt 8 -- stored as Word8 {- (B.finiteBitSize (undefined::Bool)) -} - countLeadingZeros x = cond x 0 1 - countTrailingZeros x = cond x 0 1 - -instance FiniteBits Int where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Int8 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int8)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Int16 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int16)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Int32 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int32)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Int64 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int64)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Word where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Word8 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word8)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Word16 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word16)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Word32 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word32)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Word64 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word64)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits CInt where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CInt)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Int32 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Int32 - -instance FiniteBits CUInt where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CUInt)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Word32 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Word32 - -instance FiniteBits CLong where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CLong)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @HTYPE_CLONG - countTrailingZeros = mkCountTrailingZeros . mkBitcast @HTYPE_CLONG - -instance FiniteBits CULong where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CULong)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @HTYPE_CULONG - countTrailingZeros = mkCountTrailingZeros . mkBitcast @HTYPE_CULONG - -instance FiniteBits CLLong where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CLLong)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Int64 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Int64 - -instance FiniteBits CULLong where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CULLong)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Word64 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Word64 - -instance FiniteBits CShort where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CShort)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Int16 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Int16 - -instance FiniteBits CUShort where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CUShort)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Word16 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Word16 - -instance FiniteBits CChar where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CChar)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @HTYPE_CCHAR - countTrailingZeros = mkCountTrailingZeros . mkBitcast @HTYPE_CCHAR - -instance FiniteBits CSChar where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CSChar)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Int8 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Int8 - -instance FiniteBits CUChar where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CUChar)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Word8 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Word8 + finiteBitSize _ = fromInteger (toInteger (B.finiteBitSize (undefined::Bool))) + countLeadingZeros x = x + countTrailingZeros x = x -- Default implementations -- ----------------------- -bitDefault :: (IsIntegral (EltR t), Bits t) => Exp Int -> Exp t -bitDefault x = constInt 1 `shiftL` x -testBitDefault :: (IsIntegral (EltR t), Bits t) => Exp t -> Exp Int -> Exp Bool -testBitDefault x i = (x .&. bit i) /= constInt 0 +bitDefault :: (Num t, Bits t) => Exp t -> Exp t +bitDefault x = 1 `shiftL` x + +testBitDefault :: (Num t, Bits t) => Exp t -> Exp t -> Exp Bool +testBitDefault x i = (x .&. bit i) /= 0 -shiftDefault :: (FiniteBits t, IsIntegral (EltR t), B.Bits t) => Exp t -> Exp Int -> Exp t -shiftDefault x i - = cond (i >= 0) (shiftLDefault x i) - (shiftRDefault x (-i)) +-- shiftDefault :: (FiniteBits t, IsIntegral (EltR t), B.Bits t) => Exp t -> Exp t -> Exp t +-- shiftDefault x i +-- = cond (i >= 0) (shiftLDefault x i) +-- (shiftRDefault x (-i)) -shiftLDefault :: (FiniteBits t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t +shiftLDefault :: forall t. (B.FiniteBits t, Num t, Ord t, FromIntegral Int t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t shiftLDefault x i - = cond (i >= finiteBitSize x) (constInt 0) + = cond (i >= P.fromIntegral (B.finiteBitSize (undefined::t))) 0 $ mkBShiftL x i -shiftRDefault :: forall t. (B.Bits t, FiniteBits t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t +shiftRDefault :: forall t. (B.Bits t, B.FiniteBits t, Num t, Ord t, FromIntegral Int t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp t shiftRDefault | B.isSigned (undefined::t) = shiftRADefault | otherwise = shiftRLDefault -- Shift the argument right (signed) -shiftRADefault :: (FiniteBits t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t +shiftRADefault :: forall t. (B.FiniteBits t, Num t, Ord t, FromIntegral Int t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp t shiftRADefault x i - = cond (i >= finiteBitSize x) (cond (mkLt x (constInt 0)) (constInt (-1)) (constInt 0)) + = cond (i >= P.fromIntegral (B.finiteBitSize (undefined::t))) (cond (mkLt x 0) (-1) 0) $ mkBShiftR x i -- Shift the argument right (unsigned) -shiftRLDefault :: (FiniteBits t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t +shiftRLDefault :: forall t. (B.FiniteBits t, Num t, Ord t, FromIntegral Int t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t shiftRLDefault x i - = cond (i >= finiteBitSize x) (constInt 0) + = cond (i >= P.fromIntegral (B.finiteBitSize (undefined::t))) 0 $ mkBShiftR x i -rotateDefault :: forall t. (FiniteBits t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t -rotateDefault x i - = cond (i < 0) (mkBRotateR x (-i)) - $ cond (i > 0) (mkBRotateL x i) - $ x +-- rotateDefault :: forall t. (FiniteBits t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +-- rotateDefault x i +-- = cond (i < 0) (mkBRotateR x (-i)) +-- $ cond (i > 0) (mkBRotateL x i) +-- $ x {-- -- Rotation can be implemented in terms of two shifts, but care is needed @@ -759,22 +268,19 @@ rotateDefault' _ x i wsib = finiteBitSize x --} -rotateLDefault :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t -rotateLDefault x i - = cond (i == 0) x - $ mkBRotateL x i +-- rotateLDefault :: (Num t, Eq t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +-- rotateLDefault x i +-- = cond (i == 0) x +-- $ mkBRotateL x i -rotateRDefault :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t -rotateRDefault x i - = cond (i == 0) x - $ mkBRotateR x i +-- rotateRDefault :: (Num t, Eq t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +-- rotateRDefault x i +-- = cond (i == 0) x +-- $ mkBRotateR x i isSignedDefault :: forall b. B.Bits b => Exp b -> Exp Bool isSignedDefault _ = constant (B.isSigned (undefined::b)) -constInt :: IsIntegral (EltR e) => EltR e -> Exp e -constInt = mkExp . Const (SingleScalarType (NumSingleType (IntegralNumType integralType))) - {-- _popCountDefault :: forall a. (B.FiniteBits a, IsScalar a, Bits a, Num a) => Exp a -> Exp Int _popCountDefault = @@ -828,3 +334,72 @@ popCnt64 v1 = mkFromIntegral c c = (v4 * 0x0101010101010101) `unsafeShiftR` 56 --} +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + thBits :: Name -> Q [Dec] + thBits a = + [d| instance Bits $(conT a) where + type Bools $(conT a) = Bool + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot + bit = bitDefault + testBit = testBitDefault + shiftL = shiftLDefault + shiftR = shiftRDefault + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotateL = mkBRotateL + rotateR = mkBRotateR + isSigned = isSignedDefault + popCount = mkPopCount + + instance KnownNat n => Bits (Vec n $(conT a)) where + type Bools (Vec n $(conT a)) = (Vec n Bool) + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot + bit = bitDefault + testBit x i = (x .&. bit i) /=* 0 + shiftL x i = select (i >=* P.fromIntegral (B.finiteBitSize (undefined :: $(conT a)))) 0 (mkBShiftL x i) + shiftR x i + | B.isSigned (undefined :: $(conT a)) = select (i >=* P.fromIntegral (B.finiteBitSize (undefined :: $(conT a)))) (select (x <* 0) (P.fromInteger (-1)) 0) (mkBShiftR x i) + | otherwise = select (i >=* P.fromIntegral (B.finiteBitSize (undefined :: $(conT a)))) 0 (mkBShiftR x i) + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotateL = mkBRotateL + rotateR = mkBRotateR + isSigned _ = constant (B.isSigned (undefined :: $(conT a))) + popCount = mkPopCount + + instance FiniteBits $(conT a) where + finiteBitSize _ = fromInteger (toInteger (B.finiteBitSize (undefined :: $(conT a)))) + countLeadingZeros = mkCountLeadingZeros + countTrailingZeros = mkCountTrailingZeros + + instance KnownNat n => FiniteBits (Vec n $(conT a)) where + finiteBitSize _ = fromInteger (natVal' (proxy# :: Proxy# n) * toInteger (B.finiteBitSize (undefined :: $(conT a)))) + countLeadingZeros = mkCountLeadingZeros + countTrailingZeros = mkCountTrailingZeros + |] + in + concat <$> mapM thBits integralTypes + diff --git a/src/Data/Array/Accelerate/Data/Complex.hs b/src/Data/Array/Accelerate/Data/Complex.hs index 1a1c46767..c64cef772 100644 --- a/src/Data/Array/Accelerate/Data/Complex.hs +++ b/src/Data/Array/Accelerate/Data/Complex.hs @@ -1,15 +1,15 @@ -{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE MagicHash #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} @@ -44,6 +44,7 @@ module Data.Array.Accelerate.Data.Complex ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.Floating import Data.Array.Accelerate.Classes.Fractional @@ -56,14 +57,13 @@ import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Prelude import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type -import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type -import Data.Primitive.Vec +import qualified Data.Primitive.Vec as Prim import Data.Complex ( Complex(..) ) +import Data.Primitive.Types import Prelude ( ($) ) import qualified Data.Complex as C import qualified Prelude as P @@ -87,12 +87,12 @@ instance Elt a => Elt (Complex a) where type EltR (Complex a) = ComplexR (EltR a) eltR = let tR = eltR @a in case complexR tR of - ComplexVec s -> TupRsingle $ VectorScalarType $ VectorType 2 s + ComplexVec t -> TupRsingle (NumScalarType t) ComplexTup -> TupRunit `TupRpair` tR `TupRpair` tR tagsR = let tR = eltR @a in case complexR tR of - ComplexVec s -> [ TagRsingle (VectorScalarType (VectorType 2 s)) ] + ComplexVec t -> [ TagRsingle (NumScalarType t) ] ComplexTup -> let go :: TypeR t -> [TagR t] go TupRunit = [TagRunit] go (TupRsingle s) = [TagRsingle s] @@ -101,36 +101,33 @@ instance Elt a => Elt (Complex a) where [ TagRunit `TagRpair` ta `TagRpair` tb | ta <- go tR, tb <- go tR ] toElt = case complexR $ eltR @a of - ComplexVec _ -> \(Vec2 r i) -> toElt r :+ toElt i - ComplexTup -> \(((), r), i) -> toElt r :+ toElt i + ComplexVec _ -> \(Prim.Vec2 r i) -> toElt r :+ toElt i + ComplexTup -> \(((), r), i) -> toElt r :+ toElt i fromElt (r :+ i) = case complexR $ eltR @a of - ComplexVec _ -> Vec2 (fromElt r) (fromElt i) + ComplexVec _ -> Prim.Vec2 (fromElt r) (fromElt i) ComplexTup -> (((), fromElt r), fromElt i) type family ComplexR a where - ComplexR Half = Vec2 Half - ComplexR Float = Vec2 Float - ComplexR Double = Vec2 Double - ComplexR Int = Vec2 Int - ComplexR Int8 = Vec2 Int8 - ComplexR Int16 = Vec2 Int16 - ComplexR Int32 = Vec2 Int32 - ComplexR Int64 = Vec2 Int64 - ComplexR Word = Vec2 Word - ComplexR Word8 = Vec2 Word8 - ComplexR Word16 = Vec2 Word16 - ComplexR Word32 = Vec2 Word32 - ComplexR Word64 = Vec2 Word64 - ComplexR a = (((), a), a) - --- This isn't ideal because we gather the evidence based on the --- representation type, so we really get the evidence (VecElt (EltR a)), --- which is not very useful... --- - TLM 2020-07-16 + ComplexR Half = Prim.Vec2 Float16 + ComplexR Float = Prim.Vec2 Float32 + ComplexR Double = Prim.Vec2 Float64 + ComplexR Float128 = Prim.Vec2 Float128 + ComplexR Int8 = Prim.Vec2 Int8 + ComplexR Int16 = Prim.Vec2 Int16 + ComplexR Int32 = Prim.Vec2 Int32 + ComplexR Int64 = Prim.Vec2 Int64 + ComplexR Int128 = Prim.Vec2 Int128 + ComplexR Word8 = Prim.Vec2 Word8 + ComplexR Word16 = Prim.Vec2 Word16 + ComplexR Word32 = Prim.Vec2 Word32 + ComplexR Word64 = Prim.Vec2 Word64 + ComplexR Word128 = Prim.Vec2 Word128 + ComplexR a = (((), a), a) + data ComplexType a c where - ComplexVec :: VecElt a => SingleType a -> ComplexType a (Vec2 a) - ComplexTup :: ComplexType a (((), a), a) + ComplexVec :: Prim a => NumType (Prim.Vec2 a) -> ComplexType a (Prim.Vec2 a) + ComplexTup :: ComplexType a (((), a), a) complexR :: TypeR a -> ComplexType a (ComplexR a) complexR = tuple @@ -141,49 +138,126 @@ complexR = tuple tuple (TupRsingle s) = scalar s scalar :: ScalarType a -> ComplexType a (ComplexR a) - scalar (SingleScalarType t) = single t - scalar VectorScalarType{} = ComplexTup + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t - single :: SingleType a -> ComplexType a (ComplexR a) - single (NumSingleType t) = num t + bit :: BitType t -> ComplexType t (ComplexR t) + bit TypeBit = ComplexTup + bit TypeMask{} = ComplexTup num :: NumType a -> ComplexType a (ComplexR a) num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType a -> ComplexType a (ComplexR a) - integral TypeInt = ComplexVec singleType - integral TypeInt8 = ComplexVec singleType - integral TypeInt16 = ComplexVec singleType - integral TypeInt32 = ComplexVec singleType - integral TypeInt64 = ComplexVec singleType - integral TypeWord = ComplexVec singleType - integral TypeWord8 = ComplexVec singleType - integral TypeWord16 = ComplexVec singleType - integral TypeWord32 = ComplexVec singleType - integral TypeWord64 = ComplexVec singleType + integral = \case + VectorIntegralType{} -> ComplexTup + SingleIntegralType t -> case t of + TypeInt8 -> ComplexVec numType + TypeInt16 -> ComplexVec numType + TypeInt32 -> ComplexVec numType + TypeInt64 -> ComplexVec numType + TypeInt128 -> ComplexVec numType + TypeWord8 -> ComplexVec numType + TypeWord16 -> ComplexVec numType + TypeWord32 -> ComplexVec numType + TypeWord64 -> ComplexVec numType + TypeWord128 -> ComplexVec numType floating :: FloatingType a -> ComplexType a (ComplexR a) - floating TypeHalf = ComplexVec singleType - floating TypeFloat = ComplexVec singleType - floating TypeDouble = ComplexVec singleType - + floating = \case + VectorFloatingType{} -> ComplexTup + SingleFloatingType t -> case t of + TypeFloat16 -> ComplexVec numType + TypeFloat32 -> ComplexVec numType + TypeFloat64 -> ComplexVec numType + TypeFloat128 -> ComplexVec numType constructComplex :: forall a. Elt a => Exp a -> Exp a -> Exp (Complex a) -constructComplex r i = +constructComplex r@(Exp r') i@(Exp i') = case complexR (eltR @a) of - ComplexTup -> coerce $ T2 r i - ComplexVec _ -> V2 (coerce @a @(EltR a) r) (coerce @a @(EltR a) i) + ComplexTup -> Pattern (r,i) + ComplexVec t -> Exp $ num t r' i' + where + num :: NumType (Prim.Vec2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType (Prim.Vec2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + integral (SingleIntegralType t) = case t of + integral (VectorIntegralType n t) = + let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) + in case t of + TypeInt8 -> pack v + TypeInt16 -> pack v + TypeInt32 -> pack v + TypeInt64 -> pack v + TypeInt128 -> pack v + TypeWord8 -> pack v + TypeWord16 -> pack v + TypeWord32 -> pack v + TypeWord64 -> pack v + TypeWord128 -> pack v + + floating :: FloatingType (Prim.Vec2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + floating (SingleFloatingType t) = case t of + floating (VectorFloatingType n t) = + let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) + in case t of + TypeFloat16 -> pack v + TypeFloat32 -> pack v + TypeFloat64 -> pack v + TypeFloat128 -> pack v + + pack :: ScalarType (Prim.Vec 2 t) -> SmartExp t -> SmartExp t -> SmartExp (Prim.Vec 2 t) + pack v x y + = SmartExp (Insert v TypeWord8 + (SmartExp (Insert v TypeWord8 (SmartExp (Undef v)) (SmartExp (Const scalarType 0)) x)) + (SmartExp (Const scalarType 1)) y) deconstructComplex :: forall a. Elt a => Exp (Complex a) -> (Exp a, Exp a) deconstructComplex c@(Exp c') = case complexR (eltR @a) of - ComplexTup -> let T2 r i = coerce c in (r, i) - ComplexVec t -> let T2 r i = Exp (SmartExp (VecUnpack (VecRsucc (VecRsucc (VecRnil t))) c')) - in (r, i) + ComplexTup -> let Pattern (r,i) = c in (r, i) + ComplexVec t -> let (r', i') = num t c' in (Exp r', Exp i') + where + num :: NumType (Prim.Vec2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType (Prim.Vec2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + integral (SingleIntegralType t) = case t of + integral (VectorIntegralType n t) = + let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) + in case t of + TypeInt8 -> unpack v + TypeInt16 -> unpack v + TypeInt32 -> unpack v + TypeInt64 -> unpack v + TypeInt128 -> unpack v + TypeWord8 -> unpack v + TypeWord16 -> unpack v + TypeWord32 -> unpack v + TypeWord64 -> unpack v + TypeWord128 -> unpack v + + floating :: FloatingType (Prim.Vec2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + floating (SingleFloatingType t) = case t of + floating (VectorFloatingType n t) = + let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) + in case t of + TypeFloat16 -> unpack v + TypeFloat32 -> unpack v + TypeFloat64 -> unpack v + TypeFloat128 -> unpack v + + unpack :: ScalarType (Prim.Vec 2 t) -> SmartExp (Prim.Vec 2 t) -> (SmartExp t, SmartExp t) + unpack v x = + let r = SmartExp (Extract v TypeWord8 x (SmartExp (Const scalarType 0))) + i = SmartExp (Extract v TypeWord8 x (SmartExp (Const scalarType 1))) + in + (r, i) -coerce :: EltR a ~ EltR b => Exp a -> Exp b -coerce (Exp e) = Exp e instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Complex a) where type Plain (Complex a) = Complex (Plain a) @@ -198,10 +272,16 @@ instance Eq a => Eq (Complex a) where r1 ::+ c1 /= r2 ::+ c2 = r1 /= r2 || c1 /= c2 instance RealFloat a => P.Num (Exp (Complex a)) where - (+) = lift2 ((+) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) - (-) = lift2 ((-) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) - (*) = lift2 ((*) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) - negate = lift1 (negate :: Complex (Exp a) -> Complex (Exp a)) + (+) = case complexR (eltR @a) of + ComplexTup -> lift2 ((+) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) + ComplexVec t -> mkPrimBinary $ PrimAdd t + (-) = case complexR (eltR @a) of + ComplexTup -> lift2 ((-) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) + ComplexVec t -> mkPrimBinary $ PrimSub t + (*) = lift2 ((*) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) + negate = case complexR (eltR @a) of + ComplexTup -> lift1 (negate :: Complex (Exp a) -> Complex (Exp a)) + ComplexVec t -> mkPrimUnary $ PrimNeg t signum z@(x ::+ y) = if z == 0 then z diff --git a/src/Data/Array/Accelerate/Data/Either.hs b/src/Data/Array/Accelerate/Data/Either.hs index 3c5c7401d..af50ce6d5 100644 --- a/src/Data/Array/Accelerate/Data/Either.hs +++ b/src/Data/Array/Accelerate/Data/Either.hs @@ -33,6 +33,7 @@ module Data.Array.Accelerate.Data.Either ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Language import Data.Array.Accelerate.Lift @@ -64,7 +65,7 @@ isLeft = not . isRight -- | Return 'True' if the argument is a 'Right'-value -- isRight :: (Elt a, Elt b) => Exp (Either a b) -> Exp Bool -isRight (Exp e) = Exp $ SmartExp $ (SmartExp $ Prj PairIdxLeft e) `Pair` SmartExp Nil +isRight (Exp e) = mkExp $ PrimApp (PrimToBool integralType bitType) (SmartExp $ Prj PairIdxLeft e) -- TLM: This is a sneaky hack because we know that the tag bits for Right -- and True are identical. @@ -73,14 +74,14 @@ isRight (Exp e) = Exp $ SmartExp $ (SmartExp $ Prj PairIdxLeft e) `Pair` SmartEx -- instead. -- fromLeft :: (Elt a, Elt b) => Exp (Either a b) -> Exp a -fromLeft (Exp e) = Exp $ SmartExp $ Prj PairIdxRight $ SmartExp $ Prj PairIdxLeft $ SmartExp $ Prj PairIdxRight e +fromLeft (Exp e) = mkExp $ Prj PairIdxRight $ SmartExp $ Prj PairIdxLeft $ SmartExp $ Prj PairIdxRight e -- | The 'fromRight' function extracts the element out of the 'Right' -- constructor. If the argument was actually 'Left', you will get an undefined -- value instead. -- fromRight :: (Elt a, Elt b) => Exp (Either a b) -> Exp b -fromRight (Exp e) = Exp $ SmartExp $ Prj PairIdxRight $ SmartExp $ Prj PairIdxRight e +fromRight (Exp e) = mkExp $ Prj PairIdxRight $ SmartExp $ Prj PairIdxRight e -- | The 'either' function performs case analysis on the 'Either' type. If the -- value is @'Left' a@, apply the first function to @a@; if it is @'Right' b@, diff --git a/src/Data/Array/Accelerate/Data/Maybe.hs b/src/Data/Array/Accelerate/Data/Maybe.hs index 14e8b2ade..11366398b 100644 --- a/src/Data/Array/Accelerate/Data/Maybe.hs +++ b/src/Data/Array/Accelerate/Data/Maybe.hs @@ -33,6 +33,7 @@ module Data.Array.Accelerate.Data.Maybe ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Language import Data.Array.Accelerate.Lift @@ -65,7 +66,7 @@ isNothing = not . isJust -- | Returns 'True' if the argument is of the form @Just _@ -- isJust :: Elt a => Exp (Maybe a) -> Exp Bool -isJust (Exp x) = Exp $ SmartExp $ (SmartExp $ Prj PairIdxLeft x) `Pair` SmartExp Nil +isJust (Exp x) = mkExp $ PrimApp (PrimToBool integralType bitType) (SmartExp $ Prj PairIdxLeft x) -- TLM: This is a sneaky hack because we know that the tag bits for Just -- and True are identical. @@ -134,9 +135,11 @@ instance (Monoid (Exp a), Elt a) => Monoid (Exp (Maybe a)) where mempty = Nothing_ instance (Semigroup (Exp a), Elt a) => Semigroup (Exp (Maybe a)) where - ma <> mb = cond (isNothing ma) mb - $ cond (isNothing mb) mb - $ lift (Just (fromJust ma <> fromJust mb)) + (<>) = match go + where + go Nothing_ b = b + go a Nothing_ = a + go (Just_ a) (Just_ b) = Just_ (a <> b) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Maybe a) where type Plain (Maybe a) = Maybe (Plain a) diff --git a/src/Data/Array/Accelerate/Data/Ratio.hs b/src/Data/Array/Accelerate/Data/Ratio.hs index 190317297..d54d5b960 100644 --- a/src/Data/Array/Accelerate/Data/Ratio.hs +++ b/src/Data/Array/Accelerate/Data/Ratio.hs @@ -6,6 +6,7 @@ {-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | @@ -34,7 +35,6 @@ import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Prelude import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Enum import Data.Array.Accelerate.Classes.Eq @@ -75,7 +75,7 @@ reduce x y = -- infixl 7 % (%) :: Integral a => Exp a -> Exp a -> Exp (Ratio a) -x % y = reduce (x * signum y) (abs y) +x % y = reduce (x * signum y) (abs y) infinity :: Integral a => Exp (Ratio a) infinity = 1 :% 0 @@ -109,16 +109,17 @@ instance Integral a => P.Fractional (Exp (Ratio a)) where else y :% x fromRational r = fromInteger (P.numerator r) % fromInteger (P.denominator r) -instance (Integral a, FromIntegral a Int64) => RealFrac (Ratio a) where +instance Integral a => RealFrac (Ratio a) where + type Significand (Ratio a) = a properFraction (x :% y) = let (q,r) = quotRem x y - in (fromIntegral (fromIntegral q :: Exp Int64), r :% y) + in T2 (fromIntegral q) (r :% y) instance (Integral a, ToFloating a b) => ToFloating (Ratio a) b where toFloating (x :% y) = let x' :% y' = reduce x y - in toFloating x' / toFloating y' + in toFloating x' / toFloating y' instance (FromIntegral a b, Integral b) => FromIntegral a (Ratio b) where fromIntegral x = fromIntegral x :% 1 diff --git a/src/Data/Array/Accelerate/Error.hs b/src/Data/Array/Accelerate/Error.hs index 3f00a6b5c..de6741a96 100644 --- a/src/Data/Array/Accelerate/Error.hs +++ b/src/Data/Array/Accelerate/Error.hs @@ -58,7 +58,7 @@ unsafeCheck = withFrozenCallStack $ check Unsafe -- | Throw an error if the index is not in range, otherwise evaluate the result. -- -indexCheck :: HasCallStack => Int -> Int -> a -> a +indexCheck :: (HasCallStack, Integral i) => i -> i -> a -> a indexCheck i n = boundsCheck (bformat ("index out of bounds: i=" % int % ", n=" % int) i n) (i >= 0 && i < n) diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index aee68443f..2f7548344 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -1,8 +1,11 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ParallelListComp #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} @@ -37,7 +40,7 @@ module Data.Array.Accelerate.Interpreter ( run, run1, runN, -- Internal (hidden) - evalPrim, evalPrimConst, evalCoerceScalar, atraceOp, + evalPrim, evalCoerceScalar, atraceOp, ) where @@ -53,28 +56,28 @@ import Data.Array.Accelerate.Representation.Slice import Data.Array.Accelerate.Representation.Stencil import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type -import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Trafo import Data.Array.Accelerate.Trafo.Delayed ( DelayedOpenAfun, DelayedOpenAcc ) import Data.Array.Accelerate.Trafo.Sharing ( AfunctionR, AfunctionRepr(..), afunctionRepr ) import Data.Array.Accelerate.Type -import Data.Primitive.Vec +import Data.Primitive.Bit as Bit +import Data.Primitive.Vec as Vec +import Data.Primitive.Vec as Prim import qualified Data.Array.Accelerate.AST as AST import qualified Data.Array.Accelerate.Debug.Internal.Flags as Debug import qualified Data.Array.Accelerate.Debug.Internal.Graph as Debug import qualified Data.Array.Accelerate.Debug.Internal.Stats as Debug import qualified Data.Array.Accelerate.Debug.Internal.Timed as Debug +import qualified Data.Array.Accelerate.Interpreter.Arithmetic as A import qualified Data.Array.Accelerate.Smart as Smart import qualified Data.Array.Accelerate.Sugar.Array as Sugar import qualified Data.Array.Accelerate.Sugar.Elt as Sugar import qualified Data.Array.Accelerate.Trafo.Delayed as AST -import GHC.TypeLits import Control.DeepSeq import Control.Exception import Control.Monad import Control.Monad.ST -import Data.Bits import Data.Primitive.ByteArray import Data.Primitive.Types import Data.Text.Lazy.Builder @@ -82,8 +85,10 @@ import Formatting import System.IO import System.IO.Unsafe ( unsafePerformIO ) import Unsafe.Coerce -import qualified Data.Text.IO as T import Prelude hiding ( (!!), sum ) +import qualified Data.Text.IO as T + +import GHC.TypeLits -- Program execution @@ -126,13 +131,6 @@ runN f = go eval AfunctionReprBody (Abody b) aenv = unsafePerformIO $ phase "execute" Debug.elapsed (Sugar.toArr . snd <$> evaluate (evalOpenAcc b aenv)) eval _ _aenv _ = error "Two men say they're Jesus; one of them must be wrong" --- -- | Stream a lazily read list of input arrays through the given program, --- -- collecting results as we go --- -- --- streamOut :: Arrays a => Sugar.Seq [a] -> [a] --- streamOut seq = let seq' = convertSeqWith config seq --- in evalDelayedSeq defaultSeqConfig seq' - -- Debugging -- --------- @@ -152,7 +150,7 @@ data Delayed a where Delayed :: ArrayR (Array sh e) -> sh -> (sh -> e) - -> (Int -> e) + -> (INT -> e) -> Delayed (Array sh e) @@ -226,7 +224,7 @@ evalOpenAcc (AST.Manifest pacc) aenv = p = evalOpenAfun cond aenv f = evalOpenAfun body aenv go !x - | toBool (linearIndexArray (Sugar.eltR @Word8) (p x) 0) = go (f x) + | toBool (linearIndexArray (Sugar.eltR @Bool) (p x) 0) = go (f x) | otherwise = x Use repr arr -> (TupRsingle repr, arr) @@ -372,7 +370,7 @@ zipWithOp tp f (Delayed (ArrayR shr _) shx xs _) (Delayed _ shy ys _) foldOp :: (e -> e -> e) -> e - -> Delayed (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) -> WithReprs (Array sh e) foldOp f z (Delayed (ArrayR (ShapeRsnoc shr) tp) (sh, n) arr _) = fromFunction' (ArrayR shr tp) sh (\ix -> iter (ShapeRsnoc ShapeRz) ((), n) (\((), i) -> arr (ix, i)) f z) @@ -381,7 +379,7 @@ foldOp f z (Delayed (ArrayR (ShapeRsnoc shr) tp) (sh, n) arr _) fold1Op :: HasCallStack => (e -> e -> e) - -> Delayed (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) -> WithReprs (Array sh e) fold1Op f (Delayed (ArrayR (ShapeRsnoc shr) tp) (sh, n) arr _) = boundsCheck "empty array" (n > 0) @@ -390,12 +388,12 @@ fold1Op f (Delayed (ArrayR (ShapeRsnoc shr) tp) (sh, n) arr _) foldSegOp :: HasCallStack - => IntegralType i + => SingleIntegralType i -> (e -> e -> e) -> e - -> Delayed (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) -> Delayed (Segments i) - -> WithReprs (Array (sh, Int) e) + -> WithReprs (Array (sh, INT) e) foldSegOp itp f z (Delayed repr (sh, _) arr _) (Delayed _ ((), n) _ seg) | IntegralDict <- integralDict itp = boundsCheck "empty segment descriptor" (n > 0) @@ -409,11 +407,11 @@ foldSegOp itp f z (Delayed repr (sh, _) arr _) (Delayed _ ((), n) _ seg) fold1SegOp :: HasCallStack - => IntegralType i + => SingleIntegralType i -> (e -> e -> e) - -> Delayed (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) -> Delayed (Segments i) - -> WithReprs (Array (sh, Int) e) + -> WithReprs (Array (sh, INT) e) fold1SegOp itp f (Delayed repr (sh, _) arr _) (Delayed _ ((), n) _ seg) | IntegralDict <- integralDict itp = boundsCheck "empty segment descriptor" (n > 0) @@ -428,8 +426,8 @@ fold1SegOp itp f (Delayed repr (sh, _) arr _) (Delayed _ ((), n) _ seg) scanl1Op :: forall sh e. HasCallStack => (e -> e -> e) - -> Delayed (Array (sh, Int) e) - -> WithReprs (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) + -> WithReprs (Array (sh, INT) e) scanl1Op f (Delayed (ArrayR shr tp) sh ain _) = ( TupRsingle $ ArrayR shr tp , adata `seq` Array sh adata @@ -437,13 +435,13 @@ scanl1Op f (Delayed (ArrayR shr tp) sh ain _) where -- (adata, _) = runArrayData @e $ do - aout <- newArrayData tp (size shr sh) + aout <- newArrayData tp (fromIntegral $ size shr sh) - let write (sz, 0) = writeArrayData tp aout (toIndex shr sh (sz, 0)) (ain (sz, 0)) + let write (sz, 0) = writeArrayData tp aout (fromIntegral $ toIndex shr sh (sz, 0)) (ain (sz, 0)) write (sz, i) = do - x <- readArrayData tp aout (toIndex shr sh (sz, i-1)) + x <- readArrayData tp aout (fromIntegral $ toIndex shr sh (sz, i-1)) let y = ain (sz, i) - writeArrayData tp aout (toIndex shr sh (sz, i)) (f x y) + writeArrayData tp aout (fromIntegral $ toIndex shr sh (sz, i)) (f x y) iter shr sh write (>>) (return ()) return (aout, undefined) @@ -453,8 +451,8 @@ scanlOp :: forall sh e. (e -> e -> e) -> e - -> Delayed (Array (sh, Int) e) - -> WithReprs (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) + -> WithReprs (Array (sh, INT) e) scanlOp f z (Delayed (ArrayR shr tp) (sh, n) ain _) = ( TupRsingle $ ArrayR shr tp , adata `seq` Array sh' adata @@ -463,13 +461,13 @@ scanlOp f z (Delayed (ArrayR shr tp) (sh, n) ain _) sh' = (sh, n+1) -- (adata, _) = runArrayData @e $ do - aout <- newArrayData tp (size shr sh') + aout <- newArrayData tp (fromIntegral $ size shr sh') - let write (sz, 0) = writeArrayData tp aout (toIndex shr sh' (sz, 0)) z + let write (sz, 0) = writeArrayData tp aout (fromIntegral $ toIndex shr sh' (sz, 0)) z write (sz, i) = do - x <- readArrayData tp aout (toIndex shr sh' (sz, i-1)) + x <- readArrayData tp aout (fromIntegral $ toIndex shr sh' (sz, i-1)) let y = ain (sz, i-1) - writeArrayData tp aout (toIndex shr sh' (sz, i)) (f x y) + writeArrayData tp aout (fromIntegral $ toIndex shr sh' (sz, i)) (f x y) iter shr sh' write (>>) (return ()) return (aout, undefined) @@ -479,26 +477,26 @@ scanl'Op :: forall sh e. (e -> e -> e) -> e - -> Delayed (Array (sh, Int) e) - -> WithReprs (Array (sh, Int) e, Array sh e) + -> Delayed (Array (sh, INT) e) + -> WithReprs (Array (sh, INT) e, Array sh e) scanl'Op f z (Delayed (ArrayR shr@(ShapeRsnoc shr') tp) (sh, n) ain _) = ( TupRsingle (ArrayR shr tp) `TupRpair` TupRsingle (ArrayR shr' tp) , aout `seq` asum `seq` ( Array (sh, n) aout, Array sh asum ) ) where ((aout, asum), _) = runArrayData @(e, e) $ do - aout <- newArrayData tp (size shr (sh, n)) - asum <- newArrayData tp (size shr' sh) + aout <- newArrayData tp (fromIntegral $ size shr (sh, n)) + asum <- newArrayData tp (fromIntegral $ size shr' sh) let write (sz, 0) - | n == 0 = writeArrayData tp asum (toIndex shr' sh sz) z - | otherwise = writeArrayData tp aout (toIndex shr (sh, n) (sz, 0)) z + | n == 0 = writeArrayData tp asum (fromIntegral $ toIndex shr' sh sz) z + | otherwise = writeArrayData tp aout (fromIntegral $ toIndex shr (sh, n) (sz, 0)) z write (sz, i) = do - x <- readArrayData tp aout (toIndex shr (sh, n) (sz, i-1)) + x <- readArrayData tp aout (fromIntegral $ toIndex shr (sh, n) (sz, i-1)) let y = ain (sz, i-1) if i == n - then writeArrayData tp asum (toIndex shr' sh sz) (f x y) - else writeArrayData tp aout (toIndex shr (sh, n) (sz, i)) (f x y) + then writeArrayData tp asum (fromIntegral $ toIndex shr' sh sz) (f x y) + else writeArrayData tp aout (fromIntegral $ toIndex shr (sh, n) (sz, i)) (f x y) iter shr (sh, n+1) write (>>) (return ()) return ((aout, asum), undefined) @@ -508,8 +506,8 @@ scanrOp :: forall sh e. (e -> e -> e) -> e - -> Delayed (Array (sh, Int) e) - -> WithReprs (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) + -> WithReprs (Array (sh, INT) e) scanrOp f z (Delayed (ArrayR shr tp) (sz, n) ain _) = ( TupRsingle (ArrayR shr tp) , adata `seq` Array sh' adata @@ -518,13 +516,13 @@ scanrOp f z (Delayed (ArrayR shr tp) (sz, n) ain _) sh' = (sz, n+1) -- (adata, _) = runArrayData @e $ do - aout <- newArrayData tp (size shr sh') + aout <- newArrayData tp (fromIntegral $ size shr sh') - let write (sz, 0) = writeArrayData tp aout (toIndex shr sh' (sz, n)) z + let write (sz, 0) = writeArrayData tp aout (fromIntegral $ toIndex shr sh' (sz, n)) z write (sz, i) = do let x = ain (sz, n-i) - y <- readArrayData tp aout (toIndex shr sh' (sz, n-i+1)) - writeArrayData tp aout (toIndex shr sh' (sz, n-i)) (f x y) + y <- readArrayData tp aout (fromIntegral $ toIndex shr sh' (sz, n-i+1)) + writeArrayData tp aout (fromIntegral $ toIndex shr sh' (sz, n-i)) (f x y) iter shr sh' write (>>) (return ()) return (aout, undefined) @@ -533,21 +531,21 @@ scanrOp f z (Delayed (ArrayR shr tp) (sz, n) ain _) scanr1Op :: forall sh e. HasCallStack => (e -> e -> e) - -> Delayed (Array (sh, Int) e) - -> WithReprs (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) + -> WithReprs (Array (sh, INT) e) scanr1Op f (Delayed (ArrayR shr tp) sh@(_, n) ain _) = ( TupRsingle $ ArrayR shr tp , adata `seq` Array sh adata ) where (adata, _) = runArrayData @e $ do - aout <- newArrayData tp (size shr sh) + aout <- newArrayData tp (fromIntegral $ size shr sh) - let write (sz, 0) = writeArrayData tp aout (toIndex shr sh (sz, n-1)) (ain (sz, n-1)) + let write (sz, 0) = writeArrayData tp aout (fromIntegral $ toIndex shr sh (sz, n-1)) (ain (sz, n-1)) write (sz, i) = do let x = ain (sz, n-i-1) - y <- readArrayData tp aout (toIndex shr sh (sz, n-i)) - writeArrayData tp aout (toIndex shr sh (sz, n-i-1)) (f x y) + y <- readArrayData tp aout (fromIntegral $ toIndex shr sh (sz, n-i)) + writeArrayData tp aout (fromIntegral $ toIndex shr sh (sz, n-i-1)) (f x y) iter shr sh write (>>) (return ()) return (aout, undefined) @@ -557,27 +555,27 @@ scanr'Op :: forall sh e. (e -> e -> e) -> e - -> Delayed (Array (sh, Int) e) - -> WithReprs (Array (sh, Int) e, Array sh e) + -> Delayed (Array (sh, INT) e) + -> WithReprs (Array (sh, INT) e, Array sh e) scanr'Op f z (Delayed (ArrayR shr@(ShapeRsnoc shr') tp) (sh, n) ain _) = ( TupRsingle (ArrayR shr tp) `TupRpair` TupRsingle (ArrayR shr' tp) , aout `seq` asum `seq` ( Array (sh, n) aout, Array sh asum ) ) where ((aout, asum), _) = runArrayData @(e, e) $ do - aout <- newArrayData tp (size shr (sh, n)) - asum <- newArrayData tp (size shr' sh) + aout <- newArrayData tp (fromIntegral $ size shr (sh, n)) + asum <- newArrayData tp (fromIntegral $ size shr' sh) let write (sz, 0) - | n == 0 = writeArrayData tp asum (toIndex shr' sh sz) z - | otherwise = writeArrayData tp aout (toIndex shr (sh, n) (sz, n-1)) z + | n == 0 = writeArrayData tp asum (fromIntegral $ toIndex shr' sh sz) z + | otherwise = writeArrayData tp aout (fromIntegral $ toIndex shr (sh, n) (sz, n-1)) z write (sz, i) = do let x = ain (sz, n-i) - y <- readArrayData tp aout (toIndex shr (sh, n) (sz, n-i)) + y <- readArrayData tp aout (fromIntegral $ toIndex shr (sh, n) (sz, n-i)) if i == n - then writeArrayData tp asum (toIndex shr' sh sz) (f x y) - else writeArrayData tp aout (toIndex shr (sh, n) (sz, n-i-1)) (f x y) + then writeArrayData tp asum (fromIntegral $ toIndex shr' sh sz) (f x y) + else writeArrayData tp aout (fromIntegral $ toIndex shr (sh, n) (sz, n-i-1)) (f x y) iter shr (sh, n+1) write (>>) (return ()) return ((aout, asum), undefined) @@ -597,14 +595,14 @@ permuteOp f (TupRsingle (ArrayR shr' _), def@(Array _ adef)) p (Delayed (ArrayR n' = size shr' sh' -- (adata, _) = runArrayData @e $ do - aout <- newArrayData tp n' + aout <- newArrayData tp (fromIntegral n') let -- initialise array with default values init i | i >= n' = return () | otherwise = do - x <- readArrayData tp adef i - writeArrayData tp aout i x + x <- readArrayData tp adef (fromIntegral i) + writeArrayData tp aout (fromIntegral i) x init (i+1) -- project each element onto the destination array and update @@ -616,8 +614,8 @@ permuteOp f (TupRsingle (ArrayR shr' _), def@(Array _ adef)) p (Delayed (ArrayR j = toIndex shr' sh' dst x = ain i -- - y <- readArrayData tp aout j - writeArrayData tp aout j (f x y) + y <- readArrayData tp aout (fromIntegral j) + writeArrayData tp aout (fromIntegral j) (f x y) _ -> internalError "unexpected tag" init 0 @@ -783,13 +781,13 @@ stencilAccess stencil = goR (stencilShapeR stencil) stencil -- Add a left-most component to an index -- - cons :: ShapeR sh -> Int -> sh -> (sh, Int) + cons :: ShapeR sh -> INT -> sh -> (sh, INT) cons ShapeRz ix () = ((), ix) cons (ShapeRsnoc shr) ix (sh, sz) = (cons shr ix sh, sz) -- Remove the left-most index of an index, and return the remainder -- - uncons :: ShapeR sh -> (sh, Int) -> (Int, sh) + uncons :: ShapeR sh -> (sh, INT) -> (INT, sh) uncons ShapeRz ((), v) = (v, ()) uncons (ShapeRsnoc shr) (v1, v2) = let (i, v1') = uncons shr v1 in (i, (v1', v2)) @@ -839,17 +837,6 @@ bounded shr bnd (Delayed _ sh f _) ix = _ -> internalError "unexpected boundary condition" | otherwise = iz --- toSeqOp :: forall slix sl dim co e proxy. (Elt slix, Shape sl, Shape dim, Elt e) --- => SliceIndex (EltRepr slix) --- (EltRepr sl) --- co --- (EltRepr dim) --- -> proxy slix --- -> Array dim e --- -> [Array sl e] --- toSeqOp sliceIndex _ arr = map (sliceOp sliceIndex arr :: slix -> Array sl e) --- (enumSlices sliceIndex (shape arr)) - -- Stencil boundary conditions -- --------------------------- @@ -934,16 +921,18 @@ evalOpenExp pexp env aenv = Evar (Var _ ix) -> prj ix env Const _ c -> c Undef tp -> undefElt (TupRsingle tp) - PrimConst c -> evalPrimConst c PrimApp f x -> evalPrim f (evalE x) Nil -> () Pair e1 e2 -> let !x1 = evalE e1 !x2 = evalE e2 in (x1, x2) - VecPack vecR e -> pack vecR $! evalE e - VecUnpack vecR e -> unpack vecR $! evalE e - IndexSlice slice slix sh -> restrict slice (evalE slix) - (evalE sh) + Extract vR iR v i -> evalExtract vR iR (evalE v) (evalE i) + Insert vR iR v i x -> evalInsert vR iR (evalE v) (evalE i) (evalE x) + Shuffle rR iR x y i -> let TupRsingle eR = expType x + in evalShuffle eR rR iR (evalE x) (evalE y) (evalE i) + Select m x y -> let TupRsingle eR = expType x + in evalSelect eR (evalE m) (evalE x) (evalE y) + IndexSlice slice slix sh -> restrict slice (evalE slix) (evalE sh) where restrict :: SliceIndex slix sl co sh -> slix -> sh -> sl restrict SliceNil () () = () @@ -953,8 +942,7 @@ evalOpenExp pexp env aenv = restrict (SliceFixed sliceIdx) (slx, _i) (sl, _sz) = restrict sliceIdx slx sl - IndexFull slice slix sh -> extend slice (evalE slix) - (evalE sh) + IndexFull slice slix sh -> extend slice (evalE slix) (evalE sh) where extend :: SliceIndex slix sl co sh -> slix -> sl -> sh extend SliceNil () () = () @@ -967,7 +955,7 @@ evalOpenExp pexp env aenv = ToIndex shr sh ix -> toIndex shr (evalE sh) (evalE ix) FromIndex shr sh ix -> fromIndex shr (evalE sh) (evalE ix) - Case e rhs def -> evalE (caseof (evalE e) rhs) + Case _ e rhs def -> evalE (caseof (evalE e) rhs) where caseof :: TAG -> [(TAG, OpenExp env aenv t)] -> OpenExp env aenv t caseof tag = go @@ -1009,6 +997,185 @@ evalOpenExp pexp env aenv = -- destination values are equal (this is not checked at this point). -- evalCoerceScalar :: ScalarType a -> ScalarType b -> a -> b +evalCoerceScalar = scalar + where + scalar :: ScalarType a -> ScalarType b -> a -> b + scalar (NumScalarType a) = num a + scalar (BitScalarType a) = bit a + + bit :: BitType a -> ScalarType b -> a -> b + bit TypeBit = \case + BitScalarType TypeBit -> id + _ -> internalError "evalCoerceScalar @Bit" + bit (TypeMask _) = \case + NumScalarType b -> num' b + BitScalarType b -> bit' b + where + bit' :: BitType b -> Vec n Bit -> b + bit' TypeMask{} = unsafeCoerce + bit' TypeBit = internalError "evalCoerceScalar @Bit" + + num' :: NumType b -> Vec n Bit -> b + num' (IntegralNumType b) = integral' b + num' (FloatingNumType b) = floating' b + + integral' :: IntegralType b -> Vec n Bit -> b + integral' (VectorIntegralType _ _) = unsafeCoerce + integral' (SingleIntegralType b) + | IntegralDict <- integralDict b + = peek + + floating' :: FloatingType b -> Vec n Bit -> b + floating' (VectorFloatingType _ _) = unsafeCoerce + floating' (SingleFloatingType b) + | FloatingDict <- floatingDict b + = peek + + num :: NumType a -> ScalarType b -> a -> b + num (IntegralNumType a) = integral a + num (FloatingNumType t) = floating t + + integral :: IntegralType a -> ScalarType b -> a -> b + integral (SingleIntegralType a) = \case + NumScalarType b -> num' b a + BitScalarType b -> bit' b a + where + bit' :: BitType b -> SingleIntegralType a -> a -> b + bit' TypeBit _ = unsafeCoerce + bit' TypeMask{} a + | IntegralDict <- integralDict a + = poke + + num' :: NumType b -> SingleIntegralType a -> a -> b + num' (IntegralNumType b) = integral' b + num' (FloatingNumType b) = floating' b + + integral' :: IntegralType b -> SingleIntegralType a -> a -> b + integral' (SingleIntegralType _) _ = unsafeCoerce + integral' (VectorIntegralType _ b) a + | IntegralDict <- integralDict a + = case b of + TypeInt8 -> poke + TypeInt16 -> poke + TypeInt32 -> poke + TypeInt64 -> poke + TypeInt128 -> poke + TypeWord8 -> poke + TypeWord16 -> poke + TypeWord32 -> poke + TypeWord64 -> poke + TypeWord128 -> poke + + floating' :: FloatingType b -> SingleIntegralType a -> a -> b + floating' (SingleFloatingType _) _ = unsafeCoerce + floating' (VectorFloatingType _ b) a + | IntegralDict <- integralDict a + = case b of + TypeFloat16 -> poke + TypeFloat32 -> poke + TypeFloat64 -> poke + TypeFloat128 -> poke + + integral (VectorIntegralType _ a) = \case + NumScalarType b -> num' b a + BitScalarType b -> bit' b a + where + bit' :: BitType b -> SingleIntegralType a -> Vec n a -> b + bit' TypeBit _ = unsafeCoerce + bit' TypeMask{} _ = unsafeCoerce + + num' :: NumType b -> SingleIntegralType a -> Vec n a -> b + num' (IntegralNumType b) = integral' b + num' (FloatingNumType b) = floating' b + + integral' :: IntegralType b -> SingleIntegralType a -> Vec n a -> b + integral' (VectorIntegralType _ _) _ = unsafeCoerce + integral' (SingleIntegralType b) _ + | IntegralDict <- integralDict b + = peek + + floating' :: FloatingType b -> SingleIntegralType a -> Vec n a -> b + floating' (VectorFloatingType _ _) _ = unsafeCoerce + floating' (SingleFloatingType b) _ + | FloatingDict <- floatingDict b + = peek + + floating :: FloatingType a -> ScalarType b -> a -> b + floating (SingleFloatingType a) = \case + NumScalarType b -> num' b a + BitScalarType b -> bit' b a + where + bit' :: BitType b -> SingleFloatingType a -> a -> b + bit' TypeBit _ = unsafeCoerce + bit' TypeMask{} _ = unsafeCoerce + + num' :: NumType b -> SingleFloatingType a -> a -> b + num' (IntegralNumType b) = integral' b + num' (FloatingNumType b) = floating' b + + integral' :: IntegralType b -> SingleFloatingType a -> a -> b + integral' (SingleIntegralType _) _ = unsafeCoerce + integral' (VectorIntegralType _ b) a + | FloatingDict <- floatingDict a + = case b of + TypeInt8 -> poke + TypeInt16 -> poke + TypeInt32 -> poke + TypeInt64 -> poke + TypeInt128 -> poke + TypeWord8 -> poke + TypeWord16 -> poke + TypeWord32 -> poke + TypeWord64 -> poke + TypeWord128 -> poke + + floating' :: FloatingType b -> SingleFloatingType a -> a -> b + floating' (SingleFloatingType _) _ = unsafeCoerce + floating' (VectorFloatingType _ b) a + | FloatingDict <- floatingDict a + = case b of + TypeFloat16 -> poke + TypeFloat32 -> poke + TypeFloat64 -> poke + TypeFloat128 -> poke + + floating (VectorFloatingType _ a) = \case + NumScalarType b -> num' b a + BitScalarType b -> bit' b a + where + bit' :: BitType b -> SingleFloatingType a -> Vec n a -> b + bit' TypeBit _ = unsafeCoerce + bit' TypeMask{} _ = unsafeCoerce + + num' :: NumType b -> SingleFloatingType a -> Vec n a -> b + num' (IntegralNumType b) = integral' b + num' (FloatingNumType b) = floating' b + + integral' :: IntegralType b -> SingleFloatingType a -> Vec n a -> b + integral' (VectorIntegralType _ _) _ = unsafeCoerce + integral' (SingleIntegralType b) _ + | IntegralDict <- integralDict b + = peek + + floating' :: FloatingType b -> SingleFloatingType a -> Vec n a -> b + floating' (VectorFloatingType _ _) _ = unsafeCoerce + floating' (SingleFloatingType b) _ + | FloatingDict <- floatingDict b + = peek + + {-# INLINE poke #-} + poke :: forall a b n. Prim a => a -> Vec n b + poke x = runST $ do + mba <- newByteArray (sizeOf (undefined::a)) + writeByteArray mba 0 x + ByteArray ba# <- unsafeFreezeByteArray mba + return $ Vec ba# + + {-# INLINE peek #-} + peek :: Prim b => Vec n a -> b + peek (Vec ba#) = indexByteArray (ByteArray ba#) 0 + +{-- evalCoerceScalar SingleScalarType{} SingleScalarType{} a = unsafeCoerce a evalCoerceScalar VectorScalarType{} VectorScalarType{} a = unsafeCoerce a -- XXX: or just unpack/repack the (Vec ba#) evalCoerceScalar (SingleScalarType ta) VectorScalarType{} a = vector ta a @@ -1074,806 +1241,308 @@ evalCoerceScalar VectorScalarType{} (SingleScalarType tb) a = scalar tb a {-# INLINE peek #-} peek :: Prim a => Vec n b -> a peek (Vec ba#) = indexByteArray (ByteArray ba#) 0 - +--} -- Scalar primitives -- ----------------- -evalPrimConst :: PrimConst a -> a -evalPrimConst (PrimMinBound ty) = evalMinBound ty -evalPrimConst (PrimMaxBound ty) = evalMaxBound ty -evalPrimConst (PrimPi ty) = evalPi ty - evalPrim :: PrimFun (a -> r) -> (a -> r) -evalPrim (PrimAdd ty) = evalAdd ty -evalPrim (PrimSub ty) = evalSub ty -evalPrim (PrimMul ty) = evalMul ty -evalPrim (PrimNeg ty) = evalNeg ty -evalPrim (PrimAbs ty) = evalAbs ty -evalPrim (PrimSig ty) = evalSig ty -evalPrim (PrimQuot ty) = evalQuot ty -evalPrim (PrimRem ty) = evalRem ty -evalPrim (PrimQuotRem ty) = evalQuotRem ty -evalPrim (PrimIDiv ty) = evalIDiv ty -evalPrim (PrimMod ty) = evalMod ty -evalPrim (PrimDivMod ty) = evalDivMod ty -evalPrim (PrimBAnd ty) = evalBAnd ty -evalPrim (PrimBOr ty) = evalBOr ty -evalPrim (PrimBXor ty) = evalBXor ty -evalPrim (PrimBNot ty) = evalBNot ty -evalPrim (PrimBShiftL ty) = evalBShiftL ty -evalPrim (PrimBShiftR ty) = evalBShiftR ty -evalPrim (PrimBRotateL ty) = evalBRotateL ty -evalPrim (PrimBRotateR ty) = evalBRotateR ty -evalPrim (PrimPopCount ty) = evalPopCount ty -evalPrim (PrimCountLeadingZeros ty) = evalCountLeadingZeros ty -evalPrim (PrimCountTrailingZeros ty) = evalCountTrailingZeros ty -evalPrim (PrimFDiv ty) = evalFDiv ty -evalPrim (PrimRecip ty) = evalRecip ty -evalPrim (PrimSin ty) = evalSin ty -evalPrim (PrimCos ty) = evalCos ty -evalPrim (PrimTan ty) = evalTan ty -evalPrim (PrimAsin ty) = evalAsin ty -evalPrim (PrimAcos ty) = evalAcos ty -evalPrim (PrimAtan ty) = evalAtan ty -evalPrim (PrimSinh ty) = evalSinh ty -evalPrim (PrimCosh ty) = evalCosh ty -evalPrim (PrimTanh ty) = evalTanh ty -evalPrim (PrimAsinh ty) = evalAsinh ty -evalPrim (PrimAcosh ty) = evalAcosh ty -evalPrim (PrimAtanh ty) = evalAtanh ty -evalPrim (PrimExpFloating ty) = evalExpFloating ty -evalPrim (PrimSqrt ty) = evalSqrt ty -evalPrim (PrimLog ty) = evalLog ty -evalPrim (PrimFPow ty) = evalFPow ty -evalPrim (PrimLogBase ty) = evalLogBase ty -evalPrim (PrimTruncate ta tb) = evalTruncate ta tb -evalPrim (PrimRound ta tb) = evalRound ta tb -evalPrim (PrimFloor ta tb) = evalFloor ta tb -evalPrim (PrimCeiling ta tb) = evalCeiling ta tb -evalPrim (PrimAtan2 ty) = evalAtan2 ty -evalPrim (PrimIsNaN ty) = evalIsNaN ty -evalPrim (PrimIsInfinite ty) = evalIsInfinite ty -evalPrim (PrimLt ty) = evalLt ty -evalPrim (PrimGt ty) = evalGt ty -evalPrim (PrimLtEq ty) = evalLtEq ty -evalPrim (PrimGtEq ty) = evalGtEq ty -evalPrim (PrimEq ty) = evalEq ty -evalPrim (PrimNEq ty) = evalNEq ty -evalPrim (PrimMax ty) = evalMax ty -evalPrim (PrimMin ty) = evalMin ty -evalPrim PrimLAnd = evalLAnd -evalPrim PrimLOr = evalLOr -evalPrim PrimLNot = evalLNot -evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb -evalPrim (PrimToFloating ta tb) = evalToFloating ta tb - - --- Implementation of scalar primitives --- ----------------------------------- - -toBool :: PrimBool -> Bool -toBool 0 = False -toBool _ = True - -fromBool :: Bool -> PrimBool -fromBool False = 0 -fromBool True = 1 - -evalLAnd :: (PrimBool, PrimBool) -> PrimBool -evalLAnd (x, y) = fromBool (toBool x && toBool y) - -evalLOr :: (PrimBool, PrimBool) -> PrimBool -evalLOr (x, y) = fromBool (toBool x || toBool y) - -evalLNot :: PrimBool -> PrimBool -evalLNot = fromBool . not . toBool - -evalVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, i) -> a -evalVectorIndex (VectorType n _) ti (v, i) | IntegralDict <- integralDict ti = vecIndex v (fromIntegral i) - -evalVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, (i, a)) -> Vec n a -evalVectorWrite (VectorType n _) ti (v, (i, a)) | IntegralDict <- integralDict ti = vecWrite v (fromIntegral i) a - -evalFromIntegral :: IntegralType a -> NumType b -> a -> b -evalFromIntegral ta (IntegralNumType tb) - | IntegralDict <- integralDict ta - , IntegralDict <- integralDict tb - = fromIntegral - -evalFromIntegral ta (FloatingNumType tb) - | IntegralDict <- integralDict ta - , FloatingDict <- floatingDict tb - = fromIntegral - -evalToFloating :: NumType a -> FloatingType b -> a -> b -evalToFloating (IntegralNumType ta) tb - | IntegralDict <- integralDict ta - , FloatingDict <- floatingDict tb - = realToFrac - -evalToFloating (FloatingNumType ta) tb - | FloatingDict <- floatingDict ta - , FloatingDict <- floatingDict tb - = realToFrac - - --- Extract methods from reified dictionaries --- - --- Constant methods of Bounded --- - -evalMinBound :: BoundedType a -> a -evalMinBound (IntegralBoundedType ty) - | IntegralDict <- integralDict ty - = minBound - -evalMaxBound :: BoundedType a -> a -evalMaxBound (IntegralBoundedType ty) - | IntegralDict <- integralDict ty - = maxBound - --- Constant method of floating --- - -evalPi :: FloatingType a -> a -evalPi ty | FloatingDict <- floatingDict ty = pi - -evalVectorCreate :: (KnownNat n, Prim a) => VectorType (Vec n a) -> Vec n a -evalVectorCreate (VectorType n _) = vecEmpty - -evalSin :: FloatingType a -> (a -> a) -evalSin ty | FloatingDict <- floatingDict ty = sin - -evalCos :: FloatingType a -> (a -> a) -evalCos ty | FloatingDict <- floatingDict ty = cos - -evalTan :: FloatingType a -> (a -> a) -evalTan ty | FloatingDict <- floatingDict ty = tan - -evalAsin :: FloatingType a -> (a -> a) -evalAsin ty | FloatingDict <- floatingDict ty = asin - -evalAcos :: FloatingType a -> (a -> a) -evalAcos ty | FloatingDict <- floatingDict ty = acos - -evalAtan :: FloatingType a -> (a -> a) -evalAtan ty | FloatingDict <- floatingDict ty = atan - -evalSinh :: FloatingType a -> (a -> a) -evalSinh ty | FloatingDict <- floatingDict ty = sinh - -evalCosh :: FloatingType a -> (a -> a) -evalCosh ty | FloatingDict <- floatingDict ty = cosh - -evalTanh :: FloatingType a -> (a -> a) -evalTanh ty | FloatingDict <- floatingDict ty = tanh - -evalAsinh :: FloatingType a -> (a -> a) -evalAsinh ty | FloatingDict <- floatingDict ty = asinh - -evalAcosh :: FloatingType a -> (a -> a) -evalAcosh ty | FloatingDict <- floatingDict ty = acosh - -evalAtanh :: FloatingType a -> (a -> a) -evalAtanh ty | FloatingDict <- floatingDict ty = atanh - -evalExpFloating :: FloatingType a -> (a -> a) -evalExpFloating ty | FloatingDict <- floatingDict ty = exp - -evalSqrt :: FloatingType a -> (a -> a) -evalSqrt ty | FloatingDict <- floatingDict ty = sqrt - -evalLog :: FloatingType a -> (a -> a) -evalLog ty | FloatingDict <- floatingDict ty = log - -evalFPow :: FloatingType a -> ((a, a) -> a) -evalFPow ty | FloatingDict <- floatingDict ty = uncurry (**) - -evalLogBase :: FloatingType a -> ((a, a) -> a) -evalLogBase ty | FloatingDict <- floatingDict ty = uncurry logBase - -evalTruncate :: FloatingType a -> IntegralType b -> (a -> b) -evalTruncate ta tb - | FloatingDict <- floatingDict ta - , IntegralDict <- integralDict tb - = truncate - -evalRound :: FloatingType a -> IntegralType b -> (a -> b) -evalRound ta tb - | FloatingDict <- floatingDict ta - , IntegralDict <- integralDict tb - = round - -evalFloor :: FloatingType a -> IntegralType b -> (a -> b) -evalFloor ta tb - | FloatingDict <- floatingDict ta - , IntegralDict <- integralDict tb - = floor - -evalCeiling :: FloatingType a -> IntegralType b -> (a -> b) -evalCeiling ta tb - | FloatingDict <- floatingDict ta - , IntegralDict <- integralDict tb - = ceiling - -evalAtan2 :: FloatingType a -> ((a, a) -> a) -evalAtan2 ty | FloatingDict <- floatingDict ty = uncurry atan2 - -evalIsNaN :: FloatingType a -> (a -> PrimBool) -evalIsNaN ty | FloatingDict <- floatingDict ty = fromBool . isNaN - -evalIsInfinite :: FloatingType a -> (a -> PrimBool) -evalIsInfinite ty | FloatingDict <- floatingDict ty = fromBool . isInfinite - - --- Methods of Num --- - -evalAdd :: NumType a -> ((a, a) -> a) -evalAdd (IntegralNumType ty) | IntegralDict <- integralDict ty = uncurry (+) -evalAdd (FloatingNumType ty) | FloatingDict <- floatingDict ty = uncurry (+) - -evalSub :: NumType a -> ((a, a) -> a) -evalSub (IntegralNumType ty) | IntegralDict <- integralDict ty = uncurry (-) -evalSub (FloatingNumType ty) | FloatingDict <- floatingDict ty = uncurry (-) - -evalMul :: NumType a -> ((a, a) -> a) -evalMul (IntegralNumType ty) | IntegralDict <- integralDict ty = uncurry (*) -evalMul (FloatingNumType ty) | FloatingDict <- floatingDict ty = uncurry (*) - -evalNeg :: NumType a -> (a -> a) -evalNeg (IntegralNumType ty) | IntegralDict <- integralDict ty = negate -evalNeg (FloatingNumType ty) | FloatingDict <- floatingDict ty = negate - -evalAbs :: NumType a -> (a -> a) -evalAbs (IntegralNumType ty) | IntegralDict <- integralDict ty = abs -evalAbs (FloatingNumType ty) | FloatingDict <- floatingDict ty = abs - -evalSig :: NumType a -> (a -> a) -evalSig (IntegralNumType ty) | IntegralDict <- integralDict ty = signum -evalSig (FloatingNumType ty) | FloatingDict <- floatingDict ty = signum - -evalQuot :: IntegralType a -> ((a, a) -> a) -evalQuot ty | IntegralDict <- integralDict ty = uncurry quot - -evalRem :: IntegralType a -> ((a, a) -> a) -evalRem ty | IntegralDict <- integralDict ty = uncurry rem - -evalQuotRem :: IntegralType a -> ((a, a) -> (a, a)) -evalQuotRem ty | IntegralDict <- integralDict ty = uncurry quotRem - -evalIDiv :: IntegralType a -> ((a, a) -> a) -evalIDiv ty | IntegralDict <- integralDict ty = uncurry div - -evalMod :: IntegralType a -> ((a, a) -> a) -evalMod ty | IntegralDict <- integralDict ty = uncurry mod - -evalDivMod :: IntegralType a -> ((a, a) -> (a, a)) -evalDivMod ty | IntegralDict <- integralDict ty = uncurry divMod - -evalBAnd :: IntegralType a -> ((a, a) -> a) -evalBAnd ty | IntegralDict <- integralDict ty = uncurry (.&.) - -evalBOr :: IntegralType a -> ((a, a) -> a) -evalBOr ty | IntegralDict <- integralDict ty = uncurry (.|.) - -evalBXor :: IntegralType a -> ((a, a) -> a) -evalBXor ty | IntegralDict <- integralDict ty = uncurry xor - -evalBNot :: IntegralType a -> (a -> a) -evalBNot ty | IntegralDict <- integralDict ty = complement - -evalBShiftL :: IntegralType a -> ((a, Int) -> a) -evalBShiftL ty | IntegralDict <- integralDict ty = uncurry shiftL - -evalBShiftR :: IntegralType a -> ((a, Int) -> a) -evalBShiftR ty | IntegralDict <- integralDict ty = uncurry shiftR - -evalBRotateL :: IntegralType a -> ((a, Int) -> a) -evalBRotateL ty | IntegralDict <- integralDict ty = uncurry rotateL - -evalBRotateR :: IntegralType a -> ((a, Int) -> a) -evalBRotateR ty | IntegralDict <- integralDict ty = uncurry rotateR - -evalPopCount :: IntegralType a -> (a -> Int) -evalPopCount ty | IntegralDict <- integralDict ty = popCount - -evalCountLeadingZeros :: IntegralType a -> (a -> Int) -evalCountLeadingZeros ty | IntegralDict <- integralDict ty = countLeadingZeros - -evalCountTrailingZeros :: IntegralType a -> (a -> Int) -evalCountTrailingZeros ty | IntegralDict <- integralDict ty = countTrailingZeros - -evalFDiv :: FloatingType a -> ((a, a) -> a) -evalFDiv ty | FloatingDict <- floatingDict ty = uncurry (/) - -evalRecip :: FloatingType a -> (a -> a) -evalRecip ty | FloatingDict <- floatingDict ty = recip - - -evalLt :: SingleType a -> ((a, a) -> PrimBool) -evalLt (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = fromBool . uncurry (<) -evalLt (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = fromBool . uncurry (<) - -evalGt :: SingleType a -> ((a, a) -> PrimBool) -evalGt (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = fromBool . uncurry (>) -evalGt (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = fromBool . uncurry (>) - -evalLtEq :: SingleType a -> ((a, a) -> PrimBool) -evalLtEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = fromBool . uncurry (<=) -evalLtEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = fromBool . uncurry (<=) - -evalGtEq :: SingleType a -> ((a, a) -> PrimBool) -evalGtEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = fromBool . uncurry (>=) -evalGtEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = fromBool . uncurry (>=) - -evalEq :: SingleType a -> ((a, a) -> PrimBool) -evalEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = fromBool . uncurry (==) -evalEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = fromBool . uncurry (==) - -evalNEq :: SingleType a -> ((a, a) -> PrimBool) -evalNEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = fromBool . uncurry (/=) -evalNEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = fromBool . uncurry (/=) - -evalMax :: SingleType a -> ((a, a) -> a) -evalMax (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = uncurry max -evalMax (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = uncurry max - -evalMin :: SingleType a -> ((a, a) -> a) -evalMin (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = uncurry min -evalMin (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = uncurry min - - -{-- --- Sequence evaluation --- --------------- - --- Position in sequence. --- -type SeqPos = Int - --- Configuration for sequence evaluation. --- -data SeqConfig = SeqConfig - { chunkSize :: Int -- Allocation limit for a sequence in - -- words. Actual runtime allocation should be the - -- maximum of this size and the size of the - -- largest element in the sequence. - } - --- Default sequence evaluation configuration for testing purposes. --- -defaultSeqConfig :: SeqConfig -defaultSeqConfig = SeqConfig { chunkSize = 2 } - -type Chunk a = Vector' a - --- The empty chunk. O(1). -emptyChunk :: Arrays a => Chunk a -emptyChunk = empty' - --- Number of arrays in chunk. O(1). --- -clen :: Arrays a => Chunk a -> Int -clen = length' - -elemsPerChunk :: SeqConfig -> Int -> Int -elemsPerChunk conf n - | n < 1 = chunkSize conf - | otherwise = - let (a,b) = chunkSize conf `quotRem` n - in a + signum b - --- Drop a number of arrays from a chunk. O(1). Note: Require keeping a --- scan of element sizes. --- -cdrop :: Arrays a => Int -> Chunk a -> Chunk a -cdrop = drop' dropOp (fst . offsetsOp) - --- Get all the shapes of a chunk of arrays. O(1). --- -chunkShapes :: Chunk (Array sh a) -> Vector sh -chunkShapes = shapes' - --- Get all the elements of a chunk of arrays. O(1). --- -chunkElems :: Chunk (Array sh a) -> Vector a -chunkElems = elements' - --- Convert a vector to a chunk of scalars. --- -vec2Chunk :: Elt e => Vector e -> Chunk (Scalar e) -vec2Chunk = vec2Vec' - --- Convert a list of arrays to a chunk. --- -fromListChunk :: Arrays a => [a] -> Vector' a -fromListChunk = fromList' concatOp - --- Convert a chunk to a list of arrays. --- -toListChunk :: Arrays a => Vector' a -> [a] -toListChunk = toList' fetchAllOp - --- fmap for Chunk. O(n). --- TODO: Use vectorised function. -mapChunk :: (Arrays a, Arrays b) - => (a -> b) - -> Chunk a -> Chunk b -mapChunk f c = fromListChunk $ map f (toListChunk c) - --- zipWith for Chunk. O(n). --- TODO: Use vectorised function. -zipWithChunk :: (Arrays a, Arrays b, Arrays c) - => (a -> b -> c) - -> Chunk a -> Chunk b -> Chunk c -zipWithChunk f c1 c2 = fromListChunk $ zipWith f (toListChunk c1) (toListChunk c2) - --- A window on a sequence. --- -data Window a = Window - { chunk :: Chunk a -- Current allocated chunk. - , wpos :: SeqPos -- Position of the window on the sequence, given - -- in number of elements. - } - --- The initial empty window. --- -window0 :: Arrays a => Window a -window0 = Window { chunk = emptyChunk, wpos = 0 } - --- Index the given window by the given index on the sequence. --- -(!#) :: Arrays a => Window a -> SeqPos -> Chunk a -w !# i - | j <- i - wpos w - , j >= 0 - = cdrop j (chunk w) - -- - | otherwise - = error $ "Window indexed before position. wpos = " ++ show (wpos w) ++ " i = " ++ show i - --- Move the give window by supplying the next chunk. --- -moveWin :: Arrays a => Window a -> Chunk a -> Window a -moveWin w c = w { chunk = c - , wpos = wpos w + clen (chunk w) - } - --- A cursor on a sequence. --- -data Cursor senv a = Cursor - { ref :: Idx senv a -- Reference to the sequence. - , cpos :: SeqPos -- Position of the cursor on the sequence, - -- given in number of elements. - } - --- Initial cursor. --- -cursor0 :: Idx senv a -> Cursor senv a -cursor0 x = Cursor { ref = x, cpos = 0 } - --- Advance cursor by a relative amount. --- -moveCursor :: Int -> Cursor senv a -> Cursor senv a -moveCursor k c = c { cpos = cpos c + k } +evalPrim (PrimAdd t) = A.add t +evalPrim (PrimSub t) = A.sub t +evalPrim (PrimMul t) = A.mul t +evalPrim (PrimNeg t) = A.negate t +evalPrim (PrimAbs t) = A.abs t +evalPrim (PrimSig t) = A.signum t +evalPrim (PrimQuot t) = A.quot t +evalPrim (PrimRem t) = A.rem t +evalPrim (PrimQuotRem t) = A.quotRem t +evalPrim (PrimIDiv t) = A.div t +evalPrim (PrimMod t) = A.mod t +evalPrim (PrimDivMod t) = A.divMod t +evalPrim (PrimBAnd t) = A.band t +evalPrim (PrimBOr t) = A.bor t +evalPrim (PrimBXor t) = A.xor t +evalPrim (PrimBNot t) = A.complement t +evalPrim (PrimBShiftL t) = A.shiftL t +evalPrim (PrimBShiftR t) = A.shiftR t +evalPrim (PrimBRotateL t) = A.rotateL t +evalPrim (PrimBRotateR t) = A.rotateR t +evalPrim (PrimPopCount t) = A.popCount t +evalPrim (PrimCountLeadingZeros t) = A.countLeadingZeros t +evalPrim (PrimCountTrailingZeros t) = A.countTrailingZeros t +evalPrim (PrimFDiv t) = A.fdiv t +evalPrim (PrimRecip t) = A.recip t +evalPrim (PrimSin t) = A.sin t +evalPrim (PrimCos t) = A.cos t +evalPrim (PrimTan t) = A.tan t +evalPrim (PrimAsin t) = A.asin t +evalPrim (PrimAcos t) = A.acos t +evalPrim (PrimAtan t) = A.atan t +evalPrim (PrimSinh t) = A.sinh t +evalPrim (PrimCosh t) = A.cosh t +evalPrim (PrimTanh t) = A.tanh t +evalPrim (PrimAsinh t) = A.asinh t +evalPrim (PrimAcosh t) = A.acosh t +evalPrim (PrimAtanh t) = A.atanh t +evalPrim (PrimExpFloating t) = A.exp t +evalPrim (PrimSqrt t) = A.sqrt t +evalPrim (PrimLog t) = A.log t +evalPrim (PrimFPow t) = A.pow t +evalPrim (PrimLogBase t) = A.logBase t +evalPrim (PrimTruncate ta tb) = A.truncate ta tb +evalPrim (PrimRound ta tb) = A.round ta tb +evalPrim (PrimFloor ta tb) = A.floor ta tb +evalPrim (PrimCeiling ta tb) = A.ceiling ta tb +evalPrim (PrimAtan2 t) = A.atan2 t +evalPrim (PrimIsNaN t) = A.isNaN t +evalPrim (PrimIsInfinite t) = A.isInfinite t +evalPrim (PrimLt t) = A.lt t +evalPrim (PrimGt t) = A.gt t +evalPrim (PrimLtEq t) = A.lte t +evalPrim (PrimGtEq t) = A.gte t +evalPrim (PrimEq t) = A.eq t +evalPrim (PrimNEq t) = A.neq t +evalPrim (PrimMax t) = A.max t +evalPrim (PrimMin t) = A.min t +evalPrim (PrimLAnd t) = A.land t +evalPrim (PrimLOr t) = A.lor t +evalPrim (PrimLNot t) = A.lnot t +evalPrim (PrimFromIntegral ta tb) = A.fromIntegral ta tb +evalPrim (PrimToFloating ta tb) = A.toFloating ta tb +evalPrim (PrimToBool i b) = A.toBool i b +evalPrim (PrimFromBool b i) = A.fromBool b i + + +-- Vector primitives +-- ----------------- --- Valuation for an environment of sequence windows. --- -data Val' senv where - Empty' :: Val' () - Push' :: Val' senv -> Window t -> Val' (senv, t) +evalExtract :: ScalarType (Prim.Vec n a) -> SingleIntegralType i -> Prim.Vec n a -> i -> a +evalExtract vR iR v i = scalar vR v + where + scalar :: ScalarType (Prim.Vec n t) -> Prim.Vec n t -> t + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t --- Projection of a window from a window valuation using a de Bruijn --- index. --- -prj' :: Idx senv t -> Val' senv -> Window t -prj' ZeroIdx (Push' _ v) = v -prj' (SuccIdx idx) (Push' val _) = prj' idx val + num :: NumType (Prim.Vec n t) -> Prim.Vec n t -> t + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t --- Projection of a chunk from a window valuation using a sequence --- cursor. --- -prjChunk :: Arrays a => Cursor senv a -> Val' senv -> Chunk a -prjChunk c senv = prj' (ref c) senv !# cpos c + bit :: BitType (Prim.Vec n t) -> Prim.Vec n t -> t + bit TypeMask{} v + | IntegralDict <- integralDict iR + = Bit.extract (BitMask v) (fromIntegral i) + + integral :: IntegralType (Prim.Vec n t) -> Prim.Vec n t -> t + integral (SingleIntegralType tR) _ = case tR of + integral (VectorIntegralType _ tR) v + | IntegralDict <- integralDict tR + , IntegralDict <- integralDict iR + = Vec.extract v (fromIntegral i) + + floating :: FloatingType (Prim.Vec n t) -> Prim.Vec n t -> t + floating (SingleFloatingType tR) _ = case tR of + floating (VectorFloatingType _ tR) v + | FloatingDict <- floatingDict tR + , IntegralDict <- integralDict iR + = Vec.extract v (fromIntegral i) + +evalInsert + :: ScalarType (Prim.Vec n a) + -> SingleIntegralType i + -> Prim.Vec n a + -> i + -> a + -> Prim.Vec n a +evalInsert vR iR v i x = scalar vR v x + where + scalar :: ScalarType (Prim.Vec n t) -> Prim.Vec n t -> t -> Prim.Vec n t + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t --- An executable sequence. --- -data ExecSeq senv arrs where - ExecP :: Arrays a => Window a -> ExecP senv a -> ExecSeq (senv, a) arrs -> ExecSeq senv arrs - ExecC :: Arrays a => ExecC senv a -> ExecSeq senv a - ExecR :: Arrays a => Cursor senv a -> ExecSeq senv [a] + num :: NumType (Prim.Vec n t) -> Prim.Vec n t -> t -> Prim.Vec n t + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t --- An executable producer. --- -data ExecP senv a where - ExecStreamIn :: Int - -> [a] - -> ExecP senv a - - ExecMap :: Arrays a - => (Chunk a -> Chunk b) - -> Cursor senv a - -> ExecP senv b - - ExecZipWith :: (Arrays a, Arrays b) - => (Chunk a -> Chunk b -> Chunk c) - -> Cursor senv a - -> Cursor senv b - -> ExecP senv c - - -- Stream scan skeleton. - ExecScan :: Arrays a - => (s -> Chunk a -> (Chunk r, s)) -- Chunk scanner. - -> s -- Accumulator (internal state). - -> Cursor senv a -- Input stream. - -> ExecP senv r - --- An executable consumer. --- -data ExecC senv a where - - -- Stream reduction skeleton. - ExecFold :: Arrays a - => (s -> Chunk a -> s) -- Chunk consumer function. - -> (s -> r) -- Finalizer function. - -> s -- Accumulator (internal state). - -> Cursor senv a -- Input stream. - -> ExecC senv r - - ExecStuple :: IsAtuple a - => Atuple (ExecC senv) (TupleRepr a) - -> ExecC senv a - -minCursor :: ExecSeq senv a -> SeqPos -minCursor s = travS s 0 + bit :: BitType (Prim.Vec n t) -> Prim.Vec n t -> t -> Prim.Vec n t + bit TypeMask{} v x + | IntegralDict <- integralDict iR + = unMask $ Bit.insert (BitMask v) (fromIntegral i) x + + integral :: IntegralType (Prim.Vec n t) -> Prim.Vec n t -> t -> Prim.Vec n t + integral (SingleIntegralType tR) _ _ = case tR of + integral (VectorIntegralType _ tR) v x + | IntegralDict <- integralDict tR + , IntegralDict <- integralDict iR + = Vec.insert v (fromIntegral i) x + + floating :: FloatingType (Prim.Vec n t) -> Prim.Vec n t -> t -> Prim.Vec n t + floating (SingleFloatingType tR) _ _ = case tR of + floating (VectorFloatingType _ tR) v x + | FloatingDict <- floatingDict tR + , IntegralDict <- integralDict iR + = Vec.insert v (fromIntegral i) x + +evalShuffle + :: ScalarType (Prim.Vec n a) + -> ScalarType (Prim.Vec m a) + -> SingleIntegralType i + -> Prim.Vec n a + -> Prim.Vec n a + -> Prim.Vec m i + -> Prim.Vec m a +evalShuffle = scalar where - travS :: ExecSeq senv a -> Int -> SeqPos - travS s i = - case s of - ExecP _ p s' -> travP p i `min` travS s' (i+1) - ExecC c -> travC c i - ExecR _ -> maxBound - - k :: Cursor senv a -> Int -> SeqPos - k c i - | i == idxToInt (ref c) = cpos c - | otherwise = maxBound - - travP :: ExecP senv a -> Int -> SeqPos - travP p i = - case p of - ExecStreamIn _ _ -> maxBound - ExecMap _ c -> k c i - ExecZipWith _ c1 c2 -> k c1 i `min` k c2 i - ExecScan _ _ c -> k c i - - travT :: Atuple (ExecC senv) t -> Int -> SeqPos - travT NilAtup _ = maxBound - travT (SnocAtup t c) i = travT t i `min` travC c i - - travC :: ExecC senv a -> Int -> SeqPos - travC c i = - case c of - ExecFold _ _ _ cu -> k cu i - ExecStuple t -> travT t i - - -evalDelayedSeq - :: SeqConfig - -> DelayedSeq arrs - -> arrs -evalDelayedSeq cfg (DelayedSeq aenv s) | aenv' <- evalExtend aenv Empty - = evalSeq cfg s aenv' - -evalSeq :: forall aenv arrs. - SeqConfig - -> PreOpenSeq DelayedOpenAcc aenv () arrs - -> Val aenv -> arrs -evalSeq conf s aenv = evalSeq' s + scalar :: ScalarType (Prim.Vec n t) + -> ScalarType (Prim.Vec m t) + -> SingleIntegralType i + -> Prim.Vec n t + -> Prim.Vec n t + -> Prim.Vec m i + -> Prim.Vec m t + scalar (NumScalarType s) (NumScalarType t) = num s t + scalar (BitScalarType s) (BitScalarType t) = bit s t + scalar _ _ = internalError "unexpected vector encoding" + + num :: NumType (Prim.Vec n t) + -> NumType (Prim.Vec m t) + -> SingleIntegralType i + -> Prim.Vec n t + -> Prim.Vec n t + -> Prim.Vec m i + -> Prim.Vec m t + num (IntegralNumType s) (IntegralNumType t) = integral s t + num (FloatingNumType s) (FloatingNumType t) = floating s t + num _ _ = internalError "unexpected vector encoding" + + bit :: BitType (Prim.Vec n t) + -> BitType (Prim.Vec m t) + -> SingleIntegralType i + -> Prim.Vec n t + -> Prim.Vec n t + -> Prim.Vec m i + -> Prim.Vec m t + bit (TypeMask n#) TypeMask{} iR x y i + | IntegralDict <- integralDict iR + = let n = fromInteger (natVal' n#) + in unMask + $ Bit.fromList [ boundsCheck "vector index" (j >= 0 && j < 2*n) + $ if j < n then Bit.extract (BitMask x) j + else Bit.extract (BitMask y) (j - n) + | j <- map fromIntegral (Vec.toList i) ] + + integral :: IntegralType (Prim.Vec n t) + -> IntegralType (Prim.Vec m t) + -> SingleIntegralType i + -> Prim.Vec n t + -> Prim.Vec n t + -> Prim.Vec m i + -> Prim.Vec m t + integral (SingleIntegralType s) _ _ _ _ _ = case s of + integral _ (SingleIntegralType t) _ _ _ _ = case t of + integral (VectorIntegralType n# sR) (VectorIntegralType _ tR) iR x y i + | IntegralDict <- integralDict iR + , IntegralDict <- integralDict sR + , IntegralDict <- integralDict tR + = let n = fromInteger (natVal' n#) + in Vec.fromList [ boundsCheck "vector index" (j >= 0 && j < 2*n) + $ if j < n then Vec.extract x j + else Vec.extract y (j - n) + | j <- map fromIntegral (Vec.toList i) ] + + floating :: FloatingType (Prim.Vec n t) + -> FloatingType (Prim.Vec m t) + -> SingleIntegralType i + -> Prim.Vec n t + -> Prim.Vec n t + -> Prim.Vec m i + -> Prim.Vec m t + floating (SingleFloatingType s) _ _ _ _ _ = case s of + floating _ (SingleFloatingType t) _ _ _ _ = case t of + floating (VectorFloatingType n# sR) (VectorFloatingType _ tR) iR x y i + | IntegralDict <- integralDict iR + , FloatingDict <- floatingDict sR + , FloatingDict <- floatingDict tR + = let n = fromInteger (natVal' n#) + in Vec.fromList [ boundsCheck "vector index" (j >= 0 && j < 2*n) + $ if j < n then Vec.extract x j + else Vec.extract y (j - n) + | j <- map fromIntegral (Vec.toList i) ] + +evalSelect + :: ScalarType (Prim.Vec n a) + -> Prim.Vec n Bit + -> Prim.Vec n a + -> Prim.Vec n a + -> Prim.Vec n a +evalSelect = scalar where - evalSeq' :: PreOpenSeq DelayedOpenAcc aenv senv arrs -> arrs - evalSeq' (Producer _ s) = evalSeq' s - evalSeq' (Consumer _) = loop (initSeq aenv s) - evalSeq' (Reify _) = reify (initSeq aenv s) - - -- Initialize the producers and the accumulators of the consumers - -- with the given array enviroment. - initSeq :: forall senv arrs'. - Val aenv - -> PreOpenSeq DelayedOpenAcc aenv senv arrs' - -> ExecSeq senv arrs' - initSeq aenv s = - case s of - Producer p s' -> ExecP window0 (initProducer p) (initSeq aenv s') - Consumer c -> ExecC (initConsumer c) - Reify ix -> ExecR (cursor0 ix) - - -- Generate a list from the sequence. - reify :: forall arrs. ExecSeq () [arrs] - -> [arrs] - reify s = case step s Empty' of - (Just s', a) -> a ++ reify s' - (Nothing, a) -> a - - -- Iterate the given sequence until it terminates. - -- A sequence only terminates when one of the producers are exhausted. - loop :: Arrays arrs - => ExecSeq () arrs - -> arrs - loop s = - case step' s of - (Nothing, arrs) -> arrs - (Just s', _) -> loop s' + scalar :: ScalarType (Prim.Vec n t) -> Prim.Vec n Bit -> Prim.Vec n t -> Prim.Vec n t -> Prim.Vec n t + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t - where - step' :: ExecSeq () arrs -> (Maybe (ExecSeq () arrs), arrs) - step' s = step s Empty' - - -- One iteration of a sequence. - step :: forall senv arrs'. - ExecSeq senv arrs' - -> Val' senv - -> (Maybe (ExecSeq senv arrs'), arrs') - step s senv = - case s of - ExecP w p s' -> - let (c, mp') = produce p senv - finished = 0 == clen (w !# minCursor s') - w' = if finished then moveWin w c else w - (ms'', a) = step s' (senv `Push'` w') - in case ms'' of - Nothing -> (Nothing, a) - Just s'' | finished - , Just p' <- mp' - -> (Just (ExecP w' p' s''), a) - | not finished - -> (Just (ExecP w' p s''), a) - | otherwise - -> (Nothing, a) - ExecC c -> let (c', acc) = consume c senv - in (Just (ExecC c'), acc) - ExecR ix -> let c = prjChunk ix senv in (Just (ExecR (moveCursor (clen c) ix)), toListChunk c) - - evalA :: DelayedOpenAcc aenv a -> a - evalA acc = evalOpenAcc acc aenv - - evalAF :: DelayedOpenAfun aenv f -> f - evalAF f = evalOpenAfun f aenv - - evalE :: DelayedExp aenv t -> t - evalE exp = evalExp exp aenv - - evalF :: DelayedFun aenv f -> f - evalF fun = evalFun fun aenv - - initProducer :: forall a senv. - Producer DelayedOpenAcc aenv senv a - -> ExecP senv a - initProducer p = - case p of - StreamIn arrs -> ExecStreamIn 1 arrs - ToSeq sliceIndex slix (delayed -> Delayed sh ix _) -> - let n = R.size (R.sliceShape sliceIndex (fromElt sh)) - k = elemsPerChunk conf n - in ExecStreamIn k (toSeqOp sliceIndex slix (fromFunction sh ix)) - MapSeq f x -> ExecMap (mapChunk (evalAF f)) (cursor0 x) - ChunkedMapSeq f x -> ExecMap (evalAF f) (cursor0 x) - ZipWithSeq f x y -> ExecZipWith (zipWithChunk (evalAF f)) (cursor0 x) (cursor0 y) - ScanSeq f e x -> ExecScan scanner (evalE e) (cursor0 x) - where - scanner a c = - let v0 = chunkElems c - (v1, a') = scanl'Op (evalF f) a (delayArray v0) - in (vec2Chunk v1, fromScalar a') - - initConsumer :: forall a senv. - Consumer DelayedOpenAcc aenv senv a - -> ExecC senv a - initConsumer c = - case c of - FoldSeq f e x -> - let f' = evalF f - a0 = fromFunction (Z :. chunkSize conf) (const (evalE e)) - consumer v c = zipWith'Op f' (delayArray v) (delayArray (chunkElems c)) - finalizer = fold1Op f' . delayArray - in ExecFold consumer finalizer a0 (cursor0 x) - FoldSeqFlatten f acc x -> - let f' = evalAF f - a0 = evalA acc - consumer a c = f' a (chunkShapes c) (chunkElems c) - in ExecFold consumer id a0 (cursor0 x) - Stuple t -> - let initTup :: Atuple (Consumer DelayedOpenAcc aenv senv) t -> Atuple (ExecC senv) t - initTup NilAtup = NilAtup - initTup (SnocAtup t c) = SnocAtup (initTup t) (initConsumer c) - in ExecStuple (initTup t) - - delayed :: DelayedOpenAcc aenv (Array sh e) -> Delayed (Array sh e) - delayed AST.Manifest{} = $internalError "evalOpenAcc" "expected delayed array" - delayed AST.Delayed{..} = Delayed (evalExp extentD aenv) - (evalFun indexD aenv) - (evalFun linearIndexD aenv) - -produce :: Arrays a => ExecP senv a -> Val' senv -> (Chunk a, Maybe (ExecP senv a)) -produce p senv = - case p of - ExecStreamIn k xs -> - let (xs', xs'') = (take k xs, drop k xs) - c = fromListChunk xs' - mp = if null xs'' - then Nothing - else Just (ExecStreamIn k xs'') - in (c, mp) - ExecMap f x -> - let c = prjChunk x senv - in (f c, Just $ ExecMap f (moveCursor (clen c) x)) - ExecZipWith f x y -> - let c1 = prjChunk x senv - c2 = prjChunk y senv - k = clen c1 `min` clen c2 - in (f c1 c2, Just $ ExecZipWith f (moveCursor k x) (moveCursor k y)) - ExecScan scanner a x -> - let c = prjChunk x senv - (c', a') = scanner a c - k = clen c - in (c', Just $ ExecScan scanner a' (moveCursor k x)) - -consume :: forall senv a. ExecC senv a -> Val' senv -> (ExecC senv a, a) -consume c senv = - case c of - ExecFold f g acc x -> - let c = prjChunk x senv - acc' = f acc c - -- Even though we call g here, lazy evaluation should guarantee it is - -- only ever called once. - in (ExecFold f g acc' (moveCursor (clen c) x), g acc') - ExecStuple t -> - let consT :: Atuple (ExecC senv) t -> (Atuple (ExecC senv) t, t) - consT NilAtup = (NilAtup, ()) - consT (SnocAtup t c) | (c', acc) <- consume c senv - , (t', acc') <- consT t - = (SnocAtup t' c', (acc', acc)) - (t', acc) = consT t - in (ExecStuple t', toAtuple acc) - -evalExtend :: Extend DelayedOpenAcc aenv aenv' -> Val aenv -> Val aenv' -evalExtend BaseEnv aenv = aenv -evalExtend (PushEnv ext1 ext2) aenv | aenv' <- evalExtend ext1 aenv - = Push aenv' (evalOpenAcc ext2 aenv') - -delayArray :: Array sh e -> Delayed (Array sh e) -delayArray arr@(Array _ adata) = Delayed (shape arr) (arr!) (toElt . unsafeIndexArrayData adata) - -fromScalar :: Scalar a -> a -fromScalar = (!Z) - -concatOp :: forall e. Elt e => [Vector e] -> Vector e -concatOp = concatVectors - -fetchAllOp :: (Shape sh, Elt e) => Segments sh -> Vector e -> [Array sh e] -fetchAllOp segs elts - | (offsets, n) <- offsetsOp segs - , (n ! Z) <= size (shape elts) - = [fetch (segs ! (Z :. i)) (offsets ! (Z :. i)) | i <- [0 .. size (shape segs) - 1]] - | otherwise = error $ "illegal argument to fetchAllOp" - where - fetch sh offset = fromFunction sh (\ ix -> elts ! (Z :. ((toIndex sh ix) + offset))) - -dropOp :: Elt e => Int -> Vector e -> Vector e -dropOp i v -- TODO - -- * Implement using C-style pointer-plus. - -- ; dropOp is used often (from prjChunk), - -- so it ought to be efficient O(1). - | n <- size (shape v) - , i <= n - , i >= 0 - = fromFunction (Z :. n - i) (\ (Z :. j) -> v ! (Z :. i + j)) - | otherwise = error $ "illegal argument to drop" - -offsetsOp :: Shape sh => Segments sh -> (Vector Int, Scalar Int) -offsetsOp segs = scanl'Op (+) 0 $ delayArray (mapOp size (delayArray segs)) ---} + num :: NumType (Prim.Vec n t) -> Prim.Vec n Bit -> Prim.Vec n t -> Prim.Vec n t -> Prim.Vec n t + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + bit :: BitType (Prim.Vec n t) -> Prim.Vec n Bit -> Prim.Vec n t -> Prim.Vec n t -> Prim.Vec n t + bit TypeMask{} m x y + = unMask + $ Bit.fromList [ if unBit b then Bit.extract (BitMask x) i + else Bit.extract (BitMask y) i + | b <- Bit.toList (BitMask m) + | i <- [0..] + ] + + integral :: IntegralType (Prim.Vec n t) -> Prim.Vec n Bit -> Prim.Vec n t -> Prim.Vec n t -> Prim.Vec n t + integral (SingleIntegralType tR) _ _ _ = case tR of + integral (VectorIntegralType _ tR) m x y + | IntegralDict <- integralDict tR + = Vec.fromList [ if unBit b then Vec.extract x i + else Vec.extract y i + | b <- Bit.toList (BitMask m) + | i <- [0..] + ] + + floating :: FloatingType (Prim.Vec n t) -> Prim.Vec n Bit -> Prim.Vec n t -> Prim.Vec n t -> Prim.Vec n t + floating (SingleFloatingType tR) _ _ _ = case tR of + floating (VectorFloatingType _ tR) m x y + | FloatingDict <- floatingDict tR + = Vec.fromList [ if unBit b then Vec.extract x i + else Vec.extract y i + | b <- Bit.toList (BitMask m) + | i <- [0..] + ] + + +-- Utilities +-- --------- + +toBool :: PrimBool -> Bool +toBool = unBit + +data IntegralDict t where + IntegralDict :: (Integral t, Prim t) => IntegralDict t + +data FloatingDict t where + FloatingDict :: (RealFloat t, Prim t) => FloatingDict t + +{-# INLINE integralDict #-} +integralDict :: SingleIntegralType t -> IntegralDict t +integralDict TypeInt8 = IntegralDict +integralDict TypeInt16 = IntegralDict +integralDict TypeInt32 = IntegralDict +integralDict TypeInt64 = IntegralDict +integralDict TypeInt128 = IntegralDict +integralDict TypeWord8 = IntegralDict +integralDict TypeWord16 = IntegralDict +integralDict TypeWord32 = IntegralDict +integralDict TypeWord64 = IntegralDict +integralDict TypeWord128 = IntegralDict + +{-# INLINE floatingDict #-} +floatingDict :: SingleFloatingType t -> FloatingDict t +floatingDict TypeFloat16 = FloatingDict +floatingDict TypeFloat32 = FloatingDict +floatingDict TypeFloat64 = FloatingDict +floatingDict TypeFloat128 = FloatingDict diff --git a/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs b/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs new file mode 100644 index 000000000..6850a1684 --- /dev/null +++ b/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs @@ -0,0 +1,591 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_HADDOCK hide #-} +-- | +-- Module : Data.Array.Accelerate.Interpreter.Arithmetic +-- Copyright : [2008..2022] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Interpreter.Arithmetic ( + + add, sub, mul, negate, abs, signum, + quot, rem, quotRem, div, mod, divMod, + band, bor, xor, complement, shiftL, shiftR, rotateL, rotateR, popCount, countLeadingZeros, countTrailingZeros, + fdiv, recip, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, exp, sqrt, log, pow, logBase, + truncate, round, floor, ceiling, + atan2, isNaN, isInfinite, + lt, gt, lte, gte, eq, neq, min, max, + land, lor, lnot, lall, lany, + fromIntegral, toFloating, toBool, fromBool, + +) where + +import Data.Array.Accelerate.AST +import Data.Array.Accelerate.Error +import Data.Array.Accelerate.Type + +import Data.Primitive.Bit ( BitMask(..) ) +import qualified Data.Primitive.Vec as Vec + +import Data.Primitive.Types + +import Data.Bits ( (.&.), (.|.) ) +import Data.Bool +import Data.Maybe +import Data.Type.Equality +import Formatting +import Prelude ( ($), (.) ) +import qualified Data.Bits as P +import qualified Prelude as P + +import GHC.Exts +import GHC.TypeLits +import GHC.TypeLits.Extra + + +-- Operators from Num +-- ------------------ + +add :: NumType a -> ((a, a) -> a) +add = num2 (P.+) + +sub :: NumType a -> ((a, a) -> a) +sub = num2 (P.-) + +mul :: NumType a -> ((a, a) -> a) +mul = num2 (P.*) + +negate :: NumType a -> (a -> a) +negate = num1 P.negate + +abs :: NumType a -> (a -> a) +abs = num1 P.abs + +signum :: NumType a -> (a -> a) +signum = num1 P.signum + +num1 :: (forall t. P.Num t => t -> t) -> NumType a -> (a -> a) +num1 f = \case + IntegralNumType t -> integral t + FloatingNumType t -> floating t + where + integral :: IntegralType t -> (t -> t) + integral (SingleIntegralType t) | IntegralDict <- integralDict t = f + integral (VectorIntegralType _ t) | IntegralDict <- integralDict t = map f + -- + floating :: FloatingType t -> (t -> t) + floating (SingleFloatingType t) | FloatingDict <- floatingDict t = f + floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = map f + +num2 :: (forall t. P.Num t => t -> t -> t) -> NumType a -> ((a, a) -> a) +num2 f = \case + IntegralNumType t -> integral t + FloatingNumType t -> floating t + where + integral :: IntegralType t -> ((t, t) -> t) + integral (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry f + integral (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith f) + -- + floating :: FloatingType t -> ((t, t) -> t) + floating (SingleFloatingType t) | FloatingDict <- floatingDict t = P.uncurry f + floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (zipWith f) + + +-- Operators from Integral +-- ----------------------- + +quot :: IntegralType a -> ((a, a) -> a) +quot = int2 P.quot + +rem :: IntegralType a -> ((a, a) -> a) +rem = int2 P.rem + +quotRem :: IntegralType a -> ((a, a) -> (a, a)) +quotRem = int2' P.quotRem + +div :: IntegralType a -> ((a, a) -> a) +div = int2 P.div + +mod :: IntegralType a -> ((a, a) -> a) +mod = int2 P.mod + +divMod :: IntegralType a -> ((a, a) -> (a, a)) +divMod = int2' P.divMod + +int2 :: (forall t. P.Integral t => t -> t -> t) -> IntegralType a -> ((a, a) -> a) +int2 f (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry f +int2 f (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith f) + +int2' :: (forall t. P.Integral t => t -> t -> (t, t)) -> IntegralType a -> ((a, a) -> (a, a)) +int2' f (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry f +int2' f (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith' f) + + +-- Operators from Bits & FiniteBits +-- -------------------------------- + +band :: IntegralType a -> ((a, a) -> a) +band = bits2 (.&.) + +bor :: IntegralType a -> ((a, a) -> a) +bor = bits2 (.|.) + +xor :: IntegralType a -> ((a, a) -> a) +xor = bits2 P.xor + +complement :: IntegralType a -> (a -> a) +complement = bits1 P.complement + +shiftL :: IntegralType a -> ((a, a) -> a) +shiftL = bits2 (\x i -> x `P.shiftL` P.fromIntegral i) + +shiftR :: IntegralType a -> ((a, a) -> a) +shiftR = bits2 (\x i -> x `P.shiftR` P.fromIntegral i) + +rotateL :: IntegralType a -> ((a, a) -> a) +rotateL = bits2 (\x i -> x `P.rotateL` P.fromIntegral i) + +rotateR :: IntegralType a -> ((a, a) -> a) +rotateR = bits2 (\x i -> x `P.rotateR` P.fromIntegral i) + +popCount :: IntegralType a -> (a -> a) +popCount = bits1 (P.fromIntegral . P.popCount) + +countLeadingZeros :: IntegralType a -> (a -> a) +countLeadingZeros = bits1 (P.fromIntegral . P.countLeadingZeros) + +countTrailingZeros :: IntegralType a -> (a -> a) +countTrailingZeros = bits1 (P.fromIntegral . P.countTrailingZeros) + +bits1 :: (forall t. (P.Integral t, P.FiniteBits t) => t -> t) -> IntegralType a -> (a -> a) +bits1 f (SingleIntegralType t) | IntegralDict <- integralDict t = f +bits1 f (VectorIntegralType _ t) | IntegralDict <- integralDict t = map f + +bits2 :: (forall t. (P.Integral t, P.FiniteBits t) => t -> t -> t) -> IntegralType a -> ((a, a) -> a) +bits2 f (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry f +bits2 f (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith f) + + +-- Operators from Fractional and Floating +-- -------------------------------------- + +fdiv :: FloatingType a -> ((a, a) -> a) +fdiv = float2 (P./) + +recip :: FloatingType a -> (a -> a) +recip = float1 P.recip + +sin :: FloatingType a -> (a -> a) +sin = float1 P.sin + +cos :: FloatingType a -> (a -> a) +cos = float1 P.cos + +tan :: FloatingType a -> (a -> a) +tan = float1 P.tan + +asin :: FloatingType a -> (a -> a) +asin = float1 P.asin + +acos :: FloatingType a -> (a -> a) +acos = float1 P.acos + +atan :: FloatingType a -> (a -> a) +atan = float1 P.atan + +sinh :: FloatingType a -> (a -> a) +sinh = float1 P.sinh + +cosh :: FloatingType a -> (a -> a) +cosh = float1 P.cosh + +tanh :: FloatingType a -> (a -> a) +tanh = float1 P.tanh + +asinh :: FloatingType a -> (a -> a) +asinh = float1 P.asinh + +acosh :: FloatingType a -> (a -> a) +acosh = float1 P.acosh + +atanh :: FloatingType a -> (a -> a) +atanh = float1 P.atanh + +exp :: FloatingType a -> (a -> a) +exp = float1 P.exp + +sqrt :: FloatingType a -> (a -> a) +sqrt = float1 P.sqrt + +log :: FloatingType a -> (a -> a) +log = float1 P.log + +pow :: FloatingType a -> ((a, a) -> a) +pow = float2 (P.**) + +logBase :: FloatingType a -> ((a, a) -> a) +logBase = float2 P.logBase + + +float1 :: (forall t. P.RealFloat t => t -> t) -> FloatingType a -> (a -> a) +float1 f (SingleFloatingType t) | FloatingDict <- floatingDict t = f +float1 f (VectorFloatingType _ t) | FloatingDict <- floatingDict t = map f + +float2 :: (forall t. P.RealFloat t => t -> t -> t) -> FloatingType a -> ((a, a) -> a) +float2 f (SingleFloatingType t) | FloatingDict <- floatingDict t = P.uncurry f +float2 f (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (zipWith f) + + +-- Operators from RealFrac +-- ----------------------- + +truncate :: FloatingType a -> IntegralType b -> (a -> b) +truncate (SingleFloatingType a) (SingleIntegralType b) + | FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = P.truncate +truncate (VectorFloatingType n a) (VectorIntegralType m b) + | Just Refl <- sameNat' n m + , FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = map P.truncate +truncate a b + = internalError ("truncate: cannot reconcile `" % formatFloatingType % "' with `" % formatIntegralType % "'") a b + +round :: FloatingType a -> IntegralType b -> (a -> b) +round (SingleFloatingType a) (SingleIntegralType b) + | FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = P.round +round (VectorFloatingType n a) (VectorIntegralType m b) + | Just Refl <- sameNat' n m + , FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = map P.round +round a b + = internalError ("round: cannot reconcile `" % formatFloatingType % "' with `" % formatIntegralType % "'") a b + +floor :: FloatingType a -> IntegralType b -> (a -> b) +floor (SingleFloatingType a) (SingleIntegralType b) + | FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = P.floor +floor (VectorFloatingType n a) (VectorIntegralType m b) + | Just Refl <- sameNat' n m + , FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = map P.floor +floor a b + = internalError ("floor: cannot reconcile `" % formatFloatingType % "' with `" % formatIntegralType % "'") a b + +ceiling :: FloatingType a -> IntegralType b -> (a -> b) +ceiling (SingleFloatingType a) (SingleIntegralType b) + | FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = P.ceiling +ceiling (VectorFloatingType n a) (VectorIntegralType m b) + | Just Refl <- sameNat' n m + , FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = map P.ceiling +ceiling a b + = internalError ("ceiling: cannot reconcile `" % formatFloatingType % "' with `" % formatIntegralType % "'") a b + + +-- Operators from RealFloat +-- ------------------------ + +atan2 :: FloatingType a -> ((a, a) -> a) +atan2 = float2 P.atan2 + +isNaN :: FloatingType a -> (a -> BitOrMask a) +isNaN = \case + SingleFloatingType t | FloatingDict <- floatingDict t -> isNaN' + VectorFloatingType _ t | FloatingDict <- floatingDict t -> unMask . map isNaN' + where + isNaN' x = Bit (P.isNaN x) + +isInfinite :: FloatingType a -> (a -> BitOrMask a) +isInfinite = \case + SingleFloatingType t | FloatingDict <- floatingDict t -> isInfinite' + VectorFloatingType _ t | FloatingDict <- floatingDict t -> unMask . map isInfinite' + where + isInfinite' x = Bit (P.isInfinite x) + + +-- Operators from Eq & Ord +-- ----------------------- + +lt :: ScalarType a -> ((a, a) -> BitOrMask a) +lt = cmp (P.<) + +gt :: ScalarType a -> ((a, a) -> BitOrMask a) +gt = cmp (P.>) + +lte :: ScalarType a -> ((a, a) -> BitOrMask a) +lte = cmp (P.<=) + +gte :: ScalarType a -> ((a, a) -> BitOrMask a) +gte = cmp (P.>=) + +eq :: ScalarType a -> ((a, a) -> BitOrMask a) +eq = cmp (P.==) + +neq :: ScalarType a -> ((a, a) -> BitOrMask a) +neq = cmp (P./=) + +cmp :: (forall t. P.Ord t => (t -> t -> Bool)) -> ScalarType a -> ((a, a) -> BitOrMask a) +cmp f = \case + NumScalarType t -> num t + BitScalarType t -> bit t + where + bit :: BitType t -> ((t, t) -> BitOrMask t) + bit TypeBit = Bit . P.uncurry f + bit TypeMask{} = \(x,y) -> unMask (zipWith (Bit $$ f) (BitMask x) (BitMask y)) + + num :: NumType t -> ((t, t) -> BitOrMask t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> ((t, t) -> BitOrMask t) + integral (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry (Bit $$ f) + integral (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (unMask $$ zipWith (Bit $$ f)) + + floating :: FloatingType t -> ((t, t) -> BitOrMask t) + floating (SingleFloatingType t) | FloatingDict <- floatingDict t = P.uncurry (Bit $$ f) + floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (unMask $$ zipWith (Bit $$ f)) + +min :: ScalarType a -> ((a, a) -> a) +min = \case + NumScalarType t -> num t + BitScalarType t -> bit t + where + bit :: BitType t -> ((t, t) -> t) + bit TypeBit = P.uncurry P.min + bit TypeMask{} = \(x,y) -> unMask (zipWith P.min (BitMask x) (BitMask y)) + + num :: NumType t -> ((t, t) -> t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> ((t, t) -> t) + integral (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry P.min + integral (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith P.min) + + floating :: FloatingType t -> ((t, t) -> t) + floating (SingleFloatingType t) | FloatingDict <- floatingDict t = P.uncurry P.min + floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (zipWith P.min) + +max :: ScalarType a -> ((a, a) -> a) +max = \case + NumScalarType t -> num t + BitScalarType t -> bit t + where + bit :: BitType t -> ((t, t) -> t) + bit TypeBit = P.uncurry P.max + bit TypeMask{} = \(x,y) -> unMask (zipWith P.max (BitMask x) (BitMask y)) + + num :: NumType t -> ((t, t) -> t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> ((t, t) -> t) + integral (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry P.max + integral (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith P.max) + + floating :: FloatingType t -> ((t, t) -> t) + floating (SingleFloatingType t) | FloatingDict <- floatingDict t = P.uncurry P.max + floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (zipWith P.max) + + +-- Logical operators +-- ----------------- + +land :: BitType a -> ((a, a) -> a) +land = \case + TypeBit -> P.uncurry land' + TypeMask{} -> \(x,y) -> unMask (zipWith land' (BitMask x) (BitMask y)) + where + land' (Bit x) (Bit y) = Bit (x && y) + +lor :: BitType a -> ((a, a) -> a) +lor = \case + TypeBit -> P.uncurry lor' + TypeMask{} -> \(x,y) -> unMask (zipWith lor' (BitMask x) (BitMask y)) + where + lor' (Bit x) (Bit y) = Bit (x || y) + +lnot :: BitType a -> (a -> a) +lnot = \case + TypeBit -> not' + TypeMask{} -> unMask . map not' . BitMask + where + not' (Bit x) = Bit (not x) + +lall :: BitType a -> a -> Bool +lall TypeBit x = unBit x +lall TypeMask{} x = P.all unBit (toList (BitMask x)) + +lany :: BitType a -> a -> Bool +lany TypeBit x = unBit x +lany TypeMask{} x = P.any unBit (toList (BitMask x)) + + +-- Conversion +-- ---------- + +fromIntegral :: forall a b. IntegralType a -> NumType b -> (a -> b) +fromIntegral (SingleIntegralType a) (IntegralNumType (SingleIntegralType b)) + | IntegralDict <- integralDict a + , IntegralDict <- integralDict b + = P.fromIntegral +fromIntegral (SingleIntegralType a) (FloatingNumType (SingleFloatingType b)) + | IntegralDict <- integralDict a + , FloatingDict <- floatingDict b + = P.fromIntegral +fromIntegral (VectorIntegralType n a) (IntegralNumType (VectorIntegralType m b)) + | Just Refl <- sameNat' n m + , IntegralDict <- integralDict a + , IntegralDict <- integralDict b + = map P.fromIntegral +fromIntegral (VectorIntegralType n a) (FloatingNumType (VectorFloatingType m b)) + | Just Refl <- sameNat' n m + , IntegralDict <- integralDict a + , FloatingDict <- floatingDict b + = map P.fromIntegral +fromIntegral a b + = internalError ("fromIntegral: cannot reconcile `" % formatIntegralType % "' with `" % formatNumType % "'") a b + + +toFloating :: forall a b. NumType a -> FloatingType b -> (a -> b) +toFloating (IntegralNumType (SingleIntegralType a)) (SingleFloatingType b) + | IntegralDict <- integralDict a + , FloatingDict <- floatingDict b + = P.realToFrac +toFloating (FloatingNumType (SingleFloatingType a)) (SingleFloatingType b) + | FloatingDict <- floatingDict a + , FloatingDict <- floatingDict b + = P.realToFrac +toFloating (IntegralNumType (VectorIntegralType n a)) (VectorFloatingType m b) + | Just Refl <- sameNat' n m + , IntegralDict <- integralDict a + , FloatingDict <- floatingDict b + = map P.realToFrac +toFloating (FloatingNumType (VectorFloatingType n a)) (VectorFloatingType m b) + | Just Refl <- sameNat' n m + , FloatingDict <- floatingDict a + , FloatingDict <- floatingDict b + = map P.realToFrac +toFloating a b + = internalError ("toFloating: cannot reconcile `" % formatNumType % "' with `" % formatFloatingType % "'") a b + + +toBool :: IntegralType a -> BitType b -> (a -> b) +toBool iR bR = + case iR of + SingleIntegralType t | IntegralDict <- integralDict t -> + case bR of + TypeBit -> P.fromIntegral + TypeMask n -> \x -> let m = P.finiteBitSize x + bits = P.map (Bit . P.testBit x) [0 .. m P.- 1] P.++ P.repeat (Bit False) + in + unMask $ fromList (P.reverse (P.take (P.fromIntegral (natVal' n)) bits)) + -- + VectorIntegralType _ t | IntegralDict <- integralDict t -> + case bR of + TypeBit -> \x -> P.fromIntegral (Vec.extract x 0) -- XXX: first or last lane? + TypeMask _ -> unMask . map P.fromIntegral + +fromBool :: forall a b. BitType a -> IntegralType b -> (a -> b) +fromBool bR iR = + case bR of + TypeBit -> + case iR of + SingleIntegralType t | IntegralDict <- integralDict t -> P.fromIntegral + VectorIntegralType _ t | IntegralDict <- integralDict t -> \x -> + if unBit x + then Vec.insert (Vec.splat 0) 0 1 -- XXX: first or last lane? + else Vec.splat 0 + -- + TypeMask _ -> + case iR of + SingleIntegralType t | IntegralDict <- integralDict t -> \x -> + let bits = toList (BitMask x) + m = P.finiteBitSize (P.undefined :: b) + -- + go !_ !w [] = w + go !i !w (Bit True : bs) = go (i P.+ 1) (P.setBit w i) bs + go !i !w (Bit False: bs) = go (i P.+ 1) w bs + in + go 0 0 (P.reverse (P.take m bits)) + -- + VectorIntegralType _ t | IntegralDict <- integralDict t -> map P.fromIntegral . BitMask + + +-- Vector element-wise operations +-- ------------------------------ + +-- XXX: These implementations lose the type safety that the length of the +-- underlying Vec or BitMask is preserved + +map :: (IsList a, IsList b) => (Item a -> Item b) -> a -> b +map f xs = fromList $ P.map f (toList xs) + +zipWith :: (IsList a, IsList b, IsList c) => (Item a -> Item b -> Item c) -> a -> b -> c +zipWith f xs ys = fromList $ P.zipWith f (toList xs) (toList ys) + +zipWith' + :: (IsList a, IsList b, IsList c, IsList d) + => (Item a -> Item b -> (Item c, Item d)) + -> a + -> b + -> (c, d) +zipWith' f xs ys = + let (us, vs) = P.unzip $ P.zipWith f (toList xs) (toList ys) + in (fromList us, fromList vs) + + +-- Utilities +-- --------- + +data IntegralDict t where + IntegralDict :: (P.Integral t, P.FiniteBits t, Prim t, BitOrMask t ~ Bit) => IntegralDict t + +data FloatingDict t where + FloatingDict :: (P.RealFloat t, Prim t, BitOrMask t ~ Bit) => FloatingDict t + +{-# INLINE integralDict #-} +integralDict :: SingleIntegralType t -> IntegralDict t +integralDict TypeInt8 = IntegralDict +integralDict TypeInt16 = IntegralDict +integralDict TypeInt32 = IntegralDict +integralDict TypeInt64 = IntegralDict +integralDict TypeInt128 = IntegralDict +integralDict TypeWord8 = IntegralDict +integralDict TypeWord16 = IntegralDict +integralDict TypeWord32 = IntegralDict +integralDict TypeWord64 = IntegralDict +integralDict TypeWord128 = IntegralDict + +{-# INLINE floatingDict #-} +floatingDict :: SingleFloatingType t -> FloatingDict t +floatingDict TypeFloat16 = FloatingDict +floatingDict TypeFloat32 = FloatingDict +floatingDict TypeFloat64 = FloatingDict +floatingDict TypeFloat128 = FloatingDict + +infixr 0 $$ +($$) :: (b -> a) -> (c -> d -> b) -> c -> d -> a +(f $$ g) x y = f (g x y) + diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index 727e9f7b9..da65e597d 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -38,18 +38,6 @@ module Data.Array.Accelerate.Language ( -- * Map-like functions map, zipWith, - -- -- * Sequence collection - -- collect, - - -- -- * Sequence producers - -- streamIn, toSeq, - - -- -- * Sequence transducers - -- mapSeq, zipWithSeq, scanSeq, - - -- -- * Sequence consumers - -- foldSeq, foldSeqFlatten, - -- * Reductions fold, fold1, foldSeg', fold1Seg', @@ -197,7 +185,7 @@ unit (Exp e) = Acc $ SmartAcc $ Unit (eltR @e) e -- -- ...or as columns: -- --- >>> run $ replicate (lift (Z :. All :. (4::Int))) (use vec) +-- >>> run $ replicate (Z_ ::. All_ ::. (4 :: Exp Int)) (use vec) -- Matrix (Z :. 10 :. 4) -- [ 0, 0, 0, 0, -- 1, 1, 1, 1, @@ -222,7 +210,7 @@ unit (Exp e) = Acc $ SmartAcc $ Unit (eltR @e) e -- -- >>> :{ -- let rep0 :: (Shape sh, Elt e) => Exp Int -> Acc (Array sh e) -> Acc (Array (sh :. Int) e) --- rep0 n a = replicate (lift (Any :. n)) a +-- rep0 n a = replicate (Any_ ::. n) a -- :} -- -- >>> let x = unit 42 :: Acc (Scalar Int) @@ -246,7 +234,7 @@ unit (Exp e) = Acc $ SmartAcc $ Unit (eltR @e) e -- -- >>> :{ -- let rep1 :: (Shape sh, Elt e) => Exp Int -> Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int :. Int) e) --- rep1 n a = replicate (lift (Any :. n :. All)) a +-- rep1 n a = replicate (Any_ ::. n ::. All_) a -- :} -- -- >>> run $ rep1 5 (use vec) @@ -360,7 +348,7 @@ reshape = Acc $$ applyAcc (Reshape $ shapeR @sh) -- >>> :{ -- let -- sl0 :: (Shape sh, Elt e) => Acc (Array (sh:.Int) e) -> Exp Int -> Acc (Array sh e) --- sl0 a n = slice a (lift (Any :. n)) +-- sl0 a n = slice a (Any_ ::. n) -- :} -- -- >>> let vec = fromList (Z:.10) [0..] :: Vector Int @@ -374,7 +362,7 @@ reshape = Acc $$ applyAcc (Reshape $ shapeR @sh) -- -- >>> :{ -- let sl1 :: (Shape sh, Elt e) => Acc (Array (sh:.Int:.Int) e) -> Exp Int -> Acc (Array (sh:.Int) e) --- sl1 a n = slice a (lift (Any :. n :. All)) +-- sl1 a n = slice a (Any_ ::. n ::. All_) -- :} -- -- >>> run $ sl1 (use mat) 4 @@ -515,8 +503,7 @@ zipWith = Acc $$$ applyAcc (ZipWith (eltR @a) (eltR @b) (eltR @c)) -- See also 'Data.Array.Accelerate.Data.Fold.Fold', which can be a useful way to -- compute multiple results from a single reduction. -- -fold :: forall sh a. - (Shape sh, Elt a) +fold :: forall sh a. (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) @@ -533,8 +520,7 @@ fold f (Exp x) = Acc . applyAcc (Fold (eltR @a) (unExpBinaryFunction f) (Just x) -- The first argument needs to be an /associative/ function to enable an -- efficient parallel implementation, but does not need to be commutative. -- -fold1 :: forall sh a. - (Shape sh, Elt a) +fold1 :: forall sh a. (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Acc (Array (sh:.Int) a) -> Acc (Array sh a) @@ -554,14 +540,13 @@ fold1 f = Acc . applyAcc (Fold (eltR @a) (unExpBinaryFunction f) Nothing) -- @since 1.3.0.0 -- foldSeg' - :: forall sh a i. - (Shape sh, Elt a, Elt i, IsIntegral i, i ~ EltR i) + :: forall sh a i. (Shape sh, Elt a, Elt i, IsSingleIntegral (EltR i)) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) -> Acc (Segments i) -> Acc (Array (sh:.Int) a) -foldSeg' f (Exp x) = Acc $$ applyAcc (FoldSeg (integralType @i) (eltR @a) (unExpBinaryFunction f) (Just x)) +foldSeg' f (Exp x) = Acc $$ applyAcc (FoldSeg (singleIntegralType @(EltR i)) (eltR @a) (unExpBinaryFunction f) (Just x)) -- | Variant of 'foldSeg'' that requires /all/ segments of the reduced -- array to be non-empty, and doesn't need a default value. The segment @@ -571,13 +556,12 @@ foldSeg' f (Exp x) = Acc $$ applyAcc (FoldSeg (integralType @i) (eltR @a) (unExp -- @since 1.3.0.0 -- fold1Seg' - :: forall sh a i. - (Shape sh, Elt a, Elt i, IsIntegral i, i ~ EltR i) + :: forall sh a i. (Shape sh, Elt a, Elt i, IsSingleIntegral (EltR i)) => (Exp a -> Exp a -> Exp a) -> Acc (Array (sh:.Int) a) -> Acc (Segments i) -> Acc (Array (sh:.Int) a) -fold1Seg' f = Acc $$ applyAcc (FoldSeg (integralType @i) (eltR @a) (unExpBinaryFunction f) Nothing) +fold1Seg' f = Acc $$ applyAcc (FoldSeg (singleIntegralType @(EltR i)) (eltR @a) (unExpBinaryFunction f) Nothing) -- Scan functions -- -------------- @@ -1262,11 +1246,7 @@ awhile :: forall a. Arrays a -> (Acc a -> Acc a) -- ^ function to apply -> Acc a -- ^ initial value -> Acc a -awhile f = Acc $$ applyAcc $ Awhile (arraysR @a) (unAccFunction g) - where - -- FIXME: This should be a no-op! - g :: Acc a -> Acc (Scalar PrimBool) - g = map mkCoerce . f +awhile f = Acc $$ applyAcc $ Awhile (arraysR @a) (unAccFunction f) -- Shapes and indices @@ -1297,7 +1277,7 @@ intersect (Exp shx) (Exp shy) = Exp $ intersect' (shapeR @sh) shx shy intersect' (ShapeRsnoc shR) (unPair -> (xs, x)) (unPair -> (ys, y)) = SmartExp $ intersect' shR xs ys `Pair` - SmartExp (PrimApp (PrimMin singleType) $ SmartExp $ Pair x y) + SmartExp (PrimApp (PrimMin scalarType) $ SmartExp $ Pair x y) -- | Union of two shapes @@ -1310,7 +1290,7 @@ union (Exp shx) (Exp shy) = Exp $ union' (shapeR @sh) shx shy union' (ShapeRsnoc shR) (unPair -> (xs, x)) (unPair -> (ys, y)) = SmartExp $ union' shR xs ys `Pair` - SmartExp (PrimApp (PrimMax singleType) $ SmartExp $ Pair x y) + SmartExp (PrimApp (PrimMax scalarType) $ SmartExp $ Pair x y) -- Flow-control @@ -1339,7 +1319,8 @@ while :: forall e. Elt e while c f (Exp e) = mkExp $ While @(EltR e) (eltR @e) (mkCoerce' . unExp . c . Exp) - (unExp . f . Exp) e + (unExp . f . Exp) + e -- Array operations with a scalar result @@ -1499,7 +1480,7 @@ chr = mkFromIntegral -- into '1'. -- boolToInt :: Exp Bool -> Exp Int -boolToInt = mkFromIntegral . mkCoerce @_ @Word8 +boolToInt = mkFromBool -- |Reinterpret a value as another type. The two representations must have the -- same bit size. diff --git a/src/Data/Array/Accelerate/Lift.hs b/src/Data/Array/Accelerate/Lift.hs index 27482a86f..d0a7093f7 100644 --- a/src/Data/Array/Accelerate/Lift.hs +++ b/src/Data/Array/Accelerate/Lift.hs @@ -43,6 +43,7 @@ import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra hiding ( Exp ) +import Foreign.C.Types -- | Lift a unary function into 'Exp'. @@ -276,8 +277,9 @@ instance Lift Exp CDouble where instance Lift Exp Bool where type Plain Bool = Bool - lift True = Exp . SmartExp $ SmartExp (Const scalarType 1) `Pair` SmartExp Nil - lift False = Exp . SmartExp $ SmartExp (Const scalarType 0) `Pair` SmartExp Nil + lift = Exp . SmartExp . Const scalarType . Bit + -- lift True = Exp . SmartExp $ SmartExp (Const scalarType 1) `Pair` SmartExp Nil + -- lift False = Exp . SmartExp $ SmartExp (Const scalarType 0) `Pair` SmartExp Nil instance Lift Exp Char where type Plain Char = Char diff --git a/src/Data/Array/Accelerate/Orphans.hs b/src/Data/Array/Accelerate/Orphans.hs index ea570dd0d..0388accb5 100644 --- a/src/Data/Array/Accelerate/Orphans.hs +++ b/src/Data/Array/Accelerate/Orphans.hs @@ -28,7 +28,6 @@ import Numeric.Half -- base --- deriving instance (Show a, Show b, Show c, Show d, Show e, Show f, Show g, Show h, Show i, Show j, Show k, Show l, Show m, Show n, Show o, Show p) => Show (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) @@ -47,6 +46,5 @@ deriving instance Generic (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) deriving instance Generic (Ratio a) -- primitive --- deriving instance Prim Half diff --git a/src/Data/Array/Accelerate/Pattern.hs b/src/Data/Array/Accelerate/Pattern.hs index e212c0869..4de7d22a3 100644 --- a/src/Data/Array/Accelerate/Pattern.hs +++ b/src/Data/Array/Accelerate/Pattern.hs @@ -34,20 +34,19 @@ module Data.Array.Accelerate.Pattern ( pattern I0, pattern I1, pattern I2, pattern I3, pattern I4, pattern I5, pattern I6, pattern I7, pattern I8, pattern I9, + pattern SIMD, pattern V2, pattern V3, pattern V4, pattern V8, pattern V16, ) where import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Representation.Tag -import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Array import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type -import Data.Primitive.Vec import Language.Haskell.TH.Extra hiding ( Exp, Match ) @@ -58,19 +57,21 @@ import Language.Haskell.TH.Extra hiding ( Exp pattern Pattern :: forall b a context. IsPattern context a b => b -> context a pattern Pattern vars <- (matcher @context -> vars) where Pattern = builder @context +{-# COMPLETE Pattern :: Exp #-} +{-# COMPLETE Pattern :: Acc #-} class IsPattern context a b where builder :: b -> context a matcher :: context a -> b -pattern Vector :: forall b a context. IsVector context a b => b -> context a -pattern Vector vars <- (vunpack @context -> vars) - where Vector = vpack @context +pattern SIMD :: forall b a context. IsSIMD context a b => b -> context a +pattern SIMD vars <- (vmatcher @context -> vars) + where SIMD = vbuilder @context -class IsVector context a b where - vpack :: b -> context a - vunpack :: context a -> b +class IsSIMD context a b where + vbuilder :: b -> context a + vmatcher :: context a -> b -- | Pattern synonyms for indices, which may be more convenient to use than -- 'Data.Array.Accelerate.Lift.lift' and @@ -120,7 +121,8 @@ runQ $ do -- instance (Arrays x, ArraysR x ~ (((), ArraysR a), ArraysR b), Arrays a, Arrays b,) => IsPattern Acc x (Acc a, Acc b) mkAccPattern :: Int -> Q [Dec] mkAccPattern n = do - a <- newName "a" + a <- newName "a" + _x <- newName "_x" let -- Type variables for the elements xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] @@ -138,7 +140,6 @@ runQ $ do get x 0 = [| Acc (SmartAcc (Aprj PairIdxRight $x)) |] get x i = get [| SmartAcc (Aprj PairIdxLeft $x) |] (i-1) -- - _x <- newName "_x" [d| instance $context => IsPattern Acc $(varT a) $b where builder $(tupP (map (\x -> [p| Acc $(varP x)|]) xs)) = Acc $(foldl (\vs v -> [| SmartAcc ($vs `Apair` $(varE v)) |]) [| SmartAcc Anil |] xs) @@ -150,7 +151,9 @@ runQ $ do -- instance (Elt x, EltR x ~ (((), EltR a), EltR b), Elt a, Elt b,) => IsPattern Exp x (Exp a, Exp b) mkExpPattern :: Int -> Q [Dec] mkExpPattern n = do - a <- newName "a" + a <- newName "a" + _x <- newName "_x" + _y <- newName "_y" let -- Type variables for the elements xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] @@ -171,8 +174,6 @@ runQ $ do get x 0 = [| SmartExp (Prj PairIdxRight $x) |] get x i = get [| SmartExp (Prj PairIdxLeft $x) |] (i-1) -- - _x <- newName "_x" - _y <- newName "_y" [d| instance $context => IsPattern Exp $(varT a) $b where builder $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) = let _unmatch :: SmartExp a -> SmartExp a @@ -187,6 +188,71 @@ runQ $ do _ -> $(tupE [[| Exp $(get (varE _x) i) |] | i <- [(n-1), (n-2) .. 0]]) |] + -- Generate instance declarations for IsSIMD of the form: + -- instance (Elt a, Elt v, EltR v ~ VecR n a) => IsSIMD Exp v (Exp a, Exp a) + mkVecPattern :: Int -> Q [Dec] + mkVecPattern n = do + a <- newName "a" + v <- newName "v" + _x <- newName "_x" + _y <- newName "_y" + let + aT = varT a + vT = varT v + nT = litT (numTyLit (toInteger n)) + -- Last argument to `IsSIMD`, eg (Exp, a, Exp a) in the example + tup = tupT (replicate n ([t| Exp $aT |])) + -- Constraints for the type class, consisting of the Elt + -- constraints and the equality on representation types + context = [t| (Elt $aT, Elt $vT, SIMD $nT $aT, EltR $vT ~ VecR $nT $aT) |] + -- Type variables for the elements + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + -- Variables for sub-pattern matches + -- ms = [ mkName ('m' : show i) | i <- [0 .. n-1] ] + -- tags = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms + -- + [d| instance $context => IsSIMD Exp $vT $tup where + vbuilder $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) = + let _unmatch :: SmartExp a -> SmartExp a + _unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y) + _unmatch x = x + in + Exp $(foldl (\vs (i, x) -> [| mkInsert + $(varE 'vecR `appTypeE` nT `appTypeE` aT) + $(varE 'eltR `appTypeE` aT) + TypeWord8 + $vs + (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) i)) + (_unmatch $(varE x)) + |]) + [| unExp (undef :: Exp (Vec $nT $aT)) |] + (zip [0 .. n-1] xs) + ) + + vmatcher (Exp $(varP _x)) = + case $(varE _x) of + -- SmartExp (Match $tags $(varP _y)) + -- -> $(tupE [[| Exp (SmartExp (Match $(varE m) (unExp (extract (Exp $(varE _x) :: Exp $vec) (constant (i :: Word8)))))) |] | m <- ms | i <- [0 .. n-1]]) + -- -> $(tupE [[| Exp (SmartExp (Match $(varE m) (mkExtract + -- $(varE 'vecR `appTypeE` nT `appTypeE` aT) + -- $(varE 'eltR `appTypeE` aT) + -- TypeWord8 + -- $(varE _x) + -- (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) i))))) |] + -- | m <- ms + -- | i <- [0 .. n-1] ]) + + _ -> $(tupE [[| Exp $ mkExtract + $(varE 'vecR `appTypeE` nT `appTypeE` aT) + $(varE 'eltR `appTypeE` aT) + TypeWord8 + $(varE _x) + (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) i)) + |] + | i <- [0 .. n-1] ]) + |] + +{-- -- Generate instance declarations for IsVector of the form: -- instance (Elt v, EltR v ~ Vec 2 a, Elt a) => IsVector Exp v (Exp a, Exp a) mkVecPattern :: Int -> Q [Dec] @@ -210,11 +276,12 @@ runQ $ do Exp x' -> Exp (SmartExp (VecPack $vecR x')) vunpack (Exp x) = matcher (Exp (SmartExp (VecUnpack $vecR x)) :: Exp $tR) |] +--} -- es <- mapM mkExpPattern [0..16] as <- mapM mkAccPattern [0..16] vs <- mapM mkVecPattern [2,3,4,8,16] - return $ concat (es ++ as ++ vs) + return $ concat (es ++ as++ vs) -- | Specialised pattern synonyms for tuples, which may be more convenient to @@ -256,6 +323,23 @@ runQ $ do , pragCompleteD [name] (Just ''Exp) ] + mkV :: Int -> Q [Dec] + mkV n = + let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + a = varT (mkName "a") + ts = replicate n a + name = mkName ('V':show n) + tup = tupT (map (\t -> [t| Exp $t |]) ts) + vec = [t| Vec $(litT (numTyLit (toInteger n))) $a |] + cst = [t| (Elt $a, SIMD $(litT (numTyLit (toInteger n))) $a, IsSIMD Exp $vec $tup) |] + sig = foldr (\t r -> [t| Exp $t -> $r |]) [t| Exp $vec |] ts + in + sequence + [ patSynSigD name [t| $cst => $sig |] + , patSynD name (prefixPatSyn xs) implBidir [p| SIMD $(tupP (map varP xs)) |] + , pragCompleteD [name] Nothing + ] + mkI :: Int -> Q [Dec] mkI n = let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] @@ -272,6 +356,7 @@ runQ $ do , pragCompleteD [name] Nothing ] +{-- mkV :: Int -> Q [Dec] mkV n = let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] @@ -287,6 +372,7 @@ runQ $ do , patSynD name (prefixPatSyn xs) implBidir [p| Vector $(tupP (map varP xs)) |] , pragCompleteD [name] (Just ''Exp) ] +--} -- ts <- mapM mkT [2..16] is <- mapM mkI [0..9] diff --git a/src/Data/Array/Accelerate/Pattern/Bool.hs b/src/Data/Array/Accelerate/Pattern/Bool.hs index d968aaf34..093caae60 100644 --- a/src/Data/Array/Accelerate/Pattern/Bool.hs +++ b/src/Data/Array/Accelerate/Pattern/Bool.hs @@ -20,7 +20,62 @@ module Data.Array.Accelerate.Pattern.Bool ( ) where -import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Representation.Tag +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Type -mkPattern ''Bool +import GHC.Stack + + +{-# COMPLETE False_, True_ #-} +pattern False_ :: HasCallStack => Exp Bool +pattern False_ <- (matchFalse -> Just ()) + where False_ = buildFalse + +pattern True_ :: HasCallStack => Exp Bool +pattern True_ <- (matchTrue -> Just ()) + where True_ = buildTrue + + +buildFalse :: Exp Bool +buildFalse = mkExp $ Const scalarType 0 + +matchFalse :: HasCallStack => Exp Bool -> Maybe () +matchFalse (Exp e) = + case e of + SmartExp (Match (TagRbit TypeBit 0) _) -> Just () + SmartExp Match{} -> Nothing + _ -> error $ unlines + [ "Embedded pattern synonym used outside 'match' context." + , "" + , "To use case statements in the embedded language the case statement must" + , "be applied as an n-ary function to the 'match' operator. For single" + , "argument case statements this can be done inline using LambdaCase, for" + , "example:" + , "" + , "> x & match \\case" + , "> False_ -> ..." + , "> _ -> ..." + ] + +buildTrue :: Exp Bool +buildTrue = mkExp $ Const scalarType 1 + +matchTrue :: HasCallStack => Exp Bool -> Maybe () +matchTrue (Exp e) = + case e of + SmartExp (Match (TagRbit TypeBit 1) _) -> Just () + SmartExp Match{} -> Nothing + _ -> error $ unlines + [ "Embedded pattern synonym used outside 'match' context." + , "" + , "To use case statements in the embedded language the case statement must" + , "be applied as an n-ary function to the 'match' operator. For single" + , "argument case statements this can be done inline using LambdaCase, for" + , "example:" + , "" + , "> x & match \\case" + , "> True_ -> ..." + , "> _ -> ..." + ] diff --git a/src/Data/Array/Accelerate/Pattern/TH.hs b/src/Data/Array/Accelerate/Pattern/TH.hs index 0323f8d1a..2e9a2183b 100644 --- a/src/Data/Array/Accelerate/Pattern/TH.hs +++ b/src/Data/Array/Accelerate/Pattern/TH.hs @@ -293,7 +293,7 @@ mkConS tn' tvs' prev' next' tag' con' = do ++ map varE xs ++ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat fs1) - tagged = [| Exp $ SmartExp $ Pair (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeWord8))) $(litE (IntegerL (toInteger tag))))) $vs |] + tagged = [| Exp $ SmartExp $ Pair (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) $(litE (IntegerL (toInteger tag))))) $vs |] body = clause (map (\x -> [p| (Exp $(varP x)) |]) xs) (normalB tagged) [] r <- sequence [ sigD fun sig @@ -314,7 +314,7 @@ mkConS tn' tvs' prev' next' tag' con' = do fun <- newName ("_match" ++ cn) e <- newName "_e" x <- newName "_x" - (ps,es) <- extract vs [| Prj PairIdxRight $(varE x) |] [] [] + (ps,es) <- prj vs [| Prj PairIdxRight $(varE x) |] [] [] unbind <- isExtEnabled RebindableSyntax let eqE = if unbind then letE [funD (mkName "==") [clause [] (normalB (varE '(==))) []]] else id @@ -335,17 +335,17 @@ mkConS tn' tvs' prev' next' tag' con' = do (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) [t| Exp $(foldl' appT (conT tn) (map varT tvs)) -> Maybe $(tupT (map (\t -> [t| Exp $(return t) |]) fs)) |] - matchP us = [p| TagRtag $(litP (IntegerL (toInteger tag))) $pat |] + matchP us = [p| TagRtag _ $(litP (IntegerL (toInteger tag))) $pat |] where pat = [p| $(foldl (\ps p -> [p| TagRpair $ps $p |]) [p| TagRunit |] us) |] - extract [] _ ps es = return (ps, es) - extract (u:us) x ps es = do + prj [] _ ps es = return (ps, es) + prj (u:us) x ps es = do _u <- newName "_u" let x' = [| Prj PairIdxLeft (SmartExp $x) |] if not u - then extract us x' (wildP:ps) es - else extract us x' (varP _u:ps) ([| Exp (SmartExp (Match $(varE _u) (SmartExp (Prj PairIdxRight (SmartExp $x))))) |] : es) + then prj us x' (wildP:ps) es + else prj us x' (varP _u:ps) ([| Exp (SmartExp (Match $(varE _u) (SmartExp (Prj PairIdxRight (SmartExp $x))))) |] : es) vs = reverse $ [ False | _ <- concat fs0 ] ++ [ True | _ <- fs ] ++ [ False | _ <- concat fs1 ] diff --git a/src/Data/Array/Accelerate/Prelude.hs b/src/Data/Array/Accelerate/Prelude.hs index 149c46347..2d52a1231 100644 --- a/src/Data/Array/Accelerate/Prelude.hs +++ b/src/Data/Array/Accelerate/Prelude.hs @@ -705,7 +705,7 @@ fold1All f arr = fold1 f (flatten arr) -- 40, 170, 0, 138] -- foldSeg - :: forall sh e i. (Shape sh, Elt e, Elt i, i ~ EltR i, IsIntegral i) + :: forall sh e i. (Shape sh, Elt e, Num i, IsSingleIntegral (EltR i)) => (Exp e -> Exp e -> Exp e) -> Exp e -> Acc (Array (sh:.Int) e) @@ -714,17 +714,17 @@ foldSeg foldSeg f z arr seg = foldSeg' f z arr (scanl plus zero seg) where (plus, zero) = - case integralType @i of - TypeInt{} -> ((+), 0) - TypeInt8{} -> ((+), 0) - TypeInt16{} -> ((+), 0) - TypeInt32{} -> ((+), 0) - TypeInt64{} -> ((+), 0) - TypeWord{} -> ((+), 0) - TypeWord8{} -> ((+), 0) - TypeWord16{} -> ((+), 0) - TypeWord32{} -> ((+), 0) - TypeWord64{} -> ((+), 0) + case singleIntegralType @(EltR i) of + TypeInt8{} -> ((+), 0) + TypeInt16{} -> ((+), 0) + TypeInt32{} -> ((+), 0) + TypeInt64{} -> ((+), 0) + TypeInt128{} -> ((+), 0) + TypeWord8{} -> ((+), 0) + TypeWord16{} -> ((+), 0) + TypeWord32{} -> ((+), 0) + TypeWord64{} -> ((+), 0) + TypeWord128{} -> ((+), 0) -- | Variant of 'foldSeg' that requires /all/ segments of the reduced array @@ -732,7 +732,7 @@ foldSeg f z arr seg = foldSeg' f z arr (scanl plus zero seg) -- descriptor species the length of each of the logical sub-arrays. -- fold1Seg - :: forall sh e i. (Shape sh, Elt e, Elt i, i ~ EltR i, IsIntegral i) + :: forall sh e i. (Shape sh, Elt e, Num i, IsSingleIntegral (EltR i)) => (Exp e -> Exp e -> Exp e) -> Acc (Array (sh:.Int) e) -> Acc (Segments i) @@ -742,17 +742,17 @@ fold1Seg f arr seg = fold1Seg' f arr (scanl plus zero seg) plus :: Exp i -> Exp i -> Exp i zero :: Exp i (plus, zero) = - case integralType @(EltR i) of - TypeInt{} -> ((+), 0) - TypeInt8{} -> ((+), 0) - TypeInt16{} -> ((+), 0) - TypeInt32{} -> ((+), 0) - TypeInt64{} -> ((+), 0) - TypeWord{} -> ((+), 0) - TypeWord8{} -> ((+), 0) - TypeWord16{} -> ((+), 0) - TypeWord32{} -> ((+), 0) - TypeWord64{} -> ((+), 0) + case singleIntegralType @(EltR i) of + TypeInt8{} -> ((+), 0) + TypeInt16{} -> ((+), 0) + TypeInt32{} -> ((+), 0) + TypeInt64{} -> ((+), 0) + TypeInt128{} -> ((+), 0) + TypeWord8{} -> ((+), 0) + TypeWord16{} -> ((+), 0) + TypeWord32{} -> ((+), 0) + TypeWord64{} -> ((+), 0) + TypeWord128{} -> ((+), 0) -- Specialised reductions diff --git a/src/Data/Array/Accelerate/Pretty/Graphviz.hs b/src/Data/Array/Accelerate/Pretty/Graphviz.hs index ca63fd323..f3954e12e 100644 --- a/src/Data/Array/Accelerate/Pretty/Graphviz.hs +++ b/src/Data/Array/Accelerate/Pretty/Graphviz.hs @@ -519,19 +519,20 @@ fvOpenExp env aenv = fv fv Evar{} = [] fv Undef{} = [] fv Const{} = [] - fv PrimConst{} = [] fv (PrimApp _ x) = fv x - fv (Pair e1 e2) = concat [ fv e1, fv e2] + fv (Pair e1 e2) = concat [ fv e1, fv e2 ] fv Nil = [] - fv (VecPack _ e) = fv e - fv (VecUnpack _ e) = fv e + fv (Extract _ _ v i) = concat [ fv v, fv i ] + fv (Insert _ _ v i x) = concat [ fv v, fv i, fv x ] + fv (Shuffle _ _ x y i) = concat [ fv x, fv y, fv i ] + fv (Select m x y) = concat [ fv m, fv x, fv y ] fv (IndexSlice _ slix sh) = concat [ fv slix, fv sh ] fv (IndexFull _ slix sh) = concat [ fv slix, fv sh ] fv (ToIndex _ sh ix) = concat [ fv sh, fv ix ] fv (FromIndex _ sh ix) = concat [ fv sh, fv ix ] fv (ShapeSize _ sh) = fv sh fv Foreign{} = [] - fv (Case e rhs def) = concat [ fv e, concat [ fv c | (_,c) <- rhs ], maybe [] fv def ] + fv (Case _ e rhs def) = concat [ fv e, concat [ fv c | (_,c) <- rhs ], maybe [] fv def ] fv (Cond p t e) = concat [ fv p, fv t, fv e ] fv (While p f x) = concat [ fvF p, fvF f, fv x ] fv (Coerce _ _ e) = fv e diff --git a/src/Data/Array/Accelerate/Pretty/Print.hs b/src/Data/Array/Accelerate/Pretty/Print.hs index 965aecd99..15f5a09df 100644 --- a/src/Data/Array/Accelerate/Pretty/Print.hs +++ b/src/Data/Array/Accelerate/Pretty/Print.hs @@ -59,10 +59,10 @@ import Data.Array.Accelerate.AST hiding ( Dir import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Var +import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Elt import Data.Array.Accelerate.Representation.Stencil -import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Type @@ -412,13 +412,14 @@ prettyOpenExp ctx env aenv exp = op = primOperator f op' = isInfix op ? (Operator (parens (opName op)) App L 10, op) -- - PrimConst c -> prettyPrimConst c Const tp c -> prettyConst (TupRsingle tp) c Pair{} -> prettyTuple ctx env aenv exp Nil -> "()" - VecPack _ e -> ppF1 "pack" (ppE e) - VecUnpack _ e -> ppF1 "unpack" (ppE e) - Case x xs d -> prettyCase env aenv x xs d + Extract _ _ v i -> ppF2 (Operator "#" Infix L 9) (ppE v) (ppE i) + Insert{} -> prettyInsert ctx env aenv exp + Shuffle _ _ x y i -> ppF3 "shuffle" (ppE x) (ppE y) (ppE i) + Select m x y -> ppF3 "select" (ppE m) (ppE x) (ppE y) + Case tR x xs d -> prettyCase env aenv tR x xs d Cond p t e -> flatAlt multi single where p' = ppE p context0 @@ -558,24 +559,58 @@ prettyTuple ctx env aenv exp = case collect exp of prettyCase :: Val env -> Val aenv + -> ScalarType tag -> OpenExp env aenv a - -> [(TAG, OpenExp env aenv b)] + -> [(tag, OpenExp env aenv b)] -> Maybe (OpenExp env aenv b) -> Adoc -prettyCase env aenv x xs def +prettyCase env aenv tagR x xs def = hang shiftwidth $ vsep [ case_ <+> x' <+> of_ , flatAlt (vcat xs') (encloseSep "{ " " }" "; " xs') ] where x' = prettyOpenExp context0 env aenv x - xs' = map (\(t,e) -> pretty t <+> "->" <+> prettyOpenExp context0 env aenv e) xs + xs' = map (\(t,e) -> prettyConst (TupRsingle tagR) t <+> "->" <+> prettyOpenExp context0 env aenv e) xs ++ case def of Nothing -> [] Just d -> ["_" <+> "->" <+> prettyOpenExp context0 env aenv d] -{- - +prettyInsert + :: forall env aenv t. + Context + -> Val env + -> Val aenv + -> OpenExp env aenv t + -> Adoc +prettyInsert ctx env aenv exp = + case collect exp of + Just (c, xs) -> align $ parensIf (ctxPrecedence ctx > 0) ("V" <> pretty c <+> align (sep (reverse xs))) + Nothing -> align $ ppInsert exp + where + ppInsert :: OpenExp env aenv t' -> Adoc + ppInsert (Insert _ _ v i x) = + let v' = prettyOpenExp ctx env aenv v + i' = prettyOpenExp ctx env aenv i + x' = prettyOpenExp ctx env aenv x + in + parensIf (ctxPrecedence ctx > 0) + $ hang 2 + $ sep [ "insert", brackets i', x', v'] + ppInsert e = + prettyOpenExp context0 env aenv e + + collect :: OpenExp env aenv t' -> Maybe (Word8, [Adoc]) + collect Undef{} + = Just (0, []) + collect (Insert _ _ v i x) + | Just (i', xs) <- collect v + , Just Refl <- matchOpenExp i (Const scalarType i') + = Just (i'+1, prettyOpenExp app env aenv x : xs) + collect _ + = Nothing + +{-- prettyAtuple :: forall acc aenv arrs. PrettyAcc acc @@ -597,17 +632,16 @@ prettyAtuple prettyAcc extractAcc aenv0 acc = case collect acc of | Just tup <- collect $ extractAcc a1 = Just $ tup ++ [prettyAcc app aenv0 a2] collect _ = Nothing --} +--} prettyConst :: TypeR e -> e -> Adoc prettyConst tp x = - let y = showElt tp x - in parensIf (any isSpace y) (pretty y) - -prettyPrimConst :: PrimConst a -> Adoc -prettyPrimConst PrimMinBound{} = "minBound" -prettyPrimConst PrimMaxBound{} = "maxBound" -prettyPrimConst PrimPi{} = "pi" + let y = showElt tp x + -- + isVec [] = False + isVec xs = head xs == '<' && last xs == '>' + in + parensIf (any isSpace y && not (isVec y)) (pretty y) -- Primitive operators @@ -722,11 +756,13 @@ primOperator PrimEq{} = Operator "==" Infix N 4 primOperator PrimNEq{} = Operator "/=" Infix N 4 primOperator PrimMax{} = Operator "max" App L 10 primOperator PrimMin{} = Operator "min" App L 10 -primOperator PrimLAnd = Operator "&&" Infix R 3 -primOperator PrimLOr = Operator "||" Infix R 2 -primOperator PrimLNot = Operator "not" App L 10 +primOperator PrimLAnd{} = Operator "&&" Infix R 3 +primOperator PrimLOr{} = Operator "||" Infix R 2 +primOperator PrimLNot{} = Operator "not" App L 10 primOperator PrimFromIntegral{} = Operator "fromIntegral" App L 10 primOperator PrimToFloating{} = Operator "toFloating" App L 10 +primOperator PrimToBool{} = Operator "toBool" App L 10 +primOperator PrimFromBool{} = Operator "fromBool" App L 10 -- Environments diff --git a/src/Data/Array/Accelerate/Representation/Array.hs b/src/Data/Array/Accelerate/Representation/Array.hs index d61304e76..151294c13 100644 --- a/src/Data/Array/Accelerate/Representation/Array.hs +++ b/src/Data/Array/Accelerate/Representation/Array.hs @@ -98,7 +98,7 @@ arraysRpair a b = TupRunit `TupRpair` TupRsingle a `TupRpair` TupRsingle b -- allocateArray :: ArrayR (Array sh e) -> sh -> IO (Array sh e) allocateArray (ArrayR shR eR) sh = do - adata <- newArrayData eR (size shR sh) + adata <- newArrayData eR (fromIntegral (size shR sh)) return $! Array sh adata -- | Create an array from its representation function, applied at each @@ -114,13 +114,13 @@ fromFunction repr sh f = unsafePerformIO $! fromFunctionM repr sh (return . f) fromFunctionM :: ArrayR (Array sh e) -> sh -> (sh -> IO e) -> IO (Array sh e) fromFunctionM (ArrayR shR eR) sh f = do let !n = size shR sh - arr <- newArrayData eR n + arr <- newArrayData eR (fromIntegral n) -- let write !i | i >= n = return () | otherwise = do v <- f (fromIndex shR sh i) - writeArrayData eR arr i v + writeArrayData eR arr (fromIntegral i) v write (i+1) -- write 0 @@ -135,7 +135,7 @@ fromList (ArrayR shR eR) sh xs = adata `seq` Array sh adata -- Assume the array is in dense row-major order. This is safe because -- otherwise backends would not be able to directly memcpy. -- - !n = size shR sh + !n = fromIntegral (size shR sh) (adata, _) = runArrayData @e $ do arr <- newArrayData eR n let go !i _ | i >= n = return () @@ -154,21 +154,21 @@ toList (ArrayR shR eR) (Array sh adata) = go 0 -- Assume underling array is in row-major order. This is safe because -- otherwise backends would not be able to directly memcpy. -- - !n = size shR sh + !n = fromIntegral (size shR sh) go !i | i >= n = [] | otherwise = indexArrayData eR adata i : go (i+1) -concatVectors :: forall e. TypeR e -> [Vector e] -> Vector e -concatVectors tR vs = adata `seq` Array ((), len) adata - where - offsets = scanl (+) 0 (map (size dim1 . shape) vs) - len = last offsets - (adata, _) = runArrayData @e $ do - arr <- newArrayData tR len - sequence_ [ writeArrayData tR arr (i + k) (indexArrayData tR ad i) - | (Array ((), n) ad, k) <- vs `zip` offsets - , i <- [0 .. n - 1] ] - return (arr, undefined) +-- concatVectors :: forall e. TypeR e -> [Vector e] -> Vector e +-- concatVectors tR vs = adata `seq` Array ((), len) adata +-- where +-- offsets = scanl (+) 0 (map (size dim1 . shape) vs) +-- len = last offsets +-- (adata, _) = runArrayData @e $ do +-- arr <- newArrayData tR len +-- sequence_ [ writeArrayData tR arr (i + k) (indexArrayData tR ad i) +-- | (Array ((), n) ad, k) <- vs `zip` offsets +-- , i <- [0 .. n - 1] ] +-- return (arr, undefined) shape :: Array sh e -> sh shape (Array sh _) = sh @@ -181,14 +181,14 @@ reshape shR sh shR' (Array sh' adata) (!) :: (ArrayR (Array sh e), Array sh e) -> sh -> e (!) = uncurry indexArray -(!!) :: (TypeR e, Array sh e) -> Int -> e +(!!) :: (TypeR e, Array sh e) -> INT -> e (!!) = uncurry linearIndexArray indexArray :: ArrayR (Array sh e) -> Array sh e -> sh -> e -indexArray (ArrayR shR adR) (Array sh adata) ix = indexArrayData adR adata (toIndex shR sh ix) +indexArray (ArrayR shR adR) (Array sh adata) ix = indexArrayData adR adata (fromIntegral (toIndex shR sh ix)) -linearIndexArray :: TypeR e -> Array sh e -> Int -> e -linearIndexArray adR (Array _ adata) = indexArrayData adR adata +linearIndexArray :: TypeR e -> Array sh e -> INT -> e +linearIndexArray adR (Array _ adata) = indexArrayData adR adata . fromIntegral showArray :: (e -> ShowS) -> ArrayR (Array sh e) -> Array sh e -> String showArray f arrR@(ArrayR shR _) arr@(Array sh _) = case shR of @@ -217,9 +217,11 @@ showMatrix f (ArrayR _ arrR) arr@(Array sh _) | rows * cols == 0 = "[]" | otherwise = "\n [" ++ ppMat 0 0 where - (((), rows), cols) = sh - lengths = U.generate (rows*cols) (\i -> length (f (linearIndexArray arrR arr i) "")) - widths = U.generate cols (\c -> U.maximum (U.generate rows (\r -> lengths U.! (r*cols+c)))) + (((), ri), ci) = sh + rows = fromIntegral ri + cols = fromIntegral ci + lengths = U.generate (rows*cols) (\i -> length (f (linearIndexArray arrR arr (fromIntegral i)) "")) + widths = U.generate cols (\c -> U.maximum (U.generate rows (\r -> lengths U.! (r*cols+c)))) -- ppMat :: Int -> Int -> String ppMat !r !c | c >= cols = ppMat (r+1) 0 @@ -229,7 +231,7 @@ showMatrix f (ArrayR _ arrR) arr@(Array sh _) !l = lengths U.! i !w = widths U.! c !pad = 1 - cell = replicate (w-l+pad) ' ' ++ f (linearIndexArray arrR arr i) "" + cell = replicate (w-l+pad) ' ' ++ f (linearIndexArray arrR arr (fromIntegral i)) "" -- before | r > 0 && c == 0 = "\n " @@ -294,7 +296,7 @@ showsArrays repr arrs = go 0 repr arrs needsParens repr'@(TupRpair _ _) as = isJust $ extractTuple repr' as needsParens _ _ = True -reduceRank :: ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e) +reduceRank :: ArrayR (Array (sh, INT) e) -> ArrayR (Array sh e) reduceRank (ArrayR (ShapeRsnoc shR) aeR) = ArrayR shR aeR rnfArray :: ArrayR a -> a -> () @@ -320,8 +322,7 @@ liftArray :: forall sh e. ArrayR (Array sh e) -> Array sh e -> CodeQ (Array sh e liftArray (ArrayR shR adR) (Array sh adata) = [|| Array $$(liftElt (shapeType shR) sh) $$(liftArrayData sz adR adata) ||] `at` [t| Array $(liftTypeQ (shapeType shR)) $(liftTypeQ adR) |] where - sz :: Int - sz = size shR sh + sz = fromIntegral (size shR sh) at :: CodeQ t -> Q Type -> CodeQ t at e t = unsafeCodeCoerce $ sigE (unTypeCode e) t diff --git a/src/Data/Array/Accelerate/Representation/Elt.hs b/src/Data/Array/Accelerate/Representation/Elt.hs index d72cd7165..cdaed373f 100644 --- a/src/Data/Array/Accelerate/Representation/Elt.hs +++ b/src/Data/Array/Accelerate/Representation/Elt.hs @@ -1,4 +1,5 @@ {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} @@ -18,14 +19,16 @@ module Data.Array.Accelerate.Representation.Elt import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type +import Data.Primitive.Bit import Data.Primitive.Vec import Control.Monad.ST -import Data.List ( intercalate ) import Data.Primitive.ByteArray -import Foreign.Storable import Language.Haskell.TH.Extra +import GHC.TypeLits +import GHC.Base + undefElt :: TypeR t -> t undefElt = tuple @@ -36,38 +39,66 @@ undefElt = tuple tuple (TupRsingle t) = scalar t scalar :: ScalarType t -> t - scalar (SingleScalarType t) = single t - scalar (VectorScalarType t) = vector t - - vector :: VectorType t -> t - vector (VectorType n t) = runST $ do - mba <- newByteArray (n * bytesElt (TupRsingle (SingleScalarType t))) - ByteArray ba# <- unsafeFreezeByteArray mba - return (Vec ba#) - - single :: SingleType t -> t - single (NumSingleType t) = num t + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> t + bit TypeBit = Bit False + bit (TypeMask n) = + let (q, r) = quotRem (fromInteger (natVal' n)) 8 + bytes = if r == 0 then q else q + 1 + in + runST $ do + mba <- newByteArray bytes + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# num :: NumType t -> t num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType t -> t - integral TypeInt = 0 - integral TypeInt8 = 0 - integral TypeInt16 = 0 - integral TypeInt32 = 0 - integral TypeInt64 = 0 - integral TypeWord = 0 - integral TypeWord8 = 0 - integral TypeWord16 = 0 - integral TypeWord32 = 0 - integral TypeWord64 = 0 + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n t + where + single :: SingleIntegralType t -> t + single TypeInt8 = 0 + single TypeInt16 = 0 + single TypeInt32 = 0 + single TypeInt64 = 0 + single TypeInt128 = 0 + single TypeWord8 = 0 + single TypeWord16 = 0 + single TypeWord32 = 0 + single TypeWord64 = 0 + single TypeWord128 = 0 + + vector :: KnownNat n => Proxy# n -> SingleIntegralType t -> Vec n t + vector n t = runST $ do + let bytes = bytesElt (TupRsingle (NumScalarType (IntegralNumType (VectorIntegralType n t)))) + mba <- newAlignedPinnedByteArray bytes 16 + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# floating :: FloatingType t -> t - floating TypeHalf = 0 - floating TypeFloat = 0 - floating TypeDouble = 0 + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n t + where + single :: SingleFloatingType t -> t + single TypeFloat16 = 0 + single TypeFloat32 = 0 + single TypeFloat64 = 0 + single TypeFloat128 = 0 + + vector :: KnownNat n => Proxy# n -> SingleFloatingType t -> Vec n t + vector n t = runST $ do + let bytes = bytesElt (TupRsingle (NumScalarType (FloatingNumType (VectorFloatingType n t)))) + mba <- newAlignedPinnedByteArray bytes 16 + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + bytesElt :: TypeR e -> Int bytesElt = tuple @@ -78,35 +109,46 @@ bytesElt = tuple tuple (TupRsingle t) = scalar t scalar :: ScalarType t -> Int - scalar (SingleScalarType t) = single t - scalar (VectorScalarType t) = vector t - - vector :: VectorType t -> Int - vector (VectorType n t) = n * single t + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t - single :: SingleType t -> Int - single (NumSingleType t) = num t + bit :: BitType t -> Int + bit TypeBit = 1 -- stored as Word8 + bit (TypeMask n) = + let (q,r) = quotRem (fromInteger (natVal' n)) 8 + in if r == 0 then q else q+1 num :: NumType t -> Int num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType t -> Int - integral TypeInt = sizeOf (undefined::Int) - integral TypeInt8 = 1 - integral TypeInt16 = 2 - integral TypeInt32 = 4 - integral TypeInt64 = 8 - integral TypeWord = sizeOf (undefined::Word) - integral TypeWord8 = 1 - integral TypeWord16 = 2 - integral TypeWord32 = 4 - integral TypeWord64 = 8 + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> fromInteger (natVal' n) * single t + where + single :: SingleIntegralType t -> Int + single TypeInt8 = 1 + single TypeInt16 = 2 + single TypeInt32 = 4 + single TypeInt64 = 8 + single TypeInt128 = 16 + single TypeWord8 = 1 + single TypeWord16 = 2 + single TypeWord32 = 4 + single TypeWord64 = 8 + single TypeWord128 = 16 floating :: FloatingType t -> Int - floating TypeHalf = 2 - floating TypeFloat = 4 - floating TypeDouble = 8 + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> fromInteger (natVal' n) * single t + where + single :: SingleFloatingType t -> Int + single TypeFloat16 = 2 + single TypeFloat32 = 4 + single TypeFloat64 = 8 + single TypeFloat128 = 16 showElt :: TypeR e -> e -> String showElt t v = showsElt t v "" @@ -120,38 +162,62 @@ showsElt = tuple tuple (TupRsingle tp) val = scalar tp val scalar :: ScalarType e -> e -> ShowS - scalar (SingleScalarType t) e = single t e - scalar (VectorScalarType t) e = vector t e + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t - single :: SingleType e -> e -> ShowS - single (NumSingleType t) = num t + bit :: BitType e -> e -> ShowS + bit TypeBit = shows + bit TypeMask{} = shows . BitMask num :: NumType e -> e -> ShowS num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType e -> e -> ShowS - integral TypeInt = shows - integral TypeInt8 = shows - integral TypeInt16 = shows - integral TypeInt32 = shows - integral TypeInt64 = shows - integral TypeWord = shows - integral TypeWord8 = shows - integral TypeWord16 = shows - integral TypeWord32 = shows - integral TypeWord64 = shows + integral = \case + SingleIntegralType t -> single t + VectorIntegralType _ t -> vector t + where + single :: SingleIntegralType t -> t -> ShowS + single TypeInt8 = shows + single TypeInt16 = shows + single TypeInt32 = shows + single TypeInt64 = shows + single TypeInt128 = shows + single TypeWord8 = shows + single TypeWord16 = shows + single TypeWord32 = shows + single TypeWord64 = shows + single TypeWord128 = shows + + vector :: KnownNat n => SingleIntegralType t -> Vec n t -> ShowS + vector TypeInt8 = shows + vector TypeInt16 = shows + vector TypeInt32 = shows + vector TypeInt64 = shows + vector TypeInt128 = shows + vector TypeWord8 = shows + vector TypeWord16 = shows + vector TypeWord32 = shows + vector TypeWord64 = shows + vector TypeWord128 = shows floating :: FloatingType e -> e -> ShowS - floating TypeHalf = shows - floating TypeFloat = shows - floating TypeDouble = shows - - vector :: VectorType (Vec n a) -> Vec n a -> ShowS - vector (VectorType _ s) vec - | SingleDict <- singleDict s - = showString - $ "<" ++ intercalate ", " ((\v -> single s v "") <$> listOfVec vec) ++ ">" + floating = \case + SingleFloatingType t -> single t + VectorFloatingType _ t -> vector t + where + single :: SingleFloatingType t -> t -> ShowS + single TypeFloat16 = shows + single TypeFloat32 = shows + single TypeFloat64 = shows + single TypeFloat128 = shows + + vector :: KnownNat n => SingleFloatingType t -> Vec n t -> ShowS + vector TypeFloat16 = shows + vector TypeFloat32 = shows + vector TypeFloat64 = shows + vector TypeFloat128 = shows liftElt :: TypeR t -> t -> CodeQ t liftElt TupRunit () = [|| () ||] diff --git a/src/Data/Array/Accelerate/Representation/Shape.hs b/src/Data/Array/Accelerate/Representation/Shape.hs index fa3651c03..fb68d71ab 100644 --- a/src/Data/Array/Accelerate/Representation/Shape.hs +++ b/src/Data/Array/Accelerate/Representation/Shape.hs @@ -23,14 +23,12 @@ import Data.Array.Accelerate.Representation.Type import Language.Haskell.TH.Extra import Prelude hiding ( zip ) -import GHC.Base ( quotInt, remInt ) - -- | Shape and index representations as nested pairs -- data ShapeR sh where ShapeRz :: ShapeR () - ShapeRsnoc :: ShapeR sh -> ShapeR (sh, Int) + ShapeRsnoc :: ShapeR sh -> ShapeR (sh, INT) -- | Nicely format a shape as a string -- @@ -40,9 +38,9 @@ showShape shr = foldr (\sh str -> str ++ " :. " ++ show sh) "Z" . shapeToList sh -- Synonyms for common shape types -- type DIM0 = () -type DIM1 = ((), Int) -type DIM2 = (((), Int), Int) -type DIM3 = ((((), Int), Int), Int) +type DIM1 = ((), INT) +type DIM2 = (((), INT), INT) +type DIM3 = ((((), INT), INT), INT) dim0 :: ShapeR DIM0 dim0 = ShapeRz @@ -58,13 +56,13 @@ dim3 = ShapeRsnoc dim2 -- | Number of dimensions of a /shape/ or /index/ (>= 0) -- -rank :: ShapeR sh -> Int +rank :: ShapeR sh -> INT rank ShapeRz = 0 rank (ShapeRsnoc shr) = rank shr + 1 -- | Total number of elements in an array of the given shape -- -size :: ShapeR sh -> sh -> Int +size :: ShapeR sh -> sh -> INT size ShapeRz () = 1 size (ShapeRsnoc shr) (sh, sz) | sz <= 0 = 0 @@ -86,7 +84,7 @@ intersect = zip min union :: ShapeR sh -> sh -> sh -> sh union = zip max -zip :: (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh +zip :: (INT -> INT -> INT) -> ShapeR sh -> sh -> sh -> sh zip _ ShapeRz () () = () zip f (ShapeRsnoc shr) (as, a) (bs, b) = (zip f shr as bs, f a b) @@ -99,25 +97,25 @@ eq (ShapeRsnoc shr) (sh, i) (sh', i') = i == i' && eq shr sh sh' -- representation of the array (first argument is the /shape/, second -- argument is the /index/). -- -toIndex :: HasCallStack => ShapeR sh -> sh -> sh -> Int -toIndex ShapeRz () () = 0 +toIndex :: HasCallStack => ShapeR sh -> sh -> sh -> INT +toIndex ShapeRz () () = 0 toIndex (ShapeRsnoc shr) (sh, sz) (ix, i) = indexCheck i sz $ toIndex shr sh ix * sz + i -- | Inverse of 'toIndex' -- -fromIndex :: HasCallStack => ShapeR sh -> sh -> Int -> sh -fromIndex ShapeRz () _ = () +fromIndex :: HasCallStack => ShapeR sh -> sh -> INT -> sh +fromIndex ShapeRz () _ = () fromIndex (ShapeRsnoc shr) (sh, sz) i - = (fromIndex shr sh (i `quotInt` sz), r) + = (fromIndex shr sh (i `quot` sz), r) -- If we assume that the index is in range, there is no point in computing -- the remainder for the highest dimension since i < sz must hold. -- where r = case shr of -- Check if rank of shr is 0 ShapeRz -> indexCheck i sz i - _ -> i `remInt` sz + _ -> i `rem` sz -- | Iterate through the entire shape, applying the function in the second -- argument; third argument combines results and fourth is an initial value @@ -157,13 +155,13 @@ shapeToRange (ShapeRsnoc shr) (sh, sz) = let (low, high) = shapeToRange shr sh i -- | Convert a shape or index into its list of dimensions -- -shapeToList :: ShapeR sh -> sh -> [Int] +shapeToList :: ShapeR sh -> sh -> [INT] shapeToList ShapeRz () = [] shapeToList (ShapeRsnoc shr) (sh,sz) = sz : shapeToList shr sh -- | Convert a list of dimensions into a shape -- -listToShape :: HasCallStack => ShapeR sh -> [Int] -> sh +listToShape :: HasCallStack => ShapeR sh -> [INT] -> sh listToShape shr ds = case listToShape' shr ds of Just sh -> sh @@ -171,17 +169,14 @@ listToShape shr ds = -- | Attempt to convert a list of dimensions into a shape -- -listToShape' :: ShapeR sh -> [Int] -> Maybe sh +listToShape' :: ShapeR sh -> [INT] -> Maybe sh listToShape' ShapeRz [] = Just () listToShape' (ShapeRsnoc shr) (x:xs) = (, x) <$> listToShape' shr xs listToShape' _ _ = Nothing shapeType :: ShapeR sh -> TypeR sh shapeType ShapeRz = TupRunit -shapeType (ShapeRsnoc shr) = - shapeType shr - `TupRpair` - TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt))) +shapeType (ShapeRsnoc shR) = shapeType shR `TupRpair` TupRsingle scalarType rnfShape :: ShapeR sh -> sh -> () rnfShape ShapeRz () = () diff --git a/src/Data/Array/Accelerate/Representation/Slice.hs b/src/Data/Array/Accelerate/Representation/Slice.hs index dee059a37..8e0246ad1 100644 --- a/src/Data/Array/Accelerate/Representation/Slice.hs +++ b/src/Data/Array/Accelerate/Representation/Slice.hs @@ -19,6 +19,7 @@ module Data.Array.Accelerate.Representation.Slice where import Data.Array.Accelerate.Representation.Shape +import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra @@ -38,15 +39,15 @@ instance Slice () where sliceIndex = SliceNil instance Slice sl => Slice (sl, ()) where - type SliceShape (sl, ()) = (SliceShape sl, Int) + type SliceShape (sl, ()) = (SliceShape sl, INT) type CoSliceShape (sl, ()) = CoSliceShape sl - type FullShape (sl, ()) = (FullShape sl, Int) + type FullShape (sl, ()) = (FullShape sl, INT) sliceIndex = SliceAll (sliceIndex @sl) -instance Slice sl => Slice (sl, Int) where - type SliceShape (sl, Int) = SliceShape sl - type CoSliceShape (sl, Int) = (CoSliceShape sl, Int) - type FullShape (sl, Int) = (FullShape sl, Int) +instance Slice sl => Slice (sl, INT) where + type SliceShape (sl, INT) = SliceShape sl + type CoSliceShape (sl, INT) = (CoSliceShape sl, INT) + type FullShape (sl, INT) = (FullShape sl, INT) sliceIndex = SliceFixed (sliceIndex @sl) -- |Generalised array index, which may index only in a subset of the dimensions @@ -54,8 +55,8 @@ instance Slice sl => Slice (sl, Int) where -- data SliceIndex ix slice coSlice sliceDim where SliceNil :: SliceIndex () () () () - SliceAll :: SliceIndex ix slice co dim -> SliceIndex (ix, ()) (slice, Int) co (dim, Int) - SliceFixed :: SliceIndex ix slice co dim -> SliceIndex (ix, Int) slice (co, Int) (dim, Int) + SliceAll :: SliceIndex ix slice co dim -> SliceIndex (ix, ()) (slice, INT) co (dim, INT) + SliceFixed :: SliceIndex ix slice co dim -> SliceIndex (ix, INT) slice (co, INT) (dim, INT) instance Show (SliceIndex ix slice coSlice sliceDim) where show SliceNil = "SliceNil" diff --git a/src/Data/Array/Accelerate/Representation/Stencil.hs b/src/Data/Array/Accelerate/Representation/Stencil.hs index dd546721c..804892800 100644 --- a/src/Data/Array/Accelerate/Representation/Stencil.hs +++ b/src/Data/Array/Accelerate/Representation/Stencil.hs @@ -25,6 +25,7 @@ module Data.Array.Accelerate.Representation.Stencil ( import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra @@ -40,14 +41,14 @@ data StencilR sh e pat where StencilRtup3 :: StencilR sh e pat1 -> StencilR sh e pat2 -> StencilR sh e pat3 - -> StencilR (sh, Int) e (Tup3 pat1 pat2 pat3) + -> StencilR (sh, INT) e (Tup3 pat1 pat2 pat3) StencilRtup5 :: StencilR sh e pat1 -> StencilR sh e pat2 -> StencilR sh e pat3 -> StencilR sh e pat4 -> StencilR sh e pat5 - -> StencilR (sh, Int) e (Tup5 pat1 pat2 pat3 pat4 pat5) + -> StencilR (sh, INT) e (Tup5 pat1 pat2 pat3 pat4 pat5) StencilRtup7 :: StencilR sh e pat1 -> StencilR sh e pat2 @@ -56,7 +57,7 @@ data StencilR sh e pat where -> StencilR sh e pat5 -> StencilR sh e pat6 -> StencilR sh e pat7 - -> StencilR (sh, Int) e (Tup7 pat1 pat2 pat3 pat4 pat5 pat6 pat7) + -> StencilR (sh, INT) e (Tup7 pat1 pat2 pat3 pat4 pat5 pat6 pat7) StencilRtup9 :: StencilR sh e pat1 -> StencilR sh e pat2 @@ -67,7 +68,7 @@ data StencilR sh e pat where -> StencilR sh e pat7 -> StencilR sh e pat8 -> StencilR sh e pat9 - -> StencilR (sh, Int) e (Tup9 pat1 pat2 pat3 pat4 pat5 pat6 pat7 pat8 pat9) + -> StencilR (sh, INT) e (Tup9 pat1 pat2 pat3 pat4 pat5 pat6 pat7 pat8 pat9) stencilEltR :: StencilR sh e pat -> TypeR e stencilEltR (StencilRunit3 t) = t @@ -123,7 +124,7 @@ stencilHalo = go' go :: StencilR sh e stencil -> sh go = snd . go' - cons :: ShapeR sh -> Int -> sh -> (sh, Int) + cons :: ShapeR sh -> INT -> sh -> (sh, INT) cons ShapeRz ix () = ((), ix) cons (ShapeRsnoc shr) ix (sh, sz) = (cons shr ix sh, sz) diff --git a/src/Data/Array/Accelerate/Representation/Tag.hs b/src/Data/Array/Accelerate/Representation/Tag.hs index ed7e07e80..034390dbf 100644 --- a/src/Data/Array/Accelerate/Representation/Tag.hs +++ b/src/Data/Array/Accelerate/Representation/Tag.hs @@ -1,5 +1,6 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Representation.Tag @@ -15,6 +16,7 @@ module Data.Array.Accelerate.Representation.Tag where import Data.Array.Accelerate.Type +import Data.Primitive.Bit import Language.Haskell.TH.Extra @@ -32,7 +34,7 @@ type TAG = Word8 -- The function 'eltTags' produces all valid paths through the type. For -- example the type '(Bool,Bool)' produces the following: -- --- ghci> putStrLn . unlines . map show $ eltTags @(Bool,Bool) +-- ghci> putStrLn . unlines . map show $ tagsR @(Bool,Bool) -- (((),(0#,())),(0#,())) -- (False, False) -- (((),(0#,())),(1#,())) -- (False, True) -- (((),(1#,())),(0#,())) -- (True, False) @@ -42,27 +44,47 @@ data TagR a where TagRunit :: TagR () TagRsingle :: ScalarType a -> TagR a TagRundef :: ScalarType a -> TagR a - TagRtag :: TAG -> TagR a -> TagR (TAG, a) TagRpair :: TagR a -> TagR b -> TagR (a, b) + TagRtag :: SingleIntegralType t -> t -> TagR a -> TagR (t, a) + TagRbit :: BitType t -> t -> TagR t instance Show (TagR a) where show TagRunit = "()" show TagRsingle{} = "." show TagRundef{} = "undef" - show (TagRtag v t) = "(" ++ show v ++ "#," ++ show t ++ ")" show (TagRpair ta tb) = "(" ++ show ta ++ "," ++ show tb ++ ")" + show (TagRtag tR t e) = "(" ++ integral tR t ++ "#," ++ show e ++ ")" + where + integral :: SingleIntegralType t -> t -> String + integral TypeInt8 = show + integral TypeInt16 = show + integral TypeInt32 = show + integral TypeInt64 = show + integral TypeInt128 = show + integral TypeWord8 = show + integral TypeWord16 = show + integral TypeWord32 = show + integral TypeWord64 = show + integral TypeWord128 = show + show (TagRbit tR t) = bit tR t + where + bit :: BitType t -> t -> String + bit TypeBit x = shows x "#" + bit TypeMask{} x = shows (BitMask x) "#" rnfTag :: TagR a -> () rnfTag TagRunit = () -rnfTag (TagRsingle t) = rnfScalarType t -rnfTag (TagRundef t) = rnfScalarType t -rnfTag (TagRtag v t) = v `seq` rnfTag t +rnfTag (TagRsingle e) = rnfScalarType e +rnfTag (TagRundef e) = rnfScalarType e rnfTag (TagRpair ta tb) = rnfTag ta `seq` rnfTag tb +rnfTag (TagRtag tR t e) = rnfSingleIntegralType tR `seq` t `seq` rnfTag e +rnfTag (TagRbit tR t) = rnfBitType tR `seq` t `seq` () liftTag :: TagR a -> CodeQ (TagR a) liftTag TagRunit = [|| TagRunit ||] -liftTag (TagRsingle t) = [|| TagRsingle $$(liftScalarType t) ||] -liftTag (TagRundef t) = [|| TagRundef $$(liftScalarType t) ||] -liftTag (TagRtag v t) = [|| TagRtag v $$(liftTag t) ||] +liftTag (TagRsingle e) = [|| TagRsingle $$(liftScalarType e) ||] +liftTag (TagRundef e) = [|| TagRundef $$(liftScalarType e) ||] liftTag (TagRpair ta tb) = [|| TagRpair $$(liftTag ta) $$(liftTag tb) ||] +liftTag (TagRtag tR t e) = [|| TagRtag $$(liftSingleIntegralType tR) $$(liftSingleIntegral tR t) $$(liftTag e) ||] +liftTag (TagRbit tR t) = [|| TagRbit $$(liftBitType tR) $$(liftBit tR t) ||] diff --git a/src/Data/Array/Accelerate/Representation/Type.hs b/src/Data/Array/Accelerate/Representation/Type.hs index 477f09a00..13ecc126d 100644 --- a/src/Data/Array/Accelerate/Representation/Type.hs +++ b/src/Data/Array/Accelerate/Representation/Type.hs @@ -24,6 +24,8 @@ import Data.Primitive.Vec import Formatting import Language.Haskell.TH.Extra +import GHC.TypeLits + -- | Both arrays (Acc) and expressions (Exp) are represented as nested -- pairs consisting of: @@ -83,35 +85,45 @@ liftTypeQ = tuple tuple (TupRsingle t) = scalar t scalar :: ScalarType t -> TypeQ - scalar (SingleScalarType t) = single t - scalar (VectorScalarType t) = vector t - - vector :: VectorType (Vec n a) -> TypeQ - vector (VectorType n t) = [t| Vec $(litT (numTyLit (toInteger n))) $(single t) |] + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t - single :: SingleType t -> TypeQ - single (NumSingleType t) = num t + bit :: BitType t -> TypeQ + bit TypeBit = [t| Bit |] + bit (TypeMask n) = [t| Vec $(litT (numTyLit (natVal' n))) Bit |] num :: NumType t -> TypeQ num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType t -> TypeQ - integral TypeInt = [t| Int |] - integral TypeInt8 = [t| Int8 |] - integral TypeInt16 = [t| Int16 |] - integral TypeInt32 = [t| Int32 |] - integral TypeInt64 = [t| Int64 |] - integral TypeWord = [t| Word |] - integral TypeWord8 = [t| Word8 |] - integral TypeWord16 = [t| Word16 |] - integral TypeWord32 = [t| Word32 |] - integral TypeWord64 = [t| Word64 |] + integral = \case + SingleIntegralType t -> [t| $(single t) |] + VectorIntegralType n t -> [t| Vec $(litT (numTyLit (natVal' n))) $(single t) |] + where + single :: SingleIntegralType t -> TypeQ + single TypeInt8 = [t| Int8 |] + single TypeInt16 = [t| Int16 |] + single TypeInt32 = [t| Int32 |] + single TypeInt64 = [t| Int64 |] + single TypeInt128 = [t| Int128 |] + single TypeWord8 = [t| Word8 |] + single TypeWord16 = [t| Word16 |] + single TypeWord32 = [t| Word32 |] + single TypeWord64 = [t| Word64 |] + single TypeWord128 = [t| Word128 |] floating :: FloatingType t -> TypeQ - floating TypeHalf = [t| Half |] - floating TypeFloat = [t| Float |] - floating TypeDouble = [t| Double |] + floating = \case + SingleFloatingType t -> [t| $(single t) |] + VectorFloatingType n t -> [t| Vec $(litT (numTyLit (natVal' n))) $(single t) |] + where + single :: SingleFloatingType t -> TypeQ + single TypeFloat16 = [t| Half |] + single TypeFloat32 = [t| Float |] + single TypeFloat64 = [t| Double |] + single TypeFloat128 = [t| Float128 |] + runQ $ let diff --git a/src/Data/Array/Accelerate/Representation/Vec.hs b/src/Data/Array/Accelerate/Representation/Vec.hs index bd37c7f18..c6f31570d 100644 --- a/src/Data/Array/Accelerate/Representation/Vec.hs +++ b/src/Data/Array/Accelerate/Representation/Vec.hs @@ -1,11 +1,5 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeOperators #-} -{-# OPTIONS_HADDOCK hide #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedStrings #-} -- | -- Module : Data.Array.Accelerate.Representation.Vec -- Copyright : [2008..2020] The Accelerate Team @@ -19,81 +13,99 @@ module Data.Array.Accelerate.Representation.Vec where -import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Error import Data.Array.Accelerate.Representation.Type -import Data.Primitive.Vec +import Data.Array.Accelerate.Type -import Control.Monad.ST -import Data.Primitive.ByteArray -import Data.Primitive.Types -import Language.Haskell.TH.Extra +import Data.Primitive.Bit as Prim -import GHC.Base ( Int(..), Int#, (-#) ) -import GHC.TypeNats +import qualified GHC.Exts as GHC --- | Declares the size of a SIMD vector and the type of its elements. This --- data type is used to denote the relation between a vector type (Vec --- n single) with its tuple representation (tuple). Conversions between --- those types are exposed through 'pack' and 'unpack'. --- -data VecR (n :: Nat) single tuple where - VecRnil :: SingleType s -> VecR 0 s () - VecRsucc :: VecR n s t -> VecR (n + 1) s (t, s) +toList :: TypeR v -> TypeR a -> v -> [a] +toList = go + where + go :: TypeR v -> TypeR t -> v -> [t] + go TupRunit TupRunit _ = repeat () + go (TupRpair va vb) (TupRpair ta tb) (a, b) = zip (go va ta a) (go vb tb b) + go (TupRsingle v) (TupRsingle t) xs = scalar v t xs + go _ _ _ = internalError "unexpected vector encoding" + scalar :: ScalarType v -> ScalarType t -> v -> [t] + scalar (NumScalarType v) (NumScalarType t) = num v t + scalar (BitScalarType v) (BitScalarType t) = bit v t + scalar _ _ = internalError "unexpected vector encoding" -vecRvector :: KnownNat n => VecR n s tuple -> VectorType (Vec n s) -vecRvector = uncurry VectorType . go - where - go :: VecR n s tuple -> (Int, SingleType s) - go (VecRnil tp) = (0, tp) - go (VecRsucc vec) | (n, tp) <- go vec = (n + 1, tp) + bit :: BitType v -> BitType t -> v -> [t] + bit (TypeMask _) TypeBit = GHC.toList . Prim.BitMask + bit _ _ = internalError "unexpected vector encoding" -vecRSingle :: KnownNat n => VecR n s tuple -> SingleType s -vecRSingle vecr = let (VectorType _ s) = vecRvector vecr in s + num :: NumType v -> NumType t -> v -> [t] + num (IntegralNumType v) (IntegralNumType t) = integral v t + num (FloatingNumType v) (FloatingNumType t) = floating v t + num _ _ = internalError "unexpected vector encoding" -vecRtuple :: VecR n s tuple -> TypeR tuple -vecRtuple = snd . go - where - go :: VecR n s tuple -> (SingleType s, TypeR tuple) - go (VecRnil tp) = (tp, TupRunit) - go (VecRsucc vec) | (tp, tuple) <- go vec = (tp, TupRpair tuple (TupRsingle (SingleScalarType tp))) - -pack :: forall n single tuple. KnownNat n => VecR n single tuple -> tuple -> Vec n single -pack vecR tuple - | VectorType n single <- vecRvector vecR - , SingleDict <- singleDict single - = runST $ do - mba <- newByteArray (n * sizeOf (undefined :: single)) - go (n - 1) vecR tuple mba - ByteArray ba# <- unsafeFreezeByteArray mba - return $! Vec ba# - where - go :: Prim single => Int -> VecR n' single tuple' -> tuple' -> MutableByteArray s -> ST s () - go _ (VecRnil _) () _ = return () - go i (VecRsucc r) (xs, x) mba = do - writeByteArray mba i x - go (i - 1) r xs mba - -unpack :: forall n single tuple. KnownNat n => VecR n single tuple -> Vec n single -> tuple -unpack vecR (Vec ba#) - | VectorType n single <- vecRvector vecR - , (I# n#) <- n - , SingleDict <- singleDict single - = go (n# -# 1#) vecR + integral :: IntegralType v -> IntegralType t -> v -> [t] + integral (VectorIntegralType _ TypeInt8) (SingleIntegralType TypeInt8) = GHC.toList + integral (VectorIntegralType _ TypeInt16) (SingleIntegralType TypeInt16) = GHC.toList + integral (VectorIntegralType _ TypeInt32) (SingleIntegralType TypeInt32) = GHC.toList + integral (VectorIntegralType _ TypeInt64) (SingleIntegralType TypeInt64) = GHC.toList + integral (VectorIntegralType _ TypeInt128) (SingleIntegralType TypeInt128) = GHC.toList + integral (VectorIntegralType _ TypeWord8) (SingleIntegralType TypeWord8) = GHC.toList + integral (VectorIntegralType _ TypeWord16) (SingleIntegralType TypeWord16) = GHC.toList + integral (VectorIntegralType _ TypeWord32) (SingleIntegralType TypeWord32) = GHC.toList + integral (VectorIntegralType _ TypeWord64) (SingleIntegralType TypeWord64) = GHC.toList + integral (VectorIntegralType _ TypeWord128) (SingleIntegralType TypeWord128) = GHC.toList + integral _ _ = internalError "unexpected vector encoding" + + floating :: FloatingType v -> FloatingType t -> v -> [t] + floating (VectorFloatingType _ TypeFloat16) (SingleFloatingType TypeFloat16) = GHC.toList + floating (VectorFloatingType _ TypeFloat32) (SingleFloatingType TypeFloat32) = GHC.toList + floating (VectorFloatingType _ TypeFloat64) (SingleFloatingType TypeFloat64) = GHC.toList + floating (VectorFloatingType _ TypeFloat128) (SingleFloatingType TypeFloat128) = GHC.toList + floating _ _ = internalError "unexpected vector encoding" + + +fromList :: TypeR v -> TypeR a -> [a] -> v +fromList = go where - go :: Prim single => Int# -> VecR n' single tuple' -> tuple' - go _ (VecRnil _) = () - go i# (VecRsucc r) = x `seq` xs `seq` (xs, x) - where - xs = go (i# -# 1#) r - x = indexByteArray# ba# i# - -rnfVecR :: VecR n single tuple -> () -rnfVecR (VecRnil tp) = rnfSingleType tp -rnfVecR (VecRsucc vec) = rnfVecR vec - -liftVecR :: VecR n single tuple -> CodeQ (VecR n single tuple) -liftVecR (VecRnil tp) = [|| VecRnil $$(liftSingleType tp) ||] -liftVecR (VecRsucc vec) = [|| VecRsucc $$(liftVecR vec) ||] + go :: TypeR v -> TypeR t -> [t] -> v + go TupRunit TupRunit _ = () + go (TupRpair va vb) (TupRpair ta tb) xs = let (as, bs) = unzip xs in (go va ta as, go vb tb bs) + go (TupRsingle v) (TupRsingle t) xs = scalar v t xs + go _ _ _ = error "unexpected vector encoding" + + scalar :: ScalarType v -> ScalarType t -> [t] -> v + scalar (NumScalarType v) (NumScalarType t) = num v t + scalar (BitScalarType v) (BitScalarType t) = bit v t + scalar _ _ = internalError "unexpected vector encoding" + + bit :: BitType v -> BitType t -> [t] -> v + bit (TypeMask _) TypeBit = Prim.unMask . GHC.fromList + bit _ _ = internalError "unexpected vector encoding" + + num :: NumType v -> NumType t -> [t] -> v + num (IntegralNumType v) (IntegralNumType t) = integral v t + num (FloatingNumType v) (FloatingNumType t) = floating v t + num _ _ = internalError "unexpected vector encoding" + + integral :: IntegralType v -> IntegralType t -> [t] -> v + integral (VectorIntegralType _ TypeInt8) (SingleIntegralType TypeInt8) = GHC.fromList + integral (VectorIntegralType _ TypeInt16) (SingleIntegralType TypeInt16) = GHC.fromList + integral (VectorIntegralType _ TypeInt32) (SingleIntegralType TypeInt32) = GHC.fromList + integral (VectorIntegralType _ TypeInt64) (SingleIntegralType TypeInt64) = GHC.fromList + integral (VectorIntegralType _ TypeInt128) (SingleIntegralType TypeInt128) = GHC.fromList + integral (VectorIntegralType _ TypeWord8) (SingleIntegralType TypeWord8) = GHC.fromList + integral (VectorIntegralType _ TypeWord16) (SingleIntegralType TypeWord16) = GHC.fromList + integral (VectorIntegralType _ TypeWord32) (SingleIntegralType TypeWord32) = GHC.fromList + integral (VectorIntegralType _ TypeWord64) (SingleIntegralType TypeWord64) = GHC.fromList + integral (VectorIntegralType _ TypeWord128) (SingleIntegralType TypeWord128) = GHC.fromList + integral _ _ = internalError "unexpected vector encoding" + + floating :: FloatingType v -> FloatingType t -> [t] -> v + floating (VectorFloatingType _ TypeFloat16) (SingleFloatingType TypeFloat16) = GHC.fromList + floating (VectorFloatingType _ TypeFloat32) (SingleFloatingType TypeFloat32) = GHC.fromList + floating (VectorFloatingType _ TypeFloat64) (SingleFloatingType TypeFloat64) = GHC.fromList + floating (VectorFloatingType _ TypeFloat128) (SingleFloatingType TypeFloat128) = GHC.fromList + floating _ _ = internalError "unexpected vector encoding" diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index ccb38e7ab..54419f438 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -1,10 +1,11 @@ {-# LANGUAGE AllowAmbiguousTypes #-} - {-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -12,7 +13,6 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE PolyKinds #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Smart @@ -52,34 +52,52 @@ module Data.Array.Accelerate.Smart ( -- ** Smart destructors for shapes indexHead, indexTail, - -- ** Smart constructors for constants - mkMinBound, mkMaxBound, mkPi, + -- ** Vector operations + mkPack, mkUnpack, + extract, mkExtract, + insert, mkInsert, + shuffle, + select, + + -- ** Smart constructors for primitive functions + -- *** Operators from Num + mkAdd, mkSub, mkMul, mkNeg, mkAbs, mkSig, + + -- *** Operators from Integral + mkQuot, mkRem, mkQuotRem, mkIDiv, mkMod, mkDivMod, + + -- *** Operators from FiniteBits + mkBAnd, mkBOr, mkBXor, mkBNot, + mkBShiftL, mkBShiftR, mkBRotateL, mkBRotateR, + mkPopCount, mkCountLeadingZeros, mkCountTrailingZeros, + + -- *** Operators from Fractional and Floating + mkFDiv, mkRecip, mkSin, mkCos, mkTan, mkAsin, mkAcos, mkAtan, mkSinh, mkCosh, mkTanh, mkAsinh, mkAcosh, mkAtanh, - mkExpFloating, mkSqrt, mkLog, + mkExpFloating, + mkSqrt, mkLog, mkFPow, mkLogBase, + + -- *** Operators from RealFrac and RealFloat mkTruncate, mkRound, mkFloor, mkCeiling, mkAtan2, + mkIsNaN, mkIsInfinite, - -- ** Smart constructors for primitive functions - mkAdd, mkSub, mkMul, mkNeg, mkAbs, mkSig, mkQuot, mkRem, mkQuotRem, mkIDiv, mkMod, mkDivMod, - mkBAnd, mkBOr, mkBXor, mkBNot, mkBShiftL, mkBShiftR, mkBRotateL, mkBRotateR, mkPopCount, mkCountLeadingZeros, mkCountTrailingZeros, - mkFDiv, mkRecip, mkLt, mkGt, mkLtEq, mkGtEq, mkEq, mkNEq, mkMax, mkMin, - mkLAnd, mkLOr, mkLNot, mkIsNaN, mkIsInfinite, + -- *** Relational and equality operators + mkLt, mkGt, mkLtEq, mkGtEq, mkEq, mkNEq, mkMax, mkMin, mkLAnd, mkLOr, mkLNot, -- ** Smart constructors for type coercion functions - mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..), - - -- ** Smart constructors for vector operations - mkVectorIndex, - mkVectorWrite, + mkFromIntegral, mkToFloating, mkToBool, mkFromBool, mkBitcast, mkCoerce, Coerce(..), -- ** Auxiliary functions ($$), ($$$), ($$$$), ($$$$$), ApplyAcc(..), - unAcc, unAccFunction, mkExp, unExp, unExpFunction, unExpBinaryFunction, unPair, mkPairToTuple, + unAcc, unAccFunction, + mkExp, unExp, unExpFunction, unExpBinaryFunction, mkPrimUnary, mkPrimBinary, + unPair, mkPairToTuple, -- ** Miscellaneous formatPreAccOp, @@ -88,8 +106,9 @@ module Data.Array.Accelerate.Smart ( ) where -import Data.Proxy +import Data.Array.Accelerate.AST ( Direction(..), Message(..), PrimBool, PrimMaybe, PrimFun(..), BitOrMask, primFunType ) import Data.Array.Accelerate.AST.Idx +import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Error import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Elt @@ -98,28 +117,25 @@ import Data.Array.Accelerate.Representation.Slice import Data.Array.Accelerate.Representation.Stencil hiding ( StencilR, stencilR ) import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type -import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Sugar.Array ( Arrays ) import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Sugar.Shape ( (:.)(..) ) +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import qualified Data.Array.Accelerate.Representation.Stencil as R import qualified Data.Array.Accelerate.Sugar.Array as Sugar import qualified Data.Array.Accelerate.Sugar.Shape as Sugar -import Data.Array.Accelerate.AST ( Direction(..), Message(..) - , PrimBool, PrimMaybe - , PrimFun(..), primFunType - , PrimConst(..), primConstType ) -import Data.Primitive.Vec +import qualified Data.Primitive.Vec as Prim import Data.Kind import Data.Text.Lazy.Builder import Formatting +import GHC.Prim import GHC.TypeLits +import GHC.TypeLits.Extra -- Array computations @@ -413,30 +429,30 @@ data PreSmartAcc acc exp as where Fold :: TypeR e -> (SmartExp e -> SmartExp e -> exp e) -> Maybe (exp e) - -> acc (Array (sh, Int) e) + -> acc (Array (sh, INT) e) -> PreSmartAcc acc exp (Array sh e) - FoldSeg :: IntegralType i + FoldSeg :: SingleIntegralType i -> TypeR e -> (SmartExp e -> SmartExp e -> exp e) -> Maybe (exp e) - -> acc (Array (sh, Int) e) + -> acc (Array (sh, INT) e) -> acc (Segments i) - -> PreSmartAcc acc exp (Array (sh, Int) e) + -> PreSmartAcc acc exp (Array (sh, INT) e) Scan :: Direction -> TypeR e -> (SmartExp e -> SmartExp e -> exp e) -> Maybe (exp e) - -> acc (Array (sh, Int) e) - -> PreSmartAcc acc exp (Array (sh, Int) e) + -> acc (Array (sh, INT) e) + -> PreSmartAcc acc exp (Array (sh, INT) e) Scan' :: Direction -> TypeR e -> (SmartExp e -> SmartExp e -> exp e) -> exp e - -> acc (Array (sh, Int) e) - -> PreSmartAcc acc exp (Array (sh, Int) e, Array sh e) + -> acc (Array (sh, INT) e) + -> PreSmartAcc acc exp (Array (sh, INT) e, Array sh e) Permute :: ArrayR (Array sh e) -> (SmartExp e -> SmartExp e -> exp e) @@ -517,39 +533,39 @@ data PreSmartExp acc exp t where -> exp (t1, t2) -> PreSmartExp acc exp t - VecPack :: KnownNat n - => VecR n s tup - -> exp tup - -> PreSmartExp acc exp (Vec n s) - - VecUnpack :: KnownNat n - => VecR n s tup - -> exp (Vec n s) - -> PreSmartExp acc exp tup - - VecIndex :: (KnownNat n, v ~ Vec n s) - => VectorType v - -> IntegralType i - -> exp (Vec n s) + Extract :: ScalarType (Prim.Vec n a) + -> SingleIntegralType i + -> exp (Prim.Vec n a) -> exp i - -> PreSmartExp acc exp s + -> PreSmartExp acc exp a - VecWrite :: (KnownNat n, v ~ Vec n s) - => VectorType v - -> IntegralType i - -> exp (Vec n s) + Insert :: ScalarType (Prim.Vec n a) + -> SingleIntegralType i + -> exp (Prim.Vec n a) -> exp i - -> exp s - -> PreSmartExp acc exp (Vec n s) + -> exp a + -> PreSmartExp acc exp (Prim.Vec n a) + + Shuffle :: ScalarType (Prim.Vec m a) + -> SingleIntegralType i + -> exp (Prim.Vec n a) + -> exp (Prim.Vec n a) + -> exp (Prim.Vec m i) + -> PreSmartExp acc exp (Prim.Vec m a) + + Select :: exp (Prim.Vec n Bit) + -> exp (Prim.Vec n a) + -> exp (Prim.Vec n a) + -> PreSmartExp acc exp (Prim.Vec n a) ToIndex :: ShapeR sh -> exp sh -> exp sh - -> PreSmartExp acc exp Int + -> PreSmartExp acc exp INT FromIndex :: ShapeR sh -> exp sh - -> exp Int + -> exp INT -> PreSmartExp acc exp sh Case :: exp a @@ -567,9 +583,6 @@ data PreSmartExp acc exp t where -> exp t -> PreSmartExp acc exp t - PrimConst :: PrimConst t - -> PreSmartExp acc exp t - PrimApp :: PrimFun (a -> r) -> exp a -> PreSmartExp acc exp r @@ -581,7 +594,7 @@ data PreSmartExp acc exp t where LinearIndex :: TypeR t -> acc (Array sh t) - -> exp Int + -> exp INT -> PreSmartExp acc exp t Shape :: ShapeR sh @@ -590,7 +603,7 @@ data PreSmartExp acc exp t where ShapeSize :: ShapeR sh -> exp sh - -> PreSmartExp acc exp Int + -> PreSmartExp acc exp INT Foreign :: Foreign asm => TypeR y @@ -830,29 +843,29 @@ instance HasArraysR acc => HasArraysR (PreSmartAcc acc exp) where Aprj _ _ -> error "Ejector seat? You're joking!" Atrace _ _ a -> arraysR a Use repr _ -> TupRsingle repr - Unit tp _ -> TupRsingle $ ArrayR ShapeRz $ tp + Unit aR _ -> TupRsingle $ ArrayR ShapeRz $ aR Generate repr _ _ -> TupRsingle repr - Reshape shr _ a -> let ArrayR _ tp = arrayR a - in TupRsingle $ ArrayR shr tp - Replicate si _ a -> let ArrayR _ tp = arrayR a - in TupRsingle $ ArrayR (sliceDomainR si) tp - Slice si a _ -> let ArrayR _ tp = arrayR a - in TupRsingle $ ArrayR (sliceShapeR si) tp - Map _ tp _ a -> let ArrayR shr _ = arrayR a - in TupRsingle $ ArrayR shr tp - ZipWith _ _ tp _ a _ -> let ArrayR shr _ = arrayR a - in TupRsingle $ ArrayR shr tp - Fold _ _ _ a -> let ArrayR (ShapeRsnoc shr) tp = arrayR a - in TupRsingle (ArrayR shr tp) + Reshape shr _ a -> let ArrayR _ aR = arrayR a + in TupRsingle $ ArrayR shr aR + Replicate si _ a -> let ArrayR _ aR = arrayR a + in TupRsingle $ ArrayR (sliceDomainR si) aR + Slice si a _ -> let ArrayR _ aR = arrayR a + in TupRsingle $ ArrayR (sliceShapeR si) aR + Map _ aR _ a -> let ArrayR shr _ = arrayR a + in TupRsingle $ ArrayR shr aR + ZipWith _ _ aR _ a _ -> let ArrayR shr _ = arrayR a + in TupRsingle $ ArrayR shr aR + Fold _ _ _ a -> let ArrayR (ShapeRsnoc shr) aR = arrayR a + in TupRsingle (ArrayR shr aR) FoldSeg _ _ _ _ a _ -> arraysR a Scan _ _ _ _ a -> arraysR a - Scan' _ _ _ _ a -> let repr@(ArrayR (ShapeRsnoc shr) tp) = arrayR a - in TupRsingle repr `TupRpair` TupRsingle (ArrayR shr tp) + Scan' _ _ _ _ a -> let repr@(ArrayR (ShapeRsnoc shr) aR) = arrayR a + in TupRsingle repr `TupRpair` TupRsingle (ArrayR shr aR) Permute _ _ a _ _ -> arraysR a - Backpermute shr _ _ a -> let ArrayR _ tp = arrayR a - in TupRsingle (ArrayR shr tp) - Stencil s tp _ _ _ -> TupRsingle $ ArrayR (stencilShapeR s) tp - Stencil2 s _ tp _ _ _ _ _ -> TupRsingle $ ArrayR (stencilShapeR s) tp + Backpermute shr _ _ a -> let ArrayR _ aR = arrayR a + in TupRsingle (ArrayR shr aR) + Stencil s aR _ _ _ -> TupRsingle $ ArrayR (stencilShapeR s) aR + Stencil2 s _ aR _ _ _ _ _ -> TupRsingle $ ArrayR (stencilShapeR s) aR class HasTypeR f where @@ -863,9 +876,9 @@ instance HasTypeR SmartExp where instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where typeR = \case - Tag tp _ -> tp + Tag tR _ -> tR Match _ e -> typeR e - Const tp _ -> TupRsingle tp + Const tR _ -> TupRsingle tR Nil -> TupRunit Pair e1 e2 -> typeR e1 `TupRpair` typeR e2 Prj idx e @@ -873,25 +886,44 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where PairIdxLeft -> t1 PairIdxRight -> t2 Prj _ _ -> error "I never joke about my work" - VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR - VecUnpack vecR _ -> vecRtuple vecR - VecIndex vecT _ _ _ -> let (VectorType _ s) = vecT in TupRsingle $ SingleScalarType s - VecWrite vecT _ _ _ _ -> TupRsingle $ VectorScalarType vecT - ToIndex _ _ _ -> TupRsingle scalarTypeInt + Extract vR _ _ _ -> TupRsingle (scalar vR) + where + scalar :: ScalarType (Prim.Vec n a) -> ScalarType a + scalar (NumScalarType t) = NumScalarType (num t) + scalar (BitScalarType t) = BitScalarType (bit t) + + bit :: BitType (Prim.Vec n a) -> BitType a + bit TypeMask{} = TypeBit + + num :: NumType (Prim.Vec n a) -> NumType a + num (IntegralNumType t) = IntegralNumType (integral t) + num (FloatingNumType t) = FloatingNumType (floating t) + + integral :: IntegralType (Prim.Vec n a) -> IntegralType a + integral (SingleIntegralType t) = case t of + integral (VectorIntegralType _ t) = SingleIntegralType t + + floating :: FloatingType (Prim.Vec n a) -> FloatingType a + floating (SingleFloatingType t) = case t of + floating (VectorFloatingType _ t) = SingleFloatingType t + -- + Insert t _ _ _ _ -> TupRsingle t + Shuffle t _ _ _ _ -> TupRsingle t + Select _ x _ -> typeR x + ToIndex _ _ _ -> TupRsingle (scalarType @INT) FromIndex shr _ _ -> shapeType shr Case _ ((_,c):_) -> typeR c Case{} -> internalError "encountered empty case" Cond _ e _ -> typeR e While t _ _ _ -> t - PrimConst c -> TupRsingle $ primConstType c PrimApp f _ -> snd $ primFunType f - Index tp _ _ -> tp - LinearIndex tp _ _ -> tp + Index tR _ _ -> tR + LinearIndex tR _ _ -> tR Shape shr _ -> shapeType shr - ShapeSize _ _ -> TupRsingle scalarTypeInt - Foreign tp _ _ _ -> tp - Undef tp -> TupRsingle tp - Coerce _ tp _ -> TupRsingle tp + ShapeSize _ _ -> TupRsingle (scalarType @INT) + Foreign tR _ _ _ -> tR + Undef tR -> TupRsingle tR + Coerce _ tR _ -> TupRsingle tR -- Smart constructors @@ -914,9 +946,10 @@ constant = Exp . go (eltR @e) . fromElt where go :: HasCallStack => TypeR t -> t -> SmartExp t go TupRunit () = SmartExp $ Nil - go (TupRsingle tp) c = SmartExp $ Const tp c + go (TupRsingle tR) c = SmartExp $ Const tR c go (TupRpair t1 t2) (c1, c2) = SmartExp $ go t1 c1 `Pair` go t2 c2 + -- | 'undef' can be used anywhere a constant is expected, and indicates that the -- consumer of the value can receive an unspecified bit pattern. -- @@ -948,6 +981,7 @@ undef = Exp $ go $ eltR @e go (TupRsingle t) = SmartExp $ Undef t go (TupRpair t1 t2) = SmartExp $ go t1 `Pair` go t2 + -- | Get the innermost dimension of a shape. -- -- The innermost dimension (right-most component of the shape) is the index of @@ -967,17 +1001,223 @@ indexTail :: (Elt sh, Elt a) => Exp (sh :. a) -> Exp sh indexTail (Exp x) = mkExp $ Prj PairIdxLeft x --- Smart constructor for constants --- - -mkMinBound :: (Elt t, IsBounded (EltR t)) => Exp t -mkMinBound = mkExp $ PrimConst (PrimMinBound boundedType) +mkUnpack :: forall n a. (SIMD n a, Elt a) => Exp (Vec n a) -> [Exp a] +mkUnpack v = + let n = fromIntegral (natVal' (proxy# :: Proxy# n)) :: Word8 + in map (extract v . constant) [0 .. n-1] -mkMaxBound :: (Elt t, IsBounded (EltR t)) => Exp t -mkMaxBound = mkExp $ PrimConst (PrimMaxBound boundedType) +mkPack :: forall n a. (SIMD n a, Elt a) => [Exp a] -> Exp (Vec n a) +mkPack xs = + let go :: Word8 -> [Exp a] -> Exp (Vec n a) -> Exp (Vec n a) + go _ [] vec = vec + go i (v:vs) vec = go (i+1) vs (insert vec (constant i) v) + in + go 0 xs undef -mkPi :: (Elt r, IsFloating (EltR r)) => Exp r -mkPi = mkExp $ PrimConst (PrimPi floatingType) +-- | Extract a single scalar element from the given SIMD vector at the +-- specified index +-- +-- @since 1.4.0.0 +-- +extract :: forall n a i. (Elt a, SIMD n a, IsSingleIntegral (EltR i)) + => Exp (Vec n a) + -> Exp i + -> Exp a +extract (Exp v) (Exp i) = Exp $ mkExtract (vecR @n @a) (eltR @a) singleIntegralType v i + +mkExtract :: TypeR v -> TypeR e -> SingleIntegralType i -> SmartExp v -> SmartExp i -> SmartExp e +mkExtract _vR _eR iR _v i = go _vR _eR _v + where + go :: TypeR v -> TypeR e -> SmartExp v -> SmartExp e + go TupRunit TupRunit _ = SmartExp Nil + go (TupRsingle vR) (TupRsingle eR) v = scalar vR eR v + go (TupRpair vR1 vR2) (TupRpair eR1 eR2) v = + let (v1, v2) = unPair v + in SmartExp $ go vR1 eR1 v1 `Pair` go vR2 eR2 v2 + go _ _ _ = error "impossible" + + scalar :: ScalarType v -> ScalarType e -> SmartExp v -> SmartExp e + scalar (NumScalarType vR) (NumScalarType eR) = num vR eR + scalar (BitScalarType vR) (BitScalarType eR) = bit vR eR + scalar _ _ = error "impossible" + + bit :: BitType v -> BitType e -> SmartExp v -> SmartExp e + bit TypeMask{} TypeBit v = SmartExp $ Extract scalarType iR v i + bit _ _ _ = error "impossible" + + num :: NumType v -> NumType e -> SmartExp v -> SmartExp e + num (IntegralNumType vR) (IntegralNumType eR) = integral vR eR + num (FloatingNumType vR) (FloatingNumType eR) = floating vR eR + num _ _ = error "impossible" + + integral :: IntegralType v -> IntegralType e -> SmartExp v -> SmartExp e + integral (VectorIntegralType n vR) (SingleIntegralType tR) v + | Just Refl <- matchSingleIntegralType vR tR + = SmartExp $ Extract (NumScalarType (IntegralNumType (VectorIntegralType n vR))) iR v i + integral _ _ _ = error "impossible" + + floating :: FloatingType v -> FloatingType e -> SmartExp v -> SmartExp e + floating (VectorFloatingType n vR) (SingleFloatingType tR) v + | Just Refl <- matchSingleFloatingType vR tR + = SmartExp $ Extract (NumScalarType (FloatingNumType (VectorFloatingType n vR))) iR v i + floating _ _ _ = error "impossible" + + +-- | Insert a scalar element into the given SIMD vector at the specified +-- index +-- +-- @since 1.4.0.0 +-- +insert :: forall n a i. (Elt a, SIMD n a, IsSingleIntegral (EltR i)) + => Exp (Vec n a) + -> Exp i + -> Exp a + -> Exp (Vec n a) +insert (Exp v) (Exp i) (Exp x) = Exp $ mkInsert (vecR @n @a) (eltR @a) singleIntegralType v i x + +mkInsert :: TypeR v -> TypeR e -> SingleIntegralType i -> SmartExp v -> SmartExp i -> SmartExp e -> SmartExp v +mkInsert _vR _eR iR _v i _x = go _vR _eR _v _x + where + go :: TypeR v -> TypeR e -> SmartExp v -> SmartExp e -> SmartExp v + go TupRunit TupRunit _ _ = SmartExp Nil + go (TupRsingle vR) (TupRsingle eR) v e = scalar vR eR v e + go (TupRpair vR1 vR2) (TupRpair eR1 eR2) v e = + let (v1, v2) = unPair v + (e1, e2) = unPair e + in SmartExp $ go vR1 eR1 v1 e1 `Pair` go vR2 eR2 v2 e2 + go _ _ _ _ = error "impossible" + + scalar :: ScalarType v -> ScalarType e -> SmartExp v -> SmartExp e -> SmartExp v + scalar (NumScalarType vR) (NumScalarType eR) = num vR eR + scalar (BitScalarType vR) (BitScalarType eR) = bit vR eR + scalar _ _ = error "impossible" + + bit :: BitType v -> BitType e -> SmartExp v -> SmartExp e -> SmartExp v + bit TypeMask{} TypeBit v = SmartExp . Insert scalarType iR v i + bit _ _ _ = error "impossible" + + num :: NumType v -> NumType e -> SmartExp v -> SmartExp e -> SmartExp v + num (IntegralNumType vR) (IntegralNumType eR) = integral vR eR + num (FloatingNumType vR) (FloatingNumType eR) = floating vR eR + num _ _ = error "impossible" + + integral :: IntegralType v -> IntegralType e -> SmartExp v -> SmartExp e -> SmartExp v + integral (VectorIntegralType n vR) (SingleIntegralType tR) v + | Just Refl <- matchSingleIntegralType vR tR + = SmartExp . Insert (NumScalarType (IntegralNumType (VectorIntegralType n vR))) iR v i + integral _ _ _ = error "impossible" + + floating :: FloatingType v -> FloatingType e -> SmartExp v -> SmartExp e -> SmartExp v + floating (VectorFloatingType n vR) (SingleFloatingType tR) v + | Just Refl <- matchSingleFloatingType vR tR + = SmartExp . Insert (NumScalarType (FloatingNumType (VectorFloatingType n vR))) iR v i + floating _ _ _ = error "impossible" + + +-- | Construct a permutation of elements from two input vectors +-- +-- The elements of the two input vectors are concatenated and numbered from +-- zero left-to-right. For each element in the result vector, the shuffle +-- mask selects the corresponding element from this concatenated vector to +-- copy to the result. +-- +-- @since 1.4.0.0 +-- +shuffle :: forall m n a i. (SIMD n a, SIMD m a, SIMD m i, IsSingleIntegral (EltR i)) + => Exp (Vec n a) + -> Exp (Vec n a) + -> Exp (Vec m i) + -> Exp (Vec m a) +shuffle (Exp xs) (Exp ys) (Exp i) = Exp $ go (vecR @n @a) (vecR @m @a) xs ys + where + go :: TypeR t -> TypeR r -> SmartExp t -> SmartExp t -> SmartExp r + go TupRunit TupRunit _ _ = SmartExp Nil + go (TupRsingle vR) (TupRsingle rR) x y = scalar vR rR x y + go (TupRpair vR1 vR2) (TupRpair rR1 rR2) x y = + let (x1, x2) = unPair x + (y1, y2) = unPair y + in SmartExp $ go vR1 rR1 x1 y1 `Pair` go vR2 rR2 x2 y2 + go _ _ _ _ = error "impossible" + + scalar :: ScalarType t -> ScalarType r -> SmartExp t -> SmartExp t -> SmartExp r + scalar (NumScalarType vR) (NumScalarType rR) = num vR rR + scalar (BitScalarType vR) (BitScalarType rR) = bit vR rR + scalar _ _ = error "impossible" + + bit :: BitType t -> BitType r -> SmartExp t -> SmartExp t -> SmartExp r + bit (TypeMask _) (TypeMask m) x y + | TupRsingle (NumScalarType (IntegralNumType (VectorIntegralType m' iR))) <- vecR @m @i + , Just Refl <- sameNat' m m' + = SmartExp $ Shuffle (BitScalarType (TypeMask m)) iR x y i + bit _ _ _ _ = error "impossible" + + num :: NumType t -> NumType r -> SmartExp t -> SmartExp t -> SmartExp r + num (IntegralNumType vR) (IntegralNumType rR) = integral vR rR + num (FloatingNumType vR) (FloatingNumType rR) = floating vR rR + num _ _ = error "impossible" + + integral :: IntegralType t -> IntegralType r -> SmartExp t -> SmartExp t -> SmartExp r + integral (VectorIntegralType _ vR) (VectorIntegralType m rR) x y + | TupRsingle (NumScalarType (IntegralNumType (VectorIntegralType m' iR))) <- vecR @m @i + , Just Refl <- matchSingleIntegralType vR rR + , Just Refl <- sameNat' m m' + = SmartExp $ Shuffle (NumScalarType (IntegralNumType (VectorIntegralType m vR))) iR x y i + integral _ _ _ _ = error "impossible" + + floating :: FloatingType t -> FloatingType r -> SmartExp t -> SmartExp t -> SmartExp r + floating (VectorFloatingType _ vR) (VectorFloatingType m rR) x y + | TupRsingle (NumScalarType (IntegralNumType (VectorIntegralType m' iR))) <- vecR @m @i + , Just Refl <- matchSingleFloatingType vR rR + , Just Refl <- sameNat' m m' + = SmartExp $ Shuffle (NumScalarType (FloatingNumType (VectorFloatingType m vR))) iR x y i + floating _ _ _ _ = error "impossible" + + +-- | Choose one value based on a condition, without branching. This is strict in +-- both arguments. +-- +-- @since 1.4.0.0 +-- +select :: forall n a. SIMD n a + => Exp (Vec n Bool) + -> Exp (Vec n a) + -> Exp (Vec n a) + -> Exp (Vec n a) +select (Exp mask) (Exp tt) (Exp ff) = Exp $ go (vecR @n @a) tt ff + where + go :: TypeR t -> SmartExp t -> SmartExp t -> SmartExp t + go TupRunit _ _ = SmartExp Nil + go (TupRsingle vR) t f = scalar vR t f + go (TupRpair vR1 vR2) t f = + let (t1, t2) = unPair t + (f1, f2) = unPair f + in SmartExp $ go vR1 t1 f1 `Pair` go vR2 t2 f2 + + scalar :: ScalarType t -> SmartExp t -> SmartExp t -> SmartExp t + scalar (NumScalarType vR) = num vR + scalar (BitScalarType vR) = bit vR + + bit :: BitType t -> SmartExp t -> SmartExp t -> SmartExp t + bit (TypeMask n) + | Just Refl <- sameNat' n (proxy# :: Proxy# n) + = SmartExp $$ Select mask + bit _ = error "impossible" + + num :: NumType t -> SmartExp t -> SmartExp t -> SmartExp t + num (IntegralNumType vR) = integral vR + num (FloatingNumType vR) = floating vR + + integral :: IntegralType t -> SmartExp t -> SmartExp t -> SmartExp t + integral (VectorIntegralType n _) + | Just Refl <- sameNat' n (proxy# :: Proxy# n) + = SmartExp $$ Select mask + integral _ = error "impossible" + + floating :: FloatingType t -> SmartExp t -> SmartExp t -> SmartExp t + floating (VectorFloatingType n _) + | Just Refl <- sameNat' n (proxy# :: Proxy# n) + = SmartExp $$ Select mask + floating _ = error "impossible" -- Smart constructors for primitive applications @@ -1067,7 +1307,8 @@ mkRem = mkPrimBinary $ PrimRem integralType mkQuotRem :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> (Exp t, Exp t) mkQuotRem (Exp x) (Exp y) = let pair = SmartExp $ PrimQuotRem integralType `PrimApp` SmartExp (Pair x y) - in (mkExp $ Prj PairIdxLeft pair, mkExp $ Prj PairIdxRight pair) + in ( mkExp $ Prj PairIdxLeft pair + , mkExp $ Prj PairIdxRight pair) mkIDiv :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t mkIDiv = mkPrimBinary $ PrimIDiv integralType @@ -1078,7 +1319,8 @@ mkMod = mkPrimBinary $ PrimMod integralType mkDivMod :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> (Exp t, Exp t) mkDivMod (Exp x) (Exp y) = let pair = SmartExp $ PrimDivMod integralType `PrimApp` SmartExp (Pair x y) - in (mkExp $ Prj PairIdxLeft pair, mkExp $ Prj PairIdxRight pair) + in ( mkExp $ Prj PairIdxLeft pair + , mkExp $ Prj PairIdxRight pair) -- Operators from Bits and FiniteBits @@ -1094,28 +1336,27 @@ mkBXor = mkPrimBinary $ PrimBXor integralType mkBNot :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t mkBNot = mkPrimUnary $ PrimBNot integralType -mkBShiftL :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t +mkBShiftL :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t mkBShiftL = mkPrimBinary $ PrimBShiftL integralType -mkBShiftR :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t +mkBShiftR :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t mkBShiftR = mkPrimBinary $ PrimBShiftR integralType -mkBRotateL :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t +mkBRotateL :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t mkBRotateL = mkPrimBinary $ PrimBRotateL integralType -mkBRotateR :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t +mkBRotateR :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t mkBRotateR = mkPrimBinary $ PrimBRotateR integralType -mkPopCount :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int +mkPopCount :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t mkPopCount = mkPrimUnary $ PrimPopCount integralType -mkCountLeadingZeros :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int +mkCountLeadingZeros :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t mkCountLeadingZeros = mkPrimUnary $ PrimCountLeadingZeros integralType -mkCountTrailingZeros :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int +mkCountTrailingZeros :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t mkCountTrailingZeros = mkPrimUnary $ PrimCountTrailingZeros integralType - -- Operators from Fractional mkFDiv :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t @@ -1143,68 +1384,48 @@ mkCeiling = mkPrimUnary $ PrimCeiling floatingType integralType mkAtan2 :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t mkAtan2 = mkPrimBinary $ PrimAtan2 floatingType -mkIsNaN :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp Bool -mkIsNaN = mkPrimUnaryBool $ PrimIsNaN floatingType - -mkIsInfinite :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp Bool -mkIsInfinite = mkPrimUnaryBool $ PrimIsInfinite floatingType +mkIsNaN :: (Elt t, IsFloating (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp Bool +mkIsNaN = mkPrimUnary $ PrimIsNaN floatingType --- FIXME: add missing operations from Floating, RealFrac & RealFloat +mkIsInfinite :: (Elt t, IsFloating (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp Bool +mkIsInfinite = mkPrimUnary $ PrimIsInfinite floatingType -- Relational and equality operators -mkLt :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool -mkLt = mkPrimBinaryBool $ PrimLt singleType +mkLt :: (Elt t, IsScalar (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp Bool +mkLt = mkPrimBinary $ PrimLt scalarType -mkGt :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool -mkGt = mkPrimBinaryBool $ PrimGt singleType +mkGt :: (Elt t, IsScalar (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp Bool +mkGt = mkPrimBinary $ PrimGt scalarType -mkLtEq :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool -mkLtEq = mkPrimBinaryBool $ PrimLtEq singleType +mkLtEq :: (Elt t, IsScalar (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp Bool +mkLtEq = mkPrimBinary $ PrimLtEq scalarType -mkGtEq :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool -mkGtEq = mkPrimBinaryBool $ PrimGtEq singleType +mkGtEq :: (Elt t, IsScalar (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp Bool +mkGtEq = mkPrimBinary $ PrimGtEq scalarType -mkEq :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool -mkEq = mkPrimBinaryBool $ PrimEq singleType +mkEq :: (Elt t, IsScalar (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp Bool +mkEq = mkPrimBinary $ PrimEq scalarType -mkNEq :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool -mkNEq = mkPrimBinaryBool $ PrimNEq singleType +mkNEq :: (Elt t, IsScalar (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp Bool +mkNEq = mkPrimBinary $ PrimNEq scalarType -mkMax :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp t -mkMax = mkPrimBinary $ PrimMax singleType +mkMax :: (Elt t, IsScalar (EltR t)) => Exp t -> Exp t -> Exp t +mkMax = mkPrimBinary $ PrimMax scalarType -mkMin :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp t -mkMin = mkPrimBinary $ PrimMin singleType +mkMin :: (Elt t, IsScalar (EltR t)) => Exp t -> Exp t -> Exp t +mkMin = mkPrimBinary $ PrimMin scalarType -- Logical operators -mkLAnd :: Exp Bool -> Exp Bool -> Exp Bool -mkLAnd (Exp a) (Exp b) = mkExp $ SmartExp (PrimApp PrimLAnd (SmartExp $ Pair x y)) `Pair` SmartExp Nil - where - x = SmartExp $ Prj PairIdxLeft a - y = SmartExp $ Prj PairIdxLeft b +mkLAnd :: (Elt t, IsBit (EltR t)) => Exp t -> Exp t -> Exp t +mkLAnd = mkPrimBinary $ PrimLAnd bitType -mkLOr :: Exp Bool -> Exp Bool -> Exp Bool -mkLOr (Exp a) (Exp b) = mkExp $ SmartExp (PrimApp PrimLOr (SmartExp $ Pair x y)) `Pair` SmartExp Nil - where - x = SmartExp $ Prj PairIdxLeft a - y = SmartExp $ Prj PairIdxLeft b - -mkLNot :: Exp Bool -> Exp Bool -mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil - where - x = SmartExp $ Prj PairIdxLeft a - - -inferNat :: forall n. KnownNat n => Int -inferNat = fromInteger $ natVal (Proxy @n) - -mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -mkVectorIndex (Exp v) (Exp i) = mkExp $ VecIndex (VectorType (inferNat @n) singleType) integralType v i +mkLOr :: (Elt t, IsBit (EltR t)) => Exp t -> Exp t -> Exp t +mkLOr = mkPrimBinary $ PrimLOr bitType -mkVectorWrite :: forall n a. (KnownNat n, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -> Exp (Vec n a) -mkVectorWrite (Exp v) (Exp i) (Exp el) = mkExp $ VecWrite (VectorType (inferNat @n) singleType) integralType v i el +mkLNot :: (Elt t, IsBit (EltR t)) => Exp t -> Exp t +mkLNot = mkPrimUnary $ PrimLNot bitType -- Numeric conversions @@ -1214,6 +1435,12 @@ mkFromIntegral = mkPrimUnary $ PrimFromIntegral integralType numType mkToFloating :: (Elt a, Elt b, IsNum (EltR a), IsFloating (EltR b)) => Exp a -> Exp b mkToFloating = mkPrimUnary $ PrimToFloating numType floatingType +mkToBool :: (Elt a, IsSingleIntegral (EltR a), BitOrMask (EltR a) ~ Bit) => Exp a -> Exp Bool +mkToBool = mkPrimUnary $ PrimToBool (SingleIntegralType singleIntegralType) bitType + +mkFromBool :: (Elt a, IsSingleIntegral (EltR a), BitOrMask (EltR a) ~ Bit) => Exp Bool -> Exp a +mkFromBool = mkPrimUnary $ PrimFromBool bitType (SingleIntegralType singleIntegralType) + -- Other conversions -- NOTE: Restricted to scalar types with a type-level BitSizeEq constraint to @@ -1249,7 +1476,6 @@ instance Coerce a (a, ()) where mkCoerce' a = SmartExp (Pair a (SmartExp Nil)) - -- Auxiliary functions -- -------------------- @@ -1293,15 +1519,6 @@ mkPrimUnary prim (Exp a) = mkExp $ PrimApp prim a mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b) -mkPrimTernary :: (Elt a, Elt b, Elt c, Elt d) => PrimFun ((EltR a, (EltR b, EltR c)) -> EltR d) -> Exp a -> Exp b -> Exp c -> Exp d -mkPrimTernary prim (Exp a) (Exp b) (Exp c) = mkExp $ PrimApp prim (SmartExp $ Pair a (SmartExp (Pair b c))) - -mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool -mkPrimUnaryBool = mkCoerce @PrimBool $$ mkPrimUnary - -mkPrimBinaryBool :: (Elt a, Elt b) => PrimFun ((EltR a, EltR b) -> PrimBool) -> Exp a -> Exp b -> Exp Bool -mkPrimBinaryBool = mkCoerce @PrimBool $$$ mkPrimBinary - unPair :: SmartExp (a, b) -> (SmartExp a, SmartExp b) unPair e = (SmartExp $ Prj PairIdxLeft e, SmartExp $ Prj PairIdxRight e) @@ -1386,14 +1603,15 @@ formatPreExpOp = later $ \case Nil{} -> "Nil" Pair{} -> "Pair" Prj{} -> "Prj" - VecPack{} -> "VecPack" - VecUnpack{} -> "VecUnpack" + Extract{} -> "Extract" + Insert{} -> "Insert" + Shuffle{} -> "Shuffle" + Select{} -> "Select" ToIndex{} -> "ToIndex" FromIndex{} -> "FromIndex" Case{} -> "Case" Cond{} -> "Cond" While{} -> "While" - PrimConst{} -> "PrimConst" PrimApp{} -> "PrimApp" Index{} -> "Index" LinearIndex{} -> "LinearIndex" diff --git a/src/Data/Array/Accelerate/Sugar/Array.hs b/src/Data/Array/Accelerate/Sugar/Array.hs index 87a490246..4e6d9fb21 100644 --- a/src/Data/Array/Accelerate/Sugar/Array.hs +++ b/src/Data/Array/Accelerate/Sugar/Array.hs @@ -153,7 +153,7 @@ infixl 9 ! -- infixl 9 !! (!!) :: forall sh e. Elt e => Array sh e -> Int -> e -(!!) (Array arr) i = toElt $ R.linearIndexArray (eltR @e) arr i +(!!) (Array arr) i = toElt $ R.linearIndexArray (eltR @e) arr (fromIntegral i) -- | Create an array from its representation function, applied at each -- index of the array @@ -174,8 +174,8 @@ fromFunctionM sh f = Array <$> R.fromFunctionM (arrayR @sh @e) (fromElt sh) f' -- | Create a vector from the concatenation of the given list of vectors -- -concatVectors :: forall e. Elt e => [Vector e] -> Vector e -concatVectors = toArr . R.concatVectors (eltR @e) . map fromArr +-- concatVectors :: forall e. Elt e => [Vector e] -> Vector e +-- concatVectors = toArr . R.concatVectors (eltR @e) . map fromArr -- | Creates a new, uninitialized Accelerate array -- diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index b55158900..441c45223 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -32,6 +32,7 @@ import Data.Array.Accelerate.Type import Data.Bits import Data.Char import Data.Kind +import Foreign.C.Types import Language.Haskell.TH.Extra hiding ( Type ) import GHC.Generics @@ -45,8 +46,8 @@ import GHC.Generics -- tuples thereof, stored efficiently in memory as consecutive unpacked -- elements without pointers. It roughly consists of: -- --- * Signed and unsigned integers (8, 16, 32, and 64-bits wide) --- * Floating point numbers (half, single, and double precision) +-- * Signed and unsigned integers (8, 16, 32, 64, and 128-bits wide) +-- * Floating point numbers (IEEE half, single, double, and (optionally) quadruple precision) -- * 'Char' -- * 'Bool' -- * () @@ -163,7 +164,7 @@ instance (GElt a, GElt b) => GElt (a :*: b) where instance (GElt a, GElt b, GSumElt (a :+: b)) => GElt (a :+: b) where type GEltR t (a :+: b) = (TAG, GSumEltR t (a :+: b)) geltR t = TupRpair (TupRsingle scalarType) (gsumEltR @(a :+: b) t) - gtagsR t = uncurry TagRtag <$> gsumTagsR @(a :+: b) 0 t + gtagsR t = uncurry (TagRtag singleIntegralType) <$> gsumTagsR @(a :+: b) 0 t gfromElt = gsumFromElt 0 gtoElt (k,x) = gsumToElt k x gundef t = (0xff, gsumUndef @(a :+: b) t) @@ -282,11 +283,31 @@ untag (TupRpair ta tb) = TagRpair (untag ta) (untag tb) -- instance Elt () -instance Elt Bool instance Elt Ordering instance Elt a => Elt (Maybe a) instance (Elt a, Elt b) => Elt (Either a b) +instance Elt Bool where + type EltR Bool = Bit + eltR = TupRsingle scalarType + tagsR = [TagRbit TypeBit 0, TagRbit TypeBit 1] + toElt = unBit + fromElt = Bit + +instance Elt Int where + type EltR Int = INT + eltR = TupRsingle scalarType + tagsR = [TagRsingle scalarType] + toElt = fromIntegral + fromElt = fromIntegral + +instance Elt Word where + type EltR Word = WORD + eltR = TupRsingle scalarType + tagsR = [TagRsingle scalarType] + toElt = fromIntegral + fromElt = fromIntegral + instance Elt Char where type EltR Char = Word32 eltR = TupRsingle scalarType @@ -300,16 +321,16 @@ runQ $ do -- integralTypes :: [Name] integralTypes = - [ ''Int - , ''Int8 + [ ''Int8 , ''Int16 , ''Int32 , ''Int64 - , ''Word + , ''Int128 , ''Word8 , ''Word16 , ''Word32 , ''Word64 + , ''Word128 ] floatingTypes :: [Name] @@ -317,6 +338,7 @@ runQ $ do [ ''Half , ''Float , ''Double + , ''Float128 ] newtypes :: [Name] @@ -358,18 +380,6 @@ runQ $ do in instanceD ctx [t| Elt $res |] [] - -- mkVecElt :: Name -> Integer -> Q [Dec] - -- mkVecElt name n = - -- let t = conT name - -- v = [t| Vec $(litT (numTyLit n)) $t |] - -- in - -- [d| instance Elt $v where - -- type EltR $v = $v - -- eltR = TupRsingle scalarType - -- fromElt = id - -- toElt = id - -- |] - -- ghci> $( stringE . show =<< reify ''CFloat ) -- TyConI (NewtypeD [] Foreign.C.Types.CFloat [] Nothing (NormalC Foreign.C.Types.CFloat [(Bang NoSourceUnpackedness NoSourceStrictness,ConT GHC.Types.Float)]) []) -- @@ -391,6 +401,5 @@ runQ $ do ss <- mapM mkSimple (integralTypes ++ floatingTypes) ns <- mapM mkNewtype newtypes ts <- mapM mkTuple [2..16] - -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] return (concat ss ++ concat ns ++ ts) diff --git a/src/Data/Array/Accelerate/Sugar/Shape.hs b/src/Data/Array/Accelerate/Sugar/Shape.hs index 1ac8bd0c4..a57366075 100644 --- a/src/Data/Array/Accelerate/Sugar/Shape.hs +++ b/src/Data/Array/Accelerate/Sugar/Shape.hs @@ -136,12 +136,12 @@ data Divide sh = Divide -- | Number of dimensions of a /shape/ or /index/ (>= 0) -- rank :: forall sh. Shape sh => Int -rank = R.rank (shapeR @sh) +rank = fromIntegral $ R.rank (shapeR @sh) -- | Total number of elements in an array of the given /shape/ -- size :: forall sh. Shape sh => sh -> Int -size = R.size (shapeR @sh) . fromElt +size = fromIntegral . R.size (shapeR @sh) . fromElt -- | The empty /shape/ -- @@ -165,7 +165,7 @@ toIndex :: forall sh. Shape sh => sh -- ^ Total shape (extent) of the array -> sh -- ^ The argument index -> Int -- ^ Corresponding linear index -toIndex sh ix = R.toIndex (shapeR @sh) (fromElt sh) (fromElt ix) +toIndex sh ix = fromIntegral $ R.toIndex (shapeR @sh) (fromElt sh) (fromElt ix) -- | Inverse of 'toIndex'. -- @@ -173,7 +173,7 @@ fromIndex :: forall sh. Shape sh => sh -- ^ Total shape (extent) of the array -> Int -- ^ The argument index -> sh -- ^ Corresponding multi-dimensional index -fromIndex sh = toElt . R.fromIndex (shapeR @sh) (fromElt sh) +fromIndex sh = toElt . R.fromIndex (shapeR @sh) (fromElt sh) . fromIntegral -- | Iterate through all of the indices of a shape, applying the given -- function at each index. The index space is traversed in row-major order. @@ -210,19 +210,19 @@ shapeToRange ix = -- | Convert a shape to a list of dimensions -- shapeToList :: forall sh. Shape sh => sh -> [Int] -shapeToList = R.shapeToList (shapeR @sh) . fromElt +shapeToList = map fromIntegral . R.shapeToList (shapeR @sh) . fromElt -- | Convert a list of dimensions into a shape. If the list does not -- contain exactly the number of elements as specified by the type of the -- shape: error. -- listToShape :: forall sh. Shape sh => [Int] -> sh -listToShape = toElt . R.listToShape (shapeR @sh) +listToShape = toElt . R.listToShape (shapeR @sh) . map fromIntegral -- | Attempt to convert a list of dimensions into a shape -- listToShape' :: forall sh. Shape sh => [Int] -> Maybe sh -listToShape' = fmap toElt . R.listToShape' (shapeR @sh) +listToShape' = fmap toElt . R.listToShape' (shapeR @sh) . map fromIntegral -- | Nicely format a shape as a string -- diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index 723d32c7b..80a4cb64a 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -1,9 +1,21 @@ -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_HADDOCK hide #-} -{-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Sugar.Vec -- Copyright : [2008..2020] The Accelerate Team @@ -14,26 +26,374 @@ -- Portability : non-portable (GHC extensions) -- -module Data.Array.Accelerate.Sugar.Vec - where +module Data.Array.Accelerate.Sugar.Vec ( + + Vec(..), KnownNat, + Vec2, + Vec3, + Vec4, + Vec8, + Vec16, + SIMD(..), + +) where -import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Type -import Data.Primitive.Types -import Data.Primitive.Vec +import qualified Data.Array.Accelerate.Representation.Vec as R +import qualified Data.Primitive.Vec as Prim + +import Data.Kind +import Prettyprinter +import Language.Haskell.TH.Extra hiding ( Type ) import GHC.TypeLits -import GHC.Prim +import GHC.Generics +import qualified GHC.Exts as GHC + + +-- | SIMD vectors of fixed width +-- +data Vec n a = Vec (VecR n a) + +-- Synonyms for common vector sizes +-- +type Vec2 = Vec 2 +type Vec3 = Vec 3 +type Vec4 = Vec 4 +type Vec8 = Vec 8 +type Vec16 = Vec 16 + +instance (Show a, Elt a, SIMD n a) => Show (Vec n a) where + show = vec . toList + where + vec :: [a] -> String + vec = show + . group . encloseSep (flatAlt "< " "<") (flatAlt " >" ">") ", " + . map viaShow + +instance (Eq a, SIMD n a) => Eq (Vec n a) where + Vec x == Vec y = tuple (vecR @n @a) x y + where + tuple :: TypeR v -> v -> v -> Bool + tuple TupRunit () () = True + tuple (TupRpair aR bR) (a1,b1) (a2,b2) = tuple aR a1 a2 && tuple bR b1 b2 + tuple (TupRsingle t) a b = scalar t a b + + scalar :: ScalarType v -> v -> v -> Bool + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType v -> v -> v -> Bool + bit TypeBit = (==) + bit TypeMask{} = (==) + + num :: NumType v -> v -> v -> Bool + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType v -> v -> v -> Bool + integral (SingleIntegralType t) = case t of + TypeInt8 -> (==) + TypeInt16 -> (==) + TypeInt32 -> (==) + TypeInt64 -> (==) + TypeInt128 -> (==) + TypeWord8 -> (==) + TypeWord16 -> (==) + TypeWord32 -> (==) + TypeWord64 -> (==) + TypeWord128 -> (==) + integral (VectorIntegralType _ t) = case t of + TypeInt8 -> (==) + TypeInt16 -> (==) + TypeInt32 -> (==) + TypeInt64 -> (==) + TypeInt128 -> (==) + TypeWord8 -> (==) + TypeWord16 -> (==) + TypeWord32 -> (==) + TypeWord64 -> (==) + TypeWord128 -> (==) + + floating :: FloatingType v -> v -> v -> Bool + floating (SingleFloatingType t) = case t of + TypeFloat16 -> (==) + TypeFloat32 -> (==) + TypeFloat64 -> (==) + TypeFloat128 -> (==) + floating (VectorFloatingType _ t) = case t of + TypeFloat16 -> (==) + TypeFloat32 -> (==) + TypeFloat64 -> (==) + TypeFloat128 -> (==) + + +instance (Elt a, SIMD n a) => GHC.IsList (Vec n a) where + type Item (Vec n a) = a + toList = toList + fromList = fromList + +toList :: forall n a. (Elt a, SIMD n a) => Vec n a -> [a] +toList (Vec vs) = map toElt $ R.toList (vecR @n @a) (eltR @a) vs + +fromList :: forall n a. (Elt a, SIMD n a) => [a] -> Vec n a +fromList = Vec . R.fromList (vecR @n @a) (eltR @a) . map fromElt + +instance SIMD n a => Elt (Vec n a) where + type EltR (Vec n a) = VecR n a + eltR = vecR @n @a + tagsR = vtagsR @n @a + toElt = Vec + fromElt (Vec a) = a + +-- | The 'SIMD' class characterises the subset of scalar element types from +-- 'Elt' that can be packed into SIMD vectors. +-- +-- @since 1.4.0.0 +-- +class KnownNat n => SIMD n a where + type VecR n a :: Type + type VecR n a = GVecR () n (Rep a) + + vecR :: TypeR (VecR n a) + vtagsR :: [TagR (VecR n a)] -- this will quickly get out of hand! + + default vecR + :: (GVec n (Rep a), VecR n a ~ GVecR () n (Rep a)) + => TypeR (VecR n a) + vecR = gvecR @n @(Rep a) TupRunit + + default vtagsR + :: (GVec n (Rep a), VecR n a ~ GVecR () n (Rep a)) + => [TagR (VecR n a)] + vtagsR = gvtagsR @n @(Rep a) TagRunit + +class KnownNat n => GVec n (f :: Type -> Type) where + type GVecR t n f + gvecR :: TypeR t -> TypeR (GVecR t n f) + gvtagsR :: TagR t -> [TagR (GVecR t n f)] + +instance KnownNat n => GVec n U1 where + type GVecR t n U1 = t + gvecR t = t + gvtagsR t = [t] + +instance GVec n a => GVec n (M1 i c a) where + type GVecR t n (M1 i c a) = GVecR t n a + gvecR = gvecR @n @a + gvtagsR = gvtagsR @n @a + +instance SIMD n a => GVec n (K1 i a) where + type GVecR t n (K1 i a) = (t, VecR n a) + gvecR t = TupRpair t (vecR @n @a) + gvtagsR t = TagRpair t <$> vtagsR @n @a + +instance (GVec n a, GVec n b) => GVec n (a :*: b) where + type GVecR t n (a :*: b) = GVecR (GVecR t n a) n b + gvecR = gvecR @n @b . gvecR @n @a + gvtagsR = concatMap (gvtagsR @n @b) . gvtagsR @n @a + +instance (GVec n a, GVec n b, GSumVec n (a :+: b)) => GVec n (a :+: b) where + type GVecR t n (a :+: b) = (Prim.Vec n TAG, GSumVecR t n (a :+: b)) + gvecR t = TupRpair (TupRsingle scalarType) (gsumvecR @n @(a :+: b) t) + -- gvtagsR t = let zero = Prim.fromList (replicate (fromInteger (natVal' (proxy# @n))) 0) + -- in uncurry TagRtag <$> gsumvtagsR @n @(a :+: b) zero t + gvtagsR _ = error "TODO: gvtagsR (:+:)" + +class KnownNat n => GSumVec n (f :: Type -> Type) where + type GSumVecR t n f + gsumvecR :: TypeR t -> TypeR (GSumVecR t n f) + -- gsumvtagsR :: Prim.Vec n TAG -> TagR t -> [(Prim.Vec n TAG, TagR (GSumVecR t n f))] + +instance KnownNat n => GSumVec n U1 where + type GSumVecR t n U1 = t + gsumvecR t = t + +instance GSumVec n a => GSumVec n (M1 i c a) where + type GSumVecR t n (M1 i c a) = GSumVecR t n a + gsumvecR = gsumvecR @n @a + +instance (KnownNat n, SIMD n a) => GSumVec n (K1 i a) where + type GSumVecR t n (K1 i a) = (t, VecR n a) + gsumvecR t = TupRpair t (vecR @n @a) + +instance (GVec n a, GVec n b) => GSumVec n (a :*: b) where + type GSumVecR t n (a :*: b) = GVecR t n (a :*: b) + gsumvecR = gvecR @n @(a :*: b) + +instance (GSumVec n a, GSumVec n b) => GSumVec n (a :+: b) where + type GSumVecR t n (a :+: b) = GSumVecR (GSumVecR t n a) n b + gsumvecR = gsumvecR @n @b . gsumvecR @n @a + + +instance KnownNat n => SIMD n Z +instance KnownNat n => SIMD n () +instance KnownNat n => SIMD n Ordering +instance SIMD n a => SIMD n (Maybe a) +instance SIMD n sh => SIMD n (sh :. Int) +instance (SIMD n a, SIMD n b) => SIMD n (Either a b) + +instance KnownNat n => SIMD n Bool where + type VecR n Bool = Prim.Vec n Bit + vecR = TupRsingle scalarType + vtagsR = [TagRsingle scalarType] + +instance KnownNat n => SIMD n Int where + type VecR n Int = Prim.Vec n (EltR Int) + vecR = TupRsingle scalarType + vtagsR = [TagRsingle scalarType] + +instance KnownNat n => SIMD n Word where + type VecR n Word = Prim.Vec n (EltR Word) + vecR = TupRsingle scalarType + vtagsR = [TagRsingle scalarType] + +instance KnownNat n => SIMD n Char where + type VecR n Char = Prim.Vec n (EltR Char) + vecR = TupRsingle scalarType + vtagsR = [TagRsingle scalarType] + +runQ $ do + let + integralTypes :: [Name] + integralTypes = + [ ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + mkPrim :: Name -> Q [Dec] + mkPrim name = + let t = conT name + in + [d| instance KnownNat n => SIMD n $t where + type VecR n $t = Prim.Vec n $t + vecR = TupRsingle scalarType + vtagsR = [TagRsingle scalarType] + |] + + mkTuple :: Int -> Q Dec + mkTuple n = do + w <- newName "w" + let + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + ts = map varT xs + res = tupT ts + ctx = mapM (appT [t| SIMD $(varT w) |]) ts + -- + instanceD ctx [t| SIMD $(varT w) $res |] [] + -- + ps <- mapM mkPrim numTypes + ts <- mapM mkTuple [2..16] + return (concat ps ++ ts) + + +{-- +type NoTypeError = (() :: Constraint) + +type family NoNestedVec (f :: k) :: Constraint where + NoNestedVec t = If (HasNestedVec (Rep t)) + (TypeError (NestedVecError t)) + NoTypeError + +type family HasNestedVec (f :: k -> Type) :: Bool where + HasNestedVec (f :+: g) = HasNestedVec f || HasNestedVec g + HasNestedVec (f :*: g) = HasNestedVec f || HasNestedVec g + HasNestedVec (M1 _ _ a) = HasNestedVec a + HasNestedVec U1 = 'False + -- HasNestedVec (Rec0 Int) = 'False + -- HasNestedVec (Rec0 Int8) = 'False + -- HasNestedVec (Rec0 Int16) = 'False + -- HasNestedVec (Rec0 Int32) = 'False + -- HasNestedVec (Rec0 Int64) = 'False + -- HasNestedVec (Rec0 Int128) = 'False + -- HasNestedVec (Rec0 Word) = 'False + -- HasNestedVec (Rec0 Word8) = 'False + -- HasNestedVec (Rec0 Word16) = 'False + -- HasNestedVec (Rec0 Word32) = 'False + -- HasNestedVec (Rec0 Word64) = 'False + -- HasNestedVec (Rec0 Word128) = 'False + -- HasNestedVec (Rec0 Half) = 'False + -- HasNestedVec (Rec0 Float) = 'False + -- HasNestedVec (Rec0 Double) = 'False + -- HasNestedVec (Rec0 Float128) = 'False + -- HasNestedVec (Rec0 (Vec _ _)) = 'True + -- HasNestedVec (Rec0 a) = 'False + HasNestedVec (Rec0 a) = Stuck (Rep a) 'False (HasNestedVec (Rep a)) + +-- type family NoNestedVec (f :: k -> Type) :: Constraint where +-- NoNestedVec (f :+: g) = (NoNestedVec f, NoNestedVec g) +-- NoNestedVec (f :*: g) = (NoNestedVec f, NoNestedVec g) +-- NoNestedVec (M1 _ _ a) = NoNestedVec a +-- NoNestedVec U1 = NoTypeError +-- NoNestedVec (Rec0 Int) = NoTypeError +-- NoNestedVec (Rec0 Int8) = NoTypeError +-- NoNestedVec (Rec0 Int16) = NoTypeError +-- NoNestedVec (Rec0 Int32) = NoTypeError +-- NoNestedVec (Rec0 Int64) = NoTypeError +-- NoNestedVec (Rec0 Int128) = NoTypeError +-- NoNestedVec (Rec0 Word) = NoTypeError +-- NoNestedVec (Rec0 Word8) = NoTypeError +-- NoNestedVec (Rec0 Word16) = NoTypeError +-- NoNestedVec (Rec0 Word32) = NoTypeError +-- NoNestedVec (Rec0 Word64) = NoTypeError +-- NoNestedVec (Rec0 Word128) = NoTypeError +-- NoNestedVec (Rec0 Half) = NoTypeError +-- NoNestedVec (Rec0 Float) = NoTypeError +-- NoNestedVec (Rec0 Double) = NoTypeError +-- NoNestedVec (Rec0 Float128) = NoTypeError +-- NoNestedVec (Rec0 (Vec _ _)) = TypeError (NestedVecError Int) +-- NoNestedVec (Rec0 a) = Stuck (Rep a) NoTypeError (NoNestedVec (Rep a)) + +-- NoNestedVec _ = NoTypeError + -- NoNestedVec (K1 R a) = NoNestedVec (Rep a) + -- NoNestedVec (K1 R a) = Stuck (Rep a) (NoGeneric a) (NoNestedVec (Rep a)) + +-- type family IsVec a :: Bool where +-- IsVec (Vec n a) = 'True +-- IsVec _ = 'False + +-- foo :: forall a. Stuck (Rep a) (NoGeneric a) NoTypeError => () +foo :: forall a. NoNestedVec a => () +foo = () + +data T x +type family Any :: k + +type family Stuck (f :: Type -> Type) (c :: Bool) (a :: k) :: k where + Stuck T _ _ = Any + Stuck _ _ k = k +type family NoGeneric t where + NoGeneric x = TypeError ('Text "No instance for " ':<>: 'ShowType (Generic x)) -type VecElt a = (Elt a, Prim a, IsSingle a, EltR a ~ a) +-- type family NoSIMD n t where +-- NoSIMD n t = TypeError ('Text "No instance for " ':<>: 'ShowType (SIMD n t)) -instance (KnownNat n, VecElt a) => Elt (Vec n a) where - type EltR (Vec n a) = Vec n a - eltR = TupRsingle (VectorScalarType (VectorType (fromIntegral (natVal' (proxy# :: Proxy# n))) singleType)) - tagsR = [TagRsingle (VectorScalarType (VectorType (fromIntegral (natVal' (proxy# :: Proxy# n))) singleType))] - toElt = id - fromElt = id +type family NestedVecError (t :: k) :: ErrorMessage where + NestedVecError t = 'Text "Can not derive SIMD class" + ':$$: 'Text "Because '" ':<>: 'ShowType t ':<>: 'Text "' already contains a SIMD vector." +--} diff --git a/src/Data/Array/Accelerate/Test/NoFib/Base.hs b/src/Data/Array/Accelerate/Test/NoFib/Base.hs index 1307c6cbb..31061c664 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Base.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Base.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} @@ -14,14 +16,13 @@ module Data.Array.Accelerate.Test.NoFib.Base where +import Data.Array.Accelerate.Type import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Array import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Trafo.Sharing -import Data.Array.Accelerate.Type -import Data.Primitive.Vec - import Control.Monad import Data.Primitive.Types @@ -29,6 +30,8 @@ import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range +import qualified GHC.Exts as GHC + type Run = forall a. Arrays a => Acc a -> a type RunN = forall f. Afunction f => f -> AfunctionR f @@ -94,21 +97,21 @@ f32 = Gen.float (Range.linearFracFrom 0 (-log_flt_max) log_flt_max) f64 :: Gen Double f64 = Gen.double (Range.linearFracFrom 0 (-log_flt_max) log_flt_max) -v2 :: Prim a => Gen a -> Gen (Vec2 a) -v2 a = Vec2 <$> a <*> a +v2 :: (Elt a, SIMD 2 a) => Gen a -> Gen (Vec2 a) +v2 a = GHC.fromList <$> replicateM 2 a + +v3 :: (Elt a, SIMD 3 a) => Gen a -> Gen (Vec3 a) +v3 a = GHC.fromList <$> replicateM 3 a -v3 :: Prim a => Gen a -> Gen (Vec3 a) -v3 a = Vec3 <$> a <*> a <*> a +v4 :: (Elt a, SIMD 4 a) => Gen a -> Gen (Vec4 a) +v4 a = GHC.fromList <$> replicateM 4 a -v4 :: Prim a => Gen a -> Gen (Vec4 a) -v4 a = Vec4 <$> a <*> a <*> a <*> a +v8 :: (Elt a, SIMD 8 a) => Gen a -> Gen (Vec8 a) +v8 a = GHC.fromList <$> replicateM 8 a -v8 :: Prim a => Gen a -> Gen (Vec8 a) -v8 a = Vec8 <$> a <*> a <*> a <*> a <*> a <*> a <*> a <*> a +v16 :: (Elt a, SIMD 16 a) => Gen a -> Gen (Vec16 a) +v16 a = GHC.fromList <$> replicateM 16 a -v16 :: Prim a => Gen a -> Gen (Vec16 a) -v16 a = Vec16 <$> a <*> a <*> a <*> a <*> a <*> a <*> a <*> a - <*> a <*> a <*> a <*> a <*> a <*> a <*> a <*> a log_flt_max :: RealFloat a => a log_flt_max = log flt_max diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Map.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Map.hs index dec03973d..71aa3b2a4 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Map.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Map.hs @@ -61,7 +61,7 @@ test_map runN = testIntegralElt :: forall a. ( P.Integral a, P.FiniteBits a , A.Integral a, A.FiniteBits a - , A.FromIntegral a Double + , A.FromIntegral a Int, A.FromIntegral a Double , Similar a, Show a ) => Gen a -> TestTree @@ -94,7 +94,7 @@ test_map runN = ] testFloatingElt - :: forall a. (P.RealFloat a, A.Floating a, A.RealFrac a, Similar a, Show a) + :: forall a. (P.RealFloat a, A.Floating a, A.RealFrac a, Similar a, Show a, FromIntegral (Significand a) Int) => (Range a -> Gen a) -> TestTree testFloatingElt e = @@ -194,7 +194,7 @@ test_complement runN dim e = let !go = runN (A.map A.complement) in go xs ~~~ mapRef P.complement xs test_popCount - :: (Shape sh, Show sh, Show e, A.Bits e, P.Bits e, P.Eq sh) + :: (Shape sh, Show sh, Show e, A.Bits e, P.Bits e, P.Eq sh, A.Ord e, A.Integral e, A.FromIntegral e Int) => RunN -> Gen sh -> Gen e @@ -203,10 +203,10 @@ test_popCount runN dim e = property $ do sh <- forAll dim xs <- forAll (array sh e) - let !go = runN (A.map A.popCount) in go xs ~~~ mapRef P.popCount xs + let !go = runN (A.map (A.fromIntegral . A.popCount)) in go xs ~~~ mapRef P.popCount xs test_countLeadingZeros - :: (Shape sh, Show sh, Show e, A.FiniteBits e, P.FiniteBits e, P.Eq sh) + :: (Shape sh, Show sh, Show e, A.FiniteBits e, P.FiniteBits e, P.Eq sh, A.Ord e, A.Integral e, A.FromIntegral e Int) => RunN -> Gen sh -> Gen e @@ -215,10 +215,10 @@ test_countLeadingZeros runN dim e = property $ do sh <- forAll dim xs <- forAll (array sh e) - let !go = runN (A.map A.countLeadingZeros) in go xs ~~~ mapRef countLeadingZerosRef xs + let !go = runN (A.map (A.fromIntegral . A.countLeadingZeros)) in go xs ~~~ mapRef countLeadingZerosRef xs test_countTrailingZeros - :: (Shape sh, Show sh, Show e, A.FiniteBits e, P.FiniteBits e, P.Eq sh) + :: (Shape sh, Show sh, Show e, A.FiniteBits e, P.FiniteBits e, P.Eq sh, A.Ord e, A.Integral e, A.FromIntegral e Int) => RunN -> Gen sh -> Gen e @@ -227,7 +227,7 @@ test_countTrailingZeros runN dim e = property $ do sh <- forAll dim xs <- forAll (array sh e) - let !go = runN (A.map A.countTrailingZeros) in go xs ~~~ mapRef countTrailingZerosRef xs + let !go = runN (A.map (A.fromIntegral . A.countTrailingZeros)) in go xs ~~~ mapRef countTrailingZerosRef xs test_fromIntegral :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.Integral e, A.Integral e, A.FromIntegral e Double) @@ -398,7 +398,7 @@ test_log runN dim e = let !go = runN (A.map log) in go xs ~~~ mapRef log xs test_truncate - :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e) + :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e, FromIntegral (Significand e) Int) => RunN -> Gen sh -> Gen e @@ -410,7 +410,7 @@ test_truncate runN dim e = let !go = runN (A.map A.truncate) in go xs ~~~ mapRef (P.truncate :: e -> Int) xs test_round - :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e) + :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e, FromIntegral (Significand e) Int) => RunN -> Gen sh -> Gen e @@ -422,7 +422,7 @@ test_round runN dim e = let !go = runN (A.map A.round) in go xs ~~~ mapRef (P.round :: e -> Int) xs test_floor - :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e) + :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e, FromIntegral (Significand e) Int) => RunN -> Gen sh -> Gen e @@ -434,7 +434,7 @@ test_floor runN dim e = let !go = runN (A.map A.floor) in go xs ~~~ mapRef (P.floor :: e -> Int) xs test_ceiling - :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e) + :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e, FromIntegral (Significand e) Int) => RunN -> Gen sh -> Gen e diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs index 525f9e2c7..71b8889a3 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs @@ -1,9 +1,11 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Test.NoFib.Prelude.SIMD -- Copyright : [2009..2020] The Accelerate Team @@ -20,8 +22,8 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.SIMD ( ) where -import Lens.Micro ( _1, _2, _3, _4 ) -import Lens.Micro.Extras ( view ) +import Lens.Micro ( _1, _2, _3, _4 ) +import Lens.Micro.Extras ( view ) import Prelude as P import Data.Array.Accelerate as A @@ -30,8 +32,8 @@ import Data.Array.Accelerate.Sugar.Elt as S import Data.Array.Accelerate.Sugar.Shape as S import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config -import Data.Primitive.Vec import Data.Primitive.Types +import qualified Data.Primitive.Vec as Prim import Hedgehog import qualified Hedgehog.Gen as Gen @@ -39,6 +41,8 @@ import qualified Hedgehog.Gen as Gen import Test.Tasty import Test.Tasty.Hedgehog +import GHC.Exts as GHC + test_simd :: RunN -> TestTree test_simd runN = @@ -56,7 +60,7 @@ test_simd runN = , at @TestDouble $ testElt f64 ] where - testElt :: forall e. (VecElt e, P.Eq e, Show e) + testElt :: forall e. (Elt e, SIMD 2 e, SIMD 3 e, SIMD 4 e, P.Eq e, Show e) => Gen e -> TestTree testElt e = @@ -65,7 +69,7 @@ test_simd runN = , testInject e ] - testExtract :: forall e. (VecElt e, P.Eq e, Show e) + testExtract :: (Elt e, SIMD 2 e, SIMD 3 e, SIMD 4 e, P.Eq e, Show e) => Gen e -> TestTree testExtract e = @@ -75,7 +79,7 @@ test_simd runN = , testProperty "V4" $ test_extract_v4 runN dim1 e ] - testInject :: forall e. (VecElt e, P.Eq e, Show e) + testInject :: (Elt e, SIMD 2 e, SIMD 3 e, SIMD 4 e, P.Eq e, Show e) => Gen e -> TestTree testInject e = @@ -87,7 +91,7 @@ test_simd runN = test_extract_v2 - :: (Shape sh, Show sh, Show e, VecElt e, P.Eq e, P.Eq sh) + :: (Shape sh, Show sh, Show e, Elt e, SIMD 2 e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -100,7 +104,7 @@ test_extract_v2 runN dim e = let !go = runN (A.map (view _m . unpackVec2')) in go xs === mapRef (view _l . unpackVec2) xs test_extract_v3 - :: (Shape sh, Show sh, Show e, VecElt e, P.Eq e, P.Eq sh) + :: (Shape sh, Show sh, Show e, Elt e, SIMD 3 e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -113,7 +117,7 @@ test_extract_v3 runN dim e = let !go = runN (A.map (view _m . unpackVec3')) in go xs === mapRef (view _l . unpackVec3) xs test_extract_v4 - :: (Shape sh, Show sh, Show e, VecElt e, P.Eq e, P.Eq sh) + :: (Shape sh, Show sh, Show e, Elt e, SIMD 4 e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -126,7 +130,7 @@ test_extract_v4 runN dim e = let !go = runN (A.map (view _m . unpackVec4')) in go xs === mapRef (view _l . unpackVec4) xs test_inject_v2 - :: (Shape sh, Show sh, Show e, VecElt e, P.Eq e, P.Eq sh) + :: (Shape sh, Show sh, Show e, Elt e, SIMD 2 e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -137,10 +141,10 @@ test_inject_v2 runN dim e = sh2 <- forAll dim xs <- forAll (array sh1 e) ys <- forAll (array sh2 e) - let !go = runN (A.zipWith A.V2) in go xs ys === zipWithRef Vec2 xs ys + let !go = runN (A.zipWith A.V2) in go xs ys === zipWithRef (\x y -> GHC.fromList [x,y]) xs ys test_inject_v3 - :: (Shape sh, Show sh, Show e, VecElt e, P.Eq e, P.Eq sh) + :: (Shape sh, Show sh, Show e, Elt e, SIMD 3 e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -153,10 +157,10 @@ test_inject_v3 runN dim e = xs <- forAll (array sh1 e) ys <- forAll (array sh2 e) zs <- forAll (array sh3 e) - let !go = runN (A.zipWith3 A.V3) in go xs ys zs === zipWith3Ref Vec3 xs ys zs + let !go = runN (A.zipWith3 A.V3) in go xs ys zs === zipWith3Ref (\x y z -> GHC.fromList [x,y,z]) xs ys zs test_inject_v4 - :: (Shape sh, Show sh, Show e, VecElt e, P.Eq e, P.Eq sh) + :: (Shape sh, Show sh, Show e, Elt e, SIMD 4 e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -171,25 +175,34 @@ test_inject_v4 runN dim e = ys <- forAll (array sh2 e) zs <- forAll (array sh3 e) ws <- forAll (array sh4 e) - let !go = runN (A.zipWith4 A.V4) in go xs ys zs ws === zipWith4Ref Vec4 xs ys zs ws + let !go = runN (A.zipWith4 A.V4) in go xs ys zs ws === zipWith4Ref (\x y z w -> GHC.fromList [x,y,z,w]) xs ys zs ws -unpackVec2 :: Prim e => Vec2 e -> (e, e) -unpackVec2 (Vec2 a b) = (a, b) +unpackVec2 :: (Elt e, SIMD 2 e) => Vec2 e -> (e, e) +unpackVec2 v = + case GHC.toList v of + [a,b] -> (a, b) + _ -> undefined -unpackVec3 :: Prim e => Vec3 e -> (e, e, e) -unpackVec3 (Vec3 a b c) = (a, b, c) +unpackVec3 :: (Elt e, SIMD 3 e) => Vec3 e -> (e, e, e) +unpackVec3 v = + case GHC.toList v of + [a,b,c] -> (a, b, c) + _ -> undefined -unpackVec4 :: Prim e => Vec4 e -> (e, e, e, e) -unpackVec4 (Vec4 a b c d) = (a, b, c, d) +unpackVec4 :: (Elt e, SIMD 4 e) => Vec4 e -> (e, e, e, e) +unpackVec4 v = + case GHC.toList v of + [a,b,c,d] -> (a, b, c, d) + _ -> undefined -unpackVec2' :: VecElt e => Exp (Vec2 e) -> (Exp e, Exp e) +unpackVec2' :: (Elt e, SIMD 2 e) => Exp (Vec2 e) -> (Exp e, Exp e) unpackVec2' (A.V2 a b) = (a, b) -unpackVec3' :: VecElt e => Exp (Vec3 e) -> (Exp e, Exp e, Exp e) +unpackVec3' :: (Elt e, SIMD 3 e) => Exp (Vec3 e) -> (Exp e, Exp e, Exp e) unpackVec3' (A.V3 a b c) = (a, b, c) -unpackVec4' :: VecElt e => Exp (Vec4 e) -> (Exp e, Exp e, Exp e, Exp e) +unpackVec4' :: (Elt e, SIMD 4 e) => Exp (Vec4 e) -> (Exp e, Exp e, Exp e, Exp e) unpackVec4' (A.V4 a b c d) = (a, b, c, d) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs index dec121c68..da012bc32 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs @@ -635,7 +635,7 @@ bound bnd sh0 ix0 = go TupRunit () () = Right () go (TupRpair tsh tsz) (sh,sz) (ih,iz) = go tsh sh ih `addDim` go tsz sz iz go (TupRsingle t) sh i - | Just Refl <- matchScalarType t (scalarType :: ScalarType Int) + | Just Refl <- matchScalarType t (scalarType :: ScalarType INT) = if i P.< 0 then case bnd of Clamp -> Right 0 diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/ZipWith.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/ZipWith.hs index b73785db2..e04b1ca00 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/ZipWith.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/ZipWith.hs @@ -61,7 +61,7 @@ test_zipWith runN = where testIntegralElt :: forall a. ( P.Integral a, P.FiniteBits a - , A.Integral a, A.FiniteBits a + , A.Integral a, A.FiniteBits a, A.FromIntegral Int a , Similar a, Show a ) => Gen a -> TestTree @@ -93,10 +93,10 @@ test_zipWith runN = , testProperty "(.&.)" $ test_band runN sh e , testProperty "(.|.)" $ test_bor runN sh e , testProperty "xor" $ test_xor runN sh e - , testProperty "shift" $ test_shift runN sh e + -- , testProperty "shift" $ test_shift runN sh e , testProperty "shiftL" $ test_shiftL runN sh e , testProperty "shiftR" $ test_shiftR runN sh e - , testProperty "rotate" $ test_rotate runN sh e + -- , testProperty "rotate" $ test_rotate runN sh e , testProperty "rotateL" $ test_rotateL runN sh e , testProperty "rotateR" $ test_rotateR runN sh e @@ -381,23 +381,23 @@ test_xor runN dim e = ys <- forAll (array sh2 e) let !go = runN (A.zipWith A.xor) in go xs ys ~~~ zipWithRef P.xor xs ys -test_shift - :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) - => RunN - -> Gen sh - -> Gen e - -> Property -test_shift runN dim e = - property $ do - let s = P.finiteBitSize (undefined::e) - sh1 <- forAll dim - sh2 <- forAll dim - xs <- forAll (array sh1 e) - ys <- forAll (array sh2 (Gen.int (Range.linearFrom 0 (-s) s))) - let !go = runN (A.zipWith A.shift) in go xs ys ~~~ zipWithRef P.shift xs ys +-- test_shift +-- :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) +-- => RunN +-- -> Gen sh +-- -> Gen e +-- -> Property +-- test_shift runN dim e = +-- property $ do +-- let s = P.finiteBitSize (undefined::e) +-- sh1 <- forAll dim +-- sh2 <- forAll dim +-- xs <- forAll (array sh1 e) +-- ys <- forAll (array sh2 (Gen.int (Range.linearFrom 0 (-s) s))) +-- let !go = runN (A.zipWith A.shift) in go xs ys ~~~ zipWithRef P.shift xs ys test_shiftL - :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) + :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e, A.FromIntegral Int e) => RunN -> Gen sh -> Gen e @@ -409,10 +409,10 @@ test_shiftL runN dim e = sh2 <- forAll dim xs <- forAll (array sh1 e) ys <- forAll (array sh2 (Gen.int (Range.linear 0 s))) - let !go = runN (A.zipWith A.shiftL) in go xs ys ~~~ zipWithRef P.shiftL xs ys + let !go = runN (A.zipWith (\x -> A.shiftL x . A.fromIntegral)) in go xs ys ~~~ zipWithRef P.shiftL xs ys test_shiftR - :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) + :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e, A.FromIntegral Int e) => RunN -> Gen sh -> Gen e @@ -424,25 +424,25 @@ test_shiftR runN dim e = sh2 <- forAll dim xs <- forAll (array sh1 e) ys <- forAll (array sh2 (Gen.int (Range.linear 0 s))) - let !go = runN (A.zipWith A.shiftR) in go xs ys ~~~ zipWithRef P.shiftR xs ys - -test_rotate - :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) - => RunN - -> Gen sh - -> Gen e - -> Property -test_rotate runN dim e = - property $ do - let s = P.finiteBitSize (undefined::e) - sh1 <- forAll dim - sh2 <- forAll dim - xs <- forAll (array sh1 e) - ys <- forAll (array sh2 (Gen.int (Range.linearFrom 0 (-s) s))) - let !go = runN (A.zipWith A.rotate) in go xs ys ~~~ zipWithRef P.rotate xs ys + let !go = runN (A.zipWith (\x -> A.shiftR x . A.fromIntegral)) in go xs ys ~~~ zipWithRef P.shiftR xs ys + +-- test_rotate +-- :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) +-- => RunN +-- -> Gen sh +-- -> Gen e +-- -> Property +-- test_rotate runN dim e = +-- property $ do +-- let s = P.finiteBitSize (undefined::e) +-- sh1 <- forAll dim +-- sh2 <- forAll dim +-- xs <- forAll (array sh1 e) +-- ys <- forAll (array sh2 (Gen.int (Range.linearFrom 0 (-s) s))) +-- let !go = runN (A.zipWith A.rotate) in go xs ys ~~~ zipWithRef P.rotate xs ys test_rotateL - :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) + :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e, A.FromIntegral Int e) => RunN -> Gen sh -> Gen e @@ -454,10 +454,10 @@ test_rotateL runN dim e = sh2 <- forAll dim xs <- forAll (array sh1 e) ys <- forAll (array sh2 (Gen.int (Range.linear 0 s))) - let !go = runN (A.zipWith A.rotateL) in go xs ys ~~~ zipWithRef P.rotateL xs ys + let !go = runN (A.zipWith (\x -> A.rotateL x . A.fromIntegral)) in go xs ys ~~~ zipWithRef P.rotateL xs ys test_rotateR - :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) + :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e, A.FromIntegral Int e) => RunN -> Gen sh -> Gen e @@ -469,7 +469,7 @@ test_rotateR runN dim e = sh2 <- forAll dim xs <- forAll (array sh1 e) ys <- forAll (array sh2 (Gen.int (Range.linear 0 s))) - let !go = runN (A.zipWith A.rotateR) in go xs ys ~~~ zipWithRef P.rotateR xs ys + let !go = runN (A.zipWith (\x -> A.rotateR x . A.fromIntegral)) in go xs ys ~~~ zipWithRef P.rotateR xs ys test_lt :: (Shape sh, Show sh, Show e, P.Eq sh, P.Ord e, A.Ord e) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs b/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs index a65c4b3ef..f0e5ea996 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs @@ -139,21 +139,21 @@ instance Radix Word64 where radix = radixOfUnsigned radixOfSigned - :: forall e. (Radix e, A.Bounded e, A.Integral e, A.FromIntegral e Int) + :: forall e. (Radix e, A.Bounded e, A.Integral e, A.FromIntegral Int e, A.FromIntegral e Int) => Exp Int -> Exp e -> Exp Int radixOfSigned i e = i A.== (passes' - 1) ? (radix' (e `xor` minBound), radix' e) where - radix' x = A.fromIntegral $ (x `A.shiftR` i) .&. 1 + radix' x = A.fromIntegral $ (x `A.shiftR` A.fromIntegral i) .&. 1 passes' = constant (passes (undefined :: e)) radixOfUnsigned - :: (Radix e, A.Integral e, A.FromIntegral e Int) + :: (Radix e, A.Integral e, A.FromIntegral Int e, A.FromIntegral e Int) => Exp Int -> Exp e -> Exp Int -radixOfUnsigned i e = A.fromIntegral $ (e `A.shiftR` i) .&. 1 +radixOfUnsigned i e = A.fromIntegral $ (e `A.shiftR` A.fromIntegral i) .&. 1 -- A simple (parallel) radix sort implementation [1]. diff --git a/src/Data/Array/Accelerate/Test/Similar.hs b/src/Data/Array/Accelerate/Test/Similar.hs index 1017408cc..1ca864506 100644 --- a/src/Data/Array/Accelerate/Test/Similar.hs +++ b/src/Data/Array/Accelerate/Test/Similar.hs @@ -64,33 +64,23 @@ instance Similar Int8 instance Similar Int16 instance Similar Int32 instance Similar Int64 +instance Similar Int128 instance Similar Word8 instance Similar Word16 instance Similar Word32 instance Similar Word64 +instance Similar Word128 instance Similar Char instance Similar Bool -instance Similar CShort -instance Similar CUShort -instance Similar CInt -instance Similar CUInt -instance Similar CLong -instance Similar CULong -instance Similar CLLong -instance Similar CULLong -instance Similar CChar -instance Similar CSChar -instance Similar CUChar instance Similar (Any Z) instance (Eq sh, Eq sz) => Similar (sh:.sz) instance (Eq sh) => Similar (Any (sh:.Int)) -instance Similar Half where (~=) = absRelTol 0.05 0.5 -instance Similar Float where (~=) = absRelTol 0.00005 0.005 -instance Similar Double where (~=) = absRelTol 0.00005 0.005 -instance Similar CFloat where (~=) = absRelTol 0.00005 0.005 -instance Similar CDouble where (~=) = absRelTol 0.00005 0.005 +instance Similar Float16 where (~=) = absRelTol 0.05 0.5 +instance Similar Float32 where (~=) = absRelTol 0.00005 0.005 +instance Similar Float64 where (~=) = absRelTol 0.000005 0.0005 +instance Similar Float128 where (~=) = absRelTol 0.0000005 0.0005 instance (Similar a, Similar b) => Similar (a, b) where (x1, x2) ~= (y1, y2) = x1 ~= y1 && x2 ~= y2 diff --git a/src/Data/Array/Accelerate/Trafo/Algebra.hs b/src/Data/Array/Accelerate/Trafo/Algebra.hs index 3ef22973e..270862dcd 100644 --- a/src/Data/Array/Accelerate/Trafo/Algebra.hs +++ b/src/Data/Array/Accelerate/Trafo/Algebra.hs @@ -32,21 +32,19 @@ import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Pretty.Print ( primOperator, isInfix, opName ) import Data.Array.Accelerate.Trafo.Environment -import Data.Array.Accelerate.Type -import Data.Array.Accelerate.Classes.Vector import qualified Data.Array.Accelerate.Debug.Internal.Stats as Stats -import Data.Bits +-- import Data.Bits import Data.Monoid -import Data.Primitive.Vec +-- import Data.Primitive.Vec import Data.Text ( Text ) import Prettyprinter import Prettyprinter.Render.Text import Prelude hiding ( exp ) -import qualified Prelude as P +-- import qualified Prelude as P -import GHC.Float ( float2Double, double2Float ) +-- import GHC.Float ( float2Double, double2Float ) -- Propagate constant expressions, which are either constant valued expressions @@ -62,7 +60,6 @@ propagate env = cvtE cvtE :: OpenExp env aenv e -> Maybe e cvtE exp = case exp of Const _ c -> Just c - PrimConst c -> Just (evalPrimConst c) Evar (Var _ ix) | e <- prjExp ix env , Nothing <- matchOpenExp exp e -> cvtE e @@ -88,6 +85,8 @@ evalPrimApp env f x | otherwise = maybe (Any False, PrimApp f x) (Any True,) $ case f of + _ -> Nothing + {-- PrimAdd ty -> evalAdd ty x env PrimSub ty -> evalSub ty x env PrimMul ty -> evalMul ty x env @@ -150,6 +149,7 @@ evalPrimApp env f x PrimLNot -> evalLNot x env PrimFromIntegral ta tb -> evalFromIntegral ta tb x env PrimToFloating ta tb -> evalToFloating ta tb x env + --} -- Discriminate binary functions that commute, and if so return the operands in @@ -251,6 +251,7 @@ associates fun exp = case fun of -- Helper functions -- ---------------- +{-- type a :-> b = forall env aenv. OpenExp env aenv a -> Gamma env env aenv -> Maybe (OpenExp env aenv b) eval1 :: SingleType b -> (a -> b) -> a :-> b @@ -299,6 +300,7 @@ untup2 :: OpenExp env aenv (a, b) -> Maybe (OpenExp env aenv a, OpenExp env aenv untup2 exp | Pair a b <- exp = Just (a, b) | otherwise = Nothing +--} pprFun :: Text -> PrimFun f -> Text @@ -312,7 +314,7 @@ pprFun rule f then parens (opName op) else opName op - +{-- -- Methods of Num -- -------------- @@ -725,22 +727,5 @@ evalToFloating (FloatingNumType ta) tb x env | FloatingDict <- floatingDict ta , FloatingDict <- floatingDict tb = eval1 (NumSingleType $ FloatingNumType tb) realToFrac x env - - --- Scalar primitives --- ----------------- - -evalPrimConst :: PrimConst a -> a -evalPrimConst (PrimMinBound ty) = evalMinBound ty -evalPrimConst (PrimMaxBound ty) = evalMaxBound ty -evalPrimConst (PrimPi ty) = evalPi ty - -evalMinBound :: BoundedType a -> a -evalMinBound (IntegralBoundedType ty) | IntegralDict <- integralDict ty = minBound - -evalMaxBound :: BoundedType a -> a -evalMaxBound (IntegralBoundedType ty) | IntegralDict <- integralDict ty = maxBound - -evalPi :: FloatingType a -> a -evalPi ty | FloatingDict <- floatingDict ty = pi +--} diff --git a/src/Data/Array/Accelerate/Trafo/Delayed.hs b/src/Data/Array/Accelerate/Trafo/Delayed.hs index 045d205e0..48d7d44cb 100644 --- a/src/Data/Array/Accelerate/Trafo/Delayed.hs +++ b/src/Data/Array/Accelerate/Trafo/Delayed.hs @@ -30,6 +30,7 @@ import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Trafo.Substitution +import Data.Array.Accelerate.Type import Data.Array.Accelerate.Debug.Internal.Stats as Stats @@ -56,7 +57,7 @@ data DelayedOpenAcc aenv a where { reprD :: ArrayR (Array sh e) , extentD :: Exp aenv sh , indexD :: Fun aenv (sh -> e) - , linearIndexD :: Fun aenv (Int -> e) + , linearIndexD :: Fun aenv (INT -> e) } -> DelayedOpenAcc aenv (Array sh e) instance HasArraysR DelayedOpenAcc where diff --git a/src/Data/Array/Accelerate/Trafo/Fusion.hs b/src/Data/Array/Accelerate/Trafo/Fusion.hs index 81487be09..d05abfb5d 100644 --- a/src/Data/Array/Accelerate/Trafo/Fusion.hs +++ b/src/Data/Array/Accelerate/Trafo/Fusion.hs @@ -1472,15 +1472,16 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en Undef tR -> Undef tR Nil -> Nil Pair e1 e2 -> Pair (cvtE e1) (cvtE e2) - VecPack vR e -> VecPack vR (cvtE e) - VecUnpack vR e -> VecUnpack vR (cvtE e) + Extract vR iR v i -> Extract vR iR (cvtE v) (cvtE i) + Insert vR iR v i x -> Insert vR iR (cvtE v) (cvtE i) (cvtE x) + Shuffle vR iR x y i -> Shuffle vR iR (cvtE x) (cvtE y) (cvtE i) + Select m x y -> Select (cvtE m) (cvtE x) (cvtE y) IndexSlice x ix sh -> IndexSlice x (cvtE ix) (cvtE sh) IndexFull x ix sl -> IndexFull x (cvtE ix) (cvtE sl) ToIndex shR' sh ix -> ToIndex shR' (cvtE sh) (cvtE ix) FromIndex shR' sh i -> FromIndex shR' (cvtE sh) (cvtE i) - Case e rhs def -> Case (cvtE e) (over (mapped . _2) cvtE rhs) (fmap cvtE def) + Case eR e rhs def -> Case eR (cvtE e) (over (mapped . _2) cvtE rhs) (fmap cvtE def) Cond p t e -> Cond (cvtE p) (cvtE t) (cvtE e) - PrimConst c -> PrimConst c PrimApp g x -> PrimApp g (cvtE x) ShapeSize shR' sh -> ShapeSize shR' (cvtE sh) While p f x -> While (replaceF sh' f' avar p) (replaceF sh' f' avar f) (cvtE x) @@ -1500,10 +1501,10 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en , Lam lhs (Body b) <- f' -> Stats.substitution "replaceE/!!" . cvtE $ Let lhs - (Let (LeftHandSideSingle scalarTypeInt) i + (Let (LeftHandSideSingle scalarType) i $ FromIndex shR (weakenE (weakenSucc' weakenId) sh') $ Evar - $ Var scalarTypeInt ZeroIdx) + $ Var scalarType ZeroIdx) b | otherwise -> LinearIndex a (cvtE i) @@ -1673,23 +1674,23 @@ identity t | DeclareVars lhs _ value <- declareVars t = Lam lhs $ Body $ expVars $ value weakenId -toIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (sh -> Int) +toIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (sh -> INT) toIndex shR sh | DeclareVars lhs k value <- declareVars $ shapeType shR = Lam lhs $ Body $ ToIndex shR (weakenE k sh) $ expVars $ value weakenId -fromIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (Int -> sh) +fromIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (INT -> sh) fromIndex shR sh - = Lam (LeftHandSideSingle scalarTypeInt) + = Lam (LeftHandSideSingle scalarType) $ Body $ FromIndex shR (weakenE (weakenSucc' weakenId) sh) $ Evar - $ Var scalarTypeInt ZeroIdx + $ Var scalarType ZeroIdx intersect :: ShapeR sh -> OpenExp env aenv sh -> OpenExp env aenv sh -> OpenExp env aenv sh intersect = mkShapeBinary f where - f a b = PrimApp (PrimMin singleType) $ Pair a b + f a b = PrimApp (PrimMin scalarType) $ Pair a b -- union :: ShapeR sh -> OpenExp env aenv sh -> OpenExp env aenv sh -> OpenExp env aenv sh -- union = mkShapeBinary f @@ -1697,7 +1698,7 @@ intersect = mkShapeBinary f -- f a b = PrimApp (PrimMax singleType) $ Pair a b mkShapeBinary - :: (forall env'. OpenExp env' aenv Int -> OpenExp env' aenv Int -> OpenExp env' aenv Int) + :: (forall env'. OpenExp env' aenv INT -> OpenExp env' aenv INT -> OpenExp env' aenv INT) -> ShapeR sh -> OpenExp env aenv sh -> OpenExp env aenv sh @@ -1744,8 +1745,13 @@ indexArray v@(Var (ArrayR shR _) _) | DeclareVars lhs _ value <- declareVars $ shapeType shR = Lam lhs $ Body $ Index v $ expVars $ value weakenId -linearIndex :: ArrayVar aenv (Array sh e) -> Fun aenv (Int -> e) -linearIndex v = Lam (LeftHandSideSingle scalarTypeInt) $ Body $ LinearIndex v $ Evar $ Var scalarTypeInt ZeroIdx +linearIndex :: ArrayVar aenv (Array sh e) -> Fun aenv (INT -> e) +linearIndex v + = Lam (LeftHandSideSingle scalarType) + $ Body + $ LinearIndex v + $ Evar + $ Var scalarType ZeroIdx extractOpenAcc :: ExtractAcc OpenAcc diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index 33f1b1be4..3aa082e33 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -758,16 +758,15 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp Prj idx e -> cvtPrj idx (cvt e) Nil -> AST.Nil Pair e1 e2 -> AST.Pair (cvt e1) (cvt e2) - VecPack vec e -> AST.VecPack vec (cvt e) - VecUnpack vec e -> AST.VecUnpack vec (cvt e) - VecIndex vt it v i -> AST.VecIndex vt it (cvt v) (cvt i) - VecWrite vt it v i e -> AST.VecWrite vt it (cvt v) (cvt i) (cvt e) + Extract vR iR v i -> AST.Extract vR iR (cvt v) (cvt i) + Insert vR iR v i x -> AST.Insert vR iR (cvt v) (cvt i) (cvt x) + Shuffle eR iR x y i -> AST.Shuffle eR iR (cvt x) (cvt y) (cvt i) + Select m x y -> AST.Select (cvt m) (cvt x) (cvt y) ToIndex shr sh ix -> AST.ToIndex shr (cvt sh) (cvt ix) FromIndex shr sh e -> AST.FromIndex shr (cvt sh) (cvt e) Case e rhs -> cvtCase (cvt e) (over (mapped . _2) cvt rhs) Cond e1 e2 e3 -> AST.Cond (cvt e1) (cvt e2) (cvt e3) While tp p it i -> AST.While (cvtFun1 tp p) (cvtFun1 tp it) (cvt i) - PrimConst c -> AST.PrimConst c PrimApp f e -> cvtPrimFun f (cvt e) Index _ a e -> AST.Index (cvtAvar a) (cvt e) LinearIndex _ a i -> AST.LinearIndex (cvtAvar a) (cvt i) @@ -826,7 +825,7 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp e = prjT (fst (head rs)) s rhs = map (nested s . map (over _1 ignore)) groups in - AST.Case e (zip tags rhs) Nothing + AST.Case scalarType e (zip tags rhs) Nothing -- Extract the variable representing this particular tag from the -- scrutinee. This is safe because we let-bind the argument first. @@ -834,8 +833,9 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp prjT = fromJust $$ go where go :: TagR a -> AST.OpenExp env' aenv' a -> Maybe (AST.OpenExp env' aenv' TAG) - go TagRtag{} (AST.Pair l _) = Just l - go (TagRpair ta tb) (AST.Pair l r) = + go TagRbit{} _ = error "TODO: TagRbit" + go (TagRtag TypeWord8 _ _) (AST.Pair l _) = Just l + go (TagRpair ta tb) (AST.Pair l r) = case go ta l of Just t -> Just t Nothing -> go tb r @@ -846,11 +846,12 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp eqT a b = snd $ go a b where go :: TagR a -> TagR a -> (Any, Bool) - go TagRunit TagRunit = no True - go TagRsingle{} TagRsingle{} = no True - go TagRundef{} TagRundef{} = no True - go (TagRtag v1 _) (TagRtag v2 _) = yes (v1 == v2) - go (TagRpair a1 b1) (TagRpair a2 b2) = + go TagRunit TagRunit = no True + go TagRsingle{} TagRsingle{} = no True + go TagRundef{} TagRundef{} = no True + go (TagRtag TypeWord8 v1 _) (TagRtag TypeWord8 v2 _) = yes (v1 == v2) + go (TagRbit TypeBit v1) (TagRbit TypeBit v2) = yes (v1 == v2) + go (TagRpair a1 b1) (TagRpair a2 b2) = let (Any r, s) = go a1 a2 in case r of True -> yes s @@ -861,7 +862,8 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp firstT = fromJust . go where go :: TagR a -> Maybe TAG - go (TagRtag v _) = Just v + go TagRbit{} = error "TODO: TagRbit" + go (TagRtag TypeWord8 v _) = Just v go (TagRpair a b) = case go a of Just t -> Just t @@ -877,7 +879,8 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp go TagRunit = no $ TagRunit go (TagRsingle t) = no $ TagRsingle t go (TagRundef t) = no $ TagRundef t - go (TagRtag _ a) = yes $ TagRpair (TagRundef scalarType) a + go (TagRtag t _ a) = yes $ TagRpair (TagRundef (NumScalarType (IntegralNumType (SingleIntegralType t)))) a + go TagRbit{} = error "TODO: TagRbit" go (TagRpair a1 a2) = let (Any r, a1') = go a1 in case r of @@ -1845,10 +1848,10 @@ makeOccMapSharingExp config accOccMap expOccMap = travE Nil -> return (Nil, 1) Pair e1 e2 -> travE2 Pair e1 e2 Prj i e -> travE1 (Prj i) e - VecPack vec e -> travE1 (VecPack vec) e - VecUnpack vec e -> travE1 (VecUnpack vec) e - VecIndex vt ti v i -> travE2 (VecIndex vt ti) v i - VecWrite vt ti v i e -> travE3 (VecWrite vt ti) v i e + Extract vR iR v i -> travE2 (Extract vR iR) v i + Insert vR iR v i x -> travE3 (Insert vR iR) v i x + Shuffle eR iR x y i -> travE3 (Shuffle eR iR) x y i + Select m x y -> travE3 Select m x y ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix FromIndex shr sh e -> travE2 (FromIndex shr) sh e Match t e -> travE1 (Match t) e @@ -1862,7 +1865,6 @@ makeOccMapSharingExp config accOccMap expOccMap = travE (iter', h2) <- traverseFun1 lvl t iter (init', h3) <- travE lvl init return (While t p' iter' init', h1 `max` h2 `max` h3 + 1) - PrimConst c -> return (PrimConst c, 1) PrimApp p e -> travE1 (PrimApp p) e Index tp a e -> travAE (Index tp) a e LinearIndex tp a i -> travAE (LinearIndex tp) a i @@ -2753,10 +2755,10 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp Pair e1 e2 -> travE2 Pair e1 e2 Nil -> reconstruct Nil noNodeCounts Prj i e -> travE1 (Prj i) e - VecPack vec e -> travE1 (VecPack vec) e - VecUnpack vec e -> travE1 (VecUnpack vec) e - VecIndex vt it v i -> travE2 (VecIndex vt it) v i - VecWrite vt it v i e -> travE3 (VecWrite vt it) v i e + Extract vR iR v i -> travE2 (Extract vR iR) v i + Insert vR iR v i x -> travE3 (Insert vR iR) v i x + Shuffle eR iR x y i -> travE3 (Shuffle eR iR) x y i + Select m x y -> travE3 Select m x y ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix FromIndex shr sh e -> travE2 (FromIndex shr) sh e Match t e -> travE1 (Match t) e @@ -2768,7 +2770,6 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp (it', accCount2) = scopesFun1 it (i' , accCount3) = scopesExp i in reconstruct (While tp p' it' i') (accCount1 +++ accCount2 +++ accCount3) - PrimConst c -> reconstruct (PrimConst c) noNodeCounts PrimApp p e -> travE1 (PrimApp p) e Index tp a e -> travAE (Index tp) a e LinearIndex tp a e -> travAE (LinearIndex tp) a e diff --git a/src/Data/Array/Accelerate/Trafo/Shrink.hs b/src/Data/Array/Accelerate/Trafo/Shrink.hs index 636043113..8f80dbc68 100644 --- a/src/Data/Array/Accelerate/Trafo/Shrink.hs +++ b/src/Data/Array/Accelerate/Trafo/Shrink.hs @@ -244,7 +244,6 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE cheap (Pair e1 e2) = cheap e1 && cheap e2 cheap Nil = True cheap Const{} = True - cheap PrimConst{} = True cheap Undef{} = True cheap (Coerce _ _ e) = cheap e cheap _ = False @@ -291,18 +290,17 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE Undef t -> pure (Undef t) Nil -> pure Nil Pair x y -> Pair <$> shrinkE x <*> shrinkE y - VecPack vec e -> VecPack vec <$> shrinkE e - VecUnpack vec e -> VecUnpack vec <$> shrinkE e - VecIndex vt it v i -> VecIndex vt it <$> shrinkE v <*> shrinkE i - VecWrite vt it v i e -> VecWrite vt it <$> shrinkE v <*> shrinkE i <*> shrinkE e + Extract vR iR v i -> Extract vR iR <$> shrinkE v <*> shrinkE i + Insert vR iR v i x -> Insert vR iR <$> shrinkE v <*> shrinkE i <*> shrinkE x + Shuffle eR iR x y i -> Shuffle eR iR <$> shrinkE x <*> shrinkE y <*> shrinkE i + Select m x y -> Select <$> shrinkE m <*> shrinkE x <*> shrinkE y IndexSlice x ix sh -> IndexSlice x <$> shrinkE ix <*> shrinkE sh IndexFull x ix sl -> IndexFull x <$> shrinkE ix <*> shrinkE sl ToIndex shr sh ix -> ToIndex shr <$> shrinkE sh <*> shrinkE ix FromIndex shr sh i -> FromIndex shr <$> shrinkE sh <*> shrinkE i - Case e rhs def -> Case <$> shrinkE e <*> sequenceA [ (t,) <$> shrinkE c | (t,c) <- rhs ] <*> shrinkMaybeE def + Case eR e rhs def -> Case eR <$> shrinkE e <*> sequenceA [ (t,) <$> shrinkE c | (t,c) <- rhs ] <*> shrinkMaybeE def Cond p t e -> Cond <$> shrinkE p <*> shrinkE t <*> shrinkE e While p f x -> While <$> shrinkF p <*> shrinkF f <*> shrinkE x - PrimConst c -> pure (PrimConst c) PrimApp f x -> PrimApp f <$> shrinkE x Index a sh -> Index a <$> shrinkE sh LinearIndex a i -> LinearIndex a <$> shrinkE i @@ -449,7 +447,6 @@ shrinkPreAcc shrinkAcc reduceAcc = Stats.substitution "shrinkA" shrinkA FromIndex sh i -> FromIndex (shrinkE sh) (shrinkE i) Cond p t e -> Cond (shrinkE p) (shrinkE t) (shrinkE e) While p f x -> While (shrinkF p) (shrinkF f) (shrinkE x) - PrimConst c -> PrimConst c PrimApp f x -> PrimApp f (shrinkE x) Index a sh -> Index (shrinkAcc a) (shrinkE sh) LinearIndex a i -> LinearIndex (shrinkAcc a) (shrinkE i) @@ -494,18 +491,17 @@ usesOfExp range = countE Undef _ -> Finite 0 Nil -> Finite 0 Pair e1 e2 -> countE e1 <> countE e2 - VecPack _ e -> countE e - VecUnpack _ e -> countE e - VecIndex _ _ v i -> countE v <> countE i - VecWrite _ _ v i e -> countE v <> countE i <> countE e + Extract _ _ v i -> countE v <> countE i + Insert _ _ v i x -> countE v <> countE i <> countE x + Shuffle _ _ x y i -> countE x <> countE y <> countE i + Select m x y -> countE m <> countE x <> countE y IndexSlice _ ix sh -> countE ix <> countE sh IndexFull _ ix sl -> countE ix <> countE sl FromIndex _ sh i -> countE sh <> countE i ToIndex _ sh e -> countE sh <> countE e - Case e rhs def -> countE e <> mconcat [ countE c | (_,c) <- rhs ] <> maybe (Finite 0) countE def + Case _ e rhs def -> countE e <> mconcat [ countE c | (_,c) <- rhs ] <> maybe (Finite 0) countE def Cond p t e -> countE p <> countE t <> countE e While p f x -> countE x <> loopCount (usesOfFun range p) <> loopCount (usesOfFun range f) - PrimConst _ -> Finite 0 PrimApp _ x -> countE x Index _ sh -> countE sh LinearIndex _ i -> countE i @@ -583,18 +579,17 @@ usesOfPreAcc withShape countAcc idx = count Undef _ -> 0 Nil -> 0 Pair x y -> countE x + countE y - VecPack _ e -> countE e - VecUnpack _ e -> countE e - VecIndex _ _ v i -> countE v + countE i - VecWrite _ _ v i e -> countE v + countE i + countE e + Extract _ _ v i -> countE v + countE i + Insert _ _ v i x -> countE v + countE i + countE x + Shuffle _ _ x y i -> countE x + countE y + countE i + Select m x y -> countE m + countE x + countE y IndexSlice _ ix sh -> countE ix + countE sh IndexFull _ ix sl -> countE ix + countE sl ToIndex _ sh ix -> countE sh + countE ix FromIndex _ sh i -> countE sh + countE i - Case e rhs def -> countE e + sum [ countE c | (_,c) <- rhs ] + maybe 0 countE def + Case _ e rhs def -> countE e + sum [ countE c | (_,c) <- rhs ] + maybe 0 countE def Cond p t e -> countE p + countE t + countE e While p f x -> countF p + countF f + countE x - PrimConst _ -> 0 PrimApp _ x -> countE x Index a sh -> countAvar a + countE sh LinearIndex a i -> countAvar a + countE i diff --git a/src/Data/Array/Accelerate/Trafo/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Simplify.hs index 9b315f110..55a473dfb 100644 --- a/src/Data/Array/Accelerate/Trafo/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Simplify.hs @@ -276,17 +276,16 @@ simplifyOpenExp env = first getAny . cvtE Undef tp -> pure $ Undef tp Nil -> pure Nil Pair e1 e2 -> Pair <$> cvtE e1 <*> cvtE e2 - VecPack vec e -> VecPack vec <$> cvtE e - VecUnpack vec e -> VecUnpack vec <$> cvtE e - VecIndex vt it v i -> VecIndex vt it <$> cvtE v <*> cvtE i - VecWrite vt it v i e -> VecWrite vt it <$> cvtE v <*> cvtE i <*> cvtE e + Extract vR iR v i -> Extract vR iR <$> cvtE v <*> cvtE i + Insert vR iR v i x -> Insert vR iR <$> cvtE v <*> cvtE i <*> cvtE x + Shuffle eR iR x y i -> Shuffle eR iR <$> cvtE x <*> cvtE y <*> cvtE i + Select m x y -> Select <$> cvtE m <*> cvtE x <*> cvtE y IndexSlice x ix sh -> IndexSlice x <$> cvtE ix <*> cvtE sh IndexFull x ix sl -> IndexFull x <$> cvtE ix <*> cvtE sl ToIndex shr sh ix -> toIndex shr (cvtE sh) (cvtE ix) FromIndex shr sh ix -> fromIndex shr (cvtE sh) (cvtE ix) - Case e rhs def -> caseof (cvtE e) (sequenceA [ (t,) <$> cvtE c | (t,c) <- rhs ]) (cvtMaybeE def) + Case eR e rhs def -> caseof eR (cvtE e) (sequenceA [ (t,) <$> cvtE c | (t,c) <- rhs ]) (cvtMaybeE def) Cond p t e -> cond (cvtE p) (cvtE t) (cvtE e) - PrimConst c -> pure $ PrimConst c PrimApp f x -> (u<>v, fx) where (u, x') = cvtE x @@ -335,11 +334,12 @@ simplifyOpenExp env = first getAny . cvtE | Just Refl <- matchOpenExp t' e' = Stats.knownBranch "redundant" (yes e') | otherwise = Cond <$> p <*> t <*> e - caseof :: (Any, OpenExp env aenv TAG) + caseof :: ScalarType TAG + -> (Any, OpenExp env aenv TAG) -> (Any, [(TAG, OpenExp env aenv b)]) -> (Any, Maybe (OpenExp env aenv b)) -> (Any, OpenExp env aenv b) - caseof x@(_,x') xs@(_,xs') md@(_,md') + caseof tagR x@(_,x') xs@(_,xs') md@(_,md') | Const _ t <- x' = Stats.caseElim "known" (yes (fromJust $ lookup t xs')) | Just d <- md' @@ -348,16 +348,16 @@ simplifyOpenExp env = first getAny . cvtE | Just d <- md' , [(_,(_,u))] <- us , Just Refl <- matchOpenExp d u - = Stats.caseDefault "merge" $ yes (Case x' (map snd vs) (Just u)) + = Stats.caseDefault "merge" $ yes (Case tagR x' (map snd vs) (Just u)) | Nothing <- md' , [] <- vs , [(_,(_,u))] <- us = Stats.caseElim "overlap" (yes u) | Nothing <- md' , [(_,(_,u))] <- us - = Stats.caseDefault "introduction" $ yes (Case x' (map snd vs) (Just u)) + = Stats.caseDefault "introduction" $ yes (Case tagR x' (map snd vs) (Just u)) | otherwise - = Case <$> x <*> xs <*> md + = Case tagR <$> x <*> xs <*> md where (us,vs) = partition (\(n,_) -> n > 1) $ Map.elems @@ -377,24 +377,24 @@ simplifyOpenExp env = first getAny . cvtE shape a = pure $ Shape a - shapeSize :: ShapeR sh -> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv Int) + shapeSize :: ShapeR sh -> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv INT) shapeSize shr (_, sh) | Just c <- extractConstTuple sh - = Stats.ruleFired "shapeSize/const" $ yes (Const scalarTypeInt (product (shapeToList shr c))) + = Stats.ruleFired "shapeSize/const" $ yes (Const (scalarType @INT) (product (shapeToList shr c))) shapeSize shr sh = ShapeSize shr <$> sh toIndex :: ShapeR sh -> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv sh) - -> (Any, OpenExp env aenv Int) + -> (Any, OpenExp env aenv INT) toIndex _ (_,sh) (_,FromIndex _ sh' ix) | Just Refl <- matchOpenExp sh sh' = Stats.ruleFired "toIndex/fromIndex" $ yes ix toIndex shr sh ix = ToIndex shr <$> sh <*> ix fromIndex :: ShapeR sh -> (Any, OpenExp env aenv sh) - -> (Any, OpenExp env aenv Int) + -> (Any, OpenExp env aenv INT) -> (Any, OpenExp env aenv sh) fromIndex _ (_,sh) (_,ToIndex _ sh' ix) | Just Refl <- matchOpenExp sh sh' = Stats.ruleFired "fromIndex/toIndex" $ yes ix @@ -557,37 +557,22 @@ summariseOpenExp = (terms +~ 1) . goE travA :: acc aenv a -> Stats travA _ = zero & vars +~ 1 -- assume an array index, else we should have failed elsewhere - travC :: PrimConst c -> Stats - travC (PrimMinBound t) = travBoundedType t & terms +~ 1 - travC (PrimMaxBound t) = travBoundedType t & terms +~ 1 - travC (PrimPi t) = travFloatingType t & terms +~ 1 - travIntegralType :: IntegralType t -> Stats travIntegralType _ = zero & types +~ 1 travFloatingType :: FloatingType t -> Stats travFloatingType _ = zero & types +~ 1 + travBitType :: BitType t -> Stats + travBitType _ = zero & types +~ 1 + travNumType :: NumType t -> Stats travNumType (IntegralNumType t) = travIntegralType t & types +~ 1 travNumType (FloatingNumType t) = travFloatingType t & types +~ 1 - travBoundedType :: BoundedType t -> Stats - travBoundedType (IntegralBoundedType t) = travIntegralType t & types +~ 1 - - -- travScalarType :: ScalarType t -> Stats - -- travScalarType (SingleScalarType t) = travSingleType t & types +~ 1 - -- travScalarType (VectorScalarType t) = travVectorType t & types +~ 1 - - travSingleType :: SingleType t -> Stats - travSingleType (NumSingleType t) = travNumType t & types +~ 1 - - -- travVectorType :: VectorType t -> Stats - -- travVectorType (Vector2Type t) = travSingleType t & types +~ 1 - -- travVectorType (Vector3Type t) = travSingleType t & types +~ 1 - -- travVectorType (Vector4Type t) = travSingleType t & types +~ 1 - -- travVectorType (Vector8Type t) = travSingleType t & types +~ 1 - -- travVectorType (Vector16Type t) = travSingleType t & types +~ 1 + travScalarType :: ScalarType t -> Stats + travScalarType (NumScalarType t) = travNumType t & types +~ 1 + travScalarType (BitScalarType _) = zero & types +~ 1 -- The scrutinee has already been counted goE :: OpenExp env aenv t -> Stats @@ -599,19 +584,18 @@ summariseOpenExp = (terms +~ 1) . goE Const{} -> zero Undef _ -> zero Nil -> zero & terms +~ 1 - Pair e1 e2 -> travE e1 +++ travE e2 & terms +~ 1 - VecPack _ e -> travE e - VecUnpack _ e -> travE e - VecIndex _ _ v i -> travE v +++ travE i - VecWrite _ _ v i e -> travE v +++ travE i +++ travE e + Pair e1 e2 -> travE e1 +++ travE e2 + Extract _ _ v i -> travE v +++ travE i + Insert _ _ v i x -> travE v +++ travE i +++ travE x + Shuffle _ _ x y i -> travE x +++ travE y +++ travE i + Select m x y -> travE m +++ travE x +++ travE y IndexSlice _ slix sh -> travE slix +++ travE sh & terms +~ 1 -- +1 for sliceIndex IndexFull _ slix sl -> travE slix +++ travE sl & terms +~ 1 -- +1 for sliceIndex ToIndex _ sh ix -> travE sh +++ travE ix FromIndex _ sh ix -> travE sh +++ travE ix - Case e rhs def -> travE e +++ mconcat [ travE c | (_,c) <- rhs ] +++ maybe zero travE def + Case _ e rhs def -> travE e +++ mconcat [ travE c | (_,c) <- rhs ] +++ maybe zero travE def Cond p t e -> travE p +++ travE t +++ travE e While p f x -> travF p +++ travF f +++ travE x - PrimConst c -> travC c Index a ix -> travA a +++ travE ix LinearIndex a ix -> travA a +++ travE ix Shape a -> travA a @@ -625,66 +609,68 @@ summariseOpenExp = (terms +~ 1) . goE goF :: PrimFun f -> Stats goF fun = case fun of - PrimAdd t -> travNumType t - PrimSub t -> travNumType t - PrimMul t -> travNumType t - PrimNeg t -> travNumType t - PrimAbs t -> travNumType t - PrimSig t -> travNumType t - PrimQuot t -> travIntegralType t - PrimRem t -> travIntegralType t - PrimQuotRem t -> travIntegralType t - PrimIDiv t -> travIntegralType t - PrimMod t -> travIntegralType t - PrimDivMod t -> travIntegralType t - PrimBAnd t -> travIntegralType t - PrimBOr t -> travIntegralType t - PrimBXor t -> travIntegralType t - PrimBNot t -> travIntegralType t - PrimBShiftL t -> travIntegralType t - PrimBShiftR t -> travIntegralType t - PrimBRotateL t -> travIntegralType t - PrimBRotateR t -> travIntegralType t - PrimPopCount t -> travIntegralType t - PrimCountLeadingZeros t -> travIntegralType t + PrimAdd t -> travNumType t + PrimSub t -> travNumType t + PrimMul t -> travNumType t + PrimNeg t -> travNumType t + PrimAbs t -> travNumType t + PrimSig t -> travNumType t + PrimQuot t -> travIntegralType t + PrimRem t -> travIntegralType t + PrimQuotRem t -> travIntegralType t + PrimIDiv t -> travIntegralType t + PrimMod t -> travIntegralType t + PrimDivMod t -> travIntegralType t + PrimBAnd t -> travIntegralType t + PrimBOr t -> travIntegralType t + PrimBXor t -> travIntegralType t + PrimBNot t -> travIntegralType t + PrimBShiftL t -> travIntegralType t + PrimBShiftR t -> travIntegralType t + PrimBRotateL t -> travIntegralType t + PrimBRotateR t -> travIntegralType t + PrimPopCount t -> travIntegralType t + PrimCountLeadingZeros t -> travIntegralType t PrimCountTrailingZeros t -> travIntegralType t - PrimFDiv t -> travFloatingType t - PrimRecip t -> travFloatingType t - PrimSin t -> travFloatingType t - PrimCos t -> travFloatingType t - PrimTan t -> travFloatingType t - PrimAsin t -> travFloatingType t - PrimAcos t -> travFloatingType t - PrimAtan t -> travFloatingType t - PrimSinh t -> travFloatingType t - PrimCosh t -> travFloatingType t - PrimTanh t -> travFloatingType t - PrimAsinh t -> travFloatingType t - PrimAcosh t -> travFloatingType t - PrimAtanh t -> travFloatingType t - PrimExpFloating t -> travFloatingType t - PrimSqrt t -> travFloatingType t - PrimLog t -> travFloatingType t - PrimFPow t -> travFloatingType t - PrimLogBase t -> travFloatingType t - PrimTruncate f i -> travFloatingType f +++ travIntegralType i - PrimRound f i -> travFloatingType f +++ travIntegralType i - PrimFloor f i -> travFloatingType f +++ travIntegralType i - PrimCeiling f i -> travFloatingType f +++ travIntegralType i - PrimIsNaN t -> travFloatingType t - PrimIsInfinite t -> travFloatingType t - PrimAtan2 t -> travFloatingType t - PrimLt t -> travSingleType t - PrimGt t -> travSingleType t - PrimLtEq t -> travSingleType t - PrimGtEq t -> travSingleType t - PrimEq t -> travSingleType t - PrimNEq t -> travSingleType t - PrimMax t -> travSingleType t - PrimMin t -> travSingleType t - PrimLAnd -> zero - PrimLOr -> zero - PrimLNot -> zero - PrimFromIntegral i n -> travIntegralType i +++ travNumType n - PrimToFloating n f -> travNumType n +++ travFloatingType f + PrimFDiv t -> travFloatingType t + PrimRecip t -> travFloatingType t + PrimSin t -> travFloatingType t + PrimCos t -> travFloatingType t + PrimTan t -> travFloatingType t + PrimAsin t -> travFloatingType t + PrimAcos t -> travFloatingType t + PrimAtan t -> travFloatingType t + PrimSinh t -> travFloatingType t + PrimCosh t -> travFloatingType t + PrimTanh t -> travFloatingType t + PrimAsinh t -> travFloatingType t + PrimAcosh t -> travFloatingType t + PrimAtanh t -> travFloatingType t + PrimExpFloating t -> travFloatingType t + PrimSqrt t -> travFloatingType t + PrimLog t -> travFloatingType t + PrimFPow t -> travFloatingType t + PrimLogBase t -> travFloatingType t + PrimTruncate f i -> travFloatingType f +++ travIntegralType i + PrimRound f i -> travFloatingType f +++ travIntegralType i + PrimFloor f i -> travFloatingType f +++ travIntegralType i + PrimCeiling f i -> travFloatingType f +++ travIntegralType i + PrimIsNaN t -> travFloatingType t + PrimIsInfinite t -> travFloatingType t + PrimAtan2 t -> travFloatingType t + PrimLt t -> travScalarType t + PrimGt t -> travScalarType t + PrimLtEq t -> travScalarType t + PrimGtEq t -> travScalarType t + PrimEq t -> travScalarType t + PrimNEq t -> travScalarType t + PrimMax t -> travScalarType t + PrimMin t -> travScalarType t + PrimLAnd _ -> zero & types +~ 1 + PrimLOr _ -> zero & types +~ 1 + PrimLNot _ -> zero & types +~ 1 + PrimFromIntegral i n -> travIntegralType i +++ travNumType n + PrimToFloating n f -> travNumType n +++ travFloatingType f + PrimToBool i b -> travIntegralType i +++ travBitType b + PrimFromBool b i -> travBitType b +++ travIntegralType i diff --git a/src/Data/Array/Accelerate/Trafo/Substitution.hs b/src/Data/Array/Accelerate/Trafo/Substitution.hs index 7debd6d07..fe496765c 100644 --- a/src/Data/Array/Accelerate/Trafo/Substitution.hs +++ b/src/Data/Array/Accelerate/Trafo/Substitution.hs @@ -154,19 +154,18 @@ inlineVars lhsBound expr bound Foreign tp asm f e1 -> Foreign tp asm f <$> travE e1 Pair e1 e2 -> Pair <$> travE e1 <*> travE e2 Nil -> Just Nil - VecPack vec e1 -> VecPack vec <$> travE e1 - VecUnpack vec e1 -> VecUnpack vec <$> travE e1 - VecIndex vt it v i -> VecIndex vt it <$> travE v <*> travE i - VecWrite vt it v i e -> VecWrite vt it <$> travE v <*> travE i <*> travE e + Extract vR iR v i -> Extract vR iR <$> travE v <*> travE i + Insert vR iR v i x -> Insert vR iR <$> travE v <*> travE i <*> travE x + Shuffle vR iR x y m -> Shuffle vR iR <$> travE x <*> travE y <*> travE m + Select m x y -> Select <$> travE m <*> travE x <*> travE y IndexSlice si e1 e2 -> IndexSlice si <$> travE e1 <*> travE e2 IndexFull si e1 e2 -> IndexFull si <$> travE e1 <*> travE e2 ToIndex shr e1 e2 -> ToIndex shr <$> travE e1 <*> travE e2 FromIndex shr e1 e2 -> FromIndex shr <$> travE e1 <*> travE e2 - Case e1 rhs def -> Case <$> travE e1 <*> mapM (\(t,c) -> (t,) <$> travE c) rhs <*> travMaybeE def + Case eR e1 rhs def -> Case eR <$> travE e1 <*> mapM (\(t,c) -> (t,) <$> travE c) rhs <*> travMaybeE def Cond e1 e2 e3 -> Cond <$> travE e1 <*> travE e2 <*> travE e3 While f1 f2 e1 -> While <$> travF f1 <*> travF f2 <*> travE e1 Const t c -> Just $ Const t c - PrimConst c -> Just $ PrimConst c PrimApp p e1 -> PrimApp p <$> travE e1 Index a e1 -> Index a <$> travE e1 LinearIndex a e1 -> LinearIndex a <$> travE e1 @@ -549,7 +548,6 @@ rebuildOpenExp rebuildOpenExp v av@(ReindexAvar reindex) exp = case exp of Const t c -> pure $ Const t c - PrimConst c -> pure $ PrimConst c Undef t -> pure $ Undef t Evar var -> expOut <$> v var Let lhs a b @@ -557,15 +555,15 @@ rebuildOpenExp v av@(ReindexAvar reindex) exp = -> Let lhs' <$> rebuildOpenExp v av a <*> rebuildOpenExp (shiftE' lhs lhs' v) av b Pair e1 e2 -> Pair <$> rebuildOpenExp v av e1 <*> rebuildOpenExp v av e2 Nil -> pure Nil - VecPack vec e -> VecPack vec <$> rebuildOpenExp v av e - VecUnpack vec e -> VecUnpack vec <$> rebuildOpenExp v av e - VecIndex vt it v' i -> VecIndex vt it <$> rebuildOpenExp v av v' <*> rebuildOpenExp v av i - VecWrite vt it v' i e -> VecWrite vt it <$> rebuildOpenExp v av v' <*> rebuildOpenExp v av i <*> rebuildOpenExp v av e + Extract vR iR u i -> Extract vR iR <$> rebuildOpenExp v av u <*> rebuildOpenExp v av i + Insert vR iR u i x -> Insert vR iR <$> rebuildOpenExp v av u <*> rebuildOpenExp v av i <*> rebuildOpenExp v av x + Shuffle vR iR x y m -> Shuffle vR iR <$> rebuildOpenExp v av x <*> rebuildOpenExp v av y <*> rebuildOpenExp v av m + Select m x y -> Select <$> rebuildOpenExp v av m <*> rebuildOpenExp v av x <*> rebuildOpenExp v av y IndexSlice x ix sh -> IndexSlice x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh IndexFull x ix sl -> IndexFull x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl ToIndex shr sh ix -> ToIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix FromIndex shr sh ix -> FromIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix - Case e rhs def -> Case <$> rebuildOpenExp v av e <*> sequenceA [ (t,) <$> rebuildOpenExp v av c | (t,c) <- rhs ] <*> rebuildMaybeExp v av def + Case eR e rhs def -> Case eR <$> rebuildOpenExp v av e <*> sequenceA [ (t,) <$> rebuildOpenExp v av c | (t,c) <- rhs ] <*> rebuildMaybeExp v av def Cond p t e -> Cond <$> rebuildOpenExp v av p <*> rebuildOpenExp v av t <*> rebuildOpenExp v av e While p f x -> While <$> rebuildFun v av p <*> rebuildFun v av f <*> rebuildOpenExp v av x PrimApp f x -> PrimApp f <$> rebuildOpenExp v av x diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 94e891cc1..7e603d0be 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -27,54 +27,49 @@ -- Primitive scalar types supported by Accelerate -- -- Integral types: --- * Int -- * Int8 -- * Int16 -- * Int32 -- * Int64 --- * Word +-- * Int128 -- * Word8 -- * Word16 -- * Word32 -- * Word64 +-- * Word128 -- --- Floating types: +-- Floating types (IEEE): -- * Half -- * Float -- * Double +-- * Float128 -- --- SIMD vector types of the above: --- * Vec2 --- * Vec3 --- * Vec4 --- * Vec8 --- * Vec16 +-- A single bit -- --- Note that 'Int' has the same bit width as in plain Haskell computations. --- 'Float' and 'Double' represent IEEE single and double precision floating --- point numbers, respectively. +-- and SIMD vector types of all of the above -- module Data.Array.Accelerate.Type ( - Half(..), Float, Double, - module Data.Int, - module Data.Word, - module Foreign.C.Types, + Bit(..), Half(..), Float, Double, Float128(..), + module Data.Int, Int128(..), + module Data.Word, Word128(..), module Data.Array.Accelerate.Type, ) where import Data.Array.Accelerate.Orphans () -- Prim Half + +import Data.Primitive.Bit import Data.Primitive.Vec +import Data.Numeric.Float128 import Data.Bits import Data.Int -import Data.Primitive.Types import Data.Type.Equality +import Data.WideWord.Int128 +import Data.WideWord.Word128 import Data.Word -import Foreign.C.Types -import Foreign.Storable ( Storable ) import Formatting import Language.Haskell.TH.Extra import Numeric.Half @@ -83,342 +78,281 @@ import Text.Printf import GHC.Prim import GHC.TypeLits +type Float16 = Half +type Float32 = Float +type Float64 = Double -- Scalar types -- ------------ --- Reified dictionaries --- -data SingleDict a where - SingleDict :: ( Eq a, Ord a, Show a, Storable a, Prim a ) - => SingleDict a - -data IntegralDict a where - IntegralDict :: ( Eq a, Ord a, Show a - , Bounded a, Bits a, FiniteBits a, Integral a, Num a, Real a, Storable a ) - => IntegralDict a - -data FloatingDict a where - FloatingDict :: ( Eq a, Ord a, Show a - , Floating a, Fractional a, Num a, Real a, RealFrac a, RealFloat a, Storable a ) - => FloatingDict a - - --- Scalar type representation +-- | Scalar element types are values that can be stored in machine +-- registers: ground types (int32, float64, etc.) and SIMD vectors of these -- +data ScalarType a where + NumScalarType :: NumType a -> ScalarType a + BitScalarType :: BitType a -> ScalarType a + -- Void? --- | Integral types supported in array computations. --- -data IntegralType a where - TypeInt :: IntegralType Int - TypeInt8 :: IntegralType Int8 - TypeInt16 :: IntegralType Int16 - TypeInt32 :: IntegralType Int32 - TypeInt64 :: IntegralType Int64 - TypeWord :: IntegralType Word - TypeWord8 :: IntegralType Word8 - TypeWord16 :: IntegralType Word16 - TypeWord32 :: IntegralType Word32 - TypeWord64 :: IntegralType Word64 - --- | Floating-point types supported in array computations. --- -data FloatingType a where - TypeHalf :: FloatingType Half - TypeFloat :: FloatingType Float - TypeDouble :: FloatingType Double +data BitType a where + TypeBit :: BitType Bit + TypeMask :: KnownNat n => Proxy# n -> BitType (Vec n Bit) --- | Numeric element types implement Num & Real --- data NumType a where IntegralNumType :: IntegralType a -> NumType a FloatingNumType :: FloatingType a -> NumType a --- | Bounded element types implement Bounded --- -data BoundedType a where - IntegralBoundedType :: IntegralType a -> BoundedType a - --- | All scalar element types implement Eq & Ord --- -data ScalarType a where - SingleScalarType :: SingleType a -> ScalarType a - VectorScalarType :: VectorType (Vec n a) -> ScalarType (Vec n a) +data IntegralType a where + SingleIntegralType :: SingleIntegralType a -> IntegralType a + VectorIntegralType :: KnownNat n => Proxy# n -> SingleIntegralType a -> IntegralType (Vec n a) + +data SingleIntegralType a where + TypeInt8 :: SingleIntegralType Int8 + TypeInt16 :: SingleIntegralType Int16 + TypeInt32 :: SingleIntegralType Int32 + TypeInt64 :: SingleIntegralType Int64 + TypeInt128 :: SingleIntegralType Int128 + TypeWord8 :: SingleIntegralType Word8 + TypeWord16 :: SingleIntegralType Word16 + TypeWord32 :: SingleIntegralType Word32 + TypeWord64 :: SingleIntegralType Word64 + TypeWord128 :: SingleIntegralType Word128 -data SingleType a where - NumSingleType :: NumType a -> SingleType a +data FloatingType a where + SingleFloatingType :: SingleFloatingType a -> FloatingType a + VectorFloatingType :: KnownNat n => Proxy# n -> SingleFloatingType a -> FloatingType (Vec n a) -data VectorType a where - VectorType :: KnownNat n => {-# UNPACK #-} !Int -> SingleType a -> VectorType (Vec n a) +data SingleFloatingType a where + -- TypeFloat8 :: SingleFloatingType Float8 + -- TypeBFloat16 :: SingleFloatingType BFloat16 + TypeFloat16 :: SingleFloatingType Float16 + TypeFloat32 :: SingleFloatingType Float32 + TypeFloat64 :: SingleFloatingType Float64 + TypeFloat128 :: SingleFloatingType Float128 instance Show (IntegralType a) where - show TypeInt = "Int" - show TypeInt8 = "Int8" - show TypeInt16 = "Int16" - show TypeInt32 = "Int32" - show TypeInt64 = "Int64" - show TypeWord = "Word" - show TypeWord8 = "Word8" - show TypeWord16 = "Word16" - show TypeWord32 = "Word32" - show TypeWord64 = "Word64" + show (SingleIntegralType t) = show t + show (VectorIntegralType n t) = printf "<%d x %s>" (natVal' n) (show t) instance Show (FloatingType a) where - show TypeHalf = "Half" - show TypeFloat = "Float" - show TypeDouble = "Double" + show (SingleFloatingType t) = show t + show (VectorFloatingType n t) = printf "<%d x %s>" (natVal' n) (show t) + +instance Show (SingleIntegralType a) where + show TypeInt8 = "Int8" + show TypeInt16 = "Int16" + show TypeInt32 = "Int32" + show TypeInt64 = "Int64" + show TypeInt128 = "Int128" + show TypeWord8 = "Word8" + show TypeWord16 = "Word16" + show TypeWord32 = "Word32" + show TypeWord64 = "Word64" + show TypeWord128 = "Word128" + +instance Show (SingleFloatingType a) where + show TypeFloat16 = "Float16" + show TypeFloat32 = "Float32" + show TypeFloat64 = "Float64" + show TypeFloat128 = "Float128" instance Show (NumType a) where - show (IntegralNumType ty) = show ty - show (FloatingNumType ty) = show ty - -instance Show (BoundedType a) where - show (IntegralBoundedType ty) = show ty + show (IntegralNumType t) = show t + show (FloatingNumType t) = show t -instance Show (SingleType a) where - show (NumSingleType ty) = show ty - -instance Show (VectorType a) where - show (VectorType n ty) = printf "<%d x %s>" n (show ty) +instance Show (BitType t) where + show (TypeBit) = "Bit" + show (TypeMask n) = printf "<%d x Bit>" (natVal' n) instance Show (ScalarType a) where - show (SingleScalarType ty) = show ty - show (VectorScalarType ty) = show ty + show (NumScalarType t) = show t + show (BitScalarType t) = show t formatIntegralType :: Format r (IntegralType a -> r) formatIntegralType = later $ \case - TypeInt -> "Int" - TypeInt8 -> "Int8" - TypeInt16 -> "Int16" - TypeInt32 -> "Int32" - TypeInt64 -> "Int64" - TypeWord -> "Word" - TypeWord8 -> "Word8" - TypeWord16 -> "Word16" - TypeWord32 -> "Word32" - TypeWord64 -> "Word64" + SingleIntegralType t -> bformat formatSingleIntegralType t + VectorIntegralType n t -> bformat (angled (int % " x " % formatSingleIntegralType)) (natVal' n) t + +formatSingleIntegralType :: Format r (SingleIntegralType a -> r) +formatSingleIntegralType = later $ \case + TypeInt8 -> "Int8" + TypeInt16 -> "Int16" + TypeInt32 -> "Int32" + TypeInt64 -> "Int64" + TypeInt128 -> "Int128" + TypeWord8 -> "Word8" + TypeWord16 -> "Word16" + TypeWord32 -> "Word32" + TypeWord64 -> "Word64" + TypeWord128 -> "Word128" formatFloatingType :: Format r (FloatingType a -> r) formatFloatingType = later $ \case - TypeHalf -> "Half" - TypeFloat -> "Float" - TypeDouble -> "Double" + SingleFloatingType t -> bformat formatSingleFloatingType t + VectorFloatingType n t -> bformat (angled (int % " x " % formatSingleFloatingType)) (natVal' n) t + +formatSingleFloatingType :: Format r (SingleFloatingType a -> r) +formatSingleFloatingType = later $ \case + TypeFloat16 -> "Float16" + TypeFloat32 -> "Float32" + TypeFloat64 -> "Float64" + TypeFloat128 -> "Float128" formatNumType :: Format r (NumType a -> r) formatNumType = later $ \case - IntegralNumType ty -> bformat formatIntegralType ty - FloatingNumType ty -> bformat formatFloatingType ty + IntegralNumType t -> bformat formatIntegralType t + FloatingNumType t -> bformat formatFloatingType t -formatBoundedType :: Format r (BoundedType a -> r) -formatBoundedType = later $ \case - IntegralBoundedType ty -> bformat formatIntegralType ty - -formatSingleType :: Format r (SingleType a -> r) -formatSingleType = later $ \case - NumSingleType ty -> bformat formatNumType ty - -formatVectorType :: Format r (VectorType a -> r) -formatVectorType = later $ \case - VectorType n ty -> bformat (angled (int % " x " % formatSingleType)) n ty +formatBitType :: Format r (BitType t -> r) +formatBitType = later $ \case + TypeBit -> "Bit" + TypeMask n -> bformat (angled (int % " x Bit")) (natVal' n) formatScalarType :: Format r (ScalarType a -> r) formatScalarType = later $ \case - SingleScalarType ty -> bformat formatSingleType ty - VectorScalarType ty -> bformat formatVectorType ty - - --- | Querying Integral types --- -class (IsSingle a, IsNum a, IsBounded a) => IsIntegral a where - integralType :: IntegralType a - --- | Querying Floating types --- -class (Floating a, IsSingle a, IsNum a) => IsFloating a where - floatingType :: FloatingType a - --- | Querying Numeric types --- -class (Num a, IsSingle a) => IsNum a where - numType :: NumType a - --- | Querying Bounded types --- -class IsBounded a where - boundedType :: BoundedType a - --- | Querying single value types --- -class IsScalar a => IsSingle a where - singleType :: SingleType a - --- | Querying all scalar types --- -class IsScalar a where - scalarType :: ScalarType a + NumScalarType t -> bformat formatNumType t + BitScalarType t -> bformat formatBitType t -integralDict :: IntegralType a -> IntegralDict a -integralDict TypeInt = IntegralDict -integralDict TypeInt8 = IntegralDict -integralDict TypeInt16 = IntegralDict -integralDict TypeInt32 = IntegralDict -integralDict TypeInt64 = IntegralDict -integralDict TypeWord = IntegralDict -integralDict TypeWord8 = IntegralDict -integralDict TypeWord16 = IntegralDict -integralDict TypeWord32 = IntegralDict -integralDict TypeWord64 = IntegralDict - -floatingDict :: FloatingType a -> FloatingDict a -floatingDict TypeHalf = FloatingDict -floatingDict TypeFloat = FloatingDict -floatingDict TypeDouble = FloatingDict - -singleDict :: SingleType a -> SingleDict a -singleDict = single - where - single :: SingleType a -> SingleDict a - single (NumSingleType t) = num t - - num :: NumType a -> SingleDict a - num (IntegralNumType t) = integral t - num (FloatingNumType t) = floating t - - integral :: IntegralType a -> SingleDict a - integral TypeInt = SingleDict - integral TypeInt8 = SingleDict - integral TypeInt16 = SingleDict - integral TypeInt32 = SingleDict - integral TypeInt64 = SingleDict - integral TypeWord = SingleDict - integral TypeWord8 = SingleDict - integral TypeWord16 = SingleDict - integral TypeWord32 = SingleDict - integral TypeWord64 = SingleDict - - floating :: FloatingType a -> SingleDict a - floating TypeHalf = SingleDict - floating TypeFloat = SingleDict - floating TypeDouble = SingleDict - - -scalarTypeInt :: ScalarType Int -scalarTypeInt = SingleScalarType $ NumSingleType $ IntegralNumType TypeInt - -scalarTypeWord :: ScalarType Word -scalarTypeWord = SingleScalarType $ NumSingleType $ IntegralNumType TypeWord - -scalarTypeInt32 :: ScalarType Int32 -scalarTypeInt32 = SingleScalarType $ NumSingleType $ IntegralNumType TypeInt32 - -scalarTypeWord8 :: ScalarType Word8 -scalarTypeWord8 = SingleScalarType $ NumSingleType $ IntegralNumType TypeWord8 - -scalarTypeWord32 :: ScalarType Word32 -scalarTypeWord32 = SingleScalarType $ NumSingleType $ IntegralNumType TypeWord32 - rnfScalarType :: ScalarType t -> () -rnfScalarType (SingleScalarType t) = rnfSingleType t -rnfScalarType (VectorScalarType t) = rnfVectorType t - -rnfSingleType :: SingleType t -> () -rnfSingleType (NumSingleType t) = rnfNumType t - -rnfVectorType :: VectorType t -> () -rnfVectorType (VectorType !_ t) = rnfSingleType t +rnfScalarType (NumScalarType t) = rnfNumType t +rnfScalarType (BitScalarType t) = rnfBitType t -rnfBoundedType :: BoundedType t -> () -rnfBoundedType (IntegralBoundedType t) = rnfIntegralType t +rnfBitType :: BitType t -> () +rnfBitType TypeBit = () +rnfBitType (TypeMask !_) = () rnfNumType :: NumType t -> () rnfNumType (IntegralNumType t) = rnfIntegralType t rnfNumType (FloatingNumType t) = rnfFloatingType t rnfIntegralType :: IntegralType t -> () -rnfIntegralType TypeInt = () -rnfIntegralType TypeInt8 = () -rnfIntegralType TypeInt16 = () -rnfIntegralType TypeInt32 = () -rnfIntegralType TypeInt64 = () -rnfIntegralType TypeWord = () -rnfIntegralType TypeWord8 = () -rnfIntegralType TypeWord16 = () -rnfIntegralType TypeWord32 = () -rnfIntegralType TypeWord64 = () +rnfIntegralType (SingleIntegralType t) = rnfSingleIntegralType t +rnfIntegralType (VectorIntegralType !_ t) = rnfSingleIntegralType t + +rnfSingleIntegralType :: SingleIntegralType t -> () +rnfSingleIntegralType TypeInt8 = () +rnfSingleIntegralType TypeInt16 = () +rnfSingleIntegralType TypeInt32 = () +rnfSingleIntegralType TypeInt64 = () +rnfSingleIntegralType TypeInt128 = () +rnfSingleIntegralType TypeWord8 = () +rnfSingleIntegralType TypeWord16 = () +rnfSingleIntegralType TypeWord32 = () +rnfSingleIntegralType TypeWord64 = () +rnfSingleIntegralType TypeWord128 = () rnfFloatingType :: FloatingType t -> () -rnfFloatingType TypeHalf = () -rnfFloatingType TypeFloat = () -rnfFloatingType TypeDouble = () +rnfFloatingType (SingleFloatingType t) = rnfSingleFloatingType t +rnfFloatingType (VectorFloatingType !_ t) = rnfSingleFloatingType t +rnfSingleFloatingType :: SingleFloatingType t -> () +rnfSingleFloatingType TypeFloat16 = () +rnfSingleFloatingType TypeFloat32 = () +rnfSingleFloatingType TypeFloat64 = () +rnfSingleFloatingType TypeFloat128 = () -liftScalar :: ScalarType t -> t -> CodeQ t -liftScalar (SingleScalarType t) = liftSingle t -liftScalar (VectorScalarType t) = liftVector t -liftSingle :: SingleType t -> t -> CodeQ t -liftSingle (NumSingleType t) = liftNum t +liftScalar :: ScalarType t -> t -> CodeQ t +liftScalar (NumScalarType t) = liftNum t +liftScalar (BitScalarType t) = liftBit t -liftVector :: VectorType t -> t -> CodeQ t -liftVector VectorType{} = liftVec +liftBit :: BitType t -> t -> CodeQ t +liftBit TypeBit (Bit x) = [|| Bit x ||] +liftBit TypeMask{} x = liftVec x liftNum :: NumType t -> t -> CodeQ t liftNum (IntegralNumType t) = liftIntegral t liftNum (FloatingNumType t) = liftFloating t liftIntegral :: IntegralType t -> t -> CodeQ t -liftIntegral TypeInt x = [|| x ||] -liftIntegral TypeInt8 x = [|| x ||] -liftIntegral TypeInt16 x = [|| x ||] -liftIntegral TypeInt32 x = [|| x ||] -liftIntegral TypeInt64 x = [|| x ||] -liftIntegral TypeWord x = [|| x ||] -liftIntegral TypeWord8 x = [|| x ||] -liftIntegral TypeWord16 x = [|| x ||] -liftIntegral TypeWord32 x = [|| x ||] -liftIntegral TypeWord64 x = [|| x ||] +liftIntegral (SingleIntegralType t) = liftSingleIntegral t +liftIntegral (VectorIntegralType _ _) = liftVec + +liftSingleIntegral :: SingleIntegralType t -> t -> CodeQ t +liftSingleIntegral TypeInt8 x = [|| x ||] +liftSingleIntegral TypeInt16 x = [|| x ||] +liftSingleIntegral TypeInt32 x = [|| x ||] +liftSingleIntegral TypeInt64 x = [|| x ||] +liftSingleIntegral TypeWord8 x = [|| x ||] +liftSingleIntegral TypeWord16 x = [|| x ||] +liftSingleIntegral TypeWord32 x = [|| x ||] +liftSingleIntegral TypeWord64 x = [|| x ||] +liftSingleIntegral TypeInt128 (Int128 x y) = [|| Int128 x y ||] +liftSingleIntegral TypeWord128 (Word128 x y) = [|| Word128 x y ||] liftFloating :: FloatingType t -> t -> CodeQ t -liftFloating TypeHalf x = [|| x ||] -liftFloating TypeFloat x = [|| x ||] -liftFloating TypeDouble x = [|| x ||] +liftFloating (SingleFloatingType t) = liftSingleFloating t +liftFloating (VectorFloatingType _ _) = liftVec +liftSingleFloating :: SingleFloatingType t -> t -> CodeQ t +liftSingleFloating TypeFloat16 x = [|| x ||] +liftSingleFloating TypeFloat32 x = [|| x ||] +liftSingleFloating TypeFloat64 x = [|| x ||] +liftSingleFloating TypeFloat128 (Float128 x y) = [|| Float128 x y ||] -liftScalarType :: ScalarType t -> CodeQ (ScalarType t) -liftScalarType (SingleScalarType t) = [|| SingleScalarType $$(liftSingleType t) ||] -liftScalarType (VectorScalarType t) = [|| VectorScalarType $$(liftVectorType t) ||] -liftSingleType :: SingleType t -> CodeQ (SingleType t) -liftSingleType (NumSingleType t) = [|| NumSingleType $$(liftNumType t) ||] +liftScalarType :: ScalarType t -> CodeQ (ScalarType t) +liftScalarType (NumScalarType t) = [|| NumScalarType $$(liftNumType t) ||] +liftScalarType (BitScalarType t) = [|| BitScalarType $$(liftBitType t) ||] -liftVectorType :: VectorType t -> CodeQ (VectorType t) -liftVectorType (VectorType n t) = [|| VectorType n $$(liftSingleType t) ||] +liftBitType :: BitType t -> CodeQ (BitType t) +liftBitType TypeBit = [|| TypeBit ||] +liftBitType TypeMask{} = [|| TypeMask proxy# ||] liftNumType :: NumType t -> CodeQ (NumType t) liftNumType (IntegralNumType t) = [|| IntegralNumType $$(liftIntegralType t) ||] liftNumType (FloatingNumType t) = [|| FloatingNumType $$(liftFloatingType t) ||] -liftBoundedType :: BoundedType t -> CodeQ (BoundedType t) -liftBoundedType (IntegralBoundedType t) = [|| IntegralBoundedType $$(liftIntegralType t) ||] - liftIntegralType :: IntegralType t -> CodeQ (IntegralType t) -liftIntegralType TypeInt = [|| TypeInt ||] -liftIntegralType TypeInt8 = [|| TypeInt8 ||] -liftIntegralType TypeInt16 = [|| TypeInt16 ||] -liftIntegralType TypeInt32 = [|| TypeInt32 ||] -liftIntegralType TypeInt64 = [|| TypeInt64 ||] -liftIntegralType TypeWord = [|| TypeWord ||] -liftIntegralType TypeWord8 = [|| TypeWord8 ||] -liftIntegralType TypeWord16 = [|| TypeWord16 ||] -liftIntegralType TypeWord32 = [|| TypeWord32 ||] -liftIntegralType TypeWord64 = [|| TypeWord64 ||] +liftIntegralType (SingleIntegralType t) = [|| SingleIntegralType $$(liftSingleIntegralType t) ||] +liftIntegralType (VectorIntegralType _ t) = [|| VectorIntegralType proxy# $$(liftSingleIntegralType t) ||] + +liftSingleIntegralType :: SingleIntegralType t -> CodeQ (SingleIntegralType t) +liftSingleIntegralType TypeInt8 = [|| TypeInt8 ||] +liftSingleIntegralType TypeInt16 = [|| TypeInt16 ||] +liftSingleIntegralType TypeInt32 = [|| TypeInt32 ||] +liftSingleIntegralType TypeInt64 = [|| TypeInt64 ||] +liftSingleIntegralType TypeInt128 = [|| TypeInt128 ||] +liftSingleIntegralType TypeWord8 = [|| TypeWord8 ||] +liftSingleIntegralType TypeWord16 = [|| TypeWord16 ||] +liftSingleIntegralType TypeWord32 = [|| TypeWord32 ||] +liftSingleIntegralType TypeWord64 = [|| TypeWord64 ||] +liftSingleIntegralType TypeWord128 = [|| TypeWord128 ||] liftFloatingType :: FloatingType t -> CodeQ (FloatingType t) -liftFloatingType TypeHalf = [|| TypeHalf ||] -liftFloatingType TypeFloat = [|| TypeFloat ||] -liftFloatingType TypeDouble = [|| TypeDouble ||] +liftFloatingType (SingleFloatingType t) = [|| SingleFloatingType $$(liftSingleFloatingType t) ||] +liftFloatingType (VectorFloatingType _ t) = [|| VectorFloatingType proxy# $$(liftSingleFloatingType t) ||] + +liftSingleFloatingType :: SingleFloatingType t -> CodeQ (SingleFloatingType t) +liftSingleFloatingType TypeFloat16 = [|| TypeFloat16 ||] +liftSingleFloatingType TypeFloat32 = [|| TypeFloat32 ||] +liftSingleFloatingType TypeFloat64 = [|| TypeFloat64 ||] +liftSingleFloatingType TypeFloat128 = [|| TypeFloat128 ||] + + +-- Querying types +-- -------------- + +class IsScalar a where + scalarType :: ScalarType a + +class IsBit a where + bitType :: BitType a + +class IsNum a where + numType :: NumType a + +class IsIntegral a where + integralType :: IntegralType a + +class IsFloating a where + floatingType :: FloatingType a + +class IsSingleIntegral a where + singleIntegralType :: SingleIntegralType a +class IsSingleFloating a where + singleFloatingType :: SingleFloatingType a -- Type-level bit sizes -- -------------------- @@ -440,81 +374,114 @@ type family BitSize a :: Nat runQ $ do let - bits :: FiniteBits b => b -> Integer - bits = toInteger . finiteBitSize - - integralTypes :: [(Name, Integer)] - integralTypes = - [ (''Int, bits (undefined::Int)) - , (''Int8, 8) - , (''Int16, 16) - , (''Int32, 32) - , (''Int64, 64) - , (''Word, bits (undefined::Word)) - , (''Word8, 8) - , (''Word16, 16) - , (''Word32, 32) - , (''Word64, 64) - ] + integralTypes :: [Integer] + integralTypes = [8,16,32,64,128] floatingTypes :: [(Name, Integer)] floatingTypes = - [ (''Half, 16) - , (''Float, 32) - , (''Double, 64) + [ (''Half, 16) + , (''Float, 32) + , (''Double, 64) + , (''Float128, 128) ] - vectorTypes :: [(Name, Integer)] - vectorTypes = integralTypes ++ floatingTypes + mkIntegral :: String -> Integer -> Q [Dec] + mkIntegral name bits = + let t = conT $ mkName $ printf "%s%d" name bits + c = conE $ mkName $ printf "Type%s%d" name bits + in + [d| instance IsScalar $t where + scalarType = NumScalarType numType - mkIntegral :: Name -> Integer -> Q [Dec] - mkIntegral t n = - [d| instance IsIntegral $(conT t) where - integralType = $(conE (mkName ("Type" ++ nameBase t))) - - instance IsNum $(conT t) where + instance IsNum $t where numType = IntegralNumType integralType - instance IsBounded $(conT t) where - boundedType = IntegralBoundedType integralType + instance IsIntegral $t where + integralType = SingleIntegralType singleIntegralType + + instance IsSingleIntegral $t where + singleIntegralType = $c - instance IsSingle $(conT t) where - singleType = NumSingleType numType + instance KnownNat n => IsIntegral (Vec n $t) where + integralType = VectorIntegralType proxy# $c - instance IsScalar $(conT t) where - scalarType = SingleScalarType singleType + instance KnownNat n => IsNum (Vec n $t) where + numType = IntegralNumType integralType + + instance KnownNat n => IsScalar (Vec n $t) where + scalarType = NumScalarType numType - type instance BitSize $(conT t) = $(litT (numTyLit n)) + type instance BitSize $t = $(litT (numTyLit bits)) + type instance BitSize (Vec n $t) = n GHC.TypeLits.* $(litT (numTyLit bits)) |] mkFloating :: Name -> Integer -> Q [Dec] - mkFloating t n = - [d| instance IsFloating $(conT t) where - floatingType = $(conE (mkName ("Type" ++ nameBase t))) - - instance IsNum $(conT t) where + mkFloating name bits = + let t = conT name + c = conE $ mkName $ printf "TypeFloat%d" bits + in + [d| instance IsScalar $t where + scalarType = NumScalarType numType + + instance IsNum $t where numType = FloatingNumType floatingType - instance IsSingle $(conT t) where - singleType = NumSingleType numType + instance IsFloating $t where + floatingType = SingleFloatingType singleFloatingType - instance IsScalar $(conT t) where - scalarType = SingleScalarType singleType + instance IsSingleFloating $t where + singleFloatingType = $c - type instance BitSize $(conT t) = $(litT (numTyLit n)) - |] + instance KnownNat n => IsFloating (Vec n $t) where + floatingType = VectorFloatingType proxy# $c + + instance KnownNat n => IsNum (Vec n $t) where + numType = FloatingNumType floatingType - mkVector :: Name -> Integer -> Q [Dec] - mkVector t n = - [d| instance KnownNat n => IsScalar (Vec n $(conT t)) where - scalarType = VectorScalarType (VectorType (fromIntegral (natVal' (proxy# :: Proxy# n))) singleType) + instance KnownNat n => IsScalar (Vec n $t) where + scalarType = NumScalarType numType - type instance BitSize (Vec w $(conT t)) = w GHC.TypeLits.* $(litT (numTyLit n)) + type instance BitSize $t = $(litT (numTyLit bits)) + type instance BitSize (Vec n $t) = n GHC.TypeLits.* $(litT (numTyLit bits)) |] - -- - is <- mapM (uncurry mkIntegral) integralTypes + + ss <- mapM (mkIntegral "Int") integralTypes + us <- mapM (mkIntegral "Word") integralTypes fs <- mapM (uncurry mkFloating) floatingTypes - vs <- mapM (uncurry mkVector) vectorTypes -- - return (concat is ++ concat fs ++ concat vs) + return (concat ss ++ concat us ++ concat fs) + +type instance BitSize Bit = 1 +type instance BitSize (Vec n Bit) = n + +instance IsScalar Bit where + scalarType = BitScalarType bitType + +instance KnownNat n => IsScalar (Vec n Bit) where + scalarType = BitScalarType bitType + +instance IsBit Bit where + bitType = TypeBit + +instance KnownNat n => IsBit (Vec n Bit) where + bitType = TypeMask proxy# + + +-- Determine the underlying type of a Haskell Int and Word +-- +runQ [d| type INT = $( + case finiteBitSize (undefined::Int) of + 8 -> [t| Int8 |] + 16 -> [t| Int16 |] + 32 -> [t| Int32 |] + 64 -> [t| Int64 |] + _ -> error "I don't know what architecture I am" ) |] + +runQ [d| type WORD = $( + case finiteBitSize (undefined::Word) of + 8 -> [t| Word8 |] + 16 -> [t| Word16 |] + 32 -> [t| Word32 |] + 64 -> [t| Word64 |] + _ -> error "I don't know what architecture I am" ) |] diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 5b770f0ac..2fefd2185 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -1,25 +1,21 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE FunctionalDependencies #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE UnboxedTuples #-} -{-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TupleSections #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Vec --- Copyright : [2008..2020] The Accelerate Team +-- Copyright : [2008..2022] The Accelerate Team -- License : BSD3 -- -- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> @@ -30,34 +26,37 @@ module Data.Primitive.Vec ( -- * SIMD vector types - Vec(..), + Vec(..), KnownNat, Vec2, pattern Vec2, Vec3, pattern Vec3, Vec4, pattern Vec4, Vec8, pattern Vec8, Vec16, pattern Vec16, - listOfVec, - vecOfList, + toList, fromList, + extract, insert, splat, liftVec, - Vectoring(..) ) where -import Data.Kind -import Data.Proxy +import Data.Array.Accelerate.Error + import Control.Monad.ST -import Control.Monad.Reader import Data.Primitive.ByteArray import Data.Primitive.Types import Language.Haskell.TH.Extra import Prettyprinter +import Prelude + +import qualified Foreign.Storable as Foreign import GHC.Base ( isTrue# ) import GHC.Int import GHC.Prim import GHC.TypeLits +import GHC.Types ( IO(..) ) import GHC.Word +import qualified GHC.Exts as GHC -- Note: [Representing SIMD vector types] @@ -94,66 +93,102 @@ import GHC.Word -- data Vec (n :: Nat) a = Vec ByteArray# -class Vectoring vector a | vector -> a where - type IndexType vector :: Data.Kind.Type - vecIndex :: vector -> IndexType vector -> a - vecWrite :: vector -> IndexType vector -> a -> vector - vecEmpty :: vector - -instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where - type IndexType (Vec n a) = Int - vecIndex (Vec ba#) i@(I# iu#) = let - n :: Int - n = fromIntegral $ natVal $ Proxy @n - in if i >= 0 && i < n then indexByteArray# ba# iu# else error ("index " <> show i <> " out of range in Vec of size " <> show n) - vecWrite vec@(Vec ba#) i@(I# iu#) v = runST $ do - let n :: Int - n = fromIntegral $ natVal $ Proxy @n - mba <- unsafeThawByteArray (ByteArray ba#) - writeByteArray mba i v - ByteArray nba# <- unsafeFreezeByteArray mba - return $! Vec nba# - vecEmpty = mkVec - - -mkVec :: forall n a. (KnownNat n, Prim a) => Vec n a -mkVec = runST $ do - let n :: Int = fromIntegral $ natVal $ Proxy @n - mba <- newByteArray (n * sizeOf (undefined :: a)) - ByteArray ba# <- unsafeFreezeByteArray mba - return $! Vec ba# - - type role Vec nominal representational instance (Show a, Prim a, KnownNat n) => Show (Vec n a) where - show = vec . listOfVec + show = vec . toList where vec :: [a] -> String vec = show . group . encloseSep (flatAlt "< " "<") (flatAlt " >" ">") ", " . map viaShow -vecOfList :: forall n a. (KnownNat n, Prim a) => [a] -> Vec n a -vecOfList vs = runST $ do - let n :: Int = fromIntegral $ natVal $ Proxy @n - mba <- newByteArray (n * sizeOf (undefined :: a)) - zipWithM_ (writeByteArray mba) [0..n-1] vs - ByteArray ba# <- unsafeFreezeByteArray mba - return $! Vec ba# - -listOfVec :: forall a n. (Prim a, KnownNat n) => Vec n a -> [a] -listOfVec (Vec ba#) = go 0# +instance (Prim a, KnownNat n) => GHC.IsList (Vec n a) where + type Item (Vec n a) = a + {-# INLINE toList #-} + {-# INLINE fromList #-} + toList = toList + fromList = fromList + +instance (Foreign.Storable a, KnownNat n) => Foreign.Storable (Vec n a) where + {-# INLINE sizeOf #-} + {-# INLINE alignment #-} + {-# INLINE peek #-} + {-# INLINE poke #-} + + sizeOf _ = fromInteger (natVal' (proxy# @n)) * Foreign.sizeOf (undefined :: a) + alignment _ = Foreign.alignment (undefined :: a) + + peek (Ptr addr#) = + IO $ \s0 -> + case Foreign.sizeOf (undefined :: Vec n a) of { I# bytes# -> + case newAlignedPinnedByteArray# bytes# 16# s0 of { (# s1, mba# #) -> + case copyAddrToByteArray# addr# mba# 0# bytes# s1 of { s2 -> + case unsafeFreezeByteArray# mba# s2 of { (# s3, ba# #) -> + (# s3, Vec ba# #) + }}}} + + poke (Ptr addr#) (Vec ba#) = + IO $ \s0 -> + case Foreign.sizeOf (undefined :: Vec n a) of { I# bytes# -> + case copyByteArrayToAddr# ba# 0# addr# bytes# s0 of { + s1 -> (# s1, () #) + }} + +{-# INLINE toList #-} +toList :: forall a n. (Prim a, KnownNat n) => Vec n a -> [a] +toList (Vec ba#) = go 0# where go :: Int# -> [a] go i# | isTrue# (i# <# n#) = indexByteArray# ba# i# : go (i# +# 1#) | otherwise = [] + -- + !(I# n#) = fromInteger (natVal' (proxy# :: Proxy# n)) - !(I# n#) = fromIntegral (natVal' (proxy# :: Proxy# n)) +{-# INLINE fromList #-} +fromList :: forall a n. (Prim a, KnownNat n) => [a] -> Vec n a +fromList xs = + case byteArrayFromListN (fromInteger (natVal' (proxy# :: Proxy# n))) xs of + ByteArray ba# -> Vec ba# instance Eq (Vec n a) where Vec ba1# == Vec ba2# = ByteArray ba1# == ByteArray ba2# +-- | Extract an element from a vector at the given index +-- +{-# INLINE extract #-} +extract :: forall n a. (Prim a, KnownNat n) => Vec n a -> Int -> a +extract (Vec ba#) i@(I# i#) = + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + in boundsCheck "out of range" (i >= 0 && i < n) $ indexByteArray# ba# i# + +-- | Returns a new vector where the element at the specified index has been +-- replaced with the supplied value. +-- +{-# INLINE insert #-} +insert :: forall n a. (Prim a, KnownNat n) => Vec n a -> Int -> a -> Vec n a +insert (Vec ba#) i a = + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + bytes = n * sizeOf (undefined :: a) + in boundsCheck "out of range" (i >= 0 && i < n) + $ runST $ do + mba <- newByteArray bytes + copyByteArray mba 0 (ByteArray ba#) 0 bytes + writeByteArray mba i a + ByteArray ba'# <- unsafeFreezeByteArray mba + return $! Vec ba'# + +-- | Fill all lanes of a vector with the same value +-- +{-# INLINE splat #-} +splat :: forall n a. (Prim a, KnownNat n) => a -> Vec n a +splat x = runST $ do + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + mba <- newByteArray (n * sizeOf (undefined :: a)) + setByteArray mba 0 n x + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + -- Type synonyms for common SIMD vector types -- -- Note that non-power-of-two sized SIMD vectors are a bit dubious, and diff --git a/src/GHC/TypeLits/Extra.hs b/src/GHC/TypeLits/Extra.hs new file mode 100644 index 000000000..9d1f8dfa2 --- /dev/null +++ b/src/GHC/TypeLits/Extra.hs @@ -0,0 +1,31 @@ +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE TypeOperators #-} +-- | +-- Module : GHC.TypeLits.Extra +-- Copyright : [2012..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module GHC.TypeLits.Extra + where + +import Data.Typeable + +import GHC.Exts +import GHC.TypeLits +import Unsafe.Coerce + + +-- | We either get evidence that this function was instantiated with the same +-- type-level numbers, or 'Nothing'. +-- +{-# INLINEABLE sameNat' #-} +sameNat' :: (KnownNat n, KnownNat m) => Proxy# n -> Proxy# m -> Maybe (n :~: m) +sameNat' n m + | natVal' n == natVal' m = Just (unsafeCoerce Refl) -- same as 'GHC.TypeLits.sameNat' but for 'Proxy#' + | otherwise = Nothing + diff --git a/stack-9.2.yaml b/stack-9.2.yaml index 422f6750f..d0de63bce 100644 --- a/stack-9.2.yaml +++ b/stack-9.2.yaml @@ -2,8 +2,7 @@ # For advanced use and comprehensive documentation of the format, please see: # https://docs.haskellstack.org/en/stable/yaml_configuration/ -compiler: ghc-9.2.2 -resolver: nightly-2022-03-10 +resolver: nightly-2022-06-10 packages: - . @@ -11,7 +10,11 @@ packages: # extra-deps: [] # Override default flag values for local packages and extra-deps -# flags: {} +flags: + accelerate: + nofib: true + # debug: true + # float128: true # Extra package databases containing global packages # extra-package-dbs: [] From 30fa47e3d688e87b204bba4ff3a7403476f27661 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 13 Jun 2022 15:33:09 +0200 Subject: [PATCH 18/86] stack/9.2: drop non-default flags --- stack-9.2.yaml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/stack-9.2.yaml b/stack-9.2.yaml index d0de63bce..0fe59daa6 100644 --- a/stack-9.2.yaml +++ b/stack-9.2.yaml @@ -10,11 +10,7 @@ packages: # extra-deps: [] # Override default flag values for local packages and extra-deps -flags: - accelerate: - nofib: true - # debug: true - # float128: true +# flags: {} # Extra package databases containing global packages # extra-package-dbs: [] From 361582fd663369657164edb853841dfcd329e6f6 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 13 Jun 2022 16:35:47 +0200 Subject: [PATCH 19/86] stack/8.10: update resolver --- stack-8.10.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stack-8.10.yaml b/stack-8.10.yaml index a7a50be2d..99c7ebd8e 100644 --- a/stack-8.10.yaml +++ b/stack-8.10.yaml @@ -2,7 +2,7 @@ # For advanced use and comprehensive documentation of the format, please see: # https://docs.haskellstack.org/en/stable/yaml_configuration/ -resolver: lts-18.27 +resolver: lts-18.28 packages: - . From cdbeb0a8b2d2c306a9592f907a9dd5fd20aab7e9 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 13 Jun 2022 16:38:55 +0200 Subject: [PATCH 20/86] stack/9.0: update resolver --- stack-9.0.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stack-9.0.yaml b/stack-9.0.yaml index 19296b890..e94584e09 100644 --- a/stack-9.0.yaml +++ b/stack-9.0.yaml @@ -2,7 +2,7 @@ # For advanced use and comprehensive documentation of the format, please see: # https://docs.haskellstack.org/en/stable/yaml_configuration/ -resolver: nightly-2022-03-10 +resolver: lts-19.11 packages: - . From 9c892b197c77b5789371b37e897bde74385c5efa Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 13 Jun 2022 16:44:38 +0200 Subject: [PATCH 21/86] build fixes --- src/Data/Array/Accelerate/Classes/Eq.hs | 3 ++- src/Data/Array/Accelerate/Classes/Ord.hs | 3 ++- src/Data/Array/Accelerate/Data/Complex.hs | 12 +++++++++--- src/Data/Array/Accelerate/Test/NoFib/Base.hs | 1 - src/Data/Primitive/Bit.hs | 2 +- src/Data/Primitive/Vec.hs | 7 ++++++- 6 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/Data/Array/Accelerate/Classes/Eq.hs b/src/Data/Array/Accelerate/Classes/Eq.hs index 4f770a340..e6d8eb0d9 100644 --- a/src/Data/Array/Accelerate/Classes/Eq.hs +++ b/src/Data/Array/Accelerate/Classes/Eq.hs @@ -220,7 +220,8 @@ vcmp :: forall n a. KnownNat n => (Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool)) -> (Exp (Vec n a) -> Exp (Vec n a) -> Exp Bool) vcmp op x y = - let n = fromInteger $ natVal' (proxy# :: Proxy# n) + let n :: Int + n = fromInteger $ natVal' (proxy# :: Proxy# n) v = op x y -- cmp :: forall t. (Elt t, Num t, Bits t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) diff --git a/src/Data/Array/Accelerate/Classes/Ord.hs b/src/Data/Array/Accelerate/Classes/Ord.hs index 477e1adde..c0c3581e5 100644 --- a/src/Data/Array/Accelerate/Classes/Ord.hs +++ b/src/Data/Array/Accelerate/Classes/Ord.hs @@ -249,7 +249,8 @@ vcmp :: forall n a. KnownNat n => (Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool)) -> (Exp (Vec n a) -> Exp (Vec n a) -> Exp Bool) vcmp op x y = - let n = fromInteger $ natVal' (proxy# :: Proxy# n) + let n :: Int + n = fromInteger $ natVal' (proxy# :: Proxy# n) v = op x y -- cmp :: forall t. (Elt t, Num t, Bits t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) diff --git a/src/Data/Array/Accelerate/Data/Complex.hs b/src/Data/Array/Accelerate/Data/Complex.hs index c64cef772..ee1cee8fc 100644 --- a/src/Data/Array/Accelerate/Data/Complex.hs +++ b/src/Data/Array/Accelerate/Data/Complex.hs @@ -45,6 +45,7 @@ module Data.Array.Accelerate.Data.Complex ( ) where import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.Floating import Data.Array.Accelerate.Classes.Fractional @@ -216,10 +217,15 @@ constructComplex r@(Exp r') i@(Exp i') = (SmartExp (Const scalarType 1)) y) deconstructComplex :: forall a. Elt a => Exp (Complex a) -> (Exp a, Exp a) -deconstructComplex c@(Exp c') = +deconstructComplex (Exp c) = case complexR (eltR @a) of - ComplexTup -> let Pattern (r,i) = c in (r, i) - ComplexVec t -> let (r', i') = num t c' in (Exp r', Exp i') + ComplexTup -> + let i = SmartExp (Prj PairIdxRight c) + r = SmartExp (Prj PairIdxRight (SmartExp (Prj PairIdxLeft c))) + in (Exp r, Exp i) + ComplexVec t -> + let (r, i) = num t c + in (Exp r, Exp i) where num :: NumType (Prim.Vec2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) num (IntegralNumType t) = integral t diff --git a/src/Data/Array/Accelerate/Test/NoFib/Base.hs b/src/Data/Array/Accelerate/Test/NoFib/Base.hs index 31061c664..2475cea1a 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Base.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Base.hs @@ -24,7 +24,6 @@ import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Trafo.Sharing import Control.Monad -import Data.Primitive.Types import Hedgehog import qualified Hedgehog.Gen as Gen diff --git a/src/Data/Primitive/Bit.hs b/src/Data/Primitive/Bit.hs index 6c71e5d6b..acfbc75d2 100644 --- a/src/Data/Primitive/Bit.hs +++ b/src/Data/Primitive/Bit.hs @@ -36,13 +36,13 @@ import Control.Exception import qualified Foreign.Storable as Foreign import Data.Primitive.ByteArray -import Data.Primitive.Types import Data.Primitive.Vec ( Vec(..) ) import GHC.Base ( isTrue# ) import GHC.Generics import GHC.Int import GHC.Prim +import GHC.Ptr import GHC.TypeLits import GHC.Types ( IO(..) ) import GHC.Word diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 2fefd2185..f7d26c834 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -1,4 +1,5 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} @@ -58,6 +59,10 @@ import GHC.Types ( IO(..) ) import GHC.Word import qualified GHC.Exts as GHC +#if __GLASGOW_HASKELL__ < 808 +import GHC.Ptr +#endif + -- Note: [Representing SIMD vector types] -- @@ -116,7 +121,7 @@ instance (Foreign.Storable a, KnownNat n) => Foreign.Storable (Vec n a) where {-# INLINE peek #-} {-# INLINE poke #-} - sizeOf _ = fromInteger (natVal' (proxy# @n)) * Foreign.sizeOf (undefined :: a) + sizeOf _ = fromInteger (natVal' (proxy# :: Proxy# n)) * Foreign.sizeOf (undefined :: a) alignment _ = Foreign.alignment (undefined :: a) peek (Ptr addr#) = From 0b769c68a22d3fe0175335f36ec3e21ee6ee1e47 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 13 Jun 2022 18:58:45 +0200 Subject: [PATCH 22/86] doctest fixes --- .github/workflows/ci-linux.yml | 2 +- src/Data/Array/Accelerate/Sugar/Array.hs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci-linux.yml b/.github/workflows/ci-linux.yml index b3258a84a..7dd1d722e 100644 --- a/.github/workflows/ci-linux.yml +++ b/.github/workflows/ci-linux.yml @@ -72,7 +72,7 @@ jobs: - name: Test doctest run: stack test accelerate:doctest $STACK_FLAGS - if: ${{ matrix.ghc != '9.2' }} + if: ${{ matrix.ghc != '8.6' && matrix.ghc != '8.8' }} - name: Test nofib run: stack test accelerate:nofib-interpreter $STACK_FLAGS diff --git a/src/Data/Array/Accelerate/Sugar/Array.hs b/src/Data/Array/Accelerate/Sugar/Array.hs index 4e6d9fb21..5d48c245a 100644 --- a/src/Data/Array/Accelerate/Sugar/Array.hs +++ b/src/Data/Array/Accelerate/Sugar/Array.hs @@ -38,6 +38,7 @@ import GHC.Generics import qualified GHC.Exts as GHC -- $setup +-- >>> import Prelude -- >>> :seti -XOverloadedLists From eacdced4aab2b0499fdb8f13b8c6225d8caa2248 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 13 Jun 2022 21:35:57 +0200 Subject: [PATCH 23/86] add pattern synonyms for shape constructors --- accelerate.cabal | 1 + src/Data/Array/Accelerate.hs | 9 +- src/Data/Array/Accelerate/Lift.hs | 22 ++--- src/Data/Array/Accelerate/Pattern.hs | 46 +-------- src/Data/Array/Accelerate/Pattern/Shape.hs | 103 +++++++++++++++++++++ src/Data/Array/Accelerate/Prelude.hs | 79 ++++++++-------- 6 files changed, 164 insertions(+), 96 deletions(-) create mode 100644 src/Data/Array/Accelerate/Pattern/Shape.hs diff --git a/accelerate.cabal b/accelerate.cabal index f120c4a64..5fc01625c 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -486,6 +486,7 @@ library Data.Array.Accelerate.Pattern.Either Data.Array.Accelerate.Pattern.Maybe Data.Array.Accelerate.Pattern.Ordering + Data.Array.Accelerate.Pattern.Shape Data.Array.Accelerate.Pattern.TH Data.Array.Accelerate.Prelude Data.Array.Accelerate.Pretty.Graphviz diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 22afa2a6b..958088e19 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -179,9 +179,9 @@ module Data.Array.Accelerate ( -- *** Array shapes & indices -- $shapes_and_indices -- - Z(..), (:.)(..), + Z, (:.), DIM0, DIM1, DIM2, DIM3, DIM4, DIM5, DIM6, DIM7, DIM8, DIM9, - Shape, Slice(..), All(..), Any(..), + Shape, Slice(..), All, Any, -- Split(..), Divide(..), Division(..), -- ** Array access @@ -349,7 +349,7 @@ module Data.Array.Accelerate ( pattern T7, pattern T8, pattern T9, pattern T10, pattern T11, pattern T12, pattern T13, pattern T14, pattern T15, pattern T16, - pattern Z_, pattern Ix, pattern (::.), pattern All_, pattern Any_, + pattern Z, pattern (:.), pattern All, pattern Any, pattern I0, pattern I1, pattern I2, pattern I3, pattern I4, pattern I5, pattern I6, pattern I7, pattern I8, pattern I9, @@ -448,13 +448,14 @@ import Data.Array.Accelerate.Data.Either import Data.Array.Accelerate.Data.Maybe import Data.Array.Accelerate.Language import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Shape import Data.Array.Accelerate.Pattern.TH import Data.Array.Accelerate.Prelude import Data.Array.Accelerate.Pretty () -- show instances import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Array ( Array, Arrays, Scalar, Vector, Matrix, Segments, fromFunction, fromFunctionM, toList, fromList ) import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Sugar.Shape hiding ( size, toIndex, fromIndex, intersect ) +import Data.Array.Accelerate.Sugar.Shape hiding ( Z(..), (:.)(..), Any(..), All(..), size, toIndex, fromIndex, intersect ) import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import qualified Data.Array.Accelerate.Sugar.Array as S diff --git a/src/Data/Array/Accelerate/Lift.hs b/src/Data/Array/Accelerate/Lift.hs index d0a7093f7..539dd2823 100644 --- a/src/Data/Array/Accelerate/Lift.hs +++ b/src/Data/Array/Accelerate/Lift.hs @@ -35,11 +35,11 @@ module Data.Array.Accelerate.Lift ( ) where import Data.Array.Accelerate.AST.Idx -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Shape import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Array import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Sugar.Shape +import Data.Array.Accelerate.Sugar.Shape ( Shape, DIM1 ) import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra hiding ( Exp ) @@ -76,17 +76,17 @@ lift3 f x y z = lift $ f (unlift x) (unlift y) (unlift z) -- | Lift a unary function to a computation over rank-1 indices. -- ilift1 :: (Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1 -ilift1 f = lift1 (\(Z:.i) -> Z :. f i) +ilift1 f (Z:.i) = Z :. f i -- | Lift a binary function to a computation over rank-1 indices. -- ilift2 :: (Exp Int -> Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1 -> Exp DIM1 -ilift2 f = lift2 (\(Z:.i) (Z:.j) -> Z :. f i j) +ilift2 f (Z:.i) (Z:.j) = Z :. f i j -- | Lift a ternary function to a computation over rank-1 indices. -- ilift3 :: (Exp Int -> Exp Int -> Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1 -> Exp DIM1 -> Exp DIM1 -ilift3 f = lift3 (\(Z:.i) (Z:.j) (Z:.k) -> Z :. f i j k) +ilift3 f (Z :. i) (Z :. j) (Z :. k) = Z :. f i j k -- | The class of types @e@ which can be lifted into @c@. @@ -149,28 +149,28 @@ instance Unlift Acc (Acc a) where instance Lift Exp Z where type Plain Z = Z - lift _ = Z_ + lift _ = Z instance Unlift Exp Z where unlift _ = Z instance (Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. Int) where type Plain (ix :. Int) = Plain ix :. Int - lift (ix :. i) = lift ix ::. lift i + lift (ix :. i) = lift ix :. lift i instance (Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. All) where type Plain (ix :. All) = Plain ix :. All - lift (ix :. i) = lift ix ::. constant i + lift (ix :. i) = lift ix :. constant i instance (Elt e, Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. Exp e) where type Plain (ix :. Exp e) = Plain ix :. e - lift (ix :. i) = lift ix ::. i + lift (ix :. i) = lift ix :. i instance {-# OVERLAPPABLE #-} (Elt e, Elt (Plain ix), Unlift Exp ix) => Unlift Exp (ix :. Exp e) where - unlift (ix ::. i) = unlift ix :. i + unlift (ix :. i) = unlift ix :. i instance {-# OVERLAPPABLE #-} (Elt e, Elt ix) => Unlift Exp (Exp ix :. Exp e) where - unlift (ix ::. i) = ix :. i + unlift (ix :. i) = ix :. i instance (Shape sh, Elt (Any sh)) => Lift Exp (Any sh) where type Plain (Any sh) = Any sh diff --git a/src/Data/Array/Accelerate/Pattern.hs b/src/Data/Array/Accelerate/Pattern.hs index 4de7d22a3..faa730af3 100644 --- a/src/Data/Array/Accelerate/Pattern.hs +++ b/src/Data/Array/Accelerate/Pattern.hs @@ -30,7 +30,6 @@ module Data.Array.Accelerate.Pattern ( pattern T7, pattern T8, pattern T9, pattern T10, pattern T11, pattern T12, pattern T13, pattern T14, pattern T15, pattern T16, - pattern Z_, pattern Ix, pattern (::.), pattern All_, pattern Any_, pattern I0, pattern I1, pattern I2, pattern I3, pattern I4, pattern I5, pattern I6, pattern I7, pattern I8, pattern I9, @@ -40,11 +39,12 @@ module Data.Array.Accelerate.Pattern ( ) where import Data.Array.Accelerate.AST.Idx +import Data.Array.Accelerate.Pattern.Shape import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Array import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Sugar.Shape +-- import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type @@ -73,44 +73,6 @@ class IsSIMD context a b where vbuilder :: b -> context a vmatcher :: context a -> b --- | Pattern synonyms for indices, which may be more convenient to use than --- 'Data.Array.Accelerate.Lift.lift' and --- 'Data.Array.Accelerate.Lift.unlift'. --- -pattern Z_ :: Exp DIM0 -pattern Z_ = Pattern Z -{-# COMPLETE Z_ #-} - -infixl 3 ::. -pattern (::.) :: (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b) -pattern a ::. b = Pattern (a :. b) -{-# COMPLETE (::.) #-} - -infixl 3 `Ix` -pattern Ix :: (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b) -pattern a `Ix` b = a ::. b -{-# COMPLETE Ix #-} - -pattern All_ :: Exp All -pattern All_ <- (const True -> True) - where All_ = constant All -{-# COMPLETE All_ #-} - -pattern Any_ :: (Shape sh, Elt (Any sh)) => Exp (Any sh) -pattern Any_ <- (const True -> True) - where Any_ = constant Any -{-# COMPLETE Any_ #-} - --- IsPattern instances for Shape nil and cons --- -instance IsPattern Exp Z Z where - builder _ = constant Z - matcher _ = Z - -instance (Elt a, Elt b) => IsPattern Exp (a :. b) (Exp a :. Exp b) where - builder (Exp a :. Exp b) = Exp $ SmartExp $ Pair a b - matcher (Exp t) = Exp (SmartExp $ Prj PairIdxLeft t) :. Exp (SmartExp $ Prj PairIdxRight t) - -- IsPattern instances for up to 16-tuples (Acc and Exp). TH takes care of -- the (unremarkable) boilerplate for us. @@ -345,14 +307,14 @@ runQ $ do let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] ts = map varT xs name = mkName ('I':show n) - ix = mkName "Ix" + ix = mkName ":." cst = tupT (map (\t -> [t| Elt $t |]) ts) dim = foldl (\h t -> [t| $h :. $t |]) [t| Z |] ts sig = foldr (\t r -> [t| Exp $t -> $r |]) [t| Exp $dim |] ts in sequence [ patSynSigD name [t| $cst => $sig |] - , patSynD name (prefixPatSyn xs) implBidir (foldl (\ps p -> infixP ps ix (varP p)) [p| Z_ |] xs) + , patSynD name (prefixPatSyn xs) implBidir (foldl (\ps p -> infixP ps ix (varP p)) [p| Z |] xs) , pragCompleteD [name] Nothing ] diff --git a/src/Data/Array/Accelerate/Pattern/Shape.hs b/src/Data/Array/Accelerate/Pattern/Shape.hs new file mode 100644 index 000000000..732e9906c --- /dev/null +++ b/src/Data/Array/Accelerate/Pattern/Shape.hs @@ -0,0 +1,103 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +-- | +-- Module : Data.Array.Accelerate.Pattern.Shape +-- Copyright : [2018..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Pattern.Shape ( + + pattern Z, Z, + pattern (:.), (:.), + pattern All, All, + pattern Any, Any, + +) where + +import Data.Array.Accelerate.AST.Idx +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Shape ( (:.), Z, All, Any ) +import qualified Data.Array.Accelerate.Sugar.Shape as Sugar + + +pattern Z :: IsShapeZ z => z +pattern Z <- (z_matcher -> True) + where Z = z_builder +{-# COMPLETE Z #-} + +pattern All :: IsShapeAll all => all +pattern All <- (all_matcher -> True) + where All = all_builder +{-# COMPLETE All #-} + +pattern Any :: IsShapeAny any => any +pattern Any <- (any_matcher -> True) + where Any = any_builder +{-# COMPLETE Any #-} + +infixl 3 :. +pattern (:.) :: IsShapeSnoc t h s => t -> h -> s +pattern t :. h <- (snoc_matcher -> (t Sugar.:. h)) + where t :. h = snoc_builder t h +{-# COMPLETE (:.) #-} + + +class IsShapeZ z where + z_matcher :: z -> Bool + z_builder :: z + +instance IsShapeZ Z where + z_matcher _ = True + z_builder = Sugar.Z + +instance IsShapeZ (Exp Z) where + z_matcher _ = True + z_builder = constant Sugar.Z + +class IsShapeAll all where + all_matcher :: all -> Bool + all_builder :: all + +instance IsShapeAll All where + all_matcher _ = True + all_builder = Sugar.All + +instance IsShapeAll (Exp All) where + all_matcher _ = True + all_builder = constant Sugar.All + +class IsShapeAny any where + any_matcher :: any -> Bool + any_builder :: any + +instance IsShapeAny (Any sh) where + any_matcher _ = True + any_builder = Sugar.Any + +instance Elt (Any sh) => IsShapeAny (Exp (Any sh)) where + any_matcher _ = True + any_builder = constant Sugar.Any + +class IsShapeSnoc t h s | s -> t, s -> h where + snoc_matcher :: s -> (t :. h) + snoc_builder :: t -> h -> s + +instance IsShapeSnoc (Exp t) (Exp h) (Exp (t :. h)) where + snoc_builder (Exp a) (Exp b) = Exp $ SmartExp $ Pair a b + snoc_matcher (Exp t) = Exp (SmartExp $ Prj PairIdxLeft t) Sugar.:. Exp (SmartExp $ Prj PairIdxRight t) + +instance IsShapeSnoc t h (t :. h) where + snoc_builder = (Sugar.:.) + snoc_matcher = id + diff --git a/src/Data/Array/Accelerate/Prelude.hs b/src/Data/Array/Accelerate/Prelude.hs index 2d52a1231..f0b144dba 100644 --- a/src/Data/Array/Accelerate/Prelude.hs +++ b/src/Data/Array/Accelerate/Prelude.hs @@ -123,10 +123,11 @@ import Data.Array.Accelerate.Language import Data.Array.Accelerate.Lift import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Pattern.Maybe +import Data.Array.Accelerate.Pattern.Shape import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Array ( Arrays, Array, Scalar, Vector, Segments, fromList ) import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Sugar.Shape ( Shape, Slice, Z(..), (:.)(..), All(..), DIM1, DIM2, empty ) +import Data.Array.Accelerate.Sugar.Shape ( Shape, Slice, DIM1, DIM2, empty ) import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Eq @@ -982,7 +983,7 @@ scanlSeg -> Acc (Array (sh:.Int) e) scanlSeg f z arr seg = if null arr || null flags - then fill (sh ::. sz + length seg) z + then fill (sh :. sz + length seg) z else scanl1Seg f arr' seg' where -- Segmented exclusive scan is implemented by first injecting the seed @@ -993,11 +994,11 @@ scanlSeg f z arr seg = -- overlaying the input data in all places other than at the start of -- a segment. -- - sh ::. sz = shape arr + sh :. sz = shape arr seg' = map (+1) seg arr' = permute const - (fill (sh ::. sz + length seg) z) - (\(sx ::. i) -> Just_ (sx ::. i + fromIntegral (inc ! I1 i))) + (fill (sh :. sz + length seg) z) + (\(sx :. i) -> Just_ (sx :. i + fromIntegral (inc ! I1 i))) (take (length flags) arr) -- Each element in the segments must be shifted to the right one additional @@ -1055,7 +1056,7 @@ scanl'Seg -> Acc (Array (sh:.Int) e, Array (sh:.Int) e) scanl'Seg f z arr seg = if null arr - then T2 arr (fill (indexTail (shape arr) ::. length seg) z) + then T2 arr (fill (indexTail (shape arr) :. length seg) z) else T2 body sums where -- Segmented scan' is implemented by deconstructing a segmented exclusive @@ -1075,8 +1076,8 @@ scanl'Seg f z arr seg = seg' = map (+1) seg tails = zipWith (+) seg $ prescanl (+) 0 seg' sums = backpermute - (indexTail (shape arr') ::. length seg) - (\(sz ::. i) -> sz ::. fromIntegral (tails ! I1 i)) + (indexTail (shape arr') :. length seg) + (\(sz :. i) -> sz :. fromIntegral (tails ! I1 i)) arr' -- Slice out the body of each segment. @@ -1095,8 +1096,8 @@ scanl'Seg f z arr seg = len = offset ! I1 (length offset - 1) body = backpermute - (indexTail (shape arr) ::. fromIntegral len) - (\(sz ::. i) -> sz ::. i + fromIntegral (inc ! I1 i)) + (indexTail (shape arr) :. fromIntegral len) + (\(sz :. i) -> sz :. i + fromIntegral (inc ! I1 i)) arr' @@ -1133,7 +1134,7 @@ scanl'Seg f z arr seg = -- 40, 41, 83, 126, 170, 45, 91, 138] -- scanl1Seg - :: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) + :: forall sh i e. (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Acc (Array (sh:.Int) e) -> Acc (Segments i) @@ -1141,7 +1142,7 @@ scanl1Seg scanl1Seg f arr seg = map snd . scanl1 (segmentedL f) - $ zip (replicate (lift (indexTail (shape arr) :. All)) (mkHeadFlags seg)) arr + $ zip (replicate @(sh :. All) (indexTail (shape arr) :. All) (mkHeadFlags seg)) arr -- |Segmented version of 'prescanl'. -- @@ -1203,10 +1204,10 @@ scanrSeg -> Acc (Array (sh:.Int) e) scanrSeg f z arr seg = if null arr || null flags - then fill (sh ::. sz + length seg) z + then fill (sh :. sz + length seg) z else scanr1Seg f arr' seg' where - sh ::. sz = shape arr + sh :. sz = shape arr -- Using technique described for 'scanlSeg', where we intersperse the array -- with the seed element at the start of each segment, and then perform an @@ -1217,8 +1218,8 @@ scanrSeg f z arr seg = seg' = map (+1) seg arr' = permute const - (fill (sh ::. sz + length seg) z) - (\(sx ::. i) -> Just_ (sx ::. i + fromIntegral (inc !! i) - 1)) + (fill (sh :. sz + length seg) z) + (\(sx :. i) -> Just_ (sx :. i + fromIntegral (inc !! i) - 1)) (drop (sz - length flags) arr) @@ -1262,7 +1263,7 @@ scanr'Seg -> Acc (Array (sh:.Int) e, Array (sh:.Int) e) scanr'Seg f z arr seg = if null arr - then T2 arr (fill (indexTail (shape arr) ::. length seg) z) + then T2 arr (fill (indexTail (shape arr) :. length seg) z) else T2 body sums where -- Using technique described for scanl'Seg @@ -1273,16 +1274,16 @@ scanr'Seg f z arr seg = seg' = map (+1) seg heads = prescanl (+) 0 seg' sums = backpermute - (indexTail (shape arr') ::. length seg) - (\(sz ::.i) -> sz ::. fromIntegral (heads ! I1 i)) + (indexTail (shape arr') :. length seg) + (\(sz :.i) -> sz :. fromIntegral (heads ! I1 i)) arr' -- body segments flags = mkHeadFlags seg inc = scanl1 (+) flags body = backpermute - (indexTail (shape arr) ::. indexHead (shape flags)) - (\(sz ::. i) -> sz ::. i + fromIntegral (inc ! I1 i)) + (indexTail (shape arr) :. indexHead (shape flags)) + (\(sz :. i) -> sz :. i + fromIntegral (inc ! I1 i)) arr' @@ -1310,7 +1311,7 @@ scanr'Seg f z arr seg = -- 40, 170, 129, 87, 44, 138, 93, 47] -- scanr1Seg - :: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) + :: forall sh i e. (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Acc (Array (sh:.Int) e) -> Acc (Segments i) @@ -1318,7 +1319,7 @@ scanr1Seg scanr1Seg f arr seg = map snd . scanr1 (segmentedR f) - $ zip (replicate (lift (indexTail (shape arr) :. All)) (mkTailFlags seg)) arr + $ zip (replicate @(sh :. All) (indexTail (shape arr) :. All) (mkTailFlags seg)) arr -- |Segmented version of 'prescanr'. @@ -1415,7 +1416,7 @@ segmentedR f y x = segmentedL (flip f) x y -- base library, because it results in too many ambiguity errors. -- index1' :: (Integral i, FromIntegral i Int) => Exp i -> Exp DIM1 -index1' i = lift (Z :. fromIntegral i) +index1' i = Z :. fromIntegral i -- Reshaping of arrays @@ -1663,7 +1664,7 @@ compact keep arr result = permute const dummy prj arr in if null arr - then T2 emptyArray (fill Z_ 0) + then T2 emptyArray (fill Z 0) else if the len == unindex1 (shape arr) then T2 arr len @@ -2339,7 +2340,7 @@ sfoldl :: (Shape sh, Elt a, Elt b) -> Exp a sfoldl f z ix xs = let n = indexHead (shape xs) - step (T2 i acc) = T2 (i+1) (acc `f` (xs ! (ix ::. i))) + step (T2 i acc) = T2 (i+1) (acc `f` (xs ! (ix :. i))) in snd $ while (\v -> fst v < n) step (T2 0 z) @@ -2382,17 +2383,17 @@ uncurry f t = let (x, y) = unlift t in f x y -- | The one index for a rank-0 array. -- index0 :: Exp Z -index0 = lift Z +index0 = Z -- | Turn an 'Int' expression into a rank-1 indexing expression. -- index1 :: Elt i => Exp i -> Exp (Z :. i) -index1 i = lift (Z :. i) +index1 i = Z :. i -- | Turn a rank-1 indexing expression into an 'Int' expression. -- unindex1 :: Elt i => Exp (Z :. i) -> Exp i -unindex1 ix = let Z :. i = unlift ix in i +unindex1 (Z :. i) = i -- | Creates a rank-2 index from two Exp Int`s -- @@ -2401,7 +2402,7 @@ index2 => Exp i -> Exp i -> Exp (Z :. i :. i) -index2 i j = lift (Z :. i :. j) +index2 i j = Z :. i :. j -- | Destructs a rank-2 index to an Exp tuple of two Int`s. -- @@ -2409,7 +2410,7 @@ unindex2 :: Elt i => Exp (Z :. i :. i) -> Exp (i, i) -unindex2 (Z_ ::. i ::. j) = T2 i j +unindex2 (Z :. i :. j) = T2 i j -- | Create a rank-3 index from three Exp Int`s -- @@ -2419,14 +2420,14 @@ index3 -> Exp i -> Exp i -> Exp (Z :. i :. i :. i) -index3 k j i = Z_ ::. k ::. j ::. i +index3 k j i = Z :. k :. j :. i -- | Destruct a rank-3 index into an Exp tuple of Int`s unindex3 :: Elt i => Exp (Z :. i :. i :. i) -> Exp (i, i, i) -unindex3 (Z_ ::. k ::. j ::. i) = T3 k j i +unindex3 (Z :. k :. j :. i) = T3 k j i -- Array operations with a scalar result @@ -2615,14 +2616,14 @@ emptyArray = fill (constant empty) undef -- Imported from `lens-accelerate` (which provides more general Field instances) -- _1 :: forall sh. Elt sh => Lens' (Exp (sh:.Int)) (Exp Int) -_1 = lens (\ix -> let _ :. x = unlift ix :: Exp sh :. Exp Int in x) - (\ix x -> let sh :. _ = unlift ix :: Exp sh :. Exp Int in lift (sh :. x)) +_1 = lens (\ix -> let _ :. x = ix in x) + (\ix x -> let sh :. _ = ix in sh :. x) _2 :: forall sh. Elt sh => Lens' (Exp (sh:.Int:.Int)) (Exp Int) -_2 = lens (\ix -> let _ :. y :. _ = unlift ix :: Exp sh :. Exp Int :. Exp Int in y) - (\ix y -> let sh :. _ :. x = unlift ix :: Exp sh :. Exp Int :. Exp Int in lift (sh :. y :. x)) +_2 = lens (\ix -> let _ :. y :. _ = ix in y) + (\ix y -> let sh :. _ :. x = ix in sh :. y :. x) _3 :: forall sh. Elt sh => Lens' (Exp (sh:.Int:.Int:.Int)) (Exp Int) -_3 = lens (\ix -> let _ :. z :. _ :. _ = unlift ix :: Exp sh :. Exp Int :. Exp Int :. Exp Int in z) - (\ix z -> let sh :. _ :. y :. x = unlift ix :: Exp sh :. Exp Int :. Exp Int :. Exp Int in lift (sh :. z :. y :. x)) +_3 = lens (\ix -> let _ :. z :. _ :. _ = ix in z) + (\ix z -> let sh :. _ :. y :. x = ix in sh :. z :. y :. x) From 9d61a0165591994cd244bbb1cc14b99ac6796a9d Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 13 Jun 2022 21:45:45 +0200 Subject: [PATCH 24/86] update CHANGELOG.md --- CHANGELOG.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 53ff0a1a7..f9c6245c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,12 +9,24 @@ Policy (PVP)](https://pvp.haskell.org) ## [next] ### Added * Added debugging functions in module `Data.Array.Accelerate.Debug.Trace` ([#485](https://github.com/AccelerateHS/accelerate/pull/485)) + * Support for SIMD data types in expressions. Support for storing a type `a` + in a SIMD vector can be added by deriving an instance for the class `SIMD`. + * Instances for SIMD types in basic numeric classes (e.g. `Num` for `<4 x Float>`) + * Support for 128-bit integers (signed and unsigned) + * Support for 128-bit floating point types (build with cabal flag `float128`) ### Changed * Removed dependency on lens ([#493](https://github.com/AccelerateHS/accelerate/pull/493)) + * The shape constructors (e.g. `Z` and `(:.)`) are now pattern synonyms that + work on both Haskell values and embedded expressions ### Fixed * Graphviz graph generation of `-ddump-dot` and `-ddump-simpl-dot` ([#384](https://github.com/AccelerateHS/accelerate/issues/384)) + * Bug in `Semigroup` instance for `Maybe` ([#517](https://github.com/AccelerateHS/accelerate/issues/517)) + * Bug in `Ord` instances or tuple types + +### Removed + * Pattern synonyms `Z_`, `(::.)`, `Any_`, `All_`, which are no longer required ### Contributors From 007e3cd7740a409dce03d6654c9f574b74d2ce5b Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 14 Jun 2022 11:57:54 +0200 Subject: [PATCH 25/86] test fixes --- src/Data/Array/Accelerate/Language.hs | 26 ++++++------ .../Accelerate/Test/NoFib/Imaginary/SAXPY.hs | 6 +-- .../Accelerate/Test/NoFib/Issues/Issue102.hs | 33 +++++++-------- .../Accelerate/Test/NoFib/Issues/Issue137.hs | 14 +++---- .../Accelerate/Test/NoFib/Issues/Issue264.hs | 28 ++++++------- .../Accelerate/Test/NoFib/Issues/Issue364.hs | 15 ++++--- .../Accelerate/Test/NoFib/Issues/Issue427.hs | 1 + .../Accelerate/Test/NoFib/Issues/Issue439.hs | 2 +- .../Test/NoFib/Prelude/Backpermute.hs | 19 +++++---- .../Accelerate/Test/NoFib/Prelude/Filter.hs | 10 ++--- .../Accelerate/Test/NoFib/Prelude/Fold.hs | 14 +++---- .../Accelerate/Test/NoFib/Prelude/SIMD.hs | 2 - .../Accelerate/Test/NoFib/Prelude/Scan.hs | 40 +++++++++---------- .../Array/Accelerate/Test/NoFib/Sharing.hs | 6 +-- 14 files changed, 107 insertions(+), 109 deletions(-) diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index da65e597d..d1b2d3bdb 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -176,7 +176,7 @@ unit (Exp e) = Acc $ SmartAcc $ Unit (eltR @e) e -- ...we can replicate these elements to form a two-dimensional array either by -- replicating those elements as new rows: -- --- >>> run $ replicate (constant (Z :. (4::Int) :. All)) (use vec) +-- >>> run $ replicate @(Z :. Int :. All) (Z :. 4 :. All) (use vec) -- Matrix (Z :. 4 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, @@ -185,7 +185,7 @@ unit (Exp e) = Acc $ SmartAcc $ Unit (eltR @e) e -- -- ...or as columns: -- --- >>> run $ replicate (Z_ ::. All_ ::. (4 :: Exp Int)) (use vec) +-- >>> run $ replicate @(Z :. All :. Int) (Z :. All :. 4) (use vec) -- Matrix (Z :. 10 :. 4) -- [ 0, 0, 0, 0, -- 1, 1, 1, 1, @@ -201,7 +201,7 @@ unit (Exp e) = Acc $ SmartAcc $ Unit (eltR @e) e -- Replication along more than one dimension is also possible. Here we replicate -- twice across the first dimension and three times across the third dimension: -- --- >>> run $ replicate (constant (Z :. (2::Int) :. All :. (3::Int))) (use vec) +-- >>> run $ replicate @(Z :. Int :. All :. Int) (Z :. 2 :. All :. 3) (use vec) -- Array (Z :. 2 :. 10 :. 3) [0,0,0,1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6,7,7,7,8,8,8,9,9,9,0,0,0,1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6,7,7,7,8,8,8,9,9,9] -- -- The marker 'Any' can be used in the slice specification to match against some @@ -209,8 +209,8 @@ unit (Exp e) = Acc $ SmartAcc $ Unit (eltR @e) e -- type variable @sh@ takes. -- -- >>> :{ --- let rep0 :: (Shape sh, Elt e) => Exp Int -> Acc (Array sh e) -> Acc (Array (sh :. Int) e) --- rep0 n a = replicate (Any_ ::. n) a +-- let rep0 :: forall sh e. (Shape sh, Elt e) => Exp Int -> Acc (Array sh e) -> Acc (Array (sh :. Int) e) +-- rep0 n a = replicate @(Any sh :. Int) (Any :. n) a -- :} -- -- >>> let x = unit 42 :: Acc (Scalar Int) @@ -233,8 +233,8 @@ unit (Exp e) = Acc $ SmartAcc $ Unit (eltR @e) e -- Of course, 'Any' and 'All' can be used together. -- -- >>> :{ --- let rep1 :: (Shape sh, Elt e) => Exp Int -> Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int :. Int) e) --- rep1 n a = replicate (Any_ ::. n ::. All_) a +-- let rep1 :: forall sh e. (Shape sh, Elt e) => Exp Int -> Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int :. Int) e) +-- rep1 n a = replicate @(Any sh :. Int :. All) (Any :. n :. All) a -- :} -- -- >>> run $ rep1 5 (use vec) @@ -332,13 +332,13 @@ reshape = Acc $$ applyAcc (Reshape $ shapeR @sh) -- ...will can select a specific row to yield a one dimensional result by fixing -- the row index (2) while allowing the column index to vary (via 'All'): -- --- >>> run $ slice (use mat) (constant (Z :. (2::Int) :. All)) +-- >>> run $ slice @(Z :. Int :. All) (use mat) (Z :. 2 :. All) -- Vector (Z :. 10) [20,21,22,23,24,25,26,27,28,29] -- -- A fully specified index (with no 'All's) returns a single element (zero -- dimensional array). -- --- >>> run $ slice (use mat) (constant (Z :. 4 :. 2 :: DIM2)) +-- >>> run $ slice @DIM2 (use mat) (Z :. 4 :. 2) -- Scalar Z [42] -- -- The marker 'Any' can be used in the slice specification to match against some @@ -347,8 +347,8 @@ reshape = Acc $$ applyAcc (Reshape $ shapeR @sh) -- -- >>> :{ -- let --- sl0 :: (Shape sh, Elt e) => Acc (Array (sh:.Int) e) -> Exp Int -> Acc (Array sh e) --- sl0 a n = slice a (Any_ ::. n) +-- sl0 :: forall sh e. (Shape sh, Elt e) => Acc (Array (sh:.Int) e) -> Exp Int -> Acc (Array sh e) +-- sl0 a n = slice @(Any sh :. Int) a (Any :. n) -- :} -- -- >>> let vec = fromList (Z:.10) [0..] :: Vector Int @@ -361,8 +361,8 @@ reshape = Acc $$ applyAcc (Reshape $ shapeR @sh) -- Of course, 'Any' and 'All' can be used together. -- -- >>> :{ --- let sl1 :: (Shape sh, Elt e) => Acc (Array (sh:.Int:.Int) e) -> Exp Int -> Acc (Array (sh:.Int) e) --- sl1 a n = slice a (Any_ ::. n ::. All_) +-- let sl1 :: forall sh e. (Shape sh, Elt e) => Acc (Array (sh:.Int:.Int) e) -> Exp Int -> Acc (Array (sh:.Int) e) +-- sl1 a n = slice @(Any sh :. Int :. All) a (Any :. n :. All) -- :} -- -- >>> run $ sl1 (use mat) 4 diff --git a/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SAXPY.hs b/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SAXPY.hs index 7537d5625..af69b5b28 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SAXPY.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SAXPY.hs @@ -23,12 +23,12 @@ module Data.Array.Accelerate.Test.NoFib.Imaginary.SAXPY ( import Prelude as P import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Array as S -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar +import qualified Data.Array.Accelerate.Sugar.Array as S +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog import qualified Hedgehog.Gen as Gen diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue102.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue102.hs index dedec6cf6..99960adb0 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue102.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue102.hs @@ -1,5 +1,6 @@ -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} -- | -- Module : Data.Array.Accelerate.Test.NoFib.Issues.Issue102 -- Copyright : [2009..2020] The Accelerate Team @@ -40,29 +41,29 @@ test1 = rts = 1 rustride = 1 - v = fill (constant (Z:.(p-1))) (constant 2) - ru' = fill (constant (Z:.(p-1))) (constant 1) + v = fill (I1 (p-1)) 2 + ru' = fill (I1 (p-1)) 1 -- generate a vector with phi(p)=p-1 elements - x' = reshape (constant (Z :. lts :. (p-1) :. rts)) v + x' = reshape (I3 lts (p-1) rts) v --embed into a vector of length p - y = generate (constant (Z :. lts :. p :. rts)) - (\ix -> let (Z :. l :. i :. r) = unlift ix :: Z :. Exp Int :. Exp Int :. Exp Int - in i A.== 0 ? (0, x' ! (lift $ Z :. l :. i-1 :. r))) + y = generate (I3 lts p rts) + (\ix -> let I3 l i r = ix + in i A.== 0 ? (0, x' ! (I3 l (i-1) r))) -- do a DFT_p - y' = reshape (constant (Z :. lts :. p :. rts)) (flatten y) - dftrus = generate (constant (Z :. p :. p)) - (\ix -> let (Z :. i :. j) = unlift ix :: Z :. Exp Int :. Exp Int - in ru' ! (lift (Z :. (i*j*rustride `mod` (constant p))))) + y' = reshape (I3 lts p rts) (flatten y) + dftrus = generate (I2 p p) + (\ix -> let I2 i j = ix + in ru' ! (I1 (i*j*rustride `mod` p))) - tensorDFTCoeffs = A.replicate (lift (Z:.lts:.All:.rts:.All)) dftrus + tensorDFTCoeffs = A.replicate @(Z :. Int :. All :. Int :. All) (I4 lts All rts All) dftrus tensorInputCoeffs = generate (shape tensorDFTCoeffs) - (\ix -> let (Z:.l:._:.r:.col) = unlift ix :: Z :. Exp Int :. Exp Int :. Exp Int :. Exp Int - in y' ! (lift $ Z:.l:.col:.r)) + (\ix -> let I4 l _ r col = ix + in y' ! (I3 l col r)) - dftans = flatten $ fold (+) (constant 0) $ A.zipWith (*) tensorDFTCoeffs tensorInputCoeffs + dftans = flatten $ fold (+) 0 $ A.zipWith (*) tensorDFTCoeffs tensorInputCoeffs --continue the alternate transform, but this line breaks dfty = reshape (shape y) $ dftans diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue137.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue137.hs index 444687f44..ce0cf613f 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue137.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue137.hs @@ -38,17 +38,17 @@ test1 :: Acc (Vector (Int,Int)) test1 = let sz = 3000 :: Int - interm_arrA = use $ A.fromList (Z :. sz) [ 8 - (a `mod` 17) | a <- [1..sz]] + interm_arrA = use $ A.fromList (Z :. sz) [ 8 - (a `mod` 17) | a <- [1..sz]] :: Acc (Vector Int) msA = use $ A.fromList (Z :. sz) [ (a `div` 8) | a <- [1..sz]] inf = 10000 :: Exp Int - infsA = A.generate (index1 (384 :: Exp Int)) (\_ -> lift (inf,inf)) - inpA = A.map (\v -> lift (abs v,inf) :: Exp (Int,Int)) interm_arrA + infsA = A.generate (index1 (384 :: Exp Int)) (\_ -> T2 inf inf) + inpA = A.map (\v -> T2 (abs v) inf) interm_arrA in - A.permute (\a12 b12 -> let (a1,a2) = unlift a12 - (b1,b2) = unlift b12 + A.permute (\a12 b12 -> let T2 a1 a2 = a12 + T2 b1 b2 = b12 in (a1 A.<= b1) - ? ( lift (a1, A.min a2 b1) - , lift (b1, A.min b2 a1) + ? ( T2 a1 (A.min a2 b1) + , T2 b1 (A.min b2 a1) )) infsA (\ix -> Just_ (index1 (msA A.! ix))) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue264.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue264.hs index afdbe9294..187984df7 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue264.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue264.hs @@ -26,12 +26,12 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue264 ( import Prelude as P import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Array as S -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar +import qualified Data.Array.Accelerate.Sugar.Array as S +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog import qualified Hedgehog.Gen as Gen @@ -80,7 +80,7 @@ test_not_not -> Property test_not_not runN = property $ do - xs <- forAll (array Z Gen.bool) + xs <- forAll (array @Z Z Gen.bool) let !go = runN (A.map A.not . A.map A.not) in go xs === mapRef (P.not . P.not) xs test_not_and @@ -88,8 +88,8 @@ test_not_and -> Property test_not_and runN = property $ do - xs <- forAll (array Z Gen.bool) - ys <- forAll (array Z Gen.bool) + xs <- forAll (array @Z Z Gen.bool) + ys <- forAll (array @Z Z Gen.bool) let !go = runN (A.zipWith (\u v -> A.not (u A.&& v))) in go xs ys === zipWithRef (\u v -> P.not (u P.&& v)) xs ys test_not_or @@ -97,8 +97,8 @@ test_not_or -> Property test_not_or runN = property $ do - xs <- forAll (array Z Gen.bool) - ys <- forAll (array Z Gen.bool) + xs <- forAll (array @Z Z Gen.bool) + ys <- forAll (array @Z Z Gen.bool) let !go = runN (A.zipWith (\u v -> A.not (u A.|| v))) in go xs ys === zipWithRef (\u v -> P.not (u P.|| v)) xs ys test_not_not_and @@ -106,8 +106,8 @@ test_not_not_and -> Property test_not_not_and runN = property $ do - xs <- forAll (array Z Gen.bool) - ys <- forAll (array Z Gen.bool) + xs <- forAll (array @Z Z Gen.bool) + ys <- forAll (array @Z Z Gen.bool) let !go = runN (A.zipWith (\u v -> A.not (A.not (u A.&& v)))) in go xs ys === zipWithRef (\u v -> P.not (P.not (u P.&& v))) xs ys test_not_not_or @@ -115,8 +115,8 @@ test_not_not_or -> Property test_not_not_or runN = property $ do - xs <- forAll (array Z Gen.bool) - ys <- forAll (array Z Gen.bool) + xs <- forAll (array @Z Z Gen.bool) + ys <- forAll (array @Z Z Gen.bool) let !go = runN (A.zipWith (\u v -> A.not (A.not (u A.|| v)))) in go xs ys === zipWithRef (\u v -> P.not (P.not (u P.|| v))) xs ys test_neg_neg @@ -133,7 +133,7 @@ test_neg_neg runN e = mapRef :: (Shape sh, Elt a, Elt b) => (a -> b) -> Array sh a -> Array sh b -mapRef f xs = fromFunction (S.shape xs) (\ix -> f (xs S.! ix)) +mapRef f xs = fromFunction (arrayShape xs) (\ix -> f (xs S.! ix)) zipWithRef :: (Shape sh, Elt a, Elt b, Elt c) @@ -143,6 +143,6 @@ zipWithRef -> Array sh c zipWithRef f xs ys = fromFunction - (S.shape xs `S.intersect` S.shape ys) + (arrayShape xs `S.intersect` arrayShape ys) (\ix -> f (xs S.! ix) (ys S.! ix)) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue364.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue364.hs index ddc167342..8365a70eb 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue364.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue364.hs @@ -23,17 +23,16 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue364 ( ) where -import Prelude ( fromInteger, show ) -import qualified Prelude as P +import Prelude ( show ) +import qualified Prelude as P -import Data.Array.Accelerate hiding ( fromInteger ) -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog - import Test.Tasty import Test.Tasty.HUnit @@ -55,8 +54,8 @@ test_issue364 runN = -> TestTree testElt _ = testGroup (show (eltR @e)) - [ testCase "A" $ expectedArray @_ @e Z 64 @=? runN (scanl iappend one) (intervalArray Z 64) - , testCase "B" $ expectedArray @_ @e Z 65 @=? runN (scanl iappend one) (intervalArray Z 65) -- failed for integral types + [ testCase "A" $ expectedArray @Z @e Z 64 @=? runN (scanl iappend one) (intervalArray Z 64) + , testCase "B" $ expectedArray @Z @e Z 65 @=? runN (scanl iappend one) (intervalArray Z 65) -- failed for integral types ] diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue427.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue427.hs index 1988d8c16..17a40b098 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue427.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue427.hs @@ -41,6 +41,7 @@ test_issue427 runN , testProperty "n-by-m" $ test_indicesOfTruth runN dim2 ] where + by :: Int -> Gen DIM2 by x = do y <- Gen.int (Range.linear 0 1024) pure (Z :. y :. x) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue439.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue439.hs index 3bfa18b40..d6f9b1a81 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue439.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue439.hs @@ -33,5 +33,5 @@ e1 :: Scalar Float e1 = fromList Z [2] t1 :: Acc (Scalar Float) -t1 = compute . A.map (* 2) . compute $ fill Z_ 1 +t1 = compute . A.map (* 2) . compute $ fill Z 1 diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Backpermute.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Backpermute.hs index 32bad367e..8620d2700 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Backpermute.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Backpermute.hs @@ -24,12 +24,11 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Backpermute ( import Prelude as P import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Array as S -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog import qualified Hedgehog.Gen as Gen @@ -71,7 +70,7 @@ test_backpermute runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "take" $ test_take runN sh e , testProperty "drop" $ test_drop runN sh e @@ -127,7 +126,7 @@ test_gather runN dim dim' e = -- let !go = runN $ \i -> A.backpermute (A.shape i) (i A.!) -- - go ix xs ~~~ backpermuteRef sh' (ix S.!) xs + go ix xs ~~~ backpermuteRef sh' (ix `indexArray`) xs scalar :: Elt e => e -> Scalar e @@ -140,7 +139,7 @@ backpermuteRef -> Array sh e -> Array sh' e backpermuteRef sh' p arr = - fromFunction sh' (\ix -> arr S.! p ix) + fromFunction sh' (\ix -> arr `indexArray` p ix) takeRef :: (Shape sh, Slice sh, Elt e) @@ -148,8 +147,8 @@ takeRef -> Array (sh:.Int) e -> Array (sh:.Int) e takeRef n arr = - let sh :. m = S.shape arr - in fromFunction (sh :. P.min m n) (arr S.!) + let sh :. m = arrayShape arr + in fromFunction (sh :. P.min m n) (arr `indexArray`) dropRef :: (Shape sh, Slice sh, Elt e) @@ -157,7 +156,7 @@ dropRef -> Array (sh:.Int) e -> Array (sh:.Int) e dropRef n arr = - let sh :. m = S.shape arr + let sh :. m = arrayShape arr n' = P.max 0 n - in fromFunction (sh :. P.max 0 (m - n')) (\(sz:.i) -> arr S.! (sz :. i+n')) + in fromFunction (sh :. P.max 0 (m - n')) (\(sz:.i) -> arr `indexArray` (sz :. i+n')) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Filter.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Filter.hs index cfafc2da1..164f598ed 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Filter.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Filter.hs @@ -26,12 +26,12 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Filter ( import Prelude as P import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Array as S -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar +import qualified Data.Array.Accelerate.Sugar.Array as S +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog @@ -71,7 +71,7 @@ test_filter runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "even" $ test_even runN sh e ] @@ -91,7 +91,7 @@ test_filter runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "positive" $ test_positive runN sh e ] diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Fold.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Fold.hs index 6e19ee3dc..6dce22603 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Fold.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Fold.hs @@ -25,11 +25,11 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Fold ( import Prelude as P import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog import qualified Hedgehog.Gen as Gen @@ -72,7 +72,7 @@ test_fold runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_sum runN sh e e @@ -112,7 +112,7 @@ test_foldSeg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_segmented_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_segmented_sum runN sh e e @@ -187,7 +187,7 @@ test_segmented_sum runN dim z e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (S.rank @sh))))) xs <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.foldSeg (+) (the v)) in go (scalar x) xs seg ~~~ foldSegRef (+) x xs seg @@ -202,7 +202,7 @@ test_segmented_minimum runN dim e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (S.rank @sh))))) xs <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (A.fold1Seg A.min) in go xs seg ~~~ fold1SegRef P.min xs seg @@ -217,7 +217,7 @@ test_segmented_maximum runN dim e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (S.rank @sh))))) xs <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (A.fold1Seg A.max) in go xs seg ~~~ fold1SegRef P.max xs seg diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs index 71b8889a3..45a0de821 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs @@ -32,8 +32,6 @@ import Data.Array.Accelerate.Sugar.Elt as S import Data.Array.Accelerate.Sugar.Shape as S import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config -import Data.Primitive.Types -import qualified Data.Primitive.Vec as Prim import Hedgehog import qualified Hedgehog.Gen as Gen diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Scan.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Scan.hs index f1a0f587d..4b94704c9 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Scan.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Scan.hs @@ -30,11 +30,11 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Scan ( import Prelude as P import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog import qualified Hedgehog.Gen as Gen @@ -76,7 +76,7 @@ test_scanl runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanl_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanl_sum runN sh e e , testProperty "non-commutative" $ test_scanl_interval runN sh e @@ -112,7 +112,7 @@ test_scanl1 runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanl1_sum runN sh e , testProperty "non-commutative" $ test_scanl1_interval runN sh e ] @@ -147,7 +147,7 @@ test_scanl' runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanl'_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanl'_sum runN sh e e , testProperty "non-commutative" $ test_scanl'_interval runN sh e @@ -183,7 +183,7 @@ test_scanr runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanr_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanr_sum runN sh e e , testProperty "non-commutative" $ test_scanr_interval runN sh e @@ -219,7 +219,7 @@ test_scanr1 runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanr1_sum runN sh e , testProperty "non-commutative" $ test_scanr1_interval runN sh e ] @@ -254,7 +254,7 @@ test_scanr' runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanr'_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanr'_sum runN sh e e , testProperty "non-commutative" $ test_scanr'_interval runN sh e @@ -290,7 +290,7 @@ test_scanlSeg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanlSeg_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanlSeg_sum runN sh e e ] @@ -325,7 +325,7 @@ test_scanl1Seg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanl1Seg_sum runN sh e ] @@ -359,7 +359,7 @@ test_scanl'Seg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanl'Seg_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanl'Seg_sum runN sh e e ] @@ -394,7 +394,7 @@ test_scanrSeg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanrSeg_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanrSeg_sum runN sh e e ] @@ -429,7 +429,7 @@ test_scanr1Seg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanr1Seg_sum runN sh e ] @@ -463,7 +463,7 @@ test_scanr'Seg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanr'Seg_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanr'Seg_sum runN sh e e ] @@ -636,7 +636,7 @@ test_scanlSeg_sum runN dim z e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (S.rank @sh))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.scanlSeg (+) (the v)) in go (scalar x) arr seg ~~~ scanlSegRef (+) x arr seg @@ -651,7 +651,7 @@ test_scanl1Seg_sum runN dim e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 1 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (S.rank @sh))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (A.scanl1Seg (+)) in go arr seg ~~~ scanl1SegRef (+) arr seg @@ -668,7 +668,7 @@ test_scanl'Seg_sum runN dim z e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (S.rank @sh))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.scanl'Seg (+) (the v)) in go (scalar x) arr seg ~~~ scanl'SegRef (+) x arr seg @@ -685,7 +685,7 @@ test_scanrSeg_sum runN dim z e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (S.rank @sh))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.scanrSeg (+) (the v)) in go (scalar x) arr seg ~~~ scanrSegRef (+) x arr seg @@ -700,7 +700,7 @@ test_scanr1Seg_sum runN dim e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 1 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (S.rank @sh))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (A.scanr1Seg (+)) in go arr seg ~~~ scanr1SegRef (+) arr seg @@ -717,7 +717,7 @@ test_scanr'Seg_sum runN dim z e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (S.rank @sh))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.scanr'Seg (+) (the v)) in go (scalar x) arr seg ~~~ scanr'SegRef (+) x arr seg diff --git a/src/Data/Array/Accelerate/Test/NoFib/Sharing.hs b/src/Data/Array/Accelerate/Test/NoFib/Sharing.hs index 7fada0655..722666522 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Sharing.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Sharing.hs @@ -85,13 +85,13 @@ mkArray n = use $ fromList (Z:.1) [n] test_blowup :: Int -> Acc (Array DIM1 Int) test_blowup 0 = (mkArray 0) -test_blowup n = A.map (\_ -> newArr ! (lift (Z:.(0::Int))) + - newArr ! (lift (Z:.(1::Int)))) (mkArray n) +test_blowup n = A.map (\_ -> newArr ! (Z :. 0) + + newArr ! (Z :. 1)) (mkArray n) where newArr = test_blowup (n-1) idx :: Int -> Exp DIM1 -idx i = lift (Z:.i) +idx i = constant (Z:.i) test_bfs :: Acc (Array DIM1 Int) test_bfs = A.map (\x -> (map2 ! (idx 1)) + (map1 ! (idx 2)) + x) arr From 4b2c1d72aeaf159f56134625c2bacec4eaf41d45 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 14 Jun 2022 12:20:03 +0200 Subject: [PATCH 26/86] add COMPLETE pragmas --- src/Data/Array/Accelerate/Pattern/Shape.hs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Shape.hs b/src/Data/Array/Accelerate/Pattern/Shape.hs index 732e9906c..e1660f40f 100644 --- a/src/Data/Array/Accelerate/Pattern/Shape.hs +++ b/src/Data/Array/Accelerate/Pattern/Shape.hs @@ -34,23 +34,27 @@ import qualified Data.Array.Accelerate.Sugar.Shape as Sugar pattern Z :: IsShapeZ z => z pattern Z <- (z_matcher -> True) where Z = z_builder -{-# COMPLETE Z #-} +{-# COMPLETE Z :: Z #-} +{-# COMPLETE Z :: Exp #-} pattern All :: IsShapeAll all => all pattern All <- (all_matcher -> True) where All = all_builder -{-# COMPLETE All #-} +{-# COMPLETE All :: All #-} +{-# COMPLETE All :: Exp #-} pattern Any :: IsShapeAny any => any pattern Any <- (any_matcher -> True) where Any = any_builder -{-# COMPLETE Any #-} +{-# COMPLETE Any :: Any #-} +{-# COMPLETE Any :: Exp #-} infixl 3 :. pattern (:.) :: IsShapeSnoc t h s => t -> h -> s pattern t :. h <- (snoc_matcher -> (t Sugar.:. h)) where t :. h = snoc_builder t h -{-# COMPLETE (:.) #-} +{-# COMPLETE (:.) :: (:.) #-} +{-# COMPLETE (:.) :: Exp #-} class IsShapeZ z where From f390343a53e369d300e0b007f6d156f27b73c6ee Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 14 Jun 2022 14:23:34 +0200 Subject: [PATCH 27/86] fix doctests --- src/Data/Array/Accelerate/Language.hs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index d1b2d3bdb..ad3f27284 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -112,6 +112,7 @@ import Prelude ( ($), (.), -- $setup -- >>> :seti -XFlexibleContexts -- >>> :seti -XScopedTypeVariables +-- >>> :seti -XTypeApplications -- >>> :seti -XTypeOperators -- >>> :seti -XViewPatterns -- >>> import Data.Array.Accelerate From 11829ff9de5375f29c6bbf57a1a2aeaed4a02163 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Wed, 15 Jun 2022 18:24:36 +0200 Subject: [PATCH 28/86] drop ghc-8.6 as it crashes the compiler [ 99 of 112] Compiling Data.Array.Accelerate.Lift ghc: panic! (the 'impossible' happened) (GHC version 8.6.5 for x86_64-unknown-linux): expectJust mkOneConFull CallStack (from HasCallStack): error, called at compiler/utils/Maybes.hs:55:27 in ghc:Maybes expectJust, called at compiler/deSugar/Check.hs:1312:37 in ghc:Check --- .github/workflows/ci-linux.yml | 3 +-- .github/workflows/ci-macos.yml | 1 - .github/workflows/ci-windows.yml | 1 - accelerate.cabal | 2 +- stack-8.6.yaml | 39 -------------------------------- 5 files changed, 2 insertions(+), 44 deletions(-) delete mode 100644 stack-8.6.yaml diff --git a/.github/workflows/ci-linux.yml b/.github/workflows/ci-linux.yml index 7dd1d722e..5fd41a986 100644 --- a/.github/workflows/ci-linux.yml +++ b/.github/workflows/ci-linux.yml @@ -24,7 +24,6 @@ jobs: - "9.0" - "8.10" - "8.8" - - "8.6" env: STACK_FLAGS: "--fast --flag accelerate:nofib" HADDOCK_FLAGS: "--haddock --no-haddock-deps --no-haddock-hyperlink-source --haddock-arguments=\"--no-print-missing-docs\"" @@ -72,7 +71,7 @@ jobs: - name: Test doctest run: stack test accelerate:doctest $STACK_FLAGS - if: ${{ matrix.ghc != '8.6' && matrix.ghc != '8.8' }} + if: ${{ matrix.ghc != '8.8' }} - name: Test nofib run: stack test accelerate:nofib-interpreter $STACK_FLAGS diff --git a/.github/workflows/ci-macos.yml b/.github/workflows/ci-macos.yml index f3e83e266..edb79c953 100644 --- a/.github/workflows/ci-macos.yml +++ b/.github/workflows/ci-macos.yml @@ -24,7 +24,6 @@ jobs: - "9.0" - "8.10" - "8.8" - - "8.6" env: STACK_FLAGS: "--fast --flag accelerate:nofib" HADDOCK_FLAGS: "--haddock --no-haddock-deps --no-haddock-hyperlink-source --haddock-arguments=\"--no-print-missing-docs\"" diff --git a/.github/workflows/ci-windows.yml b/.github/workflows/ci-windows.yml index 03f066140..49cb7f290 100644 --- a/.github/workflows/ci-windows.yml +++ b/.github/workflows/ci-windows.yml @@ -24,7 +24,6 @@ jobs: - "9.0" - "8.10" - "8.8" - - "8.6" env: __COMPAT_LAYER: "" diff --git a/accelerate.cabal b/accelerate.cabal index 5fc01625c..deac8aed4 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -341,7 +341,7 @@ flag nofib library build-depends: - base >= 4.12 && < 4.17 + base >= 4.13 && < 4.17 , ansi-terminal >= 0.6.2 , base-orphans >= 0.3 , bytestring >= 0.10.2 diff --git a/stack-8.6.yaml b/stack-8.6.yaml deleted file mode 100644 index 5d3724662..000000000 --- a/stack-8.6.yaml +++ /dev/null @@ -1,39 +0,0 @@ -# For more information, see: https://github.com/commercialhaskell/stack/blob/release/doc/yaml_configuration.md -# vim: nospell - -resolver: lts-14.27 - -packages: -- . - -extra-deps: -- formatting-7.1.3 -- prettyprinter-1.7.1 -- prettyprinter-ansi-terminal-1.1.3 -- tasty-rerun-1.1.18 -- text-1.2.4.1 - -# Override default flag values for local packages and extra-deps -# flags: {} - -# Extra global and per-package GHC options -# ghc-options: {} - -# Extra package databases containing global packages -# extra-package-dbs: [] - -# Control whether we use the GHC we find on the path -# system-ghc: true - -# Require a specific version of stack, using version ranges -# require-stack-version: -any # Default -# require-stack-version: >= 0.1.4.0 - -# Override the architecture used by stack, especially useful on Windows -# arch: i386 -# arch: x86_64 - -# Extra directories used by stack for building -# extra-include-dirs: [/path/to/dir] -# extra-lib-dirs: [/path/to/dir] - From 7b614fc294d272294b3a00ae27ea443ae8dd2630 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Sat, 18 Jun 2022 22:58:07 +0200 Subject: [PATCH 29/86] fix pattern matching for bit tags --- src/Data/Array/Accelerate/AST.hs | 10 +- src/Data/Array/Accelerate/Analysis/Hash.hs | 14 ++- src/Data/Array/Accelerate/Analysis/Match.hs | 25 ++++- src/Data/Array/Accelerate/Interpreter.hs | 21 +++- src/Data/Array/Accelerate/Pattern/Bool.hs | 8 +- src/Data/Array/Accelerate/Pretty/Print.hs | 9 +- .../Array/Accelerate/Representation/Tag.hs | 74 ++++++------ src/Data/Array/Accelerate/Sugar/Elt.hs | 4 +- src/Data/Array/Accelerate/Sugar/Vec.hs | 15 ++- src/Data/Array/Accelerate/Trafo/Sharing.hs | 105 ++++++++++++------ src/Data/Array/Accelerate/Trafo/Simplify.hs | 29 +++-- 11 files changed, 212 insertions(+), 102 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index c307275ef..d3c765e75 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -594,9 +594,9 @@ data OpenExp env aenv t where -> OpenExp env aenv sh -- Case statement - Case :: ScalarType TAG - -> OpenExp env aenv TAG - -> [(TAG, OpenExp env aenv b)] -- list of equations + Case :: TagType tag + -> OpenExp env aenv tag + -> [(tag, OpenExp env aenv b)] -- list of equations -> Maybe (OpenExp env aenv b) -- default case -> OpenExp env aenv b @@ -1128,7 +1128,7 @@ rnfOpenExp topExp = IndexFull slice slix sl -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sl ToIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix FromIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix - Case pR p rhs def -> rnfScalarType pR `seq` rnfE p `seq` rnfList (\(t,c) -> t `seq` rnfE c) rhs `seq` rnfMaybe rnfE def + Case pR p rhs def -> rnfTagType pR `seq` rnfE p `seq` rnfList (\(t,c) -> t `seq` rnfE c) rhs `seq` rnfMaybe rnfE def Cond p e1 e2 -> rnfE p `seq` rnfE e1 `seq` rnfE e2 While p f x -> rnfF p `seq` rnfF f `seq` rnfE x PrimApp f x -> rnfPrimFun f `seq` rnfE x @@ -1346,7 +1346,7 @@ liftOpenExp pexp = IndexFull slice slix sl -> [|| IndexFull $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sl) ||] ToIndex shr sh ix -> [|| ToIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||] FromIndex shr sh ix -> [|| FromIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||] - Case pR p rhs def -> [|| Case $$(liftScalarType pR) $$(liftE p) $$(liftList (\(t,c) -> [|| ($$(liftScalar pR t), $$(liftE c)) ||]) rhs) $$(liftMaybe liftE def) ||] + Case pR p rhs def -> [|| Case $$(liftTagType pR) $$(liftE p) $$(liftList (\(t,c) -> [|| ($$(liftTag pR t), $$(liftE c)) ||]) rhs) $$(liftMaybe liftE def) ||] Cond p t e -> [|| Cond $$(liftE p) $$(liftE t) $$(liftE e) ||] While p f x -> [|| While $$(liftF p) $$(liftF f) $$(liftE x) ||] PrimApp f x -> [|| PrimApp $$(liftPrimFun f) $$(liftE x) ||] diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index c97abcb77..25c96975a 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -44,6 +44,7 @@ import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Representation.Slice import Data.Array.Accelerate.Representation.Stencil +import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type import Data.Primitive.Vec @@ -331,7 +332,7 @@ encodeOpenExp exp = IndexFull spec ix sl -> intHost $(hashQ "IndexFull") <> travE ix <> travE sl <> encodeSliceIndex spec ToIndex _ sh i -> intHost $(hashQ "ToIndex") <> travE sh <> travE i FromIndex _ sh i -> intHost $(hashQ "FromIndex") <> travE sh <> travE i - Case eR e rhs def -> intHost $(hashQ "Case") <> encodeScalarType eR <> travE e <> mconcat [ encodeScalarConst eR t <> travE c | (t,c) <- rhs ] <> encodeMaybe travE def + Case eR e rhs def -> intHost $(hashQ "Case") <> encodeTagType eR <> travE e <> mconcat [ encodeTagConst eR t <> travE c | (t,c) <- rhs ] <> encodeMaybe travE def Cond c t e -> intHost $(hashQ "Cond") <> travE c <> travE t <> travE e While p f x -> intHost $(hashQ "While") <> travF p <> travF f <> travE x PrimApp f x -> intHost $(hashQ "PrimApp") <> encodePrimFun f <> travE x @@ -397,6 +398,12 @@ encodeSingleFloatingConst TypeFloat32 x = intHost $(hashQ "Flo encodeSingleFloatingConst TypeFloat64 x = intHost $(hashQ "Double") <> doubleHost x encodeSingleFloatingConst TypeFloat128 (Float128 x y) = intHost $(hashQ "Float128") <> word64Host x <> word64Host y +encodeTagConst :: TagType t -> t -> Builder +encodeTagConst TagBit (Bit False) = intHost $(hashQ "Bit") <> int8 0 +encodeTagConst TagBit (Bit True) = intHost $(hashQ "Bit") <> int8 1 +encodeTagConst TagWord8 x = intHost $(hashQ "Tag8") <> word8 x +encodeTagConst TagWord16 x = intHost $(hashQ "Tag16") <> word16Host x + encodePrimFun :: PrimFun f -> Builder encodePrimFun (PrimAdd a) = intHost $(hashQ "PrimAdd") <> encodeNumType a encodePrimFun (PrimSub a) = intHost $(hashQ "PrimSub") <> encodeNumType a @@ -512,6 +519,11 @@ encodeSingleFloatingType TypeFloat32 = intHost $(hashQ "Float") encodeSingleFloatingType TypeFloat64 = intHost $(hashQ "Double") encodeSingleFloatingType TypeFloat128 = intHost $(hashQ "Float128") +encodeTagType :: TagType t -> Builder +encodeTagType TagBit = intHost $(hashQ "TagBit") +encodeTagType TagWord8 = intHost $(hashQ "TagWord8") +encodeTagType TagWord16 = intHost $(hashQ "TagWord16") + encodeMaybe :: (a -> Builder) -> Maybe a -> Builder encodeMaybe _ Nothing = intHost $(hashQ "Nothing") encodeMaybe f (Just x) = intHost $(hashQ "Just") <> f x diff --git a/src/Data/Array/Accelerate/Analysis/Match.hs b/src/Data/Array/Accelerate/Analysis/Match.hs index c25a2620d..db94b1a2a 100644 --- a/src/Data/Array/Accelerate/Analysis/Match.hs +++ b/src/Data/Array/Accelerate/Analysis/Match.hs @@ -36,6 +36,8 @@ module Data.Array.Accelerate.Analysis.Match ( matchIntegralType, matchSingleIntegralType, matchFloatingType, matchSingleFloatingType, matchNumType, + matchBitType, + matchTagType, matchScalarType, matchLeftHandSide, matchALeftHandSide, matchELeftHandSide, matchTupR, @@ -51,6 +53,7 @@ import Data.Array.Accelerate.Representation.Array ( Array(..), import Data.Array.Accelerate.Representation.Shape ( ShapeR(..) ) import Data.Array.Accelerate.Representation.Slice ( SliceIndex(..) ) import Data.Array.Accelerate.Representation.Stencil +import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type import qualified Data.Array.Accelerate.Sugar.Shape as Sugar @@ -523,17 +526,18 @@ matchOpenExp (FromIndex _ sh1 i1) (FromIndex _ sh2 i2) = Just Refl matchOpenExp (Case eR1 e1 rhs1 def1) (Case eR2 e2 rhs2 def2) - | Just Refl <- matchScalarType eR1 eR2 + | Just Refl <- matchTagType eR1 eR2 , Just Refl <- matchOpenExp e1 e2 , Just Refl <- matchCaseEqs eR1 rhs1 rhs2 , Just Refl <- matchCaseDef def1 def2 = Just Refl where - matchCaseEqs :: ScalarType tag -> [(tag, OpenExp env aenv a)] -> [(tag, OpenExp env aenv b)] -> Maybe (a :~: b) + matchCaseEqs :: TagType tag -> [(tag, OpenExp env aenv a)] -> [(tag, OpenExp env aenv b)] -> Maybe (a :~: b) matchCaseEqs _ [] [] = unsafeCoerce Refl matchCaseEqs tR ((s,x):xs) ((t,y):ys) - | evalEq tR (s,t) + | TagDict <- tagDict tR + , s == t , Just Refl <- matchOpenExp x y , Just Refl <- matchCaseEqs tR xs ys = Just Refl @@ -967,6 +971,13 @@ matchSingleFloatingType TypeFloat64 TypeFloat64 = Just Refl matchSingleFloatingType TypeFloat128 TypeFloat128 = Just Refl matchSingleFloatingType _ _ = Nothing +{-# INLINEABLE matchTagType #-} +matchTagType :: TagType s -> TagType t -> Maybe (s :~: t) +matchTagType TagBit TagBit = Just Refl +matchTagType TagWord8 TagWord8 = Just Refl +matchTagType TagWord16 TagWord16 = Just Refl +matchTagType _ _ = Nothing + -- Auxiliary -- --------- @@ -1001,3 +1012,11 @@ commutes f x = case f of -- | otherwise = e +data TagDict t where + TagDict :: Eq t => TagDict t + +tagDict :: TagType t -> TagDict t +tagDict TagBit = TagDict +tagDict TagWord8 = TagDict +tagDict TagWord16 = TagDict + diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 2f7548344..dc5c1a059 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -71,7 +71,6 @@ import qualified Data.Array.Accelerate.Debug.Internal.Timed as Debug import qualified Data.Array.Accelerate.Interpreter.Arithmetic as A import qualified Data.Array.Accelerate.Smart as Smart import qualified Data.Array.Accelerate.Sugar.Array as Sugar -import qualified Data.Array.Accelerate.Sugar.Elt as Sugar import qualified Data.Array.Accelerate.Trafo.Delayed as AST import Control.DeepSeq @@ -224,8 +223,8 @@ evalOpenAcc (AST.Manifest pacc) aenv = p = evalOpenAfun cond aenv f = evalOpenAfun body aenv go !x - | toBool (linearIndexArray (Sugar.eltR @Bool) (p x) 0) = go (f x) - | otherwise = x + | toBool (linearIndexArray (TupRsingle (BitScalarType TypeBit)) (p x) 0) = go (f x) + | otherwise = x Use repr arr -> (TupRsingle repr, arr) Unit tp e -> unitOp tp (evalE e) @@ -955,11 +954,12 @@ evalOpenExp pexp env aenv = ToIndex shr sh ix -> toIndex shr (evalE sh) (evalE ix) FromIndex shr sh ix -> fromIndex shr (evalE sh) (evalE ix) - Case _ e rhs def -> evalE (caseof (evalE e) rhs) + Case tagR e rhs def -> evalE (caseof tagR (evalE e) rhs) where - caseof :: TAG -> [(TAG, OpenExp env aenv t)] -> OpenExp env aenv t - caseof tag = go + caseof :: forall tag. TagType tag -> tag -> [(tag, OpenExp env aenv t)] -> OpenExp env aenv t + caseof tagR tag | TagDict <- tagDict tagR = go where + go :: Eq tag => [(tag, OpenExp env aenv t)] -> OpenExp env aenv t go ((t,c):cs) | tag == t = c | otherwise = go cs @@ -1526,6 +1526,9 @@ data IntegralDict t where data FloatingDict t where FloatingDict :: (RealFloat t, Prim t) => FloatingDict t +data TagDict t where + TagDict :: Eq t => TagDict t + {-# INLINE integralDict #-} integralDict :: SingleIntegralType t -> IntegralDict t integralDict TypeInt8 = IntegralDict @@ -1546,3 +1549,9 @@ floatingDict TypeFloat32 = FloatingDict floatingDict TypeFloat64 = FloatingDict floatingDict TypeFloat128 = FloatingDict +{-# INLINE tagDict #-} +tagDict :: TagType t -> TagDict t +tagDict TagBit = TagDict +tagDict TagWord8 = TagDict +tagDict TagWord16 = TagDict + diff --git a/src/Data/Array/Accelerate/Pattern/Bool.hs b/src/Data/Array/Accelerate/Pattern/Bool.hs index 093caae60..8efd6f6a1 100644 --- a/src/Data/Array/Accelerate/Pattern/Bool.hs +++ b/src/Data/Array/Accelerate/Pattern/Bool.hs @@ -43,8 +43,8 @@ buildFalse = mkExp $ Const scalarType 0 matchFalse :: HasCallStack => Exp Bool -> Maybe () matchFalse (Exp e) = case e of - SmartExp (Match (TagRbit TypeBit 0) _) -> Just () - SmartExp Match{} -> Nothing + SmartExp (Match (TagRbit 0) _) -> Just () + SmartExp Match{} -> Nothing _ -> error $ unlines [ "Embedded pattern synonym used outside 'match' context." , "" @@ -64,8 +64,8 @@ buildTrue = mkExp $ Const scalarType 1 matchTrue :: HasCallStack => Exp Bool -> Maybe () matchTrue (Exp e) = case e of - SmartExp (Match (TagRbit TypeBit 1) _) -> Just () - SmartExp Match{} -> Nothing + SmartExp (Match (TagRbit 1) _) -> Just () + SmartExp Match{} -> Nothing _ -> error $ unlines [ "Embedded pattern synonym used outside 'match' context." , "" diff --git a/src/Data/Array/Accelerate/Pretty/Print.hs b/src/Data/Array/Accelerate/Pretty/Print.hs index 15f5a09df..37f6cf9b0 100644 --- a/src/Data/Array/Accelerate/Pretty/Print.hs +++ b/src/Data/Array/Accelerate/Pretty/Print.hs @@ -63,6 +63,7 @@ import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Elt import Data.Array.Accelerate.Representation.Stencil +import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Type @@ -559,7 +560,7 @@ prettyTuple ctx env aenv exp = case collect exp of prettyCase :: Val env -> Val aenv - -> ScalarType tag + -> TagType tag -> OpenExp env aenv a -> [(tag, OpenExp env aenv b)] -> Maybe (OpenExp env aenv b) @@ -571,7 +572,11 @@ prettyCase env aenv tagR x xs def ] where x' = prettyOpenExp context0 env aenv x - xs' = map (\(t,e) -> prettyConst (TupRsingle tagR) t <+> "->" <+> prettyOpenExp context0 env aenv e) xs + tR = case tagR of + TagBit -> scalarType + TagWord8 -> scalarType + TagWord16 -> scalarType + xs' = map (\(t,e) -> prettyConst (TupRsingle tR) t <+> "->" <+> prettyOpenExp context0 env aenv e) xs ++ case def of Nothing -> [] Just d -> ["_" <+> "->" <+> prettyOpenExp context0 env aenv d] diff --git a/src/Data/Array/Accelerate/Representation/Tag.hs b/src/Data/Array/Accelerate/Representation/Tag.hs index 034390dbf..081f53e22 100644 --- a/src/Data/Array/Accelerate/Representation/Tag.hs +++ b/src/Data/Array/Accelerate/Representation/Tag.hs @@ -16,7 +16,6 @@ module Data.Array.Accelerate.Representation.Tag where import Data.Array.Accelerate.Type -import Data.Primitive.Bit import Language.Haskell.TH.Extra @@ -26,6 +25,11 @@ import Language.Haskell.TH.Extra -- type TAG = Word8 +data TagType t where + TagBit :: TagType Bit + TagWord8 :: TagType Word8 + TagWord16 :: TagType Word16 + -- | This structure both witnesses the layout of our representation types -- (as TupR does) and represents a complete path of pattern matching -- through this type. It indicates which fields of the structure represent @@ -44,47 +48,51 @@ data TagR a where TagRunit :: TagR () TagRsingle :: ScalarType a -> TagR a TagRundef :: ScalarType a -> TagR a + TagRtag :: TagType t -> t -> TagR a -> TagR (t, a) + TagRbit :: Bit -> TagR Bit -- redundant with TagRtag but simplifies abstract syntax for Bool TagRpair :: TagR a -> TagR b -> TagR (a, b) - TagRtag :: SingleIntegralType t -> t -> TagR a -> TagR (t, a) - TagRbit :: BitType t -> t -> TagR t instance Show (TagR a) where show TagRunit = "()" show TagRsingle{} = "." show TagRundef{} = "undef" show (TagRpair ta tb) = "(" ++ show ta ++ "," ++ show tb ++ ")" - show (TagRtag tR t e) = "(" ++ integral tR t ++ "#," ++ show e ++ ")" - where - integral :: SingleIntegralType t -> t -> String - integral TypeInt8 = show - integral TypeInt16 = show - integral TypeInt32 = show - integral TypeInt64 = show - integral TypeInt128 = show - integral TypeWord8 = show - integral TypeWord16 = show - integral TypeWord32 = show - integral TypeWord64 = show - integral TypeWord128 = show - show (TagRbit tR t) = bit tR t + show (TagRbit b) = shows b "#" + show (TagRtag tR t e) = "(" ++ tag tR t ++ "#," ++ show e ++ ")" where - bit :: BitType t -> t -> String - bit TypeBit x = shows x "#" - bit TypeMask{} x = shows (BitMask x) "#" + tag :: TagType t -> t -> String + tag TagBit = show + tag TagWord8 = show + tag TagWord16 = show rnfTag :: TagR a -> () -rnfTag TagRunit = () -rnfTag (TagRsingle e) = rnfScalarType e -rnfTag (TagRundef e) = rnfScalarType e -rnfTag (TagRpair ta tb) = rnfTag ta `seq` rnfTag tb -rnfTag (TagRtag tR t e) = rnfSingleIntegralType tR `seq` t `seq` rnfTag e -rnfTag (TagRbit tR t) = rnfBitType tR `seq` t `seq` () +rnfTag TagRunit = () +rnfTag (TagRsingle e) = rnfScalarType e +rnfTag (TagRundef e) = rnfScalarType e +rnfTag (TagRpair ta tb) = rnfTag ta `seq` rnfTag tb +rnfTag (TagRbit (Bit b)) = b `seq` () +rnfTag (TagRtag tR t e) = rnfTagType tR `seq` t `seq` rnfTag e + +rnfTagType :: TagType t -> () +rnfTagType TagBit = () +rnfTagType TagWord8 = () +rnfTagType TagWord16 = () + +liftTagR :: TagR a -> CodeQ (TagR a) +liftTagR TagRunit = [|| TagRunit ||] +liftTagR (TagRsingle e) = [|| TagRsingle $$(liftScalarType e) ||] +liftTagR (TagRundef e) = [|| TagRundef $$(liftScalarType e) ||] +liftTagR (TagRpair ta tb) = [|| TagRpair $$(liftTagR ta) $$(liftTagR tb) ||] +liftTagR (TagRtag tR t e) = [|| TagRtag $$(liftTagType tR) $$(liftTag tR t) $$(liftTagR e) ||] +liftTagR (TagRbit (Bit b)) = [|| TagRbit (Bit b) ||] + +liftTag :: TagType t -> t -> CodeQ t +liftTag TagBit (Bit x) = [|| Bit x ||] +liftTag TagWord8 x = [|| x ||] +liftTag TagWord16 x = [|| x ||] -liftTag :: TagR a -> CodeQ (TagR a) -liftTag TagRunit = [|| TagRunit ||] -liftTag (TagRsingle e) = [|| TagRsingle $$(liftScalarType e) ||] -liftTag (TagRundef e) = [|| TagRundef $$(liftScalarType e) ||] -liftTag (TagRpair ta tb) = [|| TagRpair $$(liftTag ta) $$(liftTag tb) ||] -liftTag (TagRtag tR t e) = [|| TagRtag $$(liftSingleIntegralType tR) $$(liftSingleIntegral tR t) $$(liftTag e) ||] -liftTag (TagRbit tR t) = [|| TagRbit $$(liftBitType tR) $$(liftBit tR t) ||] +liftTagType :: TagType t -> CodeQ (TagType t) +liftTagType TagBit = [|| TagBit ||] +liftTagType TagWord8 = [|| TagWord8 ||] +liftTagType TagWord16 = [|| TagWord16 ||] diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index 441c45223..e3acfdbcb 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -164,7 +164,7 @@ instance (GElt a, GElt b) => GElt (a :*: b) where instance (GElt a, GElt b, GSumElt (a :+: b)) => GElt (a :+: b) where type GEltR t (a :+: b) = (TAG, GSumEltR t (a :+: b)) geltR t = TupRpair (TupRsingle scalarType) (gsumEltR @(a :+: b) t) - gtagsR t = uncurry (TagRtag singleIntegralType) <$> gsumTagsR @(a :+: b) 0 t + gtagsR t = uncurry (TagRtag TagWord8) <$> gsumTagsR @(a :+: b) 0 t gfromElt = gsumFromElt 0 gtoElt (k,x) = gsumToElt k x gundef t = (0xff, gsumUndef @(a :+: b) t) @@ -290,7 +290,7 @@ instance (Elt a, Elt b) => Elt (Either a b) instance Elt Bool where type EltR Bool = Bit eltR = TupRsingle scalarType - tagsR = [TagRbit TypeBit 0, TagRbit TypeBit 1] + tagsR = [TagRbit 0, TagRbit 1] toElt = unBit fromElt = Bit diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index 80a4cb64a..16066ab67 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -160,17 +160,18 @@ class KnownNat n => SIMD n a where type VecR n a = GVecR () n (Rep a) vecR :: TypeR (VecR n a) - vtagsR :: [TagR (VecR n a)] -- this will quickly get out of hand! + vtagsR :: [TagR (VecR n a)] default vecR :: (GVec n (Rep a), VecR n a ~ GVecR () n (Rep a)) => TypeR (VecR n a) vecR = gvecR @n @(Rep a) TupRunit - default vtagsR - :: (GVec n (Rep a), VecR n a ~ GVecR () n (Rep a)) - => [TagR (VecR n a)] - vtagsR = gvtagsR @n @(Rep a) TagRunit + -- default vtagsR + -- :: (GVec n (Rep a), VecR n a ~ GVecR () n (Rep a)) + -- => [TagR (VecR n a)] + -- vtagsR = gvtagsR @n @(Rep a) TagRunit + vtagsR = [tagOfType (vecR @n @a)] class KnownNat n => GVec n (f :: Type -> Type) where type GVecR t n f @@ -229,6 +230,10 @@ instance (GSumVec n a, GSumVec n b) => GSumVec n (a :+: b) where type GSumVecR t n (a :+: b) = GSumVecR (GSumVecR t n a) n b gsumvecR = gsumvecR @n @b . gsumvecR @n @a +tagOfType :: TypeR a -> TagR a +tagOfType TupRunit = TagRunit +tagOfType (TupRpair s t) = TagRpair (tagOfType s) (tagOfType t) +tagOfType (TupRsingle t) = TagRsingle t instance KnownNat n => SIMD n Z instance KnownNat n => SIMD n () diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index 3aa082e33..0beedf2b3 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -807,7 +807,7 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp AST.Let lhs bnd body -> AST.Let lhs bnd (cvtPrimFun f body) x -> AST.PrimApp f x - -- Convert the flat list of equations into nested case statement + -- Convert the flat list of equations into nested case statements -- directly on the tag variables. -- cvtCase :: HasCallStack => AST.OpenExp env' aenv' a -> [(TagR a, AST.OpenExp env' aenv' b)] -> AST.OpenExp env' aenv' b @@ -818,27 +818,43 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp = AST.Let lhs s $ nested (expVars (value weakenId)) (over (mapped . _2) (weakenE (weakenWithLHS lhs)) es) where nested :: HasCallStack => AST.OpenExp env' aenv' a -> [(TagR a, AST.OpenExp env' aenv' b)] -> AST.OpenExp env' aenv' b + nested _ [] = internalError "empty case" nested _ [(_,r)] = r - nested s rs = - let groups = groupBy (eqT `on` fst) rs - tags = map (firstT . fst . head) groups - e = prjT (fst (head rs)) s - rhs = map (nested s . map (over _1 ignore)) groups - in - AST.Case scalarType e (zip tags rhs) Nothing + nested s rs@(r:_) + | Exists tag <- tagT (fst r) + = let groups = groupBy (eqT `on` fst) rs + rhs = map (nested s . map (over _1 ignore)) groups + e = prjT tag (fst r) s + tags = map (firstT tag . fst . head) groups + in + AST.Case tag e (zip tags rhs) Nothing + + tagT :: TagR a -> Exists TagType + tagT = fromJust . go + where + go ::TagR a -> Maybe (Exists TagType) + go (TagRtag t _ _) = Just (Exists t) + go (TagRbit _) = Just (Exists TagBit) + go (TagRpair a b) + | Just t <- go a = Just t + | otherwise = go b + go _ = Nothing -- Extract the variable representing this particular tag from the -- scrutinee. This is safe because we let-bind the argument first. - prjT :: TagR a -> AST.OpenExp env' aenv' a -> AST.OpenExp env' aenv' TAG - prjT = fromJust $$ go + prjT :: forall a t env' aenv'. TagType t -> TagR a -> AST.OpenExp env' aenv' a -> AST.OpenExp env' aenv' t + prjT t = fromJust $$ go where - go :: TagR a -> AST.OpenExp env' aenv' a -> Maybe (AST.OpenExp env' aenv' TAG) - go TagRbit{} _ = error "TODO: TagRbit" - go (TagRtag TypeWord8 _ _) (AST.Pair l _) = Just l - go (TagRpair ta tb) (AST.Pair l r) = - case go ta l of - Just t -> Just t - Nothing -> go tb r + go :: TagR s -> AST.OpenExp env' aenv' s -> Maybe (AST.OpenExp env' aenv' t) + go (TagRtag s _ _) (AST.Pair l _) + | Just Refl <- matchTagType s t + = Just l + go (TagRbit _) x + | Just Refl <- matchTagType t TagBit + = Just x + go (TagRpair ta tb) (AST.Pair l r) + | Just x <- go ta l = Just x + | otherwise = go tb r go _ _ = Nothing -- Equality up to the first constructor tag encountered @@ -846,28 +862,33 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp eqT a b = snd $ go a b where go :: TagR a -> TagR a -> (Any, Bool) - go TagRunit TagRunit = no True - go TagRsingle{} TagRsingle{} = no True - go TagRundef{} TagRundef{} = no True - go (TagRtag TypeWord8 v1 _) (TagRtag TypeWord8 v2 _) = yes (v1 == v2) - go (TagRbit TypeBit v1) (TagRbit TypeBit v2) = yes (v1 == v2) - go (TagRpair a1 b1) (TagRpair a2 b2) = + go TagRunit TagRunit = no True + go TagRsingle{} TagRsingle{} = no True + go TagRundef{} TagRundef{} = no True + go (TagRtag t1 v1 _) (TagRtag t2 v2 _) + | Just Refl <- matchTagType t1 t2 + , TagDict <- tagDict t1 + = yes (v1 == v2) + go (TagRpair a1 b1) (TagRpair a2 b2) = let (Any r, s) = go a1 a2 in case r of True -> yes s False -> go b1 b2 go _ _ = no False - firstT :: TagR a -> TAG - firstT = fromJust . go + firstT :: forall a t. TagType t -> TagR a -> t + firstT t = fromJust . go where - go :: TagR a -> Maybe TAG - go TagRbit{} = error "TODO: TagRbit" - go (TagRtag TypeWord8 v _) = Just v - go (TagRpair a b) = - case go a of - Just t -> Just t - Nothing -> go b + go :: TagR s -> Maybe t + go (TagRtag s v _) + | Just Refl <- matchTagType s t + = Just v + go (TagRbit v) + | Just Refl <- matchTagType t TagBit + = Just v + go (TagRpair a b) + | Just v <- go a = Just v + | otherwise = go b go _ = Nothing -- Replace the first constructor tag encountered with a regular @@ -879,8 +900,12 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp go TagRunit = no $ TagRunit go (TagRsingle t) = no $ TagRsingle t go (TagRundef t) = no $ TagRundef t - go (TagRtag t _ a) = yes $ TagRpair (TagRundef (NumScalarType (IntegralNumType (SingleIntegralType t)))) a - go TagRbit{} = error "TODO: TagRbit" + go (TagRbit _) = yes $ TagRsingle scalarType + go (TagRtag t _ a) = + case t of + TagBit -> yes $ TagRpair (TagRundef scalarType) a + TagWord8 -> yes $ TagRpair (TagRundef scalarType) a + TagWord16 -> yes $ TagRpair (TagRundef scalarType) a go (TagRpair a1 a2) = let (Any r, a1') = go a1 in case r of @@ -3153,6 +3178,18 @@ recoverSharingSeq config seq --} +-- Utilities +-- --------- + +data TagDict t where + TagDict :: Eq t => TagDict t + +tagDict :: TagType t -> TagDict t +tagDict TagBit = TagDict +tagDict TagWord8 = TagDict +tagDict TagWord16 = TagDict + + -- Debugging -- --------- diff --git a/src/Data/Array/Accelerate/Trafo/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Simplify.hs index 55a473dfb..5b9bd3fe5 100644 --- a/src/Data/Array/Accelerate/Trafo/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Simplify.hs @@ -334,13 +334,14 @@ simplifyOpenExp env = first getAny . cvtE | Just Refl <- matchOpenExp t' e' = Stats.knownBranch "redundant" (yes e') | otherwise = Cond <$> p <*> t <*> e - caseof :: ScalarType TAG - -> (Any, OpenExp env aenv TAG) - -> (Any, [(TAG, OpenExp env aenv b)]) + caseof :: TagType tag + -> (Any, OpenExp env aenv tag) + -> (Any, [(tag, OpenExp env aenv b)]) -> (Any, Maybe (OpenExp env aenv b)) -> (Any, OpenExp env aenv b) caseof tagR x@(_,x') xs@(_,xs') md@(_,md') | Const _ t <- x' + , TagDict <- tagDict tagR = Stats.caseElim "known" (yes (fromJust $ lookup t xs')) | Just d <- md' , [] <- xs' @@ -361,13 +362,14 @@ simplifyOpenExp env = first getAny . cvtE where (us,vs) = partition (\(n,_) -> n > 1) $ Map.elems - . Map.fromListWith merge + . Map.fromListWith (merge tagR) $ [ (hashOpenExp e, (1,(t, e))) | (t,e) <- xs' ] - merge :: (Int, (TAG, OpenExp env aenv b)) -> (Int, (TAG, OpenExp env aenv b)) -> (Int, (TAG, OpenExp env aenv b)) - merge (n,(_,a)) (m,(_,b)) + merge :: TagType tag -> (Int, (tag, OpenExp env aenv b)) -> (Int, (tag, OpenExp env aenv b)) -> (Int, (tag, OpenExp env aenv b)) + merge t (n,(_,a)) (m,(_,b)) + | TagDict <- tagDict t = internalCheck "hashOpenExp/collision" (maybe False (const True) (matchOpenExp a b)) - $ (n+m, (0xff, a)) + $ (n+m, (maxBound, a)) -- Shape manipulations -- @@ -406,6 +408,7 @@ simplifyOpenExp env = first getAny . cvtE yes :: x -> (Any, x) yes x = (Any True, x) + extractConstTuple :: OpenExp env aenv t -> Maybe t extractConstTuple Nil = Just () extractConstTuple (Pair e1 e2) = (,) <$> extractConstTuple e1 <*> extractConstTuple e2 @@ -674,3 +677,15 @@ summariseOpenExp = (terms +~ 1) . goE PrimToBool i b -> travIntegralType i +++ travBitType b PrimFromBool b i -> travBitType b +++ travIntegralType i + +-- Utilities +-- --------- + +data TagDict t where + TagDict :: (Eq t, Bounded t) => TagDict t + +tagDict :: TagType t -> TagDict t +tagDict TagBit = TagDict +tagDict TagWord8 = TagDict +tagDict TagWord16 = TagDict + From 6a3234b887ab2e863e71abbcbaabad7fc6185ee4 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 20 Jun 2022 11:38:12 +0200 Subject: [PATCH 30/86] clean up constructor/enum tags --- src/Data/Array/Accelerate/Pattern/Bool.hs | 8 +-- src/Data/Array/Accelerate/Pattern/TH.hs | 2 +- .../Array/Accelerate/Representation/Tag.hs | 50 +++++++++---------- src/Data/Array/Accelerate/Sugar/Elt.hs | 4 +- src/Data/Array/Accelerate/Trafo/Sharing.hs | 36 +++++++------ 5 files changed, 54 insertions(+), 46 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Bool.hs b/src/Data/Array/Accelerate/Pattern/Bool.hs index 8efd6f6a1..08293e975 100644 --- a/src/Data/Array/Accelerate/Pattern/Bool.hs +++ b/src/Data/Array/Accelerate/Pattern/Bool.hs @@ -43,8 +43,8 @@ buildFalse = mkExp $ Const scalarType 0 matchFalse :: HasCallStack => Exp Bool -> Maybe () matchFalse (Exp e) = case e of - SmartExp (Match (TagRbit 0) _) -> Just () - SmartExp Match{} -> Nothing + SmartExp (Match (TagRenum TagBit 0) _) -> Just () + SmartExp Match{} -> Nothing _ -> error $ unlines [ "Embedded pattern synonym used outside 'match' context." , "" @@ -64,8 +64,8 @@ buildTrue = mkExp $ Const scalarType 1 matchTrue :: HasCallStack => Exp Bool -> Maybe () matchTrue (Exp e) = case e of - SmartExp (Match (TagRbit 1) _) -> Just () - SmartExp Match{} -> Nothing + SmartExp (Match (TagRenum TagBit 1) _) -> Just () + SmartExp Match{} -> Nothing _ -> error $ unlines [ "Embedded pattern synonym used outside 'match' context." , "" diff --git a/src/Data/Array/Accelerate/Pattern/TH.hs b/src/Data/Array/Accelerate/Pattern/TH.hs index 2e9a2183b..7874f4a4c 100644 --- a/src/Data/Array/Accelerate/Pattern/TH.hs +++ b/src/Data/Array/Accelerate/Pattern/TH.hs @@ -335,7 +335,7 @@ mkConS tn' tvs' prev' next' tag' con' = do (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) [t| Exp $(foldl' appT (conT tn) (map varT tvs)) -> Maybe $(tupT (map (\t -> [t| Exp $(return t) |]) fs)) |] - matchP us = [p| TagRtag _ $(litP (IntegerL (toInteger tag))) $pat |] + matchP us = [p| TagRcon _ $(litP (IntegerL (toInteger tag))) $pat |] where pat = [p| $(foldl (\ps p -> [p| TagRpair $ps $p |]) [p| TagRunit |] us) |] diff --git a/src/Data/Array/Accelerate/Representation/Tag.hs b/src/Data/Array/Accelerate/Representation/Tag.hs index 081f53e22..eace299f6 100644 --- a/src/Data/Array/Accelerate/Representation/Tag.hs +++ b/src/Data/Array/Accelerate/Representation/Tag.hs @@ -33,45 +33,45 @@ data TagType t where -- | This structure both witnesses the layout of our representation types -- (as TupR does) and represents a complete path of pattern matching -- through this type. It indicates which fields of the structure represent --- the union tags (TagRtag) or store undefined values (TagRundef). +-- the union tags (TagRcon or TagRenum) or store undefined values (TagRundef). -- --- The function 'eltTags' produces all valid paths through the type. For +-- The function 'tagsR' produces all valid paths through the type. For -- example the type '(Bool,Bool)' produces the following: -- -- ghci> putStrLn . unlines . map show $ tagsR @(Bool,Bool) --- (((),(0#,())),(0#,())) -- (False, False) --- (((),(0#,())),(1#,())) -- (False, True) --- (((),(1#,())),(0#,())) -- (True, False) --- (((),(1#,())),(1#,())) -- (True, True) +-- (((),0#),0#) -- (False, False) +-- (((),0#),1#) -- (False, True) +-- (((),1#),0#) -- (True, False) +-- (((),1#),1#) -- (True, True) -- data TagR a where TagRunit :: TagR () TagRsingle :: ScalarType a -> TagR a TagRundef :: ScalarType a -> TagR a - TagRtag :: TagType t -> t -> TagR a -> TagR (t, a) - TagRbit :: Bit -> TagR Bit -- redundant with TagRtag but simplifies abstract syntax for Bool TagRpair :: TagR a -> TagR b -> TagR (a, b) + TagRcon :: TagType t -> t -> TagR a -> TagR (t, a) -- data constructors + TagRenum :: TagType t -> t -> TagR t -- enumerations instance Show (TagR a) where show TagRunit = "()" show TagRsingle{} = "." show TagRundef{} = "undef" - show (TagRpair ta tb) = "(" ++ show ta ++ "," ++ show tb ++ ")" - show (TagRbit b) = shows b "#" - show (TagRtag tR t e) = "(" ++ tag tR t ++ "#," ++ show e ++ ")" - where - tag :: TagType t -> t -> String - tag TagBit = show - tag TagWord8 = show - tag TagWord16 = show + show (TagRpair a b) = "(" ++ show a ++ "," ++ show b ++ ")" + show (TagRcon tR t e) = "(" ++ showTag tR t ++ "#," ++ show e ++ ")" + show (TagRenum tR t) = showTag tR t ++ "#" + +showTag :: TagType t -> t -> String +showTag TagBit = show +showTag TagWord8 = show +showTag TagWord16 = show rnfTag :: TagR a -> () rnfTag TagRunit = () rnfTag (TagRsingle e) = rnfScalarType e rnfTag (TagRundef e) = rnfScalarType e -rnfTag (TagRpair ta tb) = rnfTag ta `seq` rnfTag tb -rnfTag (TagRbit (Bit b)) = b `seq` () -rnfTag (TagRtag tR t e) = rnfTagType tR `seq` t `seq` rnfTag e +rnfTag (TagRpair a b) = rnfTag a `seq` rnfTag b +rnfTag (TagRenum tR t) = rnfTagType tR `seq` t `seq` () +rnfTag (TagRcon tR t e) = rnfTagType tR `seq` t `seq` rnfTag e rnfTagType :: TagType t -> () rnfTagType TagBit = () @@ -79,12 +79,12 @@ rnfTagType TagWord8 = () rnfTagType TagWord16 = () liftTagR :: TagR a -> CodeQ (TagR a) -liftTagR TagRunit = [|| TagRunit ||] -liftTagR (TagRsingle e) = [|| TagRsingle $$(liftScalarType e) ||] -liftTagR (TagRundef e) = [|| TagRundef $$(liftScalarType e) ||] -liftTagR (TagRpair ta tb) = [|| TagRpair $$(liftTagR ta) $$(liftTagR tb) ||] -liftTagR (TagRtag tR t e) = [|| TagRtag $$(liftTagType tR) $$(liftTag tR t) $$(liftTagR e) ||] -liftTagR (TagRbit (Bit b)) = [|| TagRbit (Bit b) ||] +liftTagR TagRunit = [|| TagRunit ||] +liftTagR (TagRsingle e) = [|| TagRsingle $$(liftScalarType e) ||] +liftTagR (TagRundef e) = [|| TagRundef $$(liftScalarType e) ||] +liftTagR (TagRpair a b) = [|| TagRpair $$(liftTagR a) $$(liftTagR b) ||] +liftTagR (TagRcon tR t e) = [|| TagRcon $$(liftTagType tR) $$(liftTag tR t) $$(liftTagR e) ||] +liftTagR (TagRenum tR t) = [|| TagRenum $$(liftTagType tR) $$(liftTag tR t) ||] liftTag :: TagType t -> t -> CodeQ t liftTag TagBit (Bit x) = [|| Bit x ||] diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index e3acfdbcb..a873bd221 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -164,7 +164,7 @@ instance (GElt a, GElt b) => GElt (a :*: b) where instance (GElt a, GElt b, GSumElt (a :+: b)) => GElt (a :+: b) where type GEltR t (a :+: b) = (TAG, GSumEltR t (a :+: b)) geltR t = TupRpair (TupRsingle scalarType) (gsumEltR @(a :+: b) t) - gtagsR t = uncurry (TagRtag TagWord8) <$> gsumTagsR @(a :+: b) 0 t + gtagsR t = uncurry (TagRcon TagWord8) <$> gsumTagsR @(a :+: b) 0 t gfromElt = gsumFromElt 0 gtoElt (k,x) = gsumToElt k x gundef t = (0xff, gsumUndef @(a :+: b) t) @@ -290,7 +290,7 @@ instance (Elt a, Elt b) => Elt (Either a b) instance Elt Bool where type EltR Bool = Bit eltR = TupRsingle scalarType - tagsR = [TagRbit 0, TagRbit 1] + tagsR = [TagRenum TagBit 0, TagRenum TagBit 1] toElt = unBit fromElt = Bit diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index 0beedf2b3..d05f5ad29 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -833,8 +833,8 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp tagT = fromJust . go where go ::TagR a -> Maybe (Exists TagType) - go (TagRtag t _ _) = Just (Exists t) - go (TagRbit _) = Just (Exists TagBit) + go (TagRcon t _ _) = Just (Exists t) + go (TagRenum t _) = Just (Exists t) go (TagRpair a b) | Just t <- go a = Just t | otherwise = go b @@ -846,11 +846,11 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp prjT t = fromJust $$ go where go :: TagR s -> AST.OpenExp env' aenv' s -> Maybe (AST.OpenExp env' aenv' t) - go (TagRtag s _ _) (AST.Pair l _) + go (TagRcon s _ _) (AST.Pair l _) | Just Refl <- matchTagType s t = Just l - go (TagRbit _) x - | Just Refl <- matchTagType t TagBit + go (TagRenum s _) x + | Just Refl <- matchTagType s t = Just x go (TagRpair ta tb) (AST.Pair l r) | Just x <- go ta l = Just x @@ -865,7 +865,11 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp go TagRunit TagRunit = no True go TagRsingle{} TagRsingle{} = no True go TagRundef{} TagRundef{} = no True - go (TagRtag t1 v1 _) (TagRtag t2 v2 _) + go (TagRenum t1 v1) (TagRenum t2 v2) + | Just Refl <- matchTagType t1 t2 + , TagDict <- tagDict t1 + = yes (v1 == v2) + go (TagRcon t1 v1 _) (TagRcon t2 v2 _) | Just Refl <- matchTagType t1 t2 , TagDict <- tagDict t1 = yes (v1 == v2) @@ -880,11 +884,11 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp firstT t = fromJust . go where go :: TagR s -> Maybe t - go (TagRtag s v _) + go (TagRcon s v _) | Just Refl <- matchTagType s t = Just v - go (TagRbit v) - | Just Refl <- matchTagType t TagBit + go (TagRenum s v) + | Just Refl <- matchTagType s t = Just v go (TagRpair a b) | Just v <- go a = Just v @@ -900,12 +904,16 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp go TagRunit = no $ TagRunit go (TagRsingle t) = no $ TagRsingle t go (TagRundef t) = no $ TagRundef t - go (TagRbit _) = yes $ TagRsingle scalarType - go (TagRtag t _ a) = + go (TagRenum t _) = yes $ + case t of + TagBit -> TagRundef scalarType + TagWord8 -> TagRundef scalarType + TagWord16 -> TagRundef scalarType + go (TagRcon t _ a) = yes $ case t of - TagBit -> yes $ TagRpair (TagRundef scalarType) a - TagWord8 -> yes $ TagRpair (TagRundef scalarType) a - TagWord16 -> yes $ TagRpair (TagRundef scalarType) a + TagBit -> TagRpair (TagRundef scalarType) a + TagWord8 -> TagRpair (TagRundef scalarType) a + TagWord16 -> TagRpair (TagRundef scalarType) a go (TagRpair a1 a2) = let (Any r, a1') = go a1 in case r of From 0b10e22b05850107b2dea56fde4d0d7d2007a914 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Sun, 26 Jun 2022 15:20:15 +0200 Subject: [PATCH 31/86] add vectorised min, max --- src/Data/Array/Accelerate/Classes/VOrd.hs | 15 +++++++++++---- src/Data/Array/Accelerate/Classes/VOrd.hs-boot | 13 +++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/Data/Array/Accelerate/Classes/VOrd.hs b/src/Data/Array/Accelerate/Classes/VOrd.hs index 30a260edc..5d609ffba 100644 --- a/src/Data/Array/Accelerate/Classes/VOrd.hs +++ b/src/Data/Array/Accelerate/Classes/VOrd.hs @@ -51,12 +51,17 @@ class VEq n a => VOrd n a where (>*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) (<=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) (>=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + vmin :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n a) + vmax :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n a) vcompare :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Ordering) - x <* y = select (vcompare x y ==* vlt) vtrue vfalse - x <=* y = select (vcompare x y ==* vgt) vfalse vtrue - x >* y = select (vcompare x y ==* vgt) vtrue vfalse - x >=* y = select (vcompare x y ==* vlt) vfalse vtrue + x <* y = select (vcompare x y ==* vlt) vtrue vfalse + x <=* y = select (vcompare x y ==* vgt) vfalse vtrue + x >* y = select (vcompare x y ==* vgt) vtrue vfalse + x >=* y = select (vcompare x y ==* vlt) vfalse vtrue + + vmin x y = select (x <=* y) x y + vmax x y = select (x <=* y) y x vcompare x y = select (x ==* y) veq @@ -112,6 +117,8 @@ runQ $ do (>*) = mkPrimBinary $ PrimGt scalarType (<=*) = mkPrimBinary $ PrimLtEq scalarType (>=*) = mkPrimBinary $ PrimGtEq scalarType + vmin = mkMin + vmax = mkMax |] mkTup :: Word8 -> Q Dec diff --git a/src/Data/Array/Accelerate/Classes/VOrd.hs-boot b/src/Data/Array/Accelerate/Classes/VOrd.hs-boot index b1b83feb5..083f879e2 100644 --- a/src/Data/Array/Accelerate/Classes/VOrd.hs-boot +++ b/src/Data/Array/Accelerate/Classes/VOrd.hs-boot @@ -26,12 +26,17 @@ class VEq n a => VOrd n a where (>*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) (<=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) (>=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + vmin :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n a) + vmax :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n a) vcompare :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Ordering) - x <* y = select (vcompare x y ==* vlt) vtrue vfalse - x <=* y = select (vcompare x y ==* vgt) vfalse vtrue - x >* y = select (vcompare x y ==* vgt) vtrue vfalse - x >=* y = select (vcompare x y ==* vlt) vfalse vtrue + x <* y = select (vcompare x y ==* vlt) vtrue vfalse + x <=* y = select (vcompare x y ==* vgt) vfalse vtrue + x >* y = select (vcompare x y ==* vgt) vtrue vfalse + x >=* y = select (vcompare x y ==* vlt) vfalse vtrue + + vmin x y = select (x <=* y) x y + vmax x y = select (x <=* y) y x vcompare x y = select (x ==* y) veq From ba107c8224d40f0b22cde456a6d450ed7af2c128 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 27 Jun 2022 15:00:50 +0200 Subject: [PATCH 32/86] value/expression polymorphic pattern synonyms for Vec --- accelerate.cabal | 1 + src/Data/Array/Accelerate.hs | 12 +- src/Data/Array/Accelerate/Data/Complex.hs | 50 ++--- src/Data/Array/Accelerate/Pattern.hs | 176 +---------------- src/Data/Array/Accelerate/Pattern/SIMD.hs | 194 +++++++++++++++++++ src/Data/Array/Accelerate/Pattern/Shape.hs | 100 ++++++---- src/Data/Array/Accelerate/Sugar/Vec.hs | 16 +- src/Data/Array/Accelerate/Test/NoFib/Base.hs | 10 +- src/Data/Primitive/Vec.hs | 108 +++++------ 9 files changed, 361 insertions(+), 306 deletions(-) create mode 100644 src/Data/Array/Accelerate/Pattern/SIMD.hs diff --git a/accelerate.cabal b/accelerate.cabal index deac8aed4..e411aafba 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -486,6 +486,7 @@ library Data.Array.Accelerate.Pattern.Either Data.Array.Accelerate.Pattern.Maybe Data.Array.Accelerate.Pattern.Ordering + Data.Array.Accelerate.Pattern.SIMD Data.Array.Accelerate.Pattern.Shape Data.Array.Accelerate.Pattern.TH Data.Array.Accelerate.Prelude diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 958088e19..d4461f2b0 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -353,11 +353,12 @@ module Data.Array.Accelerate ( pattern I0, pattern I1, pattern I2, pattern I3, pattern I4, pattern I5, pattern I6, pattern I7, pattern I8, pattern I9, - Vec2, pattern V2, - Vec3, pattern V3, - Vec4, pattern V4, - Vec8, pattern V8, - Vec16, pattern V16, + pattern SIMD, + V2, pattern V2, + V3, pattern V3, + V4, pattern V4, + V8, pattern V8, + V16, pattern V16, mkPattern, mkPatterns, @@ -449,6 +450,7 @@ import Data.Array.Accelerate.Data.Maybe import Data.Array.Accelerate.Language import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Pattern.Shape +import Data.Array.Accelerate.Pattern.SIMD import Data.Array.Accelerate.Pattern.TH import Data.Array.Accelerate.Prelude import Data.Array.Accelerate.Pretty () -- show instances diff --git a/src/Data/Array/Accelerate/Data/Complex.hs b/src/Data/Array/Accelerate/Data/Complex.hs index ee1cee8fc..8ec644454 100644 --- a/src/Data/Array/Accelerate/Data/Complex.hs +++ b/src/Data/Array/Accelerate/Data/Complex.hs @@ -102,33 +102,33 @@ instance Elt a => Elt (Complex a) where [ TagRunit `TagRpair` ta `TagRpair` tb | ta <- go tR, tb <- go tR ] toElt = case complexR $ eltR @a of - ComplexVec _ -> \(Prim.Vec2 r i) -> toElt r :+ toElt i - ComplexTup -> \(((), r), i) -> toElt r :+ toElt i + ComplexVec _ -> \(Prim.V2 r i) -> toElt r :+ toElt i + ComplexTup -> \(((), r), i) -> toElt r :+ toElt i fromElt (r :+ i) = case complexR $ eltR @a of - ComplexVec _ -> Prim.Vec2 (fromElt r) (fromElt i) + ComplexVec _ -> Prim.V2 (fromElt r) (fromElt i) ComplexTup -> (((), fromElt r), fromElt i) type family ComplexR a where - ComplexR Half = Prim.Vec2 Float16 - ComplexR Float = Prim.Vec2 Float32 - ComplexR Double = Prim.Vec2 Float64 - ComplexR Float128 = Prim.Vec2 Float128 - ComplexR Int8 = Prim.Vec2 Int8 - ComplexR Int16 = Prim.Vec2 Int16 - ComplexR Int32 = Prim.Vec2 Int32 - ComplexR Int64 = Prim.Vec2 Int64 - ComplexR Int128 = Prim.Vec2 Int128 - ComplexR Word8 = Prim.Vec2 Word8 - ComplexR Word16 = Prim.Vec2 Word16 - ComplexR Word32 = Prim.Vec2 Word32 - ComplexR Word64 = Prim.Vec2 Word64 - ComplexR Word128 = Prim.Vec2 Word128 + ComplexR Half = Prim.V2 Float16 + ComplexR Float = Prim.V2 Float32 + ComplexR Double = Prim.V2 Float64 + ComplexR Float128 = Prim.V2 Float128 + ComplexR Int8 = Prim.V2 Int8 + ComplexR Int16 = Prim.V2 Int16 + ComplexR Int32 = Prim.V2 Int32 + ComplexR Int64 = Prim.V2 Int64 + ComplexR Int128 = Prim.V2 Int128 + ComplexR Word8 = Prim.V2 Word8 + ComplexR Word16 = Prim.V2 Word16 + ComplexR Word32 = Prim.V2 Word32 + ComplexR Word64 = Prim.V2 Word64 + ComplexR Word128 = Prim.V2 Word128 ComplexR a = (((), a), a) data ComplexType a c where - ComplexVec :: Prim a => NumType (Prim.Vec2 a) -> ComplexType a (Prim.Vec2 a) - ComplexTup :: ComplexType a (((), a), a) + ComplexVec :: Prim a => NumType (Prim.V2 a) -> ComplexType a (Prim.V2 a) + ComplexTup :: ComplexType a (((), a), a) complexR :: TypeR a -> ComplexType a (ComplexR a) complexR = tuple @@ -180,11 +180,11 @@ constructComplex r@(Exp r') i@(Exp i') = ComplexTup -> Pattern (r,i) ComplexVec t -> Exp $ num t r' i' where - num :: NumType (Prim.Vec2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + num :: NumType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t - integral :: IntegralType (Prim.Vec2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + integral :: IntegralType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) integral (SingleIntegralType t) = case t of integral (VectorIntegralType n t) = let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) @@ -200,7 +200,7 @@ constructComplex r@(Exp r') i@(Exp i') = TypeWord64 -> pack v TypeWord128 -> pack v - floating :: FloatingType (Prim.Vec2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + floating :: FloatingType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) floating (SingleFloatingType t) = case t of floating (VectorFloatingType n t) = let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) @@ -227,11 +227,11 @@ deconstructComplex (Exp c) = let (r, i) = num t c in (Exp r, Exp i) where - num :: NumType (Prim.Vec2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + num :: NumType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t - integral :: IntegralType (Prim.Vec2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + integral :: IntegralType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) integral (SingleIntegralType t) = case t of integral (VectorIntegralType n t) = let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) @@ -247,7 +247,7 @@ deconstructComplex (Exp c) = TypeWord64 -> unpack v TypeWord128 -> unpack v - floating :: FloatingType (Prim.Vec2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + floating :: FloatingType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) floating (SingleFloatingType t) = case t of floating (VectorFloatingType n t) = let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) diff --git a/src/Data/Array/Accelerate/Pattern.hs b/src/Data/Array/Accelerate/Pattern.hs index faa730af3..fa0f5f958 100644 --- a/src/Data/Array/Accelerate/Pattern.hs +++ b/src/Data/Array/Accelerate/Pattern.hs @@ -10,8 +10,6 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Pattern @@ -30,23 +28,13 @@ module Data.Array.Accelerate.Pattern ( pattern T7, pattern T8, pattern T9, pattern T10, pattern T11, pattern T12, pattern T13, pattern T14, pattern T15, pattern T16, - pattern I0, pattern I1, pattern I2, pattern I3, pattern I4, - pattern I5, pattern I6, pattern I7, pattern I8, pattern I9, - - pattern SIMD, - pattern V2, pattern V3, pattern V4, pattern V8, pattern V16, - ) where import Data.Array.Accelerate.AST.Idx -import Data.Array.Accelerate.Pattern.Shape import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Array import Data.Array.Accelerate.Sugar.Elt --- import Data.Array.Accelerate.Sugar.Shape -import Data.Array.Accelerate.Sugar.Vec -import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra hiding ( Exp, Match ) @@ -65,15 +53,6 @@ class IsPattern context a b where matcher :: context a -> b -pattern SIMD :: forall b a context. IsSIMD context a b => b -> context a -pattern SIMD vars <- (vmatcher @context -> vars) - where SIMD = vbuilder @context - -class IsSIMD context a b where - vbuilder :: b -> context a - vmatcher :: context a -> b - - -- IsPattern instances for up to 16-tuples (Acc and Exp). TH takes care of -- the (unremarkable) boilerplate for us. -- @@ -149,101 +128,10 @@ runQ $ do -> $(tupE [[| Exp (SmartExp (Match $(varE m) $(get (varE _x) i))) |] | m <- ms | i <- [(n-1), (n-2) .. 0]]) _ -> $(tupE [[| Exp $(get (varE _x) i) |] | i <- [(n-1), (n-2) .. 0]]) |] - - -- Generate instance declarations for IsSIMD of the form: - -- instance (Elt a, Elt v, EltR v ~ VecR n a) => IsSIMD Exp v (Exp a, Exp a) - mkVecPattern :: Int -> Q [Dec] - mkVecPattern n = do - a <- newName "a" - v <- newName "v" - _x <- newName "_x" - _y <- newName "_y" - let - aT = varT a - vT = varT v - nT = litT (numTyLit (toInteger n)) - -- Last argument to `IsSIMD`, eg (Exp, a, Exp a) in the example - tup = tupT (replicate n ([t| Exp $aT |])) - -- Constraints for the type class, consisting of the Elt - -- constraints and the equality on representation types - context = [t| (Elt $aT, Elt $vT, SIMD $nT $aT, EltR $vT ~ VecR $nT $aT) |] - -- Type variables for the elements - xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - -- Variables for sub-pattern matches - -- ms = [ mkName ('m' : show i) | i <- [0 .. n-1] ] - -- tags = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms - -- - [d| instance $context => IsSIMD Exp $vT $tup where - vbuilder $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) = - let _unmatch :: SmartExp a -> SmartExp a - _unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y) - _unmatch x = x - in - Exp $(foldl (\vs (i, x) -> [| mkInsert - $(varE 'vecR `appTypeE` nT `appTypeE` aT) - $(varE 'eltR `appTypeE` aT) - TypeWord8 - $vs - (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) i)) - (_unmatch $(varE x)) - |]) - [| unExp (undef :: Exp (Vec $nT $aT)) |] - (zip [0 .. n-1] xs) - ) - - vmatcher (Exp $(varP _x)) = - case $(varE _x) of - -- SmartExp (Match $tags $(varP _y)) - -- -> $(tupE [[| Exp (SmartExp (Match $(varE m) (unExp (extract (Exp $(varE _x) :: Exp $vec) (constant (i :: Word8)))))) |] | m <- ms | i <- [0 .. n-1]]) - -- -> $(tupE [[| Exp (SmartExp (Match $(varE m) (mkExtract - -- $(varE 'vecR `appTypeE` nT `appTypeE` aT) - -- $(varE 'eltR `appTypeE` aT) - -- TypeWord8 - -- $(varE _x) - -- (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) i))))) |] - -- | m <- ms - -- | i <- [0 .. n-1] ]) - - _ -> $(tupE [[| Exp $ mkExtract - $(varE 'vecR `appTypeE` nT `appTypeE` aT) - $(varE 'eltR `appTypeE` aT) - TypeWord8 - $(varE _x) - (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) i)) - |] - | i <- [0 .. n-1] ]) - |] - -{-- - -- Generate instance declarations for IsVector of the form: - -- instance (Elt v, EltR v ~ Vec 2 a, Elt a) => IsVector Exp v (Exp a, Exp a) - mkVecPattern :: Int -> Q [Dec] - mkVecPattern n = do - a <- newName "a" - v <- newName "v" - let - -- Last argument to `IsVector`, eg (Exp, a, Exp a) in the example - tup = tupT (replicate n ([t| Exp $(varT a)|])) - -- Representation as a vector, eg (Vec 2 a) - vec = [t| Vec $(litT (numTyLit (fromIntegral n))) $(varT a) |] - -- Constraints for the type class, consisting of Elt constraints on all type variables, - -- and an equality constraint on the representation type of `a` and the vector representation `vec`. - context = [t| (Elt $(varT v), VecElt $(varT a), EltR $(varT v) ~ $vec) |] - -- - vecR = foldr appE ([| VecRnil |] `appE` (varE 'singleType `appTypeE` varT a)) (replicate n [| VecRsucc |]) - tR = tupT (replicate n (varT a)) - -- - [d| instance $context => IsVector Exp $(varT v) $tup where - vpack x = case builder x :: Exp $tR of - Exp x' -> Exp (SmartExp (VecPack $vecR x')) - vunpack (Exp x) = matcher (Exp (SmartExp (VecUnpack $vecR x)) :: Exp $tR) - |] ---} -- es <- mapM mkExpPattern [0..16] as <- mapM mkAccPattern [0..16] - vs <- mapM mkVecPattern [2,3,4,8,16] - return $ concat (es ++ as++ vs) + return $ concat (es ++ as) -- | Specialised pattern synonyms for tuples, which may be more convenient to @@ -260,12 +148,6 @@ runQ $ do -- -- These pattern synonyms can be used for both 'Exp' and 'Acc' terms. -- --- Similarly, we have patterns for constructing and destructing indices of --- a given dimensionality: --- --- > let ix = Ix 2 3 -- :: Exp DIM2 --- > let I2 y x = ix -- y :: Exp Int, x :: Exp Int --- runQ $ do let mkT :: Int -> Q [Dec] @@ -284,60 +166,6 @@ runQ $ do , pragCompleteD [name] (Just ''Acc) , pragCompleteD [name] (Just ''Exp) ] - - mkV :: Int -> Q [Dec] - mkV n = - let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - a = varT (mkName "a") - ts = replicate n a - name = mkName ('V':show n) - tup = tupT (map (\t -> [t| Exp $t |]) ts) - vec = [t| Vec $(litT (numTyLit (toInteger n))) $a |] - cst = [t| (Elt $a, SIMD $(litT (numTyLit (toInteger n))) $a, IsSIMD Exp $vec $tup) |] - sig = foldr (\t r -> [t| Exp $t -> $r |]) [t| Exp $vec |] ts - in - sequence - [ patSynSigD name [t| $cst => $sig |] - , patSynD name (prefixPatSyn xs) implBidir [p| SIMD $(tupP (map varP xs)) |] - , pragCompleteD [name] Nothing - ] - - mkI :: Int -> Q [Dec] - mkI n = - let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - ts = map varT xs - name = mkName ('I':show n) - ix = mkName ":." - cst = tupT (map (\t -> [t| Elt $t |]) ts) - dim = foldl (\h t -> [t| $h :. $t |]) [t| Z |] ts - sig = foldr (\t r -> [t| Exp $t -> $r |]) [t| Exp $dim |] ts - in - sequence - [ patSynSigD name [t| $cst => $sig |] - , patSynD name (prefixPatSyn xs) implBidir (foldl (\ps p -> infixP ps ix (varP p)) [p| Z |] xs) - , pragCompleteD [name] Nothing - ] - -{-- - mkV :: Int -> Q [Dec] - mkV n = - let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - ts = map varT xs - name = mkName ('V':show n) - con = varT (mkName "con") - ty1 = varT (mkName "vec") - ty2 = tupT (map (con `appT`) ts) - sig = foldr (\t r -> [t| $con $t -> $r |]) (appT con ty1) ts - in - sequence - [ patSynSigD name [t| IsVector $con $ty1 $ty2 => $sig |] - , patSynD name (prefixPatSyn xs) implBidir [p| Vector $(tupP (map varP xs)) |] - , pragCompleteD [name] (Just ''Exp) - ] ---} -- - ts <- mapM mkT [2..16] - is <- mapM mkI [0..9] - vs <- mapM mkV [2,3,4,8,16] - return $ concat (ts ++ is ++ vs) + concat <$> mapM mkT [2..16] diff --git a/src/Data/Array/Accelerate/Pattern/SIMD.hs b/src/Data/Array/Accelerate/Pattern/SIMD.hs new file mode 100644 index 000000000..a8ae3e4fb --- /dev/null +++ b/src/Data/Array/Accelerate/Pattern/SIMD.hs @@ -0,0 +1,194 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ViewPatterns #-} +-- | +-- Module : Data.Array.Accelerate.Pattern +-- Copyright : [2018..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Pattern.SIMD ( + + pattern SIMD, + pattern V2, pattern V3, pattern V4, pattern V8, pattern V16, + +) where + +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Type + +import Language.Haskell.TH.Extra hiding ( Exp, Match ) +import GHC.Exts ( IsList(..) ) + + +pattern SIMD :: forall b a context. IsSIMD context a b => b -> context a +pattern SIMD vars <- (vmatcher @context -> vars) + where SIMD = vbuilder @context +{-# COMPLETE SIMD #-} + +class IsSIMD context a b where + vbuilder :: b -> context a + vmatcher :: context a -> b + + +runQ $ + let + -- Generate instance declarations for IsSIMD of the form: + -- instance (Elt a, Elt v, EltR v ~ VecR n a) => IsSIMD Exp v (Exp a, Exp a) + mkVecPattern :: Int -> Q [Dec] + mkVecPattern n = do + a <- newName "a" + v <- newName "v" + _x <- newName "_x" + _y <- newName "_y" + let + aT = varT a + vT = varT v + nT = litT (numTyLit (toInteger n)) + -- Last argument to `IsSIMD`, eg (Exp, a, Exp a) in the example + tup = tupT (replicate n ([t| Exp $aT |])) + -- Constraints for the type class, consisting of the Elt + -- constraints and the equality on representation types + context = [t| (Elt $aT, Elt $vT, SIMD $nT $aT, EltR $vT ~ VecR $nT $aT) |] + -- Type variables for the elements + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + -- Variables for sub-pattern matches + -- ms = [ mkName ('m' : show i) | i <- [0 .. n-1] ] + -- tags = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms + -- + [d| instance $context => IsSIMD Exp $vT $tup where + vbuilder $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) = + let _unmatch :: SmartExp a -> SmartExp a + _unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y) + _unmatch x = x + in + Exp $(foldl (\vs (i, x) -> [| mkInsert + $(varE 'vecR `appTypeE` nT `appTypeE` aT) + $(varE 'eltR `appTypeE` aT) + TypeWord8 + $vs + (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) i)) + (_unmatch $(varE x)) + |]) + [| unExp (undef :: Exp (Vec $nT $aT)) |] + (zip [0 .. n-1] xs) + ) + + vmatcher (Exp $(varP _x)) = + case $(varE _x) of + -- SmartExp (Match $tags $(varP _y)) + -- -> $(tupE [[| Exp (SmartExp (Match $(varE m) (unExp (extract (Exp $(varE _x) :: Exp $vec) (constant (i :: Word8)))))) |] | m <- ms | i <- [0 .. n-1]]) + -- -> $(tupE [[| Exp (SmartExp (Match $(varE m) (mkExtract + -- $(varE 'vecR `appTypeE` nT `appTypeE` aT) + -- $(varE 'eltR `appTypeE` aT) + -- TypeWord8 + -- $(varE _x) + -- (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) i))))) |] + -- | m <- ms + -- | i <- [0 .. n-1] ]) + + _ -> $(tupE [[| Exp $ mkExtract + $(varE 'vecR `appTypeE` nT `appTypeE` aT) + $(varE 'eltR `appTypeE` aT) + TypeWord8 + $(varE _x) + (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) i)) + |] + | i <- [0 .. n-1] ]) + |] + in + concat <$> mapM mkVecPattern [2,3,4,8,16] + + -- mkV :: Int -> Q [Dec] + -- mkV n = + -- let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + -- a = varT (mkName "a") + -- ts = replicate n a + -- name = mkName ('V':show n) + -- tup = tupT (map (\t -> [t| Exp $t |]) ts) + -- vec = [t| Vec $(litT (numTyLit (toInteger n))) $a |] + -- cst = [t| (Elt $a, SIMD $(litT (numTyLit (toInteger n))) $a, IsSIMD Exp $vec $tup) |] + -- sig = foldr (\t r -> [t| Exp $t -> $r |]) [t| Exp $vec |] ts + -- in + -- sequence + -- [ patSynSigD name [t| $cst => $sig |] + -- , patSynD name (prefixPatSyn xs) implBidir [p| SIMD $(tupP (map varP xs)) |] + -- , pragCompleteD [name] Nothing + -- ] + + -- mkV :: Int -> Q [Dec] + -- mkV n = + -- let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + -- ts = map varT xs + -- name = mkName ('V':show n) + -- con = varT (mkName "con") + -- ty1 = varT (mkName "vec") + -- ty2 = tupT (map (con `appT`) ts) + -- sig = foldr (\t r -> [t| $con $t -> $r |]) (appT con ty1) ts + -- in + -- sequence + -- [ patSynSigD name [t| IsVector $con $ty1 $ty2 => $sig |] + -- , patSynD name (prefixPatSyn xs) implBidir [p| Vector $(tupP (map varP xs)) |] + -- , pragCompleteD [name] (Just ''Exp) + -- ] + +-- Generate polymorphic pattern synonyms which operate on both Haskell values +-- as well as embedded expressions +-- +runQ $ + let + mkV :: Int -> Q [Dec] + mkV n = do + a <- newName "a" + v <- newName "v" + let + as = replicate n (varT a) + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + xsP = map varP xs + xsE = map varE xs + vn = mkName ("V" ++ show n) + isV = mkName ("IsV" ++ show n) + builder = mkName ("buildV" ++ show n) + matcher = mkName ("matchV" ++ show n) + ctx = return [ ConT ''Elt `AppT` VarT a + , ConT ''SIMD `AppT` LitT (NumTyLit (toInteger n)) `AppT` VarT a + ] + -- + sequence + [ patSynSigD vn [t| $(conT isV) $(varT a) $(varT v) => $(foldr (\t r -> [t| $t -> $r |]) (varT v) as) |] + , patSynD vn (prefixPatSyn xs) (explBidir [clause [] (normalB (varE builder)) []]) (parensP $ viewP (varE matcher) (tupP xsP)) + , pragCompleteD [vn] Nothing + -- + , classD (return []) isV [PlainTV a (), PlainTV v ()] [funDep [v] [a]] + [ sigD builder (foldr (\t r -> [t| $t -> $r |]) (varT v) as) + , sigD matcher [t| $(varT v) -> $(tupT as) |] + ] + -- This instance which goes via toList is horrible and I feel bad for using it + -- TLM 2022-06-27 + , instanceD ctx [t| $(conT isV) $(varT a) ($(conT vn) $(varT a)) |] + [ funD builder [ clause xsP (normalB [| fromList $(listE xsE) |]) []] + , funD matcher [ clause [viewP (varE 'toList) (listP xsP)] (normalB (tupE xsE)) [] ] + ] + , instanceD ctx [t| $(conT isV) (Exp $(varT a)) (Exp ($(conT vn) $(varT a))) |] + [ funD builder [ clause xsP (normalB [| SIMD $(tupE xsE) |]) []] + , funD matcher [ clause [conP (mkName "SIMD") [tupP xsP]] (normalB (tupE xsE)) [] ] + ] + ] + in + concat <$> mapM mkV [2,3,4,8,16] + diff --git a/src/Data/Array/Accelerate/Pattern/Shape.hs b/src/Data/Array/Accelerate/Pattern/Shape.hs index e1660f40f..8f4c1beee 100644 --- a/src/Data/Array/Accelerate/Pattern/Shape.hs +++ b/src/Data/Array/Accelerate/Pattern/Shape.hs @@ -1,8 +1,10 @@ +{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} -- | @@ -22,6 +24,9 @@ module Data.Array.Accelerate.Pattern.Shape ( pattern All, All, pattern Any, Any, + pattern I0, pattern I1, pattern I2, pattern I3, pattern I4, + pattern I5, pattern I6, pattern I7, pattern I8, pattern I9, + ) where import Data.Array.Accelerate.AST.Idx @@ -30,78 +35,107 @@ import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Shape ( (:.), Z, All, Any ) import qualified Data.Array.Accelerate.Sugar.Shape as Sugar +import Language.Haskell.TH.Extra hiding ( Exp, Match ) + pattern Z :: IsShapeZ z => z -pattern Z <- (z_matcher -> True) - where Z = z_builder +pattern Z <- (matchZ -> True) + where Z = buildZ {-# COMPLETE Z :: Z #-} {-# COMPLETE Z :: Exp #-} pattern All :: IsShapeAll all => all -pattern All <- (all_matcher -> True) - where All = all_builder +pattern All <- (matchAll -> True) + where All = buildAll {-# COMPLETE All :: All #-} {-# COMPLETE All :: Exp #-} pattern Any :: IsShapeAny any => any -pattern Any <- (any_matcher -> True) - where Any = any_builder +pattern Any <- (matchAny -> True) + where Any = buildAny {-# COMPLETE Any :: Any #-} {-# COMPLETE Any :: Exp #-} infixl 3 :. pattern (:.) :: IsShapeSnoc t h s => t -> h -> s -pattern t :. h <- (snoc_matcher -> (t Sugar.:. h)) - where t :. h = snoc_builder t h +pattern t :. h <- (matchSnoc -> (t Sugar.:. h)) + where t :. h = buildSnoc t h {-# COMPLETE (:.) :: (:.) #-} {-# COMPLETE (:.) :: Exp #-} class IsShapeZ z where - z_matcher :: z -> Bool - z_builder :: z + matchZ :: z -> Bool + buildZ :: z instance IsShapeZ Z where - z_matcher _ = True - z_builder = Sugar.Z + matchZ _ = True + buildZ = Sugar.Z instance IsShapeZ (Exp Z) where - z_matcher _ = True - z_builder = constant Sugar.Z + matchZ _ = True + buildZ = constant Sugar.Z class IsShapeAll all where - all_matcher :: all -> Bool - all_builder :: all + matchAll :: all -> Bool + buildAll :: all instance IsShapeAll All where - all_matcher _ = True - all_builder = Sugar.All + matchAll _ = True + buildAll = Sugar.All instance IsShapeAll (Exp All) where - all_matcher _ = True - all_builder = constant Sugar.All + matchAll _ = True + buildAll = constant Sugar.All class IsShapeAny any where - any_matcher :: any -> Bool - any_builder :: any + matchAny :: any -> Bool + buildAny :: any instance IsShapeAny (Any sh) where - any_matcher _ = True - any_builder = Sugar.Any + matchAny _ = True + buildAny = Sugar.Any instance Elt (Any sh) => IsShapeAny (Exp (Any sh)) where - any_matcher _ = True - any_builder = constant Sugar.Any + matchAny _ = True + buildAny = constant Sugar.Any -class IsShapeSnoc t h s | s -> t, s -> h where - snoc_matcher :: s -> (t :. h) - snoc_builder :: t -> h -> s +class IsShapeSnoc t h s | s -> h t where + matchSnoc :: s -> (t :. h) + buildSnoc :: t -> h -> s instance IsShapeSnoc (Exp t) (Exp h) (Exp (t :. h)) where - snoc_builder (Exp a) (Exp b) = Exp $ SmartExp $ Pair a b - snoc_matcher (Exp t) = Exp (SmartExp $ Prj PairIdxLeft t) Sugar.:. Exp (SmartExp $ Prj PairIdxRight t) + buildSnoc (Exp a) (Exp b) = Exp $ SmartExp $ Pair a b + matchSnoc (Exp t) = Exp (SmartExp $ Prj PairIdxLeft t) Sugar.:. Exp (SmartExp $ Prj PairIdxRight t) instance IsShapeSnoc t h (t :. h) where - snoc_builder = (Sugar.:.) - snoc_matcher = id + buildSnoc = (Sugar.:.) + matchSnoc = id + + +-- Generate patterns for constructing and destructing indices of a given +-- dimensionality: +-- +-- > let ix = Ix 2 3 -- :: Exp DIM2 +-- > let I2 y x = ix -- y :: Exp Int, x :: Exp Int +-- +runQ $ + let + mkI :: Int -> Q [Dec] + mkI n = + let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + ts = map varT xs + name = mkName ('I':show n) + ix = mkName ":." + cst = tupT (map (\t -> [t| Elt $t |]) ts) + dim = foldl (\h t -> [t| $h :. $t |]) [t| Z |] ts + sig = foldr (\t r -> [t| Exp $t -> $r |]) [t| Exp $dim |] ts + in + sequence + [ patSynSigD name [t| $cst => $sig |] + , patSynD name (prefixPatSyn xs) implBidir (foldl (\ps p -> infixP ps ix (varP p)) [p| Z |] xs) + , pragCompleteD [name] Nothing + ] + in + concat <$> mapM mkI [0..9] diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index 16066ab67..77f0232be 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -29,11 +29,7 @@ module Data.Array.Accelerate.Sugar.Vec ( Vec(..), KnownNat, - Vec2, - Vec3, - Vec4, - Vec8, - Vec16, + V2, V3, V4, V8, V16, SIMD(..), ) where @@ -61,11 +57,11 @@ data Vec n a = Vec (VecR n a) -- Synonyms for common vector sizes -- -type Vec2 = Vec 2 -type Vec3 = Vec 3 -type Vec4 = Vec 4 -type Vec8 = Vec 8 -type Vec16 = Vec 16 +type V2 = Vec 2 +type V3 = Vec 3 +type V4 = Vec 4 +type V8 = Vec 8 +type V16 = Vec 16 instance (Show a, Elt a, SIMD n a) => Show (Vec n a) where show = vec . toList diff --git a/src/Data/Array/Accelerate/Test/NoFib/Base.hs b/src/Data/Array/Accelerate/Test/NoFib/Base.hs index 2475cea1a..0b39afb9e 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Base.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Base.hs @@ -96,19 +96,19 @@ f32 = Gen.float (Range.linearFracFrom 0 (-log_flt_max) log_flt_max) f64 :: Gen Double f64 = Gen.double (Range.linearFracFrom 0 (-log_flt_max) log_flt_max) -v2 :: (Elt a, SIMD 2 a) => Gen a -> Gen (Vec2 a) +v2 :: (Elt a, SIMD 2 a) => Gen a -> Gen (V2 a) v2 a = GHC.fromList <$> replicateM 2 a -v3 :: (Elt a, SIMD 3 a) => Gen a -> Gen (Vec3 a) +v3 :: (Elt a, SIMD 3 a) => Gen a -> Gen (V3 a) v3 a = GHC.fromList <$> replicateM 3 a -v4 :: (Elt a, SIMD 4 a) => Gen a -> Gen (Vec4 a) +v4 :: (Elt a, SIMD 4 a) => Gen a -> Gen (V4 a) v4 a = GHC.fromList <$> replicateM 4 a -v8 :: (Elt a, SIMD 8 a) => Gen a -> Gen (Vec8 a) +v8 :: (Elt a, SIMD 8 a) => Gen a -> Gen (V8 a) v8 a = GHC.fromList <$> replicateM 8 a -v16 :: (Elt a, SIMD 16 a) => Gen a -> Gen (Vec16 a) +v16 :: (Elt a, SIMD 16 a) => Gen a -> Gen (V16 a) v16 a = GHC.fromList <$> replicateM 16 a diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index f7d26c834..851a7dbbd 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -28,11 +28,11 @@ module Data.Primitive.Vec ( -- * SIMD vector types Vec(..), KnownNat, - Vec2, pattern Vec2, - Vec3, pattern Vec3, - Vec4, pattern Vec4, - Vec8, pattern Vec8, - Vec16, pattern Vec16, + V2, pattern V2, + V3, pattern V3, + V4, pattern V4, + V8, pattern V8, + V16, pattern V16, toList, fromList, extract, insert, splat, @@ -68,25 +68,25 @@ import GHC.Ptr -- -- A simple polymorphic representation of SIMD types such as the following: -- --- > data Vec2 a = Vec2 !a !a +-- > data V2 a = V2 !a !a -- -- is not able to unpack the values into the constructor, meaning that --- 'Vec2' is storing pointers to (strict) values on the heap, which is +-- 'V2' is storing pointers to (strict) values on the heap, which is -- a very inefficient representation. -- -- We might try defining a data family instead so that we can get efficient -- unboxed representations, and even make use of the unlifted SIMD types GHC -- knows about: -- --- > data family Vec2 a :: * --- > data instance Vec2 Float = Vec2_Float Float# Float# -- reasonable --- > data instance Vec2 Double = Vec2_Double DoubleX2# -- built in! +-- > data family V2 a :: * +-- > data instance V2 Float = V2_Float Float# Float# -- reasonable +-- > data instance V2 Double = V2_Double DoubleX2# -- built in! -- -- However, this runs into the problem that GHC stores all values as word sized -- entities: -- --- > data instance Vec2 Int = Vec2_Int Int# Int# --- > data instance Vec2 Int8 = Vec2_Int8 Int8# Int8# -- Int8# does not exist; requires a full Int# +-- > data instance V2 Int = V2_Int Int# Int# +-- > data instance V2 Int8 = V2_Int8 Int8# Int8# -- Int8# does not exist; requires a full Int# -- -- which, again, is very memory inefficient. -- @@ -198,54 +198,54 @@ splat x = runST $ do -- -- Note that non-power-of-two sized SIMD vectors are a bit dubious, and -- special care must be taken in the code generator. For example, LLVM will --- treat a Vec3 with alignment of _4_, meaning that reads and writes will +-- treat a V3 with alignment of _4_, meaning that reads and writes will -- be (without further action) incorrect. -- -type Vec2 a = Vec 2 a -type Vec3 a = Vec 3 a -type Vec4 a = Vec 4 a -type Vec8 a = Vec 8 a -type Vec16 a = Vec 16 a - -pattern Vec2 :: Prim a => a -> a -> Vec2 a -pattern Vec2 a b <- (unpackVec2 -> (a,b)) - where Vec2 = packVec2 -{-# COMPLETE Vec2 #-} - -pattern Vec3 :: Prim a => a -> a -> a -> Vec3 a -pattern Vec3 a b c <- (unpackVec3 -> (a,b,c)) - where Vec3 = packVec3 -{-# COMPLETE Vec3 #-} - -pattern Vec4 :: Prim a => a -> a -> a -> a -> Vec4 a -pattern Vec4 a b c d <- (unpackVec4 -> (a,b,c,d)) - where Vec4 = packVec4 -{-# COMPLETE Vec4 #-} - -pattern Vec8 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> Vec8 a -pattern Vec8 a b c d e f g h <- (unpackVec8 -> (a,b,c,d,e,f,g,h)) - where Vec8 = packVec8 -{-# COMPLETE Vec8 #-} - -pattern Vec16 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> Vec16 a -pattern Vec16 a b c d e f g h i j k l m n o p <- (unpackVec16 -> (a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p)) - where Vec16 = packVec16 -{-# COMPLETE Vec16 #-} - -unpackVec2 :: Prim a => Vec2 a -> (a,a) +type V2 a = Vec 2 a +type V3 a = Vec 3 a +type V4 a = Vec 4 a +type V8 a = Vec 8 a +type V16 a = Vec 16 a + +pattern V2 :: Prim a => a -> a -> V2 a +pattern V2 a b <- (unpackVec2 -> (a,b)) + where V2 = packVec2 +{-# COMPLETE V2 #-} + +pattern V3 :: Prim a => a -> a -> a -> V3 a +pattern V3 a b c <- (unpackVec3 -> (a,b,c)) + where V3 = packVec3 +{-# COMPLETE V3 #-} + +pattern V4 :: Prim a => a -> a -> a -> a -> V4 a +pattern V4 a b c d <- (unpackVec4 -> (a,b,c,d)) + where V4 = packVec4 +{-# COMPLETE V4 #-} + +pattern V8 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> V8 a +pattern V8 a b c d e f g h <- (unpackVec8 -> (a,b,c,d,e,f,g,h)) + where V8 = packVec8 +{-# COMPLETE V8 #-} + +pattern V16 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> V16 a +pattern V16 a b c d e f g h i j k l m n o p <- (unpackVec16 -> (a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p)) + where V16 = packVec16 +{-# COMPLETE V16 #-} + +unpackVec2 :: Prim a => V2 a -> (a,a) unpackVec2 (Vec ba#) = ( indexByteArray# ba# 0# , indexByteArray# ba# 1# ) -unpackVec3 :: Prim a => Vec3 a -> (a,a,a) +unpackVec3 :: Prim a => V3 a -> (a,a,a) unpackVec3 (Vec ba#) = ( indexByteArray# ba# 0# , indexByteArray# ba# 1# , indexByteArray# ba# 2# ) -unpackVec4 :: Prim a => Vec4 a -> (a,a,a,a) +unpackVec4 :: Prim a => V4 a -> (a,a,a,a) unpackVec4 (Vec ba#) = ( indexByteArray# ba# 0# , indexByteArray# ba# 1# @@ -253,7 +253,7 @@ unpackVec4 (Vec ba#) = , indexByteArray# ba# 3# ) -unpackVec8 :: Prim a => Vec8 a -> (a,a,a,a,a,a,a,a) +unpackVec8 :: Prim a => V8 a -> (a,a,a,a,a,a,a,a) unpackVec8 (Vec ba#) = ( indexByteArray# ba# 0# , indexByteArray# ba# 1# @@ -265,7 +265,7 @@ unpackVec8 (Vec ba#) = , indexByteArray# ba# 7# ) -unpackVec16 :: Prim a => Vec16 a -> (a,a,a,a,a,a,a,a,a,a,a,a,a,a,a,a) +unpackVec16 :: Prim a => V16 a -> (a,a,a,a,a,a,a,a,a,a,a,a,a,a,a,a) unpackVec16 (Vec ba#) = ( indexByteArray# ba# 0# , indexByteArray# ba# 1# @@ -285,7 +285,7 @@ unpackVec16 (Vec ba#) = , indexByteArray# ba# 15# ) -packVec2 :: Prim a => a -> a -> Vec2 a +packVec2 :: Prim a => a -> a -> V2 a packVec2 a b = runST $ do mba <- newByteArray (2 * sizeOf a) writeByteArray mba 0 a @@ -293,7 +293,7 @@ packVec2 a b = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# -packVec3 :: Prim a => a -> a -> a -> Vec3 a +packVec3 :: Prim a => a -> a -> a -> V3 a packVec3 a b c = runST $ do mba <- newByteArray (3 * sizeOf a) writeByteArray mba 0 a @@ -302,7 +302,7 @@ packVec3 a b c = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# -packVec4 :: Prim a => a -> a -> a -> a -> Vec4 a +packVec4 :: Prim a => a -> a -> a -> a -> V4 a packVec4 a b c d = runST $ do mba <- newByteArray (4 * sizeOf a) writeByteArray mba 0 a @@ -312,7 +312,7 @@ packVec4 a b c d = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# -packVec8 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> Vec8 a +packVec8 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> V8 a packVec8 a b c d e f g h = runST $ do mba <- newByteArray (8 * sizeOf a) writeByteArray mba 0 a @@ -326,7 +326,7 @@ packVec8 a b c d e f g h = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# -packVec16 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> Vec16 a +packVec16 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> V16 a packVec16 a b c d e f g h i j k l m n o p = runST $ do mba <- newByteArray (16 * sizeOf a) writeByteArray mba 0 a From 0c2b6c23a95fc4da7fde16d196f8f44caaee6f38 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 27 Jun 2022 18:31:51 +0200 Subject: [PATCH 33/86] value/expression polymorphic pattern synonyms for tuples --- accelerate.cabal | 1 + src/Data/Array/Accelerate.hs | 3 +- src/Data/Array/Accelerate/Classes/Bounded.hs | 2 +- src/Data/Array/Accelerate/Classes/Eq.hs | 2 +- src/Data/Array/Accelerate/Classes/Ord.hs | 2 +- src/Data/Array/Accelerate/Classes/Rational.hs | 2 +- .../Array/Accelerate/Classes/RealFloat.hs | 12 +- src/Data/Array/Accelerate/Classes/RealFrac.hs | 2 +- src/Data/Array/Accelerate/Data/Complex.hs | 3 +- src/Data/Array/Accelerate/Data/Monoid.hs | 1 + src/Data/Array/Accelerate/Data/Ratio.hs | 1 + src/Data/Array/Accelerate/Language.hs | 10 +- src/Data/Array/Accelerate/Pattern.hs | 187 +++++++----------- src/Data/Array/Accelerate/Pattern/SIMD.hs | 53 +---- src/Data/Array/Accelerate/Pattern/Tuple.hs | 86 ++++++++ src/Data/Array/Accelerate/Prelude.hs | 2 +- 16 files changed, 194 insertions(+), 175 deletions(-) create mode 100644 src/Data/Array/Accelerate/Pattern/Tuple.hs diff --git a/accelerate.cabal b/accelerate.cabal index e411aafba..3edf5cbc3 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -489,6 +489,7 @@ library Data.Array.Accelerate.Pattern.SIMD Data.Array.Accelerate.Pattern.Shape Data.Array.Accelerate.Pattern.TH + Data.Array.Accelerate.Pattern.Tuple Data.Array.Accelerate.Prelude Data.Array.Accelerate.Pretty.Graphviz Data.Array.Accelerate.Pretty.Graphviz.Monad diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index d4461f2b0..b1c4f810c 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -449,8 +449,9 @@ import Data.Array.Accelerate.Data.Either import Data.Array.Accelerate.Data.Maybe import Data.Array.Accelerate.Language import Data.Array.Accelerate.Pattern -import Data.Array.Accelerate.Pattern.Shape import Data.Array.Accelerate.Pattern.SIMD +import Data.Array.Accelerate.Pattern.Shape +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Pattern.TH import Data.Array.Accelerate.Prelude import Data.Array.Accelerate.Pretty () -- show instances diff --git a/src/Data/Array/Accelerate/Classes/Bounded.hs b/src/Data/Array/Accelerate/Classes/Bounded.hs index 4e495e7d1..2229556ac 100644 --- a/src/Data/Array/Accelerate/Classes/Bounded.hs +++ b/src/Data/Array/Accelerate/Classes/Bounded.hs @@ -21,7 +21,7 @@ module Data.Array.Accelerate.Classes.Bounded ( ) where -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Vec diff --git a/src/Data/Array/Accelerate/Classes/Eq.hs b/src/Data/Array/Accelerate/Classes/Eq.hs index e6d8eb0d9..0ffbc0738 100644 --- a/src/Data/Array/Accelerate/Classes/Eq.hs +++ b/src/Data/Array/Accelerate/Classes/Eq.hs @@ -35,8 +35,8 @@ module Data.Array.Accelerate.Classes.Eq ( import Data.Array.Accelerate.AST ( PrimFun(..), BitOrMask ) import Data.Array.Accelerate.Error -import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Pattern.Bool +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt diff --git a/src/Data/Array/Accelerate/Classes/Ord.hs b/src/Data/Array/Accelerate/Classes/Ord.hs index c0c3581e5..4749b6a86 100644 --- a/src/Data/Array/Accelerate/Classes/Ord.hs +++ b/src/Data/Array/Accelerate/Classes/Ord.hs @@ -34,8 +34,8 @@ import Data.Array.Accelerate.AST ( PrimFun(.. import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Error -import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Pattern.Ordering +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt diff --git a/src/Data/Array/Accelerate/Classes/Rational.hs b/src/Data/Array/Accelerate/Classes/Rational.hs index 0a65805c8..bed30c67d 100644 --- a/src/Data/Array/Accelerate/Classes/Rational.hs +++ b/src/Data/Array/Accelerate/Classes/Rational.hs @@ -21,7 +21,7 @@ import Data.Array.Accelerate.Data.Ratio import Data.Array.Accelerate.Data.Bits import Data.Array.Accelerate.Language -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type diff --git a/src/Data/Array/Accelerate/Classes/RealFloat.hs b/src/Data/Array/Accelerate/Classes/RealFloat.hs index 4d11059c0..31ebde034 100644 --- a/src/Data/Array/Accelerate/Classes/RealFloat.hs +++ b/src/Data/Array/Accelerate/Classes/RealFloat.hs @@ -29,7 +29,7 @@ module Data.Array.Accelerate.Classes.RealFloat ( import Data.Array.Accelerate.Error import Data.Array.Accelerate.Language ( (^), cond, while ) -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type @@ -377,9 +377,9 @@ ieee754_f16_decode i = T2 high3 exp3 = cond (exp1 /= _HMINEXP) -- don't add hidden bit to denorms - (T2 (high2 .|. _HHIGHBIT) exp1) + (T2 (high2 .|. _HHIGHBIT) exp1 :: Exp (Int16, Int)) -- a denorm, normalise the mantissa - (while (\(T2 h _) -> (h .&. _HHIGHBIT) /= 0 ) + (while (\(T2 h _) -> (h .&. _HHIGHBIT) /= 0) (\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1)) (T2 high2 exp2)) @@ -411,9 +411,9 @@ ieee754_f32_decode i = T2 high3 exp3 = cond (exp1 /= _FMINEXP) -- don't add hidden bit to denorms - (T2 (high2 .|. _FHIGHBIT) exp1) + (T2 (high2 .|. _FHIGHBIT) exp1 :: Exp (Int32, Int)) -- a denorm, normalise the mantissa - (while (\(T2 h _) -> (h .&. _FHIGHBIT) /= 0 ) + (while (\(T2 h _) -> (h .&. _FHIGHBIT) /= 0) (\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1)) (T2 high2 exp2)) @@ -450,7 +450,7 @@ ieee754_f64_decode2 i = T3 hi lo ie = cond (iexp2 /= _DMINEXP) -- don't add hidden bit to denorms - (T3 (high2 .|. _DHIGHBIT) low iexp) + (T3 (high2 .|. _DHIGHBIT) low iexp :: Exp (Word32, Word32, Int)) -- a denorm, nermalise the mantissa (while (\(T3 h _ _) -> (h .&. _DHIGHBIT) /= 0) (\(T3 h l e) -> diff --git a/src/Data/Array/Accelerate/Classes/RealFrac.hs b/src/Data/Array/Accelerate/Classes/RealFrac.hs index b3701a5af..d13e3c570 100644 --- a/src/Data/Array/Accelerate/Classes/RealFrac.hs +++ b/src/Data/Array/Accelerate/Classes/RealFrac.hs @@ -25,7 +25,7 @@ module Data.Array.Accelerate.Classes.RealFrac ( ) where import Data.Array.Accelerate.Language ( cond, even ) -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt diff --git a/src/Data/Array/Accelerate/Data/Complex.hs b/src/Data/Array/Accelerate/Data/Complex.hs index 8ec644454..500b022f0 100644 --- a/src/Data/Array/Accelerate/Data/Complex.hs +++ b/src/Data/Array/Accelerate/Data/Complex.hs @@ -55,6 +55,7 @@ import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.RealFloat import Data.Array.Accelerate.Data.Functor import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Prelude import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type @@ -318,7 +319,7 @@ instance RealFloat a => P.Floating (Exp (Complex a)) where then 0 else u ::+ (y < 0 ? (-v, v)) where - T2 u v = x < 0 ? (T2 v' u', T2 u' v') + T2 u v = x < 0 ? (T2 v' u', T2 u' v') :: Exp (a,a) v' = abs y / (u'*2) u' = sqrt ((magnitude z + abs x) / 2) diff --git a/src/Data/Array/Accelerate/Data/Monoid.hs b/src/Data/Array/Accelerate/Data/Monoid.hs index 23576be77..4cb6c9b10 100644 --- a/src/Data/Array/Accelerate/Data/Monoid.hs +++ b/src/Data/Array/Accelerate/Data/Monoid.hs @@ -43,6 +43,7 @@ import Data.Array.Accelerate.Data.Semigroup () import Data.Array.Accelerate.Language import Data.Array.Accelerate.Lift import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type diff --git a/src/Data/Array/Accelerate/Data/Ratio.hs b/src/Data/Array/Accelerate/Data/Ratio.hs index d54d5b960..22ce4af26 100644 --- a/src/Data/Array/Accelerate/Data/Ratio.hs +++ b/src/Data/Array/Accelerate/Data/Ratio.hs @@ -32,6 +32,7 @@ module Data.Array.Accelerate.Data.Ratio ( import Data.Array.Accelerate.Language import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Prelude import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index ad3f27284..bfbb6bb77 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -88,7 +88,7 @@ module Data.Array.Accelerate.Language ( ) where import Data.Array.Accelerate.AST ( PrimFun(..) ) -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Representation.Array ( ArrayR(..) ) import Data.Array.Accelerate.Representation.Shape ( ShapeR(..) ) import Data.Array.Accelerate.Representation.Type @@ -1414,11 +1414,11 @@ odd n = n `rem` 2 /= 0 gcd :: Integral a => Exp a -> Exp a -> Exp a gcd x y = gcd' (abs x) (abs y) where - gcd' :: Integral a => Exp a -> Exp a -> Exp a + gcd' :: forall a. Integral a => Exp a -> Exp a -> Exp a gcd' u v = let T2 r _ = while (\(T2 _ b) -> b /= 0) (\(T2 a b) -> T2 b (a `rem` b)) - (T2 u v) + (T2 u v) :: Exp (a,a) in r @@ -1440,7 +1440,7 @@ x0 ^ y0 = cond (y0 <= 0) 1 (f x0 y0) f x y = let T2 x' y' = while (\(T2 _ v) -> even v) (\(T2 u v) -> T2 (u * u) (v `quot` 2)) - (T2 x y) + (T2 x y) :: Exp (a,b) in cond (y' == 1) x' (g (x'*x') ((y'-1) `quot` 2) x') @@ -1450,7 +1450,7 @@ x0 ^ y0 = cond (y0 <= 0) 1 (f x0 y0) (\(T3 u v w) -> cond (even v) (T3 (u*u) (v `quot` 2) w) (T3 (u*u) ((v-1) `quot` 2) (w*u))) - (T3 x y z) + (T3 x y z) :: Exp (a,b,a) in x' * z' diff --git a/src/Data/Array/Accelerate/Pattern.hs b/src/Data/Array/Accelerate/Pattern.hs index fa0f5f958..2b002b256 100644 --- a/src/Data/Array/Accelerate/Pattern.hs +++ b/src/Data/Array/Accelerate/Pattern.hs @@ -24,9 +24,6 @@ module Data.Array.Accelerate.Pattern ( pattern Pattern, - pattern T2, pattern T3, pattern T4, pattern T5, pattern T6, - pattern T7, pattern T8, pattern T9, pattern T10, pattern T11, - pattern T12, pattern T13, pattern T14, pattern T15, pattern T16, ) where @@ -57,115 +54,79 @@ class IsPattern context a b where -- the (unremarkable) boilerplate for us. -- runQ $ do - let - -- Generate instance declarations for IsPattern of the form: - -- instance (Arrays x, ArraysR x ~ (((), ArraysR a), ArraysR b), Arrays a, Arrays b,) => IsPattern Acc x (Acc a, Acc b) - mkAccPattern :: Int -> Q [Dec] - mkAccPattern n = do - a <- newName "a" - _x <- newName "_x" - let - -- Type variables for the elements - xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - -- Last argument to `IsPattern`, eg (Acc a, Acc b) in the example - b = tupT (map (\t -> [t| Acc $(varT t)|]) xs) - -- Representation as snoc-list of pairs, eg (((), ArraysR a), ArraysR b) - snoc = foldl (\sn t -> [t| ($sn, ArraysR $(varT t)) |]) [t| () |] xs - -- Constraints for the type class, consisting of Arrays constraints on all type variables, - -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. - context = tupT - $ [t| Arrays $(varT a) |] - : [t| ArraysR $(varT a) ~ $snoc |] - : map (\t -> [t| Arrays $(varT t)|]) xs - -- - get x 0 = [| Acc (SmartAcc (Aprj PairIdxRight $x)) |] - get x i = get [| SmartAcc (Aprj PairIdxLeft $x) |] (i-1) - -- - [d| instance $context => IsPattern Acc $(varT a) $b where - builder $(tupP (map (\x -> [p| Acc $(varP x)|]) xs)) = - Acc $(foldl (\vs v -> [| SmartAcc ($vs `Apair` $(varE v)) |]) [| SmartAcc Anil |] xs) - matcher (Acc $(varP _x)) = - $(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0])) - |] + let + -- Generate instance declarations for IsPattern of the form: + -- instance (Arrays x, ArraysR x ~ (((), ArraysR a), ArraysR b), Arrays a, Arrays b,) => IsPattern Acc x (Acc a, Acc b) + mkAccPattern :: Int -> Q [Dec] + mkAccPattern n = do + a <- newName "a" + _x <- newName "_x" + let + -- Type variables for the elements + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + -- Last argument to `IsPattern`, eg (Acc a, Acc b) in the example + b = tupT (map (\t -> [t| Acc $(varT t)|]) xs) + -- Representation as snoc-list of pairs, eg (((), ArraysR a), ArraysR b) + snoc = foldl (\sn t -> [t| ($sn, ArraysR $(varT t)) |]) [t| () |] xs + -- Constraints for the type class, consisting of Arrays constraints on all type variables, + -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. + context = tupT + $ [t| Arrays $(varT a) |] + : [t| ArraysR $(varT a) ~ $snoc |] + : map (\t -> [t| Arrays $(varT t)|]) xs + -- + get x 0 = [| Acc (SmartAcc (Aprj PairIdxRight $x)) |] + get x i = get [| SmartAcc (Aprj PairIdxLeft $x) |] (i-1) + -- + [d| instance $context => IsPattern Acc $(varT a) $b where + builder $(tupP (map (\x -> [p| Acc $(varP x)|]) xs)) = + Acc $(foldl (\vs v -> [| SmartAcc ($vs `Apair` $(varE v)) |]) [| SmartAcc Anil |] xs) + matcher (Acc $(varP _x)) = + $(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0])) + |] - -- Generate instance declarations for IsPattern of the form: - -- instance (Elt x, EltR x ~ (((), EltR a), EltR b), Elt a, Elt b,) => IsPattern Exp x (Exp a, Exp b) - mkExpPattern :: Int -> Q [Dec] - mkExpPattern n = do - a <- newName "a" - _x <- newName "_x" - _y <- newName "_y" - let - -- Type variables for the elements - xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - -- Variables for sub-pattern matches - ms = [ mkName ('m' : show i) | i <- [0 .. n-1] ] - tags = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms - -- Last argument to `IsPattern`, eg (Exp, a, Exp b) in the example - b = tupT (map (\t -> [t| Exp $(varT t)|]) xs) - -- Representation as snoc-list of pairs, eg (((), EltR a), EltR b) - snoc = foldl (\sn t -> [t| ($sn, EltR $(varT t)) |]) [t| () |] xs - -- Constraints for the type class, consisting of Elt constraints on all type variables, - -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. - context = tupT - $ [t| Elt $(varT a) |] - : [t| EltR $(varT a) ~ $snoc |] - : map (\t -> [t| Elt $(varT t)|]) xs - -- - get x 0 = [| SmartExp (Prj PairIdxRight $x) |] - get x i = get [| SmartExp (Prj PairIdxLeft $x) |] (i-1) - -- - [d| instance $context => IsPattern Exp $(varT a) $b where - builder $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) = - let _unmatch :: SmartExp a -> SmartExp a - _unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y) - _unmatch x = x - in - Exp $(foldl (\vs v -> [| SmartExp ($vs `Pair` _unmatch $(varE v)) |]) [| SmartExp Nil |] xs) - matcher (Exp $(varP _x)) = - case $(varE _x) of - SmartExp (Match $tags $(varP _y)) - -> $(tupE [[| Exp (SmartExp (Match $(varE m) $(get (varE _x) i))) |] | m <- ms | i <- [(n-1), (n-2) .. 0]]) - _ -> $(tupE [[| Exp $(get (varE _x) i) |] | i <- [(n-1), (n-2) .. 0]]) - |] - -- - es <- mapM mkExpPattern [0..16] - as <- mapM mkAccPattern [0..16] - return $ concat (es ++ as) - - --- | Specialised pattern synonyms for tuples, which may be more convenient to --- use than 'Data.Array.Accelerate.Lift.lift' and --- 'Data.Array.Accelerate.Lift.unlift'. For example, to construct a pair: --- --- > let a = 4 :: Exp Int --- > let b = 2 :: Exp Float --- > let c = T2 a b -- :: Exp (Int, Float); equivalent to 'lift (a,b)' --- --- Similarly they can be used to destruct values: --- --- > let T2 x y = c -- x :: Exp Int, y :: Exp Float; equivalent to 'let (x,y) = unlift c' --- --- These pattern synonyms can be used for both 'Exp' and 'Acc' terms. --- -runQ $ do - let - mkT :: Int -> Q [Dec] - mkT n = - let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - ts = map varT xs - name = mkName ('T':show n) - con = varT (mkName "con") - ty1 = tupT ts - ty2 = tupT (map (con `appT`) ts) - sig = foldr (\t r -> [t| $con $t -> $r |]) (appT con ty1) ts - in - sequence - [ patSynSigD name [t| IsPattern $con $ty1 $ty2 => $sig |] - , patSynD name (prefixPatSyn xs) implBidir [p| Pattern $(tupP (map varP xs)) |] - , pragCompleteD [name] (Just ''Acc) - , pragCompleteD [name] (Just ''Exp) - ] - -- - concat <$> mapM mkT [2..16] + -- Generate instance declarations for IsPattern of the form: + -- instance (Elt x, EltR x ~ (((), EltR a), EltR b), Elt a, Elt b,) => IsPattern Exp x (Exp a, Exp b) + mkExpPattern :: Int -> Q [Dec] + mkExpPattern n = do + a <- newName "a" + _x <- newName "_x" + _y <- newName "_y" + let + -- Type variables for the elements + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + -- Variables for sub-pattern matches + ms = [ mkName ('m' : show i) | i <- [0 .. n-1] ] + tags = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms + -- Last argument to `IsPattern`, eg (Exp, a, Exp b) in the example + b = tupT (map (\t -> [t| Exp $(varT t)|]) xs) + -- Representation as snoc-list of pairs, eg (((), EltR a), EltR b) + snoc = foldl (\sn t -> [t| ($sn, EltR $(varT t)) |]) [t| () |] xs + -- Constraints for the type class, consisting of Elt constraints on all type variables, + -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. + context = tupT + $ [t| Elt $(varT a) |] + : [t| EltR $(varT a) ~ $snoc |] + : map (\t -> [t| Elt $(varT t)|]) xs + -- + get x 0 = [| SmartExp (Prj PairIdxRight $x) |] + get x i = get [| SmartExp (Prj PairIdxLeft $x) |] (i-1) + -- + [d| instance $context => IsPattern Exp $(varT a) $b where + builder $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) = + let _unmatch :: SmartExp a -> SmartExp a + _unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y) + _unmatch x = x + in + Exp $(foldl (\vs v -> [| SmartExp ($vs `Pair` _unmatch $(varE v)) |]) [| SmartExp Nil |] xs) + matcher (Exp $(varP _x)) = + case $(varE _x) of + SmartExp (Match $tags $(varP _y)) + -> $(tupE [[| Exp (SmartExp (Match $(varE m) $(get (varE _x) i))) |] | m <- ms | i <- [(n-1), (n-2) .. 0]]) + _ -> $(tupE [[| Exp $(get (varE _x) i) |] | i <- [(n-1), (n-2) .. 0]]) + |] + -- + es <- mapM mkExpPattern [0..16] + as <- mapM mkAccPattern [0..16] + return $ concat (es ++ as) diff --git a/src/Data/Array/Accelerate/Pattern/SIMD.hs b/src/Data/Array/Accelerate/Pattern/SIMD.hs index a8ae3e4fb..075185ab4 100644 --- a/src/Data/Array/Accelerate/Pattern/SIMD.hs +++ b/src/Data/Array/Accelerate/Pattern/SIMD.hs @@ -11,7 +11,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ViewPatterns #-} -- | --- Module : Data.Array.Accelerate.Pattern +-- Module : Data.Array.Accelerate.Pattern.SIMD -- Copyright : [2018..2020] The Accelerate Team -- License : BSD3 -- @@ -50,8 +50,8 @@ runQ $ let -- Generate instance declarations for IsSIMD of the form: -- instance (Elt a, Elt v, EltR v ~ VecR n a) => IsSIMD Exp v (Exp a, Exp a) - mkVecPattern :: Int -> Q [Dec] - mkVecPattern n = do + mkV :: Int -> Q [Dec] + mkV n = do a <- newName "a" v <- newName "v" _x <- newName "_x" @@ -112,40 +112,7 @@ runQ $ | i <- [0 .. n-1] ]) |] in - concat <$> mapM mkVecPattern [2,3,4,8,16] - - -- mkV :: Int -> Q [Dec] - -- mkV n = - -- let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - -- a = varT (mkName "a") - -- ts = replicate n a - -- name = mkName ('V':show n) - -- tup = tupT (map (\t -> [t| Exp $t |]) ts) - -- vec = [t| Vec $(litT (numTyLit (toInteger n))) $a |] - -- cst = [t| (Elt $a, SIMD $(litT (numTyLit (toInteger n))) $a, IsSIMD Exp $vec $tup) |] - -- sig = foldr (\t r -> [t| Exp $t -> $r |]) [t| Exp $vec |] ts - -- in - -- sequence - -- [ patSynSigD name [t| $cst => $sig |] - -- , patSynD name (prefixPatSyn xs) implBidir [p| SIMD $(tupP (map varP xs)) |] - -- , pragCompleteD [name] Nothing - -- ] - - -- mkV :: Int -> Q [Dec] - -- mkV n = - -- let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - -- ts = map varT xs - -- name = mkName ('V':show n) - -- con = varT (mkName "con") - -- ty1 = varT (mkName "vec") - -- ty2 = tupT (map (con `appT`) ts) - -- sig = foldr (\t r -> [t| $con $t -> $r |]) (appT con ty1) ts - -- in - -- sequence - -- [ patSynSigD name [t| IsVector $con $ty1 $ty2 => $sig |] - -- , patSynD name (prefixPatSyn xs) implBidir [p| Vector $(tupP (map varP xs)) |] - -- , pragCompleteD [name] (Just ''Exp) - -- ] + concat <$> mapM mkV [2,3,4,8,16] -- Generate polymorphic pattern synonyms which operate on both Haskell values -- as well as embedded expressions @@ -161,7 +128,7 @@ runQ $ xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] xsP = map varP xs xsE = map varE xs - vn = mkName ("V" ++ show n) + name = mkName ("V" ++ show n) isV = mkName ("IsV" ++ show n) builder = mkName ("buildV" ++ show n) matcher = mkName ("matchV" ++ show n) @@ -170,9 +137,9 @@ runQ $ ] -- sequence - [ patSynSigD vn [t| $(conT isV) $(varT a) $(varT v) => $(foldr (\t r -> [t| $t -> $r |]) (varT v) as) |] - , patSynD vn (prefixPatSyn xs) (explBidir [clause [] (normalB (varE builder)) []]) (parensP $ viewP (varE matcher) (tupP xsP)) - , pragCompleteD [vn] Nothing + [ patSynSigD name [t| $(conT isV) $(varT a) $(varT v) => $(foldr (\t r -> [t| $t -> $r |]) (varT v) as) |] + , patSynD name (prefixPatSyn xs) (explBidir [clause [] (normalB (varE builder)) []]) (parensP $ viewP (varE matcher) (tupP xsP)) + , pragCompleteD [name] Nothing -- , classD (return []) isV [PlainTV a (), PlainTV v ()] [funDep [v] [a]] [ sigD builder (foldr (\t r -> [t| $t -> $r |]) (varT v) as) @@ -180,11 +147,11 @@ runQ $ ] -- This instance which goes via toList is horrible and I feel bad for using it -- TLM 2022-06-27 - , instanceD ctx [t| $(conT isV) $(varT a) ($(conT vn) $(varT a)) |] + , instanceD ctx [t| $(conT isV) $(varT a) ($(conT name) $(varT a)) |] [ funD builder [ clause xsP (normalB [| fromList $(listE xsE) |]) []] , funD matcher [ clause [viewP (varE 'toList) (listP xsP)] (normalB (tupE xsE)) [] ] ] - , instanceD ctx [t| $(conT isV) (Exp $(varT a)) (Exp ($(conT vn) $(varT a))) |] + , instanceD ctx [t| $(conT isV) (Exp $(varT a)) (Exp ($(conT name) $(varT a))) |] [ funD builder [ clause xsP (normalB [| SIMD $(tupE xsE) |]) []] , funD matcher [ clause [conP (mkName "SIMD") [tupP xsP]] (normalB (tupE xsE)) [] ] ] diff --git a/src/Data/Array/Accelerate/Pattern/Tuple.hs b/src/Data/Array/Accelerate/Pattern/Tuple.hs new file mode 100644 index 000000000..a4298244d --- /dev/null +++ b/src/Data/Array/Accelerate/Pattern/Tuple.hs @@ -0,0 +1,86 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ViewPatterns #-} +-- | +-- Module : Data.Array.Accelerate.Pattern.Tuple +-- Copyright : [2018..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Pattern.Tuple ( + + pattern T2, pattern T3, pattern T4, pattern T5, pattern T6, + pattern T7, pattern T8, pattern T9, pattern T10, pattern T11, + pattern T12, pattern T13, pattern T14, pattern T15, pattern T16, + +) where + +import Data.Array.Accelerate.Pattern ( pattern Pattern ) +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Array +import Data.Array.Accelerate.Sugar.Elt + +import Language.Haskell.TH.Extra hiding ( Exp, Match ) + + +-- Generate polymorphic pattern synonyms to construct and destruct tuples on +-- both Haskell values and embedded expressions. This isn't really necessary but +-- provides for a more consistent interface. +-- +runQ $ + let + mkT :: Int -> Q [Dec] + mkT n = + let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + res = mkName "r" + xsT = map varT xs + xsP = map varP xs + xsE = map varE xs + name = mkName ('T':show n) + isT = mkName ("IsT" ++ show n) + builder = mkName ("buildT" ++ show n) + matcher = mkName ("matchT" ++ show n) + sig = foldr (\t r -> [t| $t -> $r |]) (varT res) xsT + hdr ts r = foldl appT (conT isT) (ts ++ [r]) + in + sequence + -- Value/Embedded polymorphic pattern synonym + [ patSynSigD name [t| $(hdr xsT (varT res)) => $sig |] + , patSynD name (prefixPatSyn xs) (explBidir [clause [] (normalB (varE builder)) []]) (parensP $ viewP (varE matcher) (tupP xsP)) + , pragCompleteD [name] (Just (tupleTypeName n)) + , pragCompleteD [name] (Just ''Acc) + , pragCompleteD [name] (Just ''Exp) + -- + , classD (return []) isT (map plainTV (xs ++ [res])) [funDep [res] xs] + [ sigD builder sig + , sigD matcher [t| $(varT res) -> $(tupT xsT) |] + ] + , instanceD (return []) [t| $(hdr xsT (tupT xsT)) |] + [ funD builder [ clause xsP (normalB (tupE xsE)) [] ] + , funD matcher [ clause [] (normalB [| id |]) [] ] + ] + , instanceD (mapM (\x -> [t| Elt $x |]) xsT) [t| $(hdr (map (\x -> [t| Exp $x |]) xsT) [t| Exp $(tupT xsT) |]) |] + [ funD builder [ clause xsP (normalB [| Pattern $(tupE xsE) |]) [] ] + , funD matcher [ clause [conP (mkName "Pattern") [tupP xsP]] (normalB (tupE xsE)) []] + ] + , instanceD (mapM (\x -> [t| Arrays $x |]) xsT) [t| $(hdr (map (\x -> [t| Acc $x |]) xsT) [t| Acc $(tupT xsT) |]) |] + [ funD builder [ clause xsP (normalB [| Pattern $(tupE xsE) |]) [] ] + , funD matcher [ clause [conP (mkName "Pattern") [tupP xsP]] (normalB (tupE xsE)) []] + ] + ] + in + concat <$> mapM mkT [2..16] + diff --git a/src/Data/Array/Accelerate/Prelude.hs b/src/Data/Array/Accelerate/Prelude.hs index f0b144dba..2cfdaa38b 100644 --- a/src/Data/Array/Accelerate/Prelude.hs +++ b/src/Data/Array/Accelerate/Prelude.hs @@ -121,9 +121,9 @@ module Data.Array.Accelerate.Prelude ( import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Language import Data.Array.Accelerate.Lift -import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Pattern.Maybe import Data.Array.Accelerate.Pattern.Shape +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Array ( Arrays, Array, Scalar, Vector, Segments, fromList ) import Data.Array.Accelerate.Sugar.Elt From eb62a0d7ef977db97e350a7fb44fa507e3d35820 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 28 Jun 2022 10:57:20 +0200 Subject: [PATCH 34/86] improve type checking for tuple patterns --- src/Data/Array/Accelerate/Classes/RealFloat.hs | 6 +++--- src/Data/Array/Accelerate/Data/Complex.hs | 2 +- src/Data/Array/Accelerate/Language.hs | 8 ++++---- src/Data/Array/Accelerate/Pattern/Tuple.hs | 5 +++-- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/Data/Array/Accelerate/Classes/RealFloat.hs b/src/Data/Array/Accelerate/Classes/RealFloat.hs index 31ebde034..8312f20af 100644 --- a/src/Data/Array/Accelerate/Classes/RealFloat.hs +++ b/src/Data/Array/Accelerate/Classes/RealFloat.hs @@ -377,7 +377,7 @@ ieee754_f16_decode i = T2 high3 exp3 = cond (exp1 /= _HMINEXP) -- don't add hidden bit to denorms - (T2 (high2 .|. _HHIGHBIT) exp1 :: Exp (Int16, Int)) + (T2 (high2 .|. _HHIGHBIT) exp1) -- a denorm, normalise the mantissa (while (\(T2 h _) -> (h .&. _HHIGHBIT) /= 0) (\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1)) @@ -411,7 +411,7 @@ ieee754_f32_decode i = T2 high3 exp3 = cond (exp1 /= _FMINEXP) -- don't add hidden bit to denorms - (T2 (high2 .|. _FHIGHBIT) exp1 :: Exp (Int32, Int)) + (T2 (high2 .|. _FHIGHBIT) exp1) -- a denorm, normalise the mantissa (while (\(T2 h _) -> (h .&. _FHIGHBIT) /= 0) (\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1)) @@ -450,7 +450,7 @@ ieee754_f64_decode2 i = T3 hi lo ie = cond (iexp2 /= _DMINEXP) -- don't add hidden bit to denorms - (T3 (high2 .|. _DHIGHBIT) low iexp :: Exp (Word32, Word32, Int)) + (T3 (high2 .|. _DHIGHBIT) low iexp) -- a denorm, nermalise the mantissa (while (\(T3 h _ _) -> (h .&. _DHIGHBIT) /= 0) (\(T3 h l e) -> diff --git a/src/Data/Array/Accelerate/Data/Complex.hs b/src/Data/Array/Accelerate/Data/Complex.hs index 500b022f0..f10a10249 100644 --- a/src/Data/Array/Accelerate/Data/Complex.hs +++ b/src/Data/Array/Accelerate/Data/Complex.hs @@ -319,7 +319,7 @@ instance RealFloat a => P.Floating (Exp (Complex a)) where then 0 else u ::+ (y < 0 ? (-v, v)) where - T2 u v = x < 0 ? (T2 v' u', T2 u' v') :: Exp (a,a) + T2 u v = x < 0 ? (T2 v' u', T2 u' v') v' = abs y / (u'*2) u' = sqrt ((magnitude z + abs x) / 2) diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index bfbb6bb77..255a5e9c2 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -1414,11 +1414,11 @@ odd n = n `rem` 2 /= 0 gcd :: Integral a => Exp a -> Exp a -> Exp a gcd x y = gcd' (abs x) (abs y) where - gcd' :: forall a. Integral a => Exp a -> Exp a -> Exp a + gcd' :: Integral a => Exp a -> Exp a -> Exp a gcd' u v = let T2 r _ = while (\(T2 _ b) -> b /= 0) (\(T2 a b) -> T2 b (a `rem` b)) - (T2 u v) :: Exp (a,a) + (T2 u v) in r @@ -1440,7 +1440,7 @@ x0 ^ y0 = cond (y0 <= 0) 1 (f x0 y0) f x y = let T2 x' y' = while (\(T2 _ v) -> even v) (\(T2 u v) -> T2 (u * u) (v `quot` 2)) - (T2 x y) :: Exp (a,b) + (T2 x y) in cond (y' == 1) x' (g (x'*x') ((y'-1) `quot` 2) x') @@ -1450,7 +1450,7 @@ x0 ^ y0 = cond (y0 <= 0) 1 (f x0 y0) (\(T3 u v w) -> cond (even v) (T3 (u*u) (v `quot` 2) w) (T3 (u*u) ((v-1) `quot` 2) (w*u))) - (T3 x y z) :: Exp (a,b,a) + (T3 x y z) in x' * z' diff --git a/src/Data/Array/Accelerate/Pattern/Tuple.hs b/src/Data/Array/Accelerate/Pattern/Tuple.hs index a4298244d..64312af49 100644 --- a/src/Data/Array/Accelerate/Pattern/Tuple.hs +++ b/src/Data/Array/Accelerate/Pattern/Tuple.hs @@ -9,6 +9,7 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Pattern.Tuple @@ -72,11 +73,11 @@ runQ $ [ funD builder [ clause xsP (normalB (tupE xsE)) [] ] , funD matcher [ clause [] (normalB [| id |]) [] ] ] - , instanceD (mapM (\x -> [t| Elt $x |]) xsT) [t| $(hdr (map (\x -> [t| Exp $x |]) xsT) [t| Exp $(tupT xsT) |]) |] + , instanceD (sequence ( [t| $(varT res) ~ $(tupT xsT) |] : map (\x -> [t| Elt $x |]) xsT )) [t| $(hdr (map (\x -> [t| Exp $x |]) xsT) [t| Exp $(varT res) |]) |] [ funD builder [ clause xsP (normalB [| Pattern $(tupE xsE) |]) [] ] , funD matcher [ clause [conP (mkName "Pattern") [tupP xsP]] (normalB (tupE xsE)) []] ] - , instanceD (mapM (\x -> [t| Arrays $x |]) xsT) [t| $(hdr (map (\x -> [t| Acc $x |]) xsT) [t| Acc $(tupT xsT) |]) |] + , instanceD (sequence ( [t| $(varT res) ~ $(tupT xsT) |] : map (\x -> [t| Arrays $x |]) xsT)) [t| $(hdr (map (\x -> [t| Acc $x |]) xsT) [t| Acc $(varT res) |]) |] [ funD builder [ clause xsP (normalB [| Pattern $(tupE xsE) |]) [] ] , funD matcher [ clause [conP (mkName "Pattern") [tupP xsP]] (normalB (tupE xsE)) []] ] From 68f7de7b13f87c8628c0cca5197c041c506454c9 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 28 Jun 2022 16:25:51 +0200 Subject: [PATCH 35/86] more polymorphic pattern synonyms --- CHANGELOG.md | 7 +- src/Data/Array/Accelerate.hs | 8 +- src/Data/Array/Accelerate/Classes/Eq.hs | 11 ++- src/Data/Array/Accelerate/Classes/Ord.hs | 18 ++-- src/Data/Array/Accelerate/Classes/VOrd.hs | 9 +- src/Data/Array/Accelerate/Data/Bits.hs | 6 +- src/Data/Array/Accelerate/Data/Either.hs | 31 ++++--- src/Data/Array/Accelerate/Data/Maybe.hs | 47 +++++----- src/Data/Array/Accelerate/Pattern/Bool.hs | 73 +++++++++++----- src/Data/Array/Accelerate/Pattern/Either.hs | 74 ++++++++++++++-- src/Data/Array/Accelerate/Pattern/Maybe.hs | 78 +++++++++++++++-- src/Data/Array/Accelerate/Pattern/Ordering.hs | 78 ++++++++++++++++- src/Data/Array/Accelerate/Pattern/SIMD.hs | 7 +- src/Data/Array/Accelerate/Pattern/TH.hs | 87 ++++++++++++------- src/Data/Array/Accelerate/Prelude.hs | 54 ++++++------ .../Accelerate/Test/NoFib/Issues/Issue137.hs | 3 +- .../Accelerate/Test/NoFib/Issues/Issue185.hs | 4 +- .../Accelerate/Test/NoFib/Issues/Issue288.hs | 4 +- .../Accelerate/Test/NoFib/Issues/Issue407.hs | 2 +- .../Accelerate/Test/NoFib/Issues/Issue436.hs | 3 +- .../Accelerate/Test/NoFib/Issues/Issue93.hs | 3 +- .../Accelerate/Test/NoFib/Prelude/Permute.hs | 2 +- .../Accelerate/Test/NoFib/Prelude/SIMD.hs | 41 ++++----- .../Accelerate/Test/NoFib/Prelude/Stencil.hs | 2 +- .../Test/NoFib/Spectral/RadixSort.hs | 4 +- 25 files changed, 451 insertions(+), 205 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f9c6245c3..c83e3ae25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ Policy (PVP)](https://pvp.haskell.org) * Added debugging functions in module `Data.Array.Accelerate.Debug.Trace` ([#485](https://github.com/AccelerateHS/accelerate/pull/485)) * Support for SIMD data types in expressions. Support for storing a type `a` in a SIMD vector can be added by deriving an instance for the class `SIMD`. + Pattern synonyms `V2`, `V3`, `V4`, `V8` and `V16` are provided to work with + these at both the Haskell value and embedded expression level. * Instances for SIMD types in basic numeric classes (e.g. `Num` for `<4 x Float>`) * Support for 128-bit integers (signed and unsigned) * Support for 128-bit floating point types (build with cabal flag `float128`) @@ -18,7 +20,8 @@ Policy (PVP)](https://pvp.haskell.org) ### Changed * Removed dependency on lens ([#493](https://github.com/AccelerateHS/accelerate/pull/493)) * The shape constructors (e.g. `Z` and `(:.)`) are now pattern synonyms that - work on both Haskell values and embedded expressions + work on both Haskell values and embedded expressions Similarly for the + constructors of `Maybe`, `Either`, `Bool`, and `Ordering`. ### Fixed * Graphviz graph generation of `-ddump-dot` and `-ddump-simpl-dot` ([#384](https://github.com/AccelerateHS/accelerate/issues/384)) @@ -27,6 +30,8 @@ Policy (PVP)](https://pvp.haskell.org) ### Removed * Pattern synonyms `Z_`, `(::.)`, `Any_`, `All_`, which are no longer required + * Pattern synonyms `Just_`, `Nothing_` etc., which have been renamed to no + longer require the trailing underscore. ### Contributors diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index b1c4f810c..53e0046ab 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -314,7 +314,7 @@ module Data.Array.Accelerate ( -- ** Type classes -- *** Basic type classes Eq(..), VEq(..), - Ord(..), VOrd(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_, + Ord(..), VOrd(..), Ordering, pattern LT, pattern EQ, pattern GT, Enum, succ, pred, Bounded, minBound, maxBound, @@ -423,9 +423,9 @@ module Data.Array.Accelerate ( Int, Int8, Int16, Int32, Int64, Word, Word8, Word16, Word32, Word64, Half(..), Float, Double, - Bool(..), pattern True_, pattern False_, - Maybe(..), pattern Nothing_, pattern Just_, - Either(..), pattern Left_, pattern Right_, + Bool, pattern True, pattern False, + Maybe, pattern Nothing, pattern Just, + Either, pattern Left, pattern Right, Char, ) where diff --git a/src/Data/Array/Accelerate/Classes/Eq.hs b/src/Data/Array/Accelerate/Classes/Eq.hs index 0ffbc0738..498ea501e 100644 --- a/src/Data/Array/Accelerate/Classes/Eq.hs +++ b/src/Data/Array/Accelerate/Classes/Eq.hs @@ -25,7 +25,7 @@ module Data.Array.Accelerate.Classes.Eq ( - Bool(..), pattern True_, pattern False_, + Bool, pattern True, pattern False, Eq(..), (&&), (&&!), (||), (||!), @@ -45,7 +45,6 @@ import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import {-# SOURCE #-} Data.Array.Accelerate.Classes.VEq -import Data.Bool ( Bool(..) ) import Data.Bits import Text.Printf import Prelude ( ($), String, Num(..), Ordering(..), show, error, return, concat, map, zipWith, foldr1, mapM ) @@ -117,12 +116,12 @@ class Elt a => Eq a where instance Eq () where - _ == _ = True_ - _ /= _ = False_ + _ == _ = True + _ /= _ = False instance Eq Z where - _ == _ = True_ - _ /= _ = False_ + _ == _ = True + _ /= _ = False -- Instances of 'Prelude.Eq' don't make sense with the standard signatures as -- the return type is fixed to 'Bool'. This instance is provided to provide diff --git a/src/Data/Array/Accelerate/Classes/Ord.hs b/src/Data/Array/Accelerate/Classes/Ord.hs index 4749b6a86..82a1c11ef 100644 --- a/src/Data/Array/Accelerate/Classes/Ord.hs +++ b/src/Data/Array/Accelerate/Classes/Ord.hs @@ -26,7 +26,7 @@ module Data.Array.Accelerate.Classes.Ord ( Ord(..), - Ordering(..), pattern LT_, pattern EQ_, pattern GT_, + Ordering, pattern LT, pattern EQ, pattern GT, ) where @@ -47,7 +47,7 @@ import {-# SOURCE #-} Data.Array.Accelerate.Classes.VOrd import Data.Bits import Data.Char import Language.Haskell.TH.Extra hiding ( Exp ) -import Prelude ( ($), Num(..), Ordering(..), Maybe(..), String, show, error, unlines, return, concat, map, mapM ) +import Prelude ( ($), Num(..), Maybe(..), String, show, error, unlines, return, concat, map, mapM ) import Text.Printf import qualified Prelude as P @@ -72,18 +72,18 @@ class Eq a => Ord a where max :: Exp a -> Exp a -> Exp a compare :: Exp a -> Exp a -> Exp Ordering - x < y = cond (compare x y == LT_) True_ False_ - x <= y = cond (compare x y == GT_) False_ True_ - x > y = cond (compare x y == GT_) True_ False_ - x >= y = cond (compare x y == LT_) False_ True_ + x < y = cond (compare x y == LT) True False + x <= y = cond (compare x y == GT) False True + x > y = cond (compare x y == GT) True False + x >= y = cond (compare x y == LT) False True min x y = cond (x <= y) x y max x y = cond (x <= y) y x compare x y - = cond (x == y) EQ_ - $ cond (x <= y) LT_ - {- else -} GT_ + = cond (x == y) EQ + $ cond (x <= y) LT + {- else -} GT -- Local redefinition to prevent cyclic imports -- diff --git a/src/Data/Array/Accelerate/Classes/VOrd.hs b/src/Data/Array/Accelerate/Classes/VOrd.hs index 5d609ffba..b83af2c0b 100644 --- a/src/Data/Array/Accelerate/Classes/VOrd.hs +++ b/src/Data/Array/Accelerate/Classes/VOrd.hs @@ -34,7 +34,8 @@ import qualified Data.Primitive.Bit as Prim import Language.Haskell.TH.Extra hiding ( Type, Exp ) -import Prelude hiding ( Ord(..), (<*) ) +import Prelude hiding ( Ord(..), Ordering(..), (<*) ) +import qualified Prelude as P infix 4 <* @@ -68,9 +69,9 @@ class VEq n a => VOrd n a where $ select (x <=* y) vlt vgt vlt, veq, vgt :: KnownNat n => Exp (Vec n Ordering) -vlt = constant (Vec (let (tag,()) = fromElt LT in Prim.splat tag, ())) -veq = constant (Vec (let (tag,()) = fromElt EQ in Prim.splat tag, ())) -vgt = constant (Vec (let (tag,()) = fromElt EQ in Prim.splat tag, ())) +vlt = constant (Vec (let (tag,()) = fromElt P.LT in Prim.splat tag, ())) +veq = constant (Vec (let (tag,()) = fromElt P.EQ in Prim.splat tag, ())) +vgt = constant (Vec (let (tag,()) = fromElt P.GT in Prim.splat tag, ())) vtrue, vfalse :: KnownNat n => Exp (Vec n Bool) vtrue = constant (Vec (Prim.unMask Prim.ones)) diff --git a/src/Data/Array/Accelerate/Data/Bits.hs b/src/Data/Array/Accelerate/Data/Bits.hs index c93332ca0..25ba79cdf 100644 --- a/src/Data/Array/Accelerate/Data/Bits.hs +++ b/src/Data/Array/Accelerate/Data/Bits.hs @@ -170,12 +170,12 @@ instance Bits Bool where (.|.) = (||) xor = (/=) complement = not - zeroBits = False_ + zeroBits = False testBit = (&&) bit x = x isSigned = isSignedDefault - shiftL x i = cond i False_ x - shiftR x i = cond i False_ x + shiftL x i = cond i False x + shiftR x i = cond i False x rotateL x _ = x rotateR x _ = x popCount x = x diff --git a/src/Data/Array/Accelerate/Data/Either.hs b/src/Data/Array/Accelerate/Data/Either.hs index af50ce6d5..97b4c13f9 100644 --- a/src/Data/Array/Accelerate/Data/Either.hs +++ b/src/Data/Array/Accelerate/Data/Either.hs @@ -28,7 +28,7 @@ module Data.Array.Accelerate.Data.Either ( - Either(..), pattern Left_, pattern Right_, + Either, pattern Left, pattern Right, either, isLeft, isRight, fromLeft, fromRight, lefts, rights, ) where @@ -53,7 +53,6 @@ import Data.Array.Accelerate.Data.Functor import Data.Array.Accelerate.Data.Monoid import Data.Array.Accelerate.Data.Semigroup -import Data.Either ( Either(..) ) import Prelude ( (.), ($) ) @@ -89,8 +88,8 @@ fromRight (Exp e) = mkExp $ Prj PairIdxRight $ SmartExp $ Prj PairIdxRight e -- either :: (Elt a, Elt b, Elt c) => (Exp a -> Exp c) -> (Exp b -> Exp c) -> Exp (Either a b) -> Exp c either f g = match \case - Left_ x -> f x - Right_ x -> g x + Left x -> f x + Right x -> g x -- | Extract from the array of 'Either' all of the 'Left' elements, together -- with a segment descriptor indicating how many elements along each dimension @@ -112,32 +111,32 @@ rights es = compact (map isRight es) (map fromRight es) instance Elt a => Functor (Either a) where - fmap f = either Left_ (Right_ . f) + fmap f = either Left (Right . f) instance Elt a => Monad (Either a) where - return = Right_ - x >>= f = either Left_ f x + return = Right + x >>= f = either Left f x instance (Eq a, Eq b) => Eq (Either a b) where (==) = match go where - go (Left_ x) (Left_ y) = x == y - go (Right_ x) (Right_ y) = x == y - go _ _ = False_ + go (Left x) (Left y) = x == y + go (Right x) (Right y) = x == y + go _ _ = False instance (Ord a, Ord b) => Ord (Either a b) where compare = match go where - go (Left_ x) (Left_ y) = compare x y - go (Right_ x) (Right_ y) = compare x y - go Left_{} Right_{} = LT_ - go Right_{} Left_{} = GT_ + go (Left x) (Left y) = compare x y + go (Right x) (Right y) = compare x y + go Left{} Right{} = LT + go Right{} Left{} = GT instance (Elt a, Elt b) => Semigroup (Exp (Either a b)) where ex <> ey = isLeft ex ? ( ey, ex ) instance (Lift Exp a, Lift Exp b, Elt (Plain a), Elt (Plain b)) => Lift Exp (Either a b) where type Plain (Either a b) = Either (Plain a) (Plain b) - lift (Left a) = Left_ (lift a) - lift (Right b) = Right_ (lift b) + lift (Left a) = Left (lift a) + lift (Right b) = Right (lift b) diff --git a/src/Data/Array/Accelerate/Data/Maybe.hs b/src/Data/Array/Accelerate/Data/Maybe.hs index 11366398b..979ae0510 100644 --- a/src/Data/Array/Accelerate/Data/Maybe.hs +++ b/src/Data/Array/Accelerate/Data/Maybe.hs @@ -28,7 +28,7 @@ module Data.Array.Accelerate.Data.Maybe ( - Maybe(..), pattern Nothing_, pattern Just_, + Maybe, pattern Nothing, pattern Just, maybe, isJust, isNothing, fromMaybe, fromJust, justs, ) where @@ -54,7 +54,6 @@ import Data.Array.Accelerate.Data.Monoid import Data.Array.Accelerate.Data.Semigroup import Data.Function ( (&) ) -import Data.Maybe ( Maybe(..) ) import Prelude ( ($), (.) ) @@ -76,8 +75,8 @@ isJust (Exp x) = mkExp $ PrimApp (PrimToBool integralType bitType) (SmartExp $ P -- fromMaybe :: Elt a => Exp a -> Exp (Maybe a) -> Exp a fromMaybe d = match \case - Nothing_ -> d - Just_ x -> x + Nothing -> d + Just x -> x -- | The 'fromJust' function extracts the element out of the 'Just' constructor. -- If the argument was actually 'Nothing', you will get an undefined value @@ -93,8 +92,8 @@ fromJust (Exp x) = Exp $ SmartExp (PairIdxRight `Prj` SmartExp (PairIdxRight `Pr -- maybe :: (Elt a, Elt b) => Exp b -> (Exp a -> Exp b) -> Exp (Maybe a) -> Exp b maybe d f = match \case - Nothing_ -> d - Just_ x -> f x + Nothing -> d + Just x -> f x -- | Extract from an array all of the 'Just' values, together with a segment -- descriptor indicating how many elements along each dimension were returned. @@ -107,42 +106,42 @@ justs xs = compact (map isJust xs) (map fromJust xs) instance Functor Maybe where fmap f = match \case - Nothing_ -> Nothing_ - Just_ x -> Just_ (f x) + Nothing -> Nothing + Just x -> Just (f x) instance Monad Maybe where - return = Just_ + return = Just mx >>= f = mx & match \case - Nothing_ -> Nothing_ - Just_ x -> f x + Nothing -> Nothing + Just x -> f x instance Eq a => Eq (Maybe a) where (==) = match go where - go Nothing_ Nothing_ = True_ - go (Just_ x) (Just_ y) = x == y - go _ _ = False_ + go Nothing Nothing = True + go (Just x) (Just y) = x == y + go _ _ = False instance Ord a => Ord (Maybe a) where compare = match go where - go (Just_ x) (Just_ y) = compare x y - go Nothing_ Nothing_ = EQ_ - go Nothing_ Just_{} = LT_ - go Just_{} Nothing_{} = GT_ + go (Just x) (Just y) = compare x y + go Nothing Nothing = EQ + go Nothing Just{} = LT + go Just{} Nothing{} = GT instance (Monoid (Exp a), Elt a) => Monoid (Exp (Maybe a)) where - mempty = Nothing_ + mempty = Nothing instance (Semigroup (Exp a), Elt a) => Semigroup (Exp (Maybe a)) where (<>) = match go where - go Nothing_ b = b - go a Nothing_ = a - go (Just_ a) (Just_ b) = Just_ (a <> b) + go Nothing b = b + go a Nothing = a + go (Just a) (Just b) = Just (a <> b) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Maybe a) where type Plain (Maybe a) = Maybe (Plain a) - lift Nothing = Nothing_ - lift (Just a) = Just_ (lift a) + lift Nothing = Nothing + lift (Just a) = Just (lift a) diff --git a/src/Data/Array/Accelerate/Pattern/Bool.hs b/src/Data/Array/Accelerate/Pattern/Bool.hs index 08293e975..941238e41 100644 --- a/src/Data/Array/Accelerate/Pattern/Bool.hs +++ b/src/Data/Array/Accelerate/Pattern/Bool.hs @@ -1,9 +1,7 @@ -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Pattern.Bool -- Copyright : [2018..2020] The Accelerate Team @@ -16,7 +14,7 @@ module Data.Array.Accelerate.Pattern.Bool ( - Bool, pattern True_, pattern False_, + Bool, pattern True, pattern False, ) where @@ -24,24 +22,55 @@ import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type +import Data.Bool ( Bool ) +import Prelude hiding ( Bool(..) ) +import qualified Prelude as P + import GHC.Stack -{-# COMPLETE False_, True_ #-} -pattern False_ :: HasCallStack => Exp Bool -pattern False_ <- (matchFalse -> Just ()) - where False_ = buildFalse +{-# COMPLETE False, True #-} +pattern False :: (HasCallStack, IsFalse r) => r +pattern False <- (matchFalse -> Just ()) + where False = buildFalse + +pattern True :: (HasCallStack, IsTrue r) => r +pattern True <- (matchTrue -> Just ()) + where True = buildTrue + +class IsFalse r where + buildFalse :: r + matchFalse :: r -> Maybe () + +instance IsFalse Bool where + buildFalse = P.False + matchFalse P.False = Just () + matchFalse _ = Nothing + +instance IsFalse (Exp Bool) where + buildFalse = _buildFalse + matchFalse = _matchFalse + +class IsTrue r where + buildTrue :: r + matchTrue :: r -> Maybe () + +instance IsTrue Bool where + buildTrue = P.True + matchTrue P.True = Just () + matchTrue _ = Nothing + +instance IsTrue (Exp Bool) where + buildTrue = _buildTrue + matchTrue = _matchTrue -pattern True_ :: HasCallStack => Exp Bool -pattern True_ <- (matchTrue -> Just ()) - where True_ = buildTrue -buildFalse :: Exp Bool -buildFalse = mkExp $ Const scalarType 0 +_buildFalse :: Exp Bool +_buildFalse = mkExp $ Const scalarType 0 -matchFalse :: HasCallStack => Exp Bool -> Maybe () -matchFalse (Exp e) = +_matchFalse :: HasCallStack => Exp Bool -> Maybe () +_matchFalse (Exp e) = case e of SmartExp (Match (TagRenum TagBit 0) _) -> Just () SmartExp Match{} -> Nothing @@ -58,11 +87,11 @@ matchFalse (Exp e) = , "> _ -> ..." ] -buildTrue :: Exp Bool -buildTrue = mkExp $ Const scalarType 1 +_buildTrue :: Exp Bool +_buildTrue = mkExp $ Const scalarType 1 -matchTrue :: HasCallStack => Exp Bool -> Maybe () -matchTrue (Exp e) = +_matchTrue :: HasCallStack => Exp Bool -> Maybe () +_matchTrue (Exp e) = case e of SmartExp (Match (TagRenum TagBit 1) _) -> Just () SmartExp Match{} -> Nothing diff --git a/src/Data/Array/Accelerate/Pattern/Either.hs b/src/Data/Array/Accelerate/Pattern/Either.hs index 67c7b3a3f..8a02ba708 100644 --- a/src/Data/Array/Accelerate/Pattern/Either.hs +++ b/src/Data/Array/Accelerate/Pattern/Either.hs @@ -1,9 +1,11 @@ -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Pattern.Either -- Copyright : [2018..2020] The Accelerate Team @@ -16,11 +18,67 @@ module Data.Array.Accelerate.Pattern.Either ( - Either, pattern Left_, pattern Right_, + Either, pattern Left, pattern Right, ) where import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Smart -mkPattern ''Either +import Data.Either ( Either ) +import Language.Haskell.TH.Extra hiding ( Exp ) +import Prelude hiding ( Either(..) ) +import qualified Data.List as P +import qualified Prelude as P + + +runQ $ do + let it SigD{} = True + it FunD{} = True + it _ = False + + find _ [] = error "could not find specified function" + find pat (d:ds) = + case d of + SigD n _ | pat `P.isPrefixOf` nameBase n -> varE n + _ -> find pat ds + + decs <- filter it <$> mkPattern ''Either + rest <- [d| {-# COMPLETE Left, Right #-} + pattern Left :: IsLeft a b r => a -> r + pattern Left x <- (matchLeft -> P.Just x) + where Left = buildLeft + + class IsLeft a b r | r -> a b where + buildLeft :: a -> r + matchLeft :: r -> Maybe a + + instance IsLeft a b (Either a b) where + buildLeft = P.Left + matchLeft (P.Left a) = P.Just a + matchLeft _ = P.Nothing + + instance (Elt a, Elt b) => IsLeft (Exp a) (Exp b) (Exp (Either a b)) where + buildLeft = $(find "_buildLeft" decs) + matchLeft = $(find "_matchLeft" decs) + + pattern Right :: IsRight a b r => b -> r + pattern Right x <- (matchRight -> P.Just x) + where Right = buildRight + + class IsRight a b r | r -> a b where + buildRight :: b -> r + matchRight :: r -> Maybe b + + instance IsRight a b (Either a b) where + buildRight = P.Right + matchRight (P.Right b) = P.Just b + matchRight _ = P.Nothing + + instance (Elt a, Elt b) => IsRight (Exp a) (Exp b) (Exp (Either a b)) where + buildRight = $(find "_buildRight" decs) + matchRight = $(find "_matchRight" decs) + |] + return (decs ++ rest) diff --git a/src/Data/Array/Accelerate/Pattern/Maybe.hs b/src/Data/Array/Accelerate/Pattern/Maybe.hs index 67e341d64..88aa1ea67 100644 --- a/src/Data/Array/Accelerate/Pattern/Maybe.hs +++ b/src/Data/Array/Accelerate/Pattern/Maybe.hs @@ -1,9 +1,12 @@ -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Pattern.Maybe -- Copyright : [2018..2020] The Accelerate Team @@ -16,11 +19,70 @@ module Data.Array.Accelerate.Pattern.Maybe ( - Maybe, pattern Nothing_, pattern Just_, + Maybe, pattern Nothing, pattern Just, ) where import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Smart -mkPattern ''Maybe +import Data.Maybe ( Maybe ) +import Language.Haskell.TH.Extra hiding ( Exp ) +import Prelude hiding ( Maybe(..) ) +import qualified Data.List as P +import qualified Prelude as P + +import GHC.Stack + + +-- TODO: We should make this a feature of the mkPattern machinery +-- +runQ $ do + let it SigD{} = True + it FunD{} = True + it _ = False + + find _ [] = error "could not find specified function" + find pat (d:ds) = + case d of + SigD n _ | pat `P.isPrefixOf` nameBase n -> varE n + _ -> find pat ds + + decs <- filter it <$> mkPattern ''Maybe + rest <- [d| {-# COMPLETE Nothing, Just #-} + pattern Nothing :: (HasCallStack, IsNothing r) => r + pattern Nothing <- (matchNothing -> P.Just ()) + where Nothing = buildNothing + + pattern Just :: (HasCallStack, IsJust a r) => a -> r + pattern Just x <- (matchJust -> P.Just x) + where Just = buildJust + + class IsNothing r where + buildNothing :: r + matchNothing :: r -> Maybe () + + instance IsNothing (Maybe a) where + buildNothing = P.Nothing + matchNothing P.Nothing = P.Just () + matchNothing _ = P.Nothing + + instance Elt a => IsNothing (Exp (Maybe a)) where + buildNothing = $(find "_buildNothing" decs) + matchNothing = $(find "_matchNothing" decs) + + class IsJust a r | r -> a where + buildJust :: a -> r + matchJust :: r -> Maybe a + + instance IsJust a (Maybe a) where + buildJust = P.Just + matchJust = P.id + + instance Elt a => IsJust (Exp a) (Exp (Maybe a)) where + buildJust = $(find "_buildJust" decs) + matchJust = $(find "_matchJust" decs) + |] + return (decs ++ rest) diff --git a/src/Data/Array/Accelerate/Pattern/Ordering.hs b/src/Data/Array/Accelerate/Pattern/Ordering.hs index 2407cf9e9..896639e51 100644 --- a/src/Data/Array/Accelerate/Pattern/Ordering.hs +++ b/src/Data/Array/Accelerate/Pattern/Ordering.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -16,11 +17,84 @@ module Data.Array.Accelerate.Pattern.Ordering ( - Ordering, pattern LT_, pattern EQ_, pattern GT_, + Ordering, pattern LT, pattern EQ, pattern GT, ) where import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Smart -mkPattern ''Ordering +import Data.Ord ( Ordering ) +import Language.Haskell.TH.Extra hiding ( Exp ) +import Prelude hiding ( Ordering(..) ) +import qualified Data.List as P +import qualified Prelude as P + + +runQ $ do + let it SigD{} = True + it FunD{} = True + it _ = False + + find _ [] = error "could not find specified function" + find pat (d:ds) = + case d of + SigD n _ | pat `P.isPrefixOf` nameBase n -> varE n + _ -> find pat ds + + decs <- filter it <$> mkPattern ''Ordering + rest <- [d| {-# COMPLETE LT, EQ, GT #-} + pattern LT :: IsLT a => a + pattern LT <- (matchLT -> Just ()) + where LT = buildLT + + class IsLT a where + buildLT :: a + matchLT :: a -> Maybe () + + instance IsLT Ordering where + buildLT = P.LT + matchLT P.LT = Just () + matchLT _ = Nothing + + instance IsLT (Exp Ordering) where + buildLT = $(find "_buildLT" decs) + matchLT = $(find "_matchLT" decs) + + pattern EQ :: IsEQ a => a + pattern EQ <- (matchEQ -> Just ()) + where EQ = buildEQ + + class IsEQ a where + buildEQ :: a + matchEQ :: a -> Maybe () + + instance IsEQ Ordering where + buildEQ = P.EQ + matchEQ P.EQ = Just () + matchEQ _ = Nothing + + instance IsEQ (Exp Ordering) where + buildEQ = $(find "_buildEQ" decs) + matchEQ = $(find "_matchEQ" decs) + + pattern GT :: IsGT a => a + pattern GT <- (matchGT -> Just ()) + where GT = buildGT + + class IsGT a where + buildGT :: a + matchGT :: a -> Maybe () + + instance IsGT Ordering where + buildGT = P.GT + matchGT P.GT = Just () + matchGT _ = Nothing + + instance IsGT (Exp Ordering) where + buildGT = $(find "_buildGT" decs) + matchGT = $(find "_matchGT" decs) + + |] + return (decs ++ rest) diff --git a/src/Data/Array/Accelerate/Pattern/SIMD.hs b/src/Data/Array/Accelerate/Pattern/SIMD.hs index 075185ab4..bc1a9d479 100644 --- a/src/Data/Array/Accelerate/Pattern/SIMD.hs +++ b/src/Data/Array/Accelerate/Pattern/SIMD.hs @@ -9,6 +9,7 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Pattern.SIMD @@ -135,6 +136,10 @@ runQ $ ctx = return [ ConT ''Elt `AppT` VarT a , ConT ''SIMD `AppT` LitT (NumTyLit (toInteger n)) `AppT` VarT a ] + ctx' = return [ ConT ''Elt `AppT` VarT a + , ConT ''SIMD `AppT` LitT (NumTyLit (toInteger n)) `AppT` VarT a + , EqualityT `AppT` VarT v `AppT` (ConT name `AppT` VarT a) + ] -- sequence [ patSynSigD name [t| $(conT isV) $(varT a) $(varT v) => $(foldr (\t r -> [t| $t -> $r |]) (varT v) as) |] @@ -151,7 +156,7 @@ runQ $ [ funD builder [ clause xsP (normalB [| fromList $(listE xsE) |]) []] , funD matcher [ clause [viewP (varE 'toList) (listP xsP)] (normalB (tupE xsE)) [] ] ] - , instanceD ctx [t| $(conT isV) (Exp $(varT a)) (Exp ($(conT name) $(varT a))) |] + , instanceD ctx' [t| $(conT isV) (Exp $(varT a)) (Exp $(varT v)) |] [ funD builder [ clause xsP (normalB [| SIMD $(tupE xsE) |]) []] , funD matcher [ clause [conP (mkName "SIMD") [tupP xsP]] (normalB (tupE xsE)) [] ] ] diff --git a/src/Data/Array/Accelerate/Pattern/TH.hs b/src/Data/Array/Accelerate/Pattern/TH.hs index 7874f4a4c..6a3f8303e 100644 --- a/src/Data/Array/Accelerate/Pattern/TH.hs +++ b/src/Data/Array/Accelerate/Pattern/TH.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} -- | @@ -12,8 +13,9 @@ module Data.Array.Accelerate.Pattern.TH ( - mkPattern, - mkPatterns, + Options(..), defaultOptions, + mkPattern, mkPatternWith, + mkPatterns, mkPatternsWith, ) where @@ -36,10 +38,24 @@ import qualified Language.Haskell.TH.Extra as TH import GHC.Stack +-- | Options to control what is generated +-- +data Options = Options + { renameNormalC :: Name -> Name + , renameInfixC :: Name -> Name + } + +defaultOptions :: Options +defaultOptions = Options defaultRenameNormalC defaultRenameInfixC + + -- | As 'mkPattern', but for a list of types -- mkPatterns :: [Name] -> DecsQ -mkPatterns nms = concat <$> mapM mkPattern nms +mkPatterns = mkPatternsWith defaultOptions + +mkPatternsWith :: Options -> [Name] -> DecsQ +mkPatternsWith opts nms = concat <$> mapM (mkPatternWith opts) nms -- | Generate pattern synonyms for the given simple (Haskell'98) sum or -- product data type. @@ -61,24 +77,27 @@ mkPatterns nms = concat <$> mapM mkPattern nms -- > ycoord :: Exp Point -> Exp Float -- mkPattern :: Name -> DecsQ -mkPattern nm = do +mkPattern = mkPatternWith defaultOptions + +mkPatternWith :: Options -> Name -> DecsQ +mkPatternWith opt nm = do info <- reify nm case info of - TyConI dec -> mkDec dec + TyConI dec -> mkDec opt dec _ -> fail "mkPatterns: expected the name of a newtype or datatype" -mkDec :: Dec -> DecsQ -mkDec dec = +mkDec :: Options -> Dec -> DecsQ +mkDec opt dec = case dec of - DataD _ nm tv _ cs _ -> mkDataD nm tv cs - NewtypeD _ nm tv _ c _ -> mkNewtypeD nm tv c + DataD _ nm tv _ cs _ -> mkDataD opt nm tv cs + NewtypeD _ nm tv _ c _ -> mkNewtypeD opt nm tv c _ -> fail "mkPatterns: expected the name of a newtype or datatype" -mkNewtypeD :: Name -> [TyVarBndr ()] -> Con -> DecsQ -mkNewtypeD tn tvs c = mkDataD tn tvs [c] +mkNewtypeD :: Options -> Name -> [TyVarBndr ()] -> Con -> DecsQ +mkNewtypeD opt tn tvs c = mkDataD opt tn tvs [c] -mkDataD :: Name -> [TyVarBndr ()] -> [Con] -> DecsQ -mkDataD tn tvs cs = do +mkDataD :: Options -> Name -> [TyVarBndr ()] -> [Con] -> DecsQ +mkDataD opt tn tvs cs = do (pats, decs) <- unzip <$> go cs comp <- pragCompleteD pats Nothing return $ comp : concat decs @@ -86,14 +105,14 @@ mkDataD tn tvs cs = do -- For single-constructor types we create the pattern synonym for the -- type directly in terms of Pattern go [] = fail "mkPatterns: empty data declarations not supported" - go [c] = return <$> mkConP tn tvs c + go [c] = return <$> mkConP opt tn tvs c go _ = go' [] (map fieldTys cs) ctags cs -- For sum-types, when creating the pattern for an individual -- constructor we need to know about the types of the fields all other -- constructors as well go' prev (this:next) (tag:tags) (con:cons) = do - r <- mkConS tn tvs prev next tag con + r <- mkConS opt tn tvs prev next tag con rs <- go' (this:prev) next tags cons return (r : rs) go' _ [] [] [] = return [] @@ -122,12 +141,12 @@ mkDataD tn tvs cs = do map bitsToTag (l ++ r) -mkConP :: Name -> [TyVarBndr ()] -> Con -> Q (Name, [Dec]) -mkConP tn' tvs' con' = do +mkConP :: Options -> Name -> [TyVarBndr ()] -> Con -> Q (Name, [Dec]) +mkConP Options{..} tn' tvs' con' = do checkExts [ PatternSynonyms ] case con' of NormalC cn fs -> mkNormalC tn' cn (map tyVarBndrName tvs') (map snd fs) - RecC cn fs -> mkRecC tn' cn (map tyVarBndrName tvs') (map (rename . fst3) fs) (map thd3 fs) + RecC cn fs -> mkRecC tn' cn (map tyVarBndrName tvs') (map (renameNormalC . fst3) fs) (map thd3 fs) InfixC a cn b -> mkInfixC tn' cn (map tyVarBndrName tvs') [snd a, snd b] _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" where @@ -142,7 +161,7 @@ mkConP tn' tvs' con' = do ] return (pat, r) where - pat = rename cn + pat = renameNormalC cn sig = forallT (map (`plainInvisTV` specifiedSpec) tvs) (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) @@ -160,7 +179,7 @@ mkConP tn' tvs' con' = do ] return (pat, r) where - pat = rename cn + pat = renameNormalC cn sig = forallT (map (`plainInvisTV` specifiedSpec) tvs) (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) @@ -184,7 +203,7 @@ mkConP tn' tvs' con' = do Just f -> return (InfixD f pat : r) return (pat, r') where - pat = mkName (':' : nameBase cn) + pat = renameInfixC cn sig = forallT (map (`plainInvisTV` specifiedSpec) tvs) (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) @@ -192,18 +211,18 @@ mkConP tn' tvs' con' = do [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] (map (\t -> [t| Exp $(return t) |]) fs)) -mkConS :: Name -> [TyVarBndr ()] -> [[Type]] -> [[Type]] -> Word8 -> Con -> Q (Name, [Dec]) -mkConS tn' tvs' prev' next' tag' con' = do +mkConS :: Options -> Name -> [TyVarBndr ()] -> [[Type]] -> [[Type]] -> Word8 -> Con -> Q (Name, [Dec]) +mkConS Options{..} tn' tvs' prev' next' tag' con' = do checkExts [GADTs, PatternSynonyms, ScopedTypeVariables, TypeApplications, ViewPatterns] case con' of NormalC cn fs -> mkNormalC tn' cn tag' (map tyVarBndrName tvs') prev' (map snd fs) next' - RecC cn fs -> mkRecC tn' cn tag' (map tyVarBndrName tvs') (map (rename . fst3) fs) prev' (map thd3 fs) next' + RecC cn fs -> mkRecC tn' cn tag' (map tyVarBndrName tvs') (map (renameNormalC . fst3) fs) prev' (map thd3 fs) next' InfixC a cn b -> mkInfixC tn' cn tag' (map tyVarBndrName tvs') prev' [snd a, snd b] next' _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" where mkNormalC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) mkNormalC tn cn tag tvs ps fs ns = do - let pat = rename cn + let pat = renameNormalC cn (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns dec_pat <- mkNormalC_pattern tn pat tvs fs fun_build fun_match @@ -211,7 +230,7 @@ mkConS tn' tvs' prev' next' tag' con' = do mkRecC :: Name -> Name -> Word8 -> [Name] -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) mkRecC tn cn tag tvs xs ps fs ns = do - let pat = rename cn + let pat = renameNormalC cn (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns dec_pat <- mkRecC_pattern tn pat tvs xs fs fun_build fun_match @@ -219,9 +238,10 @@ mkConS tn' tvs' prev' next' tag' con' = do mkInfixC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) mkInfixC tn cn tag tvs ps fs ns = do - let pat = mkName (':' : nameBase cn) - (fun_build, dec_build) <- mkBuild tn (zencode (nameBase cn)) tvs tag ps fs ns - (fun_match, dec_match) <- mkMatch tn ("(" ++ nameBase pat ++ ")") (zencode (nameBase cn)) tvs tag ps fs ns + let pat = renameInfixC cn + zcn = zencode (nameBase cn) + (fun_build, dec_build) <- mkBuild tn zcn tvs tag ps fs ns + (fun_match, dec_match) <- mkMatch tn ("(" ++ nameBase pat ++ ")") zcn tvs tag ps fs ns dec_pat <- mkInfixC_pattern tn cn pat tvs fs fun_build fun_match return $ (pat, concat [dec_pat, dec_build, dec_match]) @@ -293,7 +313,7 @@ mkConS tn' tvs' prev' next' tag' con' = do ++ map varE xs ++ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat fs1) - tagged = [| Exp $ SmartExp $ Pair (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) $(litE (IntegerL (toInteger tag))))) $vs |] + tagged = [| Exp (SmartExp (Pair (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) $(litE (IntegerL (toInteger tag))))) $vs)) |] body = clause (map (\x -> [p| (Exp $(varP x)) |]) xs) (normalB tagged) [] r <- sequence [ sigD fun sig @@ -374,8 +394,8 @@ fst3 (a,_,_) = a thd3 :: (a,b,c) -> c thd3 (_,_,c) = c -rename :: Name -> Name -rename nm = +defaultRenameNormalC :: Name -> Name +defaultRenameNormalC nm = let split acc [] = (reverse acc, '\0') -- shouldn't happen split acc [l] = (reverse acc, l) @@ -388,6 +408,9 @@ rename nm = '_' -> mkName base _ -> mkName (nm' ++ "_") +defaultRenameInfixC :: Name -> Name +defaultRenameInfixC nm = mkName (':' : nameBase nm) + checkExts :: [Extension] -> Q () checkExts req = do enabled <- extsEnabled diff --git a/src/Data/Array/Accelerate/Prelude.hs b/src/Data/Array/Accelerate/Prelude.hs index 2cfdaa38b..eedd257b2 100644 --- a/src/Data/Array/Accelerate/Prelude.hs +++ b/src/Data/Array/Accelerate/Prelude.hs @@ -139,7 +139,7 @@ import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Data.Bits import Lens.Micro ( Lens', (&), (^.), (.~), (+~), (-~), lens, over ) -import Prelude ( (.), ($), Maybe(..), const, id, flip ) +import Prelude ( (.), ($), const, id, flip ) -- $setup @@ -805,14 +805,14 @@ any f = or . map f and :: Shape sh => Acc (Array (sh:.Int) Bool) -> Acc (Array sh Bool) -and = fold (&&) True_ +and = fold (&&) True -- | Check if any element along the innermost dimension is 'True'. -- or :: Shape sh => Acc (Array (sh:.Int) Bool) -> Acc (Array sh Bool) -or = fold (||) False_ +or = fold (||) False -- | Compute the sum of elements along the innermost dimension of the array. To -- find the sum of the entire array, 'flatten' it first. @@ -998,7 +998,7 @@ scanlSeg f z arr seg = seg' = map (+1) seg arr' = permute const (fill (sh :. sz + length seg) z) - (\(sx :. i) -> Just_ (sx :. i + fromIntegral (inc ! I1 i))) + (\(sx :. i) -> Just (sx :. i + fromIntegral (inc ! I1 i))) (take (length flags) arr) -- Each element in the segments must be shifted to the right one additional @@ -1091,7 +1091,7 @@ scanl'Seg f z arr seg = offset = scanl1 (+) seg inc = scanl1 (+) $ permute (+) (fill (I1 $ size arr + 1) 0) - (\ix -> Just_ (index1' (offset ! ix))) + (\ix -> Just (index1' (offset ! ix))) (fill (shape seg) (1 :: Exp i)) len = offset ! I1 (length offset - 1) @@ -1219,7 +1219,7 @@ scanrSeg f z arr seg = seg' = map (+1) seg arr' = permute const (fill (sh :. sz + length seg) z) - (\(sx :. i) -> Just_ (sx :. i + fromIntegral (inc !! i) - 1)) + (\(sx :. i) -> Just (sx :. i + fromIntegral (inc !! i) - 1)) (drop (sz - length flags) arr) @@ -1365,7 +1365,7 @@ mkHeadFlags -> Acc (Segments i) mkHeadFlags seg = init - $ permute (+) zeros (\ix -> Just_ (index1' (offset ! ix))) ones + $ permute (+) zeros (\ix -> Just (index1' (offset ! ix))) ones where T2 offset len = scanl' (+) 0 seg zeros = fill (index1' $ the len + 1) 0 @@ -1380,7 +1380,7 @@ mkTailFlags -> Acc (Segments i) mkTailFlags seg = init - $ permute (+) zeros (\ix -> Just_ (index1' (the len - 1 - offset ! ix))) ones + $ permute (+) zeros (\ix -> Just (index1' (the len - 1 - offset ! ix))) ones where T2 offset len = scanr' (+) 0 seg zeros = fill (index1' $ the len + 1) 0 @@ -1658,8 +1658,8 @@ compact keep arr = let T2 target len = scanl' (+) 0 (map boolToInt keep) prj ix = if keep!ix - then Just_ (I1 (target!ix)) - else Nothing_ + then Just (I1 (target!ix)) + else Nothing dummy = fill (I1 (the len)) undef result = permute const dummy prj arr in @@ -1676,8 +1676,8 @@ compact keep arr T2 target len = scanl' (+) 0 (map boolToInt keep) T2 offset valid = scanl' (+) 0 (flatten len) prj ix = if keep!ix - then Just_ (I1 (offset !! (toIndex sz (indexTail ix)) + target!ix)) - else Nothing_ + then Just (I1 (offset !! (toIndex sz (indexTail ix)) + target!ix)) + else Nothing dummy = fill (I1 (the valid)) undef result = permute const dummy prj arr in @@ -1758,7 +1758,7 @@ scatter -> Acc (Vector e) scatter to defaults input = permute const defaults pf input' where - pf ix = Just_ (I1 (to ! ix)) + pf ix = Just (I1 (to ! ix)) input' = backpermute (shape to `intersect` shape input) id input @@ -1787,8 +1787,8 @@ scatterIf to maskV pred defaults input = permute const defaults pf input' where input' = backpermute (shape to `intersect` shape input) id input pf ix = if pred (maskV ! ix) - then Just_ (I1 (to ! ix)) - else Nothing_ + then Just (I1 (to ! ix)) + else Nothing -- Permutations @@ -2224,9 +2224,9 @@ instance Arrays a => IfThenElse (Exp Bool) (Acc a) where -- argument. For example, given the function: -- -- > example1 :: Exp (Maybe Bool) -> Exp Int --- > example1 Nothing_ = 0 --- > example1 (Just_ False_) = 1 --- > example1 (Just_ True_) = 2 +-- > example1 Nothing = 0 +-- > example1 (Just False) = 1 +-- > example1 (Just True) = 2 -- -- In order to use this function it must be applied to the 'match' -- operator: @@ -2237,14 +2237,14 @@ instance Arrays a => IfThenElse (Exp Bool) (Acc a) where -- case statements inline. For example, instead of this: -- -- > example2 x = case f x of --- > Nothing_ -> ... -- error: embedded pattern synonym... --- > Just_ y -> ... -- ...used outside of 'match' context +-- > Nothing -> ... -- error: embedded pattern synonym... +-- > Just y -> ... -- ...used outside of 'match' context -- -- This can be written instead as: -- -- > example3 x = f x & match \case --- > Nothing_ -> ... --- > Just_ y -> ... +-- > Nothing -> ... +-- > Just y -> ... -- -- And utilising the @LambdaCase@ and @BlockArguments@ syntactic extensions. -- @@ -2261,8 +2261,8 @@ instance Arrays a => IfThenElse (Exp Bool) (Acc a) where -- -- > isNone :: Elt a => Exp (Option a) -> Exp Bool -- > isNone = match \case --- > None_ -> True_ --- > Some_{} -> False_ +-- > None_ -> True +-- > Some_{} -> False -- -- @since 1.3.0.0 -- @@ -2495,7 +2495,7 @@ length = unindex1 . shape -- new = -- let m = c2-c1 -- put i = let s = sieves ! i --- in s >= 0 && s < m ? (Just_ (I1 s), Nothing_) +-- in s >= 0 && s < m ? (Just (I1 s), Nothing) -- in -- afst -- $ filter (> 0) @@ -2530,7 +2530,7 @@ expand f g xs = else let n = m + 1 - put ix = Just_ (I1 (offset ! ix)) + put ix = Just (I1 (offset ! ix)) head_flags :: Acc (Vector Int) head_flags = permute const (fill (I1 n) 0) put (fill (shape szs) 1) @@ -2552,7 +2552,7 @@ expand f g xs = -- also the same, which is undefined behaviour (\ix -> if szs ! ix > 0 then put ix - else Nothing_) + else Nothing) $ enumFromN (shape xs) 0 in zipWith g (gather iotas xs) idxs diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue137.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue137.hs index ce0cf613f..0caa35b2b 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue137.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue137.hs @@ -24,6 +24,7 @@ import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty import Test.Tasty.HUnit +import Prelude hiding ( Maybe(..) ) test_issue137 :: RunN -> TestTree @@ -51,6 +52,6 @@ test1 = , T2 b1 (A.min b2 a1) )) infsA - (\ix -> Just_ (index1 (msA A.! ix))) + (\ix -> Just (index1 (msA A.! ix))) inpA diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue185.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue185.hs index f14f83520..53c834915 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue185.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue185.hs @@ -28,7 +28,7 @@ import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty import Test.Tasty.HUnit -import Prelude as P +import Prelude as P hiding ( Maybe(..) ) test_issue185 :: RunN -> TestTree @@ -144,6 +144,6 @@ scatterIf -> Acc (Vector e') scatterIf to maskV p def input = permute const def pf input' where - pf ix = p (maskV ! ix) ? ( Just_ (index1 (to ! ix)), Nothing_ ) + pf ix = p (maskV ! ix) ? ( Just (index1 (to ! ix)), Nothing ) input' = backpermute (shape to `intersect` shape input) id input diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue288.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue288.hs index 84c18a7d5..101d5bede 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue288.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue288.hs @@ -23,7 +23,7 @@ import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty import Test.Tasty.HUnit -import Prelude as P +import Prelude as P hiding ( Bool(..) ) test_issue288 :: RunN -> TestTree @@ -31,7 +31,7 @@ test_issue288 runN = testCase "288" $ xs @=? runN (A.map f) xs f :: Exp (Int, Int) -> Exp (Int, Int) -f e = while (const (lift False)) id e +f e = while (const False) id e xs :: Vector (Int, Int) xs = fromList (Z:.10) (P.zip [1..] [1..]) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs index 1f42aa3f9..c69fb3c8c 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs @@ -24,7 +24,7 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue407 ( ) where -import Prelude as P +import Prelude as P hiding ( Bool(..) ) import Data.Array.Accelerate as A import Data.Array.Accelerate.Sugar.Elt as S diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue436.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue436.hs index 57420f447..3663bf9bc 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue436.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue436.hs @@ -18,11 +18,12 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue436 ( ) where -import Data.Array.Accelerate as A +import Data.Array.Accelerate as A import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty import Test.Tasty.HUnit +import Prelude hiding ( Bool(..) ) test_issue436 :: RunN -> TestTree diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue93.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue93.hs index 556b27ffb..a2bd5fdf8 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue93.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue93.hs @@ -22,6 +22,7 @@ import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty import Test.Tasty.HUnit +import Prelude hiding ( Maybe(..) ) test_issue93 :: RunN -> TestTree @@ -32,7 +33,7 @@ xs :: Array DIM2 Int xs = fromList (Z :. 1 :. 1) [5] test1 :: Acc (Array DIM2 Int) -test1 = permute (\c _ -> c) (fill (shape xs') (constant 0)) Just_ xs' +test1 = permute (\c _ -> c) (fill (shape xs') (constant 0)) Just xs' where xs' = use xs diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Permute.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Permute.hs index 91ebc7484..13131a118 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Permute.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Permute.hs @@ -39,7 +39,7 @@ import Test.Tasty import Test.Tasty.Hedgehog import System.IO.Unsafe -import Prelude as P +import Prelude as P hiding ( Bool(..), Maybe(..) ) import qualified Data.Set as Set diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs index 45a0de821..a7ea66761 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs @@ -39,8 +39,6 @@ import qualified Hedgehog.Gen as Gen import Test.Tasty import Test.Tasty.Hedgehog -import GHC.Exts as GHC - test_simd :: RunN -> TestTree test_simd runN = @@ -139,7 +137,7 @@ test_inject_v2 runN dim e = sh2 <- forAll dim xs <- forAll (array sh1 e) ys <- forAll (array sh2 e) - let !go = runN (A.zipWith A.V2) in go xs ys === zipWithRef (\x y -> GHC.fromList [x,y]) xs ys + let !go = runN (A.zipWith V2) in go xs ys === zipWithRef V2 xs ys test_inject_v3 :: (Shape sh, Show sh, Show e, Elt e, SIMD 3 e, P.Eq e, P.Eq sh) @@ -155,7 +153,7 @@ test_inject_v3 runN dim e = xs <- forAll (array sh1 e) ys <- forAll (array sh2 e) zs <- forAll (array sh3 e) - let !go = runN (A.zipWith3 A.V3) in go xs ys zs === zipWith3Ref (\x y z -> GHC.fromList [x,y,z]) xs ys zs + let !go = runN (A.zipWith3 V3) in go xs ys zs === zipWith3Ref V3 xs ys zs test_inject_v4 :: (Shape sh, Show sh, Show e, Elt e, SIMD 4 e, P.Eq e, P.Eq sh) @@ -173,35 +171,26 @@ test_inject_v4 runN dim e = ys <- forAll (array sh2 e) zs <- forAll (array sh3 e) ws <- forAll (array sh4 e) - let !go = runN (A.zipWith4 A.V4) in go xs ys zs ws === zipWith4Ref (\x y z w -> GHC.fromList [x,y,z,w]) xs ys zs ws + let !go = runN (A.zipWith4 V4) in go xs ys zs ws === zipWith4Ref V4 xs ys zs ws -unpackVec2 :: (Elt e, SIMD 2 e) => Vec2 e -> (e, e) -unpackVec2 v = - case GHC.toList v of - [a,b] -> (a, b) - _ -> undefined +unpackVec2 :: (Elt e, SIMD 2 e) => V2 e -> (e, e) +unpackVec2 (V2 a b) = (a, b) -unpackVec3 :: (Elt e, SIMD 3 e) => Vec3 e -> (e, e, e) -unpackVec3 v = - case GHC.toList v of - [a,b,c] -> (a, b, c) - _ -> undefined +unpackVec3 :: (Elt e, SIMD 3 e) => V3 e -> (e, e, e) +unpackVec3 (V3 a b c) = (a, b, c) -unpackVec4 :: (Elt e, SIMD 4 e) => Vec4 e -> (e, e, e, e) -unpackVec4 v = - case GHC.toList v of - [a,b,c,d] -> (a, b, c, d) - _ -> undefined +unpackVec4 :: (Elt e, SIMD 4 e) => V4 e -> (e, e, e, e) +unpackVec4 (V4 a b c d) = (a, b, c, d) -unpackVec2' :: (Elt e, SIMD 2 e) => Exp (Vec2 e) -> (Exp e, Exp e) -unpackVec2' (A.V2 a b) = (a, b) +unpackVec2' :: (Elt e, SIMD 2 e) => Exp (V2 e) -> (Exp e, Exp e) +unpackVec2' (V2 a b) = (a, b) -unpackVec3' :: (Elt e, SIMD 3 e) => Exp (Vec3 e) -> (Exp e, Exp e, Exp e) -unpackVec3' (A.V3 a b c) = (a, b, c) +unpackVec3' :: (Elt e, SIMD 3 e) => Exp (V3 e) -> (Exp e, Exp e, Exp e) +unpackVec3' (V3 a b c) = (a, b, c) -unpackVec4' :: (Elt e, SIMD 4 e) => Exp (Vec4 e) -> (Exp e, Exp e, Exp e, Exp e) -unpackVec4' (A.V4 a b c d) = (a, b, c, d) +unpackVec4' :: (Elt e, SIMD 4 e) => Exp (V4 e) -> (Exp e, Exp e, Exp e, Exp e) +unpackVec4' (V4 a b c d) = (a, b, c, d) -- Reference Implementation diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs index da012bc32..3c60fbdd6 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs @@ -25,7 +25,7 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Stencil ( ) where import Data.Typeable -import Prelude as P +import Prelude as P hiding ( Maybe(..), Either(..) ) import Data.Array.Accelerate as A import Data.Array.Accelerate.Sugar.Elt as S diff --git a/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs b/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs index f0e5ea996..2721e6820 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs @@ -24,7 +24,7 @@ module Data.Array.Accelerate.Test.NoFib.Spectral.RadixSort ( import Data.Function import Data.List ( sortBy ) -import Prelude as P +import Prelude as P hiding ( Maybe(..) ) import qualified Data.Bits as P import Data.Array.Accelerate as A @@ -176,7 +176,7 @@ radixsortBy rdx arr = foldr1 (>->) (P.map radixPass [0..p-1]) arr iup = A.map (size v - 1 -) . prescanr (+) 0 $ flags index = A.zipWith deal flags (A.zip idown iup) in - permute const v (\ix -> Just_ (index1 (index!ix))) v + permute const v (\ix -> Just (index1 (index!ix))) v -- This is rather slow. Speeding up the reference implementation by using, say, From 7d4b276ad33931101580a7bd751a16643e922ec9 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 28 Jun 2022 17:40:23 +0200 Subject: [PATCH 36/86] fix doctests --- src/Data/Array/Accelerate/Language.hs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index 255a5e9c2..b0ff4f9b6 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -88,6 +88,7 @@ module Data.Array.Accelerate.Language ( ) where import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.Pattern.Maybe import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Representation.Array ( ArrayR(..) ) import Data.Array.Accelerate.Representation.Shape ( ShapeR(..) ) @@ -106,7 +107,7 @@ import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord -import Prelude ( ($), (.), Maybe(..), Char ) +import Prelude ( ($), (.), Char ) -- $setup @@ -700,7 +701,7 @@ scanr1 f (Acc a) = Acc $ SmartAcc $ Scan RightToLeft (eltR @a) (unExpBinaryFunct -- let zeros = fill (constant (Z:.10)) 0 -- ones = fill (shape xs) 1 -- in --- permute (+) zeros (\ix -> Just_ (I1 (xs!ix))) ones +-- permute (+) zeros (\ix -> Just (I1 (xs!ix))) ones -- :} -- -- >>> let xs = fromList (Z :. 20) [0,0,1,2,1,1,2,4,8,3,4,9,8,3,2,5,5,3,1,2] :: Vector Int @@ -717,7 +718,7 @@ scanr1 f (Acc a) = Acc $ SmartAcc $ Scan RightToLeft (eltR @a) (unExpBinaryFunct -- let zeros = fill (I2 n n) 0 -- ones = fill (I1 n) 1 -- in --- permute const zeros (\(I1 i) -> Just_ (I2 i i)) ones +-- permute const zeros (\(I1 i) -> Just (I2 i i)) ones -- :} -- -- >>> run $ identity 5 :: Matrix Int From d90ee5193c08021f38594cfd74b38a861b1a0245 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 28 Jun 2022 17:57:39 +0200 Subject: [PATCH 37/86] build fixes --- src/Data/Array/Accelerate/Pattern/Bool.hs | 3 ++- src/Data/Array/Accelerate/Pattern/Either.hs | 3 ++- src/Data/Array/Accelerate/Pattern/Maybe.hs | 3 ++- src/Data/Array/Accelerate/Pattern/Ordering.hs | 3 ++- src/Data/Array/Accelerate/Pattern/SIMD.hs | 5 +++-- 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/Bool.hs b/src/Data/Array/Accelerate/Pattern/Bool.hs index 941238e41..d3dfbb0b0 100644 --- a/src/Data/Array/Accelerate/Pattern/Bool.hs +++ b/src/Data/Array/Accelerate/Pattern/Bool.hs @@ -29,7 +29,8 @@ import qualified Prelude as P import GHC.Stack -{-# COMPLETE False, True #-} +{-# COMPLETE False, True :: Exp #-} +{-# COMPLETE False, True :: Bool #-} pattern False :: (HasCallStack, IsFalse r) => r pattern False <- (matchFalse -> Just ()) where False = buildFalse diff --git a/src/Data/Array/Accelerate/Pattern/Either.hs b/src/Data/Array/Accelerate/Pattern/Either.hs index 8a02ba708..e94ab74a9 100644 --- a/src/Data/Array/Accelerate/Pattern/Either.hs +++ b/src/Data/Array/Accelerate/Pattern/Either.hs @@ -45,7 +45,8 @@ runQ $ do _ -> find pat ds decs <- filter it <$> mkPattern ''Either - rest <- [d| {-# COMPLETE Left, Right #-} + rest <- [d| {-# COMPLETE Left, Right :: Exp #-} + {-# COMPLETE Left, Right :: Either #-} pattern Left :: IsLeft a b r => a -> r pattern Left x <- (matchLeft -> P.Just x) where Left = buildLeft diff --git a/src/Data/Array/Accelerate/Pattern/Maybe.hs b/src/Data/Array/Accelerate/Pattern/Maybe.hs index 88aa1ea67..f7a7395bb 100644 --- a/src/Data/Array/Accelerate/Pattern/Maybe.hs +++ b/src/Data/Array/Accelerate/Pattern/Maybe.hs @@ -50,7 +50,8 @@ runQ $ do _ -> find pat ds decs <- filter it <$> mkPattern ''Maybe - rest <- [d| {-# COMPLETE Nothing, Just #-} + rest <- [d| {-# COMPLETE Nothing, Just :: Exp #-} + {-# COMPLETE Nothing, Just :: Maybe #-} pattern Nothing :: (HasCallStack, IsNothing r) => r pattern Nothing <- (matchNothing -> P.Just ()) where Nothing = buildNothing diff --git a/src/Data/Array/Accelerate/Pattern/Ordering.hs b/src/Data/Array/Accelerate/Pattern/Ordering.hs index 896639e51..41441733a 100644 --- a/src/Data/Array/Accelerate/Pattern/Ordering.hs +++ b/src/Data/Array/Accelerate/Pattern/Ordering.hs @@ -43,7 +43,8 @@ runQ $ do _ -> find pat ds decs <- filter it <$> mkPattern ''Ordering - rest <- [d| {-# COMPLETE LT, EQ, GT #-} + rest <- [d| {-# COMPLETE LT, EQ, GT :: Exp #-} + {-# COMPLETE LT, EQ, GT :: Ordering #-} pattern LT :: IsLT a => a pattern LT <- (matchLT -> Just ()) where LT = buildLT diff --git a/src/Data/Array/Accelerate/Pattern/SIMD.hs b/src/Data/Array/Accelerate/Pattern/SIMD.hs index bc1a9d479..8eb6d6677 100644 --- a/src/Data/Array/Accelerate/Pattern/SIMD.hs +++ b/src/Data/Array/Accelerate/Pattern/SIMD.hs @@ -40,7 +40,7 @@ import GHC.Exts ( IsList(..) pattern SIMD :: forall b a context. IsSIMD context a b => b -> context a pattern SIMD vars <- (vmatcher @context -> vars) where SIMD = vbuilder @context -{-# COMPLETE SIMD #-} +{-# COMPLETE SIMD :: Exp #-} class IsSIMD context a b where vbuilder :: b -> context a @@ -144,7 +144,8 @@ runQ $ sequence [ patSynSigD name [t| $(conT isV) $(varT a) $(varT v) => $(foldr (\t r -> [t| $t -> $r |]) (varT v) as) |] , patSynD name (prefixPatSyn xs) (explBidir [clause [] (normalB (varE builder)) []]) (parensP $ viewP (varE matcher) (tupP xsP)) - , pragCompleteD [name] Nothing + , pragCompleteD [name] (Just ''Vec) + , pragCompleteD [name] (Just ''Exp) -- , classD (return []) isV [PlainTV a (), PlainTV v ()] [funDep [v] [a]] [ sigD builder (foldr (\t r -> [t| $t -> $r |]) (varT v) as) From 53019c83c7840597a376bc6fcbad0f3c35d65f50 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Thu, 30 Jun 2022 10:55:05 +0200 Subject: [PATCH 38/86] build fix --- src/Data/Array/Accelerate/Pattern/SIMD.hs | 2 +- src/Language/Haskell/TH/Extra.hs | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Data/Array/Accelerate/Pattern/SIMD.hs b/src/Data/Array/Accelerate/Pattern/SIMD.hs index 8eb6d6677..aa608462d 100644 --- a/src/Data/Array/Accelerate/Pattern/SIMD.hs +++ b/src/Data/Array/Accelerate/Pattern/SIMD.hs @@ -147,7 +147,7 @@ runQ $ , pragCompleteD [name] (Just ''Vec) , pragCompleteD [name] (Just ''Exp) -- - , classD (return []) isV [PlainTV a (), PlainTV v ()] [funDep [v] [a]] + , classD (return []) isV [plainTV a, plainTV v] [funDep [v] [a]] [ sigD builder (foldr (\t r -> [t| $t -> $r |]) (varT v) as) , sigD matcher [t| $(varT v) -> $(tupT as) |] ] diff --git a/src/Language/Haskell/TH/Extra.hs b/src/Language/Haskell/TH/Extra.hs index 7688167fc..ab166f4d1 100644 --- a/src/Language/Haskell/TH/Extra.hs +++ b/src/Language/Haskell/TH/Extra.hs @@ -23,7 +23,7 @@ module Language.Haskell.TH.Extra ( #if MIN_VERSION_template_haskell(2,17,0) import Language.Haskell.TH hiding ( plainInvisTV, tupP, tupE ) #else -import Language.Haskell.TH hiding ( TyVarBndr, tupP, tupE ) +import Language.Haskell.TH hiding ( TyVarBndr, tupP, tupE, plainTV ) import Language.Haskell.TH.Syntax ( unTypeQ, unsafeTExpCoerce ) #if MIN_VERSION_template_haskell(2,16,0) import GHC.Exts ( RuntimeRep, TYPE ) @@ -71,6 +71,9 @@ tyVarBndrName :: TyVarBndr flag -> Name tyVarBndrName (PlainTV n) = n tyVarBndrName (KindedTV n _) = n +plainTV :: Name -> TyVarBndr () +plainTV = PlainTV + plainInvisTV :: Name -> Specificity -> TyVarBndr Specificity plainInvisTV n _ = PlainTV n From 1d16894e41afc12685d71385389d83ec1aad30d2 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Thu, 30 Jun 2022 10:55:16 +0200 Subject: [PATCH 39/86] warning police --- src/Data/Array/Accelerate/Data/Either.hs | 1 + src/Data/Array/Accelerate/Data/Maybe.hs | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/Data/Array/Accelerate/Data/Either.hs b/src/Data/Array/Accelerate/Data/Either.hs index 97b4c13f9..a2c4cf52e 100644 --- a/src/Data/Array/Accelerate/Data/Either.hs +++ b/src/Data/Array/Accelerate/Data/Either.hs @@ -127,6 +127,7 @@ instance (Eq a, Eq b) => Eq (Either a b) where instance (Ord a, Ord b) => Ord (Either a b) where compare = match go where + go :: Exp (Either a b) -> Exp (Either a b) -> Exp Ordering go (Left x) (Left y) = compare x y go (Right x) (Right y) = compare x y go Left{} Right{} = LT diff --git a/src/Data/Array/Accelerate/Data/Maybe.hs b/src/Data/Array/Accelerate/Data/Maybe.hs index 979ae0510..b42b5c0fc 100644 --- a/src/Data/Array/Accelerate/Data/Maybe.hs +++ b/src/Data/Array/Accelerate/Data/Maybe.hs @@ -125,6 +125,7 @@ instance Eq a => Eq (Maybe a) where instance Ord a => Ord (Maybe a) where compare = match go where + go :: Exp (Maybe a) -> Exp (Maybe a) -> Exp Ordering go (Just x) (Just y) = compare x y go Nothing Nothing = EQ go Nothing Just{} = LT @@ -136,6 +137,7 @@ instance (Monoid (Exp a), Elt a) => Monoid (Exp (Maybe a)) where instance (Semigroup (Exp a), Elt a) => Semigroup (Exp (Maybe a)) where (<>) = match go where + go :: Exp (Maybe a) -> Exp (Maybe a) -> Exp (Maybe a) go Nothing b = b go a Nothing = a go (Just a) (Just b) = Just (a <> b) From c1a32feeef08a63a0563f0c539c4dbfe313a1eeb Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Thu, 30 Jun 2022 14:29:35 +0200 Subject: [PATCH 40/86] warning police --- src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs index 3c60fbdd6..54ee71e7e 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs @@ -655,6 +655,7 @@ bound bnd sh0 ix0 = | otherwise = error "bound: expected shape with Int dimensions" + addDim :: Either e ds -> Either e d -> Either e (ds, d) Right ds `addDim` Right d = Right (ds, d) _ `addDim` Left e = Left e Left e `addDim` _ = Left e From f5482d8e28ae8569ec75e4814d0332257fefda9e Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Thu, 30 Jun 2022 14:29:41 +0200 Subject: [PATCH 41/86] build fix --- src/Data/Array/Accelerate/Test/NoFib/Issues/Issue437.hs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue437.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue437.hs index 956b6956e..a4d516deb 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue437.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue437.hs @@ -73,7 +73,10 @@ test_issue437 runN (alloc P.> 0 P.&& to P.== 4 P.&& from P.== 0) -- remote memory space where xs :: (Scalar Float, Matrix Float) - xs = runN $ T2 (unit 42) (fill (constant $ Z:.10000:.10000) 1) + xs = runN g + where + g :: Acc (Scalar Float, Matrix Float) + g = T2 (unit 42) (fill (Z :. 10000 :. 10000) 1) go :: Arrays a => a -> a go = runN f From 63a0bb72bbfc8bd4446f9aa416e7137f7a016434 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Wed, 31 Aug 2022 10:18:06 +0200 Subject: [PATCH 42/86] be a bit smarter --- src/Data/Array/Accelerate/Array/Data.hs | 32 ++++++++----------- .../Array/Accelerate/Representation/Elt.hs | 16 ++++------ src/Data/Primitive/Bit.hs | 26 ++++----------- 3 files changed, 26 insertions(+), 48 deletions(-) diff --git a/src/Data/Array/Accelerate/Array/Data.hs b/src/Data/Array/Accelerate/Array/Data.hs index 24227a75d..6ff993f68 100644 --- a/src/Data/Array/Accelerate/Array/Data.hs +++ b/src/Data/Array/Accelerate/Array/Data.hs @@ -143,16 +143,17 @@ newArrayData (TupRsingle _t) !size = scalar _t scalar (NumScalarType t) = num t scalar (BitScalarType t) = bit t + -- XXX: Arrays of BitMask are stored with each mask aligned to a byte + -- boundary, rather than being packed tightly together. This might be a bit + -- surprising if we want to cast between types? We don't support any + -- non-power-of-two sized integer types though so perhaps this can not come + -- up in practice, but it is a bit strange that a 'Vec 4 Bool' will require + -- trice as much memory as necessary. ---TLM 2022-08-30 + -- bit :: BitType t -> IO (MutableArrayData t) - bit TypeBit = let (q,r) = quotRem size 8 - in if r == 0 - then allocateArray q - else allocateArray (q+1) - bit (TypeMask n) = let k = fromInteger (natVal' n) - (q,r) = quotRem k 8 - in if r == 0 - then allocateArray (size * q) - else allocateArray (size * (q + 1)) + bit TypeBit = allocateArray ((size + 7) `quot` 8) + bit (TypeMask n) = let bytes = quot (fromInteger (natVal' n)+7) 8 + in allocateArray (size * bytes) num :: NumType t -> IO (MutableArrayData t) num (IntegralNumType t) = integral t @@ -567,17 +568,10 @@ liftArrayData n = tuple scalar (BitScalarType t) = bit t bit :: BitType e -> ArrayData e -> CodeQ (ArrayData e) - bit TypeBit ua = - let (q,r) = quotRem n 8 - in if r == 0 - then liftUniqueArray q ua - else liftUniqueArray (q+1) ua + bit TypeBit ua = liftUniqueArray ((n+7) `quot` 8) ua bit (TypeMask n') ua = - let k = fromInteger (natVal' n') - (q,r) = quotRem k 8 - in if r == 0 - then liftUniqueArray (n * q) ua - else liftUniqueArray (n * (q+1)) ua + let bytes = quot (fromInteger (natVal' n')+7) 8 + in liftUniqueArray (n * bytes) ua num :: NumType e -> ArrayData e -> CodeQ (ArrayData e) num (IntegralNumType t) = integral t diff --git a/src/Data/Array/Accelerate/Representation/Elt.hs b/src/Data/Array/Accelerate/Representation/Elt.hs index cdaed373f..b888074fa 100644 --- a/src/Data/Array/Accelerate/Representation/Elt.hs +++ b/src/Data/Array/Accelerate/Representation/Elt.hs @@ -45,13 +45,11 @@ undefElt = tuple bit :: BitType t -> t bit TypeBit = Bit False bit (TypeMask n) = - let (q, r) = quotRem (fromInteger (natVal' n)) 8 - bytes = if r == 0 then q else q + 1 - in - runST $ do - mba <- newByteArray bytes - ByteArray ba# <- unsafeFreezeByteArray mba - return $! Vec ba# + let bytes = quot (fromInteger (natVal' n) + 7) 8 + in runST $ do + mba <- newByteArray bytes + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# num :: NumType t -> t num (IntegralNumType t) = integral t @@ -114,9 +112,7 @@ bytesElt = tuple bit :: BitType t -> Int bit TypeBit = 1 -- stored as Word8 - bit (TypeMask n) = - let (q,r) = quotRem (fromInteger (natVal' n)) 8 - in if r == 0 then q else q+1 + bit (TypeMask n) = quot (fromInteger (natVal' n)+7) 8 num :: NumType t -> Int num (IntegralNumType t) = integral t diff --git a/src/Data/Primitive/Bit.hs b/src/Data/Primitive/Bit.hs index acfbc75d2..0d200cae4 100644 --- a/src/Data/Primitive/Bit.hs +++ b/src/Data/Primitive/Bit.hs @@ -120,11 +120,8 @@ instance KnownNat n => Foreign.Storable (BitMask n) where alignment _ = 1 sizeOf _ = - let k = fromIntegral (natVal' (proxy# :: Proxy# n)) - (q,r) = quotRem k 8 - in if r == 0 - then q - else q+1 + let k = fromIntegral (natVal' (proxy# :: Proxy# n)) + in quot (k + 7) 8 peek (Ptr addr#) = IO $ \s0 -> @@ -197,11 +194,8 @@ extract (BitMask (Vec ba#)) i@(I# i#) = insert :: forall n. KnownNat n => BitMask n -> Int -> Bit -> BitMask n insert (BitMask (Vec ba#)) i (Bit b) = runST $ do let n = fromInteger (natVal' (proxy# :: Proxy# n)) + bytes = quot (n+7) 8 (u,v) = quotRem i 8 - (q,r) = quotRem n 8 - bytes = if r == 0 - then q - else q + 1 -- mba <- newByteArray n copyByteArray mba 0 (ByteArray ba#) 0 bytes @@ -214,11 +208,8 @@ insert (BitMask (Vec ba#)) i (Bit b) = runST $ do {-# INLINE zeros #-} zeros :: forall n. KnownNat n => BitMask n zeros = - let n = fromInteger (natVal' (proxy# :: Proxy# n)) - (q,r) = quotRem n 8 - l = if r == 0 - then q - else q + 1 + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + l = quot (n+7) 8 in case byteArrayFromListN l (replicate l (0 :: Word8)) of ByteArray ba# -> BitMask (Vec ba#) @@ -226,11 +217,8 @@ zeros = {-# INLINE ones #-} ones :: forall n. KnownNat n => BitMask n ones = - let n = fromInteger (natVal' (proxy# :: Proxy# n)) - (q,r) = quotRem n 8 - l = if r == 0 - then q - else q + 1 + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + l = quot (n+7) 8 in case byteArrayFromListN l (replicate l (0xff :: Word8)) of ByteArray ba# -> BitMask (Vec ba#) From 13c835c4813db1d22014793b9b4e22907eb7dfcc Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Wed, 31 Aug 2022 10:45:42 +0200 Subject: [PATCH 43/86] add Integral instances for vector types --- src/Data/Array/Accelerate/Classes/Integral.hs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/Data/Array/Accelerate/Classes/Integral.hs b/src/Data/Array/Accelerate/Classes/Integral.hs index 110bd721d..fd2208c31 100644 --- a/src/Data/Array/Accelerate/Classes/Integral.hs +++ b/src/Data/Array/Accelerate/Classes/Integral.hs @@ -28,6 +28,7 @@ module Data.Array.Accelerate.Classes.Integral ( ) where import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Enum @@ -73,6 +74,15 @@ runQ $ quotRem = mkQuotRem divMod = mkDivMod toInteger = P.error "Prelude.toInteger not supported for Accelerate types" + + instance KnownNat n => P.Integral (Exp (Vec n $(conT a))) where + quot = mkQuot + rem = mkRem + div = mkIDiv + mod = mkMod + quotRem = mkQuotRem + divMod = mkDivMod + toInteger = P.error "Prelude.toInteger not supported for Accelerate types" |] in concat <$> mapM mkIntegral integralTypes From 8ea3e51947ba32c046a705d2b61cfc8066e29300 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Wed, 31 Aug 2022 10:46:20 +0200 Subject: [PATCH 44/86] copy-pasta error --- src/Data/Array/Accelerate/Classes/RealFrac.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Data/Array/Accelerate/Classes/RealFrac.hs b/src/Data/Array/Accelerate/Classes/RealFrac.hs index d13e3c570..07b90c63e 100644 --- a/src/Data/Array/Accelerate/Classes/RealFrac.hs +++ b/src/Data/Array/Accelerate/Classes/RealFrac.hs @@ -154,7 +154,7 @@ defaultFloor :: forall a b. (RealFrac a, Integral b, FromIntegral (Significand a defaultFloor x | Just FloatingDict <- floatingDict @a , Just IntegralDict <- integralDict @b - = mkCeiling x + = mkFloor x -- | otherwise = let T2 n r = properFraction x in cond (r < 0) (n-1) n @@ -163,7 +163,7 @@ defaultRound :: forall a b. (RealFrac a, Integral b, FromIntegral (Significand a defaultRound x | Just FloatingDict <- floatingDict @a , Just IntegralDict <- integralDict @b - = mkCeiling x + = mkRound x -- | otherwise = let T2 n r = properFraction x From 0b515628c0068b5c36419cdde4ebca9dddb08134 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Wed, 31 Aug 2022 10:47:06 +0200 Subject: [PATCH 45/86] drop unused file --- src/Data/Array/Accelerate/Classes/Vector.hs | 35 --------------------- 1 file changed, 35 deletions(-) delete mode 100644 src/Data/Array/Accelerate/Classes/Vector.hs diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs deleted file mode 100644 index e624261bb..000000000 --- a/src/Data/Array/Accelerate/Classes/Vector.hs +++ /dev/null @@ -1,35 +0,0 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE AllowAmbiguousTypes #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MonoLocalBinds #-} -{-# LANGUAGE FunctionalDependencies #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE GADTs #-} -{-# OPTIONS_GHC -fno-warn-orphans #-} --- | --- Module : Data.Array.Accelerate.Classes.Vector --- Copyright : [2016..2020] The Accelerate Team --- License : BSD3 --- --- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> --- Stability : experimental --- Portability : non-portable (GHC extensions) --- - -module Data.Array.Accelerate.Classes.Vector - where - -import GHC.TypeLits -import Data.Array.Accelerate.Sugar.Vec -import Data.Array.Accelerate.Smart -import Data.Primitive.Vec - -instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where - type IndexType (Exp (Vec n a)) = Exp Int - vecIndex = mkVectorIndex - vecWrite = mkVectorWrite - vecEmpty = undef - From fdac92ffb139be162032317dd82886ed7dcc1d55 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Wed, 31 Aug 2022 11:39:15 +0200 Subject: [PATCH 46/86] vectorise type of smart constructors for logical operations --- src/Data/Array/Accelerate/AST.hs | 2 +- src/Data/Array/Accelerate/Classes/VEq.hs | 4 +- src/Data/Array/Accelerate/Classes/VOrd.hs | 9 +- src/Data/Array/Accelerate/Smart.hs | 138 +++++++++++----------- 4 files changed, 76 insertions(+), 77 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index d3c765e75..00c20098f 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -90,7 +90,7 @@ module Data.Array.Accelerate.AST ( Exp, OpenExp(..), Boundary(..), PrimFun(..), - PrimBool, + PrimBool, PrimMask, PrimMaybe, BitOrMask, diff --git a/src/Data/Array/Accelerate/Classes/VEq.hs b/src/Data/Array/Accelerate/Classes/VEq.hs index a24ebf433..8ec9a347d 100644 --- a/src/Data/Array/Accelerate/Classes/VEq.hs +++ b/src/Data/Array/Accelerate/Classes/VEq.hs @@ -121,8 +121,8 @@ runQ $ do mkPrim :: Name -> Q [Dec] mkPrim name = [d| instance KnownNat n => VEq n $(conT name) where - (==*) = mkPrimBinary $ PrimEq scalarType - (/=*) = mkPrimBinary $ PrimNEq scalarType + (==*) = mkEq + (/=*) = mkNEq |] mkTup :: Word8 -> Q Dec diff --git a/src/Data/Array/Accelerate/Classes/VOrd.hs b/src/Data/Array/Accelerate/Classes/VOrd.hs index b83af2c0b..82e608a53 100644 --- a/src/Data/Array/Accelerate/Classes/VOrd.hs +++ b/src/Data/Array/Accelerate/Classes/VOrd.hs @@ -19,7 +19,6 @@ module Data.Array.Accelerate.Classes.VOrd ( ) where -import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.VEq import Data.Array.Accelerate.Representation.Tag @@ -114,10 +113,10 @@ runQ $ do mkPrim :: Name -> Q [Dec] mkPrim name = [d| instance KnownNat n => VOrd n $(conT name) where - (<*) = mkPrimBinary $ PrimLt scalarType - (>*) = mkPrimBinary $ PrimGt scalarType - (<=*) = mkPrimBinary $ PrimLtEq scalarType - (>=*) = mkPrimBinary $ PrimGtEq scalarType + (<*) = mkLt + (>*) = mkGt + (<=*) = mkLtEq + (>=*) = mkGtEq vmin = mkMin vmax = mkMax |] diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 54419f438..98acbf225 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -1225,98 +1225,98 @@ select (Exp mask) (Exp tt) (Exp ff) = Exp $ go (vecR @n @a) tt ff -- Operators from Floating -mkSin :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkSin :: IsFloating (EltR t) => Exp t -> Exp t mkSin = mkPrimUnary $ PrimSin floatingType -mkCos :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkCos :: IsFloating (EltR t) => Exp t -> Exp t mkCos = mkPrimUnary $ PrimCos floatingType -mkTan :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkTan :: IsFloating (EltR t) => Exp t -> Exp t mkTan = mkPrimUnary $ PrimTan floatingType -mkAsin :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkAsin :: IsFloating (EltR t) => Exp t -> Exp t mkAsin = mkPrimUnary $ PrimAsin floatingType -mkAcos :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkAcos :: IsFloating (EltR t) => Exp t -> Exp t mkAcos = mkPrimUnary $ PrimAcos floatingType -mkAtan :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkAtan :: IsFloating (EltR t) => Exp t -> Exp t mkAtan = mkPrimUnary $ PrimAtan floatingType -mkSinh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkSinh :: IsFloating (EltR t) => Exp t -> Exp t mkSinh = mkPrimUnary $ PrimSinh floatingType -mkCosh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkCosh :: IsFloating (EltR t) => Exp t -> Exp t mkCosh = mkPrimUnary $ PrimCosh floatingType -mkTanh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkTanh :: IsFloating (EltR t) => Exp t -> Exp t mkTanh = mkPrimUnary $ PrimTanh floatingType -mkAsinh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkAsinh :: IsFloating (EltR t) => Exp t -> Exp t mkAsinh = mkPrimUnary $ PrimAsinh floatingType -mkAcosh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkAcosh :: IsFloating (EltR t) => Exp t -> Exp t mkAcosh = mkPrimUnary $ PrimAcosh floatingType -mkAtanh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkAtanh :: IsFloating (EltR t) => Exp t -> Exp t mkAtanh = mkPrimUnary $ PrimAtanh floatingType -mkExpFloating :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkExpFloating :: IsFloating (EltR t) => Exp t -> Exp t mkExpFloating = mkPrimUnary $ PrimExpFloating floatingType -mkSqrt :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkSqrt :: IsFloating (EltR t) => Exp t -> Exp t mkSqrt = mkPrimUnary $ PrimSqrt floatingType -mkLog :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkLog :: IsFloating (EltR t) => Exp t -> Exp t mkLog = mkPrimUnary $ PrimLog floatingType -mkFPow :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t +mkFPow :: IsFloating (EltR t) => Exp t -> Exp t -> Exp t mkFPow = mkPrimBinary $ PrimFPow floatingType -mkLogBase :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t +mkLogBase :: IsFloating (EltR t) => Exp t -> Exp t -> Exp t mkLogBase = mkPrimBinary $ PrimLogBase floatingType -- Operators from Num -mkAdd :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t -> Exp t +mkAdd :: IsNum (EltR t) => Exp t -> Exp t -> Exp t mkAdd = mkPrimBinary $ PrimAdd numType -mkSub :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t -> Exp t +mkSub :: IsNum (EltR t) => Exp t -> Exp t -> Exp t mkSub = mkPrimBinary $ PrimSub numType -mkMul :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t -> Exp t +mkMul :: IsNum (EltR t) => Exp t -> Exp t -> Exp t mkMul = mkPrimBinary $ PrimMul numType -mkNeg :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t +mkNeg :: IsNum (EltR t) => Exp t -> Exp t mkNeg = mkPrimUnary $ PrimNeg numType -mkAbs :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t +mkAbs :: IsNum (EltR t) => Exp t -> Exp t mkAbs = mkPrimUnary $ PrimAbs numType -mkSig :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t +mkSig :: IsNum (EltR t) => Exp t -> Exp t mkSig = mkPrimUnary $ PrimSig numType -- Operators from Integral -mkQuot :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +mkQuot :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t mkQuot = mkPrimBinary $ PrimQuot integralType -mkRem :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +mkRem :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t mkRem = mkPrimBinary $ PrimRem integralType -mkQuotRem :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> (Exp t, Exp t) +mkQuotRem :: IsIntegral (EltR t) => Exp t -> Exp t -> (Exp t, Exp t) mkQuotRem (Exp x) (Exp y) = let pair = SmartExp $ PrimQuotRem integralType `PrimApp` SmartExp (Pair x y) in ( mkExp $ Prj PairIdxLeft pair , mkExp $ Prj PairIdxRight pair) -mkIDiv :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +mkIDiv :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t mkIDiv = mkPrimBinary $ PrimIDiv integralType -mkMod :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +mkMod :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t mkMod = mkPrimBinary $ PrimMod integralType -mkDivMod :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> (Exp t, Exp t) +mkDivMod :: IsIntegral (EltR t) => Exp t -> Exp t -> (Exp t, Exp t) mkDivMod (Exp x) (Exp y) = let pair = SmartExp $ PrimDivMod integralType `PrimApp` SmartExp (Pair x y) in ( mkExp $ Prj PairIdxLeft pair @@ -1324,128 +1324,128 @@ mkDivMod (Exp x) (Exp y) = -- Operators from Bits and FiniteBits -mkBAnd :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +mkBAnd :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t mkBAnd = mkPrimBinary $ PrimBAnd integralType -mkBOr :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +mkBOr :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t mkBOr = mkPrimBinary $ PrimBOr integralType -mkBXor :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +mkBXor :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t mkBXor = mkPrimBinary $ PrimBXor integralType -mkBNot :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t +mkBNot :: IsIntegral (EltR t) => Exp t -> Exp t mkBNot = mkPrimUnary $ PrimBNot integralType -mkBShiftL :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +mkBShiftL :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t mkBShiftL = mkPrimBinary $ PrimBShiftL integralType -mkBShiftR :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +mkBShiftR :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t mkBShiftR = mkPrimBinary $ PrimBShiftR integralType -mkBRotateL :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +mkBRotateL :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t mkBRotateL = mkPrimBinary $ PrimBRotateL integralType -mkBRotateR :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +mkBRotateR :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t mkBRotateR = mkPrimBinary $ PrimBRotateR integralType -mkPopCount :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t +mkPopCount :: IsIntegral (EltR t) => Exp t -> Exp t mkPopCount = mkPrimUnary $ PrimPopCount integralType -mkCountLeadingZeros :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t +mkCountLeadingZeros :: IsIntegral (EltR t) => Exp t -> Exp t mkCountLeadingZeros = mkPrimUnary $ PrimCountLeadingZeros integralType -mkCountTrailingZeros :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t +mkCountTrailingZeros :: IsIntegral (EltR t) => Exp t -> Exp t mkCountTrailingZeros = mkPrimUnary $ PrimCountTrailingZeros integralType -- Operators from Fractional -mkFDiv :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t +mkFDiv :: IsFloating (EltR t) => Exp t -> Exp t -> Exp t mkFDiv = mkPrimBinary $ PrimFDiv floatingType -mkRecip :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t +mkRecip :: IsFloating (EltR t) => Exp t -> Exp t mkRecip = mkPrimUnary $ PrimRecip floatingType -- Operators from RealFrac -mkTruncate :: (Elt a, Elt b, IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b +mkTruncate :: (IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b mkTruncate = mkPrimUnary $ PrimTruncate floatingType integralType -mkRound :: (Elt a, Elt b, IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b +mkRound :: (IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b mkRound = mkPrimUnary $ PrimRound floatingType integralType -mkFloor :: (Elt a, Elt b, IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b +mkFloor :: (IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b mkFloor = mkPrimUnary $ PrimFloor floatingType integralType -mkCeiling :: (Elt a, Elt b, IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b +mkCeiling :: (IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b mkCeiling = mkPrimUnary $ PrimCeiling floatingType integralType -- Operators from RealFloat -mkAtan2 :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t +mkAtan2 :: IsFloating (EltR t) => Exp t -> Exp t -> Exp t mkAtan2 = mkPrimBinary $ PrimAtan2 floatingType -mkIsNaN :: (Elt t, IsFloating (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp Bool +mkIsNaN :: (IsFloating (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp b mkIsNaN = mkPrimUnary $ PrimIsNaN floatingType -mkIsInfinite :: (Elt t, IsFloating (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp Bool +mkIsInfinite :: (IsFloating (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp b mkIsInfinite = mkPrimUnary $ PrimIsInfinite floatingType -- Relational and equality operators -mkLt :: (Elt t, IsScalar (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp Bool +mkLt :: (IsScalar (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp t -> Exp b mkLt = mkPrimBinary $ PrimLt scalarType -mkGt :: (Elt t, IsScalar (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp Bool +mkGt :: (IsScalar (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp t -> Exp b mkGt = mkPrimBinary $ PrimGt scalarType -mkLtEq :: (Elt t, IsScalar (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp Bool +mkLtEq :: (IsScalar (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp t -> Exp b mkLtEq = mkPrimBinary $ PrimLtEq scalarType -mkGtEq :: (Elt t, IsScalar (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp Bool +mkGtEq :: (IsScalar (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp t -> Exp b mkGtEq = mkPrimBinary $ PrimGtEq scalarType -mkEq :: (Elt t, IsScalar (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp Bool +mkEq :: (IsScalar (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp t -> Exp b mkEq = mkPrimBinary $ PrimEq scalarType -mkNEq :: (Elt t, IsScalar (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp Bool +mkNEq :: (IsScalar (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp t -> Exp b mkNEq = mkPrimBinary $ PrimNEq scalarType -mkMax :: (Elt t, IsScalar (EltR t)) => Exp t -> Exp t -> Exp t +mkMax :: IsScalar (EltR t) => Exp t -> Exp t -> Exp t mkMax = mkPrimBinary $ PrimMax scalarType -mkMin :: (Elt t, IsScalar (EltR t)) => Exp t -> Exp t -> Exp t +mkMin :: IsScalar (EltR t) => Exp t -> Exp t -> Exp t mkMin = mkPrimBinary $ PrimMin scalarType -- Logical operators -mkLAnd :: (Elt t, IsBit (EltR t)) => Exp t -> Exp t -> Exp t +mkLAnd :: IsBit (EltR t) => Exp t -> Exp t -> Exp t mkLAnd = mkPrimBinary $ PrimLAnd bitType -mkLOr :: (Elt t, IsBit (EltR t)) => Exp t -> Exp t -> Exp t +mkLOr :: IsBit (EltR t) => Exp t -> Exp t -> Exp t mkLOr = mkPrimBinary $ PrimLOr bitType -mkLNot :: (Elt t, IsBit (EltR t)) => Exp t -> Exp t +mkLNot :: IsBit (EltR t) => Exp t -> Exp t mkLNot = mkPrimUnary $ PrimLNot bitType -- Numeric conversions -mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b +mkFromIntegral :: (IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b mkFromIntegral = mkPrimUnary $ PrimFromIntegral integralType numType -mkToFloating :: (Elt a, Elt b, IsNum (EltR a), IsFloating (EltR b)) => Exp a -> Exp b +mkToFloating :: (IsNum (EltR a), IsFloating (EltR b)) => Exp a -> Exp b mkToFloating = mkPrimUnary $ PrimToFloating numType floatingType -mkToBool :: (Elt a, IsSingleIntegral (EltR a), BitOrMask (EltR a) ~ Bit) => Exp a -> Exp Bool +mkToBool :: (IsSingleIntegral (EltR a), BitOrMask (EltR a) ~ Bit) => Exp a -> Exp Bool mkToBool = mkPrimUnary $ PrimToBool (SingleIntegralType singleIntegralType) bitType -mkFromBool :: (Elt a, IsSingleIntegral (EltR a), BitOrMask (EltR a) ~ Bit) => Exp Bool -> Exp a +mkFromBool :: (IsSingleIntegral (EltR a), BitOrMask (EltR a) ~ Bit) => Exp Bool -> Exp a mkFromBool = mkPrimUnary $ PrimFromBool bitType (SingleIntegralType singleIntegralType) -- Other conversions -- NOTE: Restricted to scalar types with a type-level BitSizeEq constraint to -- make this version "safe" -mkBitcast :: forall b a. (Elt a, Elt b, IsScalar (EltR a), IsScalar (EltR b), BitSizeEq (EltR a) (EltR b)) => Exp a -> Exp b +mkBitcast :: forall b a. (IsScalar (EltR a), IsScalar (EltR b), BitSizeEq (EltR a) (EltR b)) => Exp a -> Exp b mkBitcast (Exp a) = mkExp $ Coerce (scalarType @(EltR a)) (scalarType @(EltR b)) a mkCoerce :: Coerce (EltR a) (EltR b) => Exp a -> Exp b @@ -1507,16 +1507,16 @@ mkExp = Exp . SmartExp unExp :: Exp e -> SmartExp (EltR e) unExp (Exp e) = e -unExpFunction :: (Elt a, Elt b) => (Exp a -> Exp b) -> SmartExp (EltR a) -> SmartExp (EltR b) +unExpFunction :: (Exp a -> Exp b) -> SmartExp (EltR a) -> SmartExp (EltR b) unExpFunction f = unExp . f . Exp -unExpBinaryFunction :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> Exp c) -> SmartExp (EltR a) -> SmartExp (EltR b) -> SmartExp (EltR c) +unExpBinaryFunction :: (Exp a -> Exp b -> Exp c) -> SmartExp (EltR a) -> SmartExp (EltR b) -> SmartExp (EltR c) unExpBinaryFunction f a b = unExp $ f (Exp a) (Exp b) -mkPrimUnary :: (Elt a, Elt b) => PrimFun (EltR a -> EltR b) -> Exp a -> Exp b +mkPrimUnary :: PrimFun (EltR a -> EltR b) -> Exp a -> Exp b mkPrimUnary prim (Exp a) = mkExp $ PrimApp prim a -mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c +mkPrimBinary :: PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b) unPair :: SmartExp (a, b) -> (SmartExp a, SmartExp b) From 31b8fbab73f6993d25e287cc031461509c436eb7 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Wed, 31 Aug 2022 11:43:42 +0200 Subject: [PATCH 47/86] NOTE --- src/Data/Array/Accelerate/Array/Data.hs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/Data/Array/Accelerate/Array/Data.hs b/src/Data/Array/Accelerate/Array/Data.hs index 6ff993f68..0956c7926 100644 --- a/src/Data/Array/Accelerate/Array/Data.hs +++ b/src/Data/Array/Accelerate/Array/Data.hs @@ -150,6 +150,10 @@ newArrayData (TupRsingle _t) !size = scalar _t -- up in practice, but it is a bit strange that a 'Vec 4 Bool' will require -- trice as much memory as necessary. ---TLM 2022-08-30 -- + -- XXX: Actually this is a problem as both BitSize(Bit) and BitSize(Vec 1 + -- Bit) are a single bit, so we should be able to coerce between them, but + -- they will be stored differently. ---TLM 2022-08-31 + -- bit :: BitType t -> IO (MutableArrayData t) bit TypeBit = allocateArray ((size + 7) `quot` 8) bit (TypeMask n) = let bytes = quot (fromInteger (natVal' n)+7) 8 From 1d1eacfc37eb2efcd064ae7fb0033ed57f66004d Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 20 Sep 2022 18:41:10 +0200 Subject: [PATCH 48/86] vectorised RealFloat --- src/Data/Array/Accelerate/AST.hs | 9 +- .../Array/Accelerate/Classes/RealFloat.hs | 501 +++++++++++++----- src/Data/Array/Accelerate/Smart.hs | 9 +- src/Data/Array/Accelerate/Type.hs | 6 +- 4 files changed, 368 insertions(+), 157 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 00c20098f..4e48ea1ba 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -647,9 +647,12 @@ data OpenExp env aenv t where Undef :: ScalarType t -> OpenExp env aenv t - -- Reinterpret the bits of a value as a different type - Coerce :: BitSizeEq a b - => ScalarType a + -- Reinterpret the bits of a value as a different type. + -- + -- The types must have the same bit size, but that constraint is not include + -- at this point because GHC's typelits solver is often not powerful enough to + -- discharge that constraint. ---TLM 2022-09-20 + Coerce :: ScalarType a -> ScalarType b -> OpenExp env aenv a -> OpenExp env aenv b diff --git a/src/Data/Array/Accelerate/Classes/RealFloat.hs b/src/Data/Array/Accelerate/Classes/RealFloat.hs index 8312f20af..22157a185 100644 --- a/src/Data/Array/Accelerate/Classes/RealFloat.hs +++ b/src/Data/Array/Accelerate/Classes/RealFloat.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -27,10 +27,12 @@ module Data.Array.Accelerate.Classes.RealFloat ( ) where -import Data.Array.Accelerate.Error +import Data.Array.Accelerate.AST ( BitOrMask, PrimMask ) import Data.Array.Accelerate.Language ( (^), cond, while ) import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt ( Elt(..) ) +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import Data.Array.Accelerate.Data.Bits @@ -42,140 +44,240 @@ import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.RealFrac +import Data.Array.Accelerate.Classes.VEq +import Data.Array.Accelerate.Classes.VOrd -import Data.Text.Lazy.Builder -import Formatting -import Prelude ( (.), ($), String, error, undefined, unlines, otherwise ) +import Data.Coerce +import Data.Kind import Text.Printf +import Prelude ( (.), ($), String, error, undefined, unlines ) import qualified Prelude as P -- | Efficient, machine-independent access to the components of a floating-point -- number -- -class (RealFrac a, Floating a) => RealFloat a where +class (RealFrac a, Floating a, Integral (Exponent a)) => RealFloat a where + type Exponent a :: Type + -- | The radix of the representation (often 2) (constant) - floatRadix :: Exp a -> Exp Int -- Integer - default floatRadix :: P.RealFloat a => Exp a -> Exp Int - floatRadix _ = P.fromInteger (P.floatRadix (undefined::a)) + floatRadix :: Exp a -> Exp Int -- | The number of digits of 'floatRadix' in the significand (constant) - floatDigits :: Exp a -> Exp Int - default floatDigits :: P.RealFloat a => Exp a -> Exp Int - floatDigits _ = constant (P.floatDigits (undefined::a)) + floatDigits :: Exp a -> Exp Int -- | The lowest and highest values the exponent may assume (constant) - floatRange :: Exp a -> Exp (Int, Int) - default floatRange :: P.RealFloat a => Exp a -> Exp (Int, Int) - floatRange _ = constant $ P.floatRange (undefined::a) + floatRange :: Exp a -> Exp (Int, Int) -- | Return the significand and an appropriately scaled exponent. If -- @(m,n) = 'decodeFloat' x@ then @x = m*b^^n@, where @b@ is the -- floating-point radix ('floatRadix'). Furthermore, either @m@ and @n@ are -- both zero, or @b^(d-1) <= 'abs' m < b^d@, where @d = 'floatDigits' x@. - decodeFloat :: Exp a -> Exp (Significand a, Int) + decodeFloat :: Exp a -> Exp (Significand a, Exponent a) -- | Inverse of 'decodeFloat' - encodeFloat :: Exp (Significand a) -> Exp Int -> Exp a - default encodeFloat :: (FromIntegral Int a, FromIntegral (Significand a) a) => Exp (Significand a) -> Exp Int -> Exp a - encodeFloat x e = fromIntegral x * (fromIntegral (floatRadix (undefined :: Exp a)) ** fromIntegral e) + encodeFloat :: Exp (Significand a) -> Exp (Exponent a) -> Exp a -- | Corresponds to the second component of 'decodeFloat' - exponent :: Exp a -> Exp Int - exponent x = let T2 m n = decodeFloat x - in cond (m == 0) 0 (n + floatDigits x) + exponent :: Exp a -> Exp (Exponent a) - -- | Corresponds to the first component of 'decodeFloat' - significand :: Exp a -> Exp a - significand x = let T2 m _ = decodeFloat x - in encodeFloat m (negate (floatDigits x)) + -- | The first component of 'decodeFloat', scaled to lie in the open interval (-1,1). + significand :: Exp a -> Exp a -- | Multiply a floating point number by an integer power of the radix - scaleFloat :: Exp Int -> Exp a -> Exp a - scaleFloat k x = cond (k == 0 || isFix) x (encodeFloat m (n + clamp b)) - where - isFix = x == 0 || isNaN x || isInfinite x - T2 m n = decodeFloat x - T2 l h = floatRange x - d = floatDigits x - b = h - l + 4*d - -- n+k may overflow, which would lead to incorrect results, hence we clamp - -- the scaling parameter. If (n+k) would be larger than h, (n + clamp b k) - -- must be too, similar for smaller than (l-d). - clamp bd = max (-bd) (min bd k) + scaleFloat :: Exp Int -> Exp a -> Exp a -- | 'True' if the argument is an IEEE \"not-a-number\" (NaN) value - isNaN :: Exp a -> Exp Bool + isNaN :: (BitOrMask (EltR a) ~ EltR b) => Exp a -> Exp b -- | 'True' if the argument is an IEEE infinity or negative-infinity - isInfinite :: Exp a -> Exp Bool + isInfinite :: (BitOrMask (EltR a) ~ EltR b) => Exp a -> Exp b - -- | 'True' if the argument is too small to be represented in normalized - -- format - isDenormalized :: Exp a -> Exp Bool + -- | 'True' if the argument is too small to be represented in normalized format + isDenormalized :: (BitOrMask (EltR a) ~ EltR b) => Exp a -> Exp b -- | 'True' if the argument is an IEEE negative zero - isNegativeZero :: Exp a -> Exp Bool + isNegativeZero :: (BitOrMask (EltR a) ~ EltR b) => Exp a -> Exp b -- | 'True' if the argument is an IEEE floating point number - isIEEE :: Exp a -> Exp Bool - default isIEEE :: P.RealFloat a => Exp a -> Exp Bool - isIEEE _ = constant (P.isIEEE (undefined::a)) + isIEEE :: Exp a -> Exp Bool -- | A version of arctangent taking two real floating-point arguments. -- For real floating @x@ and @y@, @'atan2' y x@ computes the angle (from the -- positive x-axis) of the vector from the origin to the point @(x,y)@. -- @'atan2' y x@ returns a value in the range [@-pi@, @pi@]. - atan2 :: Exp a -> Exp a -> Exp a + atan2 :: Exp a -> Exp a -> Exp a -instance RealFrac Half where - type Significand Half = Int16 +instance RealFrac Float16 where + type Significand Float16 = Int16 properFraction = defaultProperFraction -instance RealFrac Float where - type Significand Float = Int32 +instance RealFrac Float32 where + type Significand Float32 = Int32 properFraction = defaultProperFraction -instance RealFrac Double where - type Significand Double = Int64 +instance RealFrac Float64 where + type Significand Float64 = Int64 properFraction = defaultProperFraction instance RealFrac Float128 where type Significand Float128 = Int128 properFraction = defaultProperFraction -instance RealFloat Half where +instance KnownNat n => RealFrac (Vec n Float16) where + type Significand (Vec n Float16) = Vec n Int16 + properFraction = defaultProperFraction' + +instance KnownNat n => RealFrac (Vec n Float32) where + type Significand (Vec n Float32) = Vec n Int32 + properFraction = defaultProperFraction' + +instance KnownNat n => RealFrac (Vec n Float64) where + type Significand (Vec n Float64) = Vec n Int64 + properFraction = defaultProperFraction' + +instance KnownNat n => RealFrac (Vec n Float128) where + type Significand (Vec n Float128) = Vec n Int128 + properFraction = defaultProperFraction' + +instance RealFloat Float16 where + type Exponent Float16 = Int + floatRadix = defaultFloatRadix + floatDigits = defaultFloatDigits + floatRange = defaultFloatRange + encodeFloat s = mkCoerce . defaultEncodeFloat @1 @Float16 (mkBitcast s) . mkBitcast + exponent = mkCoerce . defaultExponent . mkBitcast @(Vec 1 Float16) + significand = mkCoerce . defaultSignificand . mkBitcast @(Vec 1 Float16) + scaleFloat k = mkCoerce . defaultScaleFloat k . mkBitcast @(Vec 1 Float16) atan2 = mkAtan2 isNaN = mkIsNaN isInfinite = mkIsInfinite - isDenormalized = ieee754 "isDenormalized" (ieee754_f16_is_denormalized . mkBitcast) - isNegativeZero = ieee754 "isNegativeZero" (ieee754_f16_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (ieee754_f16_decode . mkBitcast) - -instance RealFloat Float where + isIEEE = defaultIsIEEE + isDenormalized = mkCoerce . ieee754_f16_is_denormalized . mkBitcast @(Vec 1 Word16) + isNegativeZero = mkCoerce . ieee754_f16_is_negative_zero . mkBitcast @(Vec 1 Word16) + decodeFloat = mkCoerce . ieee754_f16_decode . mkBitcast @(Vec 1 Word16) + +instance RealFloat Float32 where + type Exponent Float32 = Int + floatRadix = defaultFloatRadix + floatDigits = defaultFloatDigits + floatRange = defaultFloatRange + encodeFloat s = mkCoerce . defaultEncodeFloat @1 @Float32 (mkBitcast s) . mkBitcast + exponent = mkCoerce . defaultExponent . mkBitcast @(Vec 1 Float32) + significand = mkCoerce . defaultSignificand . mkBitcast @(Vec 1 Float32) + scaleFloat k = mkCoerce . defaultScaleFloat k . mkBitcast @(Vec 1 Float32) atan2 = mkAtan2 isNaN = mkIsNaN isInfinite = mkIsInfinite - isDenormalized = ieee754 "isDenormalized" (ieee754_f32_is_denormalized . mkBitcast) - isNegativeZero = ieee754 "isNegativeZero" (ieee754_f32_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (ieee754_f32_decode . mkBitcast) - -instance RealFloat Double where + isIEEE = defaultIsIEEE + isDenormalized = mkCoerce . ieee754_f32_is_denormalized . mkBitcast @(Vec 1 Word32) + isNegativeZero = mkCoerce . ieee754_f32_is_negative_zero . mkBitcast @(Vec 1 Word32) + decodeFloat = mkCoerce . ieee754_f32_decode . mkBitcast @(Vec 1 Word32) + +instance RealFloat Float64 where + type Exponent Float64 = Int + floatRadix = defaultFloatRadix + floatDigits = defaultFloatDigits + floatRange = defaultFloatRange + encodeFloat s = mkCoerce . defaultEncodeFloat @1 @Float64 (mkBitcast s) . mkBitcast + exponent = mkCoerce . defaultExponent . mkBitcast @(Vec 1 Float64) + significand = mkCoerce . defaultSignificand . mkBitcast @(Vec 1 Float64) + scaleFloat k = mkCoerce . defaultScaleFloat k . mkBitcast @(Vec 1 Float64) atan2 = mkAtan2 isNaN = mkIsNaN isInfinite = mkIsInfinite - isDenormalized = ieee754 "isDenormalized" (ieee754_f64_is_denormalized . mkBitcast) - isNegativeZero = ieee754 "isNegativeZero" (ieee754_f64_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (ieee754_f64_decode . mkBitcast) + isIEEE = defaultIsIEEE + isDenormalized = mkCoerce . ieee754_f64_is_denormalized . mkBitcast @(Vec 1 Word64) + isNegativeZero = mkCoerce . ieee754_f64_is_negative_zero . mkBitcast @(Vec 1 Word64) + decodeFloat = mkCoerce . ieee754_f64_decode . mkBitcast @(Vec 1 Word64) instance RealFloat Float128 where + type Exponent Float128 = Int + floatRadix = defaultFloatRadix + floatDigits = defaultFloatDigits + floatRange = defaultFloatRange + encodeFloat s = mkCoerce . defaultEncodeFloat @1 @Float128 (mkBitcast s) . mkBitcast + exponent = mkCoerce . defaultExponent . mkBitcast @(Vec 1 Float128) + significand = mkCoerce . defaultSignificand . mkBitcast @(Vec 1 Float128) + scaleFloat k = mkCoerce . defaultScaleFloat k . mkBitcast @(Vec 1 Float128) + atan2 = mkAtan2 + isNaN = mkIsNaN + isInfinite = mkIsInfinite + isIEEE = defaultIsIEEE + isDenormalized = mkCoerce . ieee754_f128_is_denormalized . mkBitcast @(Vec 1 Word128) + isNegativeZero = mkCoerce . ieee754_f128_is_negative_zero . mkBitcast @(Vec 1 Word128) + decodeFloat = mkCoerce . ieee754_f128_decode . mkBitcast @(Vec 1 Word128) + +instance KnownNat n => RealFloat (Vec n Float16) where + type Exponent (Vec n Float16) = Vec n Int + floatRadix _ = defaultFloatRadix (undefined :: Exp Float16) + floatDigits _ = defaultFloatDigits (undefined :: Exp Float16) + floatRange _ = defaultFloatRange (undefined :: Exp Float16) + decodeFloat = ieee754_f16_decode . mkBitcast' + encodeFloat = defaultEncodeFloat + exponent = defaultExponent + significand = defaultSignificand + scaleFloat = defaultScaleFloat + isNaN = mkIsNaN + isInfinite = mkIsInfinite + isDenormalized = coerce . ieee754_f16_is_denormalized . mkBitcast' + isNegativeZero = coerce . ieee754_f16_is_negative_zero . mkBitcast' + isIEEE _ = defaultIsIEEE (undefined :: Exp Float16) atan2 = mkAtan2 + + +instance KnownNat n => RealFloat (Vec n Float32) where + type Exponent (Vec n Float32) = Vec n Int + floatRadix _ = defaultFloatRadix (undefined :: Exp Float32) + floatDigits _ = defaultFloatDigits (undefined :: Exp Float32) + floatRange _ = defaultFloatRange (undefined :: Exp Float32) + decodeFloat = ieee754_f32_decode . mkBitcast' + encodeFloat = defaultEncodeFloat + exponent = defaultExponent + significand = defaultSignificand + scaleFloat = defaultScaleFloat + isNaN = mkIsNaN + isInfinite = mkIsInfinite + isDenormalized = coerce . ieee754_f32_is_denormalized . mkBitcast' + isNegativeZero = coerce . ieee754_f32_is_negative_zero . mkBitcast' + isIEEE _ = defaultIsIEEE (undefined :: Exp Float32) + atan2 = mkAtan2 + +instance KnownNat n => RealFloat (Vec n Float64) where + type Exponent (Vec n Float64) = Vec n Int + floatRadix _ = defaultFloatRadix (undefined :: Exp Float64) + floatDigits _ = defaultFloatDigits (undefined :: Exp Float64) + floatRange _ = defaultFloatRange (undefined :: Exp Float64) + decodeFloat = ieee754_f64_decode . mkBitcast' + encodeFloat = defaultEncodeFloat + exponent = defaultExponent + significand = defaultSignificand + scaleFloat = defaultScaleFloat isNaN = mkIsNaN isInfinite = mkIsInfinite - isDenormalized = ieee754 "isDenormalized" (ieee754_f128_is_denormalized . mkBitcast) - isNegativeZero = ieee754 "isNegativeZero" (ieee754_f128_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (ieee754_f128_decode . mkBitcast) + isDenormalized = coerce . ieee754_f64_is_denormalized . mkBitcast' + isNegativeZero = coerce . ieee754_f64_is_negative_zero . mkBitcast' + isIEEE _ = defaultIsIEEE (undefined :: Exp Float64) + atan2 = mkAtan2 + + +instance KnownNat n => RealFloat (Vec n Float128) where + type Exponent (Vec n Float128) = Vec n Int + floatRadix _ = defaultFloatRadix (undefined :: Exp Float128) + floatDigits _ = defaultFloatDigits (undefined :: Exp Float128) + floatRange _ = defaultFloatRange (undefined :: Exp Float128) + decodeFloat = ieee754_f128_decode . mkBitcast' + encodeFloat = defaultEncodeFloat + exponent = defaultExponent + significand = defaultSignificand + scaleFloat = defaultScaleFloat + isNaN = mkIsNaN + isInfinite = mkIsInfinite + isDenormalized = coerce . ieee754_f128_is_denormalized . mkBitcast' + isNegativeZero = coerce . ieee754_f128_is_negative_zero . mkBitcast' + isIEEE _ = defaultIsIEEE (undefined :: Exp Float128) + atan2 = mkAtan2 -- To satisfy superclass constraints @@ -201,11 +303,75 @@ preludeError x , "constraints for subsequent classes in the standard Haskell numeric hierarchy." ] -ieee754 :: forall a b. HasCallStack => P.RealFloat a => Builder -> (Exp a -> b) -> Exp a -> b -ieee754 name f x - | P.isIEEE (undefined::a) = f x - | otherwise = internalError (builder % ": Not implemented for non-IEEE floating point") name +-- GHC's type level natural normalisation isn't strong enough to deduce (n * 32) == (n * 32) +mkBitcast' + :: forall b a n. (IsScalar (VecR n a), IsScalar (VecR n b), BitSizeEq (EltR a) (EltR b)) + => Exp (Vec n a) + -> Exp (Vec n b) +mkBitcast' (Exp a) = mkExp $ Coerce (scalarType @(VecR n a)) (scalarType @(VecR n b)) a + +splat :: (KnownNat n, SIMD n a, Elt a) => Exp a -> Exp (Vec n a) +splat x = mkPack (P.repeat x) + +defaultFloatRadix :: forall a. P.RealFloat a => Exp a -> Exp Int +defaultFloatRadix _ = P.fromInteger (P.floatRadix (undefined::a)) + +defaultFloatDigits :: forall a. P.RealFloat a => Exp a -> Exp Int +defaultFloatDigits _ = constant (P.floatDigits (undefined::a)) + +defaultFloatRange :: forall a. P.RealFloat a => Exp a -> Exp (Int, Int) +defaultFloatRange _ = constant (P.floatRange (undefined::a)) + +defaultIsIEEE :: forall a. P.RealFloat a => Exp a -> Exp Bool +defaultIsIEEE _ = constant (P.isIEEE (undefined::a)) + +defaultEncodeFloat + :: forall n a. (SIMD n a, RealFloat a, RealFloat (Vec n a), FromIntegral Int a, FromIntegral (Significand (Vec n a)) (Vec n a), FromIntegral (Exponent (Vec n a)) (Vec n a)) + => Exp (Significand (Vec n a)) + -> Exp (Exponent (Vec n a)) + -> Exp (Vec n a) +defaultEncodeFloat x e = + let d = splat (fromIntegral (floatRadix (undefined :: Exp a))) + in fromIntegral x * (d ** fromIntegral e) + +defaultExponent + :: (RealFloat (Vec n a), Significand (Vec n a) ~ Vec n s, Exponent (Vec n a) ~ Vec n Int, VEq n a, VEq n s) + => Exp (Vec n a) + -> Exp (Vec n Int) +defaultExponent x = + let T2 m n = decodeFloat x + d = splat (floatDigits x) + in + select (m ==* 0) 0 (n + d) + +defaultSignificand + :: (RealFloat (Vec n a), Exponent (Vec n a) ~ Vec n Int, KnownNat n) + => Exp (Vec n a) + -> Exp (Vec n a) +defaultSignificand x = + let T2 m _ = decodeFloat x + d = splat (floatDigits x) + in encodeFloat m (negate d) + +defaultScaleFloat + :: (RealFloat (Vec n a), Exponent (Vec n a) ~ Vec n Int, BitOrMask (EltR (Vec n a)) ~ PrimMask n, VEq n a) + => Exp Int + -> Exp (Vec n a) + -> Exp (Vec n a) +defaultScaleFloat k x = + select (k' ==* 0 ||* isFix) x (encodeFloat m (n + clamp b)) + where + k' = splat k + isFix = x ==* 0 ||* isNaN x ||* isInfinite x + T2 m n = decodeFloat x + T2 l h = floatRange x + d = floatDigits x + b = splat (h - l + 4*d) + -- n+k may overflow, which would lead to incorrect results, hence we clamp + -- the scaling parameter. If (n+k) would be larger than h, (n + clamp b k) + -- must be too, similar for smaller than (l-d). + clamp bd = max (-bd) (min bd k') -- Must test for ±0.0 to avoid returning -0.0 in the second component of the -- pair. Unfortunately the branching costs a lot of performance. @@ -236,6 +402,32 @@ defaultProperFraction x = T2 m n = decodeFloat x (q, r) = quotRem m (2 ^ (negate n)) +-- defaultProperFraction' +-- :: (SIMD n a, SIMD n b, RealFloat (Vec n a), Exponent (Vec n a) ~ Vec n Int, FromIntegral (Significand (Vec n a)) (Vec n b), Integral (Vec n b)) +-- => Exp (Vec n a) +-- -> Exp (Vec n b, Vec n a) +-- defaultProperFraction' x = +-- T2 (select p (fromIntegral m * (2 ^ n)) (fromIntegral q)) +-- (select p 0.0 (encodeFloat r n)) +-- where +-- T2 m n = decodeFloat x +-- (q, r) = quotRem m (2 ^ (negate n)) +-- p = n >=* 0 + +-- This is a bit weird because we really want to apply the function late-wise, +-- but there isn't really a way we can do that. Boo. ---TLM 2022-09-20 +-- +defaultProperFraction' + :: (SIMD n a, RealFloat (Vec n a), Exponent (Vec n a) ~ Vec n Int, FromIntegral (Significand (Vec n a)) b, Integral b) + => Exp (Vec n a) + -> Exp (b, Vec n a) +defaultProperFraction' x = + T2 (cond (n >= 0) (fromIntegral m * (2 ^ n)) (fromIntegral q)) + (select (n >=* 0) 0.0 (encodeFloat r n)) + where + T2 m n = decodeFloat x + (q, r) = quotRem m (2 ^ (negate n)) + -- From: ghc/libraries/base/cbits/primFloat.c -- ------------------------------------------ @@ -245,51 +437,51 @@ defaultProperFraction x = -- * mantissa is non-zero. -- * (don't care about setting of sign bit.) -- -ieee754_f128_is_denormalized :: Exp Word128 -> Exp Bool +ieee754_f128_is_denormalized :: KnownNat n => Exp (Vec n Word128) -> Exp (Vec n Bool) ieee754_f128_is_denormalized x = - ieee754_f128_mantissa x == 0 && - ieee754_f128_exponent x /= 0 + ieee754_f128_mantissa x ==* 0 &&* + ieee754_f128_exponent x /=* 0 -ieee754_f64_is_denormalized :: Exp Word64 -> Exp Bool +ieee754_f64_is_denormalized :: KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Bool) ieee754_f64_is_denormalized x = - ieee754_f64_mantissa x == 0 && - ieee754_f64_exponent x /= 0 + ieee754_f64_mantissa x ==* 0 &&* + ieee754_f64_exponent x /=* 0 -ieee754_f32_is_denormalized :: Exp Word32 -> Exp Bool +ieee754_f32_is_denormalized :: KnownNat n => Exp (Vec n Word32) -> Exp (Vec n Bool) ieee754_f32_is_denormalized x = - ieee754_f32_mantissa x == 0 && - ieee754_f32_exponent x /= 0 + ieee754_f32_mantissa x ==* 0 &&* + ieee754_f32_exponent x /=* 0 -ieee754_f16_is_denormalized :: Exp Word16 -> Exp Bool +ieee754_f16_is_denormalized :: KnownNat n => Exp (Vec n Word16) -> Exp (Vec n Bool) ieee754_f16_is_denormalized x = - ieee754_f16_mantissa x == 0 && - ieee754_f16_exponent x /= 0 + ieee754_f16_mantissa x ==* 0 &&* + ieee754_f16_exponent x /=* 0 -- Negative zero if only the sign bit is set -- -ieee754_f128_is_negative_zero :: Exp Word128 -> Exp Bool +ieee754_f128_is_negative_zero :: KnownNat n => Exp (Vec n Word128) -> Exp (Vec n Bool) ieee754_f128_is_negative_zero x = - ieee754_f128_negative x && - ieee754_f128_exponent x == 0 && - ieee754_f128_mantissa x == 0 + ieee754_f128_negative x &&* + ieee754_f128_exponent x ==* 0 &&* + ieee754_f128_mantissa x ==* 0 -ieee754_f64_is_negative_zero :: Exp Word64 -> Exp Bool +ieee754_f64_is_negative_zero :: KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Bool) ieee754_f64_is_negative_zero x = - ieee754_f64_negative x && - ieee754_f64_exponent x == 0 && - ieee754_f64_mantissa x == 0 + ieee754_f64_negative x &&* + ieee754_f64_exponent x ==* 0 &&* + ieee754_f64_mantissa x ==* 0 -ieee754_f32_is_negative_zero :: Exp Word32 -> Exp Bool +ieee754_f32_is_negative_zero :: KnownNat n => Exp (Vec n Word32) -> Exp (Vec n Bool) ieee754_f32_is_negative_zero x = - ieee754_f32_negative x && - ieee754_f32_exponent x == 0 && - ieee754_f32_mantissa x == 0 + ieee754_f32_negative x &&* + ieee754_f32_exponent x ==* 0 &&* + ieee754_f32_mantissa x ==* 0 -ieee754_f16_is_negative_zero :: Exp Word16 -> Exp Bool +ieee754_f16_is_negative_zero :: KnownNat n => Exp (Vec n Word16) -> Exp (Vec n Bool) ieee754_f16_is_negative_zero x = - ieee754_f16_negative x && - ieee754_f16_exponent x == 0 && - ieee754_f16_mantissa x == 0 + ieee754_f16_negative x &&* + ieee754_f16_exponent x ==* 0 &&* + ieee754_f16_mantissa x ==* 0 -- Assume the host processor stores integers and floating point numbers in the @@ -302,13 +494,13 @@ ieee754_f16_is_negative_zero x = -- exponent 126-112 exponent (biased by 16383) -- fraction 111-0 fraction (bits to right of binary part) -- -ieee754_f128_mantissa :: Exp Word128 -> Exp Word128 +ieee754_f128_mantissa :: KnownNat n => Exp (Vec n Word128) -> Exp (Vec n Word128) ieee754_f128_mantissa x = x .&. 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFF -ieee754_f128_exponent :: Exp Word128 -> Exp Word16 +ieee754_f128_exponent :: KnownNat n => Exp (Vec n Word128) -> Exp (Vec n Word16) ieee754_f128_exponent x = fromIntegral (x `unsafeShiftR` 112) .&. 0x7FFF -ieee754_f128_negative :: Exp Word128 -> Exp Bool +ieee754_f128_negative :: KnownNat n => Exp (Vec n Word128) -> Exp (Vec n Bool) ieee754_f128_negative x = testBit x 127 -- Representation of a double precision IEEE floating point number: @@ -317,13 +509,13 @@ ieee754_f128_negative x = testBit x 127 -- exponent 62-52 exponent (biased by 1023) -- fraction 51-0 fraction (bits to right of binary point) -- -ieee754_f64_mantissa :: Exp Word64 -> Exp Word64 +ieee754_f64_mantissa :: KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Word64) ieee754_f64_mantissa x = x .&. 0xFFFFFFFFFFFFF -ieee754_f64_exponent :: Exp Word64 -> Exp Word16 +ieee754_f64_exponent :: KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Word16) ieee754_f64_exponent x = fromIntegral (x `unsafeShiftR` 52) .&. 0x7FF -ieee754_f64_negative :: Exp Word64 -> Exp Bool +ieee754_f64_negative :: KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Bool) ieee754_f64_negative x = testBit x 63 -- Representation of a single precision IEEE floating point number: @@ -332,13 +524,13 @@ ieee754_f64_negative x = testBit x 63 -- exponent 30-23 exponent (biased by 127) -- fraction 22-0 fraction (bits to right of binary point) -- -ieee754_f32_mantissa :: Exp Word32 -> Exp Word32 +ieee754_f32_mantissa :: KnownNat n => Exp (Vec n Word32) -> Exp (Vec n Word32) ieee754_f32_mantissa x = x .&. 0x7FFFFF -ieee754_f32_exponent :: Exp Word32 -> Exp Word8 +ieee754_f32_exponent :: KnownNat n => Exp (Vec n Word32) -> Exp (Vec n Word8) ieee754_f32_exponent x = fromIntegral (x `unsafeShiftR` 23) -ieee754_f32_negative :: Exp Word32 -> Exp Bool +ieee754_f32_negative :: KnownNat n => Exp (Vec n Word32) -> Exp (Vec n Bool) ieee754_f32_negative x = testBit x 31 -- Representation of a half precision IEEE floating point number: @@ -347,26 +539,26 @@ ieee754_f32_negative x = testBit x 31 -- exponent 14-10 exponent (biased by 15) -- fraction 9-0 fraction (bits to right of binary point) -- -ieee754_f16_mantissa :: Exp Word16 -> Exp Word16 +ieee754_f16_mantissa :: KnownNat n => Exp (Vec n Word16) -> Exp (Vec n Word16) ieee754_f16_mantissa x = x .&. 0x3FF -ieee754_f16_exponent :: Exp Word16 -> Exp Word8 +ieee754_f16_exponent :: KnownNat n => Exp (Vec n Word16) -> Exp (Vec n Word8) ieee754_f16_exponent x = fromIntegral (x `unsafeShiftR` 10) .&. 0x1F -ieee754_f16_negative :: Exp Word16 -> Exp Bool +ieee754_f16_negative :: KnownNat n => Exp (Vec n Word16) -> Exp (Vec n Bool) ieee754_f16_negative x = testBit x 15 -- reverse engineered following the below -ieee754_f16_decode :: Exp Word16 -> Exp (Int16, Int) +ieee754_f16_decode :: forall n. KnownNat n => Exp (Vec n Word16) -> Exp (Vec n Int16, Vec n Int) ieee754_f16_decode i = let _HHIGHBIT = 0x0400 _HMSBIT = 0x8000 - _HMINEXP = ((_HALF_MIN_EXP) - (_HALF_MANT_DIG) - 1) - _HALF_MANT_DIG = floatDigits (undefined::Exp Half) - T2 _HALF_MIN_EXP _HALF_MAX_EXP = floatRange (undefined::Exp Half) + _HMINEXP = splat ((_HALF_MIN_EXP) - (_HALF_MANT_DIG) - 1) + _HALF_MANT_DIG = floatDigits (undefined::Exp Float16) + T2 _HALF_MIN_EXP _HALF_MAX_EXP = floatRange (undefined::Exp Float16) high1 = fromIntegral i high2 = high1 .&. (_HHIGHBIT - 1) @@ -380,27 +572,37 @@ ieee754_f16_decode i = (T2 (high2 .|. _HHIGHBIT) exp1) -- a denorm, normalise the mantissa (while (\(T2 h _) -> (h .&. _HHIGHBIT) /= 0) - (\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1)) + (\(T2 h e) -> let p = (h .&. _HHIGHBIT) /=* 0 + in T2 (select p (h `unsafeShiftL` 1) h) + (select p (e-1) e)) (T2 high2 exp2)) - high4 = cond (fromIntegral i < (0 :: Exp Int16)) (-high3) high3 + high4 = select (fromIntegral i <* (0 :: Exp (Vec n Int16))) (-high3) high3 + z = high1 .&. complement _HMSBIT ==* 0 in - cond (high1 .&. complement _HMSBIT == 0) - (T2 0 0) - (T2 high4 exp3) + T2 (select z 0 high4) + (select z 0 exp3) -- From: ghc/rts/StgPrimFloat.c -- ---------------------------- +-- +-- The fast-path (no denormalised values) looks good to me, but if any one of +-- the lanes contains a denormalised value then all lanes need to continue +-- looping in predicated style to normalise the mantissa. We do a bit of +-- redundant work here that could be avoided; maybe the codegen will clean that +-- up for us, but if not it shouldn't matter too much anyway (slow path...). +-- -- TLM 2022-09-20. +-- -ieee754_f32_decode :: Exp Word32 -> Exp (Int32, Int) +ieee754_f32_decode :: forall n. KnownNat n => Exp (Vec n Word32) -> Exp (Vec n Int32, Vec n Int) ieee754_f32_decode i = let _FHIGHBIT = 0x00800000 _FMSBIT = 0x80000000 - _FMINEXP = ((_FLT_MIN_EXP) - (_FLT_MANT_DIG) - 1) - _FLT_MANT_DIG = floatDigits (undefined::Exp Float) - T2 _FLT_MIN_EXP _FLT_MAX_EXP = floatRange (undefined::Exp Float) + _FMINEXP = splat ((_FLT_MIN_EXP) - (_FLT_MANT_DIG) - 1) + _FLT_MANT_DIG = floatDigits (undefined::Exp Float32) + T2 _FLT_MIN_EXP _FLT_MAX_EXP = floatRange (undefined::Exp Float32) high1 = fromIntegral i high2 = high1 .&. (_FHIGHBIT - 1) @@ -414,35 +616,37 @@ ieee754_f32_decode i = (T2 (high2 .|. _FHIGHBIT) exp1) -- a denorm, normalise the mantissa (while (\(T2 h _) -> (h .&. _FHIGHBIT) /= 0) - (\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1)) + (\(T2 h e) -> let p = (h .&. _FHIGHBIT) /=* 0 + in T2 (select p (h `unsafeShiftL` 1) h) + (select p (e-1) e)) (T2 high2 exp2)) - high4 = cond (fromIntegral i < (0 :: Exp Int32)) (-high3) high3 + high4 = select (fromIntegral i <* (0 :: Exp (Vec n Int32))) (-high3) high3 + z = high1 .&. complement _FMSBIT ==* 0 in - cond (high1 .&. complement _FMSBIT == 0) - (T2 0 0) - (T2 high4 exp3) + T2 (select z 0 high4) + (select z 0 exp3) -ieee754_f64_decode :: Exp Word64 -> Exp (Int64, Int) +ieee754_f64_decode :: KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Int64, Vec n Int) ieee754_f64_decode i = let T4 s h l e = ieee754_f64_decode2 i in T2 (fromIntegral s * (fromIntegral h `unsafeShiftL` 32 .|. fromIntegral l)) e -ieee754_f64_decode2 :: Exp Word64 -> Exp (Int, Word32, Word32, Int) +ieee754_f64_decode2 :: forall n. KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Int64, Vec n Word32, Vec n Word32, Vec n Int) ieee754_f64_decode2 i = let _DHIGHBIT = 0x00100000 _DMSBIT = 0x80000000 - _DMINEXP = ((_DBL_MIN_EXP) - (_DBL_MANT_DIG) - 1) - _DBL_MANT_DIG = floatDigits (undefined::Exp Double) - T2 _DBL_MIN_EXP _DBL_MAX_EXP = floatRange (undefined::Exp Double) + _DMINEXP = splat ((_DBL_MIN_EXP) - (_DBL_MANT_DIG) - 1) + _DBL_MANT_DIG = floatDigits (undefined::Exp Float64) + T2 _DBL_MIN_EXP _DBL_MAX_EXP = floatRange (undefined::Exp Float64) low = fromIntegral i high = fromIntegral (i `unsafeShiftR` 32) iexp = (fromIntegral ((high `unsafeShiftR` 20) .&. 0x7FF) + _DMINEXP) - sign = cond (fromIntegral i < (0 :: Exp Int64)) (-1) 1 + sign = select (fromIntegral i <* (0 :: Exp (Vec n Int64))) (-1) 1 high2 = high .&. (_DHIGHBIT - 1) iexp2 = iexp + 1 @@ -454,16 +658,21 @@ ieee754_f64_decode2 i = -- a denorm, nermalise the mantissa (while (\(T3 h _ _) -> (h .&. _DHIGHBIT) /= 0) (\(T3 h l e) -> - let h1 = h `unsafeShiftL` 1 + let p = (h .&. _DHIGHBIT) /=* 0 + h1 = h `unsafeShiftL` 1 h2 = cond ((l .&. _DMSBIT) /= 0) (h1+1) h1 - in T3 h2 (l `unsafeShiftL` 1) (e-1)) + in T3 (select p h2 h) + (select p (l `unsafeShiftL` 1) l) + (select p (e-1) e)) (T3 high2 low iexp2)) + z = low ==* 0 &&* (high .&. (complement _DMSBIT)) ==* 0 in - cond (low == 0 && (high .&. (complement _DMSBIT)) == 0) - (T4 1 0 0 0) - (T4 sign hi lo ie) + T4 (select z 1 sign) + (select z 0 hi) + (select z 0 lo) + (select z 0 ie) -ieee754_f128_decode :: Exp Word128 -> Exp (Int128, Int) -ieee754_f128_decode = undefined +ieee754_f128_decode :: KnownNat n => Exp (Vec n Word128) -> Exp (Vec n Int128, Vec n Int) +ieee754_f128_decode = error "TODO: ieee754_f128_decode" diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 98acbf225..7a6ca85c3 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -615,8 +615,7 @@ data PreSmartExp acc exp t where Undef :: ScalarType t -> PreSmartExp acc exp t - Coerce :: BitSizeEq a b - => ScalarType a + Coerce :: ScalarType a -> ScalarType b -> exp a -> PreSmartExp acc exp b @@ -1011,8 +1010,10 @@ mkPack xs = let go :: Word8 -> [Exp a] -> Exp (Vec n a) -> Exp (Vec n a) go _ [] vec = vec go i (v:vs) vec = go (i+1) vs (insert vec (constant i) v) + -- + n = fromIntegral (natVal' (proxy# :: Proxy# n)) in - go 0 xs undef + go 0 (take n xs) undef -- | Extract a single scalar element from the given SIMD vector at the -- specified index @@ -1443,8 +1444,6 @@ mkFromBool = mkPrimUnary $ PrimFromBool bitType (SingleIntegralType singleIntegr -- Other conversions --- NOTE: Restricted to scalar types with a type-level BitSizeEq constraint to --- make this version "safe" mkBitcast :: forall b a. (IsScalar (EltR a), IsScalar (EltR b), BitSizeEq (EltR a) (EltR b)) => Exp a -> Exp b mkBitcast (Exp a) = mkExp $ Coerce (scalarType @(EltR a)) (scalarType @(EltR b)) a diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 7e603d0be..acfe32e0a 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -76,7 +76,7 @@ import Numeric.Half import Text.Printf import GHC.Prim -import GHC.TypeLits +import GHC.TypeNats type Float16 = Half type Float32 = Float @@ -412,7 +412,7 @@ runQ $ do scalarType = NumScalarType numType type instance BitSize $t = $(litT (numTyLit bits)) - type instance BitSize (Vec n $t) = n GHC.TypeLits.* $(litT (numTyLit bits)) + type instance BitSize (Vec n $t) = n GHC.TypeNats.* $(litT (numTyLit bits)) |] mkFloating :: Name -> Integer -> Q [Dec] @@ -442,7 +442,7 @@ runQ $ do scalarType = NumScalarType numType type instance BitSize $t = $(litT (numTyLit bits)) - type instance BitSize (Vec n $t) = n GHC.TypeLits.* $(litT (numTyLit bits)) + type instance BitSize (Vec n $t) = n GHC.TypeNats.* $(litT (numTyLit bits)) |] ss <- mapM (mkIntegral "Int") integralTypes From 6f3376e0c01dbb2c6e44ebd4aaffff57d392e2ea Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Wed, 21 Sep 2022 11:04:34 +0200 Subject: [PATCH 49/86] updates for vectorised RealFloat --- src/Data/Array/Accelerate/Classes/Rational.hs | 2 +- src/Data/Array/Accelerate/Data/Complex.hs | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Data/Array/Accelerate/Classes/Rational.hs b/src/Data/Array/Accelerate/Classes/Rational.hs index bed30c67d..bb5b05cc2 100644 --- a/src/Data/Array/Accelerate/Classes/Rational.hs +++ b/src/Data/Array/Accelerate/Classes/Rational.hs @@ -75,7 +75,7 @@ integralToRational integralToRational x = fromIntegral x :% 1 floatingToRational - :: (RealFloat a, Integral b, FromIntegral (Significand a) b, FromIntegral Int (Significand a), FiniteBits (Significand a)) + :: (RealFloat a, Integral b, FromIntegral (Significand a) b, FromIntegral Int (Significand a), FiniteBits (Significand a), FromIntegral (Exponent a) (Significand a)) => Exp a -> Exp (Ratio b) floatingToRational x = fromIntegral u :% fromIntegral v diff --git a/src/Data/Array/Accelerate/Data/Complex.hs b/src/Data/Array/Accelerate/Data/Complex.hs index f10a10249..8826cdbab 100644 --- a/src/Data/Array/Accelerate/Data/Complex.hs +++ b/src/Data/Array/Accelerate/Data/Complex.hs @@ -44,7 +44,7 @@ module Data.Array.Accelerate.Data.Complex ( ) where -import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.AST ( PrimFun(..), BitOrMask ) import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.Floating @@ -278,7 +278,7 @@ instance Eq a => Eq (Complex a) where r1 ::+ c1 == r2 ::+ c2 = r1 == r2 && c1 == c2 r1 ::+ c1 /= r2 ::+ c2 = r1 /= r2 || c1 /= c2 -instance RealFloat a => P.Num (Exp (Complex a)) where +instance (RealFloat a, Exponent a ~ Int) => P.Num (Exp (Complex a)) where (+) = case complexR (eltR @a) of ComplexTup -> lift2 ((+) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) ComplexVec t -> mkPrimBinary $ PrimAdd t @@ -297,7 +297,7 @@ instance RealFloat a => P.Num (Exp (Complex a)) where abs z = magnitude z ::+ 0 fromInteger n = fromInteger n ::+ 0 -instance RealFloat a => P.Fractional (Exp (Complex a)) where +instance (RealFloat a, Exponent a ~ Int) => P.Fractional (Exp (Complex a)) where fromRational x = fromRational x ::+ 0 z / z' = (x*x''+y*y'') / d ::+ (y*x''-x*y'') / d where @@ -309,7 +309,7 @@ instance RealFloat a => P.Fractional (Exp (Complex a)) where k = - max (exponent x') (exponent y') d = x'*x'' + y'*y'' -instance RealFloat a => P.Floating (Exp (Complex a)) where +instance (RealFloat a, Exponent a ~ Int, BitOrMask (EltR a) ~ Bit) => P.Floating (Exp (Complex a)) where pi = pi ::+ 0 exp (x ::+ y) = let expx = exp x in expx * cos y ::+ expx * sin y @@ -387,7 +387,7 @@ instance Functor Complex where -- | The non-negative magnitude of a complex number -- -magnitude :: RealFloat a => Exp (Complex a) -> Exp a +magnitude :: (RealFloat a, Exponent a ~ Int) => Exp (Complex a) -> Exp a magnitude (r ::+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloat mk i))) where k = max (exponent r) (exponent i) @@ -406,8 +406,8 @@ magnitude' (r ::+ i) = sqrt (r*r + i*i) -- magnitude is zero, then so is the phase. -- phase :: RealFloat a => Exp (Complex a) -> Exp a -phase z@(r ::+ i) = - if z == 0 +phase (r ::+ i) = + if r == 0 && i == 0 then 0 else atan2 i r @@ -415,7 +415,7 @@ phase z@(r ::+ i) = -- phase) pair in canonical form: the magnitude is non-negative, and the phase -- in the range @(-'pi', 'pi']@; if the magnitude is zero, then so is the phase. -- -polar :: RealFloat a => Exp (Complex a) -> Exp (a,a) +polar :: (RealFloat a, Exponent a ~ Int) => Exp (Complex a) -> Exp (a,a) polar z = T2 (magnitude z) (phase z) -- | Form a complex number from polar components of magnitude and phase. From 00819a25010646e61dbc801b325199c74cde96b7 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Thu, 29 Sep 2022 17:45:19 +0200 Subject: [PATCH 50/86] pack BitMask densely rather than being byte-aligned This ensures that we can coerce between `Bit` and `Vec 1 Bit` --- src/Data/Array/Accelerate/Array/Data.hs | 30 ++--- src/Data/Primitive/Bit.hs | 164 ++++++++++++++++++++---- 2 files changed, 146 insertions(+), 48 deletions(-) diff --git a/src/Data/Array/Accelerate/Array/Data.hs b/src/Data/Array/Accelerate/Array/Data.hs index 0956c7926..0204b87e5 100644 --- a/src/Data/Array/Accelerate/Array/Data.hs +++ b/src/Data/Array/Accelerate/Array/Data.hs @@ -143,21 +143,9 @@ newArrayData (TupRsingle _t) !size = scalar _t scalar (NumScalarType t) = num t scalar (BitScalarType t) = bit t - -- XXX: Arrays of BitMask are stored with each mask aligned to a byte - -- boundary, rather than being packed tightly together. This might be a bit - -- surprising if we want to cast between types? We don't support any - -- non-power-of-two sized integer types though so perhaps this can not come - -- up in practice, but it is a bit strange that a 'Vec 4 Bool' will require - -- trice as much memory as necessary. ---TLM 2022-08-30 - -- - -- XXX: Actually this is a problem as both BitSize(Bit) and BitSize(Vec 1 - -- Bit) are a single bit, so we should be able to coerce between them, but - -- they will be stored differently. ---TLM 2022-08-31 - -- bit :: BitType t -> IO (MutableArrayData t) bit TypeBit = allocateArray ((size + 7) `quot` 8) - bit (TypeMask n) = let bytes = quot (fromInteger (natVal' n)+7) 8 - in allocateArray (size * bytes) + bit (TypeMask n) = allocateArray (((size * fromInteger (natVal' n)) + 7) `quot` 8) num :: NumType t -> IO (MutableArrayData t) num (IntegralNumType t) = integral t @@ -560,7 +548,7 @@ mallocPlainForeignPtrBytesAligned (I# size#) = IO $ \s0 -> liftArrayData :: Int -> TypeR e -> ArrayData e -> CodeQ (ArrayData e) -liftArrayData n = tuple +liftArrayData !size = tuple where tuple :: TypeR e -> ArrayData e -> CodeQ (ArrayData e) tuple TupRunit () = [|| () ||] @@ -572,10 +560,8 @@ liftArrayData n = tuple scalar (BitScalarType t) = bit t bit :: BitType e -> ArrayData e -> CodeQ (ArrayData e) - bit TypeBit ua = liftUniqueArray ((n+7) `quot` 8) ua - bit (TypeMask n') ua = - let bytes = quot (fromInteger (natVal' n')+7) 8 - in liftUniqueArray (n * bytes) ua + bit TypeBit ua = liftUniqueArray ((size + 7) `quot` 8) ua + bit (TypeMask n) ua = liftUniqueArray (((size * fromInteger (natVal' n)) + 7) `quot` 8) ua num :: NumType e -> ArrayData e -> CodeQ (ArrayData e) num (IntegralNumType t) = integral t @@ -583,8 +569,8 @@ liftArrayData n = tuple integral :: IntegralType e -> ArrayData e -> CodeQ (ArrayData e) integral = \case - SingleIntegralType t -> single t n - VectorIntegralType n' t -> vector n' t (n * fromInteger (natVal' n')) + SingleIntegralType t -> single t size + VectorIntegralType n t -> vector n t (size * fromInteger (natVal' n)) where single :: SingleIntegralType e -> Int -> ArrayData e -> CodeQ (ArrayData e) single TypeInt8 = liftUniqueArray @@ -612,8 +598,8 @@ liftArrayData n = tuple floating :: FloatingType e -> ArrayData e -> CodeQ (ArrayData e) floating = \case - SingleFloatingType t -> single t n - VectorFloatingType n' t -> vector n' t (n * fromInteger (natVal' n')) + SingleFloatingType t -> single t size + VectorFloatingType n t -> vector n t (size * fromInteger (natVal' n)) where single :: SingleFloatingType e -> Int -> ArrayData e -> CodeQ (ArrayData e) single TypeFloat16 = liftUniqueArray diff --git a/src/Data/Primitive/Bit.hs b/src/Data/Primitive/Bit.hs index 0d200cae4..78ca506d9 100644 --- a/src/Data/Primitive/Bit.hs +++ b/src/Data/Primitive/Bit.hs @@ -29,10 +29,10 @@ module Data.Primitive.Bit ( import Data.Array.Accelerate.Error +import Control.Exception +import Control.Monad.ST import Data.Bits import Data.Typeable -import Control.Monad.ST -import Control.Exception import qualified Foreign.Storable as Foreign import Data.Primitive.ByteArray @@ -41,13 +41,18 @@ import Data.Primitive.Vec ( Vec(..) ) import GHC.Base ( isTrue# ) import GHC.Generics import GHC.Int -import GHC.Prim import GHC.Ptr import GHC.TypeLits import GHC.Types ( IO(..) ) import GHC.Word import qualified GHC.Exts as GHC +#if __GLASGOW_HASKELL__ < 902 +import GHC.Prim hiding ( subWord8# ) +#else +import GHC.Prim +#endif + -- | A newtype wrapper over 'Bool' whose instances pack bits as efficiently -- as possible (8 values per byte). Arrays of 'Bit' use 8x less memory than @@ -113,10 +118,12 @@ instance KnownNat n => GHC.IsList (BitMask n) where fromList = fromList instance KnownNat n => Foreign.Storable (BitMask n) where - {-# INLINE sizeOf #-} - {-# INLINE alignment #-} - {-# INLINE peek #-} - {-# INLINE poke #-} + {-# INLINE sizeOf #-} + {-# INLINE alignment #-} + {-# INLINE peek #-} + {-# INLINE poke #-} + {-# INLINE peekElemOff #-} + {-# INLINE pokeElemOff #-} alignment _ = 1 sizeOf _ = @@ -124,20 +131,101 @@ instance KnownNat n => Foreign.Storable (BitMask n) where in quot (k + 7) 8 peek (Ptr addr#) = - IO $ \s0 -> - case Foreign.sizeOf (undefined :: BitMask n) of { I# bytes# -> - case newByteArray# bytes# s0 of { (# s1, mba# #) -> - case copyAddrToByteArray# addr# mba# 0# bytes# s1 of { s2 -> - case unsafeFreezeByteArray# mba# s2 of { (# s3, ba# #) -> - (# s3, BitMask (Vec ba#) #) - }}}} + let k = natVal' (proxy# :: Proxy# n) + in if k `rem` 8 /= 0 + then error "TODO: use BitMask.peekElemOff for non-multiple-of-8 sized bit-masks" + else IO $ \s0 -> + case Foreign.sizeOf (undefined :: BitMask n) of { I# bytes# -> + case newByteArray# bytes# s0 of { (# s1, mba# #) -> + case copyAddrToByteArray# addr# mba# 0# bytes# s1 of { s2 -> + case unsafeFreezeByteArray# mba# s2 of { (# s3, ba# #) -> + (# s3, BitMask (Vec ba#) #) + }}}} + + peekElemOff (Ptr addr#) (I# i#) = + let !(I# k#) = fromInteger (natVal' (proxy# :: Proxy# n)) + !ki# = i# *# k# + in + if isTrue# (k# <=# 8#) + then let !(# q#, r# #) = quotRemInt# ki# 8# + !mask# = ((wordToWord8# 1## `uncheckedShiftLWord8#` k#) `subWord8#` wordToWord8# 1##) `uncheckedShiftLWord8#` (8# -# k#) + combine u# v# = (uncheckedShiftLWord8# u# r# `orWord8#` uncheckedShiftRLWord8# v# (8# -# r#)) `andWord8#` mask# + in + if isTrue# (r# +# k# <=# 8#) + -- This element does not not cross the byte boundary + then IO $ \s0 -> + case newByteArray# 1# s0 of { (# s1, mba# #) -> + case readWord8OffAddr# addr# q# s1 of { (# s2, w# #) -> + case writeWord8Array# mba# 0# (uncheckedShiftLWord8# w# r# `andWord8#` mask#) s2 of { s3 -> + case unsafeFreezeByteArray# mba# s3 of { (# s4, ba# #) -> + (# s4, BitMask (Vec ba#) #) + }}}} + -- This element crosses the byte boundary. Read two successive + -- bytes (note that on little-endian we can't just treat this + -- as a 16-bit load) to combine and extract the bits we need. + else IO $ \s0 -> + case newByteArray# 1# s0 of { (# s1, mba# #) -> + case readWord8OffAddr# addr# q# s1 of { (# s2, w0# #) -> + case readWord8OffAddr# addr# (q# +# 1#) s2 of { (# s3, w1# #) -> + case writeWord8Array# mba# 0# (combine w0# w1#) s3 of { s4 -> + case unsafeFreezeByteArray# mba# s4 of { (# s5, ba# #) -> + (# s5, BitMask (Vec ba#) #) + }}}}} + else + if isTrue# (k# <=# 16#) + then error "TODO: BitMask (8..16]" + else + if isTrue# (k# <=# 32#) + then error "TODO: BitMask (16..32]" + else + if isTrue# (k# <=# 64#) + then error "TODO: BitMask (32..64]" + else + error "TODO: BitMask.peekElemOff not yet supported at this size" poke (Ptr addr#) (BitMask (Vec ba#)) = - IO $ \s0 -> - case Foreign.sizeOf (undefined :: BitMask n) of { I# bytes# -> - case copyByteArrayToAddr# ba# 0# addr# bytes# s0 of { - s1 -> (# s1, () #) - }} + let k = natVal' (proxy# :: Proxy# n) + in if k `rem` 8 /= 0 + then error "TODO: use BitMask.pokeElemOff for non-multiple-of-8 sized bit-masks" + else IO $ \s0 -> + case Foreign.sizeOf (undefined :: BitMask n) of { I# bytes# -> + case copyByteArrayToAddr# ba# 0# addr# bytes# s0 of { + s1 -> (# s1, () #) + }} + + pokeElemOff (Ptr addr#) (I# i#) (BitMask (Vec ba#)) = + let !(I# k#) = fromInteger (natVal' (proxy# :: Proxy# n)) + !ki# = i# *# k# + in + if isTrue# (k# <=# 8#) + then let !(# q#, r# #) = quotRemInt# ki# 8# + !rk# = r# +# k# + in + if isTrue# (rk# <=# 8#) + -- This element does not cross the byte boundary + then let !w# = uncheckedShiftRLWord8# (indexWord8Array# ba# 0#) r# + !mask# = ((wordToWord8# 1## `uncheckedShiftLWord8#` k#) `subWord8#` wordToWord8# 1##) `uncheckedShiftLWord8#` (8# -# rk#) + in IO $ \s0 -> + case readWord8OffAddr# addr# q# s0 of { (# s1, v# #) -> + case writeWord8OffAddr# addr# q# ((v# `andWord8#` complementWord8# mask#) `orWord8#` (w# `andWord8#` mask#)) s1 of { s2 -> + (# s2, () #) + }} + -- This element crosses the byte boundary + else let !w# = indexWord8Array# ba# 0# + !w0# = w# `uncheckedShiftRLWord8#` r# + !w1# = w# `uncheckedShiftLWord8#` (8# -# r#) + !mask# = ((wordToWord8# 1## `uncheckedShiftLWord8#` k#) `subWord8#` wordToWord8# 1##) + !mask0# = mask# `uncheckedShiftRLWord8#` (rk# -# 8#) + !mask1# = mask# `uncheckedShiftLWord8#` (16# -# rk#) + in IO $ \s0 -> + case readWord8OffAddr# addr# q# s0 of { (# s1, v0# #) -> + case readWord8OffAddr# addr# (q# +# 1#) s1 of { (# s2, v1# #) -> + case writeWord8OffAddr# addr# q# ((v0# `andWord8#` complementWord8# mask0#) `orWord8#` (w0# `andWord8#` mask0#)) s2 of { s3 -> + case writeWord8OffAddr# addr# (q# +# 1#) ((v1# `andWord8#` complementWord8# mask1#) `orWord8#` (w1# `andWord8#` mask1#)) s3 of { s4 -> + (# s4, () #) + }}}} + else + error "TODO: BitMask.pokeElemOff not yet supported at this size" {-# INLINE toList #-} toList :: forall n. KnownNat n => BitMask n -> [Bit] @@ -161,19 +249,19 @@ toList (BitMask (Vec ba#)) = concat (unpack 0# []) {-# INLINE fromList #-} fromList :: forall n. KnownNat n => [Bit] -> BitMask n -fromList bits = case byteArrayFromListN bytes (pack bits') of +fromList bits = case byteArrayFromListN bytes (pack bits' []) of ByteArray ba# -> BitMask (Vec ba#) where bits' = take (fromInteger (natVal' (proxy# :: Proxy# n))) bits bytes = Foreign.sizeOf (undefined :: BitMask n) - pack :: [Bit] -> [Word8] - pack xs = + pack :: [Bit] -> [Word8] -> [Word8] + pack [] acc = acc + pack xs acc = let (h,t) = splitAt 8 xs w = w8 7 0 h - in if null t - then [w] - else w : pack t + in + pack t (w : acc) w8 :: Int -> Word8 -> [Bit] -> Word8 w8 !_ !w [] = w @@ -237,11 +325,35 @@ testBitWord8# x# i# = (x# `and#` bitWord8# i#) `neWord#` 0## bitWord8# :: Int# -> Word# bitWord8# i# = narrow8Word# (1## `uncheckedShiftL#` i#) +orWord8# :: Word# -> Word# -> Word# +orWord8# = or# + +andWord8# :: Word# -> Word# -> Word# +andWord8# = and# + +complementWord8# :: Word# -> Word# +complementWord8# x# = x# `xor#` 0xff## + +subWord8# :: Word# -> Word# -> Word# +subWord8# = minusWord# + +uncheckedShiftLWord8# :: Word# -> Int# -> Word# +uncheckedShiftLWord8# = uncheckedShiftL# + +uncheckedShiftRLWord8# :: Word# -> Int# -> Word# +uncheckedShiftRLWord8# = uncheckedShiftRL# + +wordToWord8# :: Word# -> Word# +wordToWord8# x = x + #else testBitWord8# :: Word8# -> Int# -> Int# -testBitWord8# x# i# = (x# `andWord8#` bitWord8# i#) `neWord8#` (wordToWord8# 0##) +testBitWord8# x# i# = (x# `andWord8#` bitWord8# i#) `neWord8#` wordToWord8# 0## bitWord8# :: Int# -> Word8# bitWord8# i# = (wordToWord8# 1##) `uncheckedShiftLWord8#` i# + +complementWord8# :: Word8# -> Word8# +complementWord8# x# = x# `xorWord8#` wordToWord8# 0xff## #endif From 4d3bd0c791203d8e91c45e8de20b687edfcc956c Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Thu, 29 Sep 2022 18:39:53 +0200 Subject: [PATCH 51/86] nofib build fix --- .../Accelerate/Test/NoFib/Issues/Issue407.hs | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs index c69fb3c8c..ee5d6131b 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs @@ -1,6 +1,7 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -27,7 +28,8 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue407 ( import Prelude as P hiding ( Bool(..) ) import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Elt as S +import Data.Array.Accelerate.AST ( BitOrMask, PrimBool ) +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty @@ -42,16 +44,18 @@ test_issue407 runN = ] where testElt - :: forall a. (Show a, P.Fractional a, A.RealFloat a) + :: forall a. (Show a, P.Fractional a, A.RealFloat a, BitOrMask (EltR a) ~ PrimBool) => TestTree testElt = + let xs :: Vector a + xs = [0/0, -2/0, -0/0, 0.1, 1/0, 0.5, 5/0] + + eNaN, eInf :: Vector Bool + eNaN = [True, False, True, False, False, False, False] -- expected: isNaN + eInf = [False, True, False, False, True, False, True] -- expected: isInfinite + in testGroup (show (eltR @a)) [ testCase "isNaN" $ eNaN @=? runN (A.map A.isNaN) xs , testCase "isInfinite" $ eInf @=? runN (A.map A.isInfinite) xs ] - where - xs :: Vector a - xs = [0/0, -2/0, -0/0, 0.1, 1/0, 0.5, 5/0] - eNaN = [True, False, True, False, False, False, False] -- expected: isNaN - eInf = [False, True, False, False, True, False, True] -- expected: isInfinite From 970046798631fc614c9de6cbbb5375d3f941b444 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Sat, 1 Oct 2022 12:39:54 +0200 Subject: [PATCH 52/86] export strict (&&!) and (||!) --- src/Data/Array/Accelerate.hs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 53e0046ab..74927aff6 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -377,6 +377,7 @@ module Data.Array.Accelerate ( -- *** Logical operations (&&), (||), not, + (&&!), (||!), -- *** Numeric operations subtract, even, odd, gcd, lcm, (^), (^^), From 86fd7fac856a605f88ace16a31c46d5f772ffc6b Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Sat, 1 Oct 2022 12:40:25 +0200 Subject: [PATCH 53/86] export 128-bit types --- src/Data/Array/Accelerate.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 74927aff6..fd0907bf0 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -421,9 +421,9 @@ module Data.Array.Accelerate ( -- --------------------------------------------------------------------------- -- Types - Int, Int8, Int16, Int32, Int64, - Word, Word8, Word16, Word32, Word64, - Half(..), Float, Double, + Int, Int8, Int16, Int32, Int64, Int128, + Word, Word8, Word16, Word32, Word64, Word128, + Half(..), Float, Double, Float16, Float32, Float64, Float128, Bool, pattern True, pattern False, Maybe, pattern Nothing, pattern Just, Either, pattern Left, pattern Right, From 9fe3290b5efbb7b080f70aeb26efad3acf6f35b3 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Sat, 1 Oct 2022 12:41:03 +0200 Subject: [PATCH 54/86] export operators on SIMD vectors --- src/Data/Array/Accelerate.hs | 5 ++++- src/Data/Array/Accelerate/AST.hs | 9 +++++---- src/Data/Array/Accelerate/Analysis/Hash.hs | 2 +- src/Data/Array/Accelerate/Analysis/Match.hs | 5 +++-- src/Data/Array/Accelerate/Interpreter.hs | 5 ++--- src/Data/Array/Accelerate/Pretty/Graphviz.hs | 2 +- src/Data/Array/Accelerate/Pretty/Print.hs | 2 +- src/Data/Array/Accelerate/Smart.hs | 15 ++++++++------- src/Data/Array/Accelerate/Trafo/Fusion.hs | 2 +- src/Data/Array/Accelerate/Trafo/Sharing.hs | 6 +++--- src/Data/Array/Accelerate/Trafo/Shrink.hs | 6 +++--- src/Data/Array/Accelerate/Trafo/Simplify.hs | 4 ++-- src/Data/Array/Accelerate/Trafo/Substitution.hs | 4 ++-- 13 files changed, 36 insertions(+), 31 deletions(-) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index fd0907bf0..439d4aa77 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -369,8 +369,11 @@ module Data.Array.Accelerate ( -- *** Tuples fst, afst, snd, asnd, curry, uncurry, + -- *** SIMD vectors + insert, extract, shuffle, + -- *** Flow control - (?), match, cond, while, iterate, + (?), select, match, cond, while, iterate, -- *** Scalar reduction sfoldl, diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 4e48ea1ba..7532daf6f 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -566,7 +566,8 @@ data OpenExp env aenv t where -> OpenExp env aenv (Vec m i) -> OpenExp env aenv (Vec m a) - Select :: OpenExp env aenv (PrimMask n) + Select :: ScalarType (Vec n a) + -> OpenExp env aenv (PrimMask n) -> OpenExp env aenv (Vec n a) -> OpenExp env aenv (Vec n a) -> OpenExp env aenv (Vec n a) @@ -843,7 +844,7 @@ expType = \case -- Insert t _ _ _ _ -> TupRsingle t Shuffle t _ _ _ _ -> TupRsingle t - Select _ x _ -> expType x + Select t _ _ _ -> TupRsingle t IndexSlice si _ _ -> shapeType $ sliceShapeR si IndexFull si _ _ -> shapeType $ sliceDomainR si ToIndex{} -> TupRsingle scalarType @@ -1126,7 +1127,7 @@ rnfOpenExp topExp = Extract vR iR v i -> rnfScalarType vR `seq` rnfSingleIntegralType iR `seq` rnfE v `seq` rnfE i Insert vR iR v i x -> rnfScalarType vR `seq` rnfSingleIntegralType iR `seq` rnfE v `seq` rnfE i `seq` rnfE x Shuffle eR iR x y i -> rnfScalarType eR `seq` rnfSingleIntegralType iR `seq` rnfE x `seq` rnfE y `seq` rnfE i - Select m x y -> rnfE m `seq` rnfE x `seq` rnfE y + Select eR m x y -> rnfScalarType eR `seq` rnfE m `seq` rnfE x `seq` rnfE y IndexSlice slice slix sh -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sh IndexFull slice slix sl -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sl ToIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix @@ -1344,7 +1345,7 @@ liftOpenExp pexp = Extract vR iR v i -> [|| Extract $$(liftScalarType vR) $$(liftSingleIntegralType iR) $$(liftE v) $$(liftE i) ||] Insert vR iR v i x -> [|| Insert $$(liftScalarType vR) $$(liftSingleIntegralType iR) $$(liftE v) $$(liftE i) $$(liftE x) ||] Shuffle eR iR x y i -> [|| Shuffle $$(liftScalarType eR) $$(liftSingleIntegralType iR) $$(liftE x) $$(liftE y) $$(liftE i) ||] - Select m x y -> [|| Select $$(liftE m) $$(liftE x) $$(liftE y) ||] + Select eR m x y -> [|| Select $$(liftScalarType eR) $$(liftE m) $$(liftE x) $$(liftE y) ||] IndexSlice slice slix sh -> [|| IndexSlice $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sh) ||] IndexFull slice slix sl -> [|| IndexFull $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sl) ||] ToIndex shr sh ix -> [|| ToIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||] diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 25c96975a..048596014 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -325,7 +325,7 @@ encodeOpenExp exp = Extract vR iR v i -> intHost $(hashQ "Extract") <> encodeScalarType vR <> encodeSingleIntegralType iR <> travE v <> travE i Insert vR iR v i x -> intHost $(hashQ "Insert") <> encodeScalarType vR <> encodeSingleIntegralType iR <> travE v <> travE i <> travE x Shuffle eR iR x y i -> intHost $(hashQ "Shuffle") <> encodeScalarType eR <> encodeSingleIntegralType iR <> travE x <> travE y <> travE i - Select m x y -> intHost $(hashQ "Select") <> travE m <>travE x <> travE y + Select eR m x y -> intHost $(hashQ "Select") <> encodeScalarType eR <> travE m <>travE x <> travE y Const tp c -> intHost $(hashQ "Const") <> encodeScalarConst tp c Undef tp -> intHost $(hashQ "Undef") <> encodeScalarType tp IndexSlice spec ix sh -> intHost $(hashQ "IndexSlice") <> travE ix <> travE sh <> encodeSliceIndex spec diff --git a/src/Data/Array/Accelerate/Analysis/Match.hs b/src/Data/Array/Accelerate/Analysis/Match.hs index db94b1a2a..bb6833a86 100644 --- a/src/Data/Array/Accelerate/Analysis/Match.hs +++ b/src/Data/Array/Accelerate/Analysis/Match.hs @@ -497,8 +497,9 @@ matchOpenExp (Shuffle eR1 iR1 x1 y1 i1) (Shuffle eR2 iR2 x2 y2 i2) , Just Refl <- matchOpenExp i1 i2 = Just Refl -matchOpenExp (Select p1 x1 y1) (Select p2 x2 y2) - | Just Refl <- matchOpenExp p1 p2 +matchOpenExp (Select eR1 p1 x1 y1) (Select eR2 p2 x2 y2) + | Just Refl <- matchScalarType eR1 eR2 + , Just Refl <- matchOpenExp p1 p2 , Just Refl <- matchOpenExp x1 x2 , Just Refl <- matchOpenExp y1 y2 = Just Refl diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index dc5c1a059..bd5a4c2d5 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -919,7 +919,7 @@ evalOpenExp pexp env aenv = in evalOpenExp exp2 env' aenv Evar (Var _ ix) -> prj ix env Const _ c -> c - Undef tp -> undefElt (TupRsingle tp) + Undef eR -> undefElt (TupRsingle eR) PrimApp f x -> evalPrim f (evalE x) Nil -> () Pair e1 e2 -> let !x1 = evalE e1 @@ -929,8 +929,7 @@ evalOpenExp pexp env aenv = Insert vR iR v i x -> evalInsert vR iR (evalE v) (evalE i) (evalE x) Shuffle rR iR x y i -> let TupRsingle eR = expType x in evalShuffle eR rR iR (evalE x) (evalE y) (evalE i) - Select m x y -> let TupRsingle eR = expType x - in evalSelect eR (evalE m) (evalE x) (evalE y) + Select eR m x y -> evalSelect eR (evalE m) (evalE x) (evalE y) IndexSlice slice slix sh -> restrict slice (evalE slix) (evalE sh) where restrict :: SliceIndex slix sl co sh -> slix -> sh -> sl diff --git a/src/Data/Array/Accelerate/Pretty/Graphviz.hs b/src/Data/Array/Accelerate/Pretty/Graphviz.hs index f3954e12e..93b9af69b 100644 --- a/src/Data/Array/Accelerate/Pretty/Graphviz.hs +++ b/src/Data/Array/Accelerate/Pretty/Graphviz.hs @@ -525,7 +525,7 @@ fvOpenExp env aenv = fv fv (Extract _ _ v i) = concat [ fv v, fv i ] fv (Insert _ _ v i x) = concat [ fv v, fv i, fv x ] fv (Shuffle _ _ x y i) = concat [ fv x, fv y, fv i ] - fv (Select m x y) = concat [ fv m, fv x, fv y ] + fv (Select _ m x y) = concat [ fv m, fv x, fv y ] fv (IndexSlice _ slix sh) = concat [ fv slix, fv sh ] fv (IndexFull _ slix sh) = concat [ fv slix, fv sh ] fv (ToIndex _ sh ix) = concat [ fv sh, fv ix ] diff --git a/src/Data/Array/Accelerate/Pretty/Print.hs b/src/Data/Array/Accelerate/Pretty/Print.hs index 37f6cf9b0..74b3e9aa5 100644 --- a/src/Data/Array/Accelerate/Pretty/Print.hs +++ b/src/Data/Array/Accelerate/Pretty/Print.hs @@ -419,7 +419,7 @@ prettyOpenExp ctx env aenv exp = Extract _ _ v i -> ppF2 (Operator "#" Infix L 9) (ppE v) (ppE i) Insert{} -> prettyInsert ctx env aenv exp Shuffle _ _ x y i -> ppF3 "shuffle" (ppE x) (ppE y) (ppE i) - Select m x y -> ppF3 "select" (ppE m) (ppE x) (ppE y) + Select _ m x y -> ppF3 "select" (ppE m) (ppE x) (ppE y) Case tR x xs d -> prettyCase env aenv tR x xs d Cond p t e -> flatAlt multi single where diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 7a6ca85c3..80bf0caf2 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -553,7 +553,8 @@ data PreSmartExp acc exp t where -> exp (Prim.Vec m i) -> PreSmartExp acc exp (Prim.Vec m a) - Select :: exp (Prim.Vec n Bit) + Select :: ScalarType (Prim.Vec n a) + -> exp (Prim.Vec n PrimBool) -> exp (Prim.Vec n a) -> exp (Prim.Vec n a) -> PreSmartExp acc exp (Prim.Vec n a) @@ -908,7 +909,7 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where -- Insert t _ _ _ _ -> TupRsingle t Shuffle t _ _ _ _ -> TupRsingle t - Select _ x _ -> typeR x + Select t _ _ _ -> TupRsingle t ToIndex _ _ _ -> TupRsingle (scalarType @INT) FromIndex shr _ _ -> shapeType shr Case _ ((_,c):_) -> typeR c @@ -1201,7 +1202,7 @@ select (Exp mask) (Exp tt) (Exp ff) = Exp $ go (vecR @n @a) tt ff bit :: BitType t -> SmartExp t -> SmartExp t -> SmartExp t bit (TypeMask n) | Just Refl <- sameNat' n (proxy# :: Proxy# n) - = SmartExp $$ Select mask + = SmartExp $$ Select (BitScalarType (TypeMask n)) mask bit _ = error "impossible" num :: NumType t -> SmartExp t -> SmartExp t -> SmartExp t @@ -1209,15 +1210,15 @@ select (Exp mask) (Exp tt) (Exp ff) = Exp $ go (vecR @n @a) tt ff num (FloatingNumType vR) = floating vR integral :: IntegralType t -> SmartExp t -> SmartExp t -> SmartExp t - integral (VectorIntegralType n _) + integral (VectorIntegralType n t) | Just Refl <- sameNat' n (proxy# :: Proxy# n) - = SmartExp $$ Select mask + = SmartExp $$ Select (NumScalarType (IntegralNumType (VectorIntegralType n t))) mask integral _ = error "impossible" floating :: FloatingType t -> SmartExp t -> SmartExp t -> SmartExp t - floating (VectorFloatingType n _) + floating (VectorFloatingType n t) | Just Refl <- sameNat' n (proxy# :: Proxy# n) - = SmartExp $$ Select mask + = SmartExp $$ Select (NumScalarType (FloatingNumType (VectorFloatingType n t))) mask floating _ = error "impossible" diff --git a/src/Data/Array/Accelerate/Trafo/Fusion.hs b/src/Data/Array/Accelerate/Trafo/Fusion.hs index d05abfb5d..8f82f6352 100644 --- a/src/Data/Array/Accelerate/Trafo/Fusion.hs +++ b/src/Data/Array/Accelerate/Trafo/Fusion.hs @@ -1475,7 +1475,7 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en Extract vR iR v i -> Extract vR iR (cvtE v) (cvtE i) Insert vR iR v i x -> Insert vR iR (cvtE v) (cvtE i) (cvtE x) Shuffle vR iR x y i -> Shuffle vR iR (cvtE x) (cvtE y) (cvtE i) - Select m x y -> Select (cvtE m) (cvtE x) (cvtE y) + Select eR m x y -> Select eR (cvtE m) (cvtE x) (cvtE y) IndexSlice x ix sh -> IndexSlice x (cvtE ix) (cvtE sh) IndexFull x ix sl -> IndexFull x (cvtE ix) (cvtE sl) ToIndex shR' sh ix -> ToIndex shR' (cvtE sh) (cvtE ix) diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index d05f5ad29..798e7abd7 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -761,7 +761,7 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp Extract vR iR v i -> AST.Extract vR iR (cvt v) (cvt i) Insert vR iR v i x -> AST.Insert vR iR (cvt v) (cvt i) (cvt x) Shuffle eR iR x y i -> AST.Shuffle eR iR (cvt x) (cvt y) (cvt i) - Select m x y -> AST.Select (cvt m) (cvt x) (cvt y) + Select eR m x y -> AST.Select eR (cvt m) (cvt x) (cvt y) ToIndex shr sh ix -> AST.ToIndex shr (cvt sh) (cvt ix) FromIndex shr sh e -> AST.FromIndex shr (cvt sh) (cvt e) Case e rhs -> cvtCase (cvt e) (over (mapped . _2) cvt rhs) @@ -1884,7 +1884,7 @@ makeOccMapSharingExp config accOccMap expOccMap = travE Extract vR iR v i -> travE2 (Extract vR iR) v i Insert vR iR v i x -> travE3 (Insert vR iR) v i x Shuffle eR iR x y i -> travE3 (Shuffle eR iR) x y i - Select m x y -> travE3 Select m x y + Select eR m x y -> travE3 (Select eR) m x y ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix FromIndex shr sh e -> travE2 (FromIndex shr) sh e Match t e -> travE1 (Match t) e @@ -2791,7 +2791,7 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp Extract vR iR v i -> travE2 (Extract vR iR) v i Insert vR iR v i x -> travE3 (Insert vR iR) v i x Shuffle eR iR x y i -> travE3 (Shuffle eR iR) x y i - Select m x y -> travE3 Select m x y + Select eR m x y -> travE3 (Select eR) m x y ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix FromIndex shr sh e -> travE2 (FromIndex shr) sh e Match t e -> travE1 (Match t) e diff --git a/src/Data/Array/Accelerate/Trafo/Shrink.hs b/src/Data/Array/Accelerate/Trafo/Shrink.hs index 8f80dbc68..d5cab4a95 100644 --- a/src/Data/Array/Accelerate/Trafo/Shrink.hs +++ b/src/Data/Array/Accelerate/Trafo/Shrink.hs @@ -293,7 +293,7 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE Extract vR iR v i -> Extract vR iR <$> shrinkE v <*> shrinkE i Insert vR iR v i x -> Insert vR iR <$> shrinkE v <*> shrinkE i <*> shrinkE x Shuffle eR iR x y i -> Shuffle eR iR <$> shrinkE x <*> shrinkE y <*> shrinkE i - Select m x y -> Select <$> shrinkE m <*> shrinkE x <*> shrinkE y + Select eR m x y -> Select eR <$> shrinkE m <*> shrinkE x <*> shrinkE y IndexSlice x ix sh -> IndexSlice x <$> shrinkE ix <*> shrinkE sh IndexFull x ix sl -> IndexFull x <$> shrinkE ix <*> shrinkE sl ToIndex shr sh ix -> ToIndex shr <$> shrinkE sh <*> shrinkE ix @@ -494,7 +494,7 @@ usesOfExp range = countE Extract _ _ v i -> countE v <> countE i Insert _ _ v i x -> countE v <> countE i <> countE x Shuffle _ _ x y i -> countE x <> countE y <> countE i - Select m x y -> countE m <> countE x <> countE y + Select _ m x y -> countE m <> countE x <> countE y IndexSlice _ ix sh -> countE ix <> countE sh IndexFull _ ix sl -> countE ix <> countE sl FromIndex _ sh i -> countE sh <> countE i @@ -582,7 +582,7 @@ usesOfPreAcc withShape countAcc idx = count Extract _ _ v i -> countE v + countE i Insert _ _ v i x -> countE v + countE i + countE x Shuffle _ _ x y i -> countE x + countE y + countE i - Select m x y -> countE m + countE x + countE y + Select _ m x y -> countE m + countE x + countE y IndexSlice _ ix sh -> countE ix + countE sh IndexFull _ ix sl -> countE ix + countE sl ToIndex _ sh ix -> countE sh + countE ix diff --git a/src/Data/Array/Accelerate/Trafo/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Simplify.hs index 5b9bd3fe5..a3788b1e1 100644 --- a/src/Data/Array/Accelerate/Trafo/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Simplify.hs @@ -279,7 +279,7 @@ simplifyOpenExp env = first getAny . cvtE Extract vR iR v i -> Extract vR iR <$> cvtE v <*> cvtE i Insert vR iR v i x -> Insert vR iR <$> cvtE v <*> cvtE i <*> cvtE x Shuffle eR iR x y i -> Shuffle eR iR <$> cvtE x <*> cvtE y <*> cvtE i - Select m x y -> Select <$> cvtE m <*> cvtE x <*> cvtE y + Select eR m x y -> Select eR <$> cvtE m <*> cvtE x <*> cvtE y IndexSlice x ix sh -> IndexSlice x <$> cvtE ix <*> cvtE sh IndexFull x ix sl -> IndexFull x <$> cvtE ix <*> cvtE sl ToIndex shr sh ix -> toIndex shr (cvtE sh) (cvtE ix) @@ -591,7 +591,7 @@ summariseOpenExp = (terms +~ 1) . goE Extract _ _ v i -> travE v +++ travE i Insert _ _ v i x -> travE v +++ travE i +++ travE x Shuffle _ _ x y i -> travE x +++ travE y +++ travE i - Select m x y -> travE m +++ travE x +++ travE y + Select _ m x y -> travE m +++ travE x +++ travE y IndexSlice _ slix sh -> travE slix +++ travE sh & terms +~ 1 -- +1 for sliceIndex IndexFull _ slix sl -> travE slix +++ travE sl & terms +~ 1 -- +1 for sliceIndex ToIndex _ sh ix -> travE sh +++ travE ix diff --git a/src/Data/Array/Accelerate/Trafo/Substitution.hs b/src/Data/Array/Accelerate/Trafo/Substitution.hs index fe496765c..72f84e997 100644 --- a/src/Data/Array/Accelerate/Trafo/Substitution.hs +++ b/src/Data/Array/Accelerate/Trafo/Substitution.hs @@ -157,7 +157,7 @@ inlineVars lhsBound expr bound Extract vR iR v i -> Extract vR iR <$> travE v <*> travE i Insert vR iR v i x -> Insert vR iR <$> travE v <*> travE i <*> travE x Shuffle vR iR x y m -> Shuffle vR iR <$> travE x <*> travE y <*> travE m - Select m x y -> Select <$> travE m <*> travE x <*> travE y + Select eR m x y -> Select eR <$> travE m <*> travE x <*> travE y IndexSlice si e1 e2 -> IndexSlice si <$> travE e1 <*> travE e2 IndexFull si e1 e2 -> IndexFull si <$> travE e1 <*> travE e2 ToIndex shr e1 e2 -> ToIndex shr <$> travE e1 <*> travE e2 @@ -558,7 +558,7 @@ rebuildOpenExp v av@(ReindexAvar reindex) exp = Extract vR iR u i -> Extract vR iR <$> rebuildOpenExp v av u <*> rebuildOpenExp v av i Insert vR iR u i x -> Insert vR iR <$> rebuildOpenExp v av u <*> rebuildOpenExp v av i <*> rebuildOpenExp v av x Shuffle vR iR x y m -> Shuffle vR iR <$> rebuildOpenExp v av x <*> rebuildOpenExp v av y <*> rebuildOpenExp v av m - Select m x y -> Select <$> rebuildOpenExp v av m <*> rebuildOpenExp v av x <*> rebuildOpenExp v av y + Select eR m x y -> Select eR <$> rebuildOpenExp v av m <*> rebuildOpenExp v av x <*> rebuildOpenExp v av y IndexSlice x ix sh -> IndexSlice x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh IndexFull x ix sl -> IndexFull x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl ToIndex shr sh ix -> ToIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix From 259191058e6aabce6a3b64cfad1311ec042e454e Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 3 Oct 2022 14:29:25 +0200 Subject: [PATCH 55/86] add conversion between bool and integral types * Adds a new class FromBool that converts (vector of) Bool to various (vector of) integral types * FromIntegral class can now convert (vector of) integral types to (vector of) Bool --- accelerate.cabal | 1 + src/Data/Array/Accelerate.hs | 4 +- src/Data/Array/Accelerate/Classes/FromBool.hs | 65 +++++++++++++++++++ .../Array/Accelerate/Classes/FromIntegral.hs | 19 +++++- src/Data/Array/Accelerate/Language.hs | 8 +-- src/Data/Array/Accelerate/Prelude.hs | 5 +- src/Data/Array/Accelerate/Smart.hs | 8 +-- 7 files changed, 93 insertions(+), 17 deletions(-) create mode 100644 src/Data/Array/Accelerate/Classes/FromBool.hs diff --git a/accelerate.cabal b/accelerate.cabal index 3edf5cbc3..960dc8985 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -459,6 +459,7 @@ library Data.Array.Accelerate.Classes.Eq Data.Array.Accelerate.Classes.Floating Data.Array.Accelerate.Classes.Fractional + Data.Array.Accelerate.Classes.FromBool Data.Array.Accelerate.Classes.FromIntegral Data.Array.Accelerate.Classes.Integral Data.Array.Accelerate.Classes.Num diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 439d4aa77..aaaacefc7 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -331,6 +331,7 @@ module Data.Array.Accelerate ( RealFloat(..), -- *** Numeric conversion classes + FromBool(..), FromIntegral(..), ToFloating(..), @@ -392,7 +393,7 @@ module Data.Array.Accelerate ( intersect, -- *** Conversions - ord, chr, boolToInt, bitcast, + ord, chr, bitcast, -- --------------------------------------------------------------------------- -- * Foreign Function Interface (FFI) @@ -439,6 +440,7 @@ import Data.Array.Accelerate.Classes.Enum import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.Floating import Data.Array.Accelerate.Classes.Fractional +import Data.Array.Accelerate.Classes.FromBool import Data.Array.Accelerate.Classes.FromIntegral import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num diff --git a/src/Data/Array/Accelerate/Classes/FromBool.hs b/src/Data/Array/Accelerate/Classes/FromBool.hs new file mode 100644 index 000000000..c680275b0 --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/FromBool.hs @@ -0,0 +1,65 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TemplateHaskell #-} +-- | +-- Module : Data.Array.Accelerate.Classes.FromBool +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.FromBool ( + + FromBool(..), + +) where + +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Type + +import Language.Haskell.TH hiding ( Exp ) + + +-- | Convert from Bool to integral types +-- +-- @since 1.4.0.0 +-- +class FromBool a b where + fromBool :: Integral b => Exp a -> Exp b + + +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + thFromBool :: Name -> Q [Dec] + thFromBool b = + [d| instance FromBool Bool $(conT b) where + fromBool = mkFromBool + + instance KnownNat n => FromBool (Vec n Bool) (Vec n $(conT b)) where + fromBool = mkFromBool + |] + in + concat <$> mapM thFromBool integralTypes + + diff --git a/src/Data/Array/Accelerate/Classes/FromIntegral.hs b/src/Data/Array/Accelerate/Classes/FromIntegral.hs index f59fd6f54..422ce0b50 100644 --- a/src/Data/Array/Accelerate/Classes/FromIntegral.hs +++ b/src/Data/Array/Accelerate/Classes/FromIntegral.hs @@ -65,7 +65,7 @@ class FromIntegral a b where -- > -- -- > concat `fmap` mapM dig cons -- -runQ $ +runQ $ do let integralTypes :: [Name] integralTypes = @@ -102,6 +102,19 @@ runQ $ instance KnownNat n => FromIntegral (Vec n $(conT a)) (Vec n $(conT b)) where fromIntegral = $(varE $ if a == b then 'id else 'mkFromIntegral ) |] - in - concat <$> sequence [ thFromIntegral from to | from <- integralTypes, to <- numTypes ] + + thToBool :: Name -> Q [Dec] + thToBool a = + [d| -- | @since 1.4.0.0 + instance FromIntegral $(conT a) Bool where + fromIntegral = mkToBool + + -- | @since 1.4.0.0 + instance KnownNat n => FromIntegral (Vec n $(conT a)) (Vec n Bool) where + fromIntegral = mkToBool + |] + -- + x <- concat <$> sequence [ thFromIntegral from to | from <- integralTypes, to <- numTypes ] + y <- concat <$> mapM thToBool integralTypes + return (x ++ y) diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index b0ff4f9b6..b5a470615 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -83,7 +83,7 @@ module Data.Array.Accelerate.Language ( subtract, even, odd, gcd, lcm, (^), (^^), -- * Conversions - ord, chr, boolToInt, bitcast, + ord, chr, bitcast, ) where @@ -1478,12 +1478,6 @@ ord = mkFromIntegral chr :: Exp Int -> Exp Char chr = mkFromIntegral --- |Convert a Boolean value to an 'Int', where 'False' turns into '0' and 'True' --- into '1'. --- -boolToInt :: Exp Bool -> Exp Int -boolToInt = mkFromBool - -- |Reinterpret a value as another type. The two representations must have the -- same bit size. -- diff --git a/src/Data/Array/Accelerate/Prelude.hs b/src/Data/Array/Accelerate/Prelude.hs index eedd257b2..690558cf0 100644 --- a/src/Data/Array/Accelerate/Prelude.hs +++ b/src/Data/Array/Accelerate/Prelude.hs @@ -131,6 +131,7 @@ import Data.Array.Accelerate.Sugar.Shape ( Shape, Sli import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Eq +import Data.Array.Accelerate.Classes.FromBool import Data.Array.Accelerate.Classes.FromIntegral import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num @@ -1656,7 +1657,7 @@ compact keep arr -- for the offset indices. | Just Refl <- matchShapeType @sh @Z = let - T2 target len = scanl' (+) 0 (map boolToInt keep) + T2 target len = scanl' (+) 0 (map fromBool keep) prj ix = if keep!ix then Just (I1 (target!ix)) else Nothing @@ -1673,7 +1674,7 @@ compact keep arr compact keep arr = let sz = indexTail (shape arr) - T2 target len = scanl' (+) 0 (map boolToInt keep) + T2 target len = scanl' (+) 0 (map fromBool keep) T2 offset valid = scanl' (+) 0 (flatten len) prj ix = if keep!ix then Just (I1 (offset !! (toIndex sz (indexTail ix)) + target!ix)) diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 80bf0caf2..0f229a923 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -1437,11 +1437,11 @@ mkFromIntegral = mkPrimUnary $ PrimFromIntegral integralType numType mkToFloating :: (IsNum (EltR a), IsFloating (EltR b)) => Exp a -> Exp b mkToFloating = mkPrimUnary $ PrimToFloating numType floatingType -mkToBool :: (IsSingleIntegral (EltR a), BitOrMask (EltR a) ~ Bit) => Exp a -> Exp Bool -mkToBool = mkPrimUnary $ PrimToBool (SingleIntegralType singleIntegralType) bitType +mkToBool :: (IsIntegral (EltR a), IsBit (EltR b)) => Exp a -> Exp b +mkToBool = mkPrimUnary $ PrimToBool integralType bitType -mkFromBool :: (IsSingleIntegral (EltR a), BitOrMask (EltR a) ~ Bit) => Exp Bool -> Exp a -mkFromBool = mkPrimUnary $ PrimFromBool bitType (SingleIntegralType singleIntegralType) +mkFromBool :: (IsBit (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b +mkFromBool = mkPrimUnary $ PrimFromBool bitType integralType -- Other conversions From 6e8897ec8d8d44dc28592adaec191c31a9434455 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 3 Oct 2022 14:29:50 +0200 Subject: [PATCH 56/86] export vector splat --- src/Data/Array/Accelerate.hs | 2 +- .../Array/Accelerate/Classes/RealFloat.hs | 3 --- src/Data/Array/Accelerate/Smart.hs | 19 +++++++++++++------ 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index aaaacefc7..ba1d2711a 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -371,7 +371,7 @@ module Data.Array.Accelerate ( fst, afst, snd, asnd, curry, uncurry, -- *** SIMD vectors - insert, extract, shuffle, + splat, insert, extract, shuffle, -- *** Flow control (?), select, match, cond, while, iterate, diff --git a/src/Data/Array/Accelerate/Classes/RealFloat.hs b/src/Data/Array/Accelerate/Classes/RealFloat.hs index 22157a185..4ad290078 100644 --- a/src/Data/Array/Accelerate/Classes/RealFloat.hs +++ b/src/Data/Array/Accelerate/Classes/RealFloat.hs @@ -311,9 +311,6 @@ mkBitcast' -> Exp (Vec n b) mkBitcast' (Exp a) = mkExp $ Coerce (scalarType @(VecR n a)) (scalarType @(VecR n b)) a -splat :: (KnownNat n, SIMD n a, Elt a) => Exp a -> Exp (Vec n a) -splat x = mkPack (P.repeat x) - defaultFloatRadix :: forall a. P.RealFloat a => Exp a -> Exp Int defaultFloatRadix _ = P.fromInteger (P.floatRadix (undefined::a)) diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 0f229a923..0685ce8c4 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -53,7 +53,7 @@ module Data.Array.Accelerate.Smart ( indexHead, indexTail, -- ** Vector operations - mkPack, mkUnpack, + splat, mkPack, mkUnpack, extract, mkExtract, insert, mkInsert, shuffle, @@ -131,7 +131,7 @@ import qualified Data.Primitive.Vec as Prim import Data.Kind import Data.Text.Lazy.Builder -import Formatting +import Formatting hiding ( splat ) import GHC.Prim import GHC.TypeLits @@ -1001,10 +1001,12 @@ indexTail :: (Elt sh, Elt a) => Exp (sh :. a) -> Exp sh indexTail (Exp x) = mkExp $ Prj PairIdxLeft x -mkUnpack :: forall n a. (SIMD n a, Elt a) => Exp (Vec n a) -> [Exp a] -mkUnpack v = - let n = fromIntegral (natVal' (proxy# :: Proxy# n)) :: Word8 - in map (extract v . constant) [0 .. n-1] +-- | Fill all lanes of a SIMD vector with the given value +-- +-- @since 1.4.0.0 +-- +splat :: (SIMD n a, Elt a) => Exp a -> Exp (Vec n a) +splat x = mkPack (repeat x) mkPack :: forall n a. (SIMD n a, Elt a) => [Exp a] -> Exp (Vec n a) mkPack xs = @@ -1016,6 +1018,11 @@ mkPack xs = in go 0 (take n xs) undef +mkUnpack :: forall n a. (SIMD n a, Elt a) => Exp (Vec n a) -> [Exp a] +mkUnpack v = + let n = fromIntegral (natVal' (proxy# :: Proxy# n)) :: Word8 + in map (extract v . constant) [0 .. n-1] + -- | Extract a single scalar element from the given SIMD vector at the -- specified index -- From 3ddaed9f294116318eede0589fb764cd7bea0905 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 3 Oct 2022 17:09:38 +0200 Subject: [PATCH 57/86] fix FromBool --- src/Data/Array/Accelerate/Classes/FromBool.hs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Data/Array/Accelerate/Classes/FromBool.hs b/src/Data/Array/Accelerate/Classes/FromBool.hs index c680275b0..59fe87f5c 100644 --- a/src/Data/Array/Accelerate/Classes/FromBool.hs +++ b/src/Data/Array/Accelerate/Classes/FromBool.hs @@ -22,7 +22,10 @@ import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Classes.Integral + import Language.Haskell.TH hiding ( Exp ) +import Prelude hiding ( Integral ) -- | Convert from Bool to integral types From 8a09b111963954f13d64b55f2521f2a73a71511d Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 3 Oct 2022 23:44:20 +0200 Subject: [PATCH 58/86] add vand, vor operators --- src/Data/Array/Accelerate.hs | 1 + src/Data/Array/Accelerate/Classes/Eq.hs | 26 ++------------- src/Data/Array/Accelerate/Smart.hs | 44 +++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 24 deletions(-) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index ba1d2711a..348c67a25 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -382,6 +382,7 @@ module Data.Array.Accelerate ( -- *** Logical operations (&&), (||), not, (&&!), (||!), + vand, vor, -- *** Numeric operations subtract, even, odd, gcd, lcm, (^), (^^), diff --git a/src/Data/Array/Accelerate/Classes/Eq.hs b/src/Data/Array/Accelerate/Classes/Eq.hs index 498ea501e..a57142888 100644 --- a/src/Data/Array/Accelerate/Classes/Eq.hs +++ b/src/Data/Array/Accelerate/Classes/Eq.hs @@ -212,28 +212,6 @@ instance Eq Ordering where x /= y = mkCoerce x /= (mkCoerce y :: Exp TAG) instance VEq n a => Eq (Vec n a) where - (==) = vcmp (==*) - (/=) = vcmp (/=*) - -vcmp :: forall n a. KnownNat n - => (Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool)) - -> (Exp (Vec n a) -> Exp (Vec n a) -> Exp Bool) -vcmp op x y = - let n :: Int - n = fromInteger $ natVal' (proxy# :: Proxy# n) - v = op x y - -- - cmp :: forall t. (Elt t, Num t, Bits t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) - => Exp (Vec n Bool) - -> Exp Bool - cmp u = - let u' = mkPrimUnary (PrimFromBool bitType integralType) u :: Exp t - in mkEq (constant ((1 `unsafeShiftL` n) - 1)) u' - in - if n P.<= 8 then cmp @Word8 v else - if n P.<= 16 then cmp @Word16 v else - if n P.<= 32 then cmp @Word32 v else - if n P.<= 64 then cmp @Word64 v else - if n P.<= 128 then cmp @Word128 v else - internalError "Can not handle Vec types with more than 128 lanes" + x == y = vand (x ==* y) + x /= y = vor (x /=* y) diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 0685ce8c4..61fc09592 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -58,6 +58,7 @@ module Data.Array.Accelerate.Smart ( insert, mkInsert, shuffle, select, + vand, vor, -- ** Smart constructors for primitive functions -- *** Operators from Num @@ -129,6 +130,7 @@ import qualified Data.Array.Accelerate.Sugar.Shape as Sugar import qualified Data.Primitive.Vec as Prim +import Data.Bits ( Bits, unsafeShiftL ) import Data.Kind import Data.Text.Lazy.Builder import Formatting hiding ( splat ) @@ -1229,6 +1231,48 @@ select (Exp mask) (Exp tt) (Exp ff) = Exp $ go (vecR @n @a) tt ff floating _ = error "impossible" +-- | Return 'True' if all lanes of the vector are 'True' +-- +-- @since 1.4.0.0 +-- +vand :: forall n. KnownNat n => Exp (Vec n Bool) -> Exp Bool +vand v = + let n :: Int + n = fromInteger $ natVal' (proxy# :: Proxy# n) + -- + cmp :: forall t. (Elt t, Num t, Bits t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) + => Exp Bool + cmp = mkEq (constant ((1 `unsafeShiftL` n) - 1)) + (mkPrimUnary (PrimFromBool bitType integralType) v :: Exp t) + in + if n <= 8 then cmp @Word8 else + if n <= 16 then cmp @Word16 else + if n <= 32 then cmp @Word32 else + if n <= 64 then cmp @Word64 else + if n <= 128 then cmp @Word128 else + internalError "Can not handle Vec types with more than 128 lanes" + +-- | Return 'True' if any lane of the vector is 'True' +-- +-- @since 1.4.0.0 +-- +vor :: forall n. KnownNat n => Exp (Vec n Bool) -> Exp Bool +vor v = + let n :: Int + n = fromInteger $ natVal' (proxy# :: Proxy# n) + -- + cmp :: forall t. (Elt t, Num t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) + => Exp Bool + cmp = mkNEq (constant 0) (mkPrimUnary (PrimFromBool bitType integralType) v :: Exp t) + in + if n <= 8 then cmp @Word8 else + if n <= 16 then cmp @Word16 else + if n <= 32 then cmp @Word32 else + if n <= 64 then cmp @Word64 else + if n <= 128 then cmp @Word128 else + internalError "Can not handle Vec types with more than 128 lanes" + + -- Smart constructors for primitive applications -- From eb777139c05bfec70117744fdfe737bfa7634aa3 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 3 Oct 2022 23:45:08 +0200 Subject: [PATCH 59/86] export vnot, &&*, ||* --- src/Data/Array/Accelerate.hs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 348c67a25..f17b1c356 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -380,8 +380,9 @@ module Data.Array.Accelerate ( sfoldl, -- *** Logical operations - (&&), (||), not, - (&&!), (||!), + not, vnot, + (&&), (&&!), (&&*), + (||), (||!), (||*), vand, vor, -- *** Numeric operations From fe8a0c855e79636cbf518a157acb32e01f670c96 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 3 Oct 2022 23:45:46 +0200 Subject: [PATCH 60/86] fix Ord instance for Vec This should now match the behaviour of the tuple instances --- src/Data/Array/Accelerate/Classes/Ord.hs | 30 ++++++------------------ 1 file changed, 7 insertions(+), 23 deletions(-) diff --git a/src/Data/Array/Accelerate/Classes/Ord.hs b/src/Data/Array/Accelerate/Classes/Ord.hs index 82a1c11ef..dc53675a7 100644 --- a/src/Data/Array/Accelerate/Classes/Ord.hs +++ b/src/Data/Array/Accelerate/Classes/Ord.hs @@ -30,9 +30,9 @@ module Data.Array.Accelerate.Classes.Ord ( ) where -import Data.Array.Accelerate.AST ( PrimFun(..), BitOrMask ) import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Classes.Eq +import Data.Array.Accelerate.Classes.VEq import Data.Array.Accelerate.Error import Data.Array.Accelerate.Pattern.Ordering import Data.Array.Accelerate.Pattern.Tuple @@ -44,16 +44,12 @@ import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import {-# SOURCE #-} Data.Array.Accelerate.Classes.VOrd -import Data.Bits import Data.Char import Language.Haskell.TH.Extra hiding ( Exp ) import Prelude ( ($), Num(..), Maybe(..), String, show, error, unlines, return, concat, map, mapM ) import Text.Printf import qualified Prelude as P -import GHC.Exts -import GHC.TypeLits - infix 4 < infix 4 > @@ -245,25 +241,13 @@ instance VOrd n a => Ord (Vec n a) where (<=) = vcmp (<=*) (>=) = vcmp (>=*) -vcmp :: forall n a. KnownNat n +vcmp :: forall n a. VOrd n a => (Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool)) -> (Exp (Vec n a) -> Exp (Vec n a) -> Exp Bool) -vcmp op x y = - let n :: Int - n = fromInteger $ natVal' (proxy# :: Proxy# n) - v = op x y - -- - cmp :: forall t. (Elt t, Num t, Bits t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) - => Exp (Vec n Bool) - -> Exp Bool - cmp u = - let u' = mkPrimUnary (PrimFromBool bitType integralType) u :: Exp t - in mkEq (constant ((1 `unsafeShiftL` n) - 1)) u' +vcmp cmp x y = + let go [u] [_] = u + go (u:us) (v:vs) = u || (v && go us vs) + go _ _ = internalError "unexpected vector encoding" in - if n P.<= 8 then cmp @Word8 v else - if n P.<= 16 then cmp @Word16 v else - if n P.<= 32 then cmp @Word32 v else - if n P.<= 64 then cmp @Word64 v else - if n P.<= 128 then cmp @Word128 v else - internalError "Can not handle Vec types with more than 128 lanes" + go (mkUnpack (cmp x y)) (mkUnpack (x ==* y)) From 4bebdae25658a4db3b023c27449becca45e8dda7 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 16 May 2023 19:02:46 +0200 Subject: [PATCH 61/86] arbitrary width signed and unsigned integers --- src/Data/Array/Accelerate/Analysis/Hash.hs | 12 +- src/Data/Array/Accelerate/Analysis/Match.hs | 14 +-- src/Data/Array/Accelerate/Type.hs | 124 +++++++++----------- 3 files changed, 63 insertions(+), 87 deletions(-) diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 048596014..2b03d7d97 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -498,16 +498,8 @@ encodeIntegralType (SingleIntegralType t) = encodeSingleIntegralType t encodeIntegralType (VectorIntegralType n t) = intHost $(hashQ "Vec") <> int8 (fromIntegral (natVal' n)) <> encodeSingleIntegralType t encodeSingleIntegralType :: SingleIntegralType t -> Builder -encodeSingleIntegralType TypeInt8 = intHost $(hashQ "Int8") -encodeSingleIntegralType TypeInt16 = intHost $(hashQ "Int16") -encodeSingleIntegralType TypeInt32 = intHost $(hashQ "Int32") -encodeSingleIntegralType TypeInt64 = intHost $(hashQ "Int64") -encodeSingleIntegralType TypeInt128 = intHost $(hashQ "Int128") -encodeSingleIntegralType TypeWord8 = intHost $(hashQ "Word8") -encodeSingleIntegralType TypeWord16 = intHost $(hashQ "Word16") -encodeSingleIntegralType TypeWord32 = intHost $(hashQ "Word32") -encodeSingleIntegralType TypeWord64 = intHost $(hashQ "Word64") -encodeSingleIntegralType TypeWord128 = intHost $(hashQ "Word128") +encodeSingleIntegralType (TypeInt n) = intHost $(hashQ "Int") <> intHost n +encodeSingleIntegralType (TypeWord n) = intHost $(hashQ "Word") <> intHost n encodeFloatingType :: FloatingType t -> Builder encodeFloatingType (SingleFloatingType t) = encodeSingleFloatingType t diff --git a/src/Data/Array/Accelerate/Analysis/Match.hs b/src/Data/Array/Accelerate/Analysis/Match.hs index bb6833a86..a76e342b0 100644 --- a/src/Data/Array/Accelerate/Analysis/Match.hs +++ b/src/Data/Array/Accelerate/Analysis/Match.hs @@ -943,17 +943,9 @@ matchIntegralType _ _ = Nothing {-# INLINEABLE matchSingleIntegralType #-} matchSingleIntegralType :: SingleIntegralType s -> SingleIntegralType t -> Maybe (s :~: t) -matchSingleIntegralType TypeInt8 TypeInt8 = Just Refl -matchSingleIntegralType TypeInt64 TypeInt64 = Just Refl -matchSingleIntegralType TypeInt32 TypeInt32 = Just Refl -matchSingleIntegralType TypeInt16 TypeInt16 = Just Refl -matchSingleIntegralType TypeInt128 TypeInt128 = Just Refl -matchSingleIntegralType TypeWord8 TypeWord8 = Just Refl -matchSingleIntegralType TypeWord64 TypeWord64 = Just Refl -matchSingleIntegralType TypeWord32 TypeWord32 = Just Refl -matchSingleIntegralType TypeWord16 TypeWord16 = Just Refl -matchSingleIntegralType TypeWord128 TypeWord128 = Just Refl -matchSingleIntegralType _ _ = Nothing +matchSingleIntegralType (TypeInt n) (TypeInt m) | m == n = Just (unsafeCoerce Refl) +matchSingleIntegralType (TypeWord n) (TypeWord m) | m == n = Just (unsafeCoerce Refl) +matchSingleIntegralType _ _ = Nothing {-# INLINEABLE matchFloatingType #-} matchFloatingType :: FloatingType s -> FloatingType t -> Maybe (s :~: t) diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index acfe32e0a..4cc0d4dee 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -7,6 +7,7 @@ {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} @@ -66,14 +67,16 @@ import Data.Numeric.Float128 import Data.Bits import Data.Int +import Data.Kind import Data.Type.Equality import Data.WideWord.Int128 import Data.WideWord.Word128 import Data.Word import Formatting -import Language.Haskell.TH.Extra +import Language.Haskell.TH.Extra hiding ( Type ) import Numeric.Half import Text.Printf +import Unsafe.Coerce import GHC.Prim import GHC.TypeNats @@ -105,17 +108,14 @@ data IntegralType a where SingleIntegralType :: SingleIntegralType a -> IntegralType a VectorIntegralType :: KnownNat n => Proxy# n -> SingleIntegralType a -> IntegralType (Vec n a) -data SingleIntegralType a where - TypeInt8 :: SingleIntegralType Int8 - TypeInt16 :: SingleIntegralType Int16 - TypeInt32 :: SingleIntegralType Int32 - TypeInt64 :: SingleIntegralType Int64 - TypeInt128 :: SingleIntegralType Int128 - TypeWord8 :: SingleIntegralType Word8 - TypeWord16 :: SingleIntegralType Word16 - TypeWord32 :: SingleIntegralType Word32 - TypeWord64 :: SingleIntegralType Word64 - TypeWord128 :: SingleIntegralType Word128 +-- Note: [Arbitrary width integers] +-- +-- We support arbitrary width signed and unsigned integers, but for almost all +-- cases you should use the type synonyms generated below. +-- +data SingleIntegralType :: Type -> Type where + TypeInt :: {-# UNPACK #-} !Int -> SingleIntegralType a + TypeWord :: {-# UNPACK #-} !Int -> SingleIntegralType a data FloatingType a where SingleFloatingType :: SingleFloatingType a -> FloatingType a @@ -129,6 +129,44 @@ data SingleFloatingType a where TypeFloat64 :: SingleFloatingType Float64 TypeFloat128 :: SingleFloatingType Float128 + +typeIntBits :: SingleIntegralType a -> Maybe Int +typeIntBits (TypeInt x) = Just x +typeIntBits _ = Nothing + +typeWordBits :: SingleIntegralType a -> Maybe Int +typeWordBits (TypeWord x) = Just x +typeWordBits _ = Nothing + +-- Generate pattern synonyms for fixed sized signed and unsigned integers. In +-- practice this is what we'll use most of the time, but occasionally we need to +-- convert via an arbitrary width integer (e.g. coercing between a BitMask and +-- an integral value). +-- +-- SEE: [Arbitrary width integers] +-- +runQ $ do + let + integralTypes :: [Integer] + integralTypes = [8,16,32,64,128] + + mkIntegral :: String -> Integer -> Q [Dec] + mkIntegral name bits = do + let t = conT $ mkName $ printf "%s%d" name bits + e = varE $ mkName $ printf "type%sBits" name + c = mkName $ printf "Type%s%d" name bits + -- + a <- newName "a" + sequence [ patSynSigD c (forallT [plainInvisTV a specifiedSpec] (return [TupleT 0]) [t| () => ($(varT a) ~ $t) => SingleIntegralType $(varT a) |]) + , patSynD c (prefixPatSyn []) (explBidir [clause [] (normalB [| $(conE (mkName (printf "Type%s" name))) $(litE (integerL bits)) |]) [] ]) + [p| (\x -> ($e x, unsafeCoerce Refl) -> (Just $(litP (integerL bits)), Refl :: $(varT a) :~: $t)) |] + ] + -- + cs <- pragCompleteD [ mkName (printf "Type%s%d" name bits) | name <- ["Int", "Word" :: String], bits <- integralTypes ] Nothing + ss <- mapM (mkIntegral "Int") integralTypes + us <- mapM (mkIntegral "Word") integralTypes + return (cs : concat ss ++ concat us) + instance Show (IntegralType a) where show (SingleIntegralType t) = show t show (VectorIntegralType n t) = printf "<%d x %s>" (natVal' n) (show t) @@ -138,16 +176,8 @@ instance Show (FloatingType a) where show (VectorFloatingType n t) = printf "<%d x %s>" (natVal' n) (show t) instance Show (SingleIntegralType a) where - show TypeInt8 = "Int8" - show TypeInt16 = "Int16" - show TypeInt32 = "Int32" - show TypeInt64 = "Int64" - show TypeInt128 = "Int128" - show TypeWord8 = "Word8" - show TypeWord16 = "Word16" - show TypeWord32 = "Word32" - show TypeWord64 = "Word64" - show TypeWord128 = "Word128" + show (TypeInt n) = printf "Int%d" n + show (TypeWord n) = printf "Word%d" n instance Show (SingleFloatingType a) where show TypeFloat16 = "Float16" @@ -174,16 +204,8 @@ formatIntegralType = later $ \case formatSingleIntegralType :: Format r (SingleIntegralType a -> r) formatSingleIntegralType = later $ \case - TypeInt8 -> "Int8" - TypeInt16 -> "Int16" - TypeInt32 -> "Int32" - TypeInt64 -> "Int64" - TypeInt128 -> "Int128" - TypeWord8 -> "Word8" - TypeWord16 -> "Word16" - TypeWord32 -> "Word32" - TypeWord64 -> "Word64" - TypeWord128 -> "Word128" + TypeInt n -> bformat ("Int" % int) n + TypeWord n -> bformat ("Word" % int) n formatFloatingType :: Format r (FloatingType a -> r) formatFloatingType = later $ \case @@ -230,16 +252,8 @@ rnfIntegralType (SingleIntegralType t) = rnfSingleIntegralType t rnfIntegralType (VectorIntegralType !_ t) = rnfSingleIntegralType t rnfSingleIntegralType :: SingleIntegralType t -> () -rnfSingleIntegralType TypeInt8 = () -rnfSingleIntegralType TypeInt16 = () -rnfSingleIntegralType TypeInt32 = () -rnfSingleIntegralType TypeInt64 = () -rnfSingleIntegralType TypeInt128 = () -rnfSingleIntegralType TypeWord8 = () -rnfSingleIntegralType TypeWord16 = () -rnfSingleIntegralType TypeWord32 = () -rnfSingleIntegralType TypeWord64 = () -rnfSingleIntegralType TypeWord128 = () +rnfSingleIntegralType (TypeInt !_) = () +rnfSingleIntegralType (TypeWord !_) = () rnfFloatingType :: FloatingType t -> () rnfFloatingType (SingleFloatingType t) = rnfSingleFloatingType t @@ -308,16 +322,8 @@ liftIntegralType (SingleIntegralType t) = [|| SingleIntegralType $$(liftSingle liftIntegralType (VectorIntegralType _ t) = [|| VectorIntegralType proxy# $$(liftSingleIntegralType t) ||] liftSingleIntegralType :: SingleIntegralType t -> CodeQ (SingleIntegralType t) -liftSingleIntegralType TypeInt8 = [|| TypeInt8 ||] -liftSingleIntegralType TypeInt16 = [|| TypeInt16 ||] -liftSingleIntegralType TypeInt32 = [|| TypeInt32 ||] -liftSingleIntegralType TypeInt64 = [|| TypeInt64 ||] -liftSingleIntegralType TypeInt128 = [|| TypeInt128 ||] -liftSingleIntegralType TypeWord8 = [|| TypeWord8 ||] -liftSingleIntegralType TypeWord16 = [|| TypeWord16 ||] -liftSingleIntegralType TypeWord32 = [|| TypeWord32 ||] -liftSingleIntegralType TypeWord64 = [|| TypeWord64 ||] -liftSingleIntegralType TypeWord128 = [|| TypeWord128 ||] +liftSingleIntegralType (TypeInt n) = [|| TypeInt n ||] +liftSingleIntegralType (TypeWord n) = [|| TypeWord n ||] liftFloatingType :: FloatingType t -> CodeQ (FloatingType t) liftFloatingType (SingleFloatingType t) = [|| SingleFloatingType $$(liftSingleFloatingType t) ||] @@ -354,24 +360,11 @@ class IsSingleIntegral a where class IsSingleFloating a where singleFloatingType :: SingleFloatingType a --- Type-level bit sizes --- -------------------- - -- | Constraint that values of these two types have the same bit width -- type BitSizeEq a b = (BitSize a == BitSize b) ~ 'True type family BitSize a :: Nat - --- Instances --- --------- --- --- Generate instances for the IsX classes. It would be preferable to do this --- automatically based on the members of the IntegralType (etc.) representations --- (see for example FromIntegral.hs) but TH phase restrictions would require us --- to split this into a separate module. --- - runQ $ do let integralTypes :: [Integer] @@ -466,7 +459,6 @@ instance IsBit Bit where instance KnownNat n => IsBit (Vec n Bit) where bitType = TypeMask proxy# - -- Determine the underlying type of a Haskell Int and Word -- runQ [d| type INT = $( From b70d5cc4e8f2ac4a81f8b5528c20cd909b96279f Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 16 May 2023 19:06:02 +0200 Subject: [PATCH 62/86] actually unsafe coerce --- src/Data/Array/Accelerate/Analysis/Match.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Data/Array/Accelerate/Analysis/Match.hs b/src/Data/Array/Accelerate/Analysis/Match.hs index a76e342b0..1bbbbd018 100644 --- a/src/Data/Array/Accelerate/Analysis/Match.hs +++ b/src/Data/Array/Accelerate/Analysis/Match.hs @@ -535,7 +535,7 @@ matchOpenExp (Case eR1 e1 rhs1 def1) (Case eR2 e2 rhs2 def2) where matchCaseEqs :: TagType tag -> [(tag, OpenExp env aenv a)] -> [(tag, OpenExp env aenv b)] -> Maybe (a :~: b) matchCaseEqs _ [] [] - = unsafeCoerce Refl + = Just (unsafeCoerce Refl) matchCaseEqs tR ((s,x):xs) ((t,y):ys) | TagDict <- tagDict tR , s == t @@ -546,7 +546,7 @@ matchOpenExp (Case eR1 e1 rhs1 def1) (Case eR2 e2 rhs2 def2) = Nothing matchCaseDef :: Maybe (OpenExp env aenv a) -> Maybe (OpenExp env aenv b) -> Maybe (a :~: b) - matchCaseDef Nothing Nothing = unsafeCoerce Refl + matchCaseDef Nothing Nothing = Just (unsafeCoerce Refl) matchCaseDef (Just x) (Just y) = matchOpenExp x y matchCaseDef _ _ = Nothing From d50e6de73a6436826331246f83355e428f6480e5 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 16 May 2023 19:08:32 +0200 Subject: [PATCH 63/86] rename Coerce to Bitcast --- src/Data/Array/Accelerate/AST.hs | 10 ++++----- src/Data/Array/Accelerate/Analysis/Hash.hs | 2 +- src/Data/Array/Accelerate/Analysis/Match.hs | 2 +- .../Array/Accelerate/Classes/RealFloat.hs | 2 +- src/Data/Array/Accelerate/Interpreter.hs | 12 +++++----- src/Data/Array/Accelerate/Pretty/Graphviz.hs | 2 +- src/Data/Array/Accelerate/Pretty/Print.hs | 2 +- src/Data/Array/Accelerate/Smart.hs | 10 ++++----- src/Data/Array/Accelerate/Trafo/Fusion.hs | 2 +- src/Data/Array/Accelerate/Trafo/Sharing.hs | 6 ++--- src/Data/Array/Accelerate/Trafo/Shrink.hs | 22 +++++++++---------- src/Data/Array/Accelerate/Trafo/Simplify.hs | 4 ++-- .../Array/Accelerate/Trafo/Substitution.hs | 4 ++-- 13 files changed, 40 insertions(+), 40 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 7532daf6f..97b4ae93e 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -653,7 +653,7 @@ data OpenExp env aenv t where -- The types must have the same bit size, but that constraint is not include -- at this point because GHC's typelits solver is often not powerful enough to -- discharge that constraint. ---TLM 2022-09-20 - Coerce :: ScalarType a + Bitcast :: ScalarType a -> ScalarType b -> OpenExp env aenv a -> OpenExp env aenv b @@ -862,7 +862,7 @@ expType = \case Shape (Var repr _) -> shapeType $ arrayRshape repr ShapeSize{} -> TupRsingle (scalarType @INT) Undef tR -> TupRsingle tR - Coerce _ tR _ -> TupRsingle tR + Bitcast _ tR _ -> TupRsingle tR primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b) @@ -1140,7 +1140,7 @@ rnfOpenExp topExp = LinearIndex a ix -> rnfArrayVar a `seq` rnfE ix Shape a -> rnfArrayVar a ShapeSize shr sh -> rnfShapeR shr `seq` rnfE sh - Coerce t1 t2 e -> rnfScalarType t1 `seq` rnfScalarType t2 `seq` rnfE e + Bitcast t1 t2 e -> rnfScalarType t1 `seq` rnfScalarType t2 `seq` rnfE e rnfExpVar :: ExpVar env t -> () rnfExpVar = rnfVar rnfScalarType @@ -1358,7 +1358,7 @@ liftOpenExp pexp = LinearIndex a ix -> [|| LinearIndex $$(liftArrayVar a) $$(liftE ix) ||] Shape a -> [|| Shape $$(liftArrayVar a) ||] ShapeSize shr ix -> [|| ShapeSize $$(liftShapeR shr) $$(liftE ix) ||] - Coerce t1 t2 e -> [|| Coerce $$(liftScalarType t1) $$(liftScalarType t2) $$(liftE e) ||] + Bitcast t1 t2 e -> [|| Bitcast $$(liftScalarType t1) $$(liftScalarType t2) $$(liftE e) ||] liftELeftHandSide :: ELeftHandSide t env env' -> CodeQ (ELeftHandSide t env env') liftELeftHandSide = liftLeftHandSide liftScalarType @@ -1503,5 +1503,5 @@ formatExpOp = later $ \case LinearIndex{} -> "LinearIndex" Shape{} -> "Shape" ShapeSize{} -> "ShapeSize" - Coerce{} -> "Coerce" + Bitcast{} -> "Coerce" diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 2b03d7d97..1fd54f0b3 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -341,7 +341,7 @@ encodeOpenExp exp = Shape a -> intHost $(hashQ "Shape") <> encodeArrayVar a ShapeSize _ sh -> intHost $(hashQ "ShapeSize") <> travE sh Foreign _ _ f e -> intHost $(hashQ "Foreign") <> travF f <> travE e - Coerce _ tp e -> intHost $(hashQ "Coerce") <> encodeScalarType tp <> travE e + Bitcast _ tp e -> intHost $(hashQ "Bitcast") <> encodeScalarType tp <> travE e encodeArrayVar :: ArrayVar aenv a -> Builder encodeArrayVar (Var repr v) = encodeArrayType repr <> encodeIdx v diff --git a/src/Data/Array/Accelerate/Analysis/Match.hs b/src/Data/Array/Accelerate/Analysis/Match.hs index 1bbbbd018..147f93adb 100644 --- a/src/Data/Array/Accelerate/Analysis/Match.hs +++ b/src/Data/Array/Accelerate/Analysis/Match.hs @@ -598,7 +598,7 @@ matchOpenExp (ShapeSize _ sh1) (ShapeSize _ sh2) matchOpenExp (Undef t1) (Undef t2) = matchScalarType t1 t2 -matchOpenExp (Coerce _ t1 e1) (Coerce _ t2 e2) +matchOpenExp (Bitcast _ t1 e1) (Bitcast _ t2 e2) | Just Refl <- matchScalarType t1 t2 , Just Refl <- matchOpenExp e1 e2 = Just Refl diff --git a/src/Data/Array/Accelerate/Classes/RealFloat.hs b/src/Data/Array/Accelerate/Classes/RealFloat.hs index 4ad290078..2cbbee5d3 100644 --- a/src/Data/Array/Accelerate/Classes/RealFloat.hs +++ b/src/Data/Array/Accelerate/Classes/RealFloat.hs @@ -309,7 +309,7 @@ mkBitcast' :: forall b a n. (IsScalar (VecR n a), IsScalar (VecR n b), BitSizeEq (EltR a) (EltR b)) => Exp (Vec n a) -> Exp (Vec n b) -mkBitcast' (Exp a) = mkExp $ Coerce (scalarType @(VecR n a)) (scalarType @(VecR n b)) a +mkBitcast' (Exp a) = mkExp $ Bitcast (scalarType @(VecR n a)) (scalarType @(VecR n b)) a defaultFloatRadix :: forall a. P.RealFloat a => Exp a -> Exp Int defaultFloatRadix _ = P.fromInteger (P.floatRadix (undefined::a)) diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index bd5a4c2d5..60a2d55bb 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -40,7 +40,7 @@ module Data.Array.Accelerate.Interpreter ( run, run1, runN, -- Internal (hidden) - evalPrim, evalCoerceScalar, atraceOp, + evalPrim, evalBitcastScalar, atraceOp, ) where @@ -986,7 +986,7 @@ evalOpenExp pexp env aenv = Shape acc -> shape $ snd $ evalA acc ShapeSize shr sh -> size shr (evalE sh) Foreign _ _ f e -> evalOpenFun f Empty Empty $ evalE e - Coerce t1 t2 e -> evalCoerceScalar t1 t2 (evalE e) + Bitcast t1 t2 e -> evalBitcastScalar t1 t2 (evalE e) -- Coercions @@ -995,8 +995,8 @@ evalOpenExp pexp env aenv = -- Coercion between two scalar types. We require that the size of the source and -- destination values are equal (this is not checked at this point). -- -evalCoerceScalar :: ScalarType a -> ScalarType b -> a -> b -evalCoerceScalar = scalar +evalBitcastScalar :: ScalarType a -> ScalarType b -> a -> b +evalBitcastScalar = scalar where scalar :: ScalarType a -> ScalarType b -> a -> b scalar (NumScalarType a) = num a @@ -1005,14 +1005,14 @@ evalCoerceScalar = scalar bit :: BitType a -> ScalarType b -> a -> b bit TypeBit = \case BitScalarType TypeBit -> id - _ -> internalError "evalCoerceScalar @Bit" + _ -> internalError "evalBitcastScalar @Bit" bit (TypeMask _) = \case NumScalarType b -> num' b BitScalarType b -> bit' b where bit' :: BitType b -> Vec n Bit -> b bit' TypeMask{} = unsafeCoerce - bit' TypeBit = internalError "evalCoerceScalar @Bit" + bit' TypeBit = internalError "evalBitcastScalar @Bit" num' :: NumType b -> Vec n Bit -> b num' (IntegralNumType b) = integral' b diff --git a/src/Data/Array/Accelerate/Pretty/Graphviz.hs b/src/Data/Array/Accelerate/Pretty/Graphviz.hs index 93b9af69b..6bf0b6b21 100644 --- a/src/Data/Array/Accelerate/Pretty/Graphviz.hs +++ b/src/Data/Array/Accelerate/Pretty/Graphviz.hs @@ -535,5 +535,5 @@ fvOpenExp env aenv = fv fv (Case _ e rhs def) = concat [ fv e, concat [ fv c | (_,c) <- rhs ], maybe [] fv def ] fv (Cond p t e) = concat [ fv p, fv t, fv e ] fv (While p f x) = concat [ fvF p, fvF f, fv x ] - fv (Coerce _ _ e) = fv e + fv (Bitcast _ _ e) = fv e diff --git a/src/Data/Array/Accelerate/Pretty/Print.hs b/src/Data/Array/Accelerate/Pretty/Print.hs index 74b3e9aa5..1bf370d91 100644 --- a/src/Data/Array/Accelerate/Pretty/Print.hs +++ b/src/Data/Array/Accelerate/Pretty/Print.hs @@ -444,7 +444,7 @@ prettyOpenExp ctx env aenv exp = ShapeSize _ sh -> ppF1 "shapeSize" (ppE sh) Index arr ix -> ppF2 (Operator (pretty '!') Infix L 9) (ppA arr) (ppE ix) LinearIndex arr ix -> ppF2 (Operator "!!" Infix L 9) (ppA arr) (ppE ix) - Coerce _ tp x -> ppF1 (Operator (withTypeRep tp "coerce") App L 10) (ppE x) + Bitcast _ tp x -> ppF1 (Operator (withTypeRep tp "bitcast") App L 10) (ppE x) Undef tp -> withTypeRep tp "undef" where diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 61fc09592..2fb6c48a1 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -618,7 +618,7 @@ data PreSmartExp acc exp t where Undef :: ScalarType t -> PreSmartExp acc exp t - Coerce :: ScalarType a + Bitcast :: ScalarType a -> ScalarType b -> exp a -> PreSmartExp acc exp b @@ -925,7 +925,7 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where ShapeSize _ _ -> TupRsingle (scalarType @INT) Foreign tR _ _ _ -> tR Undef tR -> TupRsingle tR - Coerce _ tR _ -> TupRsingle tR + Bitcast _ tR _ -> TupRsingle tR -- Smart constructors @@ -1497,7 +1497,7 @@ mkFromBool = mkPrimUnary $ PrimFromBool bitType integralType -- Other conversions mkBitcast :: forall b a. (IsScalar (EltR a), IsScalar (EltR b), BitSizeEq (EltR a) (EltR b)) => Exp a -> Exp b -mkBitcast (Exp a) = mkExp $ Coerce (scalarType @(EltR a)) (scalarType @(EltR b)) a +mkBitcast (Exp a) = mkExp $ Bitcast (scalarType @(EltR a)) (scalarType @(EltR b)) a mkCoerce :: Coerce (EltR a) (EltR b) => Exp a -> Exp b mkCoerce (Exp a) = Exp $ mkCoerce' a @@ -1506,7 +1506,7 @@ class Coerce a b where mkCoerce' :: SmartExp a -> SmartExp b instance {-# OVERLAPS #-} (IsScalar a, IsScalar b, BitSizeEq a b) => Coerce a b where - mkCoerce' = SmartExp . Coerce (scalarType @a) (scalarType @b) + mkCoerce' = SmartExp . Bitcast (scalarType @a) (scalarType @b) instance (Coerce a1 b1, Coerce a2 b2) => Coerce (a1, a2) (b1, b2) where mkCoerce' a = SmartExp $ Pair (mkCoerce' $ SmartExp $ Prj PairIdxLeft a) (mkCoerce' $ SmartExp $ Prj PairIdxRight a) @@ -1669,5 +1669,5 @@ formatPreExpOp = later $ \case Shape{} -> "Shape" ShapeSize{} -> "ShapeSize" Foreign{} -> "Foreign" - Coerce{} -> "Coerce" + Bitcast{} -> "Bitcast" diff --git a/src/Data/Array/Accelerate/Trafo/Fusion.hs b/src/Data/Array/Accelerate/Trafo/Fusion.hs index 8f82f6352..7b0f4fd3e 100644 --- a/src/Data/Array/Accelerate/Trafo/Fusion.hs +++ b/src/Data/Array/Accelerate/Trafo/Fusion.hs @@ -1485,7 +1485,7 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en PrimApp g x -> PrimApp g (cvtE x) ShapeSize shR' sh -> ShapeSize shR' (cvtE sh) While p f x -> While (replaceF sh' f' avar p) (replaceF sh' f' avar f) (cvtE x) - Coerce t1 t2 e -> Coerce t1 t2 (cvtE e) + Bitcast t1 t2 e -> Bitcast t1 t2 (cvtE e) Shape a | Just Refl <- matchVar a avar -> Stats.substitution "replaceE/shape" sh' diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index 798e7abd7..071ca6dac 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -773,7 +773,7 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp Shape _ a -> AST.Shape (cvtAvar a) ShapeSize shr e -> AST.ShapeSize shr (cvt e) Foreign repr ff f e -> AST.Foreign repr ff (convertSmartFun config (typeR e) f) (cvt e) - Coerce t1 t2 e -> AST.Coerce t1 t2 (cvt e) + Bitcast t1 t2 e -> AST.Bitcast t1 t2 (cvt e) cvtPrj :: forall a b c env1 aenv1. PairIdx (a, b) c -> AST.OpenExp env1 aenv1 (a, b) -> AST.OpenExp env1 aenv1 c cvtPrj PairIdxLeft (AST.Pair a _) = a @@ -1906,7 +1906,7 @@ makeOccMapSharingExp config accOccMap expOccMap = travE Foreign tp ff f e -> do (e', h) <- travE lvl e return (Foreign tp ff f e', h+1) - Coerce t1 t2 e -> travE1 (Coerce t1 t2) e + Bitcast t1 t2 e -> travE1 (Bitcast t1 t2) e where traverseAcc :: HasCallStack => Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int) @@ -2809,7 +2809,7 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp Shape shr a -> travA (Shape shr) a ShapeSize shr e -> travE1 (ShapeSize shr) e Foreign tp ff f e -> travE1 (Foreign tp ff f) e - Coerce t1 t2 e -> travE1 (Coerce t1 t2) e + Bitcast t1 t2 e -> travE1 (Bitcast t1 t2) e where travE1 :: HasCallStack => (ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t) diff --git a/src/Data/Array/Accelerate/Trafo/Shrink.hs b/src/Data/Array/Accelerate/Trafo/Shrink.hs index d5cab4a95..f2840de2b 100644 --- a/src/Data/Array/Accelerate/Trafo/Shrink.hs +++ b/src/Data/Array/Accelerate/Trafo/Shrink.hs @@ -240,13 +240,13 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE lIMIT = 1 cheap :: OpenExp env aenv t -> Bool - cheap (Evar _) = True - cheap (Pair e1 e2) = cheap e1 && cheap e2 - cheap Nil = True - cheap Const{} = True - cheap Undef{} = True - cheap (Coerce _ _ e) = cheap e - cheap _ = False + cheap (Evar _) = True + cheap (Pair e1 e2) = cheap e1 && cheap e2 + cheap Nil = True + cheap Const{} = True + cheap Undef{} = True + cheap (Bitcast _ _ e) = cheap e + cheap _ = False shrinkE :: HasCallStack => OpenExp env aenv t -> (Any, OpenExp env aenv t) shrinkE exp = case exp of @@ -307,7 +307,7 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE Shape a -> pure (Shape a) ShapeSize shr sh -> ShapeSize shr <$> shrinkE sh Foreign repr ff f e -> Foreign repr ff <$> shrinkF f <*> shrinkE e - Coerce t1 t2 e -> Coerce t1 t2 <$> shrinkE e + Bitcast t1 t2 e -> Bitcast t1 t2 <$> shrinkE e shrinkF :: HasCallStack => OpenFun env aenv t -> (Any, OpenFun env aenv t) shrinkF = first Any . shrinkFun @@ -455,7 +455,7 @@ shrinkPreAcc shrinkAcc reduceAcc = Stats.substitution "shrinkA" shrinkA Intersect sh sz -> Intersect (shrinkE sh) (shrinkE sz) Union sh sz -> Union (shrinkE sh) (shrinkE sz) Foreign ff f e -> Foreign ff (shrinkF f) (shrinkE e) - Coerce e -> Coerce (shrinkE e) + Bitcast e -> Bitcast (shrinkE e) shrinkF :: OpenFun env aenv' f -> OpenFun env aenv' f shrinkF (Lam f) = Lam (shrinkF f) @@ -508,7 +508,7 @@ usesOfExp range = countE Shape _ -> Finite 0 ShapeSize _ sh -> countE sh Foreign _ _ _ e -> countE e - Coerce _ _ e -> countE e + Bitcast _ _ e -> countE e usesOfFun :: VarsRange env -> OpenFun env aenv f -> Count usesOfFun range (Lam lhs f) = usesOfFun (weakenVarsRange lhs range) f @@ -598,7 +598,7 @@ usesOfPreAcc withShape countAcc idx = count | withShape -> countAvar a | otherwise -> 0 Foreign _ _ _ e -> countE e - Coerce _ _ e -> countE e + Bitcast _ _ e -> countE e countME :: Maybe (OpenExp env aenv e) -> Int countME = maybe 0 countE diff --git a/src/Data/Array/Accelerate/Trafo/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Simplify.hs index a3788b1e1..c78da38ca 100644 --- a/src/Data/Array/Accelerate/Trafo/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Simplify.hs @@ -296,7 +296,7 @@ simplifyOpenExp env = first getAny . cvtE ShapeSize shr sh -> shapeSize shr (cvtE sh) Foreign tp ff f e -> Foreign tp ff <$> first Any (simplifyOpenFun EmptyExp f) <*> cvtE e While p f x -> While <$> cvtF env p <*> cvtF env f <*> cvtE x - Coerce t1 t2 e -> Coerce t1 t2 <$> cvtE e + Bitcast t1 t2 e -> Bitcast t1 t2 <$> cvtE e cvtE' :: Gamma env' env' aenv -> OpenExp env' aenv e' -> (Any, OpenExp env' aenv e') cvtE' env' = first Any . simplifyOpenExp env' @@ -604,7 +604,7 @@ summariseOpenExp = (terms +~ 1) . goE Shape a -> travA a ShapeSize _ sh -> travE sh PrimApp f x -> travPrimFun f +++ travE x - Coerce _ _ e -> travE e + Bitcast _ _ e -> travE e travPrimFun :: PrimFun f -> Stats travPrimFun = (ops +~ 1) . goF diff --git a/src/Data/Array/Accelerate/Trafo/Substitution.hs b/src/Data/Array/Accelerate/Trafo/Substitution.hs index 72f84e997..d187da63e 100644 --- a/src/Data/Array/Accelerate/Trafo/Substitution.hs +++ b/src/Data/Array/Accelerate/Trafo/Substitution.hs @@ -172,7 +172,7 @@ inlineVars lhsBound expr bound Shape a -> Just $ Shape a ShapeSize shr e1 -> ShapeSize shr <$> travE e1 Undef t -> Just $ Undef t - Coerce t1 t2 e1 -> Coerce t1 t2 <$> travE e1 + Bitcast t1 t2 e1 -> Bitcast t1 t2 <$> travE e1 where travE :: OpenExp env1 aenv s -> Maybe (OpenExp env2 aenv s) @@ -572,7 +572,7 @@ rebuildOpenExp v av@(ReindexAvar reindex) exp = Shape a -> Shape <$> reindex a ShapeSize shr sh -> ShapeSize shr <$> rebuildOpenExp v av sh Foreign tp ff f e -> Foreign tp ff f <$> rebuildOpenExp v av e - Coerce t1 t2 e -> Coerce t1 t2 <$> rebuildOpenExp v av e + Bitcast t1 t2 e -> Bitcast t1 t2 <$> rebuildOpenExp v av e {-# INLINEABLE rebuildFun #-} rebuildFun From 6830362da95b0454e099e93fa7b10d26e4e30c67 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 16 May 2023 19:09:10 +0200 Subject: [PATCH 64/86] use associated data family --- src/Data/Array/Accelerate/Prelude.hs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/Data/Array/Accelerate/Prelude.hs b/src/Data/Array/Accelerate/Prelude.hs index 690558cf0..6cf549119 100644 --- a/src/Data/Array/Accelerate/Prelude.hs +++ b/src/Data/Array/Accelerate/Prelude.hs @@ -2270,19 +2270,17 @@ instance Arrays a => IfThenElse (Exp Bool) (Acc a) where match :: Matching f => f -> f match f = mkFun (mkMatch f) id -data Args f where - (:->) :: Exp a -> Args b -> Args (Exp a -> b) - Result :: Args (Exp a) - class Matching a where type ResultT a - mkMatch :: a -> Args a -> Exp (ResultT a) - mkFun :: (Args f -> Exp (ResultT a)) - -> (Args a -> Args f) + data ArgsR a + mkMatch :: a -> ArgsR a -> Exp (ResultT a) + mkFun :: (ArgsR f -> Exp (ResultT a)) + -> (ArgsR a -> ArgsR f) -> a instance Elt a => Matching (Exp a) where type ResultT (Exp a) = a + data ArgsR (Exp a) = Result mkFun f k = f (k Result) mkMatch (Exp e) Result = @@ -2292,6 +2290,7 @@ instance Elt a => Matching (Exp a) where instance (Elt e, Matching r) => Matching (Exp e -> r) where type ResultT (Exp e -> r) = ResultT r + data ArgsR (Exp e -> r) = Exp e :-> ArgsR r mkFun f k x = mkFun f (\xs -> k (x :-> xs)) mkMatch f (x@(Exp p) :-> xs) = From 72ef2dfce4244587fbc1b08a87b5bedfb9776a69 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 16 May 2023 19:09:16 +0200 Subject: [PATCH 65/86] unused imports --- src/Data/Array/Accelerate/Classes/Eq.hs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/Data/Array/Accelerate/Classes/Eq.hs b/src/Data/Array/Accelerate/Classes/Eq.hs index a57142888..00e3a1d50 100644 --- a/src/Data/Array/Accelerate/Classes/Eq.hs +++ b/src/Data/Array/Accelerate/Classes/Eq.hs @@ -33,7 +33,6 @@ module Data.Array.Accelerate.Classes.Eq ( ) where -import Data.Array.Accelerate.AST ( PrimFun(..), BitOrMask ) import Data.Array.Accelerate.Error import Data.Array.Accelerate.Pattern.Bool import Data.Array.Accelerate.Pattern.Tuple @@ -45,14 +44,12 @@ import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import {-# SOURCE #-} Data.Array.Accelerate.Classes.VEq -import Data.Bits import Text.Printf import Prelude ( ($), String, Num(..), Ordering(..), show, error, return, concat, map, zipWith, foldr1, mapM ) import Language.Haskell.TH.Extra hiding ( Exp ) import qualified Prelude as P import GHC.Exts -import GHC.TypeLits infix 4 == From 4db654413132e50a7df9de11e479ea5467200321 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 16 May 2023 19:12:53 +0200 Subject: [PATCH 66/86] fix vand & vor for non-power-of-two vecs --- src/Data/Array/Accelerate/Smart.hs | 31 ++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 2fb6c48a1..ae36e165a 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -130,7 +130,7 @@ import qualified Data.Array.Accelerate.Sugar.Shape as Sugar import qualified Data.Primitive.Vec as Prim -import Data.Bits ( Bits, unsafeShiftL ) +import Data.Bits ( Bits, unsafeShiftL, countLeadingZeros ) import Data.Kind import Data.Text.Lazy.Builder import Formatting hiding ( splat ) @@ -1237,14 +1237,21 @@ select (Exp mask) (Exp tt) (Exp ff) = Exp $ go (vecR @n @a) tt ff -- vand :: forall n. KnownNat n => Exp (Vec n Bool) -> Exp Bool vand v = - let n :: Int + let n, m :: Int n = fromInteger $ natVal' (proxy# :: Proxy# n) + m = max 8 (1 `unsafeShiftL` (64 - countLeadingZeros (n-1))) -- - cmp :: forall t. (Elt t, Num t, Bits t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) + cmp :: forall t. (Elt t, Num t, Bits t, IsScalar (EltR t), IsNum (EltR t), BitOrMask (EltR t) ~ Bit) => Exp Bool - cmp = mkEq (constant ((1 `unsafeShiftL` n) - 1)) - (mkPrimUnary (PrimFromBool bitType integralType) v :: Exp t) + cmp = let w = SingleIntegralType (TypeWord n) + b = mkExp $ Bitcast (BitScalarType bitType) (NumScalarType (IntegralNumType w)) (unExp v) + in + mkEq (constant ((1 `unsafeShiftL` n) - 1)) $ + if n == m + then b + else mkPrimUnary (PrimFromIntegral w numType) b :: Exp t in + if n == 1 then mkExp (Bitcast scalarType scalarType (unExp v)) else if n <= 8 then cmp @Word8 else if n <= 16 then cmp @Word16 else if n <= 32 then cmp @Word32 else @@ -1258,13 +1265,21 @@ vand v = -- vor :: forall n. KnownNat n => Exp (Vec n Bool) -> Exp Bool vor v = - let n :: Int + let n, m :: Int n = fromInteger $ natVal' (proxy# :: Proxy# n) + m = max 8 (1 `unsafeShiftL` (64 - countLeadingZeros (n-1))) -- - cmp :: forall t. (Elt t, Num t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) + cmp :: forall t. (Elt t, Num t, IsScalar (EltR t), IsNum (EltR t), BitOrMask (EltR t) ~ Bit) => Exp Bool - cmp = mkNEq (constant 0) (mkPrimUnary (PrimFromBool bitType integralType) v :: Exp t) + cmp = let w = SingleIntegralType (TypeWord n) + b = mkExp $ Bitcast (BitScalarType bitType) (NumScalarType (IntegralNumType w)) (unExp v) + in + mkNEq (constant 0) $ + if n == m + then b + else mkPrimUnary (PrimFromIntegral w numType) b :: Exp t in + if n == 1 then mkExp (Bitcast scalarType scalarType (unExp v)) else if n <= 8 then cmp @Word8 else if n <= 16 then cmp @Word16 else if n <= 32 then cmp @Word32 else From 34f05d4752a4b0f747a287fb53b350b6fbb2c016 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 14 Aug 2023 10:57:28 +0200 Subject: [PATCH 67/86] ci: don't run haddock on windows ghc-iserv: munmapForLinker: m32_release_page: Failed to unmap 4096 bytes at ...: Attempt to access invalid address. --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 35bb01eae..9373ba626 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -114,7 +114,7 @@ jobs: - name: Haddock # Behaviour of cabal haddock has changed for the worse: https://github.com/haskell/cabal/issues/8725 run: cabal haddock --disable-documentation - if: matrix.mode == 'release' + if: matrix.os != 'windows-latest' && matrix.mode == 'release' - name: Test doctest run: cabal test doctest From ab6b6fab763f087f7215d2ae8b75a94064cdf487 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 14 Aug 2023 10:59:23 +0200 Subject: [PATCH 68/86] add cc-options -std=c11 --- accelerate.cabal | 1 + 1 file changed, 1 insertion(+) diff --git a/accelerate.cabal b/accelerate.cabal index d49d5c575..6a80e0bbd 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -586,6 +586,7 @@ library cc-options: -O3 -Wall + -std=c11 cxx-options: -O3 From 0dcd2e38242f0febf6922740887f0219df517a7a Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 14 Aug 2023 11:11:05 +0200 Subject: [PATCH 69/86] updates for ghc-9.6 --- src/Data/Array/Accelerate/Classes/Eq.hs | 3 +-- src/Data/Array/Accelerate/Classes/RealFloat.hs | 2 ++ src/Data/Array/Accelerate/Data/Bits.hs | 4 +++- src/Data/Array/Accelerate/Language.hs | 2 +- src/Data/Array/Accelerate/Pattern.hs | 1 + src/Data/Array/Accelerate/Pattern/SIMD.hs | 1 + src/Data/Array/Accelerate/Pattern/Tuple.hs | 1 + src/Data/Array/Accelerate/Prelude.hs | 3 --- .../Accelerate/Test/NoFib/Issues/Issue407.hs | 1 + .../Accelerate/Test/NoFib/Issues/Issue517.hs | 16 +++++++++------- src/Data/Array/Accelerate/Unsafe.hs | 3 +-- 11 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/Data/Array/Accelerate/Classes/Eq.hs b/src/Data/Array/Accelerate/Classes/Eq.hs index 00e3a1d50..b2f71ed87 100644 --- a/src/Data/Array/Accelerate/Classes/Eq.hs +++ b/src/Data/Array/Accelerate/Classes/Eq.hs @@ -45,12 +45,11 @@ import Data.Array.Accelerate.Type import {-# SOURCE #-} Data.Array.Accelerate.Classes.VEq import Text.Printf +import Data.Char import Prelude ( ($), String, Num(..), Ordering(..), show, error, return, concat, map, zipWith, foldr1, mapM ) import Language.Haskell.TH.Extra hiding ( Exp ) import qualified Prelude as P -import GHC.Exts - infix 4 == infix 4 /= diff --git a/src/Data/Array/Accelerate/Classes/RealFloat.hs b/src/Data/Array/Accelerate/Classes/RealFloat.hs index 2cbbee5d3..f753659bc 100644 --- a/src/Data/Array/Accelerate/Classes/RealFloat.hs +++ b/src/Data/Array/Accelerate/Classes/RealFloat.hs @@ -7,6 +7,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} {-# OPTIONS_HADDOCK hide #-} @@ -49,6 +50,7 @@ import Data.Array.Accelerate.Classes.VOrd import Data.Coerce import Data.Kind +import Data.Type.Equality import Text.Printf import Prelude ( (.), ($), String, error, undefined, unlines ) import qualified Prelude as P diff --git a/src/Data/Array/Accelerate/Data/Bits.hs b/src/Data/Array/Accelerate/Data/Bits.hs index 25ba79cdf..f3887f8b1 100644 --- a/src/Data/Array/Accelerate/Data/Bits.hs +++ b/src/Data/Array/Accelerate/Data/Bits.hs @@ -9,6 +9,7 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Data.Bits @@ -45,12 +46,13 @@ import Data.Array.Accelerate.Classes.VEq import Data.Array.Accelerate.Classes.VOrd import Data.Kind +import Data.Type.Equality import Language.Haskell.TH hiding ( Exp, Type ) import Prelude ( ($), (<$>), undefined, otherwise, concat, mapM, toInteger ) import qualified Prelude as P import qualified Data.Bits as B -import GHC.Exts +import GHC.Prim ( Proxy#, proxy# ) import GHC.TypeLits infixl 8 `shiftL`, `shiftR`, `rotateL`, `rotateR` diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index 9ebd2f061..a7c0510e4 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -109,7 +109,7 @@ import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord import Prelude ( ($), (.), Char ) -#if __GLASGOW_HASKELL__ >= 904 +#if __GLASGOW_HASKELL__ == 904 import Data.Type.Equality #endif diff --git a/src/Data/Array/Accelerate/Pattern.hs b/src/Data/Array/Accelerate/Pattern.hs index 2b002b256..3ff082de7 100644 --- a/src/Data/Array/Accelerate/Pattern.hs +++ b/src/Data/Array/Accelerate/Pattern.hs @@ -10,6 +10,7 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Pattern diff --git a/src/Data/Array/Accelerate/Pattern/SIMD.hs b/src/Data/Array/Accelerate/Pattern/SIMD.hs index aa608462d..429f28864 100644 --- a/src/Data/Array/Accelerate/Pattern/SIMD.hs +++ b/src/Data/Array/Accelerate/Pattern/SIMD.hs @@ -9,6 +9,7 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} -- | diff --git a/src/Data/Array/Accelerate/Pattern/Tuple.hs b/src/Data/Array/Accelerate/Pattern/Tuple.hs index 64312af49..4f34fe26c 100644 --- a/src/Data/Array/Accelerate/Pattern/Tuple.hs +++ b/src/Data/Array/Accelerate/Pattern/Tuple.hs @@ -9,6 +9,7 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} -- | diff --git a/src/Data/Array/Accelerate/Prelude.hs b/src/Data/Array/Accelerate/Prelude.hs index cb54f8c06..6cf549119 100644 --- a/src/Data/Array/Accelerate/Prelude.hs +++ b/src/Data/Array/Accelerate/Prelude.hs @@ -141,9 +141,6 @@ import Data.Array.Accelerate.Data.Bits import Lens.Micro ( Lens', (&), (^.), (.~), (+~), (-~), lens, over ) import Prelude ( (.), ($), const, id, flip ) -#if __GLASGOW_HASKELL__ >= 904 -import Data.Type.Equality -#endif -- $setup diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs index ee5d6131b..5ac158df3 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs @@ -6,6 +6,7 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} -- | -- Module : Data.Array.Accelerate.Test.NoFib.Issues.Issue407 -- Copyright : [2009..2020] The Accelerate Team diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue517.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue517.hs index 0d92a9400..17619264e 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue517.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue517.hs @@ -17,13 +17,15 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue517 ( ) where -import Data.Array.Accelerate as A -import Data.Array.Accelerate.Data.Semigroup as A +import Data.Array.Accelerate as A +import Data.Array.Accelerate.Data.Semigroup as A import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty import Test.Tasty.HUnit +import Prelude hiding ( Maybe(..) ) + test_issue517 :: RunN -> TestTree test_issue517 runN @@ -37,9 +39,9 @@ e1 = fromList Z [(Nothing, Just 2, Just 3, Just 5, Just 7)] t1 :: Acc (Scalar (Tup5 (Maybe (Max Float)))) t1 = unit $ - T5 (Nothing_ <> Nothing_) - (Nothing_ <> Just_ 2) - (Just_ 3 <> Nothing_) - (Just_ 4 <> Just_ 5) - (Just_ 7 <> Just_ 6) + T5 (Nothing <> Nothing) + (Nothing <> Just 2) + (Just 3 <> Nothing) + (Just 4 <> Just 5) + (Just 7 <> Just 6) diff --git a/src/Data/Array/Accelerate/Unsafe.hs b/src/Data/Array/Accelerate/Unsafe.hs index 289cf7a2b..9b898f575 100644 --- a/src/Data/Array/Accelerate/Unsafe.hs +++ b/src/Data/Array/Accelerate/Unsafe.hs @@ -1,5 +1,4 @@ -{-# LANGUAGE MonoLocalBinds #-} -{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleContexts #-} -- | -- Module : Data.Array.Accelerate.Unsafe -- Copyright : [2009..2020] The Accelerate Team From 789d744e51860692bd58c1201a9edb7166448d78 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 14 Aug 2023 12:31:19 +0200 Subject: [PATCH 70/86] Haddock documentation not handled by Template Haskell until GHC-9 --- src/Data/Array/Accelerate/Classes/FromIntegral.hs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/Data/Array/Accelerate/Classes/FromIntegral.hs b/src/Data/Array/Accelerate/Classes/FromIntegral.hs index 422ce0b50..58787c859 100644 --- a/src/Data/Array/Accelerate/Classes/FromIntegral.hs +++ b/src/Data/Array/Accelerate/Classes/FromIntegral.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -105,11 +106,16 @@ runQ $ do thToBool :: Name -> Q [Dec] thToBool a = - [d| -- | @since 1.4.0.0 + [d| +#if __GLASGOW_HASKELL__ >= 900 + -- | @since 1.4.0.0 +#endif instance FromIntegral $(conT a) Bool where fromIntegral = mkToBool +#if __GLASGOW_HASKELL__ >= 900 -- | @since 1.4.0.0 +#endif instance KnownNat n => FromIntegral (Vec n $(conT a)) (Vec n Bool) where fromIntegral = mkToBool |] From 8fa749497770751cb72a539926673821a4b96b91 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 14 Aug 2023 12:32:34 +0200 Subject: [PATCH 71/86] OPTIONS_HADDOCK hide --- src/Data/Array/Accelerate/Representation/Vec.hs | 1 + src/Data/Array/Accelerate/Trafo/Var.hs | 1 + src/Data/Primitive/Bit.hs | 1 + 3 files changed, 3 insertions(+) diff --git a/src/Data/Array/Accelerate/Representation/Vec.hs b/src/Data/Array/Accelerate/Representation/Vec.hs index c6f31570d..378a640f6 100644 --- a/src/Data/Array/Accelerate/Representation/Vec.hs +++ b/src/Data/Array/Accelerate/Representation/Vec.hs @@ -1,5 +1,6 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} +{-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Representation.Vec -- Copyright : [2008..2020] The Accelerate Team diff --git a/src/Data/Array/Accelerate/Trafo/Var.hs b/src/Data/Array/Accelerate/Trafo/Var.hs index 76cb2b741..aef30b672 100644 --- a/src/Data/Array/Accelerate/Trafo/Var.hs +++ b/src/Data/Array/Accelerate/Trafo/Var.hs @@ -3,6 +3,7 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} +{-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Trafo.Var -- Copyright : [2012..2020] The Accelerate Team diff --git a/src/Data/Primitive/Bit.hs b/src/Data/Primitive/Bit.hs index 78ca506d9..57ff22b82 100644 --- a/src/Data/Primitive/Bit.hs +++ b/src/Data/Primitive/Bit.hs @@ -8,6 +8,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UnboxedTuples #-} +{-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Bit -- Copyright : [2008..2022] The Accelerate Team From 165f317878284662122b74edc75ab17c2e5f0191 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Fri, 18 Aug 2023 15:51:53 +0200 Subject: [PATCH 72/86] warning police --- src/Data/Array/Accelerate/Pretty/Graphviz.hs | 7 +-- src/Data/Array/Accelerate/Trafo/Sharing.hs | 49 +++++++++++++------- src/Data/Array/Accelerate/Trafo/Shrink.hs | 20 ++++---- src/Data/Array/Accelerate/Trafo/Var.hs | 2 +- 4 files changed, 47 insertions(+), 31 deletions(-) diff --git a/src/Data/Array/Accelerate/Pretty/Graphviz.hs b/src/Data/Array/Accelerate/Pretty/Graphviz.hs index 6bf0b6b21..8d22e958b 100644 --- a/src/Data/Array/Accelerate/Pretty/Graphviz.hs +++ b/src/Data/Array/Accelerate/Pretty/Graphviz.hs @@ -216,9 +216,10 @@ prettyDelayedOpenAcc detail ctx aenv (Manifest pacc) = p' <- prettyDelayedAfun detail aenv p f' <- prettyDelayedAfun detail aenv f -- - let PNode _ (Leaf (Nothing,xb)) fvs = x' - loop = nest 2 (sep ["awhile", pretty p', pretty f', xb ]) - return $ PNode ident (Leaf (Nothing,loop)) fvs + case x' of + PNode _ (Leaf (Nothing,xb)) fvs -> let loop = nest 2 (sep ["awhile", pretty p', pretty f', xb ]) + in return $ PNode ident (Leaf (Nothing,loop)) fvs + _ -> internalError "unexpected node" Apair a1 a2 -> genNodeId >>= prettyDelayedApair detail aenv a1 a2 diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index 071ca6dac..bb591366b 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -1278,7 +1278,7 @@ instance HasTypeR UnscopedExp where typeR (UnscopedExp _ exp) = Smart.typeR exp -- Specifies a scalar expression AST with sharing. For expressions rooted in functions the list --- holds a sorted environment corresponding to the variables bound in the immediate surounding +-- holds a sorted environment corresponding to the variables bound in the immediate surrounding -- lambdas. data ScopedExp t = ScopedExp [StableSharingExp] (SharingExp ScopedAcc ScopedExp t) @@ -2653,12 +2653,14 @@ determineScopesSharingAcc config accOccMap = scopesAcc :: HasCallStack => (SmartAcc a1 -> UnscopedAcc a2) -> (SmartAcc a1 -> ScopedAcc a2, NodeCounts) - scopesAfun1 f = (const (ScopedAcc ssa body'), (counts', graph)) + scopesAfun1 f + | not (null env) = internalError "unexpected unbound variables" + | otherwise = (const (ScopedAcc ssa body'), (counts', graph)) where - body@(UnscopedAcc fvs _) = f undefined - (ScopedAcc [] body', (counts,graph)) = scopesAcc body - (freeCounts, counts') = partition isBoundHere counts - ssa = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts] + body@(UnscopedAcc fvs _) = f undefined + (ScopedAcc env body', (counts,graph)) = scopesAcc body + (freeCounts, counts') = partition isBoundHere counts + ssa = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts] isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag _ i))) _) = i `elem` fvs isBoundHere _ = False @@ -2731,14 +2733,19 @@ determineScopesExp -> RootExp t -> (ScopedExp t, NodeCounts) -- Root (closed) expression plus Acc node counts determineScopesExp config accOccMap (RootExp expOccMap exp@(UnscopedExp fvs _)) - = let - (ScopedExp [] expWithScopes, (nodeCounts,graph)) = determineScopesSharingExp config accOccMap expOccMap exp - (expCounts, accCounts) = partition isExpNodeCount nodeCounts + | not (null env) + = internalError "unexpected unbound variables" + -- + | otherwise + = ( ScopedExp (buildInitialEnvExp fvs [se | ExpNodeCount se _ <- expCounts]) expWithScopes + , cleanCounts (accCounts,graph) + ) + where + (ScopedExp env expWithScopes, (nodeCounts,graph)) = determineScopesSharingExp config accOccMap expOccMap exp + (expCounts, accCounts) = partition isExpNodeCount nodeCounts - isExpNodeCount ExpNodeCount{} = True - isExpNodeCount _ = False - in - (ScopedExp (buildInitialEnvExp fvs [se | ExpNodeCount se _ <- expCounts]) expWithScopes, cleanCounts (accCounts,graph)) + isExpNodeCount ExpNodeCount{} = True + isExpNodeCount _ = False determineScopesSharingExp @@ -2760,12 +2767,18 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp :: HasCallStack => (SmartExp a -> UnscopedExp b) -> (SmartExp a -> ScopedExp b, NodeCounts) - scopesFun1 f = tracePure (bformat ("LAMBDA " % list formatStableSharingExp) ssa) (bformat (list formatNodeCount) counts) (const (ScopedExp ssa body'), (counts',graph)) + scopesFun1 f + | not (null env) + = internalError "unexpected unbound variables" + -- + | otherwise + = tracePure (bformat ("LAMBDA " % list formatStableSharingExp) ssa) (bformat (list formatNodeCount) counts) + $ (const (ScopedExp ssa body'), (counts',graph)) where - body@(UnscopedExp fvs _) = f undefined - (ScopedExp [] body', (counts, graph)) = scopesExp body - (freeCounts, counts') = partition isBoundHere counts - ssa = buildInitialEnvExp fvs [se | ExpNodeCount se _ <- freeCounts] + body@(UnscopedExp fvs _) = f undefined + (ScopedExp env body', (counts, graph)) = scopesExp body + (freeCounts, counts') = partition isBoundHere counts + ssa = buildInitialEnvExp fvs [se | ExpNodeCount se _ <- freeCounts] isBoundHere (ExpNodeCount (StableSharingExp _ (ExpSharing _ (Tag _ i))) _) = i `elem` fvs isBoundHere _ = False diff --git a/src/Data/Array/Accelerate/Trafo/Shrink.hs b/src/Data/Array/Accelerate/Trafo/Shrink.hs index f2840de2b..89548f345 100644 --- a/src/Data/Array/Accelerate/Trafo/Shrink.hs +++ b/src/Data/Array/Accelerate/Trafo/Shrink.hs @@ -1,11 +1,12 @@ -{-# LANGUAGE TupleSections #-} {-# LANGUAGE CPP #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} @@ -208,18 +209,19 @@ strengthenShrunkLHS -> env1 :?> env1' -> env2 :?> env2' strengthenShrunkLHS (LeftHandSideWildcard _) (LeftHandSideWildcard _) k = k -strengthenShrunkLHS (LeftHandSideSingle _) (LeftHandSideSingle _) k = \ix -> case ix of +strengthenShrunkLHS (LeftHandSideSingle _) (LeftHandSideSingle _) k = \case ZeroIdx -> Just ZeroIdx SuccIdx ix' -> SuccIdx <$> k ix' -strengthenShrunkLHS (LeftHandSidePair lA hA) (LeftHandSidePair lB hB) k = strengthenShrunkLHS hA hB $ strengthenShrunkLHS lA lB k -strengthenShrunkLHS (LeftHandSideSingle _) (LeftHandSideWildcard _) k = \ix -> case ix of +strengthenShrunkLHS (LeftHandSidePair lA hA) (LeftHandSidePair lB hB) k = + strengthenShrunkLHS hA hB $ strengthenShrunkLHS lA lB k +strengthenShrunkLHS (LeftHandSideSingle _) (LeftHandSideWildcard _) k = \case ZeroIdx -> Nothing SuccIdx ix' -> k ix' -strengthenShrunkLHS (LeftHandSidePair l h) (LeftHandSideWildcard t) k = strengthenShrunkLHS h (LeftHandSideWildcard t2) $ strengthenShrunkLHS l (LeftHandSideWildcard t1) k - where - TupRpair t1 t2 = t -strengthenShrunkLHS (LeftHandSideWildcard _) _ _ = internalError "Second LHS defines more variables" -strengthenShrunkLHS _ _ _ = internalError "Mismatch LHS single with LHS pair" +strengthenShrunkLHS (LeftHandSidePair l h) (LeftHandSideWildcard (TupRpair t1 t2)) k + = strengthenShrunkLHS h (LeftHandSideWildcard t2) + $ strengthenShrunkLHS l (LeftHandSideWildcard t1) k +strengthenShrunkLHS (LeftHandSideWildcard _) _ _ = internalError "Second LHS defines more variables" +strengthenShrunkLHS _ _ _ = internalError "Mismatch LHS single with LHS pair" -- Shrinking diff --git a/src/Data/Array/Accelerate/Trafo/Var.hs b/src/Data/Array/Accelerate/Trafo/Var.hs index aef30b672..c6bb254f9 100644 --- a/src/Data/Array/Accelerate/Trafo/Var.hs +++ b/src/Data/Array/Accelerate/Trafo/Var.hs @@ -34,7 +34,7 @@ data DeclareVars s t aenv where declareVars :: TupR s t -> DeclareVars s t env declareVars TupRunit - = DeclareVars LeftHandSideUnit weakenId $ const $ TupRunit + = DeclareVars LeftHandSideUnit weakenId $ const TupRunit declareVars (TupRsingle s) = DeclareVars (LeftHandSideSingle s) (weakenSucc weakenId) $ \k -> TupRsingle $ Var s $ k >:> ZeroIdx declareVars (TupRpair r1 r2) From 1e07380763dd33d4b5a5edaaab851c55f86f1c8d Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 21 Aug 2023 15:26:57 +0200 Subject: [PATCH 73/86] cleaning up smart constructor cruft --- src/Data/Array/Accelerate/Classes/Eq.hs | 15 +- src/Data/Array/Accelerate/Classes/Floating.hs | 69 ++-- .../Array/Accelerate/Classes/Fractional.hs | 9 +- src/Data/Array/Accelerate/Classes/FromBool.hs | 5 +- .../Array/Accelerate/Classes/FromIntegral.hs | 10 +- src/Data/Array/Accelerate/Classes/Integral.hs | 31 +- src/Data/Array/Accelerate/Classes/Num.hs | 25 +- src/Data/Array/Accelerate/Classes/Ord.hs | 57 +-- .../Array/Accelerate/Classes/RealFloat.hs | 11 +- src/Data/Array/Accelerate/Classes/RealFrac.hs | 9 +- .../Array/Accelerate/Classes/ToFloating.hs | 6 +- src/Data/Array/Accelerate/Classes/VEq.hs | 33 +- src/Data/Array/Accelerate/Classes/VOrd.hs | 27 +- src/Data/Array/Accelerate/Data/Bits.hs | 36 +- src/Data/Array/Accelerate/Data/Complex.hs | 2 +- src/Data/Array/Accelerate/Language.hs | 12 +- src/Data/Array/Accelerate/Smart.hs | 362 ++---------------- src/Data/Primitive/Bit.hs | 2 +- src/Data/Primitive/Vec.hs | 1 - 19 files changed, 246 insertions(+), 476 deletions(-) diff --git a/src/Data/Array/Accelerate/Classes/Eq.hs b/src/Data/Array/Accelerate/Classes/Eq.hs index b2f71ed87..dfb9bd492 100644 --- a/src/Data/Array/Accelerate/Classes/Eq.hs +++ b/src/Data/Array/Accelerate/Classes/Eq.hs @@ -33,6 +33,7 @@ module Data.Array.Accelerate.Classes.Eq ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Error import Data.Array.Accelerate.Pattern.Bool import Data.Array.Accelerate.Pattern.Tuple @@ -69,7 +70,7 @@ infixr 3 && -- infixr 3 &&! (&&!) :: Exp Bool -> Exp Bool -> Exp Bool -(&&!) = mkLAnd +(&&!) = mkPrimBinary $ PrimLAnd bitType -- | Disjunction: True if either argument is true. This is a short-circuit -- operator, so the second argument will be evaluated only if the first is @@ -88,12 +89,12 @@ infixr 2 || -- infixr 2 ||! (||!) :: Exp Bool -> Exp Bool -> Exp Bool -(||!) = mkLOr +(||!) = mkPrimBinary $ PrimLOr bitType -- | Logical negation -- not :: Exp Bool -> Exp Bool -not = mkLNot +not = mkPrimUnary $ PrimLNot bitType -- | The 'Eq' class defines equality '(==)' and inequality '(/=)' for @@ -171,8 +172,8 @@ runQ $ do mkPrim :: Name -> Q [Dec] mkPrim t = [d| instance Eq $(conT t) where - (==) = mkEq - (/=) = mkNEq + (==) = mkPrimBinary $ PrimEq scalarType + (/=) = mkPrimBinary $ PrimNEq scalarType |] mkTup :: Int -> Q [Dec] @@ -200,8 +201,8 @@ instance Eq sh => Eq (sh :. Int) where x /= y = indexHead x /= indexHead y || indexTail x /= indexTail y instance Eq Bool where - (==) = mkEq - (/=) = mkNEq + (==) = mkPrimBinary $ PrimEq scalarType + (/=) = mkPrimBinary $ PrimNEq scalarType instance Eq Ordering where x == y = mkCoerce x == (mkCoerce y :: Exp TAG) diff --git a/src/Data/Array/Accelerate/Classes/Floating.hs b/src/Data/Array/Accelerate/Classes/Floating.hs index 01821e208..263fd9d1d 100644 --- a/src/Data/Array/Accelerate/Classes/Floating.hs +++ b/src/Data/Array/Accelerate/Classes/Floating.hs @@ -35,6 +35,7 @@ module Data.Array.Accelerate.Classes.Floating ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type @@ -65,43 +66,43 @@ runQ $ thFloating a = [d| instance P.Floating (Exp $(conT a)) where pi = constant pi - sin = mkSin - cos = mkCos - tan = mkTan - asin = mkAsin - acos = mkAcos - atan = mkAtan - sinh = mkSinh - cosh = mkCosh - tanh = mkTanh - asinh = mkAsinh - acosh = mkAcosh - atanh = mkAtanh - exp = mkExpFloating - sqrt = mkSqrt - log = mkLog - (**) = mkFPow - logBase = mkLogBase + sin = mkPrimUnary $ PrimSin floatingType + cos = mkPrimUnary $ PrimCos floatingType + tan = mkPrimUnary $ PrimTan floatingType + asin = mkPrimUnary $ PrimAsin floatingType + acos = mkPrimUnary $ PrimAcos floatingType + atan = mkPrimUnary $ PrimAtan floatingType + sinh = mkPrimUnary $ PrimSinh floatingType + cosh = mkPrimUnary $ PrimCosh floatingType + tanh = mkPrimUnary $ PrimTanh floatingType + asinh = mkPrimUnary $ PrimAsinh floatingType + acosh = mkPrimUnary $ PrimAcosh floatingType + atanh = mkPrimUnary $ PrimAtanh floatingType + exp = mkPrimUnary $ PrimExpFloating floatingType + sqrt = mkPrimUnary $ PrimSqrt floatingType + log = mkPrimUnary $ PrimLog floatingType + (**) = mkPrimBinary $ PrimFPow floatingType + logBase = mkPrimBinary $ PrimLogBase floatingType instance KnownNat n => P.Floating (Exp (Vec n $(conT a))) where pi = constant (Vec (Prim.splat pi)) - sin = mkSin - cos = mkCos - tan = mkTan - asin = mkAsin - acos = mkAcos - atan = mkAtan - sinh = mkSinh - cosh = mkCosh - tanh = mkTanh - asinh = mkAsinh - acosh = mkAcosh - atanh = mkAtanh - exp = mkExpFloating - sqrt = mkSqrt - log = mkLog - (**) = mkFPow - logBase = mkLogBase + sin = mkPrimUnary $ PrimSin floatingType + cos = mkPrimUnary $ PrimCos floatingType + tan = mkPrimUnary $ PrimTan floatingType + asin = mkPrimUnary $ PrimAsin floatingType + acos = mkPrimUnary $ PrimAcos floatingType + atan = mkPrimUnary $ PrimAtan floatingType + sinh = mkPrimUnary $ PrimSinh floatingType + cosh = mkPrimUnary $ PrimCosh floatingType + tanh = mkPrimUnary $ PrimTanh floatingType + asinh = mkPrimUnary $ PrimAsinh floatingType + acosh = mkPrimUnary $ PrimAcosh floatingType + atanh = mkPrimUnary $ PrimAtanh floatingType + exp = mkPrimUnary $ PrimExpFloating floatingType + sqrt = mkPrimUnary $ PrimSqrt floatingType + log = mkPrimUnary $ PrimLog floatingType + (**) = mkPrimBinary $ PrimFPow floatingType + logBase = mkPrimBinary $ PrimLogBase floatingType |] in concat <$> mapM thFloating floatingTypes diff --git a/src/Data/Array/Accelerate/Classes/Fractional.hs b/src/Data/Array/Accelerate/Classes/Fractional.hs index dd554f0f6..01eccea27 100644 --- a/src/Data/Array/Accelerate/Classes/Fractional.hs +++ b/src/Data/Array/Accelerate/Classes/Fractional.hs @@ -26,6 +26,7 @@ module Data.Array.Accelerate.Classes.Fractional ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type @@ -66,13 +67,13 @@ runQ $ thFractional :: Name -> Q [Dec] thFractional a = [d| instance P.Fractional (Exp $(conT a)) where - (/) = mkFDiv - recip = mkRecip + (/) = mkPrimBinary $ PrimFDiv floatingType + recip = mkPrimUnary $ PrimRecip floatingType fromRational = constant . P.fromRational instance KnownNat n => P.Fractional (Exp (Vec n $(conT a))) where - (/) = mkFDiv - recip = mkRecip + (/) = mkPrimBinary $ PrimFDiv floatingType + recip = mkPrimUnary $ PrimRecip floatingType fromRational = constant . Vec . Prim.splat . P.fromRational |] in diff --git a/src/Data/Array/Accelerate/Classes/FromBool.hs b/src/Data/Array/Accelerate/Classes/FromBool.hs index 59fe87f5c..23c5d994c 100644 --- a/src/Data/Array/Accelerate/Classes/FromBool.hs +++ b/src/Data/Array/Accelerate/Classes/FromBool.hs @@ -18,6 +18,7 @@ module Data.Array.Accelerate.Classes.FromBool ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type @@ -57,10 +58,10 @@ runQ $ thFromBool :: Name -> Q [Dec] thFromBool b = [d| instance FromBool Bool $(conT b) where - fromBool = mkFromBool + fromBool = mkPrimUnary $ PrimFromBool bitType integralType instance KnownNat n => FromBool (Vec n Bool) (Vec n $(conT b)) where - fromBool = mkFromBool + fromBool = mkPrimUnary $ PrimFromBool bitType integralType |] in concat <$> mapM thFromBool integralTypes diff --git a/src/Data/Array/Accelerate/Classes/FromIntegral.hs b/src/Data/Array/Accelerate/Classes/FromIntegral.hs index 58787c859..d9dccf743 100644 --- a/src/Data/Array/Accelerate/Classes/FromIntegral.hs +++ b/src/Data/Array/Accelerate/Classes/FromIntegral.hs @@ -19,7 +19,9 @@ module Data.Array.Accelerate.Classes.FromIntegral ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type @@ -38,8 +40,8 @@ class FromIntegral a b where -- | General coercion from integral types fromIntegral :: Integral a => Exp a -> Exp b --- instance {-# OVERLAPPABLE #-} (Elt a, Elt b, IsIntegral a, IsNum b) => FromIntegral a b where --- fromIntegral = mkFromIntegral +mkFromIntegral :: (IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b +mkFromIntegral = mkPrimUnary $ PrimFromIntegral integralType numType -- Reify in ghci: -- @@ -111,13 +113,13 @@ runQ $ do -- | @since 1.4.0.0 #endif instance FromIntegral $(conT a) Bool where - fromIntegral = mkToBool + fromIntegral = mkPrimUnary $ PrimToBool integralType bitType #if __GLASGOW_HASKELL__ >= 900 -- | @since 1.4.0.0 #endif instance KnownNat n => FromIntegral (Vec n $(conT a)) (Vec n Bool) where - fromIntegral = mkToBool + fromIntegral = mkPrimUnary $ PrimToBool integralType bitType |] -- x <- concat <$> sequence [ thFromIntegral from to | from <- integralTypes, to <- numTypes ] diff --git a/src/Data/Array/Accelerate/Classes/Integral.hs b/src/Data/Array/Accelerate/Classes/Integral.hs index fd2208c31..faf61d320 100644 --- a/src/Data/Array/Accelerate/Classes/Integral.hs +++ b/src/Data/Array/Accelerate/Classes/Integral.hs @@ -27,8 +27,11 @@ module Data.Array.Accelerate.Classes.Integral ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Enum @@ -46,6 +49,18 @@ import qualified Prelude as P -- type Integral a = (Enum a, Ord a, Num a, P.Integral (Exp a)) +mkQuotRem :: IsIntegral (EltR t) => Exp t -> Exp t -> (Exp t, Exp t) +mkQuotRem (Exp x) (Exp y) = + let r = SmartExp $ PrimQuotRem integralType `PrimApp` SmartExp (Pair x y) + in ( mkExp $ Prj PairIdxLeft r + , mkExp $ Prj PairIdxRight r) + +mkDivMod :: IsIntegral (EltR t) => Exp t -> Exp t -> (Exp t, Exp t) +mkDivMod (Exp x) (Exp y) = + let r = SmartExp $ PrimDivMod integralType `PrimApp` SmartExp (Pair x y) + in ( mkExp $ Prj PairIdxLeft r + , mkExp $ Prj PairIdxRight r) + runQ $ let integralTypes :: [Name] @@ -67,19 +82,19 @@ runQ $ mkIntegral :: Name -> Q [Dec] mkIntegral a = [d| instance P.Integral (Exp $(conT a)) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod + quot = mkPrimBinary $ PrimQuot integralType + rem = mkPrimBinary $ PrimRem integralType + div = mkPrimBinary $ PrimIDiv integralType + mod = mkPrimBinary $ PrimMod integralType quotRem = mkQuotRem divMod = mkDivMod toInteger = P.error "Prelude.toInteger not supported for Accelerate types" instance KnownNat n => P.Integral (Exp (Vec n $(conT a))) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod + quot = mkPrimBinary $ PrimQuot integralType + rem = mkPrimBinary $ PrimRem integralType + div = mkPrimBinary $ PrimIDiv integralType + mod = mkPrimBinary $ PrimMod integralType quotRem = mkQuotRem divMod = mkDivMod toInteger = P.error "Prelude.toInteger not supported for Accelerate types" diff --git a/src/Data/Array/Accelerate/Classes/Num.hs b/src/Data/Array/Accelerate/Classes/Num.hs index a236074ff..551b55ad8 100644 --- a/src/Data/Array/Accelerate/Classes/Num.hs +++ b/src/Data/Array/Accelerate/Classes/Num.hs @@ -26,6 +26,7 @@ module Data.Array.Accelerate.Classes.Num ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Vec @@ -105,21 +106,21 @@ runQ $ thNum :: Name -> Q [Dec] thNum a = [d| instance P.Num (Exp $(conT a)) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig + (+) = mkPrimBinary $ PrimAdd numType + (-) = mkPrimBinary $ PrimSub numType + (*) = mkPrimBinary $ PrimSub numType + negate = mkPrimUnary $ PrimNeg numType + abs = mkPrimUnary $ PrimAbs numType + signum = mkPrimUnary $ PrimSig numType fromInteger = constant . P.fromInteger instance KnownNat n => P.Num (Exp (Vec n $(conT a))) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig + (+) = mkPrimBinary $ PrimAdd numType + (-) = mkPrimBinary $ PrimSub numType + (*) = mkPrimBinary $ PrimSub numType + negate = mkPrimUnary $ PrimNeg numType + abs = mkPrimUnary $ PrimAbs numType + signum = mkPrimUnary $ PrimSig numType fromInteger = constant . Vec . Prim.splat . P.fromInteger |] in diff --git a/src/Data/Array/Accelerate/Classes/Ord.hs b/src/Data/Array/Accelerate/Classes/Ord.hs index dc53675a7..9aad3c2a5 100644 --- a/src/Data/Array/Accelerate/Classes/Ord.hs +++ b/src/Data/Array/Accelerate/Classes/Ord.hs @@ -30,6 +30,7 @@ module Data.Array.Accelerate.Classes.Ord ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.VEq @@ -56,7 +57,7 @@ infix 4 > infix 4 <= infix 4 >= --- | The 'Ord' class for totally ordered datatypes +-- | The 'Ord' class for totally ordered data types -- class Eq a => Ord a where {-# MINIMAL (<=) | compare #-} @@ -146,33 +147,33 @@ runQ $ do mkPrim :: Name -> Q [Dec] mkPrim t = [d| instance Ord $(conT t) where - (<) = mkLt - (>) = mkGt - (<=) = mkLtEq - (>=) = mkGtEq - min = mkMin - max = mkMax + (<) = mkPrimBinary $ PrimLt scalarType + (>) = mkPrimBinary $ PrimGt scalarType + (<=) = mkPrimBinary $ PrimLtEq scalarType + (>=) = mkPrimBinary $ PrimGtEq scalarType + min = mkPrimBinary $ PrimMin scalarType + max = mkPrimBinary $ PrimMax scalarType |] - mkLt' :: [ExpQ] -> [ExpQ] -> ExpQ - mkLt' [x] [y] = [| $x < $y |] - mkLt' (x:xs) (y:ys) = [| $x < $y || ( $x == $y && $(mkLt' xs ys) ) |] - mkLt' _ _ = error "mkLt'" + mkLt :: [ExpQ] -> [ExpQ] -> ExpQ + mkLt [x] [y] = [| $x < $y |] + mkLt (x:xs) (y:ys) = [| $x < $y || ( $x == $y && $(mkLt xs ys) ) |] + mkLt _ _ = error "mkLt" - mkGt' :: [ExpQ] -> [ExpQ] -> ExpQ - mkGt' [x] [y] = [| $x > $y |] - mkGt' (x:xs) (y:ys) = [| $x > $y || ( $x == $y && $(mkGt' xs ys) ) |] - mkGt' _ _ = error "mkGt'" + mkGt :: [ExpQ] -> [ExpQ] -> ExpQ + mkGt [x] [y] = [| $x > $y |] + mkGt (x:xs) (y:ys) = [| $x > $y || ( $x == $y && $(mkGt xs ys) ) |] + mkGt _ _ = error "mkGt" - mkLtEq' :: [ExpQ] -> [ExpQ] -> ExpQ - mkLtEq' [x] [y] = [| $x <= $y |] - mkLtEq' (x:xs) (y:ys) = [| $x < $y || ( $x == $y && $(mkLtEq' xs ys) ) |] - mkLtEq' _ _ = error "mkLtEq'" + mkLtEq :: [ExpQ] -> [ExpQ] -> ExpQ + mkLtEq [x] [y] = [| $x <= $y |] + mkLtEq (x:xs) (y:ys) = [| $x < $y || ( $x == $y && $(mkLtEq xs ys) ) |] + mkLtEq _ _ = error "mkLtEq" - mkGtEq' :: [ExpQ] -> [ExpQ] -> ExpQ - mkGtEq' [x] [y] = [| $x >= $y |] - mkGtEq' (x:xs) (y:ys) = [| $x > $y || ( $x == $y && $(mkGtEq' xs ys) ) |] - mkGtEq' _ _ = error "mkGtEq'" + mkGtEq :: [ExpQ] -> [ExpQ] -> ExpQ + mkGtEq [x] [y] = [| $x >= $y |] + mkGtEq (x:xs) (y:ys) = [| $x > $y || ( $x == $y && $(mkGtEq xs ys) ) |] + mkGtEq _ _ = error "mkGtEq" mkTup :: Int -> Q [Dec] mkTup n = @@ -184,10 +185,10 @@ runQ $ do pat vs = conP (mkName ('T':show n)) (map varP vs) in [d| instance $cst => Ord $res where - $(pat xs) < $(pat ys) = $( mkLt' (map varE xs) (map varE ys) ) - $(pat xs) > $(pat ys) = $( mkGt' (map varE xs) (map varE ys) ) - $(pat xs) >= $(pat ys) = $( mkGtEq' (map varE xs) (map varE ys) ) - $(pat xs) <= $(pat ys) = $( mkLtEq' (map varE xs) (map varE ys) ) + $(pat xs) < $(pat ys) = $( mkLt (map varE xs) (map varE ys) ) + $(pat xs) > $(pat ys) = $( mkGt (map varE xs) (map varE ys) ) + $(pat xs) >= $(pat ys) = $( mkGtEq (map varE xs) (map varE ys) ) + $(pat xs) <= $(pat ys) = $( mkLtEq (map varE xs) (map varE ys) ) |] is <- mapM mkPrim integralTypes @@ -249,5 +250,5 @@ vcmp cmp x y = go (u:us) (v:vs) = u || (v && go us vs) go _ _ = internalError "unexpected vector encoding" in - go (mkUnpack (cmp x y)) (mkUnpack (x ==* y)) + go (unpack (cmp x y)) (unpack (x ==* y)) diff --git a/src/Data/Array/Accelerate/Classes/RealFloat.hs b/src/Data/Array/Accelerate/Classes/RealFloat.hs index f753659bc..a24d55c9f 100644 --- a/src/Data/Array/Accelerate/Classes/RealFloat.hs +++ b/src/Data/Array/Accelerate/Classes/RealFloat.hs @@ -28,7 +28,7 @@ module Data.Array.Accelerate.Classes.RealFloat ( ) where -import Data.Array.Accelerate.AST ( BitOrMask, PrimMask ) +import Data.Array.Accelerate.AST ( PrimFun(..), BitOrMask, PrimMask ) import Data.Array.Accelerate.Language ( (^), cond, while ) import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Smart @@ -281,6 +281,15 @@ instance KnownNat n => RealFloat (Vec n Float128) where isIEEE _ = defaultIsIEEE (undefined :: Exp Float128) atan2 = mkAtan2 +mkIsNaN :: (IsFloating (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp b +mkIsNaN = mkPrimUnary $ PrimIsNaN floatingType + +mkIsInfinite :: (IsFloating (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp b +mkIsInfinite = mkPrimUnary $ PrimIsInfinite floatingType + +mkAtan2 :: IsFloating (EltR t) => Exp t -> Exp t -> Exp t +mkAtan2 = mkPrimBinary $ PrimAtan2 floatingType + -- To satisfy superclass constraints -- diff --git a/src/Data/Array/Accelerate/Classes/RealFrac.hs b/src/Data/Array/Accelerate/Classes/RealFrac.hs index 07b90c63e..6c7a285a5 100644 --- a/src/Data/Array/Accelerate/Classes/RealFrac.hs +++ b/src/Data/Array/Accelerate/Classes/RealFrac.hs @@ -24,6 +24,7 @@ module Data.Array.Accelerate.Classes.RealFrac ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Language ( cond, even ) import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Representation.Type @@ -136,7 +137,7 @@ defaultTruncate :: forall a b. (RealFrac a, Integral b, FromIntegral (Significan defaultTruncate x | Just FloatingDict <- floatingDict @a , Just IntegralDict <- integralDict @b - = mkTruncate x + = mkPrimUnary (PrimTruncate floatingType integralType) x -- | otherwise = let T2 n _ = properFraction x in n @@ -145,7 +146,7 @@ defaultCeiling :: forall a b. (RealFrac a, Integral b, FromIntegral (Significand defaultCeiling x | Just FloatingDict <- floatingDict @a , Just IntegralDict <- integralDict @b - = mkCeiling x + = mkPrimUnary (PrimCeiling floatingType integralType) x -- | otherwise = let T2 n r = properFraction x in cond (r > 0) (n+1) n @@ -154,7 +155,7 @@ defaultFloor :: forall a b. (RealFrac a, Integral b, FromIntegral (Significand a defaultFloor x | Just FloatingDict <- floatingDict @a , Just IntegralDict <- integralDict @b - = mkFloor x + = mkPrimUnary (PrimFloor floatingType integralType) x -- | otherwise = let T2 n r = properFraction x in cond (r < 0) (n-1) n @@ -163,7 +164,7 @@ defaultRound :: forall a b. (RealFrac a, Integral b, FromIntegral (Significand a defaultRound x | Just FloatingDict <- floatingDict @a , Just IntegralDict <- integralDict @b - = mkRound x + = mkPrimUnary (PrimRound floatingType integralType) x -- | otherwise = let T2 n r = properFraction x diff --git a/src/Data/Array/Accelerate/Classes/ToFloating.hs b/src/Data/Array/Accelerate/Classes/ToFloating.hs index 1d45bb525..05adf84e6 100644 --- a/src/Data/Array/Accelerate/Classes/ToFloating.hs +++ b/src/Data/Array/Accelerate/Classes/ToFloating.hs @@ -19,7 +19,9 @@ module Data.Array.Accelerate.Classes.ToFloating ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type @@ -40,8 +42,8 @@ class ToFloating a b where -- | General coercion to floating types toFloating :: (Num a, Floating b) => Exp a -> Exp b --- instance (Elt a, Elt b, IsNum a, IsFloating b) => ToFloating a b where --- toFloating = mkToFloating +mkToFloating :: (IsNum (EltR a), IsFloating (EltR b)) => Exp a -> Exp b +mkToFloating = mkPrimUnary $ PrimToFloating numType floatingType -- Generate standard instances explicitly. See also: 'FromIntegral'. diff --git a/src/Data/Array/Accelerate/Classes/VEq.hs b/src/Data/Array/Accelerate/Classes/VEq.hs index 8ec9a347d..62a26e52d 100644 --- a/src/Data/Array/Accelerate/Classes/VEq.hs +++ b/src/Data/Array/Accelerate/Classes/VEq.hs @@ -32,43 +32,43 @@ import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Smart -import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type -import Data.Array.Accelerate.Error - import qualified Data.Primitive.Bit as Prim import Language.Haskell.TH.Extra hiding ( Type, Exp ) import Prelude hiding ( Eq(..) ) -import GHC.Exts -import GHC.TypeLits - -- | Vectorised conjunction: Element-wise returns true if both arguments in -- the corresponding lane are True. This is a strict vectorised version of -- '(Data.Array.Accelerate.&&)' that always evaluates both arguments. -- +-- @since 1.4.0.0 +-- infixr 3 &&* (&&*) :: KnownNat n => Exp (Vec n Bool) -> Exp (Vec n Bool) -> Exp (Vec n Bool) -(&&*) = mkLAnd +(&&*) = mkPrimBinary $ PrimLAnd bitType -- | Vectorised disjunction: Element-wise returns true if either argument -- in the corresponding lane is true. This is a strict vectorised version -- of '(Data.Array.Accelerate.||)' that always evaluates both arguments. -- +-- @since 1.4.0.0 +-- infixr 2 ||* (||*) :: KnownNat n => Exp (Vec n Bool) -> Exp (Vec n Bool) -> Exp (Vec n Bool) -(||*) = mkLOr +(||*) = mkPrimBinary $ PrimLOr bitType -- | Vectorised logical negation -- +-- @since 1.4.0.0 +-- vnot :: KnownNat n => Exp (Vec n Bool) -> Exp (Vec n Bool) -vnot = mkLNot +vnot = mkPrimUnary $ PrimLNot bitType infix 4 ==* @@ -121,8 +121,8 @@ runQ $ do mkPrim :: Name -> Q [Dec] mkPrim name = [d| instance KnownNat n => VEq n $(conT name) where - (==*) = mkEq - (/=*) = mkNEq + (==*) = mkPrimBinary $ PrimEq scalarType + (/=*) = mkPrimBinary $ PrimNEq scalarType |] mkTup :: Word8 -> Q Dec @@ -137,7 +137,7 @@ runQ $ do ctx = (++) <$> mapM (appT [t| Eq |]) ts <*> mapM (appT [t| SIMD $(varT w) |]) ts - cmp f = [| mkPack (zipWith $f (mkUnpack $(varE x)) (mkUnpack $(varE y))) |] + cmp f = [| pack (zipWith $f (unpack $(varE x)) (unpack $(varE y))) |] -- instanceD ctx [t| VEq $(varT w) $res |] [ funD (mkName "==*") [ clause [varP x, varP y] (normalB (cmp [| (==) |])) [] ] @@ -161,6 +161,10 @@ instance KnownNat n => VEq n Z where _ /=* _ = vfalse instance KnownNat n => VEq n Bool where + (==*) = mkPrimBinary $ PrimEq scalarType + (/=*) = mkPrimBinary $ PrimNEq scalarType + +{-- (==*) = let n = natVal' (proxy# :: Proxy# n) -- @@ -200,10 +204,11 @@ instance KnownNat n => VEq n Bool where if n <= 64 then cmp @Word64 else if n <= 128 then cmp @Word128 else internalError "Can not handle SIMD vector types with more than 128 lanes" +--} instance (Eq sh, SIMD n sh) => VEq n (sh :. Int) where - x ==* y = mkPack (zipWith (==) (mkUnpack x) (mkUnpack y)) - x /=* y = mkPack (zipWith (/=) (mkUnpack x) (mkUnpack y)) + x ==* y = pack (zipWith (==) (unpack x) (unpack y)) + x /=* y = pack (zipWith (/=) (unpack x) (unpack y)) instance KnownNat n => VEq n Ordering where x ==* y = mkCoerce x ==* (mkCoerce y :: Exp (Vec n TAG)) diff --git a/src/Data/Array/Accelerate/Classes/VOrd.hs b/src/Data/Array/Accelerate/Classes/VOrd.hs index 82e608a53..6f041fb3f 100644 --- a/src/Data/Array/Accelerate/Classes/VOrd.hs +++ b/src/Data/Array/Accelerate/Classes/VOrd.hs @@ -19,6 +19,7 @@ module Data.Array.Accelerate.Classes.VOrd ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.VEq import Data.Array.Accelerate.Representation.Tag @@ -43,7 +44,9 @@ infix 4 <=* infix 4 >=* -- | The 'VOrd' class defines lane-wise comparisons for totally ordered --- datatypes. +-- data types. +-- +-- @since 1.4.0.0 -- class VEq n a => VOrd n a where {-# MINIMAL (<=*) | vcompare #-} @@ -113,12 +116,12 @@ runQ $ do mkPrim :: Name -> Q [Dec] mkPrim name = [d| instance KnownNat n => VOrd n $(conT name) where - (<*) = mkLt - (>*) = mkGt - (<=*) = mkLtEq - (>=*) = mkGtEq - vmin = mkMin - vmax = mkMax + (<*) = mkPrimBinary $ PrimLt scalarType + (>*) = mkPrimBinary $ PrimGt scalarType + (<=*) = mkPrimBinary $ PrimLtEq scalarType + (>=*) = mkPrimBinary $ PrimGtEq scalarType + vmin = mkPrimBinary $ PrimMin scalarType + vmax = mkPrimBinary $ PrimMax scalarType |] mkTup :: Word8 -> Q Dec @@ -132,7 +135,7 @@ runQ $ do res = tupT ts ctx = (++) <$> mapM (appT [t| Ord |]) ts <*> mapM (appT [t| SIMD $(varT w) |]) ts - cmp f = [| mkPack (zipWith $f (mkUnpack $(varE x)) (mkUnpack $(varE y))) |] + cmp f = [| pack (zipWith $f (unpack $(varE x)) (unpack $(varE y))) |] -- instanceD ctx [t| VOrd $(varT w) $res |] [ funD (mkName "<*") [ clause [varP x, varP y] (normalB (cmp [| (<) |])) [] ] @@ -166,8 +169,8 @@ instance KnownNat n => VOrd n Ordering where x >=* y = mkCoerce x >=* (mkCoerce y :: Exp (Vec n TAG)) instance (Ord sh, VOrd n sh) => VOrd n (sh :. Int) where - x <* y = mkPack (zipWith (<) (mkUnpack x) (mkUnpack y)) - x >* y = mkPack (zipWith (>) (mkUnpack x) (mkUnpack y)) - x <=* y = mkPack (zipWith (<=) (mkUnpack x) (mkUnpack y)) - x >=* y = mkPack (zipWith (>=) (mkUnpack x) (mkUnpack y)) + x <* y = pack (zipWith (<) (unpack x) (unpack y)) + x >* y = pack (zipWith (>) (unpack x) (unpack y)) + x <=* y = pack (zipWith (<=) (unpack x) (unpack y)) + x >=* y = pack (zipWith (>=) (unpack x) (unpack y)) diff --git a/src/Data/Array/Accelerate/Data/Bits.hs b/src/Data/Array/Accelerate/Data/Bits.hs index f3887f8b1..d60a61872 100644 --- a/src/Data/Array/Accelerate/Data/Bits.hs +++ b/src/Data/Array/Accelerate/Data/Bits.hs @@ -31,6 +31,7 @@ module Data.Array.Accelerate.Data.Bits ( ) where import Data.Array.Accelerate.AST ( BitOrMask ) +import Data.Array.Accelerate.AST ( PrimFun(..), BitOrMask ) import Data.Array.Accelerate.Language import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt @@ -215,7 +216,7 @@ shiftRDefault -- Shift the argument right (signed) shiftRADefault :: forall t. (B.FiniteBits t, Num t, Ord t, FromIntegral Int t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp t shiftRADefault x i - = cond (i >= P.fromIntegral (B.finiteBitSize (undefined::t))) (cond (mkLt x 0) (-1) 0) + = cond (i >= P.fromIntegral (B.finiteBitSize (undefined::t))) (cond (x < 0) (-1) 0) $ mkBShiftR x i -- Shift the argument right (unsigned) @@ -336,6 +337,39 @@ popCnt64 v1 = mkFromIntegral c c = (v4 * 0x0101010101010101) `unsafeShiftR` 56 --} +mkBAnd :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBAnd = mkPrimBinary $ PrimBAnd integralType + +mkBOr :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBOr = mkPrimBinary $ PrimBOr integralType + +mkBXor :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBXor = mkPrimBinary $ PrimBXor integralType + +mkBNot :: IsIntegral (EltR t) => Exp t -> Exp t +mkBNot = mkPrimUnary $ PrimBNot integralType + +mkBShiftL :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBShiftL = mkPrimBinary $ PrimBShiftL integralType + +mkBShiftR :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBShiftR = mkPrimBinary $ PrimBShiftR integralType + +mkBRotateL :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBRotateL = mkPrimBinary $ PrimBRotateL integralType + +mkBRotateR :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBRotateR = mkPrimBinary $ PrimBRotateR integralType + +mkPopCount :: IsIntegral (EltR t) => Exp t -> Exp t +mkPopCount = mkPrimUnary $ PrimPopCount integralType + +mkCountLeadingZeros :: IsIntegral (EltR t) => Exp t -> Exp t +mkCountLeadingZeros = mkPrimUnary $ PrimCountLeadingZeros integralType + +mkCountTrailingZeros :: IsIntegral (EltR t) => Exp t -> Exp t +mkCountTrailingZeros = mkPrimUnary $ PrimCountTrailingZeros integralType + runQ $ let integralTypes :: [Name] diff --git a/src/Data/Array/Accelerate/Data/Complex.hs b/src/Data/Array/Accelerate/Data/Complex.hs index 16a2f9863..58c188d87 100644 --- a/src/Data/Array/Accelerate/Data/Complex.hs +++ b/src/Data/Array/Accelerate/Data/Complex.hs @@ -61,7 +61,7 @@ import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Prelude import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type -import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Smart hiding ( pack, unpack ) import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type import qualified Data.Primitive.Vec as Prim diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index a7c0510e4..345ab23d3 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -17,13 +17,6 @@ -- Stability : experimental -- Portability : non-portable (GHC extensions) -- --- We use the dictionary view of overloaded operations (such as arithmetic and --- bit manipulation) to reify such expressions. With non-overloaded --- operations (such as, the logical connectives) and partially overloaded --- operations (such as comparisons), we use the standard operator names with a --- \'*\' attached. We keep the standard alphanumeric names as they can be --- easily qualified. --- module Data.Array.Accelerate.Language ( @@ -55,7 +48,6 @@ module Data.Array.Accelerate.Language ( Boundary, Stencil, clamp, mirror, wrap, function, - -- ** Common stencil types Stencil3, Stencil5, Stencil7, Stencil9, Stencil3x3, Stencil5x3, Stencil3x5, Stencil5x5, @@ -1475,12 +1467,12 @@ x ^^ n -- |Convert a character to an 'Int'. -- ord :: Exp Char -> Exp Int -ord = mkFromIntegral +ord = mkPrimUnary $ PrimFromIntegral integralType numType -- |Convert an 'Int' into a character. -- chr :: Exp Int -> Exp Char -chr = mkFromIntegral +chr = mkPrimUnary $ PrimFromIntegral integralType numType -- |Reinterpret a value as another type. The two representations must have the -- same bit size. diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index ae36e165a..f77f9bbf5 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -46,68 +46,34 @@ module Data.Array.Accelerate.Smart ( HasArraysR(..), HasTypeR(..), - -- ** Smart constructors for literals + -- ** Constants constant, undef, - -- ** Smart destructors for shapes + -- ** Shapes indexHead, indexTail, - -- ** Vector operations - splat, mkPack, mkUnpack, + -- ** SIMD vectors + splat, pack, unpack, extract, mkExtract, insert, mkInsert, shuffle, select, - vand, vor, - - -- ** Smart constructors for primitive functions - -- *** Operators from Num - mkAdd, mkSub, mkMul, mkNeg, mkAbs, mkSig, - - -- *** Operators from Integral - mkQuot, mkRem, mkQuotRem, mkIDiv, mkMod, mkDivMod, - - -- *** Operators from FiniteBits - mkBAnd, mkBOr, mkBXor, mkBNot, - mkBShiftL, mkBShiftR, mkBRotateL, mkBRotateR, - mkPopCount, mkCountLeadingZeros, mkCountTrailingZeros, - - -- *** Operators from Fractional and Floating - mkFDiv, mkRecip, - mkSin, mkCos, mkTan, - mkAsin, mkAcos, mkAtan, - mkSinh, mkCosh, mkTanh, - mkAsinh, mkAcosh, mkAtanh, - mkExpFloating, - mkSqrt, mkLog, - mkFPow, mkLogBase, - - -- *** Operators from RealFrac and RealFloat - mkTruncate, mkRound, mkFloor, mkCeiling, - mkAtan2, - mkIsNaN, mkIsInfinite, - - -- *** Relational and equality operators - mkLt, mkGt, mkLtEq, mkGtEq, mkEq, mkNEq, mkMax, mkMin, mkLAnd, mkLOr, mkLNot, - - -- ** Smart constructors for type coercion functions - mkFromIntegral, mkToFloating, mkToBool, mkFromBool, mkBitcast, mkCoerce, Coerce(..), - - -- ** Auxiliary functions + + -- ** Type coercions + mkBitcast, mkCoerce, Coerce(..), + + -- ** Miscellaneous ($$), ($$$), ($$$$), ($$$$$), ApplyAcc(..), unAcc, unAccFunction, mkExp, unExp, unExpFunction, unExpBinaryFunction, mkPrimUnary, mkPrimBinary, unPair, mkPairToTuple, - - -- ** Miscellaneous - formatPreAccOp, - formatPreExpOp, + formatPreAccOp, formatPreExpOp, ) where -import Data.Array.Accelerate.AST ( Direction(..), Message(..), PrimBool, PrimMaybe, PrimFun(..), BitOrMask, primFunType ) +import Data.Array.Accelerate.AST ( Direction(..), Message(..), PrimBool, PrimMaybe, PrimFun(..), primFunType ) import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Error @@ -130,7 +96,6 @@ import qualified Data.Array.Accelerate.Sugar.Shape as Sugar import qualified Data.Primitive.Vec as Prim -import Data.Bits ( Bits, unsafeShiftL, countLeadingZeros ) import Data.Kind import Data.Text.Lazy.Builder import Formatting hiding ( splat ) @@ -486,6 +451,12 @@ data PreSmartAcc acc exp as where -> acc (Array sh b) -> PreSmartAcc acc exp (Array sh c) + -- Coerce :: ShapeR sh + -- -> TypeR a + -- -> TypeR b + -- -> acc (Array sh a) + -- -> PreSmartAcc acc exp (Array sh b) + -- Embedded expressions of the surface language -- -------------------------------------------- @@ -1008,10 +979,14 @@ indexTail (Exp x) = mkExp $ Prj PairIdxLeft x -- @since 1.4.0.0 -- splat :: (SIMD n a, Elt a) => Exp a -> Exp (Vec n a) -splat x = mkPack (repeat x) +splat x = pack (repeat x) -mkPack :: forall n a. (SIMD n a, Elt a) => [Exp a] -> Exp (Vec n a) -mkPack xs = +-- | Pack scalar expressions into a single SIMD vector +-- +-- @since 1.4.0.0 +-- +pack :: forall n a. (SIMD n a, Elt a) => [Exp a] -> Exp (Vec n a) +pack xs = let go :: Word8 -> [Exp a] -> Exp (Vec n a) -> Exp (Vec n a) go _ [] vec = vec go i (v:vs) vec = go (i+1) vs (insert vec (constant i) v) @@ -1020,8 +995,12 @@ mkPack xs = in go 0 (take n xs) undef -mkUnpack :: forall n a. (SIMD n a, Elt a) => Exp (Vec n a) -> [Exp a] -mkUnpack v = +-- | Unpack the lanes of a SIMD vector into scalar elements +-- +-- @since 1.4.0.0 +-- +unpack :: forall n a. (SIMD n a, Elt a) => Exp (Vec n a) -> [Exp a] +unpack v = let n = fromIntegral (natVal' (proxy# :: Proxy# n)) :: Word8 in map (extract v . constant) [0 .. n-1] @@ -1231,285 +1210,8 @@ select (Exp mask) (Exp tt) (Exp ff) = Exp $ go (vecR @n @a) tt ff floating _ = error "impossible" --- | Return 'True' if all lanes of the vector are 'True' --- --- @since 1.4.0.0 --- -vand :: forall n. KnownNat n => Exp (Vec n Bool) -> Exp Bool -vand v = - let n, m :: Int - n = fromInteger $ natVal' (proxy# :: Proxy# n) - m = max 8 (1 `unsafeShiftL` (64 - countLeadingZeros (n-1))) - -- - cmp :: forall t. (Elt t, Num t, Bits t, IsScalar (EltR t), IsNum (EltR t), BitOrMask (EltR t) ~ Bit) - => Exp Bool - cmp = let w = SingleIntegralType (TypeWord n) - b = mkExp $ Bitcast (BitScalarType bitType) (NumScalarType (IntegralNumType w)) (unExp v) - in - mkEq (constant ((1 `unsafeShiftL` n) - 1)) $ - if n == m - then b - else mkPrimUnary (PrimFromIntegral w numType) b :: Exp t - in - if n == 1 then mkExp (Bitcast scalarType scalarType (unExp v)) else - if n <= 8 then cmp @Word8 else - if n <= 16 then cmp @Word16 else - if n <= 32 then cmp @Word32 else - if n <= 64 then cmp @Word64 else - if n <= 128 then cmp @Word128 else - internalError "Can not handle Vec types with more than 128 lanes" - --- | Return 'True' if any lane of the vector is 'True' --- --- @since 1.4.0.0 --- -vor :: forall n. KnownNat n => Exp (Vec n Bool) -> Exp Bool -vor v = - let n, m :: Int - n = fromInteger $ natVal' (proxy# :: Proxy# n) - m = max 8 (1 `unsafeShiftL` (64 - countLeadingZeros (n-1))) - -- - cmp :: forall t. (Elt t, Num t, IsScalar (EltR t), IsNum (EltR t), BitOrMask (EltR t) ~ Bit) - => Exp Bool - cmp = let w = SingleIntegralType (TypeWord n) - b = mkExp $ Bitcast (BitScalarType bitType) (NumScalarType (IntegralNumType w)) (unExp v) - in - mkNEq (constant 0) $ - if n == m - then b - else mkPrimUnary (PrimFromIntegral w numType) b :: Exp t - in - if n == 1 then mkExp (Bitcast scalarType scalarType (unExp v)) else - if n <= 8 then cmp @Word8 else - if n <= 16 then cmp @Word16 else - if n <= 32 then cmp @Word32 else - if n <= 64 then cmp @Word64 else - if n <= 128 then cmp @Word128 else - internalError "Can not handle Vec types with more than 128 lanes" - - --- Smart constructors for primitive applications --- - --- Operators from Floating - -mkSin :: IsFloating (EltR t) => Exp t -> Exp t -mkSin = mkPrimUnary $ PrimSin floatingType - -mkCos :: IsFloating (EltR t) => Exp t -> Exp t -mkCos = mkPrimUnary $ PrimCos floatingType - -mkTan :: IsFloating (EltR t) => Exp t -> Exp t -mkTan = mkPrimUnary $ PrimTan floatingType - -mkAsin :: IsFloating (EltR t) => Exp t -> Exp t -mkAsin = mkPrimUnary $ PrimAsin floatingType - -mkAcos :: IsFloating (EltR t) => Exp t -> Exp t -mkAcos = mkPrimUnary $ PrimAcos floatingType - -mkAtan :: IsFloating (EltR t) => Exp t -> Exp t -mkAtan = mkPrimUnary $ PrimAtan floatingType - -mkSinh :: IsFloating (EltR t) => Exp t -> Exp t -mkSinh = mkPrimUnary $ PrimSinh floatingType - -mkCosh :: IsFloating (EltR t) => Exp t -> Exp t -mkCosh = mkPrimUnary $ PrimCosh floatingType - -mkTanh :: IsFloating (EltR t) => Exp t -> Exp t -mkTanh = mkPrimUnary $ PrimTanh floatingType - -mkAsinh :: IsFloating (EltR t) => Exp t -> Exp t -mkAsinh = mkPrimUnary $ PrimAsinh floatingType - -mkAcosh :: IsFloating (EltR t) => Exp t -> Exp t -mkAcosh = mkPrimUnary $ PrimAcosh floatingType - -mkAtanh :: IsFloating (EltR t) => Exp t -> Exp t -mkAtanh = mkPrimUnary $ PrimAtanh floatingType - -mkExpFloating :: IsFloating (EltR t) => Exp t -> Exp t -mkExpFloating = mkPrimUnary $ PrimExpFloating floatingType - -mkSqrt :: IsFloating (EltR t) => Exp t -> Exp t -mkSqrt = mkPrimUnary $ PrimSqrt floatingType - -mkLog :: IsFloating (EltR t) => Exp t -> Exp t -mkLog = mkPrimUnary $ PrimLog floatingType - -mkFPow :: IsFloating (EltR t) => Exp t -> Exp t -> Exp t -mkFPow = mkPrimBinary $ PrimFPow floatingType - -mkLogBase :: IsFloating (EltR t) => Exp t -> Exp t -> Exp t -mkLogBase = mkPrimBinary $ PrimLogBase floatingType - --- Operators from Num - -mkAdd :: IsNum (EltR t) => Exp t -> Exp t -> Exp t -mkAdd = mkPrimBinary $ PrimAdd numType - -mkSub :: IsNum (EltR t) => Exp t -> Exp t -> Exp t -mkSub = mkPrimBinary $ PrimSub numType - -mkMul :: IsNum (EltR t) => Exp t -> Exp t -> Exp t -mkMul = mkPrimBinary $ PrimMul numType - -mkNeg :: IsNum (EltR t) => Exp t -> Exp t -mkNeg = mkPrimUnary $ PrimNeg numType - -mkAbs :: IsNum (EltR t) => Exp t -> Exp t -mkAbs = mkPrimUnary $ PrimAbs numType - -mkSig :: IsNum (EltR t) => Exp t -> Exp t -mkSig = mkPrimUnary $ PrimSig numType - --- Operators from Integral - -mkQuot :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t -mkQuot = mkPrimBinary $ PrimQuot integralType - -mkRem :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t -mkRem = mkPrimBinary $ PrimRem integralType - -mkQuotRem :: IsIntegral (EltR t) => Exp t -> Exp t -> (Exp t, Exp t) -mkQuotRem (Exp x) (Exp y) = - let pair = SmartExp $ PrimQuotRem integralType `PrimApp` SmartExp (Pair x y) - in ( mkExp $ Prj PairIdxLeft pair - , mkExp $ Prj PairIdxRight pair) - -mkIDiv :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t -mkIDiv = mkPrimBinary $ PrimIDiv integralType - -mkMod :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t -mkMod = mkPrimBinary $ PrimMod integralType - -mkDivMod :: IsIntegral (EltR t) => Exp t -> Exp t -> (Exp t, Exp t) -mkDivMod (Exp x) (Exp y) = - let pair = SmartExp $ PrimDivMod integralType `PrimApp` SmartExp (Pair x y) - in ( mkExp $ Prj PairIdxLeft pair - , mkExp $ Prj PairIdxRight pair) - --- Operators from Bits and FiniteBits - -mkBAnd :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t -mkBAnd = mkPrimBinary $ PrimBAnd integralType - -mkBOr :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t -mkBOr = mkPrimBinary $ PrimBOr integralType - -mkBXor :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t -mkBXor = mkPrimBinary $ PrimBXor integralType - -mkBNot :: IsIntegral (EltR t) => Exp t -> Exp t -mkBNot = mkPrimUnary $ PrimBNot integralType - -mkBShiftL :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t -mkBShiftL = mkPrimBinary $ PrimBShiftL integralType - -mkBShiftR :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t -mkBShiftR = mkPrimBinary $ PrimBShiftR integralType - -mkBRotateL :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t -mkBRotateL = mkPrimBinary $ PrimBRotateL integralType - -mkBRotateR :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t -mkBRotateR = mkPrimBinary $ PrimBRotateR integralType - -mkPopCount :: IsIntegral (EltR t) => Exp t -> Exp t -mkPopCount = mkPrimUnary $ PrimPopCount integralType - -mkCountLeadingZeros :: IsIntegral (EltR t) => Exp t -> Exp t -mkCountLeadingZeros = mkPrimUnary $ PrimCountLeadingZeros integralType - -mkCountTrailingZeros :: IsIntegral (EltR t) => Exp t -> Exp t -mkCountTrailingZeros = mkPrimUnary $ PrimCountTrailingZeros integralType - --- Operators from Fractional - -mkFDiv :: IsFloating (EltR t) => Exp t -> Exp t -> Exp t -mkFDiv = mkPrimBinary $ PrimFDiv floatingType - -mkRecip :: IsFloating (EltR t) => Exp t -> Exp t -mkRecip = mkPrimUnary $ PrimRecip floatingType - --- Operators from RealFrac - -mkTruncate :: (IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b -mkTruncate = mkPrimUnary $ PrimTruncate floatingType integralType - -mkRound :: (IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b -mkRound = mkPrimUnary $ PrimRound floatingType integralType - -mkFloor :: (IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b -mkFloor = mkPrimUnary $ PrimFloor floatingType integralType - -mkCeiling :: (IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b -mkCeiling = mkPrimUnary $ PrimCeiling floatingType integralType - --- Operators from RealFloat - -mkAtan2 :: IsFloating (EltR t) => Exp t -> Exp t -> Exp t -mkAtan2 = mkPrimBinary $ PrimAtan2 floatingType - -mkIsNaN :: (IsFloating (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp b -mkIsNaN = mkPrimUnary $ PrimIsNaN floatingType - -mkIsInfinite :: (IsFloating (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp b -mkIsInfinite = mkPrimUnary $ PrimIsInfinite floatingType - --- Relational and equality operators - -mkLt :: (IsScalar (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp t -> Exp b -mkLt = mkPrimBinary $ PrimLt scalarType - -mkGt :: (IsScalar (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp t -> Exp b -mkGt = mkPrimBinary $ PrimGt scalarType - -mkLtEq :: (IsScalar (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp t -> Exp b -mkLtEq = mkPrimBinary $ PrimLtEq scalarType - -mkGtEq :: (IsScalar (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp t -> Exp b -mkGtEq = mkPrimBinary $ PrimGtEq scalarType - -mkEq :: (IsScalar (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp t -> Exp b -mkEq = mkPrimBinary $ PrimEq scalarType - -mkNEq :: (IsScalar (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp t -> Exp b -mkNEq = mkPrimBinary $ PrimNEq scalarType - -mkMax :: IsScalar (EltR t) => Exp t -> Exp t -> Exp t -mkMax = mkPrimBinary $ PrimMax scalarType - -mkMin :: IsScalar (EltR t) => Exp t -> Exp t -> Exp t -mkMin = mkPrimBinary $ PrimMin scalarType - --- Logical operators - -mkLAnd :: IsBit (EltR t) => Exp t -> Exp t -> Exp t -mkLAnd = mkPrimBinary $ PrimLAnd bitType - -mkLOr :: IsBit (EltR t) => Exp t -> Exp t -> Exp t -mkLOr = mkPrimBinary $ PrimLOr bitType - -mkLNot :: IsBit (EltR t) => Exp t -> Exp t -mkLNot = mkPrimUnary $ PrimLNot bitType - --- Numeric conversions - -mkFromIntegral :: (IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b -mkFromIntegral = mkPrimUnary $ PrimFromIntegral integralType numType - -mkToFloating :: (IsNum (EltR a), IsFloating (EltR b)) => Exp a -> Exp b -mkToFloating = mkPrimUnary $ PrimToFloating numType floatingType - -mkToBool :: (IsIntegral (EltR a), IsBit (EltR b)) => Exp a -> Exp b -mkToBool = mkPrimUnary $ PrimToBool integralType bitType - -mkFromBool :: (IsBit (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b -mkFromBool = mkPrimUnary $ PrimFromBool bitType integralType - --- Other conversions +-- Coercions between types +-- ----------------------- mkBitcast :: forall b a. (IsScalar (EltR a), IsScalar (EltR b), BitSizeEq (EltR a) (EltR b)) => Exp a -> Exp b mkBitcast (Exp a) = mkExp $ Bitcast (scalarType @(EltR a)) (scalarType @(EltR b)) a diff --git a/src/Data/Primitive/Bit.hs b/src/Data/Primitive/Bit.hs index 57ff22b82..17e2f17e4 100644 --- a/src/Data/Primitive/Bit.hs +++ b/src/Data/Primitive/Bit.hs @@ -352,7 +352,7 @@ testBitWord8# :: Word8# -> Int# -> Int# testBitWord8# x# i# = (x# `andWord8#` bitWord8# i#) `neWord8#` wordToWord8# 0## bitWord8# :: Int# -> Word8# -bitWord8# i# = (wordToWord8# 1##) `uncheckedShiftLWord8#` i# +bitWord8# i# = wordToWord8# 1## `uncheckedShiftLWord8#` i# complementWord8# :: Word8# -> Word8# complementWord8# x# = x# `xorWord8#` wordToWord8# 0xff## diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 851a7dbbd..a0bf661da 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -2,7 +2,6 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} From d11f98c526f597cee33658833a0a819bbf5607f8 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 21 Aug 2023 15:31:33 +0200 Subject: [PATCH 74/86] add more [SIMD vector] primops - VNum class for horizontal sum/product of a vector - vminimum/vmaximum for horizontal vector min/max reduction - bitreverse - byteswap --- accelerate.cabal | 1 + src/Data/Array/Accelerate.hs | 2 + src/Data/Array/Accelerate/AST.hs | 414 ++++++++++-------- src/Data/Array/Accelerate/Analysis/Hash.hs | 15 +- src/Data/Array/Accelerate/Analysis/Match.hs | 50 +-- src/Data/Array/Accelerate/Classes/Ord.hs-boot | 43 ++ src/Data/Array/Accelerate/Classes/VEq.hs | 16 + src/Data/Array/Accelerate/Classes/VEq.hs-boot | 7 +- src/Data/Array/Accelerate/Classes/VNum.hs | 81 ++++ src/Data/Array/Accelerate/Classes/VOrd.hs | 10 + .../Array/Accelerate/Classes/VOrd.hs-boot | 29 +- src/Data/Array/Accelerate/Data/Bits.hs | 26 +- src/Data/Array/Accelerate/Interpreter.hs | 140 +++--- .../Accelerate/Interpreter/Arithmetic.hs | 148 +++++-- src/Data/Array/Accelerate/Pretty/Print.hs | 140 +++--- src/Data/Array/Accelerate/Sugar/Vec.hs | 10 +- src/Data/Array/Accelerate/Trafo/Simplify.hs | 19 +- 17 files changed, 767 insertions(+), 384 deletions(-) create mode 100644 src/Data/Array/Accelerate/Classes/Ord.hs-boot create mode 100644 src/Data/Array/Accelerate/Classes/VNum.hs diff --git a/accelerate.cabal b/accelerate.cabal index 6a80e0bbd..5228911e3 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -473,6 +473,7 @@ library Data.Array.Accelerate.Classes.RealFrac Data.Array.Accelerate.Classes.ToFloating Data.Array.Accelerate.Classes.VEq + Data.Array.Accelerate.Classes.VNum Data.Array.Accelerate.Classes.VOrd Data.Array.Accelerate.Debug.Internal.Clock Data.Array.Accelerate.Debug.Internal.Flags diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index f6eaee0b3..0c6508ad3 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -323,6 +323,7 @@ module Data.Array.Accelerate ( -- *** Numeric type classes Num, (+), (-), (*), negate, abs, signum, fromInteger, + VNum(..), Integral, quot, rem, div, mod, quotRem, divMod, Rational(..), Fractional, (/), recip, fromRational, @@ -452,6 +453,7 @@ import Data.Array.Accelerate.Classes.RealFloat import Data.Array.Accelerate.Classes.RealFrac import Data.Array.Accelerate.Classes.ToFloating import Data.Array.Accelerate.Classes.VEq +import Data.Array.Accelerate.Classes.VNum import Data.Array.Accelerate.Classes.VOrd import Data.Array.Accelerate.Data.Either import Data.Array.Accelerate.Data.Maybe diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 97b4ae93e..1e162f77a 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -676,6 +676,8 @@ data PrimFun sig where PrimNeg :: NumType a -> PrimFun (a -> a) PrimAbs :: NumType a -> PrimFun (a -> a) PrimSig :: NumType a -> PrimFun (a -> a) + PrimVAdd :: NumType (Vec n a) -> PrimFun (Vec n a -> a) + PrimVMul :: NumType (Vec n a) -> PrimFun (Vec n a -> a) -- operators from Integral PrimQuot :: IntegralType a -> PrimFun ((a, a) -> a) @@ -697,6 +699,11 @@ data PrimFun sig where PrimPopCount :: IntegralType a -> PrimFun (a -> a) PrimCountLeadingZeros :: IntegralType a -> PrimFun (a -> a) PrimCountTrailingZeros :: IntegralType a -> PrimFun (a -> a) + PrimBReverse :: IntegralType a -> PrimFun (a -> a) + PrimBSwap :: IntegralType a -> PrimFun (a -> a) -- prerequisite: BitSize a % 16 == 0 + PrimVBAnd :: IntegralType (Vec n a) -> PrimFun (Vec n a -> a) + PrimVBOr :: IntegralType (Vec n a) -> PrimFun (Vec n a -> a) + PrimVBXor :: IntegralType (Vec n a) -> PrimFun (Vec n a -> a) -- operators from Fractional and Floating PrimFDiv :: FloatingType a -> PrimFun ((a, a) -> a) @@ -732,14 +739,16 @@ data PrimFun sig where PrimIsInfinite :: FloatingType a -> PrimFun (a -> BitOrMask a) -- relational and equality operators - PrimLt :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) - PrimGt :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) - PrimLtEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) - PrimGtEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) - PrimEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) - PrimNEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) - PrimMax :: ScalarType a -> PrimFun ((a, a) -> a) - PrimMin :: ScalarType a -> PrimFun ((a, a) -> a) + PrimLt :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimGt :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimLtEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimGtEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimNEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimMin :: ScalarType a -> PrimFun ((a, a) -> a) + PrimMax :: ScalarType a -> PrimFun ((a, a) -> a) + PrimVMin :: ScalarType (Vec n a) -> PrimFun (Vec n a -> a) + PrimVMax :: ScalarType (Vec n a) -> PrimFun (Vec n a -> a) -- logical operators -- @@ -751,9 +760,11 @@ data PrimFun sig where -- short-circuiting, while (&&!) and (||!) are strict versions of these -- operators, which are defined using PrimLAnd and PrimLOr. -- - PrimLAnd :: BitType a -> PrimFun ((a, a) -> a) - PrimLOr :: BitType a -> PrimFun ((a, a) -> a) - PrimLNot :: BitType a -> PrimFun (a -> a) + PrimLAnd :: BitType a -> PrimFun ((a, a) -> a) + PrimLOr :: BitType a -> PrimFun ((a, a) -> a) + PrimLNot :: BitType a -> PrimFun (a -> a) + PrimVLAnd :: BitType (Vec n a) -> PrimFun (Vec n a -> a) + PrimVLOr :: BitType (Vec n a) -> PrimFun (Vec n a -> a) -- general conversion between types PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b) @@ -874,27 +885,34 @@ primFunType = \case PrimNeg t -> unary' $ num t PrimAbs t -> unary' $ num t PrimSig t -> unary' $ num t + PrimVAdd t -> unary (num t) (scalar_num t) + PrimVMul t -> unary (num t) (scalar_num t) -- Integral PrimQuot t -> binary' $ integral t PrimRem t -> binary' $ integral t - PrimQuotRem t -> unary' $ integral t `TupRpair` integral t + PrimQuotRem t -> unary' $ integral t `TupRpair` integral t PrimIDiv t -> binary' $ integral t PrimMod t -> binary' $ integral t - PrimDivMod t -> unary' $ integral t `TupRpair` integral t + PrimDivMod t -> unary' $ integral t `TupRpair` integral t -- Bits & FiniteBits PrimBAnd t -> binary' $ integral t PrimBOr t -> binary' $ integral t PrimBXor t -> binary' $ integral t - PrimBNot t -> unary' $ integral t - PrimBShiftL t -> (integral t `TupRpair` integral t, integral t) - PrimBShiftR t -> (integral t `TupRpair` integral t, integral t) - PrimBRotateL t -> (integral t `TupRpair` integral t, integral t) - PrimBRotateR t -> (integral t `TupRpair` integral t, integral t) - PrimPopCount t -> unary (integral t) (integral t) - PrimCountLeadingZeros t -> unary (integral t) (integral t) - PrimCountTrailingZeros t -> unary (integral t) (integral t) + PrimBNot t -> unary' $ integral t + PrimBShiftL t -> binary' $ integral t + PrimBShiftR t -> binary' $ integral t + PrimBRotateL t -> binary' $ integral t + PrimBRotateR t -> binary' $ integral t + PrimPopCount t -> unary' $ integral t + PrimCountLeadingZeros t -> unary' $ integral t + PrimCountTrailingZeros t -> unary' $ integral t + PrimBSwap t -> unary' $ integral t + PrimBReverse t -> unary' $ integral t + PrimVBAnd t -> unary (integral t) (scalar_integral t) + PrimVBOr t -> unary (integral t) (scalar_integral t) + PrimVBXor t -> unary (integral t) (scalar_integral t) -- Fractional, Floating PrimFDiv t -> binary' $ floating t @@ -925,8 +943,8 @@ primFunType = \case -- RealFloat PrimAtan2 t -> binary' $ floating t - PrimIsNaN t -> unary (floating t) (floating_mask t) - PrimIsInfinite t -> unary (floating t) (floating_mask t) + PrimIsNaN t -> unary (floating t) (mask_floating t) + PrimIsInfinite t -> unary (floating t) (mask_floating t) -- Relational and equality PrimLt t -> compare' t @@ -935,13 +953,17 @@ primFunType = \case PrimGtEq t -> compare' t PrimEq t -> compare' t PrimNEq t -> compare' t - PrimMax t -> binary' $ single t PrimMin t -> binary' $ single t + PrimMax t -> binary' $ single t + PrimVMin t -> unary (single t) (scalar_scalar t) + PrimVMax t -> unary (single t) (scalar_scalar t) -- Logical PrimLAnd t -> binary' (bit t) PrimLOr t -> binary' (bit t) PrimLNot t -> unary' (bit t) + PrimVLAnd t -> unary (bit t) (scalar_bit t) + PrimVLOr t -> unary (bit t) (scalar_bit t) -- general conversion between types PrimFromIntegral a b -> unary (integral a) (num b) @@ -954,7 +976,7 @@ primFunType = \case unary' a = unary a a binary a b = (a `TupRpair` a, b) binary' a = binary a a - compare' a = binary (single a) (scalar_mask a) + compare' a = binary (single a) (mask_scalar a) single = TupRsingle num = single . NumScalarType @@ -962,21 +984,21 @@ primFunType = \case integral = num . IntegralNumType floating = num . FloatingNumType - scalar_mask :: ScalarType t -> TypeR (BitOrMask t) - scalar_mask (NumScalarType t) = num_mask t - scalar_mask (BitScalarType t) = bit_mask t + mask_scalar :: ScalarType t -> TypeR (BitOrMask t) + mask_scalar (NumScalarType t) = mask_num t + mask_scalar (BitScalarType t) = mask_bit t - bit_mask :: BitType t -> TypeR (BitOrMask t) - bit_mask TypeBit = bit TypeBit - bit_mask (TypeMask n) = bit (TypeMask n) + mask_bit :: BitType t -> TypeR (BitOrMask t) + mask_bit TypeBit = bit TypeBit + mask_bit (TypeMask n) = bit (TypeMask n) - num_mask :: NumType t -> TypeR (BitOrMask t) - num_mask (IntegralNumType t) = integral_mask t - num_mask (FloatingNumType t) = floating_mask t + mask_num :: NumType t -> TypeR (BitOrMask t) + mask_num (IntegralNumType t) = mask_integral t + mask_num (FloatingNumType t) = mask_floating t - integral_mask :: IntegralType t -> TypeR (BitOrMask t) - integral_mask (VectorIntegralType n _) = single (BitScalarType (TypeMask n)) - integral_mask (SingleIntegralType t) = case t of + mask_integral :: IntegralType t -> TypeR (BitOrMask t) + mask_integral (VectorIntegralType n _) = single (BitScalarType (TypeMask n)) + mask_integral (SingleIntegralType t) = case t of TypeInt8 -> single (scalarType @Bit) TypeInt16 -> single (scalarType @Bit) TypeInt32 -> single (scalarType @Bit) @@ -988,14 +1010,35 @@ primFunType = \case TypeWord64 -> single (scalarType @Bit) TypeWord128 -> single (scalarType @Bit) - floating_mask :: FloatingType t -> TypeR (BitOrMask t) - floating_mask (VectorFloatingType n _) = single (BitScalarType (TypeMask n)) - floating_mask (SingleFloatingType t) = case t of + mask_floating :: FloatingType t -> TypeR (BitOrMask t) + mask_floating (VectorFloatingType n _) = single (BitScalarType (TypeMask n)) + mask_floating (SingleFloatingType t) = case t of TypeFloat16 -> single (scalarType @Bit) TypeFloat32 -> single (scalarType @Bit) TypeFloat64 -> single (scalarType @Bit) TypeFloat128 -> single (scalarType @Bit) + scalar_scalar :: ScalarType (Vec n a) -> TypeR a + scalar_scalar (NumScalarType t) = scalar_num t + scalar_scalar (BitScalarType t) = scalar_bit t + + scalar_bit :: BitType (Vec n a) -> TypeR a + scalar_bit TypeMask{} = bit TypeBit + + scalar_num :: NumType (Vec n a) -> TypeR a + scalar_num (IntegralNumType t) = scalar_integral t + scalar_num (FloatingNumType t) = scalar_floating t + + scalar_integral :: IntegralType (Vec n a) -> TypeR a + scalar_integral = \case + SingleIntegralType t -> case t of + VectorIntegralType _ t -> integral (SingleIntegralType t) + + scalar_floating :: FloatingType (Vec n a) -> TypeR a + scalar_floating = \case + SingleFloatingType t -> case t of + VectorFloatingType _ t -> floating (SingleFloatingType t) + -- Normal form data -- ================ @@ -1154,70 +1197,82 @@ rnfConst (TupRsingle t) !_ = rnfScalarType t -- scalars should have (nf = rnfConst (TupRpair ta tb) (a,b) = rnfConst ta a `seq` rnfConst tb b rnfPrimFun :: PrimFun f -> () -rnfPrimFun (PrimAdd t) = rnfNumType t -rnfPrimFun (PrimSub t) = rnfNumType t -rnfPrimFun (PrimMul t) = rnfNumType t -rnfPrimFun (PrimNeg t) = rnfNumType t -rnfPrimFun (PrimAbs t) = rnfNumType t -rnfPrimFun (PrimSig t) = rnfNumType t -rnfPrimFun (PrimQuot t) = rnfIntegralType t -rnfPrimFun (PrimRem t) = rnfIntegralType t -rnfPrimFun (PrimQuotRem t) = rnfIntegralType t -rnfPrimFun (PrimIDiv t) = rnfIntegralType t -rnfPrimFun (PrimMod t) = rnfIntegralType t -rnfPrimFun (PrimDivMod t) = rnfIntegralType t -rnfPrimFun (PrimBAnd t) = rnfIntegralType t -rnfPrimFun (PrimBOr t) = rnfIntegralType t -rnfPrimFun (PrimBXor t) = rnfIntegralType t -rnfPrimFun (PrimBNot t) = rnfIntegralType t -rnfPrimFun (PrimBShiftL t) = rnfIntegralType t -rnfPrimFun (PrimBShiftR t) = rnfIntegralType t -rnfPrimFun (PrimBRotateL t) = rnfIntegralType t -rnfPrimFun (PrimBRotateR t) = rnfIntegralType t -rnfPrimFun (PrimPopCount t) = rnfIntegralType t -rnfPrimFun (PrimCountLeadingZeros t) = rnfIntegralType t -rnfPrimFun (PrimCountTrailingZeros t) = rnfIntegralType t -rnfPrimFun (PrimFDiv t) = rnfFloatingType t -rnfPrimFun (PrimRecip t) = rnfFloatingType t -rnfPrimFun (PrimSin t) = rnfFloatingType t -rnfPrimFun (PrimCos t) = rnfFloatingType t -rnfPrimFun (PrimTan t) = rnfFloatingType t -rnfPrimFun (PrimAsin t) = rnfFloatingType t -rnfPrimFun (PrimAcos t) = rnfFloatingType t -rnfPrimFun (PrimAtan t) = rnfFloatingType t -rnfPrimFun (PrimSinh t) = rnfFloatingType t -rnfPrimFun (PrimCosh t) = rnfFloatingType t -rnfPrimFun (PrimTanh t) = rnfFloatingType t -rnfPrimFun (PrimAsinh t) = rnfFloatingType t -rnfPrimFun (PrimAcosh t) = rnfFloatingType t -rnfPrimFun (PrimAtanh t) = rnfFloatingType t -rnfPrimFun (PrimExpFloating t) = rnfFloatingType t -rnfPrimFun (PrimSqrt t) = rnfFloatingType t -rnfPrimFun (PrimLog t) = rnfFloatingType t -rnfPrimFun (PrimFPow t) = rnfFloatingType t -rnfPrimFun (PrimLogBase t) = rnfFloatingType t -rnfPrimFun (PrimTruncate f i) = rnfFloatingType f `seq` rnfIntegralType i -rnfPrimFun (PrimRound f i) = rnfFloatingType f `seq` rnfIntegralType i -rnfPrimFun (PrimFloor f i) = rnfFloatingType f `seq` rnfIntegralType i -rnfPrimFun (PrimCeiling f i) = rnfFloatingType f `seq` rnfIntegralType i -rnfPrimFun (PrimIsNaN t) = rnfFloatingType t -rnfPrimFun (PrimIsInfinite t) = rnfFloatingType t -rnfPrimFun (PrimAtan2 t) = rnfFloatingType t -rnfPrimFun (PrimLt t) = rnfScalarType t -rnfPrimFun (PrimGt t) = rnfScalarType t -rnfPrimFun (PrimLtEq t) = rnfScalarType t -rnfPrimFun (PrimGtEq t) = rnfScalarType t -rnfPrimFun (PrimEq t) = rnfScalarType t -rnfPrimFun (PrimNEq t) = rnfScalarType t -rnfPrimFun (PrimMax t) = rnfScalarType t -rnfPrimFun (PrimMin t) = rnfScalarType t -rnfPrimFun (PrimLAnd t) = rnfBitType t -rnfPrimFun (PrimLOr t) = rnfBitType t -rnfPrimFun (PrimLNot t) = rnfBitType t -rnfPrimFun (PrimFromIntegral i n) = rnfIntegralType i `seq` rnfNumType n -rnfPrimFun (PrimToFloating n f) = rnfNumType n `seq` rnfFloatingType f -rnfPrimFun (PrimToBool i b) = rnfIntegralType i `seq` rnfBitType b -rnfPrimFun (PrimFromBool b i) = rnfBitType b `seq` rnfIntegralType i +rnfPrimFun = \case + PrimAdd t -> rnfNumType t + PrimSub t -> rnfNumType t + PrimMul t -> rnfNumType t + PrimNeg t -> rnfNumType t + PrimAbs t -> rnfNumType t + PrimSig t -> rnfNumType t + PrimVAdd t -> rnfNumType t + PrimVMul t -> rnfNumType t + PrimQuot t -> rnfIntegralType t + PrimRem t -> rnfIntegralType t + PrimQuotRem t -> rnfIntegralType t + PrimIDiv t -> rnfIntegralType t + PrimMod t -> rnfIntegralType t + PrimDivMod t -> rnfIntegralType t + PrimBAnd t -> rnfIntegralType t + PrimBOr t -> rnfIntegralType t + PrimBXor t -> rnfIntegralType t + PrimBNot t -> rnfIntegralType t + PrimBShiftL t -> rnfIntegralType t + PrimBShiftR t -> rnfIntegralType t + PrimBRotateL t -> rnfIntegralType t + PrimBRotateR t -> rnfIntegralType t + PrimPopCount t -> rnfIntegralType t + PrimCountLeadingZeros t -> rnfIntegralType t + PrimCountTrailingZeros t -> rnfIntegralType t + PrimBReverse t -> rnfIntegralType t + PrimBSwap t -> rnfIntegralType t + PrimVBAnd t -> rnfIntegralType t + PrimVBOr t -> rnfIntegralType t + PrimVBXor t -> rnfIntegralType t + PrimFDiv t -> rnfFloatingType t + PrimRecip t -> rnfFloatingType t + PrimSin t -> rnfFloatingType t + PrimCos t -> rnfFloatingType t + PrimTan t -> rnfFloatingType t + PrimAsin t -> rnfFloatingType t + PrimAcos t -> rnfFloatingType t + PrimAtan t -> rnfFloatingType t + PrimSinh t -> rnfFloatingType t + PrimCosh t -> rnfFloatingType t + PrimTanh t -> rnfFloatingType t + PrimAsinh t -> rnfFloatingType t + PrimAcosh t -> rnfFloatingType t + PrimAtanh t -> rnfFloatingType t + PrimExpFloating t -> rnfFloatingType t + PrimSqrt t -> rnfFloatingType t + PrimLog t -> rnfFloatingType t + PrimFPow t -> rnfFloatingType t + PrimLogBase t -> rnfFloatingType t + PrimTruncate f i -> rnfFloatingType f `seq` rnfIntegralType i + PrimRound f i -> rnfFloatingType f `seq` rnfIntegralType i + PrimFloor f i -> rnfFloatingType f `seq` rnfIntegralType i + PrimCeiling f i -> rnfFloatingType f `seq` rnfIntegralType i + PrimAtan2 t -> rnfFloatingType t + PrimIsNaN t -> rnfFloatingType t + PrimIsInfinite t -> rnfFloatingType t + PrimLt t -> rnfScalarType t + PrimGt t -> rnfScalarType t + PrimLtEq t -> rnfScalarType t + PrimGtEq t -> rnfScalarType t + PrimEq t -> rnfScalarType t + PrimNEq t -> rnfScalarType t + PrimMin t -> rnfScalarType t + PrimMax t -> rnfScalarType t + PrimVMin t -> rnfScalarType t + PrimVMax t -> rnfScalarType t + PrimLAnd t -> rnfBitType t + PrimLOr t -> rnfBitType t + PrimLNot t -> rnfBitType t + PrimVLAnd t -> rnfBitType t + PrimVLOr t -> rnfBitType t + PrimFromIntegral i n -> rnfIntegralType i `seq` rnfNumType n + PrimToFloating n f -> rnfNumType n `seq` rnfFloatingType f + PrimToBool i b -> rnfIntegralType i `seq` rnfBitType b + PrimFromBool b i -> rnfBitType b `seq` rnfIntegralType i -- Template Haskell @@ -1371,77 +1426,90 @@ liftBoundary ArrayR (Array sh e) -> Boundary aenv (Array sh e) -> CodeQ (Boundary aenv (Array sh e)) -liftBoundary _ Clamp = [|| Clamp ||] -liftBoundary _ Mirror = [|| Mirror ||] -liftBoundary _ Wrap = [|| Wrap ||] -liftBoundary (ArrayR _ tp) (Constant v) = [|| Constant $$(liftElt tp v) ||] -liftBoundary _ (Function f) = [|| Function $$(liftOpenFun f) ||] +liftBoundary (ArrayR _ tR) = \case + Clamp -> [|| Clamp ||] + Mirror -> [|| Mirror ||] + Wrap -> [|| Wrap ||] + Constant v -> [|| Constant $$(liftElt tR v) ||] + Function f -> [|| Function $$(liftOpenFun f) ||] liftPrimFun :: PrimFun f -> CodeQ (PrimFun f) -liftPrimFun (PrimAdd t) = [|| PrimAdd $$(liftNumType t) ||] -liftPrimFun (PrimSub t) = [|| PrimSub $$(liftNumType t) ||] -liftPrimFun (PrimMul t) = [|| PrimMul $$(liftNumType t) ||] -liftPrimFun (PrimNeg t) = [|| PrimNeg $$(liftNumType t) ||] -liftPrimFun (PrimAbs t) = [|| PrimAbs $$(liftNumType t) ||] -liftPrimFun (PrimSig t) = [|| PrimSig $$(liftNumType t) ||] -liftPrimFun (PrimQuot t) = [|| PrimQuot $$(liftIntegralType t) ||] -liftPrimFun (PrimRem t) = [|| PrimRem $$(liftIntegralType t) ||] -liftPrimFun (PrimQuotRem t) = [|| PrimQuotRem $$(liftIntegralType t) ||] -liftPrimFun (PrimIDiv t) = [|| PrimIDiv $$(liftIntegralType t) ||] -liftPrimFun (PrimMod t) = [|| PrimMod $$(liftIntegralType t) ||] -liftPrimFun (PrimDivMod t) = [|| PrimDivMod $$(liftIntegralType t) ||] -liftPrimFun (PrimBAnd t) = [|| PrimBAnd $$(liftIntegralType t) ||] -liftPrimFun (PrimBOr t) = [|| PrimBOr $$(liftIntegralType t) ||] -liftPrimFun (PrimBXor t) = [|| PrimBXor $$(liftIntegralType t) ||] -liftPrimFun (PrimBNot t) = [|| PrimBNot $$(liftIntegralType t) ||] -liftPrimFun (PrimBShiftL t) = [|| PrimBShiftL $$(liftIntegralType t) ||] -liftPrimFun (PrimBShiftR t) = [|| PrimBShiftR $$(liftIntegralType t) ||] -liftPrimFun (PrimBRotateL t) = [|| PrimBRotateL $$(liftIntegralType t) ||] -liftPrimFun (PrimBRotateR t) = [|| PrimBRotateR $$(liftIntegralType t) ||] -liftPrimFun (PrimPopCount t) = [|| PrimPopCount $$(liftIntegralType t) ||] -liftPrimFun (PrimCountLeadingZeros t) = [|| PrimCountLeadingZeros $$(liftIntegralType t) ||] -liftPrimFun (PrimCountTrailingZeros t) = [|| PrimCountTrailingZeros $$(liftIntegralType t) ||] -liftPrimFun (PrimFDiv t) = [|| PrimFDiv $$(liftFloatingType t) ||] -liftPrimFun (PrimRecip t) = [|| PrimRecip $$(liftFloatingType t) ||] -liftPrimFun (PrimSin t) = [|| PrimSin $$(liftFloatingType t) ||] -liftPrimFun (PrimCos t) = [|| PrimCos $$(liftFloatingType t) ||] -liftPrimFun (PrimTan t) = [|| PrimTan $$(liftFloatingType t) ||] -liftPrimFun (PrimAsin t) = [|| PrimAsin $$(liftFloatingType t) ||] -liftPrimFun (PrimAcos t) = [|| PrimAcos $$(liftFloatingType t) ||] -liftPrimFun (PrimAtan t) = [|| PrimAtan $$(liftFloatingType t) ||] -liftPrimFun (PrimSinh t) = [|| PrimSinh $$(liftFloatingType t) ||] -liftPrimFun (PrimCosh t) = [|| PrimCosh $$(liftFloatingType t) ||] -liftPrimFun (PrimTanh t) = [|| PrimTanh $$(liftFloatingType t) ||] -liftPrimFun (PrimAsinh t) = [|| PrimAsinh $$(liftFloatingType t) ||] -liftPrimFun (PrimAcosh t) = [|| PrimAcosh $$(liftFloatingType t) ||] -liftPrimFun (PrimAtanh t) = [|| PrimAtanh $$(liftFloatingType t) ||] -liftPrimFun (PrimExpFloating t) = [|| PrimExpFloating $$(liftFloatingType t) ||] -liftPrimFun (PrimSqrt t) = [|| PrimSqrt $$(liftFloatingType t) ||] -liftPrimFun (PrimLog t) = [|| PrimLog $$(liftFloatingType t) ||] -liftPrimFun (PrimFPow t) = [|| PrimFPow $$(liftFloatingType t) ||] -liftPrimFun (PrimLogBase t) = [|| PrimLogBase $$(liftFloatingType t) ||] -liftPrimFun (PrimTruncate ta tb) = [|| PrimTruncate $$(liftFloatingType ta) $$(liftIntegralType tb) ||] -liftPrimFun (PrimRound ta tb) = [|| PrimRound $$(liftFloatingType ta) $$(liftIntegralType tb) ||] -liftPrimFun (PrimFloor ta tb) = [|| PrimFloor $$(liftFloatingType ta) $$(liftIntegralType tb) ||] -liftPrimFun (PrimCeiling ta tb) = [|| PrimCeiling $$(liftFloatingType ta) $$(liftIntegralType tb) ||] -liftPrimFun (PrimIsNaN t) = [|| PrimIsNaN $$(liftFloatingType t) ||] -liftPrimFun (PrimIsInfinite t) = [|| PrimIsInfinite $$(liftFloatingType t) ||] -liftPrimFun (PrimAtan2 t) = [|| PrimAtan2 $$(liftFloatingType t) ||] -liftPrimFun (PrimLt t) = [|| PrimLt $$(liftScalarType t) ||] -liftPrimFun (PrimGt t) = [|| PrimGt $$(liftScalarType t) ||] -liftPrimFun (PrimLtEq t) = [|| PrimLtEq $$(liftScalarType t) ||] -liftPrimFun (PrimGtEq t) = [|| PrimGtEq $$(liftScalarType t) ||] -liftPrimFun (PrimEq t) = [|| PrimEq $$(liftScalarType t) ||] -liftPrimFun (PrimNEq t) = [|| PrimNEq $$(liftScalarType t) ||] -liftPrimFun (PrimMax t) = [|| PrimMax $$(liftScalarType t) ||] -liftPrimFun (PrimMin t) = [|| PrimMin $$(liftScalarType t) ||] -liftPrimFun (PrimLAnd t) = [|| PrimLAnd $$(liftBitType t) ||] -liftPrimFun (PrimLOr t) = [|| PrimLOr $$(liftBitType t) ||] -liftPrimFun (PrimLNot t) = [|| PrimLNot $$(liftBitType t) ||] -liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||] -liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||] -liftPrimFun (PrimToBool ta tb) = [|| PrimToBool $$(liftIntegralType ta) $$(liftBitType tb) ||] -liftPrimFun (PrimFromBool ta tb) = [|| PrimFromBool $$(liftBitType ta) $$(liftIntegralType tb) ||] +liftPrimFun = \case + PrimAdd t -> [|| PrimAdd $$(liftNumType t) ||] + PrimSub t -> [|| PrimSub $$(liftNumType t) ||] + PrimMul t -> [|| PrimMul $$(liftNumType t) ||] + PrimNeg t -> [|| PrimNeg $$(liftNumType t) ||] + PrimAbs t -> [|| PrimAbs $$(liftNumType t) ||] + PrimSig t -> [|| PrimSig $$(liftNumType t) ||] + PrimVAdd t -> [|| PrimVAdd $$(liftNumType t) ||] + PrimVMul t -> [|| PrimVMul $$(liftNumType t) ||] + PrimQuot t -> [|| PrimQuot $$(liftIntegralType t) ||] + PrimRem t -> [|| PrimRem $$(liftIntegralType t) ||] + PrimQuotRem t -> [|| PrimQuotRem $$(liftIntegralType t) ||] + PrimIDiv t -> [|| PrimIDiv $$(liftIntegralType t) ||] + PrimMod t -> [|| PrimMod $$(liftIntegralType t) ||] + PrimDivMod t -> [|| PrimDivMod $$(liftIntegralType t) ||] + PrimBAnd t -> [|| PrimBAnd $$(liftIntegralType t) ||] + PrimBOr t -> [|| PrimBOr $$(liftIntegralType t) ||] + PrimBXor t -> [|| PrimBXor $$(liftIntegralType t) ||] + PrimBNot t -> [|| PrimBNot $$(liftIntegralType t) ||] + PrimBShiftL t -> [|| PrimBShiftL $$(liftIntegralType t) ||] + PrimBShiftR t -> [|| PrimBShiftR $$(liftIntegralType t) ||] + PrimBRotateL t -> [|| PrimBRotateL $$(liftIntegralType t) ||] + PrimBRotateR t -> [|| PrimBRotateR $$(liftIntegralType t) ||] + PrimPopCount t -> [|| PrimPopCount $$(liftIntegralType t) ||] + PrimCountLeadingZeros t -> [|| PrimCountLeadingZeros $$(liftIntegralType t) ||] + PrimCountTrailingZeros t -> [|| PrimCountTrailingZeros $$(liftIntegralType t) ||] + PrimBReverse t -> [|| PrimBReverse $$(liftIntegralType t) ||] + PrimBSwap t -> [|| PrimBSwap $$(liftIntegralType t) ||] + PrimVBAnd t -> [|| PrimVBAnd $$(liftIntegralType t) ||] + PrimVBOr t -> [|| PrimVBOr $$(liftIntegralType t) ||] + PrimVBXor t -> [|| PrimVBXor $$(liftIntegralType t) ||] + PrimFDiv t -> [|| PrimFDiv $$(liftFloatingType t) ||] + PrimRecip t -> [|| PrimRecip $$(liftFloatingType t) ||] + PrimSin t -> [|| PrimSin $$(liftFloatingType t) ||] + PrimCos t -> [|| PrimCos $$(liftFloatingType t) ||] + PrimTan t -> [|| PrimTan $$(liftFloatingType t) ||] + PrimAsin t -> [|| PrimAsin $$(liftFloatingType t) ||] + PrimAcos t -> [|| PrimAcos $$(liftFloatingType t) ||] + PrimAtan t -> [|| PrimAtan $$(liftFloatingType t) ||] + PrimSinh t -> [|| PrimSinh $$(liftFloatingType t) ||] + PrimCosh t -> [|| PrimCosh $$(liftFloatingType t) ||] + PrimTanh t -> [|| PrimTanh $$(liftFloatingType t) ||] + PrimAsinh t -> [|| PrimAsinh $$(liftFloatingType t) ||] + PrimAcosh t -> [|| PrimAcosh $$(liftFloatingType t) ||] + PrimAtanh t -> [|| PrimAtanh $$(liftFloatingType t) ||] + PrimExpFloating t -> [|| PrimExpFloating $$(liftFloatingType t) ||] + PrimSqrt t -> [|| PrimSqrt $$(liftFloatingType t) ||] + PrimLog t -> [|| PrimLog $$(liftFloatingType t) ||] + PrimFPow t -> [|| PrimFPow $$(liftFloatingType t) ||] + PrimLogBase t -> [|| PrimLogBase $$(liftFloatingType t) ||] + PrimTruncate ta tb -> [|| PrimTruncate $$(liftFloatingType ta) $$(liftIntegralType tb) ||] + PrimRound ta tb -> [|| PrimRound $$(liftFloatingType ta) $$(liftIntegralType tb) ||] + PrimFloor ta tb -> [|| PrimFloor $$(liftFloatingType ta) $$(liftIntegralType tb) ||] + PrimCeiling ta tb -> [|| PrimCeiling $$(liftFloatingType ta) $$(liftIntegralType tb) ||] + PrimIsNaN t -> [|| PrimIsNaN $$(liftFloatingType t) ||] + PrimIsInfinite t -> [|| PrimIsInfinite $$(liftFloatingType t) ||] + PrimAtan2 t -> [|| PrimAtan2 $$(liftFloatingType t) ||] + PrimLt t -> [|| PrimLt $$(liftScalarType t) ||] + PrimGt t -> [|| PrimGt $$(liftScalarType t) ||] + PrimLtEq t -> [|| PrimLtEq $$(liftScalarType t) ||] + PrimGtEq t -> [|| PrimGtEq $$(liftScalarType t) ||] + PrimEq t -> [|| PrimEq $$(liftScalarType t) ||] + PrimNEq t -> [|| PrimNEq $$(liftScalarType t) ||] + PrimMin t -> [|| PrimMin $$(liftScalarType t) ||] + PrimMax t -> [|| PrimMax $$(liftScalarType t) ||] + PrimVMin t -> [|| PrimVMin $$(liftScalarType t) ||] + PrimVMax t -> [|| PrimVMax $$(liftScalarType t) ||] + PrimLAnd t -> [|| PrimLAnd $$(liftBitType t) ||] + PrimLOr t -> [|| PrimLOr $$(liftBitType t) ||] + PrimLNot t -> [|| PrimLNot $$(liftBitType t) ||] + PrimVLAnd t -> [|| PrimVLAnd $$(liftBitType t) ||] + PrimVLOr t -> [|| PrimVLOr $$(liftBitType t) ||] + PrimFromIntegral ta tb -> [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||] + PrimToFloating ta tb -> [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||] + PrimToBool ta tb -> [|| PrimToBool $$(liftIntegralType ta) $$(liftBitType tb) ||] + PrimFromBool ta tb -> [|| PrimFromBool $$(liftBitType ta) $$(liftIntegralType tb) ||] formatDirection :: Format r (Direction -> r) diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 1fd54f0b3..7cbc17956 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -411,6 +411,8 @@ encodePrimFun (PrimMul a) = intHost $(hashQ "PrimMul") encodePrimFun (PrimNeg a) = intHost $(hashQ "PrimNeg") <> encodeNumType a encodePrimFun (PrimAbs a) = intHost $(hashQ "PrimAbs") <> encodeNumType a encodePrimFun (PrimSig a) = intHost $(hashQ "PrimSig") <> encodeNumType a +encodePrimFun (PrimVAdd a) = intHost $(hashQ "PrimVAdd") <> encodeNumType a +encodePrimFun (PrimVMul a) = intHost $(hashQ "PrimVMul") <> encodeNumType a encodePrimFun (PrimQuot a) = intHost $(hashQ "PrimQuot") <> encodeIntegralType a encodePrimFun (PrimRem a) = intHost $(hashQ "PrimRem") <> encodeIntegralType a encodePrimFun (PrimQuotRem a) = intHost $(hashQ "PrimQuotRem") <> encodeIntegralType a @@ -428,6 +430,11 @@ encodePrimFun (PrimBRotateR a) = intHost $(hashQ "PrimBRotateR") encodePrimFun (PrimPopCount a) = intHost $(hashQ "PrimPopCount") <> encodeIntegralType a encodePrimFun (PrimCountLeadingZeros a) = intHost $(hashQ "PrimCountLeadingZeros") <> encodeIntegralType a encodePrimFun (PrimCountTrailingZeros a) = intHost $(hashQ "PrimCountTrailingZeros") <> encodeIntegralType a +encodePrimFun (PrimBReverse a) = intHost $(hashQ "PrimBReverse") <> encodeIntegralType a +encodePrimFun (PrimBSwap a) = intHost $(hashQ "PrimBSwap") <> encodeIntegralType a +encodePrimFun (PrimVBAnd a) = intHost $(hashQ "PrimVBAnd") <> encodeIntegralType a +encodePrimFun (PrimVBOr a) = intHost $(hashQ "PrimVBOr") <> encodeIntegralType a +encodePrimFun (PrimVBXor a) = intHost $(hashQ "PrimVBXor") <> encodeIntegralType a encodePrimFun (PrimFDiv a) = intHost $(hashQ "PrimFDiv") <> encodeFloatingType a encodePrimFun (PrimRecip a) = intHost $(hashQ "PrimRecip") <> encodeFloatingType a encodePrimFun (PrimSin a) = intHost $(hashQ "PrimSin") <> encodeFloatingType a @@ -460,13 +467,17 @@ encodePrimFun (PrimLtEq a) = intHost $(hashQ "PrimLtEq") encodePrimFun (PrimGtEq a) = intHost $(hashQ "PrimGtEq") <> encodeScalarType a encodePrimFun (PrimEq a) = intHost $(hashQ "PrimEq") <> encodeScalarType a encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") <> encodeScalarType a -encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeScalarType a encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeScalarType a +encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeScalarType a +encodePrimFun (PrimVMin a) = intHost $(hashQ "PrimVMin") <> encodeScalarType a +encodePrimFun (PrimVMax a) = intHost $(hashQ "PrimVMax") <> encodeScalarType a encodePrimFun (PrimLAnd a) = intHost $(hashQ "PrimLAnd") <> encodeBitType a encodePrimFun (PrimLOr a) = intHost $(hashQ "PrimLOr") <> encodeBitType a encodePrimFun (PrimLNot a) = intHost $(hashQ "PrimLNot") <> encodeBitType a +encodePrimFun (PrimVLAnd a) = intHost $(hashQ "PrimVLAnd") <> encodeBitType a +encodePrimFun (PrimVLOr a) = intHost $(hashQ "PrimVLOr") <> encodeBitType a encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b -encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b +encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b encodePrimFun (PrimToBool a b) = intHost $(hashQ "PrimToBool") <> encodeIntegralType a <> encodeBitType b encodePrimFun (PrimFromBool a b) = intHost $(hashQ "PrimFromBool") <> encodeBitType a <> encodeIntegralType b diff --git a/src/Data/Array/Accelerate/Analysis/Match.hs b/src/Data/Array/Accelerate/Analysis/Match.hs index 147f93adb..11e1dc7ab 100644 --- a/src/Data/Array/Accelerate/Analysis/Match.hs +++ b/src/Data/Array/Accelerate/Analysis/Match.hs @@ -630,41 +630,41 @@ matchConst (TupRsingle ty) a b = evalEq ty (a,b) matchConst (TupRpair ta tb) (a1,b1) (a2,b2) = matchConst ta a1 a2 && matchConst tb b1 b2 evalEq :: ScalarType a -> (a, a) -> Bool -evalEq t x = lall (scalar t) (eq t x) +evalEq t x = unBit $ scalar t (eq t x) where - scalar :: ScalarType s -> BitType (BitOrMask s) + scalar :: ScalarType s -> BitOrMask s -> PrimBool scalar (NumScalarType s) = num s scalar (BitScalarType s) = bit s - bit :: BitType s -> BitType (BitOrMask s) - bit TypeBit = TypeBit - bit (TypeMask n) = TypeMask n + bit :: BitType s -> BitOrMask s -> PrimBool + bit TypeBit = id + bit TypeMask{} = vland bitType - num :: NumType s -> BitType (BitOrMask s) + num :: NumType s -> BitOrMask s -> PrimBool num (IntegralNumType s) = integral s num (FloatingNumType s) = floating s - integral :: IntegralType s -> BitType (BitOrMask s) - integral (VectorIntegralType n _) = TypeMask n + integral :: IntegralType s -> BitOrMask s -> PrimBool + integral (VectorIntegralType n _) = vland (TypeMask n) integral (SingleIntegralType s) = case s of - TypeInt8 -> TypeBit - TypeInt16 -> TypeBit - TypeInt32 -> TypeBit - TypeInt64 -> TypeBit - TypeInt128 -> TypeBit - TypeWord8 -> TypeBit - TypeWord16 -> TypeBit - TypeWord32 -> TypeBit - TypeWord64 -> TypeBit - TypeWord128 -> TypeBit - - floating :: FloatingType s -> BitType (BitOrMask s) - floating (VectorFloatingType n _) = TypeMask n + TypeInt8 -> id + TypeInt16 -> id + TypeInt32 -> id + TypeInt64 -> id + TypeInt128 -> id + TypeWord8 -> id + TypeWord16 -> id + TypeWord32 -> id + TypeWord64 -> id + TypeWord128 -> id + + floating :: FloatingType s -> BitOrMask s -> PrimBool + floating (VectorFloatingType n _) = vland (TypeMask n) floating (SingleFloatingType s) = case s of - TypeFloat16 -> TypeBit - TypeFloat32 -> TypeBit - TypeFloat64 -> TypeBit - TypeFloat128 -> TypeBit + TypeFloat16 -> id + TypeFloat32 -> id + TypeFloat64 -> id + TypeFloat128 -> id -- Environment projection indices diff --git a/src/Data/Array/Accelerate/Classes/Ord.hs-boot b/src/Data/Array/Accelerate/Classes/Ord.hs-boot new file mode 100644 index 000000000..5a1354c38 --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/Ord.hs-boot @@ -0,0 +1,43 @@ +{-# LANGUAGE NoImplicitPrelude #-} +-- | +-- Module : Data.Array.Accelerate.Classes.VOrd +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.Ord ( + + Ord(..), + Ordering, + +) where + +import Data.Array.Accelerate.Classes.Eq +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Pattern.Ordering + + +class Eq a => Ord a where + {-# MINIMAL (<=) | compare #-} + (<) :: Exp a -> Exp a -> Exp Bool + (>) :: Exp a -> Exp a -> Exp Bool + (<=) :: Exp a -> Exp a -> Exp Bool + (>=) :: Exp a -> Exp a -> Exp Bool + min :: Exp a -> Exp a -> Exp a + max :: Exp a -> Exp a -> Exp a + compare :: Exp a -> Exp a -> Exp Ordering + + (<) = undefined + (<=) = undefined + (>) = undefined + (>=) = undefined + + min = undefined + max = undefined + + compare = undefined + diff --git a/src/Data/Array/Accelerate/Classes/VEq.hs b/src/Data/Array/Accelerate/Classes/VEq.hs index 62a26e52d..1a248de22 100644 --- a/src/Data/Array/Accelerate/Classes/VEq.hs +++ b/src/Data/Array/Accelerate/Classes/VEq.hs @@ -24,6 +24,8 @@ module Data.Array.Accelerate.Classes.VEq ( (&&*), (||*), vnot, + vand, + vor, ) where @@ -70,6 +72,20 @@ infixr 2 ||* vnot :: KnownNat n => Exp (Vec n Bool) -> Exp (Vec n Bool) vnot = mkPrimUnary $ PrimLNot bitType +-- | Return 'True' if all lanes of the vector are 'True' +-- +-- @since 1.4.0.0 +-- +vand :: KnownNat n => Exp (Vec n Bool) -> Exp Bool +vand = mkPrimUnary $ PrimVLAnd bitType + +-- | Return 'True' if any lane of the vector is 'True' +-- +-- @since 1.4.0.0 +-- +vor :: KnownNat n => Exp (Vec n Bool) -> Exp Bool +vor = mkPrimUnary $ PrimVLOr bitType + infix 4 ==* infix 4 /=* diff --git a/src/Data/Array/Accelerate/Classes/VEq.hs-boot b/src/Data/Array/Accelerate/Classes/VEq.hs-boot index 082de6b18..4b99452f4 100644 --- a/src/Data/Array/Accelerate/Classes/VEq.hs-boot +++ b/src/Data/Array/Accelerate/Classes/VEq.hs-boot @@ -19,8 +19,9 @@ class SIMD n a => VEq n a where (==*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) (/=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) {-# MINIMAL (==*) | (/=*) #-} - x ==* y = vnot (x /=* y) - x /=* y = vnot (x ==* y) + (==*) = undefined + (/=*) = undefined -vnot :: KnownNat n => Exp (Vec n Bool) -> Exp (Vec n Bool) +vand :: KnownNat n => Exp (Vec n Bool) -> Exp Bool +vor :: KnownNat n => Exp (Vec n Bool) -> Exp Bool diff --git a/src/Data/Array/Accelerate/Classes/VNum.hs b/src/Data/Array/Accelerate/Classes/VNum.hs new file mode 100644 index 000000000..c1964b092 --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/VNum.hs @@ -0,0 +1,81 @@ +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TemplateHaskell #-} +-- | +-- Module : Data.Array.Accelerate.Classes.VNum +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com> +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.VNum ( + + VNum(..), + +) where + +import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Smart + +import Language.Haskell.TH.Extra hiding ( Type, Exp ) + + +-- | The 'VNum' class defines numeric operations over SIMD vectors. +-- +-- @since 1.4.0.0 +-- +class SIMD n a => VNum n a where + -- | Horizontal reduction of a vector with addition. This operation is not + -- guaranteed to preserve the associativity of an equivalent scalarised + -- counterpart. + vadd :: Exp (Vec n a) -> Exp a + + -- | Horizontal reduction of a vector with multiplication. This operation is + -- not guaranteed to preserve the associativity of an equivalent scalarised + -- counterpart. + vmul :: Exp (Vec n a) -> Exp a + + +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + thVNum :: Name -> Q [Dec] + thVNum name = + [d| instance KnownNat n => VNum n $(conT name) where + vadd = mkPrimUnary $ PrimVAdd numType + vmul = mkPrimUnary $ PrimVMul numType + |] + in + concat <$> mapM thVNum numTypes + diff --git a/src/Data/Array/Accelerate/Classes/VOrd.hs b/src/Data/Array/Accelerate/Classes/VOrd.hs index 6f041fb3f..bd4971eff 100644 --- a/src/Data/Array/Accelerate/Classes/VOrd.hs +++ b/src/Data/Array/Accelerate/Classes/VOrd.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -56,6 +57,8 @@ class VEq n a => VOrd n a where (>=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) vmin :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n a) vmax :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n a) + vminimum :: Exp (Vec n a) -> Exp a + vmaximum :: Exp (Vec n a) -> Exp a vcompare :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Ordering) x <* y = select (vcompare x y ==* vlt) vtrue vfalse @@ -66,6 +69,11 @@ class VEq n a => VOrd n a where vmin x y = select (x <=* y) x y vmax x y = select (x <=* y) y x + default vminimum :: Ord a => Exp (Vec n a) -> Exp a + default vmaximum :: Ord a => Exp (Vec n a) -> Exp a + vminimum x = P.minimum (unpack x) + vmaximum x = P.maximum (unpack x) + vcompare x y = select (x ==* y) veq $ select (x <=* y) vlt vgt @@ -122,6 +130,8 @@ runQ $ do (>=*) = mkPrimBinary $ PrimGtEq scalarType vmin = mkPrimBinary $ PrimMin scalarType vmax = mkPrimBinary $ PrimMax scalarType + vminimum = mkPrimUnary $ PrimVMin scalarType + vmaximum = mkPrimUnary $ PrimVMax scalarType |] mkTup :: Word8 -> Q Dec diff --git a/src/Data/Array/Accelerate/Classes/VOrd.hs-boot b/src/Data/Array/Accelerate/Classes/VOrd.hs-boot index 083f879e2..f379e784b 100644 --- a/src/Data/Array/Accelerate/Classes/VOrd.hs-boot +++ b/src/Data/Array/Accelerate/Classes/VOrd.hs-boot @@ -1,3 +1,4 @@ +{-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} -- | -- Module : Data.Array.Accelerate.Classes.VOrd @@ -15,9 +16,12 @@ module Data.Array.Accelerate.Classes.VOrd ( ) where +import Data.Array.Accelerate.Classes.VEq import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Vec -import Data.Array.Accelerate.Classes.VEq +import {-# SOURCE #-} Data.Array.Accelerate.Classes.Ord + +import Prelude hiding ( Ord(..), Ordering(..), (<*) ) class VEq n a => VOrd n a where @@ -28,19 +32,22 @@ class VEq n a => VOrd n a where (>=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) vmin :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n a) vmax :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n a) + vminimum :: Exp (Vec n a) -> Exp a + vmaximum :: Exp (Vec n a) -> Exp a vcompare :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Ordering) - x <* y = select (vcompare x y ==* vlt) vtrue vfalse - x <=* y = select (vcompare x y ==* vgt) vfalse vtrue - x >* y = select (vcompare x y ==* vgt) vtrue vfalse - x >=* y = select (vcompare x y ==* vlt) vfalse vtrue + (<*) = undefined + (<=*) = undefined + (>*) = undefined + (>=*) = undefined - vmin x y = select (x <=* y) x y - vmax x y = select (x <=* y) y x + vmin = undefined + vmax = undefined - vcompare x y - = select (x ==* y) veq - $ select (x <=* y) vlt vgt + default vminimum :: Ord a => Exp (Vec n a) -> Exp a + default vmaximum :: Ord a => Exp (Vec n a) -> Exp a + vminimum = undefined + vmaximum = undefined -vtrue, vfalse :: KnownNat n => Exp (Vec n Bool) + vcompare = undefined diff --git a/src/Data/Array/Accelerate/Data/Bits.hs b/src/Data/Array/Accelerate/Data/Bits.hs index d60a61872..491a4e159 100644 --- a/src/Data/Array/Accelerate/Data/Bits.hs +++ b/src/Data/Array/Accelerate/Data/Bits.hs @@ -30,7 +30,6 @@ module Data.Array.Accelerate.Data.Bits ( ) where -import Data.Array.Accelerate.AST ( BitOrMask ) import Data.Array.Accelerate.AST ( PrimFun(..), BitOrMask ) import Data.Array.Accelerate.Language import Data.Array.Accelerate.Smart @@ -163,6 +162,18 @@ class Bits a => FiniteBits a where -- countTrailingZeros :: Exp a -> Exp a + -- | Reverse the order of bits + -- + -- @since 1.4.0.0 + -- + bitreverse :: Exp a -> Exp a + + -- | Reverse the order of bytes + -- + -- @since 1.4.0.0 + -- + byteswap :: Exp a -> Exp a + -- Instances for Bits -- ------------------ @@ -187,6 +198,8 @@ instance FiniteBits Bool where finiteBitSize _ = fromInteger (toInteger (B.finiteBitSize (undefined::Bool))) countLeadingZeros x = x countTrailingZeros x = x + bitreverse x = x + byteswap x = x -- Default implementations @@ -370,6 +383,13 @@ mkCountLeadingZeros = mkPrimUnary $ PrimCountLeadingZeros integralType mkCountTrailingZeros :: IsIntegral (EltR t) => Exp t -> Exp t mkCountTrailingZeros = mkPrimUnary $ PrimCountTrailingZeros integralType +mkBReverse :: IsIntegral (EltR t) => Exp t -> Exp t +mkBReverse = mkPrimUnary $ PrimBReverse integralType + +mkBSwap :: IsIntegral (EltR t) => Exp t -> Exp t +mkBSwap = mkPrimUnary $ PrimBSwap integralType + + runQ $ let integralTypes :: [Name] @@ -430,11 +450,15 @@ runQ $ finiteBitSize _ = fromInteger (toInteger (B.finiteBitSize (undefined :: $(conT a)))) countLeadingZeros = mkCountLeadingZeros countTrailingZeros = mkCountTrailingZeros + bitreverse = mkBReverse + byteswap = mkBSwap instance KnownNat n => FiniteBits (Vec n $(conT a)) where finiteBitSize _ = fromInteger (natVal' (proxy# :: Proxy# n) * toInteger (B.finiteBitSize (undefined :: $(conT a)))) countLeadingZeros = mkCountLeadingZeros countTrailingZeros = mkCountTrailingZeros + bitreverse = mkBReverse + byteswap = mkBSwap |] in concat <$> mapM thBits integralTypes diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 60a2d55bb..f49e06526 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -1246,70 +1246,82 @@ evalCoerceScalar VectorScalarType{} (SingleScalarType tb) a = scalar tb a -- ----------------- evalPrim :: PrimFun (a -> r) -> (a -> r) -evalPrim (PrimAdd t) = A.add t -evalPrim (PrimSub t) = A.sub t -evalPrim (PrimMul t) = A.mul t -evalPrim (PrimNeg t) = A.negate t -evalPrim (PrimAbs t) = A.abs t -evalPrim (PrimSig t) = A.signum t -evalPrim (PrimQuot t) = A.quot t -evalPrim (PrimRem t) = A.rem t -evalPrim (PrimQuotRem t) = A.quotRem t -evalPrim (PrimIDiv t) = A.div t -evalPrim (PrimMod t) = A.mod t -evalPrim (PrimDivMod t) = A.divMod t -evalPrim (PrimBAnd t) = A.band t -evalPrim (PrimBOr t) = A.bor t -evalPrim (PrimBXor t) = A.xor t -evalPrim (PrimBNot t) = A.complement t -evalPrim (PrimBShiftL t) = A.shiftL t -evalPrim (PrimBShiftR t) = A.shiftR t -evalPrim (PrimBRotateL t) = A.rotateL t -evalPrim (PrimBRotateR t) = A.rotateR t -evalPrim (PrimPopCount t) = A.popCount t -evalPrim (PrimCountLeadingZeros t) = A.countLeadingZeros t -evalPrim (PrimCountTrailingZeros t) = A.countTrailingZeros t -evalPrim (PrimFDiv t) = A.fdiv t -evalPrim (PrimRecip t) = A.recip t -evalPrim (PrimSin t) = A.sin t -evalPrim (PrimCos t) = A.cos t -evalPrim (PrimTan t) = A.tan t -evalPrim (PrimAsin t) = A.asin t -evalPrim (PrimAcos t) = A.acos t -evalPrim (PrimAtan t) = A.atan t -evalPrim (PrimSinh t) = A.sinh t -evalPrim (PrimCosh t) = A.cosh t -evalPrim (PrimTanh t) = A.tanh t -evalPrim (PrimAsinh t) = A.asinh t -evalPrim (PrimAcosh t) = A.acosh t -evalPrim (PrimAtanh t) = A.atanh t -evalPrim (PrimExpFloating t) = A.exp t -evalPrim (PrimSqrt t) = A.sqrt t -evalPrim (PrimLog t) = A.log t -evalPrim (PrimFPow t) = A.pow t -evalPrim (PrimLogBase t) = A.logBase t -evalPrim (PrimTruncate ta tb) = A.truncate ta tb -evalPrim (PrimRound ta tb) = A.round ta tb -evalPrim (PrimFloor ta tb) = A.floor ta tb -evalPrim (PrimCeiling ta tb) = A.ceiling ta tb -evalPrim (PrimAtan2 t) = A.atan2 t -evalPrim (PrimIsNaN t) = A.isNaN t -evalPrim (PrimIsInfinite t) = A.isInfinite t -evalPrim (PrimLt t) = A.lt t -evalPrim (PrimGt t) = A.gt t -evalPrim (PrimLtEq t) = A.lte t -evalPrim (PrimGtEq t) = A.gte t -evalPrim (PrimEq t) = A.eq t -evalPrim (PrimNEq t) = A.neq t -evalPrim (PrimMax t) = A.max t -evalPrim (PrimMin t) = A.min t -evalPrim (PrimLAnd t) = A.land t -evalPrim (PrimLOr t) = A.lor t -evalPrim (PrimLNot t) = A.lnot t -evalPrim (PrimFromIntegral ta tb) = A.fromIntegral ta tb -evalPrim (PrimToFloating ta tb) = A.toFloating ta tb -evalPrim (PrimToBool i b) = A.toBool i b -evalPrim (PrimFromBool b i) = A.fromBool b i +evalPrim = \case + PrimAdd t -> A.add t + PrimSub t -> A.sub t + PrimMul t -> A.mul t + PrimNeg t -> A.negate t + PrimAbs t -> A.abs t + PrimSig t -> A.signum t + PrimVAdd t -> A.vadd t + PrimVMul t -> A.vmul t + PrimQuot t -> A.quot t + PrimRem t -> A.rem t + PrimQuotRem t -> A.quotRem t + PrimIDiv t -> A.div t + PrimMod t -> A.mod t + PrimDivMod t -> A.divMod t + PrimBAnd t -> A.band t + PrimBOr t -> A.bor t + PrimBXor t -> A.xor t + PrimBNot t -> A.complement t + PrimBShiftL t -> A.shiftL t + PrimBShiftR t -> A.shiftR t + PrimBRotateL t -> A.rotateL t + PrimBRotateR t -> A.rotateR t + PrimPopCount t -> A.popCount t + PrimCountLeadingZeros t -> A.countLeadingZeros t + PrimCountTrailingZeros t -> A.countTrailingZeros t + PrimBReverse t -> A.bitreverse t + PrimBSwap t -> A.byteswap t + PrimVBAnd t -> A.vband t + PrimVBOr t -> A.vbor t + PrimVBXor t -> A.vbxor t + PrimFDiv t -> A.fdiv t + PrimRecip t -> A.recip t + PrimSin t -> A.sin t + PrimCos t -> A.cos t + PrimTan t -> A.tan t + PrimAsin t -> A.asin t + PrimAcos t -> A.acos t + PrimAtan t -> A.atan t + PrimSinh t -> A.sinh t + PrimCosh t -> A.cosh t + PrimTanh t -> A.tanh t + PrimAsinh t -> A.asinh t + PrimAcosh t -> A.acosh t + PrimAtanh t -> A.atanh t + PrimExpFloating t -> A.exp t + PrimSqrt t -> A.sqrt t + PrimLog t -> A.log t + PrimFPow t -> A.pow t + PrimLogBase t -> A.logBase t + PrimTruncate ta tb -> A.truncate ta tb + PrimRound ta tb -> A.round ta tb + PrimFloor ta tb -> A.floor ta tb + PrimCeiling ta tb -> A.ceiling ta tb + PrimAtan2 t -> A.atan2 t + PrimIsNaN t -> A.isNaN t + PrimIsInfinite t -> A.isInfinite t + PrimLt t -> A.lt t + PrimGt t -> A.gt t + PrimLtEq t -> A.lte t + PrimGtEq t -> A.gte t + PrimEq t -> A.eq t + PrimNEq t -> A.neq t + PrimMin t -> A.min t + PrimMax t -> A.max t + PrimVMin t -> A.vmin t + PrimVMax t -> A.vmax t + PrimLAnd t -> A.land t + PrimLOr t -> A.lor t + PrimLNot t -> A.lnot t + PrimVLAnd t -> A.vland t + PrimVLOr t -> A.vlor t + PrimFromIntegral ta tb -> A.fromIntegral ta tb + PrimToFloating ta tb -> A.toFloating ta tb + PrimToBool i b -> A.toBool i b + PrimFromBool b i -> A.fromBool b i -- Vector primitives diff --git a/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs b/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs index 6850a1684..80d02415e 100644 --- a/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs +++ b/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs @@ -1,4 +1,5 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} @@ -19,14 +20,14 @@ module Data.Array.Accelerate.Interpreter.Arithmetic ( - add, sub, mul, negate, abs, signum, + add, sub, mul, negate, abs, signum, vadd, vmul, quot, rem, quotRem, div, mod, divMod, - band, bor, xor, complement, shiftL, shiftR, rotateL, rotateR, popCount, countLeadingZeros, countTrailingZeros, + band, bor, xor, complement, shiftL, shiftR, rotateL, rotateR, popCount, countLeadingZeros, countTrailingZeros, bitreverse, byteswap, vband, vbor, vbxor, fdiv, recip, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, exp, sqrt, log, pow, logBase, truncate, round, floor, ceiling, atan2, isNaN, isInfinite, - lt, gt, lte, gte, eq, neq, min, max, - land, lor, lnot, lall, lany, + lt, gt, lte, gte, eq, neq, min, max, vmin, vmax, + land, lor, lnot, vland, vlor, fromIntegral, toFloating, toBool, fromBool, ) where @@ -36,6 +37,8 @@ import Data.Array.Accelerate.Error import Data.Array.Accelerate.Type import Data.Primitive.Bit ( BitMask(..) ) +import Data.Primitive.Vec ( Vec ) +import qualified Data.Primitive.Bit as Bit import qualified Data.Primitive.Vec as Vec import Data.Primitive.Types @@ -49,6 +52,8 @@ import Prelude ( ($), (.) ) import qualified Data.Bits as P import qualified Prelude as P +import GHC.Int +import GHC.Word import GHC.Exts import GHC.TypeLits import GHC.TypeLits.Extra @@ -75,6 +80,12 @@ abs = num1 P.abs signum :: NumType a -> (a -> a) signum = num1 P.signum +vadd :: NumType (Vec n a) -> (Vec n a -> a) +vadd = vnum2 (P.+) + +vmul :: NumType (Vec n a) -> (Vec n a -> a) +vmul = vnum2 (P.*) + num1 :: (forall t. P.Num t => t -> t) -> NumType a -> (a -> a) num1 f = \case IntegralNumType t -> integral t @@ -101,6 +112,20 @@ num2 f = \case floating (SingleFloatingType t) | FloatingDict <- floatingDict t = P.uncurry f floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (zipWith f) +vnum2 :: (forall t. P.Num t => t -> t -> t) -> NumType (Vec n a) -> (Vec n a -> a) +vnum2 f = \case + IntegralNumType t -> integral t + FloatingNumType t -> floating t + where + integral :: IntegralType (Vec n t) -> (Vec n t -> t) + integral = \case + SingleIntegralType t -> case t of + VectorIntegralType _ t | IntegralDict <- integralDict t -> P.foldl1 f . Vec.toList + + floating :: FloatingType (Vec n t) -> (Vec n t -> t) + floating = \case + SingleFloatingType t -> case t of + VectorFloatingType _ t | FloatingDict <- floatingDict t -> P.foldl1 f . Vec.toList -- Operators from Integral -- ----------------------- @@ -168,6 +193,49 @@ countLeadingZeros = bits1 (P.fromIntegral . P.countLeadingZeros) countTrailingZeros :: IntegralType a -> (a -> a) countTrailingZeros = bits1 (P.fromIntegral . P.countTrailingZeros) +bitreverse :: IntegralType a -> (a -> a) +bitreverse = \case + VectorIntegralType _ t | IntegralDict <- integralDict t -> Vec.fromList . P.map (bitreverse (SingleIntegralType t)) . Vec.toList + SingleIntegralType t -> single t + where + single :: SingleIntegralType t -> t -> t + single TypeInt8 (I8# i#) = I8# (word8ToInt8# (wordToWord8# (bitReverse8# (word8ToWord# (int8ToWord8# i#))))) + single TypeInt16 (I16# i#) = I16# (word16ToInt16# (wordToWord16# (bitReverse16# (word16ToWord# (int16ToWord16# i#))))) + single TypeInt32 (I32# i#) = I32# (word32ToInt32# (wordToWord32# (bitReverse32# (word32ToWord# (int32ToWord32# i#))))) + single TypeInt64 (I64# i#) = I64# (word64ToInt64# (bitReverse64# (int64ToWord64# i#))) + single TypeInt128 (Int128 h l) = Int128 (bitreverse integralType l) (bitreverse integralType h) + single TypeWord8 (W8# w#) = W8# (wordToWord8# (bitReverse8# (word8ToWord# w#))) + single TypeWord16 (W16# w#) = W16# (wordToWord16# (bitReverse16# (word16ToWord# w#))) + single TypeWord32 (W32# w#) = W32# (wordToWord32# (bitReverse32# (word32ToWord# w#))) + single TypeWord64 (W64# w#) = W64# (bitReverse64# w#) + single TypeWord128 (Word128 h l) = Word128 (bitreverse integralType l) (bitreverse integralType h) + +byteswap :: IntegralType a -> (a -> a) +byteswap = \case + VectorIntegralType _ t | IntegralDict <- integralDict t -> Vec.fromList . P.map (byteswap (SingleIntegralType t)) . Vec.toList + SingleIntegralType t -> single t + where + single :: SingleIntegralType t -> t -> t + single TypeInt8 x = x + single TypeInt16 (I16# i#) = I16# (word16ToInt16# (wordToWord16# (byteSwap16# (word16ToWord# (int16ToWord16# i#))))) + single TypeInt32 (I32# i#) = I32# (word32ToInt32# (wordToWord32# (byteSwap32# (word32ToWord# (int32ToWord32# i#))))) + single TypeInt64 (I64# i#) = I64# (word64ToInt64# (byteSwap64# (int64ToWord64# i#))) + single TypeInt128 (Int128 h l) = Int128 (byteswap integralType l) (byteswap integralType h) + single TypeWord8 x = x + single TypeWord16 (W16# w#) = W16# (wordToWord16# (byteSwap16# (word16ToWord# w#))) + single TypeWord32 (W32# w#) = W32# (wordToWord32# (byteSwap32# (word32ToWord# w#))) + single TypeWord64 (W64# w#) = W64# (byteSwap64# w#) + single TypeWord128 (Word128 h l) = Word128 (byteswap integralType l) (byteswap integralType h) + +vband :: IntegralType (Vec n a) -> (Vec n a -> a) +vband = vbits2 (.&.) + +vbor :: IntegralType (Vec n a) -> (Vec n a -> a) +vbor = vbits2 (.|.) + +vbxor :: IntegralType (Vec n a) -> (Vec n a -> a) +vbxor = vbits2 P.xor + bits1 :: (forall t. (P.Integral t, P.FiniteBits t) => t -> t) -> IntegralType a -> (a -> a) bits1 f (SingleIntegralType t) | IntegralDict <- integralDict t = f bits1 f (VectorIntegralType _ t) | IntegralDict <- integralDict t = map f @@ -176,6 +244,10 @@ bits2 :: (forall t. (P.Integral t, P.FiniteBits t) => t -> t -> t) -> IntegralTy bits2 f (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry f bits2 f (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith f) +vbits2 :: (forall t. P.FiniteBits t => t -> t -> t) -> IntegralType (Vec n a) -> (Vec n a -> a) +vbits2 _ (SingleIntegralType t) = case t of +vbits2 f (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.foldl1 f . Vec.toList + -- Operators from Fractional and Floating -- -------------------------------------- @@ -367,47 +439,61 @@ cmp f = \case floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (unMask $$ zipWith (Bit $$ f)) min :: ScalarType a -> ((a, a) -> a) -min = \case +min = ord2 P.min + +max :: ScalarType a -> ((a, a) -> a) +max = ord2 P.max + +vmin :: ScalarType (Vec n a) -> (Vec n a -> a) +vmin = vord2 P.min + +vmax :: ScalarType (Vec n a) -> (Vec n a -> a) +vmax = vord2 P.max + +ord2 :: (forall t. P.Ord t => t -> t -> t) -> ScalarType a -> ((a, a) -> a) +ord2 f = \case NumScalarType t -> num t BitScalarType t -> bit t where bit :: BitType t -> ((t, t) -> t) - bit TypeBit = P.uncurry P.min - bit TypeMask{} = \(x,y) -> unMask (zipWith P.min (BitMask x) (BitMask y)) + bit TypeBit = P.uncurry f + bit TypeMask{} = \(x,y) -> unMask (zipWith f (BitMask x) (BitMask y)) num :: NumType t -> ((t, t) -> t) num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType t -> ((t, t) -> t) - integral (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry P.min - integral (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith P.min) + integral (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry f + integral (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith f) floating :: FloatingType t -> ((t, t) -> t) - floating (SingleFloatingType t) | FloatingDict <- floatingDict t = P.uncurry P.min - floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (zipWith P.min) + floating (SingleFloatingType t) | FloatingDict <- floatingDict t = P.uncurry f + floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (zipWith f) -max :: ScalarType a -> ((a, a) -> a) -max = \case + +vord2 :: (forall t. P.Ord t => t -> t -> t) -> ScalarType (Vec n a) -> (Vec n a -> a) +vord2 f = \case NumScalarType t -> num t BitScalarType t -> bit t where - bit :: BitType t -> ((t, t) -> t) - bit TypeBit = P.uncurry P.max - bit TypeMask{} = \(x,y) -> unMask (zipWith P.max (BitMask x) (BitMask y)) + bit :: BitType (Vec n t) -> (Vec n t -> t) + bit TypeMask{} = P.foldl1 f . Bit.toList . BitMask - num :: NumType t -> ((t, t) -> t) - num (IntegralNumType t) = integral t - num (FloatingNumType t) = floating t + num :: NumType (Vec n t) -> (Vec n t -> t) + num = \case + IntegralNumType t -> integral t + FloatingNumType t -> floating t - integral :: IntegralType t -> ((t, t) -> t) - integral (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry P.max - integral (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith P.max) - - floating :: FloatingType t -> ((t, t) -> t) - floating (SingleFloatingType t) | FloatingDict <- floatingDict t = P.uncurry P.max - floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (zipWith P.max) + integral :: IntegralType (Vec n t) -> (Vec n t -> t) + integral = \case + SingleIntegralType t -> case t of + VectorIntegralType _ t | IntegralDict <- integralDict t -> P.foldl1 f . Vec.toList + floating :: FloatingType (Vec n t) -> (Vec n t -> t) + floating = \case + SingleFloatingType t -> case t of + VectorFloatingType _ t | FloatingDict <- floatingDict t -> P.foldl1 f . Vec.toList -- Logical operators -- ----------------- @@ -433,13 +519,11 @@ lnot = \case where not' (Bit x) = Bit (not x) -lall :: BitType a -> a -> Bool -lall TypeBit x = unBit x -lall TypeMask{} x = P.all unBit (toList (BitMask x)) +vland :: BitType (Vec n a) -> (Vec n a -> a) +vland TypeMask{} x = Bit $ P.all unBit (toList (BitMask x)) -lany :: BitType a -> a -> Bool -lany TypeBit x = unBit x -lany TypeMask{} x = P.any unBit (toList (BitMask x)) +vlor :: BitType (Vec n a) -> (Vec n a -> a) +vlor TypeMask{} x = Bit $ P.any unBit (toList (BitMask x)) -- Conversion diff --git a/src/Data/Array/Accelerate/Pretty/Print.hs b/src/Data/Array/Accelerate/Pretty/Print.hs index 1bf370d91..05eee3dfe 100644 --- a/src/Data/Array/Accelerate/Pretty/Print.hs +++ b/src/Data/Array/Accelerate/Pretty/Print.hs @@ -704,70 +704,82 @@ isInfix :: Operator -> Bool isInfix Operator{..} = opFixity == Infix primOperator :: PrimFun a -> Operator -primOperator PrimAdd{} = Operator (pretty '+') Infix L 6 -primOperator PrimSub{} = Operator (pretty '-') Infix L 6 -primOperator PrimMul{} = Operator (pretty '*') Infix L 7 -primOperator PrimNeg{} = Operator (pretty '-') Prefix L 6 -- Haskell's only prefix operator -primOperator PrimAbs{} = Operator "abs" App L 10 -primOperator PrimSig{} = Operator "signum" App L 10 -primOperator PrimQuot{} = Operator "quot" App L 10 -primOperator PrimRem{} = Operator "rem" App L 10 -primOperator PrimQuotRem{} = Operator "quotRem" App L 10 -primOperator PrimIDiv{} = Operator "div" App L 10 -primOperator PrimMod{} = Operator "mod" App L 10 -primOperator PrimDivMod{} = Operator "divMod" App L 10 -primOperator PrimBAnd{} = Operator ".&." Infix L 7 -primOperator PrimBOr{} = Operator ".|." Infix L 5 -primOperator PrimBXor{} = Operator "xor" App L 10 -primOperator PrimBNot{} = Operator "complement" App L 10 -primOperator PrimBShiftL{} = Operator "shiftL" App L 10 -primOperator PrimBShiftR{} = Operator "shiftR" App L 10 -primOperator PrimBRotateL{} = Operator "rotateL" App L 10 -primOperator PrimBRotateR{} = Operator "rotateR" App L 10 -primOperator PrimPopCount{} = Operator "popCount" App L 10 -primOperator PrimCountLeadingZeros{} = Operator "countLeadingZeros" App L 10 -primOperator PrimCountTrailingZeros{} = Operator "countTrailingZeros" App L 10 -primOperator PrimFDiv{} = Operator (pretty '/') Infix L 7 -primOperator PrimRecip{} = Operator "recip" App L 10 -primOperator PrimSin{} = Operator "sin" App L 10 -primOperator PrimCos{} = Operator "cos" App L 10 -primOperator PrimTan{} = Operator "tan" App L 10 -primOperator PrimAsin{} = Operator "asin" App L 10 -primOperator PrimAcos{} = Operator "acos" App L 10 -primOperator PrimAtan{} = Operator "atan" App L 10 -primOperator PrimSinh{} = Operator "sinh" App L 10 -primOperator PrimCosh{} = Operator "cosh" App L 10 -primOperator PrimTanh{} = Operator "tanh" App L 10 -primOperator PrimAsinh{} = Operator "asinh" App L 10 -primOperator PrimAcosh{} = Operator "acosh" App L 10 -primOperator PrimAtanh{} = Operator "atanh" App L 10 -primOperator PrimExpFloating{} = Operator "exp" App L 10 -primOperator PrimSqrt{} = Operator "sqrt" App L 10 -primOperator PrimLog{} = Operator "log" App L 10 -primOperator PrimFPow{} = Operator "**" Infix R 8 -primOperator PrimLogBase{} = Operator "logBase" App L 10 -primOperator PrimTruncate{} = Operator "truncate" App L 10 -primOperator PrimRound{} = Operator "round" App L 10 -primOperator PrimFloor{} = Operator "floor" App L 10 -primOperator PrimCeiling{} = Operator "ceiling" App L 10 -primOperator PrimAtan2{} = Operator "atan2" App L 10 -primOperator PrimIsNaN{} = Operator "isNaN" App L 10 -primOperator PrimIsInfinite{} = Operator "isInfinite" App L 10 -primOperator PrimLt{} = Operator "<" Infix N 4 -primOperator PrimGt{} = Operator ">" Infix N 4 -primOperator PrimLtEq{} = Operator "<=" Infix N 4 -primOperator PrimGtEq{} = Operator ">=" Infix N 4 -primOperator PrimEq{} = Operator "==" Infix N 4 -primOperator PrimNEq{} = Operator "/=" Infix N 4 -primOperator PrimMax{} = Operator "max" App L 10 -primOperator PrimMin{} = Operator "min" App L 10 -primOperator PrimLAnd{} = Operator "&&" Infix R 3 -primOperator PrimLOr{} = Operator "||" Infix R 2 -primOperator PrimLNot{} = Operator "not" App L 10 -primOperator PrimFromIntegral{} = Operator "fromIntegral" App L 10 -primOperator PrimToFloating{} = Operator "toFloating" App L 10 -primOperator PrimToBool{} = Operator "toBool" App L 10 -primOperator PrimFromBool{} = Operator "fromBool" App L 10 +primOperator = \case + PrimAdd{} -> Operator (pretty '+') Infix L 6 + PrimSub{} -> Operator (pretty '-') Infix L 6 + PrimMul{} -> Operator (pretty '*') Infix L 7 + PrimNeg{} -> Operator (pretty '-') Prefix L 6 -- Haskell's only prefix operator + PrimAbs{} -> Operator "abs" App L 10 + PrimSig{} -> Operator "signum" App L 10 + PrimVAdd{} -> Operator "vadd" App L 10 + PrimVMul{} -> Operator "vmul" App L 10 + PrimQuot{} -> Operator "quot" App L 10 + PrimRem{} -> Operator "rem" App L 10 + PrimQuotRem{} -> Operator "quotRem" App L 10 + PrimIDiv{} -> Operator "div" App L 10 + PrimMod{} -> Operator "mod" App L 10 + PrimDivMod{} -> Operator "divMod" App L 10 + PrimBAnd{} -> Operator ".&." Infix L 7 + PrimBOr{} -> Operator ".|." Infix L 5 + PrimBXor{} -> Operator "xor" App L 10 + PrimBNot{} -> Operator "complement" App L 10 + PrimBShiftL{} -> Operator "shiftL" App L 10 + PrimBShiftR{} -> Operator "shiftR" App L 10 + PrimBRotateL{} -> Operator "rotateL" App L 10 + PrimBRotateR{} -> Operator "rotateR" App L 10 + PrimPopCount{} -> Operator "popCount" App L 10 + PrimCountLeadingZeros{} -> Operator "countLeadingZeros" App L 10 + PrimCountTrailingZeros{} -> Operator "countTrailingZeros" App L 10 + PrimBReverse{} -> Operator "bitreverse" App L 10 + PrimBSwap{} -> Operator "byteswap" App L 10 + PrimVBAnd{} -> Operator "vand" App L 10 + PrimVBOr{} -> Operator "vor" App L 10 + PrimVBXor{} -> Operator "vxor" App L 10 + PrimFDiv{} -> Operator (pretty '/') Infix L 7 + PrimRecip{} -> Operator "recip" App L 10 + PrimSin{} -> Operator "sin" App L 10 + PrimCos{} -> Operator "cos" App L 10 + PrimTan{} -> Operator "tan" App L 10 + PrimAsin{} -> Operator "asin" App L 10 + PrimAcos{} -> Operator "acos" App L 10 + PrimAtan{} -> Operator "atan" App L 10 + PrimSinh{} -> Operator "sinh" App L 10 + PrimCosh{} -> Operator "cosh" App L 10 + PrimTanh{} -> Operator "tanh" App L 10 + PrimAsinh{} -> Operator "asinh" App L 10 + PrimAcosh{} -> Operator "acosh" App L 10 + PrimAtanh{} -> Operator "atanh" App L 10 + PrimExpFloating{} -> Operator "exp" App L 10 + PrimSqrt{} -> Operator "sqrt" App L 10 + PrimLog{} -> Operator "log" App L 10 + PrimFPow{} -> Operator "**" Infix R 8 + PrimLogBase{} -> Operator "logBase" App L 10 + PrimTruncate{} -> Operator "truncate" App L 10 + PrimRound{} -> Operator "round" App L 10 + PrimFloor{} -> Operator "floor" App L 10 + PrimCeiling{} -> Operator "ceiling" App L 10 + PrimAtan2{} -> Operator "atan2" App L 10 + PrimIsNaN{} -> Operator "isNaN" App L 10 + PrimIsInfinite{} -> Operator "isInfinite" App L 10 + PrimLt{} -> Operator "<" Infix N 4 + PrimGt{} -> Operator ">" Infix N 4 + PrimLtEq{} -> Operator "<=" Infix N 4 + PrimGtEq{} -> Operator ">=" Infix N 4 + PrimEq{} -> Operator "==" Infix N 4 + PrimNEq{} -> Operator "/=" Infix N 4 + PrimMin{} -> Operator "min" App L 10 + PrimMax{} -> Operator "max" App L 10 + PrimVMin{} -> Operator "minimum" App L 10 + PrimVMax{} -> Operator "maximum" App L 10 + PrimLAnd{} -> Operator "&&" Infix R 3 + PrimLOr{} -> Operator "||" Infix R 2 + PrimLNot{} -> Operator "not" App L 10 + PrimVLAnd{} -> Operator "and" App L 10 + PrimVLOr{} -> Operator "or" App L 10 + PrimFromIntegral{} -> Operator "fromIntegral" App L 10 + PrimToFloating{} -> Operator "toFloating" App L 10 + PrimToBool{} -> Operator "toBool" App L 10 + PrimFromBool{} -> Operator "fromBool" App L 10 -- Environments diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index 77f0232be..cee537e7c 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -63,7 +63,7 @@ type V4 = Vec 4 type V8 = Vec 8 type V16 = Vec 16 -instance (Show a, Elt a, SIMD n a) => Show (Vec n a) where +instance (Show a, SIMD n a) => Show (Vec n a) where show = vec . toList where vec :: [a] -> String @@ -128,15 +128,15 @@ instance (Eq a, SIMD n a) => Eq (Vec n a) where TypeFloat128 -> (==) -instance (Elt a, SIMD n a) => GHC.IsList (Vec n a) where +instance SIMD n a => GHC.IsList (Vec n a) where type Item (Vec n a) = a toList = toList fromList = fromList -toList :: forall n a. (Elt a, SIMD n a) => Vec n a -> [a] +toList :: forall n a. SIMD n a => Vec n a -> [a] toList (Vec vs) = map toElt $ R.toList (vecR @n @a) (eltR @a) vs -fromList :: forall n a. (Elt a, SIMD n a) => [a] -> Vec n a +fromList :: forall n a. SIMD n a => [a] -> Vec n a fromList = Vec . R.fromList (vecR @n @a) (eltR @a) . map fromElt instance SIMD n a => Elt (Vec n a) where @@ -151,7 +151,7 @@ instance SIMD n a => Elt (Vec n a) where -- -- @since 1.4.0.0 -- -class KnownNat n => SIMD n a where +class (KnownNat n, Elt a) => SIMD n a where type VecR n a :: Type type VecR n a = GVecR () n (Rep a) diff --git a/src/Data/Array/Accelerate/Trafo/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Simplify.hs index c78da38ca..c7a31aae7 100644 --- a/src/Data/Array/Accelerate/Trafo/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Simplify.hs @@ -618,6 +618,8 @@ summariseOpenExp = (terms +~ 1) . goE PrimNeg t -> travNumType t PrimAbs t -> travNumType t PrimSig t -> travNumType t + PrimVAdd t -> travNumType t + PrimVMul t -> travNumType t PrimQuot t -> travIntegralType t PrimRem t -> travIntegralType t PrimQuotRem t -> travIntegralType t @@ -635,6 +637,11 @@ summariseOpenExp = (terms +~ 1) . goE PrimPopCount t -> travIntegralType t PrimCountLeadingZeros t -> travIntegralType t PrimCountTrailingZeros t -> travIntegralType t + PrimBReverse t -> travIntegralType t + PrimBSwap t -> travIntegralType t + PrimVBAnd t -> travIntegralType t + PrimVBOr t -> travIntegralType t + PrimVBXor t -> travIntegralType t PrimFDiv t -> travFloatingType t PrimRecip t -> travFloatingType t PrimSin t -> travFloatingType t @@ -667,11 +674,15 @@ summariseOpenExp = (terms +~ 1) . goE PrimGtEq t -> travScalarType t PrimEq t -> travScalarType t PrimNEq t -> travScalarType t - PrimMax t -> travScalarType t PrimMin t -> travScalarType t - PrimLAnd _ -> zero & types +~ 1 - PrimLOr _ -> zero & types +~ 1 - PrimLNot _ -> zero & types +~ 1 + PrimMax t -> travScalarType t + PrimVMin t -> travScalarType t + PrimVMax t -> travScalarType t + PrimLAnd t -> travBitType t + PrimLOr t -> travBitType t + PrimLNot t -> travBitType t + PrimVLAnd t -> travBitType t + PrimVLOr t -> travBitType t PrimFromIntegral i n -> travIntegralType i +++ travNumType n PrimToFloating n f -> travNumType n +++ travFloatingType f PrimToBool i b -> travIntegralType i +++ travBitType b From 8fbf6615b38835998d266f016080bc820e49bd4b Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 28 Aug 2023 20:14:42 +0200 Subject: [PATCH 75/86] copy pasta error --- src/Data/Array/Accelerate/Classes/Num.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Data/Array/Accelerate/Classes/Num.hs b/src/Data/Array/Accelerate/Classes/Num.hs index 551b55ad8..351f050a0 100644 --- a/src/Data/Array/Accelerate/Classes/Num.hs +++ b/src/Data/Array/Accelerate/Classes/Num.hs @@ -108,7 +108,7 @@ runQ $ [d| instance P.Num (Exp $(conT a)) where (+) = mkPrimBinary $ PrimAdd numType (-) = mkPrimBinary $ PrimSub numType - (*) = mkPrimBinary $ PrimSub numType + (*) = mkPrimBinary $ PrimMul numType negate = mkPrimUnary $ PrimNeg numType abs = mkPrimUnary $ PrimAbs numType signum = mkPrimUnary $ PrimSig numType @@ -117,7 +117,7 @@ runQ $ instance KnownNat n => P.Num (Exp (Vec n $(conT a))) where (+) = mkPrimBinary $ PrimAdd numType (-) = mkPrimBinary $ PrimSub numType - (*) = mkPrimBinary $ PrimSub numType + (*) = mkPrimBinary $ PrimMul numType negate = mkPrimUnary $ PrimNeg numType abs = mkPrimUnary $ PrimAbs numType signum = mkPrimUnary $ PrimSig numType From 673cc77fc4f5c57ac721dd5861327784ecc9ddca Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 28 Aug 2023 20:15:12 +0200 Subject: [PATCH 76/86] show instance for TupR --- src/Data/Array/Accelerate/Representation/Type.hs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Data/Array/Accelerate/Representation/Type.hs b/src/Data/Array/Accelerate/Representation/Type.hs index 028a7a3d6..b275b58a0 100644 --- a/src/Data/Array/Accelerate/Representation/Type.hs +++ b/src/Data/Array/Accelerate/Representation/Type.hs @@ -47,7 +47,11 @@ data TupR s a where TupRsingle :: s a -> TupR s a TupRpair :: TupR s a -> TupR s b -> TupR s (a, b) -deriving instance (forall a. Show (s a)) => Show (TupR s t) +instance (forall a. Show (s a)) => Show (TupR s t) where + show = \case + TupRunit -> "()" + TupRsingle t -> show t + TupRpair a b -> "(" ++ show a ++ "," ++ show b ++ ")" formatTypeR :: Format r (TypeR a -> r) formatTypeR = later $ \case From a9f062ee80f16d02ce0c1e6d64faaeabbe113107 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Mon, 28 Aug 2023 20:22:15 +0200 Subject: [PATCH 77/86] coerce arrays between different types --- src/Data/Array/Accelerate/AST.hs | 29 +- src/Data/Array/Accelerate/Analysis/Hash.hs | 1 + src/Data/Array/Accelerate/Array/Unique.hs | 10 + src/Data/Array/Accelerate/Interpreter.hs | 265 ++++++++++++++++++ src/Data/Array/Accelerate/Pretty/Graphviz.hs | 7 +- src/Data/Array/Accelerate/Pretty/Print.hs | 1 + src/Data/Array/Accelerate/Smart.hs | 106 ++++++- src/Data/Array/Accelerate/Trafo/Fusion.hs | 5 +- src/Data/Array/Accelerate/Trafo/LetSplit.hs | 1 + src/Data/Array/Accelerate/Trafo/Sharing.hs | 10 +- src/Data/Array/Accelerate/Trafo/Shrink.hs | 1 + src/Data/Array/Accelerate/Trafo/Simplify.hs | 1 + .../Array/Accelerate/Trafo/Substitution.hs | 1 + src/Data/Array/Accelerate/Type.hs | 1 - src/Data/Array/Accelerate/Unsafe.hs | 53 +++- src/Data/Primitive/Bit.hs | 2 +- 16 files changed, 469 insertions(+), 25 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 1e162f77a..40bc8fd6a 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -81,7 +81,7 @@ module Data.Array.Accelerate.AST ( -- * Internal AST -- ** Array computations Afun, PreAfun, OpenAfun, PreOpenAfun(..), - Acc, OpenAcc(..), PreOpenAcc(..), Direction(..), Message(..), + Acc, OpenAcc(..), PreOpenAcc(..), Direction(..), Message(..), RescaleFactor, ALeftHandSide, ArrayVar, ArrayVars, -- ** Scalar expressions @@ -205,6 +205,13 @@ data Message a where -> Text -> Message a +-- Coercing an array to a different type may involve scaling the size of the +-- array by the given factor. Positive values mean the size of the innermost +-- dimension is multiplied by that value (i.e. the number of elements in the +-- array grows by that factor), negative meaning it shrinks. +-- +type RescaleFactor = INT + -- | Collective array computations parametrised over array variables -- represented with de Bruijn indices. -- @@ -287,6 +294,21 @@ data PreOpenAcc (acc :: Type -> Type -> Type) aenv a where -> acc aenv arrs2 -> PreOpenAcc acc aenv arrs2 + -- Reinterpret the bits of the array as a different type. The size of the + -- innermost dimension is adjusted as necessary. The old and new sizes must be + -- compatible, but this may not be checked; e.g. in the conversion + -- + -- > Array (Z :. n) Float <~~> Array (Z :. m) (Vec 3 Float) + -- + -- then we require + -- + -- > (m, r) = quotRem n 3 and r == 0 + -- + Acoerce :: RescaleFactor + -> TypeR b + -> acc aenv (Array (sh, INT) a) + -> PreOpenAcc acc aenv (Array (sh, INT) b) + -- Array inlet. Triggers (possibly) asynchronous host->device transfer if -- necessary. -- @@ -792,6 +814,8 @@ instance HasArraysR acc => HasArraysR (PreOpenAcc acc) where arraysR (Apair as bs) = TupRpair (arraysR as) (arraysR bs) arraysR Anil = TupRunit arraysR (Atrace _ _ bs) = arraysR bs + arraysR (Acoerce _ bR as) = let ArrayR shR _ = arrayR as + in arraysRarray shR bR arraysR (Apply aR _ _) = aR arraysR (Aforeign r _ _ _) = r arraysR (Acond _ a _) = arraysR a @@ -1092,6 +1116,7 @@ rnfPreOpenAcc rnfA pacc = Apair as bs -> rnfA as `seq` rnfA bs Anil -> () Atrace msg as bs -> rnfM msg `seq` rnfA as `seq` rnfA bs + Acoerce scale bR as -> scale `seq` rnfTypeR bR `seq` rnfA as Apply repr afun acc -> rnfTupR rnfArrayR repr `seq` rnfAF afun `seq` rnfA acc Aforeign repr asm afun a -> rnfTupR rnfArrayR repr `seq` rnf (strForeign asm) `seq` rnfAF afun `seq` rnfA a Acond p a1 a2 -> rnfE p `seq` rnfA a1 `seq` rnfA a2 @@ -1310,6 +1335,7 @@ liftPreOpenAcc liftA pacc = Apair as bs -> [|| Apair $$(liftA as) $$(liftA bs) ||] Anil -> [|| Anil ||] Atrace msg as bs -> [|| Atrace $$(liftMessage (arraysR as) msg) $$(liftA as) $$(liftA bs) ||] + Acoerce scale bR a -> [|| Acoerce scale $$(liftTypeR bR) $$(liftA a) ||] Apply repr f a -> [|| Apply $$(liftArraysR repr) $$(liftAF f) $$(liftA a) ||] Aforeign repr asm f a -> [|| Aforeign $$(liftArraysR repr) $$(liftForeign asm) $$(liftPreOpenAfun liftA f) $$(liftA a) ||] Acond p t e -> [|| Acond $$(liftE p) $$(liftA t) $$(liftA e) ||] @@ -1523,6 +1549,7 @@ formatPreAccOp = later $ \case Avar (Var _ ix) -> bformat ("Avar a" % int) (idxToInt ix) Use aR a -> bformat ("Use " % string) (showArrayShort 5 (showsElt (arrayRtype aR)) aR a) Atrace{} -> "Atrace" + Acoerce{} -> "Acoerce" Apply{} -> "Apply" Aforeign{} -> "Aforeign" Acond{} -> "Acond" diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 7cbc17956..51e959da3 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -174,6 +174,7 @@ encodePreOpenAcc options encodeAcc pacc = Apair a1 a2 -> intHost $(hashQ "Apair") <> travA a1 <> travA a2 Anil -> intHost $(hashQ "Anil") Atrace (Message _ _ msg) as bs -> intHost $(hashQ "Atrace") <> intHost (Hashable.hash msg) <> travA as <> travA bs + Acoerce _ bR a -> intHost $(hashQ "Acoerce") <> encodeTypeR bR <> travA a Apply _ f a -> intHost $(hashQ "Apply") <> travAF f <> travA a Aforeign _ _ f a -> intHost $(hashQ "Aforeign") <> travAF f <> travA a Use repr a -> intHost $(hashQ "Use") <> encodeArrayType repr <> deep (encodeArray a) diff --git a/src/Data/Array/Accelerate/Array/Unique.hs b/src/Data/Array/Accelerate/Array/Unique.hs index 10aa97bd2..875203346 100644 --- a/src/Data/Array/Accelerate/Array/Unique.hs +++ b/src/Data/Array/Accelerate/Array/Unique.hs @@ -20,6 +20,7 @@ import Data.Array.Accelerate.Lifetime import Control.Applicative import Control.Concurrent.Unique import Control.DeepSeq +import Data.Coerce import Data.Word import Foreign.ForeignPtr import Foreign.ForeignPtr.Unsafe @@ -109,6 +110,15 @@ unsafeUniqueArrayPtr :: UniqueArray a -> Ptr a unsafeUniqueArrayPtr = unsafeForeignPtrToPtr . unsafeGetValue . uniqueArrayData +-- | Cast a unique array parameterised by one type into another type +-- +-- @since 1.4.0.0 +-- +{-# INLINE castUniqueArray #-} +castUniqueArray :: UniqueArray a -> UniqueArray b +castUniqueArray = coerce + + -- | Ensure that the unique array is alive at the given place in a sequence of -- IO actions. Note that this does not force the actual array payload. -- diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index f49e06526..f302dd4b8 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -47,7 +47,9 @@ module Data.Array.Accelerate.Interpreter ( import Data.Array.Accelerate.AST hiding ( Boundary(..) ) import Data.Array.Accelerate.AST.Environment import Data.Array.Accelerate.AST.Var +import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Array.Data +import Data.Array.Accelerate.Array.Unique import Data.Array.Accelerate.Error import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Elt @@ -87,6 +89,7 @@ import Unsafe.Coerce import Prelude hiding ( (!!), sum ) import qualified Data.Text.IO as T +import GHC.Prim import GHC.TypeLits @@ -211,6 +214,7 @@ evalOpenAcc (AST.Manifest pacc) aenv = (TupRpair r1 r2, (a1, a2)) Anil -> (TupRunit, ()) Atrace msg as bs -> unsafePerformIO $ manifest bs <$ atraceOp msg (snd $ manifest as) + Acoerce scale bR acc -> acoerceOp scale bR (manifest acc) Apply repr afun acc -> (repr, evalOpenAfun afun aenv $ snd $ manifest acc) Aforeign repr _ afun acc -> (repr, evalOpenAfun afun Empty $ snd $ manifest acc) Acond p acc1 acc2 @@ -992,6 +996,267 @@ evalOpenExp pexp env aenv = -- Coercions -- --------- +acoerceOp + :: HasCallStack + => RescaleFactor + -> TypeR b + -> WithReprs (Array (sh, INT) a) + -> WithReprs (Array (sh, INT) b) +acoerceOp scale bR (TupRsingle (ArrayR shR aR), Array (sz,sh) adata) = (repr', arr') + where + repr' = TupRsingle (ArrayR shR bR) + arr' = Array (sz,sh') adata' + sh' = case compare scale 0 of + EQ -> sh + GT -> sh * scale + LT -> let (q,r) = quotRem sh (negate scale) + in boundsCheck "shape mismatch" (r == 0) q + adata' = acoerce aR bR adata + + acoerce :: TypeR a -> TypeR b -> ArrayData a -> ArrayData b + acoerce TupRunit TupRunit () = () + acoerce (TupRpair TupRunit aR) bR ((), ad) | Just Refl <- matchTypeR aR bR = ad + acoerce (TupRpair aR TupRunit) bR (ad, ()) | Just Refl <- matchTypeR aR bR = ad + acoerce aR (TupRpair TupRunit bR) ad | Just Refl <- matchTypeR aR bR = ((), ad) + acoerce aR (TupRpair bR TupRunit) ad | Just Refl <- matchTypeR aR bR = (ad, ()) + acoerce (TupRpair aR1 aR2) (TupRpair bR1 bR2) (a1, a2) = (acoerce aR1 bR1 a1, acoerce aR2 bR2 a2) + acoerce (TupRsingle aR) (TupRsingle bR) ad = scalar aR bR ad + acoerce _ _ _ = internalError "missing cases for class Acoerce" + + scalar :: ScalarType a -> ScalarType b -> ArrayData a -> ArrayData b + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + num :: NumType a -> ScalarType b -> ArrayData a -> ArrayData b + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + bit :: forall a b. BitType a -> ScalarType b -> ArrayData a -> ArrayData b + bit = \case + TypeBit -> scalar' + TypeMask{} -> scalar' + where + scalar' :: ScalarType b -> ScalarArrayData a -> ArrayData b + scalar' (NumScalarType t) = num' t + scalar' (BitScalarType t) = bit' t + + bit' :: BitType b -> ScalarArrayData a -> ArrayData b + bit' TypeBit = castUniqueArray + bit' TypeMask{} = castUniqueArray + + num' :: NumType b -> ScalarArrayData a -> ArrayData b + num' (IntegralNumType t) = integral' t + num' (FloatingNumType t) = floating' t + + integral' :: IntegralType b -> ScalarArrayData a -> ArrayData b + integral' = \case + SingleIntegralType t -> single' t + VectorIntegralType n t -> vector' n t + where + single' :: SingleIntegralType b -> ScalarArrayData a -> ArrayData b + single' TypeInt8 = castUniqueArray + single' TypeInt16 = castUniqueArray + single' TypeInt32 = castUniqueArray + single' TypeInt64 = castUniqueArray + single' TypeInt128 = castUniqueArray + single' TypeWord8 = castUniqueArray + single' TypeWord16 = castUniqueArray + single' TypeWord32 = castUniqueArray + single' TypeWord64 = castUniqueArray + single' TypeWord128 = castUniqueArray + + vector' :: KnownNat n => Proxy# n -> SingleIntegralType c -> ScalarArrayData a -> ArrayData (Vec n c) + vector' _ TypeInt8 = castUniqueArray + vector' _ TypeInt16 = castUniqueArray + vector' _ TypeInt32 = castUniqueArray + vector' _ TypeInt64 = castUniqueArray + vector' _ TypeInt128 = castUniqueArray + vector' _ TypeWord8 = castUniqueArray + vector' _ TypeWord16 = castUniqueArray + vector' _ TypeWord32 = castUniqueArray + vector' _ TypeWord64 = castUniqueArray + vector' _ TypeWord128 = castUniqueArray + + floating' :: FloatingType b -> ScalarArrayData a -> ArrayData b + floating' = \case + SingleFloatingType t -> single' t + VectorFloatingType n t -> vector' n t + where + single' :: SingleFloatingType b -> ScalarArrayData a -> ArrayData b + single' TypeFloat16 = castUniqueArray + single' TypeFloat32 = castUniqueArray + single' TypeFloat64 = castUniqueArray + single' TypeFloat128 = castUniqueArray + + vector' :: KnownNat n => Proxy# n -> SingleFloatingType c -> ScalarArrayData a -> ArrayData (Vec n c) + vector' _ TypeFloat16 = castUniqueArray + vector' _ TypeFloat32 = castUniqueArray + vector' _ TypeFloat64 = castUniqueArray + vector' _ TypeFloat128 = castUniqueArray + + integral :: forall a b. IntegralType a -> ScalarType b -> ArrayData a -> ArrayData b + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n t + where + single :: SingleIntegralType a -> ScalarType b -> ArrayData a -> ArrayData b + single TypeInt8 = scalar' + single TypeInt16 = scalar' + single TypeInt32 = scalar' + single TypeInt64 = scalar' + single TypeInt128 = scalar' + single TypeWord8 = scalar' + single TypeWord16 = scalar' + single TypeWord32 = scalar' + single TypeWord64 = scalar' + single TypeWord128 = scalar' + + vector :: (KnownNat n, a ~ Vec n c) => Proxy# n -> SingleIntegralType c -> ScalarType b -> ArrayData (Vec n c) -> ArrayData b + vector _ TypeInt8 = scalar' + vector _ TypeInt16 = scalar' + vector _ TypeInt32 = scalar' + vector _ TypeInt64 = scalar' + vector _ TypeInt128 = scalar' + vector _ TypeWord8 = scalar' + vector _ TypeWord16 = scalar' + vector _ TypeWord32 = scalar' + vector _ TypeWord64 = scalar' + vector _ TypeWord128 = scalar' + + scalar' :: ScalarType b -> ScalarArrayData a -> ArrayData b + scalar' (NumScalarType t) = num' t + scalar' (BitScalarType t) = bit' t + + bit' :: BitType b -> ScalarArrayData a -> ArrayData b + bit' TypeBit = castUniqueArray + bit' TypeMask{} = castUniqueArray + + num' :: NumType b -> ScalarArrayData a -> ArrayData b + num' (IntegralNumType t) = integral' t + num' (FloatingNumType t) = floating' t + + integral' :: IntegralType b -> ScalarArrayData a -> ArrayData b + integral' = \case + SingleIntegralType t -> single' t + VectorIntegralType n t -> vector' n t + where + single' :: SingleIntegralType b -> ScalarArrayData a -> ArrayData b + single' TypeInt8 = castUniqueArray + single' TypeInt16 = castUniqueArray + single' TypeInt32 = castUniqueArray + single' TypeInt64 = castUniqueArray + single' TypeInt128 = castUniqueArray + single' TypeWord8 = castUniqueArray + single' TypeWord16 = castUniqueArray + single' TypeWord32 = castUniqueArray + single' TypeWord64 = castUniqueArray + single' TypeWord128 = castUniqueArray + + vector' :: KnownNat n => Proxy# n -> SingleIntegralType c -> ScalarArrayData a -> ArrayData (Vec n c) + vector' _ TypeInt8 = castUniqueArray + vector' _ TypeInt16 = castUniqueArray + vector' _ TypeInt32 = castUniqueArray + vector' _ TypeInt64 = castUniqueArray + vector' _ TypeInt128 = castUniqueArray + vector' _ TypeWord8 = castUniqueArray + vector' _ TypeWord16 = castUniqueArray + vector' _ TypeWord32 = castUniqueArray + vector' _ TypeWord64 = castUniqueArray + vector' _ TypeWord128 = castUniqueArray + + floating' :: FloatingType b -> ScalarArrayData a -> ArrayData b + floating' = \case + SingleFloatingType t -> single' t + VectorFloatingType n t -> vector' n t + where + single' :: SingleFloatingType b -> ScalarArrayData a -> ArrayData b + single' TypeFloat16 = castUniqueArray + single' TypeFloat32 = castUniqueArray + single' TypeFloat64 = castUniqueArray + single' TypeFloat128 = castUniqueArray + + vector' :: KnownNat n => Proxy# n -> SingleFloatingType c -> ScalarArrayData a -> ArrayData (Vec n c) + vector' _ TypeFloat16 = castUniqueArray + vector' _ TypeFloat32 = castUniqueArray + vector' _ TypeFloat64 = castUniqueArray + vector' _ TypeFloat128 = castUniqueArray + + floating :: forall a b. FloatingType a -> ScalarType b -> ArrayData a -> ArrayData b + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n t + where + single :: SingleFloatingType a -> ScalarType b -> ArrayData a -> ArrayData b + single TypeFloat16 = scalar' + single TypeFloat32 = scalar' + single TypeFloat64 = scalar' + single TypeFloat128 = scalar' + + vector :: (KnownNat n, a ~ Vec n c) => Proxy# n -> SingleFloatingType c -> ScalarType b -> ArrayData (Vec n c) -> ArrayData b + vector _ TypeFloat16 = scalar' + vector _ TypeFloat32 = scalar' + vector _ TypeFloat64 = scalar' + vector _ TypeFloat128 = scalar' + + scalar' :: ScalarType b -> ScalarArrayData a -> ArrayData b + scalar' (NumScalarType t) = num' t + scalar' (BitScalarType t) = bit' t + + bit' :: BitType b -> ScalarArrayData a -> ArrayData b + bit' TypeBit = castUniqueArray + bit' TypeMask{} = castUniqueArray + + num' :: NumType b -> ScalarArrayData a -> ArrayData b + num' (IntegralNumType t) = integral' t + num' (FloatingNumType t) = floating' t + + integral' :: IntegralType b -> ScalarArrayData a -> ArrayData b + integral' = \case + SingleIntegralType t -> single' t + VectorIntegralType n t -> vector' n t + where + single' :: SingleIntegralType b -> ScalarArrayData a -> ArrayData b + single' TypeInt8 = castUniqueArray + single' TypeInt16 = castUniqueArray + single' TypeInt32 = castUniqueArray + single' TypeInt64 = castUniqueArray + single' TypeInt128 = castUniqueArray + single' TypeWord8 = castUniqueArray + single' TypeWord16 = castUniqueArray + single' TypeWord32 = castUniqueArray + single' TypeWord64 = castUniqueArray + single' TypeWord128 = castUniqueArray + + vector' :: KnownNat n => Proxy# n -> SingleIntegralType c -> ScalarArrayData a -> ArrayData (Vec n c) + vector' _ TypeInt8 = castUniqueArray + vector' _ TypeInt16 = castUniqueArray + vector' _ TypeInt32 = castUniqueArray + vector' _ TypeInt64 = castUniqueArray + vector' _ TypeInt128 = castUniqueArray + vector' _ TypeWord8 = castUniqueArray + vector' _ TypeWord16 = castUniqueArray + vector' _ TypeWord32 = castUniqueArray + vector' _ TypeWord64 = castUniqueArray + vector' _ TypeWord128 = castUniqueArray + + floating' :: FloatingType b -> ScalarArrayData a -> ArrayData b + floating' = \case + SingleFloatingType t -> single' t + VectorFloatingType n t -> vector' n t + where + single' :: SingleFloatingType b -> ScalarArrayData a -> ArrayData b + single' TypeFloat16 = castUniqueArray + single' TypeFloat32 = castUniqueArray + single' TypeFloat64 = castUniqueArray + single' TypeFloat128 = castUniqueArray + + vector' :: KnownNat n => Proxy# n -> SingleFloatingType c -> ScalarArrayData a -> ArrayData (Vec n c) + vector' _ TypeFloat16 = castUniqueArray + vector' _ TypeFloat32 = castUniqueArray + vector' _ TypeFloat64 = castUniqueArray + vector' _ TypeFloat128 = castUniqueArray + + -- Coercion between two scalar types. We require that the size of the source and -- destination values are equal (this is not checked at this point). -- diff --git a/src/Data/Array/Accelerate/Pretty/Graphviz.hs b/src/Data/Array/Accelerate/Pretty/Graphviz.hs index 8d22e958b..04547e616 100644 --- a/src/Data/Array/Accelerate/Pretty/Graphviz.hs +++ b/src/Data/Array/Accelerate/Pretty/Graphviz.hs @@ -192,9 +192,9 @@ prettyDelayedOpenAcc detail ctx aenv (Manifest pacc) = Avar ix -> pnode (avar ix) Alet lhs bnd body -> do bnd'@(PNode ident _ _) <- prettyDelayedOpenAcc detail context0 aenv bnd - (aenv1, a) <- prettyLetALeftHandSide ident aenv lhs - _ <- mkNode bnd' (Just a) - body' <- prettyDelayedOpenAcc detail context0 aenv1 body + (aenv1, a) <- prettyLetALeftHandSide ident aenv lhs + _ <- mkNode bnd' (Just a) + body' <- prettyDelayedOpenAcc detail context0 aenv1 body return body' Acond p t e -> do @@ -225,6 +225,7 @@ prettyDelayedOpenAcc detail ctx aenv (Manifest pacc) = Anil -> "()" .$ [] Atrace (Message _ _ msg) as bs -> "atrace" .$ [ return $ PDoc (pretty msg) [], ppA as, ppA bs ] + Acoerce _ bR a -> "coerce" .$ [ return $ PDoc ("@" <> pretty (show bR)) [], ppA a ] Use repr arr -> "use" .$ [ return $ PDoc (prettyArray repr arr) [] ] Unit _ e -> "unit" .$ [ ppE e ] Generate _ sh f -> "generate" .$ [ ppE sh, ppF f ] diff --git a/src/Data/Array/Accelerate/Pretty/Print.hs b/src/Data/Array/Accelerate/Pretty/Print.hs index 05eee3dfe..34b7dd93a 100644 --- a/src/Data/Array/Accelerate/Pretty/Print.hs +++ b/src/Data/Array/Accelerate/Pretty/Print.hs @@ -195,6 +195,7 @@ prettyPreOpenAcc config ctx prettyAcc extractAcc aenv pacc = Atrace (Message _ _ msg) as bs -> ppN "atrace" .$ [ fromString (show msg), ppA as, ppA bs ] + Acoerce _ bR a -> ppN "coerce" .$ [ "@" <> pretty (show bR), ppA a ] Aforeign _ ff _ a -> ppN "aforeign" .$ [ pretty (strForeign ff), ppA a ] Awhile p f a -> ppN "awhile" .$ [ ppAF p, ppAF f, ppA a ] Use repr arr -> ppN "use" .$ [ prettyArray repr arr ] diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index f77f9bbf5..4940678da 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -60,7 +60,9 @@ module Data.Array.Accelerate.Smart ( select, -- ** Type coercions - mkBitcast, mkCoerce, Coerce(..), + mkBitcast, + mkCoerce, Coerce(..), + mkAcoerce, Acoerce(..), -- ** Miscellaneous ($$), ($$$), ($$$$), ($$$$$), @@ -73,7 +75,7 @@ module Data.Array.Accelerate.Smart ( ) where -import Data.Array.Accelerate.AST ( Direction(..), Message(..), PrimBool, PrimMaybe, PrimFun(..), primFunType ) +import Data.Array.Accelerate.AST ( Direction(..), Message(..), RescaleFactor, PrimBool, PrimMaybe, PrimFun(..), primFunType ) import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Error @@ -351,6 +353,11 @@ data PreSmartAcc acc exp as where -> acc arrs2 -> PreSmartAcc acc exp arrs2 + Acoerce :: RescaleFactor + -> TypeR b + -> acc (Array (sh, INT) a) + -> PreSmartAcc acc exp (Array (sh, INT) b) + Use :: ArrayR (Array sh e) -> Array sh e -> PreSmartAcc acc exp (Array sh e) @@ -451,12 +458,6 @@ data PreSmartAcc acc exp as where -> acc (Array sh b) -> PreSmartAcc acc exp (Array sh c) - -- Coerce :: ShapeR sh - -- -> TypeR a - -- -> TypeR b - -- -> acc (Array sh a) - -- -> PreSmartAcc acc exp (Array sh b) - -- Embedded expressions of the surface language -- -------------------------------------------- @@ -815,8 +816,10 @@ instance HasArraysR acc => HasArraysR (PreSmartAcc acc exp) where PairIdxRight -> t2 Aprj _ _ -> error "Ejector seat? You're joking!" Atrace _ _ a -> arraysR a + Acoerce _ bR a -> let ArrayR shR _ = arrayR a + in TupRsingle $ ArrayR shR bR Use repr _ -> TupRsingle repr - Unit aR _ -> TupRsingle $ ArrayR ShapeRz $ aR + Unit aR _ -> TupRsingle $ ArrayR ShapeRz aR Generate repr _ _ -> TupRsingle repr Reshape shr _ a -> let ArrayR _ aR = arrayR a in TupRsingle $ ArrayR shr aR @@ -1244,6 +1247,90 @@ instance Coerce a (a, ()) where mkCoerce' a = SmartExp (Pair a (SmartExp Nil)) +mkAcoerce + :: forall b a sh. (Acoerce (EltR a) (EltR b)) + => Acc (Sugar.Array (sh :. Int) a) + -> Acc (Sugar.Array (sh :. Int) b) +mkAcoerce (Acc a) = + let ArrayR _ aR = arrayR a + (bR, sz) = mkAcoerce' @(EltR a) aR + in Acc $ SmartAcc (Acoerce sz bR a) + +class Acoerce a b where + mkAcoerce' :: TypeR a -> (TypeR b, RescaleFactor) + +instance {-# OVERLAPS #-} (IsScalar a, IsScalar b) => Acoerce a b where + mkAcoerce' _ = + let ta = scalarType @a + tb = scalarType @b + sa = scalar ta + sb = scalar tb + sz = case compare sa sb of + EQ -> 0 + GT -> sa `quot` sb + LT -> negate $ sb `quot` sa + -- + scalar :: ScalarType t -> RescaleFactor + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> RescaleFactor + bit TypeBit = 1 + bit (TypeMask n) = fromInteger (natVal' n) + + num :: NumType t -> RescaleFactor + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> RescaleFactor + integral (VectorIntegralType n t) = fromInteger (natVal' n) * integral (SingleIntegralType t) + integral (SingleIntegralType t) = case t of + TypeInt n -> fromIntegral n + TypeWord n -> fromIntegral n + + floating :: FloatingType t -> RescaleFactor + floating (VectorFloatingType n t) = fromInteger (natVal' n) * floating (SingleFloatingType t) + floating (SingleFloatingType t) = case t of + TypeFloat16 -> 16 + TypeFloat32 -> 32 + TypeFloat64 -> 64 + TypeFloat128 -> 128 + in + (TupRsingle tb, sz) + +-- TODO: Make this a compile time error. This should be possible with type +-- families, but GHC seems to have problems normalising the BitSize of 'Vec's +-- (and the GHC.TypeLits.Normalise package unfortunately does not seem to help). +-- +instance (Acoerce a1 b1, Acoerce a2 b2) => Acoerce (a1, a2) (b1, b2) where + mkAcoerce' (TupRpair a1 a2) = + let (b1, s1) = mkAcoerce' a1 + (b2, s2) = mkAcoerce' a2 + in + if s1 == s2 + then (TupRpair b1 b2, s1) + else error $ formatToString ("Could not coerce type `" % formatTypeR % "' to `" % formatTypeR % "'") + (TupRpair a1 a2) (TupRpair b1 b2) + mkAcoerce' _ = error "impossible" + +instance Acoerce ((), a) a where + mkAcoerce' (TupRpair TupRunit a) = (a, 0) + mkAcoerce' _ = error "impossible" + +instance Acoerce (a, ()) a where + mkAcoerce' (TupRpair a TupRunit) = (a, 0) + mkAcoerce' _ = error "impossible" + +instance Acoerce a ((), a) where + mkAcoerce' a = (TupRpair TupRunit a, 0) + +instance Acoerce a (a, ()) where + mkAcoerce' a = (TupRpair a TupRunit, 0) + +instance Acoerce a a where + mkAcoerce' a = (a, 0) + + -- Auxiliary functions -- -------------------- @@ -1361,6 +1448,7 @@ formatPreAccOp = later $ \case Stencil{} -> "Stencil" Stencil2{} -> "Stencil2" Aforeign{} -> "Aforeign" + Acoerce{} -> "Acoerce" formatPreExpOp :: Format r (PreSmartExp acc exp t -> r) formatPreExpOp = later $ \case diff --git a/src/Data/Array/Accelerate/Trafo/Fusion.hs b/src/Data/Array/Accelerate/Trafo/Fusion.hs index 7b0f4fd3e..d4bcf3487 100644 --- a/src/Data/Array/Accelerate/Trafo/Fusion.hs +++ b/src/Data/Array/Accelerate/Trafo/Fusion.hs @@ -178,6 +178,7 @@ manifest config (OpenAcc pacc) = Apair a1 a2 -> Apair (manifest config a1) (manifest config a2) Anil -> Anil Atrace msg a1 a2 -> Atrace msg (manifest config a1) (manifest config a2) + Acoerce scale bR a -> Acoerce scale bR (manifest config a) Apply repr f a -> apply repr (cvtAF f) (manifest config a) Aforeign repr ff f a -> Aforeign repr ff (cvtAF f) (manifest config a) @@ -364,12 +365,13 @@ embedPreOpenAcc config matchAcc embedAcc elimAcc pacc -- duplication. SEE: [Sharing vs. Fusion] -- Alet lhs bnd body -> aletD embedAcc elimAcc lhs bnd body - Anil -> done $ Anil Acond p at ae -> acondD matchAcc embedAcc (cvtE p) at ae + Anil -> done $ Anil Apply aR f a -> done $ Apply aR (cvtAF f) (cvtA a) Awhile p f a -> done $ Awhile (cvtAF p) (cvtAF f) (cvtA a) Apair a1 a2 -> done $ Apair (cvtA a1) (cvtA a2) Atrace msg a1 a2 -> done $ Atrace msg (cvtA a1) (cvtA a2) + Acoerce scale bR a -> done $ Acoerce scale bR (cvtA a) Aforeign aR ff f a -> done $ Aforeign aR ff (cvtAF f) (cvtA a) -- Collect s -> collectD s @@ -1549,6 +1551,7 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en Acond p at ae -> Acond (cvtE p) (cvtA at) (cvtA ae) Anil -> Anil Atrace msg a b -> Atrace msg (cvtA a) (cvtA b) + Acoerce scale bR a -> Acoerce scale bR (cvtA a) Apair a1 a2 -> Apair (cvtA a1) (cvtA a2) Awhile p f a -> Awhile (cvtAF p) (cvtAF f) (cvtA a) Apply repr f a -> Apply repr (cvtAF f) (cvtA a) diff --git a/src/Data/Array/Accelerate/Trafo/LetSplit.hs b/src/Data/Array/Accelerate/Trafo/LetSplit.hs index 9d3c1250e..efb3b94f5 100644 --- a/src/Data/Array/Accelerate/Trafo/LetSplit.hs +++ b/src/Data/Array/Accelerate/Trafo/LetSplit.hs @@ -39,6 +39,7 @@ convertPreOpenAcc = \case Apair a1 a2 -> Apair (convertAcc a1) (convertAcc a2) Anil -> Anil Atrace msg as bs -> Atrace msg (convertAcc as) (convertAcc bs) + Acoerce scale bR a -> Acoerce scale bR (convertAcc a) Apply repr f a -> Apply repr (convertAfun f) (convertAcc a) Aforeign repr asm f a -> Aforeign repr asm (convertAfun f) (convertAcc a) Acond e a1 a2 -> Acond e (convertAcc a1) (convertAcc a2) diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index bb591366b..e8bca9635 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -337,6 +337,7 @@ convertSharingAcc config alyt aenv (ScopedAcc lams (AccSharing _ preAcc)) Aprj ix a -> let AST.OpenAcc a' = cvtAprj ix a in a' Atrace msg acc1 acc2 -> AST.Atrace msg (cvtA acc1) (cvtA acc2) + Acoerce scale bR acc -> AST.Acoerce scale bR (cvtA acc) Use repr array -> AST.Use repr array Unit tp e -> AST.Unit tp (cvtE e) Generate repr@(ArrayR shr _) sh f @@ -1557,6 +1558,9 @@ makeOccMapSharingAcc config accOccMap = traverseAcc (a', h1) <- traverseAcc lvl acc1 (b', h2) <- traverseAcc lvl acc2 return (Atrace msg a' b', h1 `max` h2 + 1) + Acoerce scale bR acc -> do + (a', h) <- traverseAcc lvl acc + return (Acoerce scale bR a', h + 1) Use repr arr -> return (Use repr arr, 1) Unit tp e -> do (e', h) <- traverseExp lvl e @@ -2421,7 +2425,11 @@ determineScopesSharingAcc config accOccMap = scopesAcc (a1', accCount1) = scopesAcc a1 (a2', accCount2) = scopesAcc a2 in - reconstruct (Atrace msg a1' a2') (accCount1 +++ accCount2) + reconstruct (Atrace msg a1' a2') (accCount1 +++ accCount2) + Acoerce scale bR acc -> let + (acc', accCount) = scopesAcc acc + in + reconstruct (Acoerce scale bR acc') accCount Use repr arr -> reconstruct (Use repr arr) noNodeCounts Unit tp e -> let (e', accCount) = scopesExp e diff --git a/src/Data/Array/Accelerate/Trafo/Shrink.hs b/src/Data/Array/Accelerate/Trafo/Shrink.hs index 89548f345..f76013a86 100644 --- a/src/Data/Array/Accelerate/Trafo/Shrink.hs +++ b/src/Data/Array/Accelerate/Trafo/Shrink.hs @@ -546,6 +546,7 @@ usesOfPreAcc withShape countAcc idx = count Apair a1 a2 -> countA a1 + countA a2 Anil -> 0 Atrace _ a1 a2 -> countA a1 + countA a2 + Acoerce _ _ a -> countA a Apply _ f a -> countAF f idx + countA a Aforeign _ _ _ a -> countA a Acond p t e -> countE p + countA t + countA e diff --git a/src/Data/Array/Accelerate/Trafo/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Simplify.hs index c7a31aae7..7bf19b44a 100644 --- a/src/Data/Array/Accelerate/Trafo/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Simplify.hs @@ -91,6 +91,7 @@ simplifyPreOpenAcc = \case Apair a1 a2 -> Apair (simplifyOpenAcc a1) (simplifyOpenAcc a2) Anil -> Anil Atrace msg as bs -> Atrace msg (simplifyOpenAcc as) (simplifyOpenAcc bs) + Acoerce scale bR a -> Acoerce scale bR (simplifyOpenAcc a) Apply repr f a -> Apply repr (simplifyOpenAfun f) (simplifyOpenAcc a) Aforeign repr asm f a -> Aforeign repr asm (simplifyOpenAfun f) (simplifyOpenAcc a) Acond e a1 a2 -> Acond (simplifyExp e) (simplifyOpenAcc a1) (simplifyOpenAcc a2) diff --git a/src/Data/Array/Accelerate/Trafo/Substitution.hs b/src/Data/Array/Accelerate/Trafo/Substitution.hs index d187da63e..8cb0a3512 100644 --- a/src/Data/Array/Accelerate/Trafo/Substitution.hs +++ b/src/Data/Array/Accelerate/Trafo/Substitution.hs @@ -683,6 +683,7 @@ rebuildPreOpenAcc k av acc = Apair as bs -> Apair <$> k av as <*> k av bs Anil -> pure Anil Atrace msg as bs -> Atrace msg <$> k av as <*> k av bs + Acoerce s bR a -> Acoerce s bR <$> k av a Apply repr f a -> Apply repr <$> rebuildAfun k av f <*> k av a Acond p t e -> Acond <$> rebuildOpenExp (pure . IE) av' p <*> k av t <*> k av e Awhile p f a -> Awhile <$> rebuildAfun k av p <*> rebuildAfun k av f <*> k av a diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 4cc0d4dee..3330a4f0e 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -7,7 +7,6 @@ {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} diff --git a/src/Data/Array/Accelerate/Unsafe.hs b/src/Data/Array/Accelerate/Unsafe.hs index 9b898f575..70b7fc677 100644 --- a/src/Data/Array/Accelerate/Unsafe.hs +++ b/src/Data/Array/Accelerate/Unsafe.hs @@ -1,4 +1,8 @@ -{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} -- | -- Module : Data.Array.Accelerate.Unsafe -- Copyright : [2009..2020] The Accelerate Team @@ -23,10 +27,16 @@ module Data.Array.Accelerate.Unsafe ( import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Array +import Data.Array.Accelerate.Sugar.Shape --- | The function 'coerce' allows you to convert a value between any two types --- whose underlying representations have the same bit size at each component. +-- | The class 'Coercible' reinterprets the bits of a value or array of values +-- as that of a different type. +-- +-- At the expression level, this allows you to convert a value between any two +-- types whose underlying representations have the same bit size at each +-- component. -- -- For example: -- @@ -41,11 +51,38 @@ import Data.Array.Accelerate.Sugar.Elt -- abstract type to the concrete type by dropping the extra @()@ from the -- representation, and vice-versa. -- --- The type class 'Coerce' assures that there is a coercion between the two --- types. +-- At the array level this may also entail changing the size of the innermost +-- dimension. -- --- @since 1.2.0.0 +-- For example: +-- +-- > coerce (x :: Acc (Vector Float)) :: Acc (Vector (Complex Float)) +-- +-- will result in an array with half as many elements, as each element now +-- consists of two values (the real and imaginary values laid out consecutively +-- in memory, and now interpreted as a single packed 'Vec 2 Float'). For this to +-- be safe, the size of 'x' must therefore be even. +-- +-- Note that when applied at the array level 'coerce' prevents array fusion. +-- Therefore if the bit size of the source and target value types is the same, +-- then: +-- +-- > map f . coerce . map g -- -coerce :: Coerce (EltR a) (EltR b) => Exp a -> Exp b -coerce = mkCoerce +-- will result in two kernels being executed, whereas: +-- +-- > map f . map coerce . map g +-- +-- will fuse into a single kernel. +-- +-- @since 1.4.0.0 +-- +class Coercible f a b where + coerce :: f a -> f b + +instance Acoerce (EltR a) (EltR b) => Coercible Acc (Array (sh :. Int) a) (Array (sh :. Int) b) where + coerce = mkAcoerce + +instance Coerce (EltR a) (EltR b) => Coercible Exp a b where + coerce = mkCoerce diff --git a/src/Data/Primitive/Bit.hs b/src/Data/Primitive/Bit.hs index 17e2f17e4..773f7feeb 100644 --- a/src/Data/Primitive/Bit.hs +++ b/src/Data/Primitive/Bit.hs @@ -58,7 +58,7 @@ import GHC.Prim -- | A newtype wrapper over 'Bool' whose instances pack bits as efficiently -- as possible (8 values per byte). Arrays of 'Bit' use 8x less memory than -- arrays of 'Bool' (which stores one value per byte). However, (parallel) --- random writes are slower. +-- random writes are (almost certainly) slower. -- newtype Bit = Bit { unBit :: Bool } deriving (Eq, Ord, Bounded, Enum, FiniteBits, Bits, Typeable, Generic) From ac5ff616556a7c3ec8c5fcb908b4db3c7191db85 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 29 Aug 2023 11:36:26 +0200 Subject: [PATCH 78/86] minor cleanup --- src/Data/Array/Accelerate/Interpreter.hs | 294 +++++++---------------- 1 file changed, 89 insertions(+), 205 deletions(-) diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index f302dd4b8..7c5a9583c 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} @@ -1032,67 +1033,8 @@ acoerceOp scale bR (TupRsingle (ArrayR shR aR), Array (sz,sh) adata) = (repr', a num (FloatingNumType t) = floating t bit :: forall a b. BitType a -> ScalarType b -> ArrayData a -> ArrayData b - bit = \case - TypeBit -> scalar' - TypeMask{} -> scalar' - where - scalar' :: ScalarType b -> ScalarArrayData a -> ArrayData b - scalar' (NumScalarType t) = num' t - scalar' (BitScalarType t) = bit' t - - bit' :: BitType b -> ScalarArrayData a -> ArrayData b - bit' TypeBit = castUniqueArray - bit' TypeMask{} = castUniqueArray - - num' :: NumType b -> ScalarArrayData a -> ArrayData b - num' (IntegralNumType t) = integral' t - num' (FloatingNumType t) = floating' t - - integral' :: IntegralType b -> ScalarArrayData a -> ArrayData b - integral' = \case - SingleIntegralType t -> single' t - VectorIntegralType n t -> vector' n t - where - single' :: SingleIntegralType b -> ScalarArrayData a -> ArrayData b - single' TypeInt8 = castUniqueArray - single' TypeInt16 = castUniqueArray - single' TypeInt32 = castUniqueArray - single' TypeInt64 = castUniqueArray - single' TypeInt128 = castUniqueArray - single' TypeWord8 = castUniqueArray - single' TypeWord16 = castUniqueArray - single' TypeWord32 = castUniqueArray - single' TypeWord64 = castUniqueArray - single' TypeWord128 = castUniqueArray - - vector' :: KnownNat n => Proxy# n -> SingleIntegralType c -> ScalarArrayData a -> ArrayData (Vec n c) - vector' _ TypeInt8 = castUniqueArray - vector' _ TypeInt16 = castUniqueArray - vector' _ TypeInt32 = castUniqueArray - vector' _ TypeInt64 = castUniqueArray - vector' _ TypeInt128 = castUniqueArray - vector' _ TypeWord8 = castUniqueArray - vector' _ TypeWord16 = castUniqueArray - vector' _ TypeWord32 = castUniqueArray - vector' _ TypeWord64 = castUniqueArray - vector' _ TypeWord128 = castUniqueArray - - floating' :: FloatingType b -> ScalarArrayData a -> ArrayData b - floating' = \case - SingleFloatingType t -> single' t - VectorFloatingType n t -> vector' n t - where - single' :: SingleFloatingType b -> ScalarArrayData a -> ArrayData b - single' TypeFloat16 = castUniqueArray - single' TypeFloat32 = castUniqueArray - single' TypeFloat64 = castUniqueArray - single' TypeFloat128 = castUniqueArray - - vector' :: KnownNat n => Proxy# n -> SingleFloatingType c -> ScalarArrayData a -> ArrayData (Vec n c) - vector' _ TypeFloat16 = castUniqueArray - vector' _ TypeFloat32 = castUniqueArray - vector' _ TypeFloat64 = castUniqueArray - vector' _ TypeFloat128 = castUniqueArray + bit TypeBit = scalar' @a + bit TypeMask{} = scalar' @a integral :: forall a b. IntegralType a -> ScalarType b -> ArrayData a -> ArrayData b integral = \case @@ -1100,86 +1042,28 @@ acoerceOp scale bR (TupRsingle (ArrayR shR aR), Array (sz,sh) adata) = (repr', a VectorIntegralType n t -> vector n t where single :: SingleIntegralType a -> ScalarType b -> ArrayData a -> ArrayData b - single TypeInt8 = scalar' - single TypeInt16 = scalar' - single TypeInt32 = scalar' - single TypeInt64 = scalar' - single TypeInt128 = scalar' - single TypeWord8 = scalar' - single TypeWord16 = scalar' - single TypeWord32 = scalar' - single TypeWord64 = scalar' - single TypeWord128 = scalar' + single TypeInt8 = scalar' @a + single TypeInt16 = scalar' @a + single TypeInt32 = scalar' @a + single TypeInt64 = scalar' @a + single TypeInt128 = scalar' @a + single TypeWord8 = scalar' @a + single TypeWord16 = scalar' @a + single TypeWord32 = scalar' @a + single TypeWord64 = scalar' @a + single TypeWord128 = scalar' @a vector :: (KnownNat n, a ~ Vec n c) => Proxy# n -> SingleIntegralType c -> ScalarType b -> ArrayData (Vec n c) -> ArrayData b - vector _ TypeInt8 = scalar' - vector _ TypeInt16 = scalar' - vector _ TypeInt32 = scalar' - vector _ TypeInt64 = scalar' - vector _ TypeInt128 = scalar' - vector _ TypeWord8 = scalar' - vector _ TypeWord16 = scalar' - vector _ TypeWord32 = scalar' - vector _ TypeWord64 = scalar' - vector _ TypeWord128 = scalar' - - scalar' :: ScalarType b -> ScalarArrayData a -> ArrayData b - scalar' (NumScalarType t) = num' t - scalar' (BitScalarType t) = bit' t - - bit' :: BitType b -> ScalarArrayData a -> ArrayData b - bit' TypeBit = castUniqueArray - bit' TypeMask{} = castUniqueArray - - num' :: NumType b -> ScalarArrayData a -> ArrayData b - num' (IntegralNumType t) = integral' t - num' (FloatingNumType t) = floating' t - - integral' :: IntegralType b -> ScalarArrayData a -> ArrayData b - integral' = \case - SingleIntegralType t -> single' t - VectorIntegralType n t -> vector' n t - where - single' :: SingleIntegralType b -> ScalarArrayData a -> ArrayData b - single' TypeInt8 = castUniqueArray - single' TypeInt16 = castUniqueArray - single' TypeInt32 = castUniqueArray - single' TypeInt64 = castUniqueArray - single' TypeInt128 = castUniqueArray - single' TypeWord8 = castUniqueArray - single' TypeWord16 = castUniqueArray - single' TypeWord32 = castUniqueArray - single' TypeWord64 = castUniqueArray - single' TypeWord128 = castUniqueArray - - vector' :: KnownNat n => Proxy# n -> SingleIntegralType c -> ScalarArrayData a -> ArrayData (Vec n c) - vector' _ TypeInt8 = castUniqueArray - vector' _ TypeInt16 = castUniqueArray - vector' _ TypeInt32 = castUniqueArray - vector' _ TypeInt64 = castUniqueArray - vector' _ TypeInt128 = castUniqueArray - vector' _ TypeWord8 = castUniqueArray - vector' _ TypeWord16 = castUniqueArray - vector' _ TypeWord32 = castUniqueArray - vector' _ TypeWord64 = castUniqueArray - vector' _ TypeWord128 = castUniqueArray - - floating' :: FloatingType b -> ScalarArrayData a -> ArrayData b - floating' = \case - SingleFloatingType t -> single' t - VectorFloatingType n t -> vector' n t - where - single' :: SingleFloatingType b -> ScalarArrayData a -> ArrayData b - single' TypeFloat16 = castUniqueArray - single' TypeFloat32 = castUniqueArray - single' TypeFloat64 = castUniqueArray - single' TypeFloat128 = castUniqueArray - - vector' :: KnownNat n => Proxy# n -> SingleFloatingType c -> ScalarArrayData a -> ArrayData (Vec n c) - vector' _ TypeFloat16 = castUniqueArray - vector' _ TypeFloat32 = castUniqueArray - vector' _ TypeFloat64 = castUniqueArray - vector' _ TypeFloat128 = castUniqueArray + vector _ TypeInt8 = scalar' @a + vector _ TypeInt16 = scalar' @a + vector _ TypeInt32 = scalar' @a + vector _ TypeInt64 = scalar' @a + vector _ TypeInt128 = scalar' @a + vector _ TypeWord8 = scalar' @a + vector _ TypeWord16 = scalar' @a + vector _ TypeWord32 = scalar' @a + vector _ TypeWord64 = scalar' @a + vector _ TypeWord128 = scalar' @a floating :: forall a b. FloatingType a -> ScalarType b -> ArrayData a -> ArrayData b floating = \case @@ -1187,74 +1071,74 @@ acoerceOp scale bR (TupRsingle (ArrayR shR aR), Array (sz,sh) adata) = (repr', a VectorFloatingType n t -> vector n t where single :: SingleFloatingType a -> ScalarType b -> ArrayData a -> ArrayData b - single TypeFloat16 = scalar' - single TypeFloat32 = scalar' - single TypeFloat64 = scalar' - single TypeFloat128 = scalar' + single TypeFloat16 = scalar' @a + single TypeFloat32 = scalar' @a + single TypeFloat64 = scalar' @a + single TypeFloat128 = scalar' @a vector :: (KnownNat n, a ~ Vec n c) => Proxy# n -> SingleFloatingType c -> ScalarType b -> ArrayData (Vec n c) -> ArrayData b - vector _ TypeFloat16 = scalar' - vector _ TypeFloat32 = scalar' - vector _ TypeFloat64 = scalar' - vector _ TypeFloat128 = scalar' - - scalar' :: ScalarType b -> ScalarArrayData a -> ArrayData b - scalar' (NumScalarType t) = num' t - scalar' (BitScalarType t) = bit' t - - bit' :: BitType b -> ScalarArrayData a -> ArrayData b - bit' TypeBit = castUniqueArray - bit' TypeMask{} = castUniqueArray - - num' :: NumType b -> ScalarArrayData a -> ArrayData b - num' (IntegralNumType t) = integral' t - num' (FloatingNumType t) = floating' t - - integral' :: IntegralType b -> ScalarArrayData a -> ArrayData b - integral' = \case - SingleIntegralType t -> single' t - VectorIntegralType n t -> vector' n t - where - single' :: SingleIntegralType b -> ScalarArrayData a -> ArrayData b - single' TypeInt8 = castUniqueArray - single' TypeInt16 = castUniqueArray - single' TypeInt32 = castUniqueArray - single' TypeInt64 = castUniqueArray - single' TypeInt128 = castUniqueArray - single' TypeWord8 = castUniqueArray - single' TypeWord16 = castUniqueArray - single' TypeWord32 = castUniqueArray - single' TypeWord64 = castUniqueArray - single' TypeWord128 = castUniqueArray - - vector' :: KnownNat n => Proxy# n -> SingleIntegralType c -> ScalarArrayData a -> ArrayData (Vec n c) - vector' _ TypeInt8 = castUniqueArray - vector' _ TypeInt16 = castUniqueArray - vector' _ TypeInt32 = castUniqueArray - vector' _ TypeInt64 = castUniqueArray - vector' _ TypeInt128 = castUniqueArray - vector' _ TypeWord8 = castUniqueArray - vector' _ TypeWord16 = castUniqueArray - vector' _ TypeWord32 = castUniqueArray - vector' _ TypeWord64 = castUniqueArray - vector' _ TypeWord128 = castUniqueArray - - floating' :: FloatingType b -> ScalarArrayData a -> ArrayData b - floating' = \case - SingleFloatingType t -> single' t - VectorFloatingType n t -> vector' n t - where - single' :: SingleFloatingType b -> ScalarArrayData a -> ArrayData b - single' TypeFloat16 = castUniqueArray - single' TypeFloat32 = castUniqueArray - single' TypeFloat64 = castUniqueArray - single' TypeFloat128 = castUniqueArray - - vector' :: KnownNat n => Proxy# n -> SingleFloatingType c -> ScalarArrayData a -> ArrayData (Vec n c) - vector' _ TypeFloat16 = castUniqueArray - vector' _ TypeFloat32 = castUniqueArray - vector' _ TypeFloat64 = castUniqueArray - vector' _ TypeFloat128 = castUniqueArray + vector _ TypeFloat16 = scalar' @a + vector _ TypeFloat32 = scalar' @a + vector _ TypeFloat64 = scalar' @a + vector _ TypeFloat128 = scalar' @a + + scalar' :: forall a b. ScalarType b -> ScalarArrayData a -> ArrayData b + scalar' (NumScalarType t) = num' @a t + scalar' (BitScalarType t) = bit' @a t + + num' :: forall a b. NumType b -> ScalarArrayData a -> ArrayData b + num' (IntegralNumType t) = integral' @a t + num' (FloatingNumType t) = floating' @a t + + bit' :: forall a b. BitType b -> ScalarArrayData a -> ArrayData b + bit' TypeBit = castUniqueArray + bit' TypeMask{} = castUniqueArray + + integral' :: forall a b. IntegralType b -> ScalarArrayData a -> ArrayData b + integral' = \case + SingleIntegralType t -> single' t + VectorIntegralType n t -> vector' n t + where + single' :: SingleIntegralType b -> ScalarArrayData a -> ArrayData b + single' TypeInt8 = castUniqueArray + single' TypeInt16 = castUniqueArray + single' TypeInt32 = castUniqueArray + single' TypeInt64 = castUniqueArray + single' TypeInt128 = castUniqueArray + single' TypeWord8 = castUniqueArray + single' TypeWord16 = castUniqueArray + single' TypeWord32 = castUniqueArray + single' TypeWord64 = castUniqueArray + single' TypeWord128 = castUniqueArray + + vector' :: KnownNat n => Proxy# n -> SingleIntegralType c -> ScalarArrayData a -> ArrayData (Vec n c) + vector' _ TypeInt8 = castUniqueArray + vector' _ TypeInt16 = castUniqueArray + vector' _ TypeInt32 = castUniqueArray + vector' _ TypeInt64 = castUniqueArray + vector' _ TypeInt128 = castUniqueArray + vector' _ TypeWord8 = castUniqueArray + vector' _ TypeWord16 = castUniqueArray + vector' _ TypeWord32 = castUniqueArray + vector' _ TypeWord64 = castUniqueArray + vector' _ TypeWord128 = castUniqueArray + + floating' :: forall a b. FloatingType b -> ScalarArrayData a -> ArrayData b + floating' = \case + SingleFloatingType t -> single' t + VectorFloatingType n t -> vector' n t + where + single' :: SingleFloatingType b -> ScalarArrayData a -> ArrayData b + single' TypeFloat16 = castUniqueArray + single' TypeFloat32 = castUniqueArray + single' TypeFloat64 = castUniqueArray + single' TypeFloat128 = castUniqueArray + + vector' :: KnownNat n => Proxy# n -> SingleFloatingType c -> ScalarArrayData a -> ArrayData (Vec n c) + vector' _ TypeFloat16 = castUniqueArray + vector' _ TypeFloat32 = castUniqueArray + vector' _ TypeFloat64 = castUniqueArray + vector' _ TypeFloat128 = castUniqueArray -- Coercion between two scalar types. We require that the size of the source and From 2c4008f94f789cbfaf7dfe190cbfcb3ac5f23914 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Wed, 30 Aug 2023 15:43:11 +0200 Subject: [PATCH 79/86] build fix for ghc < 9.4 --- .../Accelerate/Interpreter/Arithmetic.hs | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs b/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs index 80d02415e..cc914d1a2 100644 --- a/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs +++ b/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs @@ -1,4 +1,5 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} @@ -673,3 +674,49 @@ infixr 0 $$ ($$) :: (b -> a) -> (c -> d -> b) -> c -> d -> a (f $$ g) x y = f (g x y) +#if __GLASGOW_HASKELL__ < 904 +int64ToWord64# :: Int# -> Word# +int64ToWord64# = int2Word# + +word64ToInt64# :: Word# -> Int# +word64ToInt64# = word2Int# +#endif + +#if __GLASGOW_HASKELL__ < 902 +wordToWord8# :: Word# -> Word# +wordToWord8# x = x + +wordToWord16# :: Word# -> Word# +wordToWord16# x = x + +wordToWord32# :: Word# -> Word# +wordToWord32# x = x + +word8ToWord# :: Word# -> Word# +word8ToWord# x = x + +word16ToWord# :: Word# -> Word# +word16ToWord# x = x + +word32ToWord# :: Word# -> Word# +word32ToWord# x = x + +int8ToWord8# :: Int# -> Word# +int8ToWord8# = int2Word# + +int16ToWord16# :: Int# -> Word# +int16ToWord16# = int2Word# + +int32ToWord32# :: Int# -> Word# +int32ToWord32# = int2Word# + +word8ToInt8# :: Word# -> Int# +word8ToInt8# = word2Int# + +word16ToInt16# :: Word# -> Int# +word16ToInt16# = word2Int# + +word32ToInt32# :: Word# -> Int# +word32ToInt32# = word2Int# +#endif + From 73c2e989ee084a1939bd314c4cb0da899292fb30 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Sat, 9 Sep 2023 16:00:35 +0200 Subject: [PATCH 80/86] export acoerceOp --- src/Data/Array/Accelerate.hs | 3 +-- src/Data/Array/Accelerate/Interpreter.hs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 0c6508ad3..a3a25937f 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -51,7 +51,7 @@ -- reference implementation defining the semantics of the Accelerate language -- -- * <http://hackage.haskell.org/package/accelerate-llvm-native accelerate-llvm-native>: --- implementation supporting parallel execution on multicore CPUs (e.g. x86). +-- implementation supporting parallel execution on multicore CPUs (e.g. x86-64, AARCH64). -- -- * <http://hackage.haskell.org/package/accelerate-llvm-ptx accelerate-llvm-ptx>: -- implementation supporting parallel execution on CUDA-capable NVIDIA GPUs. @@ -700,4 +700,3 @@ arrayReshape = S.reshape -- * <https://hackage.haskell.org/package/accelerate-io-serialise accelerate-io-serialise>: binary serialisation of arrays using <https://hackage.haskell.org/package/serialise serialise> -- * <https://hackage.haskell.org/package/accelerate-io-vector accelerate-io-vector>: efficient boxed and unboxed one-dimensional arrays -- - diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 7c5a9583c..101cb6a5c 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -41,7 +41,7 @@ module Data.Array.Accelerate.Interpreter ( run, run1, runN, -- Internal (hidden) - evalPrim, evalBitcastScalar, atraceOp, + evalPrim, evalBitcastScalar, atraceOp, acoerceOp, ) where From 671a7828f666e6e73973e5c63dbb936e365dd35a Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 12 Sep 2023 15:02:02 +0200 Subject: [PATCH 81/86] what was I thinking? --- src/Data/Primitive/Bit.hs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/Data/Primitive/Bit.hs b/src/Data/Primitive/Bit.hs index 773f7feeb..d286054e9 100644 --- a/src/Data/Primitive/Bit.hs +++ b/src/Data/Primitive/Bit.hs @@ -240,12 +240,12 @@ toList (BitMask (Vec ba#)) = concat (unpack 0# []) let q# = quotInt# i# 8# w# = indexWord8Array# ba# q# lim# = minInt# 8# (n# -# i#) - w8 j# = if isTrue# (j# <# lim#) - then let b# = testBitWord8# w# (7# -# j#) - in Bit (isTrue# b#) : w8 (j# +# 1#) - else [] + w8 j# acc' = + if isTrue# (j# <# lim#) + then w8 (j# +# 1#) (Bit (isTrue# (testBitWord8# w# j#)) : acc') + else acc' in - unpack (i# +# 8#) (w8 0# : acc) + unpack (i# +# 8#) (w8 0# [] : acc) | otherwise = acc {-# INLINE fromList #-} @@ -260,14 +260,14 @@ fromList bits = case byteArrayFromListN bytes (pack bits' []) of pack [] acc = acc pack xs acc = let (h,t) = splitAt 8 xs - w = w8 7 0 h + w = w8 0 0 h in pack t (w : acc) w8 :: Int -> Word8 -> [Bit] -> Word8 w8 !_ !w [] = w - w8 !i !w (Bit True :bs) = w8 (i-1) (setBit w i) bs - w8 !i !w (Bit False:bs) = w8 (i-1) w bs + w8 !i !w (Bit True :bs) = w8 (i+1) (setBit w i) bs + w8 !i !w (Bit False:bs) = w8 (i+1) w bs {-# INLINE extract #-} extract :: forall n. KnownNat n => BitMask n -> Int -> Bit From 03a0b608727cdec0f26d691cbe84170fb2624a6b Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Thu, 28 Sep 2023 15:11:09 +0200 Subject: [PATCH 82/86] update acoerceOp --- src/Data/Array/Accelerate/Interpreter.hs | 11 ++++++----- src/Data/Array/Accelerate/Smart.hs | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 101cb6a5c..0e3a2e168 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -215,7 +215,8 @@ evalOpenAcc (AST.Manifest pacc) aenv = (TupRpair r1 r2, (a1, a2)) Anil -> (TupRunit, ()) Atrace msg as bs -> unsafePerformIO $ manifest bs <$ atraceOp msg (snd $ manifest as) - Acoerce scale bR acc -> acoerceOp scale bR (manifest acc) + Acoerce scale bR acc -> let (TupRsingle (ArrayR shR aR), as) = manifest acc + in (TupRsingle (ArrayR shR bR), acoerceOp scale aR bR as) Apply repr afun acc -> (repr, evalOpenAfun afun aenv $ snd $ manifest acc) Aforeign repr _ afun acc -> (repr, evalOpenAfun afun Empty $ snd $ manifest acc) Acond p acc1 acc2 @@ -1000,12 +1001,12 @@ evalOpenExp pexp env aenv = acoerceOp :: HasCallStack => RescaleFactor + -> TypeR a -> TypeR b - -> WithReprs (Array (sh, INT) a) - -> WithReprs (Array (sh, INT) b) -acoerceOp scale bR (TupRsingle (ArrayR shR aR), Array (sz,sh) adata) = (repr', arr') + -> Array (sh, INT) a + -> Array (sh, INT) b +acoerceOp scale aR bR (Array (sz,sh) adata) = arr' where - repr' = TupRsingle (ArrayR shR bR) arr' = Array (sz,sh') adata' sh' = case compare scale 0 of EQ -> sh diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 4940678da..c033c4471 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -1266,7 +1266,7 @@ instance {-# OVERLAPS #-} (IsScalar a, IsScalar b) => Acoerce a b where sa = scalar ta sb = scalar tb sz = case compare sa sb of - EQ -> 0 + EQ -> 0 -- TLM: reuse this value for something else? rescale of ±1 achieves the same thing GT -> sa `quot` sb LT -> negate $ sb `quot` sa -- From 934a12fe022b42d3441dec549c41513c7b8b08d2 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Thu, 28 Sep 2023 15:13:07 +0200 Subject: [PATCH 83/86] fix undef size computation --- .../Array/Accelerate/Representation/Elt.hs | 60 ++++--------------- 1 file changed, 10 insertions(+), 50 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/Elt.hs b/src/Data/Array/Accelerate/Representation/Elt.hs index b888074fa..af2ab2d64 100644 --- a/src/Data/Array/Accelerate/Representation/Elt.hs +++ b/src/Data/Array/Accelerate/Representation/Elt.hs @@ -74,7 +74,10 @@ undefElt = tuple vector :: KnownNat n => Proxy# n -> SingleIntegralType t -> Vec n t vector n t = runST $ do - let bytes = bytesElt (TupRsingle (NumScalarType (IntegralNumType (VectorIntegralType n t)))) + let bits = case t of + TypeInt w -> w + TypeWord w -> w + bytes = max 1 (quot (fromInteger (natVal' n) * bits) 8) mba <- newAlignedPinnedByteArray bytes 16 ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# @@ -92,60 +95,17 @@ undefElt = tuple vector :: KnownNat n => Proxy# n -> SingleFloatingType t -> Vec n t vector n t = runST $ do - let bytes = bytesElt (TupRsingle (NumScalarType (FloatingNumType (VectorFloatingType n t)))) + let bits = case t of + TypeFloat16 -> 16 + TypeFloat32 -> 32 + TypeFloat64 -> 64 + TypeFloat128 -> 128 + bytes = max 1 (quot (fromInteger (natVal' n) * bits) 8) mba <- newAlignedPinnedByteArray bytes 16 ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# -bytesElt :: TypeR e -> Int -bytesElt = tuple - where - tuple :: TypeR t -> Int - tuple TupRunit = 0 - tuple (TupRpair ta tb) = tuple ta + tuple tb - tuple (TupRsingle t) = scalar t - - scalar :: ScalarType t -> Int - scalar (NumScalarType t) = num t - scalar (BitScalarType t) = bit t - - bit :: BitType t -> Int - bit TypeBit = 1 -- stored as Word8 - bit (TypeMask n) = quot (fromInteger (natVal' n)+7) 8 - - num :: NumType t -> Int - num (IntegralNumType t) = integral t - num (FloatingNumType t) = floating t - - integral :: IntegralType t -> Int - integral = \case - SingleIntegralType t -> single t - VectorIntegralType n t -> fromInteger (natVal' n) * single t - where - single :: SingleIntegralType t -> Int - single TypeInt8 = 1 - single TypeInt16 = 2 - single TypeInt32 = 4 - single TypeInt64 = 8 - single TypeInt128 = 16 - single TypeWord8 = 1 - single TypeWord16 = 2 - single TypeWord32 = 4 - single TypeWord64 = 8 - single TypeWord128 = 16 - - floating :: FloatingType t -> Int - floating = \case - SingleFloatingType t -> single t - VectorFloatingType n t -> fromInteger (natVal' n) * single t - where - single :: SingleFloatingType t -> Int - single TypeFloat16 = 2 - single TypeFloat32 = 4 - single TypeFloat64 = 8 - single TypeFloat128 = 16 - showElt :: TypeR e -> e -> String showElt t v = showsElt t v "" From cefdcec1ef97e896f2fb6fe66e3ba88896f3152d Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Thu, 28 Sep 2023 15:53:20 +0200 Subject: [PATCH 84/86] export pack, unpack --- src/Data/Array/Accelerate.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index a3a25937f..a6aad65fd 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -372,7 +372,7 @@ module Data.Array.Accelerate ( fst, afst, snd, asnd, curry, uncurry, -- *** SIMD vectors - splat, insert, extract, shuffle, + splat, pack, unpack, insert, extract, shuffle, -- *** Flow control (?), select, match, cond, while, iterate, From 308fed44dc809d0d9acb5eb3899f246233538f0b Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Thu, 28 Sep 2023 15:53:44 +0200 Subject: [PATCH 85/86] embedding polymorphic Complex constructor --- src/Data/Array/Accelerate/Data/Complex.hs | 338 +++++++++++----------- 1 file changed, 176 insertions(+), 162 deletions(-) diff --git a/src/Data/Array/Accelerate/Data/Complex.hs b/src/Data/Array/Accelerate/Data/Complex.hs index 58c188d87..6872c1db7 100644 --- a/src/Data/Array/Accelerate/Data/Complex.hs +++ b/src/Data/Array/Accelerate/Data/Complex.hs @@ -1,19 +1,20 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RebindableSyntax #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RebindableSyntax #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Data.Complex @@ -30,7 +31,7 @@ module Data.Array.Accelerate.Data.Complex ( -- * Rectangular from - Complex(..), pattern (::+), + Complex, pattern (:+), real, imag, @@ -66,7 +67,7 @@ import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type import qualified Data.Primitive.Vec as Prim -import Data.Complex ( Complex(..) ) +import Data.Complex ( Complex ) import Data.Primitive.Types import Prelude ( ($) ) import qualified Data.Complex as C @@ -76,19 +77,32 @@ import Data.Type.Equality #endif -infix 6 ::+ -pattern (::+) :: Elt a => Exp a -> Exp a -> Exp (Complex a) -pattern r ::+ i <- (deconstructComplex -> (r, i)) - where (::+) = constructComplex -{-# COMPLETE (::+) #-} +infix 6 :+ +pattern (:+) :: IsComplex a b => a -> a -> b +pattern r :+ i <- (matchComplex -> (r,i)) + where (:+) = buildComplex +{-# COMPLETE (:+) :: Complex #-} +{-# COMPLETE (:+) :: Exp #-} + +class IsComplex a b | b -> a where + matchComplex :: b -> (a, a) + buildComplex :: a -> a -> b + +instance IsComplex a (Complex a) where + buildComplex = (C.:+) + matchComplex (r C.:+ i) = (r, i) -- Use an array-of-structs representation for complex numbers if possible. --- This matches the standard C-style layout, but we can use this representation only at --- specific types (not for any type 'a') as we can only have vectors of primitive type. --- For other types, we use a structure-of-arrays representation. This is handled by the --- ComplexR. We use the GADT ComplexR and function complexR to reconstruct --- information on how the elements are represented. +-- +-- This matches the standard C-style layout, but we can use this representation +-- only at specific types (not for any type 'a') as we can only have vectors of +-- primitive type. For other types, we use a structure-of-arrays representation. +-- This is handled by the ComplexR. We use the GADT ComplexR and function +-- complexR to reconstruct information on how the elements are represented. +-- +-- TODO: This is no longer true, we could SIMD-ify more types here. +-- - TLM 2023-09-28 -- instance Elt a => Elt (Complex a) where type EltR (Complex a) = ComplexR (EltR a) @@ -180,108 +194,108 @@ complexR = tuple TypeFloat64 -> ComplexVec numType TypeFloat128 -> ComplexVec numType -constructComplex :: forall a. Elt a => Exp a -> Exp a -> Exp (Complex a) -constructComplex r@(Exp r') i@(Exp i') = - case complexR (eltR @a) of - ComplexTup -> Pattern (r,i) - ComplexVec t -> Exp $ num t r' i' - where - num :: NumType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) - num (IntegralNumType t) = integral t - num (FloatingNumType t) = floating t - - integral :: IntegralType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) - integral (SingleIntegralType t) = case t of - integral (VectorIntegralType n t) = - let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) - in case t of - TypeInt8 -> pack v - TypeInt16 -> pack v - TypeInt32 -> pack v - TypeInt64 -> pack v - TypeInt128 -> pack v - TypeWord8 -> pack v - TypeWord16 -> pack v - TypeWord32 -> pack v - TypeWord64 -> pack v - TypeWord128 -> pack v - - floating :: FloatingType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) - floating (SingleFloatingType t) = case t of - floating (VectorFloatingType n t) = - let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) - in case t of - TypeFloat16 -> pack v - TypeFloat32 -> pack v - TypeFloat64 -> pack v - TypeFloat128 -> pack v - - pack :: ScalarType (Prim.Vec 2 t) -> SmartExp t -> SmartExp t -> SmartExp (Prim.Vec 2 t) - pack v x y - = SmartExp (Insert v TypeWord8 - (SmartExp (Insert v TypeWord8 (SmartExp (Undef v)) (SmartExp (Const scalarType 0)) x)) - (SmartExp (Const scalarType 1)) y) - -deconstructComplex :: forall a. Elt a => Exp (Complex a) -> (Exp a, Exp a) -deconstructComplex (Exp c) = - case complexR (eltR @a) of - ComplexTup -> - let i = SmartExp (Prj PairIdxRight c) - r = SmartExp (Prj PairIdxRight (SmartExp (Prj PairIdxLeft c))) - in (Exp r, Exp i) - ComplexVec t -> - let (r, i) = num t c - in (Exp r, Exp i) - where - num :: NumType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) - num (IntegralNumType t) = integral t - num (FloatingNumType t) = floating t - - integral :: IntegralType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) - integral (SingleIntegralType t) = case t of - integral (VectorIntegralType n t) = - let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) - in case t of - TypeInt8 -> unpack v - TypeInt16 -> unpack v - TypeInt32 -> unpack v - TypeInt64 -> unpack v - TypeInt128 -> unpack v - TypeWord8 -> unpack v - TypeWord16 -> unpack v - TypeWord32 -> unpack v - TypeWord64 -> unpack v - TypeWord128 -> unpack v - - floating :: FloatingType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) - floating (SingleFloatingType t) = case t of - floating (VectorFloatingType n t) = - let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) - in case t of - TypeFloat16 -> unpack v - TypeFloat32 -> unpack v - TypeFloat64 -> unpack v - TypeFloat128 -> unpack v - - unpack :: ScalarType (Prim.Vec 2 t) -> SmartExp (Prim.Vec 2 t) -> (SmartExp t, SmartExp t) - unpack v x = - let r = SmartExp (Extract v TypeWord8 x (SmartExp (Const scalarType 0))) - i = SmartExp (Extract v TypeWord8 x (SmartExp (Const scalarType 1))) - in - (r, i) + +instance Elt a => IsComplex (Exp a) (Exp (Complex a)) where + matchComplex (Exp c) = + case complexR (eltR @a) of + ComplexTup -> + let i = SmartExp (Prj PairIdxRight c) + r = SmartExp (Prj PairIdxRight (SmartExp (Prj PairIdxLeft c))) + in (Exp r, Exp i) + ComplexVec t -> + let (r, i) = num t c + in (Exp r, Exp i) + where + num :: NumType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + integral (SingleIntegralType t) = case t of + integral (VectorIntegralType n t) = + let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) + in case t of + TypeInt8 -> unpack v + TypeInt16 -> unpack v + TypeInt32 -> unpack v + TypeInt64 -> unpack v + TypeInt128 -> unpack v + TypeWord8 -> unpack v + TypeWord16 -> unpack v + TypeWord32 -> unpack v + TypeWord64 -> unpack v + TypeWord128 -> unpack v + + floating :: FloatingType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + floating (SingleFloatingType t) = case t of + floating (VectorFloatingType n t) = + let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) + in case t of + TypeFloat16 -> unpack v + TypeFloat32 -> unpack v + TypeFloat64 -> unpack v + TypeFloat128 -> unpack v + + unpack :: ScalarType (Prim.Vec 2 t) -> SmartExp (Prim.Vec 2 t) -> (SmartExp t, SmartExp t) + unpack v x = + let r = SmartExp (Extract v TypeWord8 x (SmartExp (Const scalarType 0))) + i = SmartExp (Extract v TypeWord8 x (SmartExp (Const scalarType 1))) + in + (r, i) + + buildComplex r@(Exp r') i@(Exp i') = + case complexR (eltR @a) of + ComplexTup -> Pattern (r,i) + ComplexVec t -> Exp $ num t r' i' + where + num :: NumType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + integral (SingleIntegralType t) = case t of + integral (VectorIntegralType n t) = + let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) + in case t of + TypeInt8 -> pack v + TypeInt16 -> pack v + TypeInt32 -> pack v + TypeInt64 -> pack v + TypeInt128 -> pack v + TypeWord8 -> pack v + TypeWord16 -> pack v + TypeWord32 -> pack v + TypeWord64 -> pack v + TypeWord128 -> pack v + + floating :: FloatingType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + floating (SingleFloatingType t) = case t of + floating (VectorFloatingType n t) = + let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) + in case t of + TypeFloat16 -> pack v + TypeFloat32 -> pack v + TypeFloat64 -> pack v + TypeFloat128 -> pack v + + pack :: ScalarType (Prim.Vec 2 t) -> SmartExp t -> SmartExp t -> SmartExp (Prim.Vec 2 t) + pack v x y + = SmartExp (Insert v TypeWord8 + (SmartExp (Insert v TypeWord8 (SmartExp (Undef v)) (SmartExp (Const scalarType 0)) x)) + (SmartExp (Const scalarType 1)) y) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Complex a) where type Plain (Complex a) = Complex (Plain a) - lift (r :+ i) = lift r ::+ lift i + lift (r :+ i) = lift r :+ lift i instance Elt a => Unlift Exp (Complex (Exp a)) where - unlift (r ::+ i) = r :+ i + unlift (r :+ i) = r :+ i instance Eq a => Eq (Complex a) where - r1 ::+ c1 == r2 ::+ c2 = r1 == r2 && c1 == c2 - r1 ::+ c1 /= r2 ::+ c2 = r1 /= r2 || c1 /= c2 + r1 :+ c1 == r2 :+ c2 = r1 == r2 && c1 == c2 + r1 :+ c1 /= r2 :+ c2 = r1 /= r2 || c1 /= c2 instance (RealFloat a, Exponent a ~ Int) => P.Num (Exp (Complex a)) where (+) = case complexR (eltR @a) of @@ -294,20 +308,20 @@ instance (RealFloat a, Exponent a ~ Int) => P.Num (Exp (Complex a)) where negate = case complexR (eltR @a) of ComplexTup -> lift1 (negate :: Complex (Exp a) -> Complex (Exp a)) ComplexVec t -> mkPrimUnary $ PrimNeg t - signum z@(x ::+ y) = + signum z@(x :+ y) = if z == 0 then z else let r = magnitude z - in x/r ::+ y/r - abs z = magnitude z ::+ 0 - fromInteger n = fromInteger n ::+ 0 + in x/r :+ y/r + abs z = magnitude z :+ 0 + fromInteger n = fromInteger n :+ 0 instance (RealFloat a, Exponent a ~ Int) => P.Fractional (Exp (Complex a)) where - fromRational x = fromRational x ::+ 0 - z / z' = (x*x''+y*y'') / d ::+ (y*x''-x*y'') / d + fromRational x = fromRational x :+ 0 + z / z' = (x*x''+y*y'') / d :+ (y*x''-x*y'') / d where - x :+ y = unlift z - x' :+ y' = unlift z' + x :+ y = z + x' :+ y' = z' -- x'' = scaleFloat k x' y'' = scaleFloat k y' @@ -315,14 +329,14 @@ instance (RealFloat a, Exponent a ~ Int) => P.Fractional (Exp (Complex a)) where d = x'*x'' + y'*y'' instance (RealFloat a, Exponent a ~ Int, BitOrMask (EltR a) ~ Bit) => P.Floating (Exp (Complex a)) where - pi = pi ::+ 0 - exp (x ::+ y) = let expx = exp x - in expx * cos y ::+ expx * sin y - log z = log (magnitude z) ::+ phase z - sqrt z@(x ::+ y) = + pi = pi :+ 0 + exp (x :+ y) = let expx = exp x + in expx * cos y :+ expx * sin y + log z = log (magnitude z) :+ phase z + sqrt z@(x :+ y) = if z == 0 then 0 - else u ::+ (y < 0 ? (-v, v)) + else u :+ (y < 0 ? (-v, v)) where T2 u v = x < 0 ? (T2 v' u', T2 u' v') v' = abs y / (u'*2) @@ -331,50 +345,50 @@ instance (RealFloat a, Exponent a ~ Int, BitOrMask (EltR a) ~ Bit) => P.Floating x ** y = if y == 0 then 1 else if x == 0 then if exp_r > 0 then 0 else - if exp_r < 0 then inf ::+ 0 - else nan ::+ nan + if exp_r < 0 then inf :+ 0 + else nan :+ nan else if isInfinite r || isInfinite i - then if exp_r > 0 then inf ::+ 0 else + then if exp_r > 0 then inf :+ 0 else if exp_r < 0 then 0 - else nan ::+ nan + else nan :+ nan else exp (log x * y) where - r ::+ i = x - exp_r ::+ _ = y + r :+ i = x + exp_r :+ _ = y -- inf = 1 / 0 nan = 0 / 0 - sin (x ::+ y) = sin x * cosh y ::+ cos x * sinh y - cos (x ::+ y) = cos x * cosh y ::+ (- sin x * sinh y) - tan (x ::+ y) = (sinx*coshy ::+ cosx*sinhy) / (cosx*coshy ::+ (-sinx*sinhy)) + sin (x :+ y) = sin x * cosh y :+ cos x * sinh y + cos (x :+ y) = cos x * cosh y :+ (- sin x * sinh y) + tan (x :+ y) = (sinx*coshy :+ cosx*sinhy) / (cosx*coshy :+ (-sinx*sinhy)) where sinx = sin x cosx = cos x sinhy = sinh y coshy = cosh y - sinh (x ::+ y) = cos y * sinh x ::+ sin y * cosh x - cosh (x ::+ y) = cos y * cosh x ::+ sin y * sinh x - tanh (x ::+ y) = (cosy*sinhx ::+ siny*coshx) / (cosy*coshx ::+ siny*sinhx) + sinh (x :+ y) = cos y * sinh x :+ sin y * cosh x + cosh (x :+ y) = cos y * cosh x :+ sin y * sinh x + tanh (x :+ y) = (cosy*sinhx :+ siny*coshx) / (cosy*coshx :+ siny*sinhx) where siny = sin y cosy = cos y sinhx = sinh x coshx = cosh x - asin z@(x ::+ y) = y' ::+ (-x') + asin z@(x :+ y) = y' :+ (-x') where - x' ::+ y' = log (((-y) ::+ x) + sqrt (1 - z*z)) + x' :+ y' = log (((-y) :+ x) + sqrt (1 - z*z)) - acos z = y'' ::+ (-x'') + acos z = y'' :+ (-x'') where - x'' ::+ y'' = log (z + ((-y') ::+ x')) - x' ::+ y' = sqrt (1 - z*z) + x'' :+ y'' = log (z + ((-y') :+ x')) + x' :+ y' = sqrt (1 - z*z) - atan z@(x ::+ y) = y' ::+ (-x') + atan z@(x :+ y) = y' :+ (-x') where - x' ::+ y' = log (((1-y) ::+ x) / sqrt (1+z*z)) + x' :+ y' = log (((1-y) :+ x) / sqrt (1+z*z)) asinh z = log (z + sqrt (1+z*z)) acosh z = log (z + (z+1) * sqrt ((z-1)/(z+1))) @@ -382,18 +396,18 @@ instance (RealFloat a, Exponent a ~ Int, BitOrMask (EltR a) ~ Bit) => P.Floating instance (FromIntegral a b, Num b, Elt (Complex b)) => FromIntegral a (Complex b) where - fromIntegral x = fromIntegral x ::+ 0 + fromIntegral x = fromIntegral x :+ 0 -- | @since 1.2.0.0 -- instance Functor Complex where - fmap f (r ::+ i) = f r ::+ f i + fmap f (r :+ i) = f r :+ f i -- | The non-negative magnitude of a complex number -- magnitude :: (RealFloat a, Exponent a ~ Int) => Exp (Complex a) -> Exp a -magnitude (r ::+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloat mk i))) +magnitude (r :+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloat mk i))) where k = max (exponent r) (exponent i) mk = -k @@ -405,13 +419,13 @@ magnitude (r ::+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloa -- @since 1.3.0.0 -- magnitude' :: RealFloat a => Exp (Complex a) -> Exp a -magnitude' (r ::+ i) = sqrt (r*r + i*i) +magnitude' (r :+ i) = sqrt (r*r + i*i) -- | The phase of a complex number, in the range @(-'pi', 'pi']@. If the -- magnitude is zero, then so is the phase. -- phase :: RealFloat a => Exp (Complex a) -> Exp a -phase (r ::+ i) = +phase (r :+ i) = if r == 0 && i == 0 then 0 else atan2 i r @@ -437,17 +451,17 @@ cis = lift1 (C.cis :: Exp a -> Complex (Exp a)) -- | Return the real part of a complex number -- real :: Elt a => Exp (Complex a) -> Exp a -real (r ::+ _) = r +real (r :+ _) = r -- | Return the imaginary part of a complex number -- imag :: Elt a => Exp (Complex a) -> Exp a -imag (_ ::+ i) = i +imag (_ :+ i) = i -- | Return the complex conjugate of a complex number, defined as -- -- > conjugate(Z) = X - iY -- conjugate :: Num a => Exp (Complex a) -> Exp (Complex a) -conjugate z = real z ::+ (- imag z) +conjugate z = real z :+ (- imag z) From 2784ec6b212fba1aa49b4860d71c7a6bf2043a1f Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" <trevor.mcdonell@gmail.com> Date: Tue, 3 Oct 2023 17:51:36 +0200 Subject: [PATCH 86/86] embedding polymorphic containers from Monoid & Semigroup --- src/Data/Array/Accelerate/Data/Monoid.hs | 128 ++++++++++++-------- src/Data/Array/Accelerate/Data/Semigroup.hs | 128 ++++++++++++-------- 2 files changed, 158 insertions(+), 98 deletions(-) diff --git a/src/Data/Array/Accelerate/Data/Monoid.hs b/src/Data/Array/Accelerate/Data/Monoid.hs index 4cb6c9b10..20a13aa76 100644 --- a/src/Data/Array/Accelerate/Data/Monoid.hs +++ b/src/Data/Array/Accelerate/Data/Monoid.hs @@ -1,16 +1,17 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} #if __GLASGOW_HASKELL__ >= 806 -{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableInstances #-} #endif -- | -- Module : Data.Array.Accelerate.Data.Monoid @@ -30,8 +31,8 @@ module Data.Array.Accelerate.Data.Monoid ( Monoid(..), (<>), - Sum(..), pattern Sum_, - Product(..), pattern Product_, + Sum, pattern Sum, + Product, pattern Product, ) where @@ -49,30 +50,45 @@ import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type import Data.Function -import Data.Monoid hiding ( (<>) ) -import Data.Semigroup +import Data.Monoid ( Monoid(..), Product, Sum ) +import Data.Semigroup ( Semigroup(..) ) import qualified Prelude as P +import qualified Data.Monoid as P -- Sum: Monoid under addition -- -------------------------- -pattern Sum_ :: Elt a => Exp a -> Exp (Sum a) -pattern Sum_ x = Pattern x -{-# COMPLETE Sum_ #-} +pattern Sum :: IsSum a b => a -> b +pattern Sum x <- (matchSum -> x) + where Sum = buildSum +{-# COMPLETE Sum :: Sum #-} +{-# COMPLETE Sum :: Exp #-} + +class IsSum a b | b -> a where + matchSum :: b -> a + buildSum :: a -> b + +instance IsSum a (Sum a) where + matchSum = P.getSum + buildSum = P.Sum + +instance Elt a => IsSum (Exp a) (Exp (Sum a)) where + matchSum (Pattern x) = x + buildSum x = Pattern x instance Elt a => Elt (Sum a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Sum a) where type Plain (Sum a) = Sum (Plain a) - lift (Sum a) = Sum_ (lift a) + lift (Sum a) = Sum (lift a) instance Elt a => Unlift Exp (Sum (Exp a)) where - unlift (Sum_ a) = Sum a + unlift (Sum a) = Sum a instance Bounded a => P.Bounded (Exp (Sum a)) where - minBound = Sum_ minBound - maxBound = Sum_ maxBound + minBound = Sum minBound + maxBound = Sum maxBound instance Num a => P.Num (Exp (Sum a)) where (+) = lift2 ((+) :: Sum (Exp a) -> Sum (Exp a) -> Sum (Exp a)) @@ -84,45 +100,59 @@ instance Num a => P.Num (Exp (Sum a)) where fromInteger x = lift (P.fromInteger x :: Sum (Exp a)) instance Eq a => Eq (Sum a) where - (==) = lift2 ((==) `on` getSum) - (/=) = lift2 ((/=) `on` getSum) + (==) = lift2 ((==) `on` P.getSum) + (/=) = lift2 ((/=) `on` P.getSum) instance Ord a => Ord (Sum a) where - (<) = lift2 ((<) `on` getSum) - (>) = lift2 ((>) `on` getSum) - (<=) = lift2 ((<=) `on` getSum) - (>=) = lift2 ((>=) `on` getSum) - min x y = Sum_ $ lift2 (min `on` getSum) x y - max x y = Sum_ $ lift2 (max `on` getSum) x y + (<) = lift2 ((<) `on` P.getSum) + (>) = lift2 ((>) `on` P.getSum) + (<=) = lift2 ((<=) `on` P.getSum) + (>=) = lift2 ((>=) `on` P.getSum) + min x y = Sum $ lift2 (min `on` P.getSum) x y + max x y = Sum $ lift2 (max `on` P.getSum) x y instance Num a => Monoid (Exp (Sum a)) where mempty = 0 -- | @since 1.2.0.0 instance Num a => Semigroup (Exp (Sum a)) where - (<>) = (+) - stimes n (Sum_ x) = Sum_ $ P.fromIntegral n * x + (<>) = (+) + stimes n (Sum x) = Sum $ P.fromIntegral n * x -- Product: Monoid under multiplication -- ------------------------------------ -pattern Product_ :: Elt a => Exp a -> Exp (Product a) -pattern Product_ x = Pattern x -{-# COMPLETE Product_ #-} +pattern Product :: IsProduct a b => a -> b +pattern Product x <- (matchProduct -> x) + where Product = buildProduct +{-# COMPLETE Product :: Product #-} +{-# COMPLETE Product :: Exp #-} + +class IsProduct a b | b -> a where + matchProduct :: b -> a + buildProduct :: a -> b + +instance IsProduct a (Product a) where + matchProduct = P.getProduct + buildProduct = P.Product + +instance Elt a => IsProduct (Exp a) (Exp (Product a)) where + matchProduct (Pattern x) = x + buildProduct x = Pattern x instance Elt a => Elt (Product a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Product a) where type Plain (Product a) = Product (Plain a) - lift (Product a) = Product_ (lift a) + lift (Product a) = Product (lift a) instance Elt a => Unlift Exp (Product (Exp a)) where - unlift (Product_ a) = Product a + unlift (Product a) = Product a instance Bounded a => P.Bounded (Exp (Product a)) where - minBound = Product_ minBound - maxBound = Product_ maxBound + minBound = Product minBound + maxBound = Product maxBound instance Num a => P.Num (Exp (Product a)) where (+) = lift2 ((+) :: Product (Exp a) -> Product (Exp a) -> Product (Exp a)) @@ -134,24 +164,24 @@ instance Num a => P.Num (Exp (Product a)) where fromInteger x = lift (P.fromInteger x :: Product (Exp a)) instance Eq a => Eq (Product a) where - (==) = lift2 ((==) `on` getProduct) - (/=) = lift2 ((/=) `on` getProduct) + (==) = lift2 ((==) `on` P.getProduct) + (/=) = lift2 ((/=) `on` P.getProduct) instance Ord a => Ord (Product a) where - (<) = lift2 ((<) `on` getProduct) - (>) = lift2 ((>) `on` getProduct) - (<=) = lift2 ((<=) `on` getProduct) - (>=) = lift2 ((>=) `on` getProduct) - min x y = Product_ $ lift2 (min `on` getProduct) x y - max x y = Product_ $ lift2 (max `on` getProduct) x y + (<) = lift2 ((<) `on` P.getProduct) + (>) = lift2 ((>) `on` P.getProduct) + (<=) = lift2 ((<=) `on` P.getProduct) + (>=) = lift2 ((>=) `on` P.getProduct) + min x y = Product $ lift2 (min `on` P.getProduct) x y + max x y = Product $ lift2 (max `on` P.getProduct) x y instance Num a => Monoid (Exp (Product a)) where mempty = 1 -- | @since 1.2.0.0 instance Num a => Semigroup (Exp (Product a)) where - (<>) = (*) - stimes n (Product_ x) = Product_ $ x ^ (P.fromIntegral n :: Exp Int) + (<>) = (*) + stimes n (Product x) = Product $ x ^ (P.fromIntegral n :: Exp Int) -- Instances for unit and tuples diff --git a/src/Data/Array/Accelerate/Data/Semigroup.hs b/src/Data/Array/Accelerate/Data/Semigroup.hs index 030c51243..82ecdebe5 100644 --- a/src/Data/Array/Accelerate/Data/Semigroup.hs +++ b/src/Data/Array/Accelerate/Data/Semigroup.hs @@ -1,17 +1,18 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RebindableSyntax #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RebindableSyntax #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} #if __GLASGOW_HASKELL__ >= 806 -{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableInstances #-} #endif -- | -- Module : Data.Array.Accelerate.Data.Semigroup @@ -31,8 +32,8 @@ module Data.Array.Accelerate.Data.Semigroup ( Semigroup(..), - Min(..), pattern Min_, - Max(..), pattern Max_, + Min, pattern Min, + Max, pattern Max, ) where @@ -47,26 +48,41 @@ import Data.Array.Accelerate.Sugar.Elt import Data.Function import Data.Monoid ( Monoid(..) ) -import Data.Semigroup +import Data.Semigroup ( Semigroup(..), Min, Max ) import qualified Prelude as P +import qualified Data.Semigroup as P -pattern Min_ :: Elt a => Exp a -> Exp (Min a) -pattern Min_ x = Pattern x -{-# COMPLETE Min_ #-} +pattern Min :: IsMin a b => a -> b +pattern Min x <- (matchMin -> x) + where Min = buildMin +{-# COMPLETE Min :: Min #-} +{-# COMPLETE Min :: Exp #-} + +class IsMin a b | b -> a where + matchMin :: b -> a + buildMin :: a -> b + +instance IsMin a (Min a) where + matchMin = P.getMin + buildMin = P.Min + +instance Elt a => IsMin (Exp a) (Exp (Min a)) where + matchMin (Pattern x) = x + buildMin x = Pattern x instance Elt a => Elt (Min a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Min a) where type Plain (Min a) = Min (Plain a) - lift (Min a) = Min_ (lift a) + lift (Min a) = Min (lift a) instance Elt a => Unlift Exp (Min (Exp a)) where - unlift (Min_ a) = Min a + unlift (Min a) = Min a instance Bounded a => P.Bounded (Exp (Min a)) where - minBound = lift $ Min (minBound :: Exp a) - maxBound = lift $ Min (maxBound :: Exp a) + minBound = Min minBound + maxBound = Min maxBound instance Num a => P.Num (Exp (Min a)) where (+) = lift2 ((+) :: Min (Exp a) -> Min (Exp a) -> Min (Exp a)) @@ -78,42 +94,56 @@ instance Num a => P.Num (Exp (Min a)) where fromInteger x = lift (P.fromInteger x :: Min (Exp a)) instance Eq a => Eq (Min a) where - (==) = lift2 ((==) `on` getMin) - (/=) = lift2 ((/=) `on` getMin) + (==) = lift2 ((==) `on` P.getMin) + (/=) = lift2 ((/=) `on` P.getMin) instance Ord a => Ord (Min a) where - (<) = lift2 ((<) `on` getMin) - (>) = lift2 ((>) `on` getMin) - (<=) = lift2 ((<=) `on` getMin) - (>=) = lift2 ((>=) `on` getMin) - min x y = lift . Min $ lift2 (min `on` getMin) x y - max x y = lift . Min $ lift2 (max `on` getMin) x y + (<) = lift2 ((<) `on` P.getMin) + (>) = lift2 ((>) `on` P.getMin) + (<=) = lift2 ((<=) `on` P.getMin) + (>=) = lift2 ((>=) `on` P.getMin) + min x y = Min $ lift2 (min `on` P.getMin) x y + max x y = Min $ lift2 (max `on` P.getMin) x y instance Ord a => Semigroup (Exp (Min a)) where - x <> y = lift . Min $ lift2 (min `on` getMin) x y - stimes = stimesIdempotent + x <> y = Min $ lift2 (min `on` P.getMin) x y + stimes = P.stimesIdempotent instance (Ord a, Bounded a) => Monoid (Exp (Min a)) where mempty = maxBound mappend = (<>) -pattern Max_ :: Elt a => Exp a -> Exp (Max a) -pattern Max_ x = Pattern x -{-# COMPLETE Max_ #-} +pattern Max :: IsMax a b => a -> b +pattern Max x <- (matchMax -> x) + where Max = buildMax +{-# COMPLETE Max :: Max #-} +{-# COMPLETE Max :: Exp #-} + +class IsMax a b | b -> a where + matchMax :: b -> a + buildMax :: a -> b + +instance IsMax a (Max a) where + matchMax = P.getMax + buildMax = P.Max + +instance Elt a => IsMax (Exp a) (Exp (Max a)) where + matchMax (Pattern x) = x + buildMax x = Pattern x instance Elt a => Elt (Max a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Max a) where type Plain (Max a) = Max (Plain a) - lift (Max a) = Max_ (lift a) + lift (Max a) = Max (lift a) instance Elt a => Unlift Exp (Max (Exp a)) where - unlift (Max_ a) = Max a + unlift (Max a) = Max a instance Bounded a => P.Bounded (Exp (Max a)) where - minBound = Max_ minBound - maxBound = Max_ maxBound + minBound = Max minBound + maxBound = Max maxBound instance Num a => P.Num (Exp (Max a)) where (+) = lift2 ((+) :: Max (Exp a) -> Max (Exp a) -> Max (Exp a)) @@ -125,20 +155,20 @@ instance Num a => P.Num (Exp (Max a)) where fromInteger x = lift (P.fromInteger x :: Max (Exp a)) instance Eq a => Eq (Max a) where - (==) = lift2 ((==) `on` getMax) - (/=) = lift2 ((/=) `on` getMax) + (==) = lift2 ((==) `on` P.getMax) + (/=) = lift2 ((/=) `on` P.getMax) instance Ord a => Ord (Max a) where - (<) = lift2 ((<) `on` getMax) - (>) = lift2 ((>) `on` getMax) - (<=) = lift2 ((<=) `on` getMax) - (>=) = lift2 ((>=) `on` getMax) - min x y = Max_ $ lift2 (min `on` getMax) x y - max x y = Max_ $ lift2 (max `on` getMax) x y + (<) = lift2 ((<) `on` P.getMax) + (>) = lift2 ((>) `on` P.getMax) + (<=) = lift2 ((<=) `on` P.getMax) + (>=) = lift2 ((>=) `on` P.getMax) + min x y = Max $ lift2 (min `on` P.getMax) x y + max x y = Max $ lift2 (max `on` P.getMax) x y instance Ord a => Semigroup (Exp (Max a)) where - x <> y = Max_ $ lift2 (max `on` getMax) x y - stimes = stimesIdempotent + x <> y = Max $ lift2 (max `on` P.getMax) x y + stimes = P.stimesIdempotent instance (Ord a, Bounded a) => Monoid (Exp (Max a)) where mempty = minBound