-- |
-- Module:     Control.ContStuff
-- Copyright:  (c) 2010 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
-- Stability:  experimental
--
-- This module implements a number of monad transformers using a CPS
-- approach internally.

{-# LANGUAGE
  FlexibleInstances,
  MultiParamTypeClasses,
  TypeFamilies #-}

module Control.ContStuff
  ( -- * The identity monad
    Id(..),

    -- * Monad transformers
    -- ** Identity transformer
    IdT(..),
    -- ** Basic CPS monads
    Cont, runCont, evalCont, modifyCont,
    ContT(..), runContT, evalContT, modifyContT,
    -- ** Advanced CPS monads
    -- *** Choice/nondeterminism
    ChoiceT(..), runChoiceT, findFirst, findAll, listChoiceT, listA,
    -- *** State
    State, runState, evalState, execState,
    StateT(..), runStateT, evalStateT, execStateT,
    -- ** Writer monads
    OldWriter, runOldWriter, evalOldWriter, execOldWriter,
    OldWriterT, runOldWriterT, evalOldWriterT, execOldWriterT,
    WriterT, runWriterT,

    -- * Effect classes
    Abortable(..),
    CallCC(..), Label, labelCC, goto,
    LiftBase(..), io,
    Runnable(..),
    Stateful(..), getField, modify, modifyField, modifyFieldLazy, modifyLazy,
    Transformer(..),
    Writable(..),

    -- * Module reexports
    module Control.Applicative,
    module Control.Monad
  )
  where

import Control.Applicative
import Control.Arrow
import Control.Monad
import Control.Monad.Fix
import Control.Monad.ST
import Data.Monoid


-- ================== --
-- The identity monad --
-- ================== --


-- | The identity monad.  This monad represents values themselves,
-- i.e. computations without effects.

newtype Id a = Id { getId :: a }

instance Functor Id where
  fmap f (Id x) = Id (f x)

instance Applicative Id where
  pure = Id
  Id f <*> Id x = Id (f x)

instance Monad Id where
  return = Id
  Id x >>= f = f x

instance MonadFix Id where
  mfix f = fix (f . getId)


-- ================== --
-- Monad transformers --
-- ================== --


-------------
-- ChoiceT --
-------------

-- | The choice monad transformer, which models, as the most common
-- interpretation, nondeterminism.  Internally a list of choices is
-- represented as a CPS-based left-fold function.

newtype ChoiceT r i m a =
  ChoiceT { getChoiceT ::
              (i -> a -> (i -> m r) -> m r)
                -> i
                -> (i -> m r)
                -> m r }


instance Alternative (ChoiceT r i m) where
  empty = ChoiceT $ \_ z k -> k z
  ChoiceT c <|> ChoiceT d =
    ChoiceT $ \fold z k ->
      c fold z (\zc -> d fold zc k)

instance Applicative (ChoiceT r i m) where
  pure x = ChoiceT $ \fold z k -> fold z x k
  ChoiceT cf <*> ChoiceT cx =
    ChoiceT $ \fold z k ->
      cx (\xx yx kx -> cf (\xf yf kf -> fold xf (yf yx) kf) xx kx) z k

instance Functor (ChoiceT r i m) where
  fmap f (ChoiceT c) =
    ChoiceT $ \fold z k ->
      c (\x y k -> fold x (f y) k) z k

instance Monad (ChoiceT r i m) where
  return x = ChoiceT $ \fold z k -> fold z x k
  ChoiceT c >>= f =
    ChoiceT $ \fold z k ->
      c (\x y kc -> getChoiceT (f y) fold x kc) z k

instance Transformer (ChoiceT r i) where
  lift c = ChoiceT $ \fold z k -> c >>= \x -> fold z x k


-- | Run a choice computation.

runChoiceT ::
  (i -> a -> (i -> m r) -> m r)
    -> i
    -> (i -> m r)
    -> ChoiceT r i m a
    -> m r
runChoiceT fold z k (ChoiceT c) = c fold z k


-- | Find the first solution.

findFirst :: (Alternative f, Applicative m) => ChoiceT (f a) (f a) m a -> m (f a)
findFirst = runChoiceT (\_ y _ -> pure (pure y)) empty pure


-- | Find all solutions.

findAll :: (Alternative f, Applicative m) => ChoiceT (f a) (f a) m a -> m (f a)
findAll = runChoiceT (\x y k -> k (x <|> pure y)) empty pure


-- | Get list of solutions (faster than 'findAll', but returns solutions
-- in reversed order).

listChoiceT :: Applicative m => ChoiceT [a] [a] m a -> m [a]
listChoiceT = runChoiceT (\x y k -> k (y:x)) [] pure


-- | Turn a list into a computation with alternatives.

listA :: Alternative f => [a] -> f a
listA = foldr (<|>) empty . map pure


-----------
-- ContT --
-----------

-- | The continuation passing style monad transformer.  This monad
-- transformer models the most basic form of CPS.

newtype ContT r m a =
  ContT { getContT :: (a -> m r) -> m r }

instance Applicative m => Abortable (ContT r m) where
  type Result (ContT r m) = r
  abort = ContT . const . pure

instance Alternative m => Alternative (ContT r m) where
  empty = ContT $ const empty
  ContT c <|> ContT d =
    ContT $ \k -> c (\x -> d (\y -> k x <|> k y))

instance Applicative (ContT r m) where
  pure = return
  ContT cf <*> ContT cx =
    ContT $ \k -> cf (\f -> cx (\x -> k (f x)))

instance CallCC (ContT r m) where
  callCC f = ContT $ \k -> getContT (f (ContT . const . k)) k

instance Functor (ContT r m) where
  fmap f (ContT c) = ContT $ \k -> c (\x -> k (f x))

instance Monad (ContT r m) where
  return x = ContT $ \k -> k x
  ContT c >>= f =
    ContT $ \k -> c (\x -> getContT (f x) k)

instance Runnable (ContT r) r m a where
  type Argument (ContT r) r m a = a -> m r
  runT k (ContT c) = c k

instance Transformer (ContT r) where
  lift c = ContT $ \k -> c >>= k

instance Alternative m => Writable (ContT r m) r where
  tell x = ContT $ \k -> pure x <|> k ()

instance (Functor m, Monoid w) => Writable (ContT (r, w) m) w where
  tell x = ContT $ \k -> fmap (second (`mappend` x)) (k ())


-- | Run a CPS-style computation given the supplied final continuation.

runContT :: (a -> m r) -> ContT r m a -> m r
runContT k (ContT c) = c k


-- | Evaluate a CPS-style computation to its final result.

evalContT :: Applicative m => ContT r m r -> m r
evalContT (ContT c) = c pure


-- | Transform the final result along the way.

modifyContT :: Functor m => (r -> r) -> ContT r m ()
modifyContT f = ContT $ \k -> fmap f (k ())


-- | Pure CPS monad derived from ContT.

type Cont r a = ContT r Id a


-- | Run a pure CPS computation.

runCont :: (a -> r) -> Cont r a -> r
runCont k (ContT c) = getId $ c (Id . k)


-- | Evaluate a pure CPS computation to its final result.

evalCont :: Cont r r -> r
evalCont (ContT c) = getId $ c pure


-- | Modify the result of a CPS computation along the way.

modifyCont :: (r -> r) -> Cont r ()
modifyCont = modifyContT


---------
-- IdT --
---------

-- | The identity monad transformer.  This monad transformer represents
-- computations themselves without further side effects.  Unlike most
-- other monad transformers in this module it is not implemented in
-- terms of continuation passing style.

newtype IdT m a = IdT { getIdT :: m a }

instance Alternative m => Alternative (IdT m) where
  empty = IdT empty
  IdT c <|> IdT d = IdT (c <|> d)

instance Applicative m => Applicative (IdT m) where
  pure = IdT . pure
  IdT cf <*> IdT cx = IdT $ cf <*> cx

instance Functor m => Functor (IdT m) where
  fmap f (IdT c) = IdT (fmap f c)

instance Monad m => Monad (IdT m) where
  return = IdT . return
  IdT c >>= f = IdT $ c >>= getIdT . f

instance MonadFix m => MonadFix (IdT m) where
  mfix f = IdT $ mfix (getIdT . f)

instance Runnable IdT r m r where
  type Argument IdT r m r = ()
  runT _ (IdT c) = c

instance Transformer IdT where
  lift = IdT


----------------
-- OldWriterT --
----------------

-- | The traditional writer monad transformer.

type OldWriterT r w m a = ContT (r, w) m a


-- | Run a traditional writer transformer.

runOldWriterT :: (Applicative m, Monoid w) => OldWriterT r w m r -> m (r, w)
runOldWriterT (ContT c) = c (\x -> pure (x, mempty))


-- | Run a traditional writer transformer and return its result.

evalOldWriterT :: (Applicative m, Monoid w) => OldWriterT r w m r -> m r
evalOldWriterT = fmap fst . runOldWriterT


-- | Run a traditional writer transformer and return its log.

execOldWriterT :: (Applicative m, Monoid w) => OldWriterT r w m r -> m w
execOldWriterT = fmap snd . runOldWriterT


-- | The traditional writer monad.

type OldWriter r w a = ContT (r, w) Id a


-- | Run a traditional writer computation.

runOldWriter :: Monoid w => OldWriter r w r -> (r, w)
runOldWriter = getId . runOldWriterT


-- | Run a traditional writer computation and return its result.

evalOldWriter :: Monoid w => OldWriter r w r -> r
evalOldWriter = fst . getId . runOldWriterT


-- | Run a traditional writer computation and return its log.

execOldWriter :: Monoid w => OldWriter r w r -> w
execOldWriter = snd . getId . runOldWriterT


------------
-- StateT --
------------

newtype StateT r s m a =
  StateT { getStateT :: s -> (s -> a -> m r) -> m r }

instance Applicative m => Abortable (StateT r s m) where
  type Result (StateT r s m) = r
  abort x = StateT $ \_ _ -> pure x

instance Alternative m => Alternative (StateT r s m) where
  empty = StateT . const . const $ empty
  StateT c <|> StateT d =
    StateT $ \s0 k -> c s0 k <|> d s0 k

instance Applicative (StateT r s m) where
  pure = return
  StateT cf <*> StateT cx =
    StateT $ \s0 k -> cf s0 (\s1 f -> cx s1 (\s2 x -> k s2 (f x)))

instance CallCC (StateT r s m) where
  callCC f = StateT $ \s0 k -> getStateT (f (\x -> StateT $ \s1 _ -> k s1 x)) s0 k

instance Functor (StateT r s m) where
  fmap f (StateT c) =
    StateT $ \s0 k -> c s0 (\s1 -> k s1 . f)

instance Monad (StateT r s m) where
  return x = StateT $ \s0 k -> k s0 x
  StateT c >>= f =
    StateT $ \s0 k -> c s0 (\s1 x -> getStateT (f x) s1 k)

instance Runnable (StateT r s) r m a where
  type Argument (StateT r s) r m a = (s, s -> a -> m r)
  runT (s0, k) (StateT c) = c s0 k

instance Stateful (StateT r s m) where
  type StateOf (StateT r s m) = s
  get = StateT $ \s0 k -> k s0 s0
  put s1 = s1 `seq` StateT $ \_ k -> k s1 ()
  putLazy s1 = StateT $ \_ k -> k s1 ()

instance Transformer (StateT r s) where
  lift c = StateT $ \s0 k -> c >>= k s0

instance Alternative m => Writable (StateT r s m) r where
  tell x = StateT $ \s0 k -> pure x <|> k s0 ()

instance (Functor m, Monoid w) => Writable (StateT (r, w) s m) w where
  tell x = StateT $ \s0 k -> fmap (second (`mappend` x)) (k s0 ())


-- | Run a state transformer.

runStateT :: s -> (s -> a -> m r) -> StateT r s m a -> m r
runStateT s0 k (StateT c) = c s0 k


-- | Run a state transformer returning its result.

evalStateT :: Applicative m => s -> StateT r s m r -> m r
evalStateT s0 (StateT c) = c s0 (\_ x -> pure x)


-- | Run a state transformer returning its final state.

execStateT :: Applicative m => s -> StateT s s m a -> m s
execStateT s0 (StateT c) = c s0 (\s1 _ -> pure s1)


-- | Pure state monad derived from StateT.

type State r s a = StateT r s Id a


-- | Run a stateful computation.

runState :: s -> (s -> a -> r) -> State r s a -> r
runState s0 k c = getId $ runStateT s0 (\s1 -> Id . k s1) c


-- | Run a stateful computation returning its result.

evalState :: s -> State r s r -> r
evalState = (getId .) . evalStateT


-- | Run a stateful computation returning its result.

execState :: s -> State s s a -> s
execState = (getId .) . execStateT


-------------
-- WriterT --
-------------

-- | The writer monad transformer.  Supports logging effects.

type WriterT = ContT


-- | Run a writer transformer.

runWriterT :: Alternative m => WriterT r m a -> m r
runWriterT (ContT c) = c (const empty)


-- ============== --
-- Effect classes --
-- ============== --


-- | Monads supporting abortion.

class Abortable m where
  type Result m
  abort :: Result m -> m a


-- | Monads supporting *call-with-current-continuation* (aka callCC).

class CallCC m where
  -- | Call with current continuation.
  callCC :: ((a -> m b) -> m a) -> m a


newtype Label m a = Label (a -> Label m a -> m ())


-- | Capture the current continuation for later use.

labelCC :: (Applicative m, CallCC m) => a -> m (a, Label m a)
labelCC x = callCC $ \k -> pure (x, Label $ curry k)


-- | Jump to a label.

goto :: Applicative m => Label m a -> a -> m b
goto lk@(Label k) x = k x lk *> pure undefined


-- | Monads, which support lifting base monad computations.

class LiftBase m a where
  type Base m a
  base :: Base m a -> m a

instance LiftBase IO a where type Base IO a = IO a; base = id
instance LiftBase Id a where type Base Id a = Id a; base = id
instance LiftBase Maybe a where type Base Maybe a = Maybe a; base = id
instance LiftBase (ST s) a where type Base (ST s) a = ST s a; base = id
instance LiftBase [] a where type Base [] a = [a]; base = id
instance LiftBase ((->) r) a where type Base ((->) r) a = r -> a; base = id

instance (LiftBase m a, Monad m) => LiftBase (IdT m) a where
  type Base (IdT m) a = Base m a; base = lift . base
instance (LiftBase m a, Monad m) => LiftBase (ChoiceT r i m) a where
  type Base (ChoiceT r i m) a = Base m a; base = lift . base
instance (LiftBase m a, Monad m) => LiftBase (ContT r m) a where
  type Base (ContT r m) a = Base m a; base = lift . base
instance (LiftBase m a, Monad m) => LiftBase (StateT r s m) a where
  type Base (StateT r s m) a = Base m a; base = lift . base


-- | Handy alias for lifting IO computations.

io :: (LiftBase m a, Base m a ~ IO a) => Base m a -> m a
io = base


-- | Every monad transformer @t@ that supports transforming @t m a@ to
-- @m a@ can be an instance of this class.

class Runnable t r m a where
  type Argument t r m a
  runT :: Argument t r m a -> t m a -> m r


-- | Stateful monads.

class Stateful m where
  type StateOf m
  -- | Get the current state.
  get :: m (StateOf m)
  -- | Set the current state and force it.
  put :: StateOf m -> m ()
  put x = x `seq` putLazy x
  -- | Set the current state, but don't force it.
  putLazy :: StateOf m -> m ()

instance (Monad m, Stateful m) => Stateful (ContT r m) where
  type StateOf (ContT r m) = StateOf m
  get = lift get
  put = lift . put
  putLazy = lift . putLazy


-- | Get a certain field.

getField :: (Functor m, Stateful m) => (StateOf m -> a) -> m a
getField = (<$> get)


-- | Apply a function to the current state.

modify :: (Monad m, Stateful m) => (StateOf m -> StateOf m) -> m ()
modify f = liftM f get >>= put


-- | Get a field and modify the state.

modifyField :: (Monad m, Stateful m) =>
               (StateOf m -> a) -> (a -> StateOf m) -> m ()
modifyField accessor f = liftM (f . accessor) get >>= put


-- | Get a field and modify the state.  Lazy version.

modifyFieldLazy :: (Monad m, Stateful m) =>
                   (StateOf m -> a) -> (a -> StateOf m) -> m ()
modifyFieldLazy accessor f = liftM (f . accessor) get >>= putLazy


-- | Apply a function to the current state.  Lazy version.

modifyLazy :: (Monad m, Stateful m) => (StateOf m -> StateOf m) -> m ()
modifyLazy f = liftM f get >>= putLazy


-- | The monad transformer class.  Lifting computations one level down
-- the monad stack, or stated differently promoting a computation of the
-- underlying monad to the transformer.

class Transformer t where
  -- | Promote a monadic computation to the transformer.
  lift :: Monad m => m a -> t m a


-- | Monads with support for logging.  Traditionally these are called
-- *writer monads*.

class Writable m w where
  tell :: w -> m ()