-
Notifications
You must be signed in to change notification settings - Fork 70
/
Copy pathFunction.hs
154 lines (136 loc) · 4.21 KB
/
Function.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
{-# LANGUAGE BangPatterns, CPP, FlexibleContexts, Rank2Types #-}
#if __GLASGOW_HASKELL__ >= 704
{-# OPTIONS_GHC -fsimpl-tick-factor=200 #-}
#endif
-- |
-- Module : Statistics.Function
-- Copyright : (c) 2009, 2010, 2011 Bryan O'Sullivan
-- License : BSD3
--
-- Maintainer : [email protected]
-- Stability : experimental
-- Portability : portable
--
-- Useful functions.
module Statistics.Function
(
-- * Scanning
minMax
-- * Sorting
, sort
, gsort
, sortBy
, partialSort
-- * Indexing
, indexed
, indices
-- * Bit twiddling
, nextHighestPowerOfTwo
-- * Comparison
, within
-- * Arithmetic
, square
-- * Vectors
, unsafeModify
-- * Combinators
, for
, rfor
, for_
) where
#include "MachDeps.h"
import Control.Monad.ST (ST)
import Data.Bits ((.|.), shiftR)
import qualified Data.Vector.Algorithms.Intro as I
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as M
import Numeric.MathFunctions.Comparison (within)
-- | Sort a vector.
sort :: U.Vector Double -> U.Vector Double
sort = G.modify I.sort
{-# NOINLINE sort #-}
-- | Sort a vector.
gsort :: (Ord e, G.Vector v e) => v e -> v e
gsort = G.modify I.sort
{-# INLINE gsort #-}
-- | Sort a vector using a custom ordering.
sortBy :: (G.Vector v e) => I.Comparison e -> v e -> v e
sortBy f = G.modify $ I.sortBy f
{-# INLINE sortBy #-}
-- | Partially sort a vector, such that the least /k/ elements will be
-- at the front.
partialSort :: (G.Vector v e, Ord e) =>
Int -- ^ The number /k/ of least elements.
-> v e
-> v e
partialSort k = G.modify (`I.partialSort` k)
{-# SPECIALIZE partialSort :: Int -> U.Vector Double -> U.Vector Double #-}
-- | Return the indices of a vector.
indices :: (G.Vector v a, G.Vector v Int) => v a -> v Int
indices a = G.enumFromTo 0 (G.length a - 1)
{-# INLINE indices #-}
-- | Zip a vector with its indices.
indexed :: (G.Vector v e, G.Vector v Int, G.Vector v (Int,e)) => v e -> v (Int,e)
indexed a = G.zip (indices a) a
{-# INLINE indexed #-}
data MM = MM {-# UNPACK #-} !Double {-# UNPACK #-} !Double
-- | Compute the minimum and maximum of a vector in one pass.
minMax :: (G.Vector v Double) => v Double -> (Double, Double)
minMax = fini . G.foldl' go (MM (1/0) (-1/0))
where
go (MM lo hi) k = MM (min lo k) (max hi k)
fini (MM lo hi) = (lo, hi)
{-# INLINE minMax #-}
-- | Efficiently compute the next highest power of two for a
-- non-negative integer. If the given value is already a power of
-- two, it is returned unchanged. If negative, zero is returned.
nextHighestPowerOfTwo :: Int -> Int
nextHighestPowerOfTwo n
#if WORD_SIZE_IN_BITS == 64
= 1 + _i32
#else
= 1 + i16
#endif
where
i0 = n - 1
i1 = i0 .|. i0 `shiftR` 1
i2 = i1 .|. i1 `shiftR` 2
i4 = i2 .|. i2 `shiftR` 4
i8 = i4 .|. i4 `shiftR` 8
i16 = i8 .|. i8 `shiftR` 16
_i32 = i16 .|. i16 `shiftR` 32
-- It could be implemented as
--
-- > nextHighestPowerOfTwo n = 1 + foldl' go (n-1) [1, 2, 4, 8, 16, 32]
-- where go m i = m .|. m `shiftR` i
--
-- But GHC do not inline foldl (probably because it's recursive) and
-- as result function walks list of boxed ints. Hand rolled version
-- uses unboxed arithmetic.
-- | Multiply a number by itself.
square :: Double -> Double
square x = x * x
-- | Simple for loop. Counts from /start/ to /end/-1.
for :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
for n0 !n f = loop n0
where
loop i | i == n = return ()
| otherwise = f i >> loop (i+1)
{-# INLINE for #-}
-- | Simple reverse-for loop. Counts from /start/-1 to /end/ (which
-- must be less than /start/).
rfor :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
rfor n0 !n f = loop n0
where
loop i | i == n = return ()
| otherwise = let i' = i-1 in f i' >> loop i'
{-# INLINE rfor #-}
for_ :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
for_ n0 !n f | n0 > n = rfor n0 n f
| otherwise = for n0 n f
{-# INLINE for_ #-}
unsafeModify :: M.MVector s Double -> Int -> (Double -> Double) -> ST s ()
unsafeModify v i f = do
k <- M.unsafeRead v i
M.unsafeWrite v i (f k)
{-# INLINE unsafeModify #-}