-- |
-- Module:     Control.ContStuff.Trans
-- Copyright:  (c) 2010 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
-- Stability:  experimental
--
-- This module implements a number of monad transformers using a CPS
-- approach internally.

{-# LANGUAGE
  FlexibleInstances,
  MultiParamTypeClasses,
  RankNTypes,
  TypeFamilies #-}

module Control.ContStuff.Trans
    ( -- * Monad transformers
      -- ** Identity transformer
      IdentityT(..),

      -- ** ContT
      ContT(..),
      runContT, evalContT, modifyContT,

      -- ** Choice/nondeterminism
      ChoiceT(..),
      runChoiceT, choice, findAll, findAll_, findFirst,
      findFirst_, listA, listChoiceT, maybeChoiceT,

      -- ** Exceptions
      EitherT(..),
      runEitherT, evalEitherT, modifyEitherT, testEitherT,

      MaybeT(..),
      runMaybeT, evalMaybeT, modifyMaybeT, testMaybeT,

      -- ** State
      ReaderT,
      forkReaderT,
      runReaderT,

      StateT(..),
      runStateT, evalStateT, execStateT,

      -- ** Writer monads
      WriterT,
      runWriterT,

      OldWriterT,
      runOldWriterT, evalOldWriterT, execOldWriterT
    )
    where

import Control.Applicative
import Control.Arrow
import Control.Concurrent hiding (forkIO, forkOS)
import Control.ContStuff.Classes
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Identity
import Data.Monoid


-------------
-- ChoiceT --
-------------

-- | The choice monad transformer, which models, as the most common
-- interpretation, nondeterminism.  Internally a list of choices is
-- represented as a CPS-based left-fold function.

newtype ChoiceT r i m a =
    ChoiceT { getChoiceT ::
                  (i -> a -> (i -> m r) -> m r)
                      -> i
                      -> (i -> m r)
                      -> m r }

instance Applicative m => Abortable (ChoiceT r i m) where
    type Result (ChoiceT r i m) = r
    abort x = ChoiceT $ \_ _ _ -> pure x

instance Alternative (ChoiceT r i m) where
    empty = ChoiceT $ \_ z k -> k z
    ChoiceT c <|> ChoiceT d =
        ChoiceT $ \fold z k ->
            c fold z (\zc -> d fold zc k)

instance Applicative (ChoiceT r i m) where
    pure x = ChoiceT $ \fold z k -> fold z x k
    ChoiceT cf <*> ChoiceT cx =
        ChoiceT $ \fold z k ->
            cx (\xx yx kx -> cf (\xf yf kf -> fold xf (yf yx) kf) xx kx) z k

instance (Applicative m, Forkable m) => Forkable (ChoiceT r i m) where
    forkIO = lift . forkIO . findAll_
    forkOS = lift . forkOS . findAll_

instance Functor (ChoiceT r i m) where
    fmap f (ChoiceT c) =
        ChoiceT $ \fold z k ->
            c (\x y k -> fold x (f y) k) z k

instance LiftFunctor (ChoiceT r i) where
    type InnerFunctor (ChoiceT r i) = []
    liftF c = lift c >>= choice

instance Monad (ChoiceT r i m) where
    return x = ChoiceT $ \fold z k -> fold z x k
    ChoiceT c >>= f =
        ChoiceT $ \fold z k ->
            c (\x y kc -> getChoiceT (f y) fold x kc) z k

instance MonadIO m => MonadIO (ChoiceT r i m) where
    liftIO = lift . liftIO

instance MonadPlus (ChoiceT r i m) where
    mzero = empty
    mplus = (<|>)

instance MonadTrans (ChoiceT r i) where
    lift c = ChoiceT $ \fold z k -> c >>= \x -> fold z x k


-- | Run a choice computation.

runChoiceT ::
    (i -> a -> (i -> m r) -> m r)
        -> i
        -> (i -> m r)
        -> ChoiceT r i m a
        -> m r
runChoiceT fold z k (ChoiceT c) = c fold z k


-- | Turn a list into a 'ChoiceT' computation efficiently.

choice :: [a] -> ChoiceT r i m a
choice xs = ChoiceT (choice' xs)
    where
    choice' []     = \_ z k -> k z
    choice' (x:xs) = \fold z k -> fold z x (\y -> choice' xs fold y k)


-- | Find all solutions.

findAll :: (Alternative f, Applicative m) => ChoiceT (f a) (f a) m a -> m (f a)
findAll = runChoiceT (\x y k -> k (x <|> pure y)) empty pure


-- | Find all solutions and ignore them.

findAll_ :: Applicative m => ChoiceT r i m a -> m ()
findAll_ =
    (() <$) .
    runChoiceT (\_ _ k -> k undef) undef (const $ pure undef)


-- | Find the first solution.

findFirst :: (Alternative f, Applicative m) => ChoiceT (f a) (f a) m a -> m (f a)
findFirst = runChoiceT (\_ y _ -> pure (pure y)) empty pure


-- | Find the first solution and ignore it.

findFirst_ :: Applicative m => ChoiceT r i m a -> m ()
findFirst_ =
    (() <$) .
    runChoiceT (\_ _ _ -> pure undef) undef (const $ pure undef)


-- | Turn a list into a computation with alternatives.

listA :: Alternative f => [a] -> f a
listA = foldr (<|>) empty . map pure


-- | Get list of solutions (faster than 'findAll', but returns solutions
-- in reversed order).

listChoiceT :: Applicative m => ChoiceT [a] [a] m a -> m [a]
listChoiceT = runChoiceT (\x y k -> k (y:x)) [] pure


-- | Get one solution (faster than 'findFirst').

maybeChoiceT :: Applicative m => ChoiceT (Maybe a) (Maybe a) m a -> m (Maybe a)
maybeChoiceT = runChoiceT (\_ y _ -> pure (Just y)) Nothing pure


-----------
-- ContT --
-----------

-- | The continuation passing style monad transformer.  This monad
-- transformer models the most basic form of CPS.

newtype ContT r m a =
    ContT { getContT :: (a -> m r) -> m r }

instance Applicative m => Abortable (ContT r m) where
    type Result (ContT r m) = r
    abort = ContT . const . pure

instance Alternative m => Alternative (ContT r m) where
    empty = ContT $ const empty
    ContT c <|> ContT d = ContT $ \k -> c k <|> d k

instance Applicative (ContT r m) where
    pure = return
    ContT cf <*> ContT cx =
        ContT $ \k -> cf (\f -> cx (\x -> k (f x)))

instance CallCC (ContT r m) where
    callCC f = ContT $ \k -> getContT (f (ContT . const . k)) k

instance Forkable m => Forkable (ContT () m) where
    forkIO (ContT c) = ContT $ \k -> forkIO (c toUnitM) >>= k
    forkOS (ContT c) = ContT $ \k -> forkOS (c toUnitM) >>= k

instance Functor (ContT r m) where
    fmap f (ContT c) = ContT $ \k -> c (\x -> k (f x))

instance Monad (ContT r m) where
    return x = ContT $ \k -> k x
    ContT c >>= f =
        ContT $ \k -> c (\x -> getContT (f x) k)

instance Alternative m => MonadPlus (ContT r m) where
    mzero = empty
    mplus = (<|>)

instance MonadIO m => MonadIO (ContT r m) where
    liftIO = lift . liftIO

instance MonadTrans (ContT r) where
    lift c = ContT $ \k -> c >>= k

instance Alternative m => Writable (ContT r m) r where
    tell x = ContT $ \k -> pure x <|> k ()

instance (Functor m, Monoid w) => Writable (ContT (r, w) m) w where
    tell x = ContT $ \k -> fmap (second (`mappend` x)) (k ())


-- | Run a CPS-style computation given the supplied final continuation.

runContT :: (a -> m r) -> ContT r m a -> m r
runContT k (ContT c) = c k


-- | Evaluate a CPS-style computation to its final result.

evalContT :: Applicative m => ContT r m r -> m r
evalContT (ContT c) = c pure


-- | Transform the final result along the way.

modifyContT :: Functor m => (r -> r) -> ContT r m ()
modifyContT f = ContT $ \k -> fmap f (k ())


-------------
-- EitherT --
-------------

-- | Monad transformer for CPS computations with an additional exception
-- continuation.

newtype EitherT r e m a =
    EitherT { getEitherT :: (a -> m r) -> (e -> m r) -> m r }

instance Applicative m => Abortable (EitherT r e m) where
    type Result (EitherT r e m) = r
    abort x = EitherT $ \_ _ -> pure x

instance Applicative (EitherT r e m) where
    pure x = EitherT $ \k _ -> k x
    EitherT cf <*> EitherT cx =
        EitherT $ \k expk -> cf (\f -> cx (\x -> k (f x)) expk) expk

instance Alternative m => Alternative (EitherT r e m) where
    empty = EitherT $ \_ _ -> empty
    EitherT c <|> EitherT d =
        EitherT $ \k expk -> c k expk <|> d k expk

instance CallCC (EitherT r e m) where
    callCC f =
        EitherT $ \k expk ->
            getEitherT (f (\x -> EitherT $ \_ _ -> k x)) k expk

instance Forkable m => Forkable (EitherT () e m) where
    forkIO (EitherT c) = lift . forkIO $ c toUnitM toUnitM
    forkOS (EitherT c) = lift . forkOS $ c toUnitM toUnitM

instance Functor (EitherT r e m) where
    fmap f (EitherT c) =
        EitherT $ \k expk -> c (k . f) expk

instance HasExceptions (EitherT r e m) where
    type Exception (EitherT r e m) = e
    raise exp = EitherT $ \_ expk -> expk exp
    try (EitherT c) = EitherT $ \k _ -> c (k . Right) (k . Left)

instance LiftFunctor (EitherT r e) where
    type InnerFunctor (EitherT r e) = Either e
    liftF c = EitherT $ \k expk -> c >>= either expk k

instance Monad (EitherT r e m) where
    return x = EitherT $ \k _ -> k x
    EitherT c >>= f =
        EitherT $ \k expk ->
            c (\x -> getEitherT (f x) k expk) expk

instance MonadIO m => MonadIO (EitherT r e m) where
    liftIO = lift . liftIO

instance Alternative m => MonadPlus (EitherT r e m) where
    mzero = empty
    mplus = (<|>)

instance MonadTrans (EitherT r e) where
    lift c = EitherT $ \k _ -> c >>= k

instance Alternative m => Writable (EitherT r e m) r where
    tell x = EitherT $ \k _ -> pure x <|> k ()

instance (Functor m, Monoid w) => Writable (EitherT (r, w) e m) w where
    tell x = EitherT $ \k _ -> fmap (second (`mappend` x)) (k ())


-- | Run an 'EitherT' transformer.

runEitherT :: (a -> m r) -> (e -> m r) -> EitherT r e m a -> m r
runEitherT k expk (EitherT c) = c k expk


-- | Run an 'EitherT' transformer returning an 'Either' result.

evalEitherT :: Applicative m => EitherT (Either e a) e m a -> m (Either e a)
evalEitherT (EitherT c) = c (pure . Right) (pure . Left)


-- | Modify the result of an 'EitherT' computation along the way.

modifyEitherT :: Functor m => (r -> r) -> EitherT r e m ()
modifyEitherT f = EitherT $ \k _ -> fmap f (k ())


-- | Run the 'EitherT' computation and return 'True', if it results in a
-- right value, 'False' otherwise.

testEitherT :: Applicative m => EitherT Bool e m a -> m Bool
testEitherT =
    let pc :: Applicative m => a -> b -> m a
        pc x = pure . const x
    in runEitherT (pc True) (pc False)


------------
-- MaybeT --
------------

-- | Monad transformer for CPS computations with an additional exception
-- continuation with no argument.

newtype MaybeT r m a =
    MaybeT { getMaybeT :: (a -> m r) -> m r -> m r }

instance Applicative m => Abortable (MaybeT r m) where
    type Result (MaybeT r m) = r
    abort x = MaybeT $ \_ _ -> pure x

instance Applicative (MaybeT r m) where
    pure x = MaybeT $ \just _ -> just x
    MaybeT cf <*> MaybeT cx =
        MaybeT $ \just noth -> cf (\f -> cx (\x -> just (f x)) noth) noth

instance Alternative (MaybeT r m) where
    empty = MaybeT $ \_ noth -> noth
    MaybeT c <|> MaybeT d =
        MaybeT $ \just noth ->
            c (\x -> just x) (d (\x -> just x) noth)

instance CallCC (MaybeT r m) where
    callCC f =
        MaybeT $ \just noth ->
            getMaybeT (f (\x -> MaybeT $ \_ _ -> just x)) just noth

instance Forkable m => Forkable (MaybeT () m) where
    forkIO (MaybeT c) = lift . forkIO $ c toUnitM (return ())
    forkOS (MaybeT c) = lift . forkOS $ c toUnitM (return ())

instance HasExceptions (MaybeT r m) where
    type Exception (MaybeT r m) = ()
    raise _ = MaybeT $ const id
    try (MaybeT c) = MaybeT $ \just _ -> c (just . Right) (just $ Left ())

instance Functor (MaybeT r m) where
    fmap f (MaybeT c) =
        MaybeT $ \just noth -> c (just . f) noth

instance LiftFunctor (MaybeT r) where
    type InnerFunctor (MaybeT r) = Maybe
    liftF c = MaybeT $ \just nothing -> c >>= maybe nothing just

instance Monad (MaybeT r m) where
    return x = MaybeT $ \just _ -> just x
    MaybeT c >>= f =
        MaybeT $ \just noth ->
            c (\x -> getMaybeT (f x) just noth) noth

instance MonadIO m => MonadIO (MaybeT r m) where
    liftIO = lift . liftIO

instance Alternative m => MonadPlus (MaybeT r m) where
    mzero = empty
    mplus = (<|>)

instance MonadTrans (MaybeT r) where
    lift c = MaybeT $ \just _ -> c >>= just

instance Alternative m => Writable (MaybeT r m) r where
    tell x = MaybeT $ \just _ -> pure x <|> just ()

instance (Functor m, Monoid w) => Writable (MaybeT (r, w) m) w where
    tell x = MaybeT $ \just _ -> fmap (second (`mappend` x)) (just ())


-- | Run a 'MaybeT' transformer.

runMaybeT :: (a -> m r) -> m r -> MaybeT r m a -> m r
runMaybeT just noth (MaybeT c) = c just noth


-- | Run a 'MaybeT' transformer returning a 'Maybe' result.

evalMaybeT :: Applicative m => MaybeT (Maybe a) m a -> m (Maybe a)
evalMaybeT (MaybeT c) = c (pure . Just) (pure Nothing)


-- | Modify the result of a 'MaybeT' computation along the way.

modifyMaybeT :: Functor m => (r -> r) -> MaybeT r m ()
modifyMaybeT f = MaybeT $ \just _ -> fmap f (just ())


-- | Run the 'MaybeT' computation and return 'True', if it results in a
-- Just value, 'False' otherwise.

testMaybeT :: Applicative m => MaybeT Bool m a -> m Bool
testMaybeT = runMaybeT (pure . const True) (pure False)


----------------
-- OldWriterT --
----------------

-- | The traditional writer monad transformer.

type OldWriterT r w m a = ContT (r, w) m a


-- | Run a traditional writer transformer.

runOldWriterT :: (Applicative m, Monoid w) => OldWriterT r w m r -> m (r, w)
runOldWriterT (ContT c) = c (\x -> pure (x, mempty))


-- | Run a traditional writer transformer and return its result.

evalOldWriterT :: (Applicative m, Monoid w) => OldWriterT r w m r -> m r
evalOldWriterT = fmap fst . runOldWriterT


-- | Run a traditional writer transformer and return its log.

execOldWriterT :: (Applicative m, Monoid w) => OldWriterT r w m r -> m w
execOldWriterT = fmap snd . runOldWriterT


-------------
-- ReaderT --
-------------

-- | Monad transformer for computations with readable environment.
-- Unlike the other monad transformers this one allows no CPS effects
-- and also hides its constructors, which makes it commutative.
--
-- If you need CPS effects, consider using 'StateT'.

newtype ReaderT e m a =
    ReaderT { getReaderT :: forall r. StateT r e m a }

instance Applicative (ReaderT e m) where
    pure = return
    ReaderT cf <*> ReaderT cx = ReaderT (cf <*> cx)

instance Functor (ReaderT e m) where
    fmap f (ReaderT c) = ReaderT (fmap f c)

instance Monad (ReaderT e m) where
    return x = ReaderT (return x)
    ReaderT c >>= f =
        ReaderT (c >>= getReaderT . f)

instance MonadIO m => MonadIO (ReaderT e m) where
    liftIO c = ReaderT (liftIO c)

instance Readable (ReaderT e m) where
    type StateOf (ReaderT e m) = e
    get = ReaderT get

instance MonadTrans (ReaderT e) where
    lift c = ReaderT (lift c)


-- | Fork a concurrent thread for a computation with environment.

forkReaderT :: (Applicative m, Forkable m) => ReaderT e m () -> ReaderT e m ThreadId
forkReaderT c = do
    env <- get
    lift $ forkIO (runReaderT env c)


-- | Run a computation with environment.

runReaderT :: Applicative m => e -> ReaderT e m a -> m a
runReaderT x (ReaderT s) = evalStateT x s


------------
-- StateT --
------------

-- | Monad transformer for stateful computations.

newtype StateT r s m a =
    StateT { getStateT :: (a -> s -> m r) -> s -> m r }

instance Applicative m => Abortable (StateT r s m) where
    type Result (StateT r s m) = r
    abort x = StateT $ \_ _ -> pure x

instance Alternative m => Alternative (StateT r s m) where
    empty = StateT . const . const $ empty
    StateT c <|> StateT d =
        StateT $ \k s0 -> c k s0 <|> d k s0

instance Applicative (StateT r s m) where
    pure = return
    StateT cf <*> StateT cx = StateT $ \k -> cf (\f -> cx (k . f))

instance CallCC (StateT r s m) where
    callCC f = StateT $ \k -> getStateT (f (\x -> StateT $ \_ -> k x)) k

instance Forkable m => Forkable (StateT () s m) where
    forkIO (StateT c) = StateT $ \k s0 -> forkIO (c (\_ _ -> return ()) s0) >>= flip k s0
    forkOS (StateT c) = StateT $ \k s0 -> forkOS (c (\_ _ -> return ()) s0) >>= flip k s0

instance Functor (StateT r s m) where
    fmap f (StateT c) = StateT $ \k -> c (k . f)

instance Monad (StateT r s m) where
    return x = StateT ($ x)
    StateT c >>= f = StateT $ \k -> c (\x -> getStateT (f x) k)

instance MonadIO m => MonadIO (StateT r s m) where
    liftIO = lift . liftIO

instance Alternative m => MonadPlus (StateT r s m) where
    mzero = empty
    mplus = (<|>)

instance Readable (StateT r s m) where
    type StateOf (StateT r s m) = s
    get = StateT $ \k s0 -> k s0 s0

instance Stateful (StateT r s m) where
    put s1 = s1 `seq` StateT $ \k -> const (k () s1)
    putLazy s1 = StateT $ \k -> const (k () s1)

instance MonadTrans (StateT r s) where
    lift c = StateT $ \k s0 -> c >>= flip k s0

instance Alternative m => Writable (StateT r s m) r where
    tell x = StateT $ \k s0 -> pure x <|> k () s0

instance (Functor m, Monoid w) => Writable (StateT (r, w) s m) w where
    tell x = StateT $ \k -> fmap (second (`mappend` x)) . k ()


-- | Run a state transformer.

runStateT :: s -> (a -> s -> m r) -> StateT r s m a -> m r
runStateT s0 k (StateT c) = c k s0


-- | Run a state transformer returning its result.

evalStateT :: Applicative m => s -> StateT r s m r -> m r
evalStateT s0 (StateT c) = c (\x -> const (pure x)) s0


-- | Run a state transformer returning its final state.

execStateT :: Applicative m => s -> StateT s s m a -> m s
execStateT s0 (StateT c) = c (\_ s1 -> pure s1) s0


-------------
-- WriterT --
-------------

-- | The writer monad transformer.  Supports logging effects.

type WriterT = ContT


-- | Run a writer transformer.

runWriterT :: Alternative m => WriterT r m a -> m r
runWriterT (ContT c) = c (const empty)


----------------------
-- Helper functions --
----------------------

-- | Turn an arbitrary pure value into a monadic bottom.

toUnitM :: Monad m => a -> m ()
toUnitM = return . const undef


-- | The undefined value with a more descriptive error message.

undef :: a
undef = error "contstuff: Undefined value evaluated. This is a bug!"