module Control.Monad.ST.Trans(
STT,
runST,
STRef,
newSTRef,
readSTRef,
writeSTRef,
STArray,
newSTArray,
readSTArray,
writeSTArray,
boundsSTArray,
numElementsSTArray,
freezeSTArray,
thawSTArray,
runSTArray
)where
import GHC.Base
import GHC.Arr (Ix(..), safeRangeSize, safeIndex,
Array(..), arrEleBottom)
import Control.Monad.Trans
import Control.Monad.Error.Class
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Writer.Class
newtype STT s m a = STT (State# s -> m (STTRet s a))
unSTT (STT f) = f
data STTRet s a = STTRet (State# s) a
instance Monad m => Monad (STT s m) where
return a = STT $ \st -> return (STTRet st a)
STT m >>= k = STT $ \st ->
do ret <- m st
case ret of
STTRet new_st a ->
unSTT (k a) new_st
instance MonadTrans (STT s) where
lift m = STT $ \st ->
do a <- m
return (STTRet st a)
data STRef s a = STRef (MutVar# s a)
newSTRef :: Monad m => a -> STT s m (STRef s a)
newSTRef init = STT $ \st1 ->
case newMutVar# init st1 of
(# st2, var #) -> return (STTRet st2 (STRef var))
readSTRef :: Monad m => STRef s a -> STT s m a
readSTRef (STRef var) = STT $ \st1 ->
case readMutVar# var st1 of
(# st2, a #) -> return (STTRet st2 a)
writeSTRef :: Monad m => STRef s a -> a -> STT s m ()
writeSTRef (STRef var) a = STT $ \st1 ->
case writeMutVar# var a st1 of
st2 -> return (STTRet st2 ())
instance Eq (STRef s a) where
STRef v1 == STRef v2 = sameMutVar# v1 v2
runST :: Monad m => (forall s. STT s m a) -> m a
runST m = let (STT f) = m
in do (STTRet st a) <- ( f realWorld# )
return a
instance MonadError e m => MonadError e (STT s m) where
throwError e = lift (throwError e)
catchError (STT m) f = STT $ \st -> catchError (m st)
(\e -> unSTT (f e) st)
instance MonadReader r m => MonadReader r (STT s m) where
ask = lift ask
local f (STT m) = STT $ \st -> local f (m st)
instance MonadState s m => MonadState s (STT s' m) where
get = lift get
put s = lift (put s)
instance MonadWriter w m => MonadWriter w (STT s m) where
tell w = lift (tell w)
listen (STT m)= STT $ \st1 -> do (STTRet st2 a, w) <- listen (m st1)
return (STTRet st2 (a,w))
pass (STT m) = STT $ \st1 -> pass (do (STTRet st2 (a,f)) <- m st1
return (STTRet st2 a, f))
data STArray s i e = STArray !i !i !Int (MutableArray# s e)
instance Eq (STArray s i e) where
STArray _ _ _ arr1# == STArray _ _ _ arr2# = sameMutableArray# arr1# arr2#
newSTArray :: (Ix i, Monad m) => (i,i) -> e -> STT s m (STArray s i e)
newSTArray (l,u) init = STT $ \s1# ->
case safeRangeSize (l,u) of { n@(I# n#) ->
case newArray# n# init s1# of { (# s2#, marr# #) ->
return (STTRet s2# (STArray l u n marr#)) }}
boundsSTArray :: STArray s i e -> (i,i)
boundsSTArray (STArray l u _ _) = (l,u)
numElementsSTArray :: STArray s i e -> Int
numElementsSTArray (STArray _ _ n _) = n
readSTArray :: (Ix i, Monad m) => STArray s i e -> i -> STT s m e
readSTArray marr@(STArray l u n _) i =
unsafeReadSTArray marr (safeIndex (l,u) n i)
unsafeReadSTArray :: (Ix i, Monad m) => STArray s i e -> Int -> STT s m e
unsafeReadSTArray (STArray _ _ _ marr#) (I# i#)
= STT $ \s1# -> case readArray# marr# i# s1# of
(# s2#, e #) -> return (STTRet s2# e)
writeSTArray :: (Ix i, Monad m) => STArray s i e -> i -> e -> STT s m ()
writeSTArray marr@(STArray l u n _) i e =
unsafeWriteSTArray marr (safeIndex (l,u) n i) e
unsafeWriteSTArray :: (Ix i, Monad m) => STArray s i e -> Int -> e -> STT s m ()
unsafeWriteSTArray (STArray _ _ _ marr#) (I# i#) e = STT $ \s1# ->
case writeArray# marr# i# e s1# of
s2# -> return (STTRet s2# ())
freezeSTArray :: (Ix i,Monad m) => STArray s i e -> STT s m (Array i e)
freezeSTArray (STArray l u n@(I# n#) marr#) = STT $ \s1# ->
case newArray# n# arrEleBottom s1# of { (# s2#, marr'# #) ->
let copy i# s3# | i# ==# n# = s3#
| otherwise =
case readArray# marr# i# s3# of { (# s4#, e #) ->
case writeArray# marr'# i# e s4# of { s5# ->
copy (i# +# 1#) s5# }} in
case copy 0# s2# of { s3# ->
case unsafeFreezeArray# marr'# s3# of { (# s4#, arr# #) ->
return (STTRet s4# (Array l u n arr# )) }}}
unsafeFreezeSTArray :: (Ix i, Monad m) => STArray s i e -> STT s m (Array i e)
unsafeFreezeSTArray (STArray l u n marr#) = STT $ \s1# ->
case unsafeFreezeArray# marr# s1# of { (# s2#, arr# #) ->
return (STTRet s2# (Array l u n arr# )) }
thawSTArray :: (Ix i, Monad m) => Array i e -> STT s m (STArray s i e)
thawSTArray (Array l u n@(I# n#) arr#) = STT $ \s1# ->
case newArray# n# arrEleBottom s1# of { (# s2#, marr# #) ->
let copy i# s3# | i# ==# n# = s3#
| otherwise =
case indexArray# arr# i# of { (# e #) ->
case writeArray# marr# i# e s3# of { s4# ->
copy (i# +# 1#) s4# }} in
case copy 0# s2# of { s3# ->
return (STTRet s3# (STArray l u n marr# )) }}
unsafeThawSTArray :: (Ix i, Monad m) => Array i e -> STT s m (STArray s i e)
unsafeThawSTArray (Array l u n arr#) = STT $ \s1# ->
case unsafeThawArray# arr# s1# of { (# s2#, marr# #) ->
return (STTRet s2# (STArray l u n marr# )) }
runSTArray :: (Ix i, Monad m)
=> (forall s . STT s m (STArray s i e))
-> m (Array i e)
runSTArray st = runST (st >>= unsafeFreezeSTArray)