{-# OPTIONS 
   -fglasgow-exts
   #-}

{-# LANGUAGE ScopedTypeVariables, UndecidableInstances, IncoherentInstances #-}

-----------------------------------------------------------------------------
-- Module      :  Control.Monad.RWS.Lazy
-- Copyright   :  (c) Mark Snyder 2008.
-- License     :  BSD-style
-- Maintainer  :  Mark Snyder, marks@ittc.ku.edu
-- Stability   :  experimental
-- Portability :  non-portable (multi-param classes, functional dependencies)
-----------------------------------------------------------------------------

module Control.Monad.RWSX.Lazy (
    RWSX(..),
    runRWSX,
    evalRWSX,
    execRWSX,
    mapRWSX,
    withRWSX,
    RWSTX(..),
    runRWSTX,
    evalRWSTX,
    execRWSTX,
    mapRWSTX,
    withRWSTX,
    module Control.Monad,
    module Control.Monad.Fix,
    module Control.Monad.Trans,
    module Data.Monoid,
    module Control.Monad.RWSX.Class,
  ) where

import Control.Monad
import Control.Monad.Cont.Class
import Control.Monad.Error.Class
import Control.Monad.Fix
--import Control.Monad.RWS.Class
--import Control.Monad.Reader.Class
--import Control.Monad.State.Class
import Control.Monad.Trans
--import Control.Monad.Writer.Class
import Data.Monoid

import Control.Monad.Index

import Control.Monad.ErrorX.Class
import Control.Monad.RWSX.Class
import Control.Monad.ReaderX.Class
import Control.Monad.StateX.Class
import Control.Monad.WriterX.Class

data (Index ix) => RWSX ix r w s a = RWSX ix (r -> s -> (a,s,w))
runRWSX :: (Index ix) => ix -> RWSX ix r w s a -> (r -> s -> (a,s,w))
runRWSX (_::ix) (RWSX (_::ix) f) = f
--instead of... newtype RWS r w s a = RWS { runRWS :: r -> s -> (a, s, w) }

evalRWSX :: (Index ix) => ix -> RWSX ix r w s a -> r -> s -> (a, w)
evalRWSX (ixv::ix) m r s = let
    (a, _, w) = runRWSX ixv m r s
    in (a, w)

execRWSX :: (Index ix) => ix -> RWSX ix r w s a -> r -> s -> (s, w)
execRWSX (ixv::ix) m r s = let
    (_, s', w) = runRWSX ixv m r s
    in (s', w)

mapRWSX :: (Index ix) => ix -> ((a, s, w) -> (b, s, w')) -> RWSX ix r w s a -> RWSX ix r w' s b
mapRWSX (ixv::ix) f m = RWSX ixv $ \r s -> f (runRWSX ixv m r s)

withRWSX :: (Index ix) => ix -> (r' -> s -> (r, s)) -> RWSX ix r w s a -> RWSX ix r' w s a
withRWSX (ixv::ix) f m = RWSX ixv $ \r s -> uncurry (runRWSX ixv m) (f r s)

instance (Index ix) => Functor (RWSX ix r w s) where
    fmap f m = RWSX (getVal::ix) $ \r s -> let
        (a, s', w) = runRWSX (getVal::ix) m r s
        in (f a, s', w)

instance (Monoid w, Index ix) => Monad (RWSX ix r w s) where
    return a = RWSX (getVal::ix) $ \_ s -> (a, s, mempty)
    m >>= k  = RWSX (getVal::ix) $ \r s -> let
        (a, s',  w)  = runRWSX (getVal::ix) m r s
        (b, s'', w') = runRWSX (getVal::ix) (k a) r s'
        in (b, s'', w `mappend` w')

instance (Monoid w, Index ix) => MonadFix (RWSX ix r w s) where
    mfix f = RWSX (getVal::ix) $ \r s -> let (a, s', w) = runRWSX (getVal::ix) (f a) r s in (a, s', w)

instance (Monoid w, Index ix) => MonadReaderX ix r (RWSX ix r w s) where
    askx (ixv::ix)       = RWSX ixv $ \r s -> (r, s, mempty)
    localx (ixv::ix) f m = RWSX ixv $ \r s -> runRWSX ixv m (f r) s

instance (Monoid w, Index ix) => MonadWriterX ix w (RWSX ix r w s) where
    tellx (ixv::ix)   w = RWSX ixv $ \_ s -> ((), s, w)
    listenx (ixv::ix) m = RWSX ixv $ \r s -> let
        (a, s', w) = runRWSX ixv m r s
        in ((a, w), s', w)
    passx (ixv::ix)   m = RWSX ixv $ \r s -> let
        ((a, f), s', w) = runRWSX ixv m r s
        in (a, s', f w)

instance (Monoid w, Index ix) => MonadStateX ix s (RWSX ix r w s) where
    getx (ixv::ix)   = RWSX ixv $ \_ s -> (s, s, mempty)
    putx (ixv::ix) s = RWSX ixv $ \_ _ -> ((), s, mempty)

instance (Monoid w, Index ix) => MonadRWSX ix r w s (RWSX ix r w s) where

-- ---------------------------------------------------------------------------
-- Our parameterizable RWS monad, with an inner monad

data (Index ix) => RWSTX ix r w s m a = RWSTX ix (r -> s -> m (a,s,w))
runRWSTX :: (Index ix) => ix -> RWSTX ix r w s m a -> r -> s -> m (a,s,w)
runRWSTX (_::ix) (RWSTX (_::ix) f) r s = f r s
--instead of... newtype RWST r w s m a = RWST { runRWST :: r -> s -> m (a, s, w) }

evalRWSTX :: (Monad m, Index ix) => ix -> RWSTX ix r w s m a -> r -> s -> m (a, w)
evalRWSTX (ixv::ix) m r s = do
    ~(a, _, w) <- runRWSTX ixv m r s
    return (a, w)

execRWSTX :: (Monad m, Index ix) => ix -> RWSTX ix r w s m a -> r -> s -> m (s, w)
execRWSTX (ixv::ix) m r s = do
    ~(_, s', w) <- runRWSTX ixv m r s
    return (s', w)

mapRWSTX :: (Index ix) => ix -> (m (a, s, w) -> n (b, s, w')) -> RWSTX ix r w s m a -> RWSTX ix r w' s n b
mapRWSTX (ixv::ix) f m = RWSTX ixv $ \r s -> f (runRWSTX ixv m r s)

withRWSTX :: (Index ix) => ix -> (r' -> s -> (r, s)) -> RWSTX ix r w s m a -> RWSTX ix r' w s m a
withRWSTX (ixv::ix) f m = RWSTX ixv $ \r s -> uncurry (runRWSTX ixv m) (f r s)

instance (Monad m, Index ix) => Functor (RWSTX ix r w s m) where
    fmap f m = RWSTX (getVal::ix) $ \r s -> do
        ~(a, s', w) <- runRWSTX (getVal::ix) m r s
        return (f a, s', w)

instance (Monoid w, Monad m, Index ix) => Monad (RWSTX ix r w s m) where
    return a = RWSTX (getVal::ix) $ \_ s -> return (a, s, mempty)
    m >>= k  = RWSTX (getVal::ix) $ \r s -> do
        ~(a, s', w)  <- runRWSTX (getVal::ix) m r s
        ~(b, s'',w') <- runRWSTX (getVal::ix) (k a) r s'
        return (b, s'', w `mappend` w')
    fail msg = RWSTX (getVal::ix) $ \_ _ -> fail msg

instance (Monoid w, MonadPlus m, Index ix) => MonadPlus (RWSTX ix r w s m) where
    mzero       = RWSTX (getVal::ix) $ \_ _ -> mzero
    m `mplus` n = RWSTX (getVal::ix) $ \r s -> runRWSTX (getVal::ix) m r s `mplus` runRWSTX (getVal::ix) n r s

instance (Monoid w, MonadFix m, Index ix) => MonadFix (RWSTX ix r w s m) where
    mfix f = RWSTX (getVal::ix) $ \r s -> mfix $ \ ~(a, _, _) -> runRWSTX (getVal::ix) (f a) r s

instance (Monoid w, Monad m, Index ix) => MonadReaderX ix r (RWSTX ix r w s m) where
    askx (ixv::ix)       = RWSTX ixv $ \r s -> return (r, s, mempty)
    localx (ixv::ix) f m = RWSTX ixv $ \r s -> runRWSTX ixv m (f r) s

instance (Monoid w, Monad m, Index ix) => MonadWriterX ix w (RWSTX ix r w s m) where
    tellx (ixv::ix)   w = RWSTX ixv $ \_ s -> return ((),s,w)
    listenx (ixv::ix) m = RWSTX ixv $ \r s -> do
        ~(a, s', w) <- runRWSTX ixv m r s
        return ((a, w), s', w)
    passx (ixv::ix)   m = RWSTX ixv $ \r s -> do
        ~((a, f), s', w) <- runRWSTX ixv m r s
        return (a, s', f w)

instance (Monoid w, Monad m, Index ix) => MonadStateX ix s (RWSTX ix r w s m) where
    getx (ixv::ix)   = RWSTX ixv $ \_ s -> return (s, s, mempty)
    putx (ixv::ix) s = RWSTX ixv $ \_ _ -> return ((), s, mempty)

instance (Monoid w,
          Monad m,
          MonadRWSX ix r w s m,
          MonadStateX  ix s m,
          MonadReaderX ix r m,
          MonadWriterX ix w m,
          Index ix)
    => MonadRWSX ix r w s (RWSTX ix r w s m) where


{- class (Monoid w, Index ix, MonadReaderX ix r m, MonadWriterX ix w m, MonadStateX ix s m)
   => MonadRWSX ix r w s m | ix m -> r, ix m -> w, ix m -> s -}

-- ---------------------------------------------------------------------------
-- Instances for other mtl transformers

instance (Monoid w, Index ix) => MonadTrans (RWSTX ix r w s) where
    lift m = RWSTX (getVal::ix) $ \_ s -> do
        a <- m
        return (a, s, mempty)

instance (Monoid w, MonadIO m, Index ix) => MonadIO (RWSTX ix r w s m) where
    liftIO = lift . liftIO

instance (Monoid w, MonadCont m, Index ix) => MonadCont (RWSTX ix r w s m) where
    callCC f = RWSTX (getVal::ix) $ \r s ->
        callCC $ \c ->
        runRWSTX (getVal::ix) (f (\a -> RWSTX (getVal::ix) $ \_ s' -> c (a, s', mempty))) r s




-- Error
instance (Monoid w, MonadError e m, Index ix) => MonadError e (RWSTX ix r w s m) where
    throwError       = lift . throwError
    m `catchError` h = RWSTX (getVal::ix) $ \r s -> runRWSTX (getVal::ix) m r s
        `catchError` \e -> runRWSTX (getVal::ix) (h e) r s




-- ErrorX
instance (Monoid w, MonadErrorX ixe e m, Index ix, Index ixe) => MonadErrorX ixe e (RWSTX ix r w s m) where
    throwErrorx (ixv::ixe)     = lift . throwErrorx ixv
    catchErrorx (ixv::ixe) m h = RWSTX (getVal::ix) $ \r s -> 
                                catchErrorx
                                (ixv)
                                (runRWSTX (getVal::ix) m r s)
                                (\e -> runRWSTX (getVal::ix) (h e) r s)

--instance (Monoid w1, Monoid w2, Index ix1, Index ix2,
--         MonadReaderX ix1 r1 m, MonadWriterX ix1 w1 m, MonadStateX ix1 s1 m)
--    => MonadRWSX ix1 r1 w1 s1 (RWSTX ix2 r2 w2 s2 m) where

--NOT FINISHED YET: TODO: create these three instances for R,W,S (for RWSTX) so that we can load ManualTests.hs .


instance (MonadStateX  ixs s m, Monad m, Monoid w3, Index ixs, Index ix3)
    => MonadStateX     ixs s (RWSTX ix3 r3 w3 s3 m) where
    getx (ixv::ixs)        = lift $ getx ixv 
    putx (ixv::ixs) (v::s) = lift $ putx ixv v

instance (MonadReaderX ixr r m, Monad m, Monoid w3, Index ixr, Index ix3)
    => MonadReaderX    ixr r (RWSTX ix3 r3 w3 s3 m) where
    askx (ixv::ixr)       = lift $ askx ixv
    localx (ixv::ixr) f m = RWSTX (getVal::ix3) $ \r s -> localx ixv f (runRWSTX (getVal::ix3) m r s)

instance (MonadWriterX ixw w m, Monad m, Monoid w3, Monoid w, Index ixw, Index ix3)
    => MonadWriterX    ixw w (RWSTX ix3 r3 w3 s3 m) where
    tellx   (ixv::ixw) w = lift $ tellx   ixv w
    listenx (ixv::ixw) m = RWSTX (getVal::ix3) $ \r s -> do
                             ~((a,s',w),w2) <- listenx ixv $ runRWSTX (getVal::ix3) m r s
                             return ((a,w2),s',w)
    passx   (ixv::ixw) m = RWSTX (getVal::ix3) $ \r s ->  
         {-- rest:: m (a,s,w) -} passx ixv $ do
                            {-- rest: m ((a,s,w),w->w) --}
                                 ~((a,(f::w->w)),s',w') <- runRWSTX (getVal::ix3) m r s
                                 return ((a,s',w'),f)

{-
instance (Monoid w, Monad m, Index ix) => MonadReaderX ix r (RWSTX ix r w s m) where
    askx (ixv::ix)       = RWSTX ixv $ \r s -> return (r, s, mempty)
    localx (ixv::ix) f m = RWSTX ixv $ \r s -> runRWSTX ixv m (f r) s

instance (Monoid w, Monad m, Index ix) => MonadWriterX ix w (RWSTX ix r w s m) where
    tellx (ixv::ix)   w = RWSTX ixv $ \_ s -> return ((),s,w)
    listenx (ixv::ix) m = RWSTX ixv $ \r s -> do
        ~(a, s', w) <- runRWSTX ixv m r s
        return ((a, w), s', w)
    passx (ixv::ix)   m = RWSTX ixv $ \r s -> do
        ~((a, f), s', w) <- runRWSTX ixv m r s
        return (a, s', f w)

instance (Monoid w, Monad m, Index ix) => MonadStateX ix s (RWSTX ix r w s m) where
    getx (ixv::ix)   = RWSTX ixv $ \_ s -> return (s, s, mempty)
    putx (ixv::ix) s = RWSTX ixv $ \_ _ -> return ((), s, mempty)
-}