{-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TupleSections #-} ----------------------------------------------------------------------------- -- | -- Maintainer : me@joelt.io -- Stability : experimental -- Portability : portable -- -- Concurrent state monad, providing a State-like interface but allowing -- for multiple threads to operate on the same value simultaneously. -- -- This module performs state computations lazily. For a strict version, -- see "Control.Monad.State.Concurrent.Strict". ----------------------------------------------------------------------------- module Control.Monad.State.Concurrent.Lazy ( module Control.Monad.State, -- *** The StateC monad transformer StateC, -- *** Concurrent state operations runStateC, evalStateC, execStateC, -- *** Running concurrent operations on a single input runStatesC, evalStatesC, execStatesC, -- *** Lifting other operations liftCallCCC, liftCatchC, liftListenC, liftPassC ) where import Control.Applicative import Control.Concurrent.Lifted.Fork import Control.Concurrent.MVar import Control.Concurrent.STM import Control.Exception (throwIO) import Control.Monad import Control.Monad.Catch import Control.Monad.State #if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ < 706 import Prelude hiding (catch) #endif -- --------------------------------------------------------------------------- -- | A concurrent state transformer monad parameterized by: -- -- * @s@ - The state. This is contained in a 'TVar'. -- -- * @m@ - The inner monad. -- -- The 'return' function leaves the state unchanged, while @>>=@ performs -- actions atomically on the held 'TVar'. -- -- This is very similar to @transformers@' 'StateT', with the exception of -- the 'MonadIO' constraint on every instance, which is necessary to -- perform STM actions. newtype StateC s m a = StateC { _runStateC :: TVar s -> m (a, TVar s) } instance MonadTrans (StateC s) where lift m = StateC $ \s -> do a <- m return (a, s) instance (Functor m, MonadIO m) => Functor (StateC s m) where fmap f m = StateC $ \s -> fmap (\ ~(a, s') -> (f a, s')) $ _runStateC m s instance (Functor m, MonadIO m) => Applicative (StateC s m) where pure = return (<*>) = ap instance (MonadIO m, Functor m, MonadPlus m) => Alternative (StateC s m) where empty = mzero (<|>) = mplus instance (MonadPlus m, MonadIO m) => MonadPlus (StateC s m) where mzero = StateC $ const mzero m `mplus` n = StateC $ \s -> _runStateC m s `mplus` _runStateC n s instance MonadIO m => Monad (StateC s m) where return a = StateC $ \s -> return (a, s) m >>= k = StateC $ \s -> do ~(a, s') <- _runStateC m s _runStateC (k a) s' instance (Functor m, MonadIO m) => MonadState s (StateC s m) where state f = StateC $ \tv -> do newval <- liftIO . atomically $ do old <- readTVar tv let ~(a, s) = f old swapTVar tv s return a return (newval, tv) instance (MonadIO m, MonadFix m) => MonadFix (StateC s m) where mfix f = StateC $ \s -> mfix $ \ ~(a, _) -> _runStateC (f a) s instance MonadIO m => MonadIO (StateC s m) where liftIO i = StateC $ \s -> do a <- liftIO i return (a, s) instance (MonadIO m, MonadCatch m) => MonadCatch (StateC s m) where throwM = liftIO . throwIO catch = liftCatchC catch mask a = StateC $ \tv -> mask $ \u -> _runStateC (a $ q u) tv where q u (StateC f) = StateC (u . f) uninterruptibleMask a = StateC $ \tv -> uninterruptibleMask $ \u -> _runStateC (a $ q u) tv where q u (StateC f) = StateC (u . f) instance MonadFork m => MonadFork (StateC s m) where fork = liftFork fork forkOn i = liftFork (forkOn i) forkOS = liftFork forkOS liftFork :: Monad m => (m () -> m a) -> StateC t m () -> StateC t m a liftFork f (StateC m) = StateC $ \tv -> do tid <- f . voidM $ m tv return (tid, tv) where voidM = (>> return ()) -- | Unwrap a concurrent state monad computation as a function. runStateC :: MonadIO m => StateC s m a -- ^ state-passing computation to execute -> TVar s -- ^ initial state -> m (a, s) -- ^ return value and final state runStateC m s = do ~(a, b) <- _runStateC m s r <- liftIO $ readTVarIO b return (a, r) -- | Evaluate a concurrent state computation with the given initial state -- and return the final value, discarding the final state. -- -- * @'evalStateC' m s = 'liftM' 'fst' ('runStateC' m s)@ evalStateC :: MonadIO m => StateC s m a -- ^ state-passing computation to execute -> TVar s -- ^ initial state -> m a -- ^ return value evalStateC m s = liftM fst $ runStateC m s -- | Execute a concurrent state computation with the given initial state and return -- the final state, discarding the final value. -- -- * @'execStateC' m s = 'liftM' 'snd' ('runStateC' m s)@ execStateC :: MonadIO m => StateC s m a -- ^ state-passing computation to execute -> TVar s -- ^ initial state -> m s -- ^ final state execStateC m s = liftM snd $ runStateC m s -- | Uniform lifting of a @callCC@ operation to the new monad. liftCallCCC :: ((((a, TVar s) -> m (b, TVar s)) -> m (a, TVar s)) -> m (a, TVar s)) -> ((a -> StateC s m b) -> StateC s m a) -> StateC s m a liftCallCCC callCC f = StateC $ \tv -> callCC $ \c -> _runStateC (f (\a -> StateC $ \_ -> c (a, tv))) tv -- | Lift a @catchError@ operation to the new monad. liftCatchC :: (m (a, TVar s) -> (e -> m (a, TVar s)) -> m (a, TVar s)) -> StateC s m a -> (e -> StateC s m a) -> StateC s m a liftCatchC catchError m h = StateC $ \s -> _runStateC m s `catchError` \e -> _runStateC (h e) s -- | Lift a @listen@ operation to the new monad. liftListenC :: Monad m => (m (a, TVar s) -> m ((a, TVar s), w)) -> StateC s m a -> StateC s m (a,w) liftListenC listen m = StateC $ \tv -> do ~((a, s'), w) <- listen (_runStateC m tv) return ((a, w), s') -- | Lift a @pass@ operation to the new monad. liftPassC :: Monad m => (m ((a, TVar s), b) -> m (a, TVar s)) -> StateC s m (a, b) -> StateC s m a liftPassC pass m = StateC $ \tv -> pass $ do ~((a, f), s') <- _runStateC m tv return ((a, s'), f) -- | Run multiple state operations on the same value, returning the -- resultant state and the value produced by each operation. runStatesC :: MonadFork m => [StateC s m a] -- ^ state-passing computations to execute -> s -- ^ initial state -> m ([a], s) -- ^ return values and final state runStatesC ms s = do v <- liftIO $ newTVarIO s mvs <- mapM (const (liftIO newEmptyMVar)) ms forM_ (zip mvs ms) $ \(mv, operation) -> fork $ do res <- evalStateC operation v liftIO $ putMVar mv res items <- forM mvs (liftIO . takeMVar) end <- liftIO $ readTVarIO v return (items, end) -- | Run multiple state operations on the same value, returning all values -- produced by each operation. -- -- * @'evalStatesC' ms s = 'liftM' 'fst' ('runStatesC' ms s)@ evalStatesC :: MonadFork m => [StateC s m a] -- ^ state-passing computations to execute -> s -- ^ initial state -> m [a] -- ^ return values evalStatesC ms s = liftM fst $ runStatesC ms s -- | Run multiple state operations on the same value, returning the -- resultant state. -- -- * @'execStatesC' ms s = 'liftM' 'snd' ('runStatesC' ms s)@ execStatesC :: MonadFork m => [StateC s m a] -- ^ state-passing computations to execute -> s -- ^ initial state -> m s -- ^ final state execStatesC ms s = liftM snd $ runStatesC ms s