{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}
module Data.Massiv.Array.Mutable
( Mutable
, MArray
, msize
, new
, thaw
, freeze
, read
, read'
, write
, write'
, modify
, modify'
, swap
, swap'
, RealWorld
, computeInto
, generateM
, generateLinearM
, mapM
, imapM
, forM
, iforM
, sequenceM
) where
import Prelude hiding (mapM, read)
import Control.Monad (unless)
import Control.Monad.Primitive (PrimMonad (..))
import Data.Massiv.Array.Manifest.Internal
import Data.Massiv.Array.Unsafe
import Data.Massiv.Core.Common
import GHC.Base (Int (..))
import GHC.Prim
new :: (Mutable r ix e, PrimMonad m) => ix -> m (MArray (PrimState m) r ix e)
new sz = unsafeNewZero (liftIndex (max 0) sz)
{-# INLINE new #-}
thaw :: (Mutable r ix e, PrimMonad m) => Array r ix e -> m (MArray (PrimState m) r ix e)
thaw = unsafeThaw . clone
{-# INLINE thaw #-}
freeze :: (Mutable r ix e, PrimMonad m) => Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
freeze comp marr = unsafeFreeze comp marr >>= (return . clone)
{-# INLINE freeze #-}
read :: (Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m (Maybe e)
read marr ix =
if isSafeIndex (msize marr) ix
then Just <$> unsafeRead marr ix
else return Nothing
{-# INLINE read #-}
read' :: (Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
read' marr ix = do
mval <- read marr ix
case mval of
Just e -> return e
Nothing -> errorIx "Data.Massiv.Array.Mutable.read'" (msize marr) ix
{-# INLINE read' #-}
write :: (Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m Bool
write marr ix e =
if isSafeIndex (msize marr) ix
then unsafeWrite marr ix e >> return True
else return False
{-# INLINE write #-}
write' :: (Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
write' marr ix e =
write marr ix e >>= (`unless` errorIx "Data.Massiv.Array.Mutable.write'" (msize marr) ix)
{-# INLINE write' #-}
modify :: (Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> (e -> e) -> ix -> m Bool
modify marr f ix =
if isSafeIndex (msize marr) ix
then do
val <- unsafeRead marr ix
unsafeWrite marr ix $ f val
return True
else return False
{-# INLINE modify #-}
modify' :: (Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> (e -> e) -> ix -> m ()
modify' marr f ix =
modify marr f ix >>= (`unless` errorIx "Data.Massiv.Array.Mutable.modify'" (msize marr) ix)
{-# INLINE modify' #-}
swap :: (Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> ix -> m Bool
swap marr ix1 ix2 = do
let sz = msize marr
if isSafeIndex sz ix1 && isSafeIndex sz ix2
then do
val1 <- unsafeRead marr ix1
val2 <- unsafeRead marr ix2
unsafeWrite marr ix1 val2
unsafeWrite marr ix2 val1
return True
else return False
{-# INLINE swap #-}
swap' :: (Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> ix -> m ()
swap' marr ix1 ix2 = do
success <- swap marr ix1 ix2
unless success $
errorIx "Data.Massiv.Array.Mutable.swap'" (msize marr) $
if isSafeIndex (msize marr) ix1
then ix2
else ix1
{-# INLINE swap' #-}
unsafeLinearFillM :: (Mutable r ix e, Monad m) =>
MArray RealWorld r ix e -> (Int -> m e) -> WorldState -> m WorldState
unsafeLinearFillM ma f (State s_#) = go 0# s_#
where
!(I# k#) = totalElem (msize ma)
go i# s# =
case i# <# k# of
0# -> return (State s#)
_ -> do
let i = I# i#
res <- f i
State s'# <- unsafeLinearWriteA ma i res (State s#)
go (i# +# 1#) s'#
{-# INLINE unsafeLinearFillM #-}
generateLinearM :: (Monad m, Mutable r ix e) => Comp -> ix -> (Int -> m e) -> m (Array r ix e)
generateLinearM comp sz f = do
(s, mba) <- unsafeNewA (liftIndex (max 0) sz) (State (noDuplicate# realWorld#))
s' <- unsafeLinearFillM mba f s
(_, ba) <- unsafeFreezeA comp mba s'
return ba
{-# INLINE generateLinearM #-}
generateM :: (Monad m, Mutable r ix e) => Comp -> ix -> (ix -> m e) -> m (Array r ix e)
generateM comp sz f = generateLinearM comp sz (f . fromLinearIndex sz)
{-# INLINE generateM #-}
imapM
:: (Monad m, Source r ix e, Mutable r' ix e') =>
r' -> (ix -> e -> m e') -> Array r ix e -> m (Array r' ix e')
imapM _ f arr =
generateLinearM (getComp arr) sz (\ !i -> f (fromLinearIndex sz i) (unsafeLinearIndex arr i))
where
!sz = size arr
{-# INLINE imapM #-}
mapM
:: (Monad m, Source r ix e, Mutable r' ix e') =>
r' -> (e -> m e') -> Array r ix e -> m (Array r' ix e')
mapM r f = imapM r (const f)
{-# INLINE mapM #-}
forM ::
(Monad m, Source r ix e, Mutable r' ix e')
=> r'
-> Array r ix e
-> (e -> m e')
-> m (Array r' ix e')
forM r = flip (mapM r)
{-# INLINE forM #-}
iforM :: (Monad m, Source r ix e, Mutable r' ix e') =>
r' -> Array r ix e -> (ix -> e -> m e') -> m (Array r' ix e')
iforM r = flip (imapM r)
{-# INLINE iforM #-}
sequenceM
:: (Monad m, Source r ix (m e), Mutable r' ix e) =>
r' -> Array r ix (m e) -> m (Array r' ix e)
sequenceM r = mapM r id
{-# INLINE sequenceM #-}