{-# OPTIONS
            -fno-warn-orphans
  #-}

{-# LANGUAGE ScopedTypeVariables, UndecidableInstances, OverlappingInstances, FlexibleInstances, MultiParamTypeClasses #-}

{- |
Module      :  Control.Monad.ReaderX
Copyright   :  (c) Mark Snyder 2008.
License     :  BSD-style
Maintainer  :  Mark Snyder, marks@ittc.ku.edu
Stability   :  experimental
Portability :  non-portable (multi-param classes, functional dependencies)
-}

module Control.Monad.ReaderX (
    module Control.Monad.ReaderX.Class,
    ReaderX(..),
    mkReaderX,
    runReaderX,
    mapReaderx,
    withReaderx,

    ReaderTX(..),
    mkReaderTX,
    runReaderTX,
    mapReaderTX,
    withReaderTX,
    module Control.Monad,
    module Control.Monad.Fix,
    module Control.Monad.Trans,    
    module Control.Monad.Index
    ) where

import Control.Monad
import Control.Monad.Cont.Class
import Control.Monad.Error
import Control.Monad.Fix
import Control.Monad.Instances ()
import Control.Monad.Trans

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Control.Monad.RWS

import Control.Monad.Index


import Control.Monad.ErrorX
import Control.Monad.ErrorX.Class

import Control.Monad.ReaderX.Class
import Control.Monad.StateX.Class
import Control.Monad.WriterX.Class


-- ----------------------------------------------------------------------------
-- The partially applied function type is a simple reader monad

instance (Index ix) => MonadReaderX ix r ((->) r) where
    askx (_::ix)       = id
    localx (_::ix) f m = m . f

newtype ReaderX ix r a = ReaderX {runReaderX' :: r -> a}
mkReaderX :: (Index ix) => ix -> (r->a) -> ReaderX ix r a
mkReaderX _ v = ReaderX v
runReaderX :: (Index ix) => ix -> ReaderX ix r a -> (r -> a)
runReaderX _ m r = runReaderX' m r

mapReaderx :: (Index ix) => ix -> (a -> b) -> ReaderX ix r a -> ReaderX ix r b
mapReaderx (ixv::ix) f m = mkReaderX ixv $ f . runReaderX' m

-- | A more general version of 'local'.

withReaderx :: (Index ix) => ix -> (r' -> r) -> ReaderX ix r a -> ReaderX ix r' a
withReaderx (ixv::ix) f m = mkReaderX ixv $ runReaderX' m . f

instance (Index ix) => Functor (ReaderX ix r) where
    fmap f m = mkReaderX (getVal::ix) $ \r -> f (runReaderX (getVal::ix) m r)

instance (Index ix) => Monad (ReaderX ix r) where
    return a = mkReaderX (getVal::ix) $ \_ -> a
    m >>= k  = mkReaderX (getVal::ix) $ \r -> runReaderX (getVal::ix) (k (runReaderX (getVal::ix) m r)) r

instance (Index ix) => MonadFix (ReaderX ix r) where
    mfix f = mkReaderX (getVal::ix) $ \r -> let a = runReaderX (getVal::ix) (f a) r in a

instance (Index ix) => MonadReaderX ix r (ReaderX ix r) where
    askx (ixv::ix)       = mkReaderX ixv id
    localx (ixv::ix) f m = mkReaderX ixv $ runReaderX' m . f

-- The reader monad transformer.  Can be used to add environment
-- reading functionality to other monads.

newtype ReaderTX ix r m a = ReaderTX { runReaderTX' :: r -> m a }
mkReaderTX :: (Index ix) => ix -> (r->m a) -> ReaderTX ix r m a
mkReaderTX _ f = ReaderTX f
runReaderTX :: (Index ix) => ix -> ReaderTX ix r m a -> r -> m a
runReaderTX _ m r = runReaderTX' m r

mapReaderTX :: (Index ix) => ix -> (m a -> n b) -> ReaderTX ix w m a -> ReaderTX ix w n b
mapReaderTX (ixv::ix) f m = mkReaderTX ixv $ f . runReaderTX ixv m

withReaderTX :: (Index ix) => ix -> (r' -> r) -> ReaderTX ix r m a -> ReaderTX ix r' m a
withReaderTX (ixv::ix) f m = mkReaderTX ixv $ runReaderTX ixv m . f

instance (Monad m, Index ix) => Functor (ReaderTX ix r m) where
    fmap f m = mkReaderTX (getVal::ix) $ \r -> do
        a <- runReaderTX (getVal::ix) m r
        return (f a)

instance (Monad m, Index ix) => Monad (ReaderTX ix r m) where
    return a = mkReaderTX (getVal::ix) $ \_ -> return a
    m >>= k  = mkReaderTX (getVal::ix) $ \r -> do
        a <- runReaderTX (getVal::ix) m r
        runReaderTX (getVal::ix) (k a) r
    fail msg = mkReaderTX (getVal::ix) $ \_ -> fail msg

instance (MonadPlus m, Index ix) => MonadPlus (ReaderTX ix r m) where
    mzero       = mkReaderTX (getVal::ix) $ \_ -> mzero
    m `mplus` n = mkReaderTX (getVal::ix) $ \r -> runReaderTX (getVal::ix) m r `mplus` runReaderTX (getVal::ix) n r

instance (MonadFix m, Index ix) => MonadFix (ReaderTX ix r m) where
    mfix f = mkReaderTX (getVal::ix) $ \r -> mfix $ \a -> runReaderTX (getVal::ix) (f a) r

instance (Monad m, Index ix) => MonadReaderX ix r (ReaderTX ix r m) where
    askx (ixv::ix)       = mkReaderTX ixv return
    localx (ixv::ix) f m = mkReaderTX ixv $ \r -> runReaderTX ixv m (f r)

-- ---------------------------------------------------------------------------
-- Instances for other mtl transformers

instance (Index ix) => MonadTrans (ReaderTX ix r) where
    lift m = mkReaderTX (getVal::ix) $ \_ -> m

instance (MonadIO m, Index ix) => MonadIO (ReaderTX ix r m) where
    liftIO = lift . liftIO

instance (MonadCont m, Index ix) => MonadCont (ReaderTX ix r m) where
    callCC f = mkReaderTX (getVal::ix) $ \r ->
        callCC $ \c ->
        runReaderTX' (f (\a -> mkReaderTX (getVal::ix) $ \_ -> c a)) r




-- Error
instance (MonadError e m, Index ix) => MonadError e (ReaderTX ix r m) where
    throwError       = lift . throwError
    m `catchError` h = mkReaderTX (getVal::ix) $ \r -> runReaderTX' m r
        `catchError` \e -> runReaderTX' (h e) r

instance (Index ix, Error e, MonadReaderX ix r m) => MonadReaderX ix r (ErrorT e m) where
    askx (ixv::ix) = lift $ askx ixv
    localx (ixv::ix) f m = ErrorT $ localx ixv f (runErrorT m)



--ErrorX
instance (MonadErrorX ixe e m, Index ixe, Index ixr) => MonadErrorX ixe e (ReaderTX ixr r m) where
    throwErrorx (ixv::ixe)       = lift . throwErrorx ixv
    catchErrorx (ixv::ixe) m h = mkReaderTX (getVal::ixr) $ \r -> 
                          catchErrorx 
                          ixv
                          (runReaderTX' m r)
                          (\e -> runReaderTX' (h e) r)

--Reader
instance (MonadReader r m, Index ix) => MonadReader r (ReaderTX ix r2 m) where
    ask = mkReaderTX (getVal::ix) $ \_ -> ask
    local f ((ReaderTX comp)::(ReaderTX ix r2 m a)) = mkReaderTX (getVal::ix) $ \e -> local f (comp e)

instance (Monad m, MonadReaderX ix r1 m, Index ix) => 
    MonadReaderX ix r1 (ReaderT r2 m) where
    askx (ixv::ix) = ReaderT $ \_ -> askx ixv
    localx (ixv::ix) f (ReaderT comp) =  ReaderT $ \e -> localx ixv f (comp e)

--ReaderX
instance (Index ix1, Index ix2, MonadReaderX ix1 r1 m) => 
    MonadReaderX ix1 r1 (ReaderTX ix2 r2 m) where
   askx (ixv::ix1)  = mkReaderTX (getVal::ix2) $ \(_::r2) -> askx ixv
   localx (ixv::ix1) (f::r1->r1) ((ReaderTX comp)::(ReaderTX ix2 r2 m a)) = 
                      mkReaderTX (getVal::ix2) $ \x -> localx ixv f (comp x)

-- State
instance (Index ix,  MonadState s m) => MonadState s (ReaderTX ix r m) where
    get   = lift $ get
    put s = lift $ put s

instance (Index ix, MonadReaderX ix r m) => MonadReaderX ix r (StateT s m) where
    askx (ixv::ix) = lift $ askx ixv
    localx (ixv::ix) f (StateT g) = StateT  $ \s -> localx ixv f (g s)

-- StateX
instance (Index ixr, Index ixs, MonadStateX ixs s m) => MonadStateX ixs s (ReaderTX ixr r m) where
    getx (ixv::ixs)   = lift $ getx ixv
    putx (ixv::ixs) s = lift $ putx ixv s

-- Writer
instance (Index ix, MonadWriter w m) => MonadWriter w (ReaderTX ix r m) where
    tell     = lift  . tell
    listen m = mkReaderTX (getVal::ix) $ \w -> listen (runReaderTX' m w)
    pass   m = mkReaderTX (getVal::ix) $ \w -> pass   (runReaderTX' m w)

instance (Index ix, MonadReaderX ix r m, Monoid w) => MonadReaderX ix r (WriterT w m) where
    askx   (ixv::ix)     = lift $ askx ixv
    localx (ixv::ix) f m = WriterT $ localx ixv f (runWriterT m)

-- WriterX
instance (Index ixr, MonadWriterX ixw w m) => MonadWriterX ixw w (ReaderTX ixr r m) where
    tellx   (ixv::ixw)   = lift  . tellx ixv
    listenx (ixv::ixw) m = mkReaderTX (getVal::ixr) $ \w -> listenx ixv (runReaderTX' m w)
    passx   (ixv::ixw) m = mkReaderTX (getVal::ixr) $ \w -> passx   ixv (runReaderTX' m w)

-- RWS
instance (Monoid w, Index ix2, MonadReader r m, MonadState s m, MonadWriter w m) => MonadRWS r w s (ReaderTX ix2 r2 m) where

instance (Monoid w2, Monad m, Index ix1, MonadReaderX ix1 r1 m) => MonadReaderX ix1 r1 (RWST r2 w2 s2 m) where
    askx   (_::ix1)     = RWST $ \_ s2 -> askx   (getVal::ix1) >>= \r1 -> return (r1, s2, mempty)
    localx (_::ix1) f m = RWST $ \r2 s2 -> localx (getVal::ix1) f (runRWST m r2 s2)