{-# LANGUAGE GADTs, UndecidableInstances, NoMonomorphismRestriction, GeneralizedNewtypeDeriving, MultiParamTypeClasses, FlexibleInstances #-} module Control.Monad.Trans.Task ( -- * Task monad transformer TaskT (..) -- * Trace of a base monad , Trace (..) , runTrace -- * Task functions , taskToTrace , runTask ) where import Control.Applicative import Control.Monad.Reader.Class import Control.Monad.State.Class import Control.Monad.IO.Class import Control.Monad.Trans import Control.Monad.Trans.Cont import Control.Monad.Task.Class import Data.Either (partitionEithers) -- | A @Trace m e@ represents the control flow of a mult-threaded task monad -- defined over a base monad @m@ and event type @e@. data Trace m e where EXIT :: Trace m e RET :: Trace m e YIELD :: m (Trace m e) -> Trace m e FORK :: m (Trace m e) -> m (Trace m e) -> Trace m e WATCH :: (e -> Maybe v) -> (v -> m (Trace m e)) -> Trace m e SIGNAL :: e -> m (Trace m e) -> Trace m e -- | @runTrace@ runs a trace to its completion in the base monad with a simple -- round-robin scheduler. runTrace :: Monad m => m (Trace m e) -> m () runTrace prog = loop [prog] [] where loop [] _ = return () loop (m:ms) ss = m >>= step where step EXIT = return () step RET = loop ms ss step (YIELD t) = loop (ms ++ [t]) ss step (FORK t1 t2) = loop (t1:t2:ms) ss step (WATCH f g) = loop ms (WATCH f g : ss) step (SIGNAL e t) = loop (ms' ++ [t] ++ ms) ss' where (ms', ss') = partitionEithers evs evs = [ maybe (Right x) (Left . g) (f e) | x@(WATCH f g) <- ss ] -- | Task monad transformer. newtype TaskT e m a = TaskT { runTaskT :: ContT (Trace m e) m a } deriving (Functor, Applicative, MonadIO) -- | @tasktoTrace@ CPS-converts a task monad into a trace in its base monad. taskToTrace :: Monad m => TaskT e m a -> m (Trace m e) taskToTrace (TaskT (ContT f)) = f (\_ -> return RET) -- | @runTask@ runs a task monad until to its completion, i.e., no more active -- tasks to run, or until it exits. -- -- * @'runTask' = 'runTrace' . 'taskToTrace'@ runTask :: Monad m => TaskT e m a -> m () runTask = runTrace . taskToTrace instance Monad m => Monad (TaskT e m) where return = TaskT . return (>>=) m f = TaskT $ runTaskT m >>= runTaskT . f fail _ = TaskT $ ContT $ \_ -> return EXIT instance MonadTrans (TaskT e) where lift = TaskT . lift instance MonadReader s m => MonadReader s (TaskT e m) where ask = TaskT ask local f = TaskT . local f . runTaskT instance MonadState s m => MonadState s (TaskT e m) where get = TaskT get put = TaskT . put instance Monad m => MonadTask e (TaskT e m) where exit = TaskT $ ContT $ \_ -> return EXIT yield = TaskT $ ContT $ return . YIELD . ($()) fork p = TaskT $ ContT $ return . FORK (taskToTrace p) . ($()) watch f = TaskT $ ContT $ return . WATCH f signal e = TaskT $ ContT $ return . SIGNAL e . ($())