module MonadLib (
Id, ReaderT, WriterT, StateT, ExceptionT, ContT,
MonadT(..), BaseM(..),
ReaderM(..), WriterM(..), StateM(..), ExceptionM(..), ContM(..),
Label, labelCC, jump,
runId, runReaderT, runWriterT, runStateT, runExceptionT, runContT,
RunReaderM(..), RunWriterM(..), RunStateM(..), RunExceptionM(..),
version,
module Control.Monad
) where
import Control.Monad
import Control.Monad.Fix
import Data.Monoid
version :: (Int,Int,Int)
version = (3,0,0)
newtype Id a = I a
newtype ReaderT i m a = R (i -> m a)
newtype WriterT i m a = W (m (a,i))
newtype StateT i m a = S (i -> m (a,i))
newtype ExceptionT i m a = X (m (Either i a))
newtype ContT i m a = C ((a -> m i) -> m i)
runId :: Id a -> a
runId (I a) = a
runReaderT :: i -> ReaderT i m a -> m a
runReaderT i (R m) = m i
runWriterT :: WriterT i m a -> m (a,i)
runWriterT (W m) = m
runStateT :: i -> StateT i m a -> m (a,i)
runStateT i (S m) = m i
runExceptionT :: ExceptionT i m a -> m (Either i a)
runExceptionT (X m) = m
runContT :: (a -> m i) -> ContT i m a -> m i
runContT i (C m) = m i
class MonadT t where
lift :: (Monad m) => m a -> t m a
instance MonadT (ReaderT i) where lift m = R (\_ -> m)
instance MonadT (StateT i) where lift m = S (\s -> liftM (\a -> (a,s)) m)
instance (Monoid i) =>
MonadT (WriterT i) where lift m = W (liftM (\a -> (a,mempty)) m)
instance MonadT (ExceptionT i) where lift m = X (liftM Right m)
instance MonadT (ContT i) where lift m = C (m >>=)
class (Monad m, Monad n) => BaseM m n | m -> n where
inBase :: n a -> m a
instance BaseM IO IO where inBase = id
instance BaseM Maybe Maybe where inBase = id
instance BaseM [] [] where inBase = id
instance BaseM Id Id where inBase = id
instance (BaseM m n) => BaseM (ReaderT i m) n where inBase = lift . inBase
instance (BaseM m n) => BaseM (StateT i m) n where inBase = lift . inBase
instance (BaseM m n,Monoid i) => BaseM (WriterT i m) n where inBase = lift . inBase
instance (BaseM m n) => BaseM (ExceptionT i m) n where inBase = lift . inBase
instance (BaseM m n) => BaseM (ContT i m) n where inBase = lift . inBase
instance Monad Id where
return x = I x
fail x = error x
m >>= k = k (runId m)
instance (Monad m) => Monad (ReaderT i m) where
return x = lift (return x)
fail x = lift (fail x)
m >>= k = R $ \r -> runReaderT r m >>= \a ->
runReaderT r (k a)
instance (Monad m) => Monad (StateT i m) where
return x = lift (return x)
fail x = lift (fail x)
m >>= k = S $ \s -> runStateT s m >>= \ ~(a,s') ->
runStateT s' (k a)
instance (Monad m,Monoid i) => Monad (WriterT i m) where
return x = lift (return x)
fail x = lift (fail x)
m >>= k = W $ runWriterT m >>= \ ~(a,w1) ->
runWriterT (k a) >>= \ ~(b,w2) ->
return (b,mappend w1 w2)
instance (Monad m) => Monad (ExceptionT i m) where
return x = lift (return x)
fail x = lift (fail x)
m >>= k = X $ runExceptionT m >>= \a ->
case a of
Left x -> return (Left x)
Right a -> runExceptionT (k a)
instance (Monad m) => Monad (ContT i m) where
return x = lift (return x)
fail x = lift (fail x)
m >>= k = C $ \c -> runContT (\a -> runContT c (k a)) m
instance Functor Id where fmap = liftM
instance (Monad m) => Functor (ReaderT i m) where fmap = liftM
instance (Monad m) => Functor (StateT i m) where fmap = liftM
instance (Monad m,Monoid i) => Functor (WriterT i m) where fmap = liftM
instance (Monad m) => Functor (ExceptionT i m) where fmap = liftM
instance (Monad m) => Functor (ContT i m) where fmap = liftM
instance MonadFix Id where
mfix f = let m = f (runId m) in m
instance (MonadFix m) => MonadFix (ReaderT i m) where
mfix f = R $ \r -> mfix (runReaderT r . f)
instance (MonadFix m) => MonadFix (StateT i m) where
mfix f = S $ \s -> mfix (runStateT s . f . fst)
instance (MonadFix m,Monoid i) => MonadFix (WriterT i m) where
mfix f = W $ mfix (runWriterT . f . fst)
instance (MonadFix m) => MonadFix (ExceptionT i m) where
mfix f = X $ mfix (runExceptionT . f . fromRight)
where fromRight (Right a) = a
fromRight _ = error "ExceptionT: mfix looped."
instance (MonadPlus m) => MonadPlus (ReaderT i m) where
mzero = lift mzero
mplus (R m) (R n) = R (\r -> mplus (m r) (n r))
instance (MonadPlus m) => MonadPlus (StateT i m) where
mzero = lift mzero
mplus (S m) (S n) = S (\s -> mplus (m s) (n s))
instance (MonadPlus m,Monoid i) => MonadPlus (WriterT i m) where
mzero = lift mzero
mplus (W m) (W n) = W (mplus m n)
instance (MonadPlus m) => MonadPlus (ExceptionT i m) where
mzero = lift mzero
mplus (X m) (X n) = X (mplus m n)
class (Monad m) => ReaderM m i | m -> i where
ask :: m i
instance (Monad m) => ReaderM (ReaderT i m) i where
ask = R return
instance (ReaderM m j,Monoid i)
=> ReaderM (WriterT i m) j where ask = lift ask
instance (ReaderM m j) => ReaderM (StateT i m) j where ask = lift ask
instance (ReaderM m j) => ReaderM (ExceptionT i m) j where ask = lift ask
instance (ReaderM m j) => ReaderM (ContT i m) j where ask = lift ask
class (Monad m) => WriterM m i | m -> i where
put :: i -> m ()
instance (Monad m,Monoid i) => WriterM (WriterT i m) i where
put x = W (return ((),x))
instance (WriterM m j) => WriterM (ReaderT i m) j where put = lift . put
instance (WriterM m j) => WriterM (StateT i m) j where put = lift . put
instance (WriterM m j) => WriterM (ExceptionT i m) j where put = lift . put
instance (WriterM m j) => WriterM (ContT i m) j where put = lift . put
class (Monad m) => StateM m i | m -> i where
get :: m i
set :: i -> m ()
instance (Monad m) => StateM (StateT i m) i where
get = S (\s -> return (s,s))
set s = S (\_ -> return ((),s))
instance (StateM m j) => StateM (ReaderT i m) j where
get = lift get
set = lift . set
instance (StateM m j,Monoid i) => StateM (WriterT i m) j where
get = lift get
set = lift . set
instance (StateM m j) => StateM (ExceptionT i m) j where
get = lift get
set = lift . set
instance (StateM m j) => StateM (ContT i m) j where
get = lift get
set = lift . set
class (Monad m) => ExceptionM m i | m -> i where
raise :: i -> m a
instance (Monad m) => ExceptionM (ExceptionT i m) i where
raise x = X (return (Left x))
instance (ExceptionM m j) => ExceptionM (ReaderT i m) j where
raise = lift . raise
instance (ExceptionM m j,Monoid i) => ExceptionM (WriterT i m) j where
raise = lift . raise
instance (ExceptionM m j) => ExceptionM (StateT i m) j where
raise = lift . raise
instance (ExceptionM m j) => ExceptionM (ContT i m) j where
raise = lift . raise
class Monad m => ContM m where
callCC :: ((a -> m b) -> m a) -> m a
instance (ContM m) => ContM (ReaderT i m) where
callCC f = R $ \r -> callCC $ \k -> runReaderT r $ f $ \a -> lift $ k a
instance (ContM m) => ContM (StateT i m) where
callCC f = S $ \s -> callCC $ \k -> runStateT s $ f $ \a -> lift $ k (a,s)
instance (ContM m,Monoid i) => ContM (WriterT i m) where
callCC f = W $ callCC $ \k -> runWriterT $ f $ \a -> lift $ k (a,mempty)
instance (ContM m) => ContM (ExceptionT i m) where
callCC f = X $ callCC $ \k -> runExceptionT $ f $ \a -> lift $ k $ Right a
instance (Monad m) => ContM (ContT i m) where
callCC f = C $ \k -> runContT k $ f $ \a -> C $ \_ -> k a
class (ReaderM m i) => RunReaderM m i | m -> i where
local :: i -> m a -> m a
instance (Monad m) => RunReaderM (ReaderT i m) i where
local i m = lift (runReaderT i m)
instance (RunReaderM m j,Monoid i) => RunReaderM (WriterT i m) j where
local i (W m) = W (local i m)
instance (RunReaderM m j) => RunReaderM (StateT i m) j where
local i (S m) = S (local i . m)
instance (RunReaderM m j) => RunReaderM (ExceptionT i m) j where
local i (X m) = X (local i m)
class WriterM m i => RunWriterM m i | m -> i where
collect :: m a -> m (a,i)
instance (RunWriterM m j) => RunWriterM (ReaderT i m) j where
collect (R m) = R (collect . m)
instance (Monad m,Monoid i) => RunWriterM (WriterT i m) i where
collect (W m) = lift m
instance (RunWriterM m j) => RunWriterM (StateT i m) j where
collect (S m) = S (liftM swap . collect . m)
where swap (~(a,s),w) = ((a,w),s)
instance (RunWriterM m j) => RunWriterM (ExceptionT i m) j where
collect (X m) = X (liftM swap (collect m))
where swap (Right a,w) = Right (a,w)
swap (Left x,_) = Left x
class (StateM m i) => RunStateM m i | m -> i where
runS :: i -> m a -> m (a,i)
instance (RunStateM m j) => RunStateM (ReaderT i m) j where
runS s (R m) = R (runS s . m)
instance (RunStateM m j,Monoid i) => RunStateM (WriterT i m) j where
runS s (W m) = W (liftM swap (runS s m))
where swap (~(a,s),w) = ((a,w),s)
instance (Monad m) => RunStateM (StateT i m) i where
runS s m = lift (runStateT s m)
instance (RunStateM m j) => RunStateM (ExceptionT i m) j where
runS s (X m) = X (liftM swap (runS s m))
where swap (Left e,_) = Left e
swap (Right a,s) = Right (a,s)
class ExceptionM m i => RunExceptionM m i | m -> i where
try :: m a -> m (Either i a)
instance (RunExceptionM m i) => RunExceptionM (ReaderT j m) i where
try (R m) = R (try . m)
instance (RunExceptionM m i,Monoid j) => RunExceptionM (WriterT j m) i where
try (W m) = W (liftM swap (try m))
where swap (Right ~(a,w)) = (Right a,w)
swap (Left e) = (Left e, mempty)
instance (RunExceptionM m i) => RunExceptionM (StateT j m) i where
try (S m) = S (\s -> liftM (swap s) (try (m s)))
where swap _ (Right ~(a,s)) = (Right a,s)
swap s (Left e) = (Left e, s)
instance (Monad m) => RunExceptionM (ExceptionT i m) i where
try m = lift (runExceptionT m)
newtype Label m a = L ((a, Label m a) -> m ())
labelCC :: (ContM m) => a -> m (a, Label m a)
labelCC x = callCC (\k -> return (x, L k))
jump :: (ContM m) => a -> Label m a -> m b
jump x (L k) = k (x, L k) >> return unreachable
where unreachable = error "(bug) jump: unreachable"