Skip to content

Commit

Permalink
Optimize Data.Array.Mutable.Unlifted.Linear.map (#334)
Browse files Browse the repository at this point in the history
* Split map and toList benchmarks

* Optimize Data.Array.Mutable.Unlifted.Linear.map

Co-authored-by: Arnaud Spiwack <[email protected]>
  • Loading branch information
utdemir and aspiwack authored Jul 8, 2021
1 parent b4531f0 commit ead3a75
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 39 deletions.
36 changes: 29 additions & 7 deletions bench/Data/Mutable/Array.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@ module Data.Mutable.Array (benchmarks) where
import Gauge
import Data.Function ((&))
import qualified Data.Unrestricted.Linear as Linear
import Data.List (foldl')
import qualified Prelude.Linear as Linear
import Control.DeepSeq (rnf)

import qualified Data.Array.Mutable.Linear as Array.Linear
import qualified Data.Vector

dontFuse :: a -> a
dontFuse a = a
{-# NOINLINE dontFuse #-}

arr_size :: Int
arr_size = 10_000_000

benchmarks :: Benchmark
benchmarks = bgroup "arrays"
[ runImpls "map" bMap arr_size
[ runImpls "toList" bToList arr_size
, runImpls "map" bMap arr_size
, runImpls "reads" bReads arr_size
]

Expand All @@ -47,24 +51,42 @@ runImpls name impls size =

--------------------------------------------------------------------------------

bToList :: Impls
bToList = Impls linear dataVector
where
linear :: Array.Linear.Array Int %1-> ()
linear hm =
hm
Linear.& Array.Linear.toList
Linear.& Linear.lift rnf
Linear.& Linear.unur

dataVector :: Data.Vector.Vector Int -> ()
dataVector hm =
hm
& Data.Vector.toList
& rnf
{-# NOINLINE bToList #-}

bMap :: Impls
bMap = Impls linear dataVector
where
linear :: Array.Linear.Array Int %1-> ()
linear hm =
hm
Linear.& Array.Linear.map (+1)
Linear.& Array.Linear.toList
Linear.& Linear.lift (foldl' (+) 0)
Linear.& Linear.unur
Linear.& Array.Linear.unsafeGet 5
Linear.& (`Linear.lseq` ())

dataVector :: Data.Vector.Vector Int -> ()
dataVector hm =
hm
& Data.Vector.map (+1)
& Data.Vector.toList
& foldl' (+) 0
& dontFuse -- This looks like cheating, I know. But we're trying to measure
-- the speed of `map`, and without this, `vector` fuses the `map`
-- with the subsequent `index` to skip writing to the rest of the
-- vector.
& (`Data.Vector.unsafeIndex` 5)
& (`seq` ())
{-# NOINLINE bMap #-}

Expand Down
53 changes: 21 additions & 32 deletions src/Data/Array/Mutable/Unlifted/Linear.hs
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,27 @@ copyInto start@(GHC.I# start#) = Unsafe.toLinear2 go
{-# NOINLINE copyInto #-} -- prevents the runRW# effect from being reordered

map :: (a -> b) -> Array# a %1-> Array# b
map (f :: a -> b) arr =
size arr
`chain2` \(# Ur s, arr' #) -> go 0 s arr'
where
-- When we're mapping an array, we first insert `b`'s
-- inside an `Array# a` by unsafeCoerce'ing, and then we
-- unsafeCoerce the result to an `Array# b`.
go :: Int -> Int -> Array# a %1-> Array# b
go i s arr'
| i Prelude.== s =
Unsafe.toLinear GHC.unsafeCoerce# arr'
| Prelude.otherwise =
get i arr'
`chain2` \(# Ur a, arr'' #) -> set i (Unsafe.coerce (f a)) arr''
`chain` \arr''' -> go (i Prelude.+ 1) s arr'''
map (f :: a -> b) = Unsafe.toLinear (\(Array# as) ->
let -- We alias the input array to write the resulting -- 'b's to,
-- just to make the typechecker happy. Care must be taken to
-- only read indices from 'as' that is not yet written to 'bs'.
bs :: GHC.MutableArray# GHC.RealWorld b
bs = GHC.unsafeCoerce# as
len :: GHC.Int#
len = GHC.sizeofMutableArray# as

-- For each index ([0..len]), we read the element on 'as', pass
-- it through 'f' and write to the same location on 'bs'.
go :: GHC.Int# -> GHC.State# GHC.RealWorld -> ()
go i st
| GHC.I# i Prelude.== GHC.I# len = ()
| Prelude.otherwise =
case GHC.readArray# as i st of
(# st', a #) ->
case GHC.writeArray# bs i (f a) st' of
!st'' -> go (i GHC.+# 1#) st''
in GHC.runRW# (go 0#) `GHC.seq` Array# bs
)
{-# NOINLINE map #-}

-- | Return the array elements as a lazy list.
Expand Down Expand Up @@ -178,20 +184,3 @@ dup2 = Unsafe.toLinear go
(GHC.cloneMutableArray# arr 0# (GHC.sizeofMutableArray# arr)) of
(# _, new #) -> (# Array# arr, Array# new #)
{-# NOINLINE dup2 #-}

-- * Internal library

-- Below two are variants of (&) specialized for taking commonly used
-- unlifted values and returning a levity-polymorphic result.
--
-- They are not polymorphic on their first parameter since levity-polymorphism
-- disallows binding to levity-polymorphic values.

chain :: forall (r :: GHC.RuntimeRep) a (b :: GHC.TYPE r).
Array# a %1-> (Array# a %1-> b) %1-> b
chain a f = f a

chain2 :: forall (r :: GHC.RuntimeRep) a b (c :: GHC.TYPE r).
(# b, Array# a #) %1-> ((# b, Array# a #) %1-> c) %1-> c
chain2 a f = f a
infixl 1 `chain`, `chain2`

0 comments on commit ead3a75

Please sign in to comment.