{-# LANGUAGE MultiParamTypeClasses #-}

{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}

{- |

Module      :  Control.Monad.Resumable
Copyright   :  Copyright Nicolas Frisby 2010
License     :  <http://creativecommons.org/licenses/by/3.0/>

Maintainer  :  nicolas.frisby@gmail.com
Stability   :  experimental
Portability :  non-portable (GHC extensions)

A monad transformer for resumable exceptions. The 'ResumableT' transformer is
isomorphic to @ReactT@, the dominant reactivity monad in the literature. The
differences serve to match the @mtl@ style.

-}

module Control.Monad.Resumable
  (-- * Monadic interface
   MonadResumable(..),
   -- * Monad transformer
   Resumable, ResumableT(..),
   -- ** Run functions
   runResumableT, runResumableT_responder,
   runResumableT', runResumableT_responder',
   -- ** Scope manipulation 
   Static(..), Dynamic(..), asStatic, asDynamic, statically, dynamically,
   -- ** Manipulating the inner monad
   mapResumableT_static, mapResumableT_dynamic
  ) where

import {-# SOURCE #-} Control.Monad.Resumable.Scoped
import Control.Monad.Resumable.Class

import Data.Monoid (Monoid)

import Control.Monad (liftM, (<=<))
import Control.Monad.Trans (MonadTrans(..), MonadIO(..))
import Control.Monad.Fix (MonadFix(..))

import Control.Monad.RWS (MonadRWS)
import Control.Monad.Reader (MonadReader(..))
import Control.Monad.Writer (MonadWriter(..))
import Control.Monad.State (MonadState(..))
import Control.Monad.Error (MonadError(..))
import Control.Monad.Cont (MonadCont(..))
import Control.Monad.Identity

import qualified Control.Arrow as Arrow



type Resumable scope req res = ResumableT scope req res Identity

newtype ResumableT scope req res m a =
  ResumableT {unResumableT ::
               m (Either (req, res -> ResumableT scope req res m a) a)}


-- | Establishes static scoping as default.
asStatic :: ResumableT Static req res m a -> ResumableT Static req res m a
asStatic = id

-- | Establishes dynamic scoping as default.
asDynamic :: ResumableT Dynamic req res m a -> ResumableT Dynamic req res m a
asDynamic = id



-- | The preferred top-level interface nevers allows exceptions to go
-- unhandled.
runResumableT :: Monad m =>
  ResumableT scope req res m a ->
  m (req -> (res -> ResumableT scope req res m a) ->
     ResumableT scope req res m a) ->
  m a
runResumableT m mf = loop m where
  loop (ResumableT m) =
    m >>= either (\ sig -> mf >>= loop . ($ sig) . uncurry) return

-- | This variation recognizes that the handling of requests primarily
-- involves generating responses.
runResumableT_responder :: Monad m =>
  ResumableT scope req res m a ->
  m (req -> ResumableT scope req res m res) ->
  m a
runResumableT_responder m mr = runResumableT m responder where
  responder = (\ r req k -> r req >>= k) `liftM` mr

-- | The handler does not depend on the inner monad.
runResumableT' :: Monad m =>
  ResumableT scope req res m a ->
  (req -> (res -> ResumableT scope req res m a) ->
   ResumableT scope req res m a) ->
  m a
runResumableT' m f = runResumableT m (return f)

-- | The responder does not depend on the inner monad.
runResumableT_responder' :: Monad m =>
  ResumableT scope req res m a ->
  (req -> ResumableT scope req res m res) ->
  m a
runResumableT_responder' m r = runResumableT_responder m (return r)


-- | This manipulation of the inner monad acheives static scoping -- the
-- manipulation is not preserved in the resumption.
mapResumableT_static ::
  (m (Either (req, res -> ResumableT scope req res m a) a) ->
   n (Either (req', res' -> ResumableT scope' req' res' n b) b)) ->
  ResumableT scope req res m a -> ResumableT scope' req' res' n b
mapResumableT_static f = ResumableT . f . unResumableT

-- | This manipulation of the inner monad acheives dynamic scoping -- the
-- manipulation is preserved in the resumption.
mapResumableT_dynamic :: (Monad m, Monad n) =>
  (m (Either (req, res -> ResumableT scope req res m a) a) ->
   n (Either (req, res -> ResumableT scope req res m a) b)) ->
  ResumableT scope req res m a -> ResumableT scope req res n b
mapResumableT_dynamic f = loop where
  loop = mapResumableT_static (liftM (Arrow.left (Arrow.second (loop .))) . f)



instance MonadTrans (ResumableT scope req res) where
  lift = ResumableT . liftM Right

instance Functor m => Functor (ResumableT scope req res m) where
  fmap f = loop (fmap (Arrow.right f)) where
    loop f = mapResumableT_static
             (fmap (Arrow.left (Arrow.second (loop f .))) . f)

instance Monad m => Monad (ResumableT scope req res m) where
  return = lift . return
  ResumableT m >>= f = ResumableT $
    m >>= either (return . Left . Arrow.second (f <=<)) (unResumableT . f)

instance Monad m => MonadResumable req res (ResumableT Static req res m) where
  yield req k = ResumableT (return (Left (req, k)))
  handle m h = mapResumableT_static
               (>>= either (unResumableT . uncurry h) (return . Right)) m

instance Monad m => MonadResumable req res (ResumableT Dynamic req res m) where
  yield req k = ResumableT (return (Left (req, k)))
  handle m h = mapResumableT_dynamic
               (>>= either (unResumableT . uncurry h) (return . Right)) m

instance (Monoid w, MonadReader r m, MonadState s m, MonadWriter w m) =>
  MonadRWS r w s (ResumableT Static req res m)

instance (Monoid w, MonadReader r m, MonadState s m, MonadWriter w m) =>
  MonadRWS r w s (ResumableT Dynamic req res m)

instance MonadReader r m => MonadReader r (ResumableT Static req res m) where
  ask = lift ask
  local f = mapResumableT_static (local f)

instance MonadReader r m => MonadReader r (ResumableT Dynamic req res m) where
  ask = lift ask
  local f = mapResumableT_dynamic (local f)

instance MonadState s m => MonadState s (ResumableT scope req res m) where
  get = lift get
  put = lift . put

instance MonadError e m => MonadError e (ResumableT Static req res m) where
  throwError = lift . throwError
  catchError m h = mapResumableT_static (flip catchError (unResumableT . h)) m

instance MonadError e m => MonadError e (ResumableT Dynamic req res m) where
  throwError = lift . throwError
  catchError m h = mapResumableT_dynamic (flip catchError (unResumableT . h)) m

instance MonadWriter w m => MonadWriter w (ResumableT scope req res m) where
  tell = lift . tell
  listen = mapResumableT_dynamic (liftM post . listen) where
    post (x, w) = Arrow.right (flip (,) w) x
  pass = mapResumableT_dynamic (pass . liftM pre) where
    pre (Left p) = (Left p, id) -- id is a bad replacement?
    pre (Right (a, f)) = (Right a, f)

instance MonadCont m => MonadCont (ResumableT scope req res m) where
  callCC f = ResumableT $
    callCC $ \ k -> unResumableT (f (\ a -> ResumableT (k (Right a))))

instance MonadIO m => MonadIO (ResumableT scope req res m) where
  liftIO = lift . liftIO

instance MonadFix m => MonadFix (ResumableT scope req res m) where
  mfix f = ResumableT (mfix (unResumableT . f . pre)) where
    pre (Right a) = a
    pre (Left _) =
      error "mfix fails when applied to a yielding ResumableT computation"