module Foundation.Array.Boxed
( Array
, MArray
) where
import GHC.Prim
import GHC.Types
import GHC.ST
import Foundation.Number
import Foundation.Internal.Base
import Foundation.Internal.Types
import Foundation.Primitive.Types
import Foundation.Primitive.Monad
import Foundation.Array.Common
import qualified Foundation.Collection as C
import qualified Prelude
data Array a = Array (Array# a)
deriving (Typeable)
data MArray a st = MArray (MutableArray# st a)
deriving (Typeable)
instance Functor Array where
fmap = map
instance Monoid (Array a) where
mempty = empty
mappend = append
mconcat = concat
instance Show a => Show (Array a) where
show v = show (toList v)
instance Eq a => Eq (Array a) where
(==) = equal
instance Ord a => Ord (Array a) where
compare = vCompare
type instance C.Element (Array ty) = ty
instance IsList (Array ty) where
type Item (Array ty) = ty
fromList = vFromList
toList = vToList
instance C.InnerFunctor (Array ty)
instance C.Sequential (Array ty) where
null = null
take = take
drop = drop
splitAt = splitAt
revTake = revTake
revDrop = revDrop
revSplitAt = revSplitAt
splitOn = splitOn
break = break
intersperse = intersperse
span = span
reverse = reverse
filter = filter
unsnoc = unsnoc
uncons = uncons
snoc = snoc
cons = cons
find = find
sortBy = sortBy
length = length
singleton = fromList . (:[])
instance C.MutableCollection (MArray ty) where
type Collection (MArray ty) = Array ty
type MutableKey (MArray ty) = Int
type MutableValue (MArray ty) = ty
thaw = thaw
freeze = freeze
unsafeThaw = unsafeThaw
unsafeFreeze = unsafeFreeze
mutUnsafeWrite = unsafeWrite
mutUnsafeRead = unsafeRead
mutWrite = write
mutRead = read
instance C.IndexedCollection (Array ty) where
(!) l n
| n < 0 || n >= length l = Nothing
| otherwise = Just $ index l n
findIndex predicate c = loop 0
where
!len = length c
loop i
| i == len = Nothing
| otherwise =
if predicate (unsafeIndex c i) then Just i else Nothing
instance C.Zippable (Array ty) where
zipWith f a b = runST $ do
mv <- new len
go mv 0 f (toList a) (toList b)
unsafeFreeze mv
where
!len = Size $ min (C.length a) (C.length b)
go _ _ _ [] _ = return ()
go _ _ _ _ [] = return ()
go mv i f' (a':as') (b':bs') = do
write mv i (f' a' b')
go mv (i + 1) f' as' bs'
instance C.BoxedZippable (Array ty)
mutableLength :: MArray ty st -> Int
mutableLength (MArray ma) = I# (sizeofMutableArray# ma)
index :: Array ty -> Int -> ty
index array n
| n < 0 || n >= len = throw (OutOfBound OOB_Index n len)
| otherwise = unsafeIndex array n
where len = length array
unsafeIndex :: Array ty -> Int -> ty
unsafeIndex (Array a) (I# n) = let (# v #) = indexArray# a n in v
read :: PrimMonad prim => MArray ty (PrimState prim) -> Int -> prim ty
read array n
| n < 0 || n >= len = primThrow (OutOfBound OOB_Read n len)
| otherwise = unsafeRead array n
where len = mutableLength array
unsafeRead :: PrimMonad prim => MArray ty (PrimState prim) -> Int -> prim ty
unsafeRead (MArray ma) (I# i) = primitive $ \s1 -> readArray# ma i s1
write :: PrimMonad prim => MArray ty (PrimState prim) -> Int -> ty -> prim ()
write array n val
| n < 0 || n >= len = primThrow (OutOfBound OOB_Write n len)
| otherwise = unsafeWrite array n val
where len = mutableLength array
unsafeWrite :: PrimMonad prim => MArray ty (PrimState prim) -> Int -> ty -> prim ()
unsafeWrite (MArray ma) (I# i) v = primitive $ \s1 -> let !s2 = writeArray# ma i v s1 in (# s2, () #)
unsafeFreeze :: PrimMonad prim => MArray ty (PrimState prim) -> prim (Array ty)
unsafeFreeze (MArray ma) = primitive $ \s1 ->
case unsafeFreezeArray# ma s1 of
(# s2, a #) -> (# s2, Array a #)
unsafeThaw :: PrimMonad prim => Array ty -> prim (MArray ty (PrimState prim))
unsafeThaw (Array a) = primitive $ \st -> (# st, MArray (unsafeCoerce# a) #)
thaw :: PrimMonad prim => Array ty -> prim (MArray ty (PrimState prim))
thaw array = do
m <- new (lengthSize array)
unsafeCopyAtRO m (Offset 0) array (Offset 0) (lengthSize array)
return m
freeze :: PrimMonad prim => MArray ty (PrimState prim) -> prim (Array ty)
freeze marray = do
m <- new sz
copyAt m (Offset 0) marray (Offset 0) sz
unsafeFreeze m
where
sz = Size $ mutableLength marray
copyAt :: PrimMonad prim
=> MArray ty (PrimState prim)
-> Offset ty
-> MArray ty (PrimState prim)
-> Offset ty
-> Size ty
-> prim ()
copyAt dst od src os n = loop od os
where !endIndex = os `offsetPlusE` n
loop (Offset d) s@(Offset i)
| s == endIndex = return ()
| otherwise = unsafeRead src i >>= unsafeWrite dst d >> loop (Offset $ d+1) (Offset $ i+1)
unsafeCopyAtRO :: PrimMonad prim
=> MArray ty (PrimState prim)
-> Offset ty
-> Array ty
-> Offset ty
-> Size ty
-> prim ()
unsafeCopyAtRO dst od src os n = loop od os
where !endIndex = os `offsetPlusE` n
loop (Offset d) s@(Offset i)
| s == endIndex = return ()
| otherwise = unsafeWrite dst d (unsafeIndex src i) >> loop (Offset $ d+1) (Offset $ i+1)
unsafeCopyFrom :: Array ty
-> Size ty
-> (Array ty -> Offset ty -> MArray ty s -> ST s ())
-> ST s (Array ty)
unsafeCopyFrom v' newLen f = new newLen >>= fill (Offset 0) f >>= unsafeFreeze
where len = lengthSize v'
endIdx = Offset 0 `offsetPlusE` len
fill i f' r'
| i == endIdx = return r'
| otherwise = do f' v' i r'
fill (i + Offset 1) f' r'
new :: PrimMonad prim => Size ty -> prim (MArray ty (PrimState prim))
new (Size (I# n)) = primitive $ \s1 ->
case newArray# n (error "vector: internal error uninitialized vector") s1 of
(# s2, ma #) -> (# s2, MArray ma #)
create :: Int
-> (Int -> ty)
-> Array ty
create n initializer = runST (new (Size n) >>= iter initializer)
where
iter :: PrimMonad prim => (Int -> ty) -> MArray ty (PrimState prim) -> prim (Array ty)
iter f ma = loop (Offset 0)
where
!end = Offset 0 `offsetPlusE` Size n
loop s@(Offset i)
| s == end = unsafeFreeze ma
| otherwise = unsafeWrite ma i (f i) >> loop (Offset $ i+1)
equal :: Eq a => Array a -> Array a -> Bool
equal a b = (len == length b) && eachEqual 0
where
len = length a
eachEqual !i
| i == len = True
| unsafeIndex a i /= unsafeIndex b i = False
| otherwise = eachEqual (i+1)
vCompare :: Ord a => Array a -> Array a -> Ordering
vCompare a b = loop 0
where
!la = length a
!lb = length b
loop n
| n == la = if la == lb then EQ else LT
| n == lb = GT
| otherwise =
case unsafeIndex a n `compare` unsafeIndex b n of
EQ -> loop (n+1)
r -> r
empty :: Array a
empty = runST $ onNewArray 0 (\_ s -> s)
length :: Array a -> Int
length (Array a) = I# (sizeofArray# a)
lengthSize :: Array a -> Size a
lengthSize (Array a) = Size $ I# (sizeofArray# a)
vFromList :: [a] -> Array a
vFromList l = runST (new len >>= loop 0 l)
where
len = Size $ C.length l
loop _ [] ma = unsafeFreeze ma
loop i (x:xs) ma = unsafeWrite ma i x >> loop (i+1) xs ma
vToList :: Array a -> [a]
vToList v = fmap (unsafeIndex v) [0..(length v 1)]
append :: Array ty -> Array ty -> Array ty
append a b = runST $ do
r <- new (la+lb)
ma <- unsafeThaw a
mb <- unsafeThaw b
copyAt r (Offset 0) ma (Offset 0) la
copyAt r (sizeAsOffset la) mb (Offset 0) lb
unsafeFreeze r
where la = lengthSize a
lb = lengthSize b
concat :: [Array ty] -> Array ty
concat l = runST $ do
r <- new (Size $ Prelude.sum $ fmap length l)
loop r (Offset 0) l
unsafeFreeze r
where loop _ _ [] = return ()
loop r i (x:xs) = do
mx <- unsafeThaw x
copyAt r i mx (Offset 0) lx
loop r (i `offsetPlusE` lx) xs
where lx = lengthSize x
onNewArray :: PrimMonad m
=> Int
-> (MutableArray# (PrimState m) a -> State# (PrimState m) -> State# (PrimState m))
-> m (Array a)
onNewArray (I# len) f = primitive $ \st -> do
case newArray# len (error "onArray") st of { (# st2, mv #) ->
case f mv st2 of { st3 ->
case unsafeFreezeArray# mv st3 of { (# st4, a #) ->
(# st4, Array a #) }}}
null :: Array ty -> Bool
null = (==) 0 . length
take :: Int -> Array ty -> Array ty
take nbElems v
| nbElems <= 0 = empty
| otherwise = runST $ do
muv <- new n
unsafeCopyAtRO muv (Offset 0) v (Offset 0) n
unsafeFreeze muv
where
n = Size $ min nbElems (length v)
drop :: Int -> Array ty -> Array ty
drop nbElems v
| nbElems <= 0 = v
| otherwise = runST $ do
muv <- new n
unsafeCopyAtRO muv (Offset 0) v (Offset offset) n
unsafeFreeze muv
where
offset = min nbElems (length v)
n = Size (length v offset)
splitAt :: Int -> Array ty -> (Array ty, Array ty)
splitAt n v = (take n v, drop n v)
revTake :: Int -> Array ty -> Array ty
revTake nbElems v = drop (length v nbElems) v
revDrop :: Int -> Array ty -> Array ty
revDrop nbElems v = take (length v nbElems) v
revSplitAt :: Int -> Array ty -> (Array ty, Array ty)
revSplitAt n v = (drop idx v, take idx v)
where idx = length v n
splitOn :: (ty -> Bool) -> Array ty -> [Array ty]
splitOn predicate vec
| len == Size 0 = []
| otherwise = loop (Offset 0) (Offset 0)
where
!len = lengthSize vec
!endIdx = Offset 0 `offsetPlusE` len
loop prevIdx idx@(Offset i)
| idx == endIdx = [runST $ sub vec prevIdx idx]
| otherwise =
let e = unsafeIndex vec i
idx' = idx + Offset 1
in if predicate e
then runST (sub vec prevIdx idx) : loop idx' idx'
else loop prevIdx idx'
sub :: PrimMonad prim => Array ty -> Offset ty -> Offset ty -> prim (Array ty)
sub vec startIdx expectedEndIdx
| startIdx == endIdx = return empty
| startIdx >= sizeAsOffset len = return empty
| otherwise = new sz >>= loop startIdx (Offset 0)
where
!len = lengthSize vec
loop os@(Offset s) od@(Offset d) mv
| os == endIdx = unsafeFreeze mv
| otherwise = unsafeWrite mv d (unsafeIndex vec s) >> loop (os+Offset 1) (od+Offset 1) mv
!sz = endIdx startIdx
!endIdx = min expectedEndIdx (sizeAsOffset $ lengthSize vec)
break :: (ty -> Bool) -> Array ty -> (Array ty, Array ty)
break predicate v = findBreak 0
where
findBreak i
| i == length v = (v, empty)
| otherwise =
if predicate (unsafeIndex v i)
then splitAt i v
else findBreak (i+1)
intersperse :: ty -> Array ty -> Array ty
intersperse sep v
| len <= Size 1 = v
| otherwise = runST $ unsafeCopyFrom v ((len + len) Size 1) (go (Offset 0 `offsetPlusE` (len Size 1)) sep)
where len = lengthSize v
go :: Offset ty -> ty -> Array ty -> Offset ty -> MArray ty s -> ST s ()
go endI sep' oldV oldI@(Offset oi) newV
| oldI == endI = unsafeWrite newV dst e
| otherwise = do
unsafeWrite newV dst e
unsafeWrite newV (dst + 1) sep'
where
e = unsafeIndex oldV oi
(Offset dst) = oldI + oldI
span :: (ty -> Bool) -> Array ty -> (Array ty, Array ty)
span p = break (not . p)
map :: (a -> b) -> Array a -> Array b
map f a = create (length a) (\i -> f $ unsafeIndex a i)
cons :: ty -> Array ty -> Array ty
cons e vec
| len == Size 0 = C.singleton e
| otherwise = runST $ do
mv <- new (len + Size 1)
unsafeWrite mv 0 e
unsafeCopyAtRO mv (Offset 1) vec (Offset 0) len
unsafeFreeze mv
where
!len = lengthSize vec
snoc :: Array ty -> ty -> Array ty
snoc vec e
| len == Size 0 = C.singleton e
| otherwise = runST $ do
mv <- new (len + Size 1)
unsafeCopyAtRO mv (Offset 0) vec (Offset 0) len
unsafeWrite mv lastI e
unsafeFreeze mv
where
!len@(Size lastI) = lengthSize vec
uncons :: Array ty -> Maybe (ty, Array ty)
uncons vec
| len == 0 = Nothing
| otherwise = Just (unsafeIndex vec 0, drop 1 vec)
where
!len = length vec
unsnoc :: Array ty -> Maybe (Array ty, ty)
unsnoc vec
| len == 0 = Nothing
| otherwise = Just (take (len 1) vec, unsafeIndex vec (len1))
where
!len = length vec
find :: (ty -> Bool) -> Array ty -> Maybe ty
find predicate vec = loop 0
where
!len = length vec
loop i
| i == len = Nothing
| otherwise =
let e = unsafeIndex vec i
in if predicate e then Just e else loop (i+1)
sortBy :: (ty -> ty -> Ordering) -> Array ty -> Array ty
sortBy xford vec = runST (thaw vec >>= doSort xford)
where
len = length vec
doSort :: PrimMonad prim => (ty -> ty -> Ordering) -> MArray ty (PrimState prim) -> prim (Array ty)
doSort ford ma = qsort 0 (len 1) >> unsafeFreeze ma
where
qsort lo hi
| lo >= hi = return ()
| otherwise = do
p <- partition lo hi
qsort lo (p1)
qsort (p+1) hi
partition lo hi = do
pivot <- unsafeRead ma hi
let loop i j
| j == hi = return i
| otherwise = do
aj <- unsafeRead ma j
i' <- if ford aj pivot == GT
then return i
else do
ai <- unsafeRead ma i
unsafeWrite ma j ai
unsafeWrite ma i aj
return $ i + 1
loop i' (j+1)
i <- loop lo lo
ai <- unsafeRead ma i
ahi <- unsafeRead ma hi
unsafeWrite ma hi ai
unsafeWrite ma i ahi
return i
filter :: (ty -> Bool) -> Array ty -> Array ty
filter predicate vec = runST (new len >>= copyFilterFreeze predicate (unsafeIndex vec))
where
!len = lengthSize vec
!end = Offset 0 `offsetPlusE` len
copyFilterFreeze :: PrimMonad prim => (ty -> Bool) -> (Int -> ty) -> MArray ty (PrimState prim) -> prim (Array ty)
copyFilterFreeze predi getVec mvec = loop (Offset 0) (Offset 0) >>= freezeUntilIndex mvec
where
loop d@(Offset di) s@(Offset si)
| s == end = return d
| predi v = unsafeWrite mvec di v >> loop (d+Offset 1) (s+Offset 1)
| otherwise = loop d (s+Offset 1)
where
v = getVec si
freezeUntilIndex :: PrimMonad prim => MArray ty (PrimState prim) -> Offset ty -> prim (Array ty)
freezeUntilIndex mvec d = do
m <- new (offsetAsSize d)
copyAt m (Offset 0) mvec (Offset 0) (offsetAsSize d)
unsafeFreeze m
reverse :: Array ty -> Array ty
reverse a = create len toEnd
where
len = length a
toEnd i = unsafeIndex a (len i 1)