module Control.Concurrent.MState
(
MState
, runMState
, evalMState
, execMState
, mapMState
, withMState
, modifyM
, Forkable (..)
, forkM
) where
import Control.Monad
import Control.Monad.State.Class
import Control.Monad.Cont
import Control.Monad.Error
import Control.Monad.Reader
import Control.Monad.Writer
import Control.Concurrent
import qualified Control.Exception as E
newtype MState t m a = MState { runMState' :: (MVar t, Chan (MVar ())) -> m a }
class (MonadIO m) => Forkable m where
fork :: m () -> m ThreadId
instance Forkable IO where
fork = forkIO
instance Forkable (ReaderT s IO) where
fork newT = ask >>= liftIO . forkIO . runReaderT newT
catchMVar :: IO a -> (E.BlockedIndefinitelyOnMVar -> IO a) -> IO a
catchMVar = E.catch
waitForTermination :: MonadIO m
=> Chan (MVar ())
-> m ()
waitForTermination c = liftIO $ do
empty <- isEmptyChan c
catchMVar (unless empty $ do
mv <- readChan c
_ <- takeMVar mv
waitForTermination c)
(const $ return ())
runMState :: Forkable m
=> MState t m a
-> t
-> m (a,t)
runMState m t = do
ref <- liftIO $ newMVar t
c <- liftIO newChan
mv <- liftIO newEmptyMVar
_ <- runMState' (forkM $ m >>= liftIO . putMVar mv) (ref, c)
waitForTermination c
a <- liftIO $ takeMVar mv
t' <- liftIO $ readMVar ref
return (a,t')
evalMState :: Forkable m
=> MState t m a
-> t
-> m a
evalMState m t = runMState m t >>= return . fst
execMState :: Forkable m
=> MState t m a
-> t
-> m t
execMState m t = runMState m t >>= return . snd
mapMState :: (MonadIO m, MonadIO n)
=> (m (a,t) -> n (b,t))
-> MState t m a
-> MState t n b
mapMState f m = MState $ \s@(r,_) -> do
~(b,v') <- f $ do
a <- runMState' m s
v <- liftIO $ readMVar r
return (a,v)
_ <- liftIO $ swapMVar r v'
return b
withMState :: (MonadIO m)
=> (t -> t)
-> MState t m a
-> MState t m a
withMState f m = MState $ \s@(r,_) -> do
liftIO $ modifyMVar_ r (return . f)
runMState' m s
forkM :: Forkable m
=> MState t m ()
-> MState t m ThreadId
forkM m = MState $ \s@(_,c) -> do
w <- liftIO newEmptyMVar
liftIO $ writeChan c w
fork $ runMState' m s >> liftIO (putMVar w ())
modifyM :: (MonadIO m) => (t -> t) -> MState t m ()
modifyM f = MState $ \(t,_) -> liftIO $ modifyMVar_ t (return . f)
instance (Monad m) => Monad (MState t m) where
return a = MState $ \_ -> return a
m >>= k = MState $ \t -> do
a <- runMState' m t
runMState' (k a) t
fail str = MState $ \_ -> fail str
instance (Monad m) => Functor (MState t m) where
fmap f m = MState $ \t -> do
a <- runMState' m t
return (f a)
instance (MonadPlus m) => MonadPlus (MState t m) where
mzero = MState $ \_ -> mzero
m `mplus` n = MState $ \t -> runMState' m t `mplus` runMState' n t
instance (MonadIO m) => MonadState t (MState t m) where
get = MState $ \(r,_) -> liftIO $ readMVar r
put val = MState $ \(r,_) -> do _ <- liftIO $ swapMVar r val
return ()
instance (MonadFix m) => MonadFix (MState t m) where
mfix f = MState $ \s -> mfix $ \a -> runMState' (f a) s
instance MonadTrans (MState t) where
lift m = MState $ \_ -> m
instance (MonadIO m) => MonadIO (MState t m) where
liftIO = lift . liftIO
instance (MonadCont m) => MonadCont (MState t m) where
callCC f = MState $ \s ->
callCC $ \c ->
runMState' (f (\a -> MState $ \_ -> c a)) s
instance (MonadError e m) => MonadError e (MState t m) where
throwError = lift . throwError
m `catchError` h = MState $ \s ->
runMState' m s `catchError` \e -> runMState' (h e) s
instance (MonadReader r m) => MonadReader r (MState t m) where
ask = lift ask
local f m = MState $ \s -> local f (runMState' m s)
instance (MonadWriter w m) => MonadWriter w (MState t m) where
tell = lift . tell
listen m = MState $ listen . runMState' m
pass m = MState $ pass . runMState' m