Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector indexing and insertion operations #509

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@
/docs/_build
*.hi
*.o

hie.yaml
1 change: 1 addition & 0 deletions accelerate.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ library
Data.Array.Accelerate.Classes.RealFloat
Data.Array.Accelerate.Classes.RealFrac
Data.Array.Accelerate.Classes.ToFloating
Data.Array.Accelerate.Classes.Vector
Data.Array.Accelerate.Debug.Internal.Clock
Data.Array.Accelerate.Debug.Internal.Flags
Data.Array.Accelerate.Debug.Internal.Graph
Expand Down
5 changes: 5 additions & 0 deletions src/Data/Array/Accelerate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -310,13 +310,17 @@ module Data.Array.Accelerate (

-- ** SIMD vectors
Vec, VecElt,
Vectoring(..),
vecOfList,
listOfVec,

-- ** Type classes
-- *** Basic type classes
Eq(..),
Ord(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_,
Enum, succ, pred,
Bounded, minBound, maxBound,

-- Functor(..), (<$>), ($>), void,
-- Monad(..),

Expand Down Expand Up @@ -445,6 +449,7 @@ import Data.Array.Accelerate.Classes.Rational
import Data.Array.Accelerate.Classes.RealFloat
import Data.Array.Accelerate.Classes.RealFrac
import Data.Array.Accelerate.Classes.ToFloating
import Data.Array.Accelerate.Classes.Vector
import Data.Array.Accelerate.Data.Either
import Data.Array.Accelerate.Data.Maybe
import Data.Array.Accelerate.Language
Expand Down
47 changes: 34 additions & 13 deletions src/Data/Array/Accelerate/AST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ import Data.Array.Accelerate.Sugar.Foreign
import Data.Array.Accelerate.Type
import Data.Primitive.Vec

import Data.Primitive.Types
import Control.DeepSeq
import Data.Kind
import Data.Maybe
Expand Down Expand Up @@ -655,7 +656,6 @@ data PrimConst ty where
-- constant from Floating
PrimPi :: FloatingType a -> PrimConst a


-- |Primitive scalar operations
--
data PrimFun sig where
Expand Down Expand Up @@ -748,6 +748,10 @@ data PrimFun sig where
PrimLOr :: PrimFun ((PrimBool, PrimBool) -> PrimBool)
PrimLNot :: PrimFun (PrimBool -> PrimBool)

-- local array operators
HugoPeters1024 marked this conversation as resolved.
Show resolved Hide resolved
PrimVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a)
PrimVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, (i, a)) -> Vec n a)

-- general conversion between types
PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b)
PrimToFloating :: NumType a -> FloatingType b -> PrimFun (a -> b)
Expand Down Expand Up @@ -825,7 +829,7 @@ expType = \case
While _ (Lam lhs _) _ -> lhsToTupR lhs
While{} -> error "What's the matter, you're running in the shadows"
Const tR _ -> TupRsingle tR
PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c
PrimConst c -> TupRsingle $ primConstType c
PrimApp f _ -> snd $ primFunType f
Index (Var repr _) _ -> arrayRtype repr
LinearIndex (Var repr _) _ -> arrayRtype repr
Expand All @@ -834,17 +838,20 @@ expType = \case
Undef tR -> TupRsingle tR
Coerce _ tR _ -> TupRsingle tR

primConstType :: PrimConst a -> SingleType a
primConstType :: PrimConst a -> ScalarType a
primConstType = \case
PrimMinBound t -> bounded t
PrimMaxBound t -> bounded t
PrimPi t -> floating t
where
bounded :: BoundedType a -> SingleType a
bounded (IntegralBoundedType t) = NumSingleType $ IntegralNumType t
bounded :: BoundedType a -> ScalarType a
bounded (IntegralBoundedType t) = SingleScalarType $ NumSingleType $ IntegralNumType t

floating :: FloatingType t -> ScalarType t
floating = SingleScalarType . NumSingleType . FloatingNumType

floating :: FloatingType t -> SingleType t
floating = NumSingleType . FloatingNumType
vector :: forall n a. (KnownNat n) => VectorType (Vec n a) -> ScalarType (Vec n a)
vector = VectorScalarType

primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b)
primFunType = \case
Expand Down Expand Up @@ -924,6 +931,17 @@ primFunType = \case
PrimLOr -> binary' tbool
PrimLNot -> unary' tbool

-- Local Vector operations
PrimVectorIndex v'@(VectorType _ a) i' ->
let v = singleVector v'
i = integral i'
in (v `TupRpair` i, single a)

PrimVectorWrite v'@(VectorType _ a) i' ->
let v = singleVector v'
i = integral i'
in (v `TupRpair` (i `TupRpair` single a), v)

-- general conversion between types
PrimFromIntegral a b -> unary (integral a) (num b)
PrimToFloating a b -> unary (num a) (floating b)
Expand All @@ -936,6 +954,7 @@ primFunType = \case
compare' a = binary (single a) tbool

single = TupRsingle . SingleScalarType
singleVector = TupRsingle . VectorScalarType
num = TupRsingle . SingleScalarType . NumSingleType
integral = num . IntegralNumType
floating = num . FloatingNumType
Expand Down Expand Up @@ -1100,9 +1119,9 @@ rnfConst (TupRsingle t) !_ = rnfScalarType t -- scalars should have (nf =
rnfConst (TupRpair ta tb) (a,b) = rnfConst ta a `seq` rnfConst tb b

rnfPrimConst :: PrimConst c -> ()
rnfPrimConst (PrimMinBound t) = rnfBoundedType t
rnfPrimConst (PrimMaxBound t) = rnfBoundedType t
rnfPrimConst (PrimPi t) = rnfFloatingType t
rnfPrimConst (PrimMinBound t) = rnfBoundedType t
rnfPrimConst (PrimMaxBound t) = rnfBoundedType t
rnfPrimConst (PrimPi t) = rnfFloatingType t

rnfPrimFun :: PrimFun f -> ()
rnfPrimFun (PrimAdd t) = rnfNumType t
Expand Down Expand Up @@ -1165,6 +1184,7 @@ rnfPrimFun (PrimMin t) = rnfSingleType t
rnfPrimFun PrimLAnd = ()
rnfPrimFun PrimLOr = ()
rnfPrimFun PrimLNot = ()
rnfPrimFun (PrimVectorIndex v i) = rnfVectorType v `seq` rnfIntegralType i
rnfPrimFun (PrimFromIntegral i n) = rnfIntegralType i `seq` rnfNumType n
rnfPrimFun (PrimToFloating n f) = rnfNumType n `seq` rnfFloatingType f

Expand Down Expand Up @@ -1326,9 +1346,9 @@ liftBoundary (ArrayR _ tp) (Constant v) = [|| Constant $$(liftElt tp v) ||]
liftBoundary _ (Function f) = [|| Function $$(liftOpenFun f) ||]

liftPrimConst :: PrimConst c -> CodeQ (PrimConst c)
liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||]
liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||]
liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||]
liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||]
liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||]
liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||]

liftPrimFun :: PrimFun f -> CodeQ (PrimFun f)
liftPrimFun (PrimAdd t) = [|| PrimAdd $$(liftNumType t) ||]
Expand Down Expand Up @@ -1391,6 +1411,7 @@ liftPrimFun (PrimMin t) = [|| PrimMin $$(liftSingleType t) ||]
liftPrimFun PrimLAnd = [|| PrimLAnd ||]
liftPrimFun PrimLOr = [|| PrimLOr ||]
liftPrimFun PrimLNot = [|| PrimLNot ||]
liftPrimFun (PrimVectorIndex v i) = [|| PrimVectorIndex $$(liftVectorType v) $$(liftIntegralType i) ||]
liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||]
liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||]

Expand Down
2 changes: 2 additions & 0 deletions src/Data/Array/Accelerate/Analysis/Hash.hs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,8 @@ encodePrimFun (PrimEq a) = intHost $(hashQ "PrimEq")
encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") <> encodeSingleType a
encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeSingleType a
encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeSingleType a
encodePrimFun (PrimVectorIndex (VectorType _ a) b) = intHost $(hashQ "PrimVectorIndex") <> encodeSingleType a <> encodeNumType (IntegralNumType b)
encodePrimFun (PrimVectorWrite (VectorType _ a) b) = intHost $(hashQ "PrimVectorWrite") <> encodeSingleType a <> encodeNumType (IntegralNumType b)
encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b
encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b
encodePrimFun PrimLAnd = intHost $(hashQ "PrimLAnd")
Expand Down
3 changes: 1 addition & 2 deletions src/Data/Array/Accelerate/Classes/Enum.hs
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ defaultFromEnum = preludeError "fromEnum"
preludeError :: String -> a
preludeError x
= error
$ unlines [ printf "Prelude.%s is not supported for Accelerate types" x
, ""
$ unlines [ printf "Prelude.%s is not supported for Accelerate types" x , ""
, "These Prelude.Enum instances are present only to fulfil superclass"
, "constraints for subsequent classes in the standard Haskell numeric hierarchy."
]
Expand Down
33 changes: 33 additions & 0 deletions src/Data/Array/Accelerate/Classes/Vector.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module : Data.Array.Accelerate.Classes.Vector
-- Copyright : [2016..2020] The Accelerate Team
-- License : BSD3
--
-- Maintainer : Trevor L. McDonell <[email protected]>
-- Stability : experimental
-- Portability : non-portable (GHC extensions)
--
module Data.Array.Accelerate.Classes.Vector where

import Data.Kind
import GHC.TypeLits
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Smart
import Data.Primitive.Vec

instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where
type IndexType (Exp (Vec n a)) = Exp Int
vecIndex = mkVectorIndex
vecWrite = mkVectorWrite
vecEmpty = undef


12 changes: 12 additions & 0 deletions src/Data/Array/Accelerate/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import qualified Data.Array.Accelerate.Sugar.Array as Sugar
import qualified Data.Array.Accelerate.Sugar.Elt as Sugar
import qualified Data.Array.Accelerate.Trafo.Delayed as AST

import GHC.TypeLits
import Control.DeepSeq
import Control.Exception
import Control.Monad
Expand Down Expand Up @@ -1144,6 +1145,8 @@ evalPrim (PrimMin ty) = evalMin ty
evalPrim PrimLAnd = evalLAnd
evalPrim PrimLOr = evalLOr
evalPrim PrimLNot = evalLNot
evalPrim (PrimVectorIndex v i) = evalVectorIndex v i
evalPrim (PrimVectorWrite v i) = evalVectorWrite v i
evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb
evalPrim (PrimToFloating ta tb) = evalToFloating ta tb

Expand All @@ -1168,6 +1171,12 @@ evalLOr (x, y) = fromBool (toBool x || toBool y)
evalLNot :: PrimBool -> PrimBool
evalLNot = fromBool . not . toBool

evalVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, i) -> a
evalVectorIndex (VectorType n _) ti (v, i) | IntegralDict <- integralDict ti = vecIndex v (fromIntegral i)

evalVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, (i, a)) -> Vec n a
evalVectorWrite (VectorType n _) ti (v, (i, a)) | IntegralDict <- integralDict ti = vecWrite v (fromIntegral i) a

evalFromIntegral :: IntegralType a -> NumType b -> a -> b
evalFromIntegral ta (IntegralNumType tb)
| IntegralDict <- integralDict ta
Expand Down Expand Up @@ -1213,6 +1222,9 @@ evalMaxBound (IntegralBoundedType ty)
evalPi :: FloatingType a -> a
evalPi ty | FloatingDict <- floatingDict ty = pi

evalVectorCreate :: (KnownNat n, Prim a) => VectorType (Vec n a) -> Vec n a
evalVectorCreate (VectorType n _) = vecEmpty

evalSin :: FloatingType a -> (a -> a)
evalSin ty | FloatingDict <- floatingDict ty = sin

Expand Down
25 changes: 23 additions & 2 deletions src/Data/Array/Accelerate/Smart.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
Expand All @@ -12,6 +12,7 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PolyKinds #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module : Data.Array.Accelerate.Smart
Expand Down Expand Up @@ -71,6 +72,10 @@ module Data.Array.Accelerate.Smart (
-- ** Smart constructors for type coercion functions
mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..),

-- ** Smart constructors for vector operations
mkVectorIndex,
mkVectorWrite,

-- ** Auxiliary functions
($$), ($$$), ($$$$), ($$$$$),
ApplyAcc(..),
Expand All @@ -83,6 +88,7 @@ module Data.Array.Accelerate.Smart (
) where


import Data.Proxy
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array
Expand All @@ -95,6 +101,7 @@ import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Sugar.Array ( Arrays )
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Sugar.Foreign
import Data.Array.Accelerate.Sugar.Shape ( (:.)(..) )
import Data.Array.Accelerate.Type
Expand Down Expand Up @@ -859,7 +866,7 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where
Case{} -> internalError "encountered empty case"
Cond _ e _ -> typeR e
While t _ _ _ -> t
PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c
PrimConst c -> TupRsingle $ primConstType c
PrimApp f _ -> snd $ primFunType f
Index tp _ _ -> tp
LinearIndex tp _ _ -> tp
Expand Down Expand Up @@ -1172,6 +1179,17 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil
where
x = SmartExp $ Prj PairIdxLeft a

-- Operators from Vec
mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a
mkVectorIndex = let n :: Int
n = fromIntegral $ natVal $ Proxy @n
in mkPrimBinary $ PrimVectorIndex @n (VectorType n singleType) integralType

mkVectorWrite :: forall n a. (KnownNat n, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -> Exp (Vec n a)
mkVectorWrite = let n :: Int
n = fromIntegral $ natVal $ Proxy @n
in mkPrimTernary $ PrimVectorWrite @n (VectorType n singleType) integralType
HugoPeters1024 marked this conversation as resolved.
Show resolved Hide resolved

-- Numeric conversions

mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b
Expand Down Expand Up @@ -1259,6 +1277,9 @@ mkPrimUnary prim (Exp a) = mkExp $ PrimApp prim a
mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c
mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b)

mkPrimTernary :: (Elt a, Elt b, Elt c, Elt d) => PrimFun ((EltR a, (EltR b, EltR c)) -> EltR d) -> Exp a -> Exp b -> Exp c -> Exp d
mkPrimTernary prim (Exp a) (Exp b) (Exp c) = mkExp $ PrimApp prim (SmartExp $ Pair a (SmartExp (Pair b c)))

mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool
mkPrimUnaryBool = mkCoerce @PrimBool $$ mkPrimUnary

Expand Down
4 changes: 4 additions & 0 deletions src/Data/Array/Accelerate/Trafo/Algebra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Pretty.Print ( primOperator, isInfix, opName )
import Data.Array.Accelerate.Trafo.Environment
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Classes.Vector

import qualified Data.Array.Accelerate.Debug.Internal.Stats as Stats

import Data.Bits
import Data.Monoid
import Data.Text ( Text )
import Data.Primitive.Vec
import Data.Text.Prettyprint.Doc
import Data.Text.Prettyprint.Doc.Render.Text
import GHC.Float ( float2Double, double2Float )
Expand Down Expand Up @@ -142,6 +144,8 @@ evalPrimApp env f x
PrimNEq ty -> evalNEq ty x env
PrimMax ty -> evalMax ty x env
PrimMin ty -> evalMin ty x env
PrimVectorIndex _ _ -> Nothing
PrimVectorWrite _ _ -> Nothing
PrimLAnd -> evalLAnd x env
PrimLOr -> evalLOr x env
PrimLNot -> evalLNot x env
Expand Down
Loading