module MonadLib (
  
  
  Id, Lift, IdT, ReaderT, WriterT,
  StateT,
  ExceptionT,
  
  
  ChoiceT, ContT,
  
  
  MonadT(..), BaseM(..),
  
  
  ReaderM(..), WriterM(..), StateM(..), ExceptionM(..), ContM(..), AbortM(..),
  Label, labelCC, jump,
  
  
  
  runId, runLift,
  runIdT, runReaderT, runWriterT,
  runStateT, runExceptionT, runContT,
  runChoiceT, findOne, findAll,
  
  
  RunReaderM(..), RunWriterM(..), RunExceptionM(..),
  
  asks, puts, sets, sets_, raises,
  mapReader, mapWriter, mapException,
  handle,
  
  version,
  module Control.Monad
) where
import Control.Applicative
import Control.Monad
import Control.Monad.Fix
import Control.Monad.ST (ST)
import qualified Control.Exception as IO (throwIO,try)
#ifdef USE_BASE3
import qualified Control.Exception as IO (Exception)
#else
import qualified Control.Exception as IO (SomeException)
#endif
import System.Exit(ExitCode,exitWith)
import Data.Monoid
import Prelude hiding (Ordering(..))
version :: (Int,Int,Int)
version = (3,5,2)
newtype Id a              = I a
data Lift a               = L a
newtype IdT m a           = IT (m a)
newtype ReaderT i m a     = R (i -> m a)
newtype WriterT i m a = W { unW :: m (P a i) }
data P a i = P a !i
newtype StateT     i m a  = S (i -> m (a,i))
newtype ExceptionT i m a  = X (m (Either i a))
data ChoiceT m a          = NoAnswer
                          | Answer a
                          | Choice (ChoiceT m a) (ChoiceT m a)
                          | ChoiceEff (m (ChoiceT m a))
newtype ContT i m a  = C ((a -> m i) -> m i)
runId         :: Id a -> a
runId (I a) = a
runLift       :: Lift a -> a
runLift (L a) = a
runIdT        :: IdT m a -> m a
runIdT (IT a)  = a
runReaderT    :: i -> ReaderT i m a -> m a
runReaderT i (R m) = m i
runWriterT :: (Monad m) => WriterT i m a -> m (a,i)
runWriterT (W m) = liftM to_pair m
  where to_pair ~(P a w) = (a,w)
runStateT     :: i -> StateT i m a -> m (a,i)
runStateT i (S m) = m i
runExceptionT :: ExceptionT i m a -> m (Either i a)
runExceptionT (X m) = m
runChoiceT :: (Monad m) => ChoiceT m a -> m (Maybe (a,ChoiceT m a))
runChoiceT (Answer a)     = return (Just (a,NoAnswer))
runChoiceT NoAnswer       = return Nothing
runChoiceT (Choice l r)   = do x <- runChoiceT l
                               case x of
                                 Nothing      -> runChoiceT r
                                 Just (a,l1)  -> return (Just (a,Choice l1 r))
runChoiceT (ChoiceEff m)  = runChoiceT =<< m
findOne :: (Monad m) => ChoiceT m a -> m (Maybe a)
findOne m = fmap fst `liftM` runChoiceT m
findAll :: (Monad m) => ChoiceT m a -> m [a]
findAll m = all_res =<< runChoiceT m
  where all_res Nothing       = return []
        all_res (Just (a,as)) = (a:) `liftM` findAll as
runContT      :: (a -> m i) -> ContT i m a -> m i
runContT i (C m) = m i
class MonadT t where
  
  lift :: (Monad m) => m a -> t m a
instance MonadT IdT            where lift m = IT m
instance MonadT (ReaderT    i) where lift m = R (\_ -> m)
instance MonadT (StateT     i) where lift m = S (\s -> liftM (\a -> (a,s)) m)
instance (Monoid i)
      => MonadT (WriterT i)    where lift m = W (liftM (\a -> P a mempty) m)
instance MonadT (ExceptionT i) where lift m = X (liftM Right m)
instance MonadT ChoiceT        where lift m = ChoiceEff (liftM Answer m)
instance MonadT (ContT      i) where lift m = C (\k -> m >>= k)
t_inBase   :: (MonadT t, BaseM m n) => n a -> t m a
t_inBase m  = lift (inBase m)
t_return   :: (MonadT t, Monad m) => a -> t m a
t_return x  = lift (return x)
t_fail     :: (MonadT t, Monad m) => String -> t m a
t_fail x    = lift (fail x)
t_mzero    :: (MonadT t, MonadPlus m) => t m a
t_mzero     = lift mzero
t_ask      :: (MonadT t, ReaderM m i) => t m i
t_ask       = lift ask
t_put      :: (MonadT t, WriterM m i) => i -> t m ()
t_put x     = lift (put x)
t_get      :: (MonadT t, StateM m i) => t m i
t_get       = lift get
t_set      :: (MonadT t, StateM m i) => i -> t m ()
t_set i     = lift (set i)
t_raise    :: (MonadT t, ExceptionM m i) => i -> t m a
t_raise i   = lift (raise i)
t_abort    :: (MonadT t, AbortM m i) => i -> t m a
t_abort i   = lift (abort i)
class (Monad m, Monad n) => BaseM m n | m -> n where
  
  inBase :: n a -> m a
instance BaseM IO IO         where inBase = id
instance BaseM Maybe Maybe   where inBase = id
instance BaseM [] []         where inBase = id
instance BaseM Id Id         where inBase = id
instance BaseM Lift Lift     where inBase = id
instance BaseM (ST s) (ST s) where inBase = id
instance (BaseM m n) => BaseM (IdT          m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (ReaderT    i m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (StateT     i m) n where inBase = t_inBase
instance (BaseM m n,Monoid i)
                     => BaseM (WriterT i m) n    where inBase = t_inBase
instance (BaseM m n) => BaseM (ExceptionT i m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (ChoiceT      m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (ContT      i m) n where inBase = t_inBase
instance Monad Id where
  return x = I x
  fail x   = error x
  m >>= k  = k (runId m)
instance Monad Lift where
  return x  = L x
  fail x    = error x
  L x >>= k = k x     
                      
instance (Monad m) => Monad (IdT m) where
  return  = t_return
  fail    = t_fail
  m >>= k = IT (runIdT m >>= (runIdT . k))
instance (Monad m) => Monad (ReaderT i m) where
  return  = t_return
  fail    = t_fail
  m >>= k = R (\r -> runReaderT r m >>= \a -> runReaderT r (k a))
instance (Monad m) => Monad (StateT i m) where
  return  = t_return
  fail    = t_fail
  m >>= k = S (\s -> runStateT s m >>= \ ~(a,s') -> runStateT s' (k a))
instance (Monad m,Monoid i) => Monad (WriterT i m) where
  return  = t_return
  fail    = t_fail
  m >>= k = W $ unW m     >>= \ ~(P a w1) ->
                unW (k a) >>= \ ~(P b w2) ->
                return (P b (mappend w1 w2))
instance (Monad m) => Monad (ExceptionT i m) where
  return  = t_return
  fail    = t_fail
  m >>= k = X $ runExceptionT m >>= \e ->
                case e of
                  Left x  -> return (Left x)
                  Right a -> runExceptionT (k a)
instance (Monad m) => Monad (ChoiceT m) where
  return x  = Answer x
  fail x    = lift (fail x)
  Answer a  >>= k     = k a
  NoAnswer >>= _      = NoAnswer
  Choice m1 m2 >>= k  = Choice (m1 >>= k) (m2 >>= k)
  ChoiceEff m >>= k   = ChoiceEff (liftM (>>= k) m)
instance (Monad m) => Monad (ContT i m) where
  return  = t_return
  fail    = t_fail
  m >>= k = C $ \c -> runContT (\a -> runContT c (k a)) m
instance                       Functor Id               where fmap = liftM
instance                       Functor Lift             where fmap = liftM
instance (Monad m)          => Functor (IdT          m) where fmap = liftM
instance (Monad m)          => Functor (ReaderT    i m) where fmap = liftM
instance (Monad m)          => Functor (StateT     i m) where fmap = liftM
instance (Monad m,Monoid i) => Functor (WriterT i m)    where fmap = liftM
instance (Monad m)          => Functor (ExceptionT i m) where fmap = liftM
instance (Monad m)          => Functor (ChoiceT      m) where fmap = liftM
instance (Monad m)          => Functor (ContT      i m) where fmap = liftM
instance              Applicative Id            where (<*>) = ap; pure = return
instance              Applicative Lift          where (<*>) = ap; pure = return
instance (Monad m) => Applicative (IdT m)       where (<*>) = ap; pure = return
instance (Monad m) => Applicative (ReaderT i m) where (<*>) = ap; pure = return
instance (Monad m) => Applicative (StateT i m)  where (<*>) = ap; pure = return
instance (Monad m,Monoid i)
                   => Applicative (WriterT i m) where (<*>) = ap; pure = return
instance (Monad m) => Applicative (ExceptionT i m)
                                                where (<*>) = ap; pure = return
instance (Monad m) => Applicative (ChoiceT m)   where (<*>) = ap; pure = return
instance (Monad m) => Applicative (ContT i m)   where (<*>) = ap; pure = return
instance (MonadPlus m)
           => Alternative (IdT m)           where (<|>) = mplus; empty = mzero
instance (MonadPlus m)
           => Alternative (ReaderT i m)     where (<|>) = mplus; empty = mzero
instance (MonadPlus m)
           => Alternative (StateT i m)      where (<|>) = mplus; empty = mzero
instance (MonadPlus m,Monoid i)
           => Alternative (WriterT i m)     where (<|>) = mplus; empty = mzero
instance (MonadPlus m)
           => Alternative (ExceptionT i m)  where (<|>) = mplus; empty = mzero
instance (Monad m)
           => Alternative (ChoiceT m)       where (<|>) = mplus; empty = mzero
instance (MonadPlus m)
           => Alternative (ContT i m)       where (<|>) = mplus; empty = mzero
instance MonadFix Id where
  mfix f  = let m = f (runId m) in m
instance MonadFix Lift where
  mfix f  = let m = f (runLift m) in m
instance (MonadFix m) => MonadFix (IdT m) where
  mfix f  = IT (mfix (runIdT . f))
instance (MonadFix m) => MonadFix (ReaderT i m) where
  mfix f  = R $ \r -> mfix (runReaderT r . f)
instance (MonadFix m) => MonadFix (StateT i m) where
  mfix f  = S $ \s -> mfix (runStateT s . f . fst)
instance (MonadFix m,Monoid i) => MonadFix (WriterT i m) where
  mfix f  = W $ mfix (unW . f . val)
    where val ~(P a _) = a
instance (MonadFix m) => MonadFix (ExceptionT i m) where
  mfix f  = X $ mfix (runExceptionT . f . fromRight)
    where fromRight (Right a) = a
          fromRight _         = error "ExceptionT: mfix looped."
instance (MonadPlus m) => MonadPlus (IdT m) where
  mzero               = t_mzero
  mplus (IT m) (IT n) = IT (mplus m n)
instance (MonadPlus m) => MonadPlus (ReaderT i m) where
  mzero             = t_mzero
  mplus (R m) (R n) = R (\r -> mplus (m r) (n r))
instance (MonadPlus m) => MonadPlus (StateT i m) where
  mzero             = t_mzero
  mplus (S m) (S n) = S (\s -> mplus (m s) (n s))
instance (MonadPlus m,Monoid i) => MonadPlus (WriterT i m) where
  mzero               = t_mzero
  mplus (W m) (W n) = W (mplus m n)
instance (MonadPlus m) => MonadPlus (ExceptionT i m) where
  mzero             = t_mzero
  mplus (X m) (X n) = X (mplus m n)
instance (Monad m) => MonadPlus (ChoiceT m) where
  mzero             = NoAnswer
  mplus m n         = Choice m n
instance (MonadPlus m) => MonadPlus (ContT i m) where
  mzero             = t_mzero
  mplus (C m) (C n) = C (\k -> m k `mplus` n k)
class (Monad m) => ReaderM m i | m -> i where
  
  ask :: m i
instance (Monad m) => ReaderM (ReaderT i m) i where
  ask = R return
instance (ReaderM m j) => ReaderM (IdT m) j           where ask = t_ask
instance (ReaderM m j,Monoid i)
                       => ReaderM (WriterT i m) j     where ask = t_ask
instance (ReaderM m j) => ReaderM (StateT i m) j      where ask = t_ask
instance (ReaderM m j) => ReaderM (ExceptionT i m) j  where ask = t_ask
instance (ReaderM m j) => ReaderM (ChoiceT m) j       where ask = t_ask
instance (ReaderM m j) => ReaderM (ContT i m) j       where ask = t_ask
class (Monad m) => WriterM m i | m -> i where
  
  put  :: i -> m ()
instance (Monad m,Monoid i) => WriterM (WriterT i m) i where
  put x = W (return (P () x))
instance (WriterM m j) => WriterM (IdT          m) j where put = t_put
instance (WriterM m j) => WriterM (ReaderT    i m) j where put = t_put
instance (WriterM m j) => WriterM (StateT     i m) j where put = t_put
instance (WriterM m j) => WriterM (ExceptionT i m) j where put = t_put
instance (WriterM m j) => WriterM (ChoiceT      m) j where put = t_put
instance (WriterM m j) => WriterM (ContT      i m) j where put = t_put
class (Monad m) => StateM m i | m -> i where
  
  get :: m i
  
  set :: i -> m ()
instance (Monad m) => StateM (StateT i m) i where
  get   = S (\s -> return (s,s))
  set s = S (\_ -> return ((),s))
instance (StateM m j) => StateM (IdT m) j where
  get = t_get; set = t_set
instance (StateM m j) => StateM (ReaderT i m) j where
  get = t_get; set = t_set
instance (StateM m j,Monoid i) => StateM (WriterT i m) j where
  get = t_get; set = t_set
instance (StateM m j) => StateM (ExceptionT i m) j where
  get = t_get; set = t_set
instance (StateM m j) => StateM (ChoiceT m) j where
  get = t_get; set = t_set
instance (StateM m j) => StateM (ContT i m) j where
  get = t_get; set = t_set
class (Monad m) => ExceptionM m i | m -> i where
  
  raise :: i -> m a
#ifdef USE_BASE3
instance ExceptionM IO IO.Exception where
  raise = IO.throwIO
#else
instance ExceptionM IO IO.SomeException where
  raise = IO.throwIO
#endif
instance (Monad m) => ExceptionM (ExceptionT i m) i where
  raise x = X (return (Left x))
instance (ExceptionM m j) => ExceptionM (IdT m) j where
  raise = t_raise
instance (ExceptionM m j) => ExceptionM (ReaderT i m) j where
  raise = t_raise
instance (ExceptionM m j,Monoid i) => ExceptionM (WriterT i m) j where
  raise = t_raise
instance (ExceptionM m j) => ExceptionM (StateT  i m) j where
  raise = t_raise
instance (ExceptionM m j) => ExceptionM (ChoiceT   m) j where
  raise = t_raise
instance (ExceptionM m j) => ExceptionM (ContT   i m) j where
  raise = t_raise
class Monad m => ContM m where
  
  callCC :: ((a -> m b) -> m a) -> m a
instance (ContM m) => ContM (IdT m) where
  callCC f = IT $ callCC $ \k -> runIdT $ f $ \a -> lift $ k a
instance (ContM m) => ContM (ReaderT i m) where
  callCC f = R $ \r -> callCC $ \k -> runReaderT r $ f $ \a -> lift $ k a
instance (ContM m) => ContM (StateT i m) where
  callCC f = S $ \s -> callCC $ \k -> runStateT s $ f $ \a -> lift $ k (a,s)
instance (ContM m,Monoid i) => ContM (WriterT i m) where
  callCC f = W $ callCC $ \k -> unW $ f $ \a -> lift $ k (P a mempty)
instance (ContM m) => ContM (ExceptionT i m) where
  callCC f = X $ callCC $ \k -> runExceptionT $ f $ \a -> lift $ k $ Right a
instance (ContM m) => ContM (ChoiceT m) where
  callCC f = ChoiceEff $ callCC $ \k -> return $ f $ \a -> lift $ k $ Answer a
    
instance (Monad m) => ContM (ContT i m) where
  callCC f = C $ \k -> runContT k $ f $ \a -> C $ \_ -> k a
class (ReaderM m i) => RunReaderM m i | m -> i where
  
  local        :: i -> m a -> m a
  
instance (Monad m)        => RunReaderM (ReaderT    i m) i where
  local i m     = lift (runReaderT i m)
instance (RunReaderM m j) => RunReaderM (IdT m) j where
  local i (IT m) = IT (local i m)
instance (RunReaderM m j,Monoid i) => RunReaderM (WriterT i m) j where
  local i (W m) = W (local i m)
instance (RunReaderM m j) => RunReaderM (StateT     i m) j where
  local i (S m) = S (local i . m)
instance (RunReaderM m j) => RunReaderM (ExceptionT i m) j where
  local i (X m) = X (local i m)
class WriterM m i => RunWriterM m i | m -> i where
  
  collect :: m a -> m (a,i)
instance (Monad m,Monoid i) => RunWriterM (WriterT i m) i where
  collect m = lift (runWriterT m)
instance (RunWriterM m j) => RunWriterM (IdT m) j where
  collect (IT m) = IT (collect m)
instance (RunWriterM m j) => RunWriterM (ReaderT i m) j where
  collect (R m) = R (collect . m)
instance (RunWriterM m j) => RunWriterM (StateT i m) j where
  collect (S m) = S (liftM swap . collect . m)
    where swap (~(a,s),w) = ((a,w),s)
instance (RunWriterM m j) => RunWriterM (ExceptionT i m) j where
  collect (X m) = X (liftM swap (collect m))
    where swap (Right a,w)  = Right (a,w)
          swap (Left x,_)   = Left x
class ExceptionM m i => RunExceptionM m i | m -> i where
  
  
  
  
  try :: m a -> m (Either i a)
#ifdef USE_BASE3
instance RunExceptionM IO IO.Exception where
  try = IO.try
#else
instance RunExceptionM IO IO.SomeException where
  try = IO.try
#endif
instance (Monad m) => RunExceptionM (ExceptionT i m) i where
  try m = lift (runExceptionT m)
instance (RunExceptionM m i) => RunExceptionM (IdT m) i where
  try (IT m) = IT (try m)
instance (RunExceptionM m i) => RunExceptionM (ReaderT j m) i where
  try (R m) = R (try . m)
instance (RunExceptionM m i,Monoid j) => RunExceptionM (WriterT j m) i where
  try (W m) = W (liftM swap (try m))
    where swap (Right (P a w))  = P (Right a) w
          swap (Left e)         = P (Left e) mempty
instance (RunExceptionM m i) => RunExceptionM (StateT j m) i where
  try (S m) = S (\s -> liftM (swap s) (try (m s)))
    where swap _ (Right ~(a,s)) = (Right a,s)
          swap s (Left e)       = (Left e, s)
class Monad m => AbortM m i where
  
  abort :: i -> m a
instance Monad m => AbortM (ContT i m) i where
  abort i = C (\_ -> return i)
instance AbortM IO ExitCode where
  abort = exitWith
instance AbortM m i => AbortM (IdT m) i           where abort = t_abort
instance AbortM m i => AbortM (ReaderT j m) i     where abort = t_abort
instance (AbortM m i,Monoid j)
                    => AbortM (WriterT j m) i     where abort = t_abort
instance AbortM m i => AbortM (StateT j m) i      where abort = t_abort
instance AbortM m i => AbortM (ExceptionT j m) i  where abort = t_abort
instance AbortM m i => AbortM (ChoiceT m) i       where abort = t_abort
newtype Label m a    = Lab ((a, Label m a) -> m ())
labelCC            :: (ContM m) => a -> m (a, Label m a)
labelCC x           = callCC (\k -> return (x, Lab k))
jump               :: (ContM m) => a -> Label m a -> m b
jump x (Lab k)      = k (x, Lab k) >> return unreachable
  where unreachable = error "(bug) jump: unreachable"
asks :: ReaderM m r => (r -> a) -> m a
asks f      = do r <- ask
                 return (f r)
puts :: WriterM m w => (a,w) -> m a
puts ~(a,w) = put w >> return a
sets :: StateM m s => (s -> (a,s)) -> m a
sets f      = do s <- get
                 let (a,s1) = f s
                 set s1
                 return a
sets_ :: StateM m s => (s -> s) -> m ()
sets_ f     = do s <- get
                 set (f s)
raises :: ExceptionM m x => Either x a -> m a
raises (Right a)  = return a
raises (Left x)   = raise x
mapReader        :: RunReaderM m r => (r -> r) -> m a -> m a
mapReader f m     = do r <- ask
                       local (f r) m
mapWriter        :: RunWriterM m w => (w -> w) -> m a -> m a
mapWriter f m     = do ~(a,w) <- collect m
                       put (f w)
                       return a
mapException     :: RunExceptionM m x => (x -> x) -> m a -> m a
mapException f m  = do r <- try m
                       case r of
                         Right a -> return a
                         Left x  -> raise (f x)
handle           :: RunExceptionM m x => m a -> (x -> m a) -> m a
handle m f        = do r <- try m
                       case r of
                         Right a -> return a
                         Left x  -> f x