-- | -- Module: Control.ContStuff -- Copyright: (c) 2010 Ertugrul Soeylemez -- License: BSD3 -- Maintainer: Ertugrul Soeylemez -- Stability: experimental -- -- This module implements a number of monad transformers using a CPS -- approach internally. {-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, TypeFamilies #-} module Control.ContStuff ( -- * The identity monad Id(..), -- * Monad transformers -- ** Identity transformer IdT(..), -- ** Basic CPS monads Cont, runCont, evalCont, modifyCont, ContT(..), runContT, evalContT, modifyContT, -- ** Advanced CPS monads -- *** Choice/nondeterminism ChoiceT(..), runChoiceT, findFirst, findAll, listChoiceT, listA, -- *** State State, runState, evalState, execState, StateT(..), runStateT, evalStateT, execStateT, -- ** Writer monads OldWriter, runOldWriter, evalOldWriter, execOldWriter, OldWriterT, runOldWriterT, evalOldWriterT, execOldWriterT, WriterT, runWriterT, -- * Effect classes Abortable(..), CallCC(..), Label, labelCC, goto, LiftBase(..), io, Runnable(..), Stateful(..), getField, modify, modifyField, modifyFieldLazy, modifyLazy, Transformer(..), Writable(..), -- * Module reexports module Control.Applicative, module Control.Monad ) where import Control.Applicative import Control.Arrow import Control.Monad import Control.Monad.Fix import Control.Monad.ST import Data.Monoid -- ================== -- -- The identity monad -- -- ================== -- -- | The identity monad. This monad represents values themselves, -- i.e. computations without effects. newtype Id a = Id { getId :: a } instance Functor Id where fmap f (Id x) = Id (f x) instance Applicative Id where pure = Id Id f <*> Id x = Id (f x) instance Monad Id where return = Id Id x >>= f = f x instance MonadFix Id where mfix f = fix (f . getId) -- ================== -- -- Monad transformers -- -- ================== -- ------------- -- ChoiceT -- ------------- -- | The choice monad transformer, which models, as the most common -- interpretation, nondeterminism. Internally a list of choices is -- represented as a CPS-based left-fold function. newtype ChoiceT r i m a = ChoiceT { getChoiceT :: (i -> a -> (i -> m r) -> m r) -> i -> (i -> m r) -> m r } instance Alternative (ChoiceT r i m) where empty = ChoiceT $ \_ z k -> k z ChoiceT c <|> ChoiceT d = ChoiceT $ \fold z k -> c fold z (\zc -> d fold zc k) instance Applicative (ChoiceT r i m) where pure x = ChoiceT $ \fold z k -> fold z x k ChoiceT cf <*> ChoiceT cx = ChoiceT $ \fold z k -> cx (\xx yx kx -> cf (\xf yf kf -> fold xf (yf yx) kf) xx kx) z k instance Functor (ChoiceT r i m) where fmap f (ChoiceT c) = ChoiceT $ \fold z k -> c (\x y k -> fold x (f y) k) z k instance Monad (ChoiceT r i m) where return x = ChoiceT $ \fold z k -> fold z x k ChoiceT c >>= f = ChoiceT $ \fold z k -> c (\x y kc -> getChoiceT (f y) fold x kc) z k instance Transformer (ChoiceT r i) where lift c = ChoiceT $ \fold z k -> c >>= \x -> fold z x k -- | Run a choice computation. runChoiceT :: (i -> a -> (i -> m r) -> m r) -> i -> (i -> m r) -> ChoiceT r i m a -> m r runChoiceT fold z k (ChoiceT c) = c fold z k -- | Find the first solution. findFirst :: (Alternative f, Applicative m) => ChoiceT (f a) (f a) m a -> m (f a) findFirst = runChoiceT (\_ y _ -> pure (pure y)) empty pure -- | Find all solutions. findAll :: (Alternative f, Applicative m) => ChoiceT (f a) (f a) m a -> m (f a) findAll = runChoiceT (\x y k -> k (x <|> pure y)) empty pure -- | Get list of solutions (faster than 'findAll', but returns solutions -- in reversed order). listChoiceT :: Applicative m => ChoiceT [a] [a] m a -> m [a] listChoiceT = runChoiceT (\x y k -> k (y:x)) [] pure -- | Turn a list into a computation with alternatives. listA :: Alternative f => [a] -> f a listA = foldr (<|>) empty . map pure ----------- -- ContT -- ----------- -- | The continuation passing style monad transformer. This monad -- transformer models the most basic form of CPS. newtype ContT r m a = ContT { getContT :: (a -> m r) -> m r } instance Applicative m => Abortable (ContT r m) where type Result (ContT r m) = r abort = ContT . const . pure instance (Alternative m, Monad m) => Alternative (ContT r m) where empty = ContT $ const empty ContT c <|> ContT d = ContT $ \k -> c k <|> d k instance Applicative (ContT r m) where pure = return ContT cf <*> ContT cx = ContT $ \k -> cf (\f -> cx (\x -> k (f x))) instance CallCC (ContT r m) where callCC f = ContT $ \k -> getContT (f (ContT . const . k)) k instance Functor (ContT r m) where fmap f (ContT c) = ContT $ \k -> c (\x -> k (f x)) instance Monad (ContT r m) where return x = ContT $ \k -> k x ContT c >>= f = ContT $ \k -> c (\x -> getContT (f x) k) instance Runnable (ContT r) r m a where type Argument (ContT r) r m a = a -> m r runT k (ContT c) = c k instance Transformer (ContT r) where lift c = ContT $ \k -> c >>= k instance Alternative m => Writable (ContT r m) r where tell x = ContT $ \k -> pure x <|> k () instance (Functor m, Monoid w) => Writable (ContT (r, w) m) w where tell x = ContT $ \k -> fmap (second (`mappend` x)) (k ()) -- | Run a CPS-style computation given the supplied final continuation. runContT :: (a -> m r) -> ContT r m a -> m r runContT k (ContT c) = c k -- | Evaluate a CPS-style computation to its final result. evalContT :: Applicative m => ContT r m r -> m r evalContT (ContT c) = c pure -- | Transform the final result along the way. modifyContT :: Functor m => (r -> r) -> ContT r m () modifyContT f = ContT $ \k -> fmap f (k ()) -- | Pure CPS monad derived from ContT. type Cont r a = ContT r Id a -- | Run a pure CPS computation. runCont :: (a -> r) -> Cont r a -> r runCont k (ContT c) = getId $ c (Id . k) -- | Evaluate a pure CPS computation to its final result. evalCont :: Cont r r -> r evalCont (ContT c) = getId $ c pure -- | Modify the result of a CPS computation along the way. modifyCont :: (r -> r) -> Cont r () modifyCont = modifyContT --------- -- IdT -- --------- -- | The identity monad transformer. This monad transformer represents -- computations themselves without further side effects. Unlike most -- other monad transformers in this module it is not implemented in -- terms of continuation passing style. newtype IdT m a = IdT { getIdT :: m a } instance Alternative m => Alternative (IdT m) where empty = IdT empty IdT c <|> IdT d = IdT (c <|> d) instance Applicative m => Applicative (IdT m) where pure = IdT . pure IdT cf <*> IdT cx = IdT $ cf <*> cx instance Functor m => Functor (IdT m) where fmap f (IdT c) = IdT (fmap f c) instance Monad m => Monad (IdT m) where return = IdT . return IdT c >>= f = IdT $ c >>= getIdT . f instance MonadFix m => MonadFix (IdT m) where mfix f = IdT $ mfix (getIdT . f) instance Runnable IdT r m r where type Argument IdT r m r = () runT _ (IdT c) = c instance Transformer IdT where lift = IdT ---------------- -- OldWriterT -- ---------------- -- | The traditional writer monad transformer. type OldWriterT r w m a = ContT (r, w) m a -- | Run a traditional writer transformer. runOldWriterT :: (Applicative m, Monoid w) => OldWriterT r w m r -> m (r, w) runOldWriterT (ContT c) = c (\x -> pure (x, mempty)) -- | Run a traditional writer transformer and return its result. evalOldWriterT :: (Applicative m, Monoid w) => OldWriterT r w m r -> m r evalOldWriterT = fmap fst . runOldWriterT -- | Run a traditional writer transformer and return its log. execOldWriterT :: (Applicative m, Monoid w) => OldWriterT r w m r -> m w execOldWriterT = fmap snd . runOldWriterT -- | The traditional writer monad. type OldWriter r w a = ContT (r, w) Id a -- | Run a traditional writer computation. runOldWriter :: Monoid w => OldWriter r w r -> (r, w) runOldWriter = getId . runOldWriterT -- | Run a traditional writer computation and return its result. evalOldWriter :: Monoid w => OldWriter r w r -> r evalOldWriter = fst . getId . runOldWriterT -- | Run a traditional writer computation and return its log. execOldWriter :: Monoid w => OldWriter r w r -> w execOldWriter = snd . getId . runOldWriterT ------------ -- StateT -- ------------ newtype StateT r s m a = StateT { getStateT :: s -> (s -> a -> m r) -> m r } instance Applicative m => Abortable (StateT r s m) where type Result (StateT r s m) = r abort x = StateT $ \_ _ -> pure x instance Alternative m => Alternative (StateT r s m) where empty = StateT . const . const $ empty StateT c <|> StateT d = StateT $ \s0 k -> c s0 k <|> d s0 k instance Applicative (StateT r s m) where pure = return StateT cf <*> StateT cx = StateT $ \s0 k -> cf s0 (\s1 f -> cx s1 (\s2 x -> k s2 (f x))) instance CallCC (StateT r s m) where callCC f = StateT $ \s0 k -> getStateT (f (\x -> StateT $ \s1 _ -> k s1 x)) s0 k instance Functor (StateT r s m) where fmap f (StateT c) = StateT $ \s0 k -> c s0 (\s1 -> k s1 . f) instance Monad (StateT r s m) where return x = StateT $ \s0 k -> k s0 x StateT c >>= f = StateT $ \s0 k -> c s0 (\s1 x -> getStateT (f x) s1 k) instance Runnable (StateT r s) r m a where type Argument (StateT r s) r m a = (s, s -> a -> m r) runT (s0, k) (StateT c) = c s0 k instance Stateful (StateT r s m) where type StateOf (StateT r s m) = s get = StateT $ \s0 k -> k s0 s0 put s1 = s1 `seq` StateT $ \_ k -> k s1 () putLazy s1 = StateT $ \_ k -> k s1 () instance Transformer (StateT r s) where lift c = StateT $ \s0 k -> c >>= k s0 instance Alternative m => Writable (StateT r s m) r where tell x = StateT $ \s0 k -> pure x <|> k s0 () instance (Functor m, Monoid w) => Writable (StateT (r, w) s m) w where tell x = StateT $ \s0 k -> fmap (second (`mappend` x)) (k s0 ()) -- | Run a state transformer. runStateT :: s -> (s -> a -> m r) -> StateT r s m a -> m r runStateT s0 k (StateT c) = c s0 k -- | Run a state transformer returning its result. evalStateT :: Applicative m => s -> StateT r s m r -> m r evalStateT s0 (StateT c) = c s0 (\_ x -> pure x) -- | Run a state transformer returning its final state. execStateT :: Applicative m => s -> StateT s s m a -> m s execStateT s0 (StateT c) = c s0 (\s1 _ -> pure s1) -- | Pure state monad derived from StateT. type State r s a = StateT r s Id a -- | Run a stateful computation. runState :: s -> (s -> a -> r) -> State r s a -> r runState s0 k c = getId $ runStateT s0 (\s1 -> Id . k s1) c -- | Run a stateful computation returning its result. evalState :: s -> State r s r -> r evalState = (getId .) . evalStateT -- | Run a stateful computation returning its result. execState :: s -> State s s a -> s execState = (getId .) . execStateT ------------- -- WriterT -- ------------- -- | The writer monad transformer. Supports logging effects. type WriterT = ContT -- | Run a writer transformer. runWriterT :: Alternative m => WriterT r m a -> m r runWriterT (ContT c) = c (const empty) -- ============== -- -- Effect classes -- -- ============== -- -- | Monads supporting abortion. class Abortable m where type Result m abort :: Result m -> m a -- | Monads supporting *call-with-current-continuation* (aka callCC). class CallCC m where -- | Call with current continuation. callCC :: ((a -> m b) -> m a) -> m a newtype Label m a = Label (a -> Label m a -> m ()) -- | Capture the current continuation for later use. labelCC :: (Applicative m, CallCC m) => a -> m (a, Label m a) labelCC x = callCC $ \k -> pure (x, Label $ curry k) -- | Jump to a label. goto :: Applicative m => Label m a -> a -> m b goto lk@(Label k) x = k x lk *> pure undefined -- | Monads, which support lifting base monad computations. class LiftBase m a where type Base m a base :: Base m a -> m a instance LiftBase IO a where type Base IO a = IO a; base = id instance LiftBase Id a where type Base Id a = Id a; base = id instance LiftBase Maybe a where type Base Maybe a = Maybe a; base = id instance LiftBase (ST s) a where type Base (ST s) a = ST s a; base = id instance LiftBase [] a where type Base [] a = [a]; base = id instance LiftBase ((->) r) a where type Base ((->) r) a = r -> a; base = id instance (LiftBase m a, Monad m) => LiftBase (IdT m) a where type Base (IdT m) a = Base m a; base = lift . base instance (LiftBase m a, Monad m) => LiftBase (ChoiceT r i m) a where type Base (ChoiceT r i m) a = Base m a; base = lift . base instance (LiftBase m a, Monad m) => LiftBase (ContT r m) a where type Base (ContT r m) a = Base m a; base = lift . base instance (LiftBase m a, Monad m) => LiftBase (StateT r s m) a where type Base (StateT r s m) a = Base m a; base = lift . base -- | Handy alias for lifting IO computations. io :: (LiftBase m a, Base m a ~ IO a) => Base m a -> m a io = base -- | Every monad transformer @t@ that supports transforming @t m a@ to -- @m a@ can be an instance of this class. class Runnable t r m a where type Argument t r m a runT :: Argument t r m a -> t m a -> m r -- | Stateful monads. class Stateful m where type StateOf m -- | Get the current state. get :: m (StateOf m) -- | Set the current state and force it. put :: StateOf m -> m () put x = x `seq` putLazy x -- | Set the current state, but don't force it. putLazy :: StateOf m -> m () instance (Monad m, Stateful m) => Stateful (ContT r m) where type StateOf (ContT r m) = StateOf m get = lift get put = lift . put putLazy = lift . putLazy -- | Get a certain field. getField :: (Functor m, Stateful m) => (StateOf m -> a) -> m a getField = (<$> get) -- | Apply a function to the current state. modify :: (Monad m, Stateful m) => (StateOf m -> StateOf m) -> m () modify f = liftM f get >>= put -- | Get a field and modify the state. modifyField :: (Monad m, Stateful m) => (StateOf m -> a) -> (a -> StateOf m) -> m () modifyField accessor f = liftM (f . accessor) get >>= put -- | Get a field and modify the state. Lazy version. modifyFieldLazy :: (Monad m, Stateful m) => (StateOf m -> a) -> (a -> StateOf m) -> m () modifyFieldLazy accessor f = liftM (f . accessor) get >>= putLazy -- | Apply a function to the current state. Lazy version. modifyLazy :: (Monad m, Stateful m) => (StateOf m -> StateOf m) -> m () modifyLazy f = liftM f get >>= putLazy -- | The monad transformer class. Lifting computations one level down -- the monad stack, or stated differently promoting a computation of the -- underlying monad to the transformer. class Transformer t where -- | Promote a monadic computation to the transformer. lift :: Monad m => m a -> t m a -- | Monads with support for logging. Traditionally these are called -- *writer monads*. class Writable m w where tell :: w -> m ()