{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-} --------------------------------------------------------------------------- -- | -- Module : Control.Concurrent.MState -- Copyright : (c) Nils Schweinsberg 2010 -- License : BSD3-style (see LICENSE) -- -- Maintainer : mail@n-sch.de -- Stability : unstable -- Portability : portable -- -- MState: A consistent State monad for concurrent applications. -- --------------------------------------------------------------------------- module Control.Concurrent.MState ( -- * The MState Monad MState , runMState , evalMState , execMState , mapMState , withMState , modifyM -- * Concurrency , Forkable (..) , forkM -- * Example -- $example ) where import Control.Monad import Control.Monad.State.Class import Control.Monad.Cont import Control.Monad.Error import Control.Monad.Reader import Control.Monad.Writer import Control.Concurrent import qualified Control.Exception as E -- | The MState is an abstract data definition for a State monad which can be -- used in concurrent applications. It can be accessed with @evalMState@ and -- @execMState@. To start a new state thread use @forkM@. newtype MState t m a = MState { runMState' :: (MVar t, Chan (MVar ())) -> m a } -- | The class which is needed to start new threads in the MState monad. Don't -- confuse this with @forkM@ which should be used to fork new threads! class (MonadIO m) => Forkable m where fork :: m () -> m ThreadId instance Forkable IO where fork = forkIO instance Forkable (ReaderT s IO) where fork newT = ask >>= liftIO . forkIO . runReaderT newT catchMVar :: IO a -> (E.BlockedIndefinitelyOnMVar -> IO a) -> IO a catchMVar = E.catch -- | Read the Chan full of MVars and wait for all MVars to get filled by the -- threads. On MVar-exception this will skip the current MVar and take the next -- one (if available). waitForTermination :: MonadIO m => Chan (MVar ()) -> m () waitForTermination c = liftIO $ do empty <- isEmptyChan c catchMVar (unless empty $ do -- Read next threads MVar and wait until it's filled mv <- readChan c _ <- takeMVar mv waitForTermination c) (const $ return ()) -- | Run the MState and return both, the function value and the state value runMState :: Forkable m => MState t m a -- ^ Action to evaluate -> t -- ^ Initial state value -> m (a,t) runMState m t = do ref <- liftIO $ newMVar t c <- liftIO newChan mv <- liftIO newEmptyMVar _ <- runMState' (forkM $ m >>= liftIO . putMVar mv) (ref, c) waitForTermination c a <- liftIO $ takeMVar mv t' <- liftIO $ readMVar ref return (a,t') -- | Evaluate the MState monad with the given initial state, throwing away the -- final state stored in the MVar. evalMState :: Forkable m => MState t m a -- ^ Action to evaluate -> t -- ^ Initial state value -> m a evalMState m t = runMState m t >>= return . fst -- | Execute the MState monad with a given initial state. Returns the value of -- the final state. execMState :: Forkable m => MState t m a -- ^ Action to execute -> t -- ^ Initial state value -> m t execMState m t = runMState m t >>= return . snd -- | Map a stateful computation from one @(return value, state)@ pair to -- another. See @Control.Monad.State.Lazy.mapState@ for more information. mapMState :: (MonadIO m, MonadIO n) => (m (a,t) -> n (b,t)) -> MState t m a -> MState t n b mapMState f m = MState $ \s@(r,_) -> do ~(b,v') <- f $ do a <- runMState' m s v <- liftIO $ readMVar r return (a,v) _ <- liftIO $ swapMVar r v' return b -- | Apply this function to this state and return the resulting state. withMState :: (MonadIO m) => (t -> t) -> MState t m a -> MState t m a withMState f m = MState $ \s@(r,_) -> do liftIO $ modifyMVar_ r (return . f) runMState' m s -- | Start a new thread, using @forkIO@. The main process will wait for all -- child processes to finish. forkM :: Forkable m => MState t m () -- ^ State action to be forked -> MState t m ThreadId forkM m = MState $ \s@(_,c) -> do -- Add new thread MVar to our waiting channel w <- liftIO newEmptyMVar liftIO $ writeChan c w fork $ runMState' m s >> liftIO (putMVar w ()) -- | Modify the MState. Block all other threads from accessing the state. modifyM :: (MonadIO m) => (t -> t) -> MState t m () modifyM f = MState $ \(t,_) -> liftIO $ modifyMVar_ t (return . f) -------------------------------------------------------------------------------- -- Monad instances -------------------------------------------------------------------------------- instance (Monad m) => Monad (MState t m) where return a = MState $ \_ -> return a m >>= k = MState $ \t -> do a <- runMState' m t runMState' (k a) t fail str = MState $ \_ -> fail str instance (Monad m) => Functor (MState t m) where fmap f m = MState $ \t -> do a <- runMState' m t return (f a) instance (MonadPlus m) => MonadPlus (MState t m) where mzero = MState $ \_ -> mzero m `mplus` n = MState $ \t -> runMState' m t `mplus` runMState' n t instance (MonadIO m) => MonadState t (MState t m) where get = MState $ \(r,_) -> liftIO $ readMVar r put val = MState $ \(r,_) -> do _ <- liftIO $ swapMVar r val return () instance (MonadFix m) => MonadFix (MState t m) where mfix f = MState $ \s -> mfix $ \a -> runMState' (f a) s -------------------------------------------------------------------------------- -- mtl instances -------------------------------------------------------------------------------- instance MonadTrans (MState t) where lift m = MState $ \_ -> m instance (MonadIO m) => MonadIO (MState t m) where liftIO = lift . liftIO instance (MonadCont m) => MonadCont (MState t m) where callCC f = MState $ \s -> callCC $ \c -> runMState' (f (\a -> MState $ \_ -> c a)) s instance (MonadError e m) => MonadError e (MState t m) where throwError = lift . throwError m `catchError` h = MState $ \s -> runMState' m s `catchError` \e -> runMState' (h e) s instance (MonadReader r m) => MonadReader r (MState t m) where ask = lift ask local f m = MState $ \s -> local f (runMState' m s) instance (MonadWriter w m) => MonadWriter w (MState t m) where tell = lift . tell listen m = MState $ listen . runMState' m pass m = MState $ pass . runMState' m {- $example Example usage: > import Control.Concurrent > import Control.Concurrent.MState > import Control.Monad.State > > type MyState a = MState Int IO a > > -- Expected state value: 2 > main = print =<< execMState incTwice 0 > > incTwice :: MyState () > incTwice = do > > -- First inc > inc > > -- This thread should get killed before it can "inc" our state: > kill =<< forkM incDelayed > -- This thread should "inc" our state > forkM incDelayed > > return () > > where > inc = get >>= put . (+1) > kill = liftIO . killThread > incDelayed = do liftIO $ threadDelay 2000000 > inc -}