{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}

 -- whether indexes match or not.
{-# LANGUAGE IncoherentInstances #-}
{-# LANGUAGE OverlappingInstances #-}

{-# LANGUAGE TypeFamilies #-}

-----------------------------------------------------------------------------
-- Module      :  Control.Monad.RWS.Strict
-- Copyright   :  (c) Mark Snyder 2008.
-- License     :  BSD-style
-- Maintainer  :  Mark Snyder, marks@ittc.ku.edu
-- Stability   :  experimental
-- Portability :  non-portable (multi-param classes, functional dependencies)
-----------------------------------------------------------------------------

module Control.Monad.RWSX.Strict (
    RWSX(..),
    mkRWSX,
    runRWSX,
    evalRWSX,
    execRWSX,
    mapRWSX,
    withRWSX,
    RWSTX(..),
    mkRWSTX,
    runRWSTX,
    evalRWSTX,
    execRWSTX,
    mapRWSTX,
    withRWSTX,
    module Control.Monad.RWSX.Class,
    module Control.Monad,
    module Control.Monad.Fix,
    module Control.Monad.Trans,
    module Data.Monoid,
  ) where

import Control.Monad
import Control.Monad.Cont.Class
import Control.Monad.Error.Class
import Control.Monad.Fix

import Control.Monad.RWS.Class
import Control.Monad.RWS
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Control.Monad.Trans
import Data.Monoid

import Control.Monad.Index

import Control.Monad.ErrorX.Class

import Control.Monad.ReaderX
import Control.Monad.StateX
import Control.Monad.WriterX
import Control.Monad.ReaderX.Class()
import Control.Monad.StateX.Class()
import Control.Monad.WriterX.Class()
import Control.Monad.RWSX.Class

-- ---------------------------------------------------------------------------
--data (Index ix) => RWSX ix r w s a = RWSX ix (r -> s -> (a,s,w))
--runRWSX :: (Index ix) => ix -> RWSX ix r w s a -> (r -> s -> (a,s,w))
--runRWSX (_::ix) (RWSX (_::ix) f) r s = f r s
newtype RWSX ix r w s a = RWSX { runRWSX' :: r -> s -> (a, s, w) }
mkRWSX :: (Index ix) => ix -> (r->s->(a,s,w)) -> RWSX ix r w s a
mkRWSX _ v = RWSX v
runRWSX :: (Index ix) => ix -> RWSX ix r w s a -> r -> s -> (a,s,w)
runRWSX _ m r s = runRWSX' m r s

evalRWSX :: (Index ix) => ix -> RWSX ix r w s a -> r -> s -> (a, w)
evalRWSX (ixv::ix) m r s = case runRWSX' {-ixv-} m r s of
                    (a, _, w) -> (a, w)

execRWSX :: (Index ix) => ix -> RWSX ix r w s a -> r -> s -> (s, w)
execRWSX (ixv::ix) m r s = case runRWSX' {-ixv -}m r s of
                    (_, s', w) -> (s', w)

mapRWSX :: (Index ix) => ix -> ((a, s, w) -> (b, s, w')) -> RWSX ix r w s a -> RWSX ix r w' s b
mapRWSX (ixv::ix) f m = mkRWSX ixv $ \r s -> f (runRWSX' {-ixv-} m r s)

withRWSX :: (Index ix) => ix -> (r' -> s -> (r, s)) -> RWSX ix r w s a -> RWSX ix r' w s a
withRWSX (ixv::ix) f m = mkRWSX ixv $ \r s -> uncurry (runRWSX' {-ixv-} m) (f r s)

instance (Index ix) => Functor (RWSX ix r w s) where
    fmap f m = mkRWSX (getVal::ix) $ \r s -> case runRWSX' {-(getVal::ix)-} m r s of
                                 (a, s', w) -> (f a, s', w)

instance (Monoid w, Index ix) => Monad (RWSX ix r w s) where
    return a = mkRWSX (getVal::ix) $ \_ s -> (a, s, mempty)
    m >>= k  = mkRWSX (getVal::ix) $ \r s -> case runRWSX' {-(getVal::ix)-} m r s of
                                 (a, s',  w) ->
                                     case runRWSX' {-(getVal::ix)-} (k a) r s' of
                                         (b, s'', w') ->
                                             (b, s'', w `mappend` w')

instance (Monoid w, Index ix) => MonadFix (RWSX ix r w s) where
    mfix f = mkRWSX (getVal::ix) $ \r s -> let (a, s', w) = runRWSX' {-(getVal::ix)-} (f a) r s in (a, s', w)

-- providing the three monads' behaviors
instance (Monoid w, Index ix) => MonadReaderX ix r (RWSX ix r w s) where
    askx (ixv::ix)       = mkRWSX ixv $ \r s -> (r, s, mempty)
    localx (ixv::ix) f m = mkRWSX ixv $ \r s -> runRWSX' {-ixv-} m (f r) s


instance (Monoid w, Index ix) => MonadWriterX ix w (RWSX ix r w s) where
    tellx (ixv::ix)   w = mkRWSX ixv $ \_ s -> ((), s, w)
    listenx (ixv::ix) m = mkRWSX ixv $ \r s -> case runRWSX' {-ixv-} m r s of
                                 (a, s', w) -> ((a, w), s', w)
    passx (ixv::ix)   m = mkRWSX ixv $ \r s -> case runRWSX' {-ixv-} m r s of
                                 ((a, f), s', w) -> (a, s', f w)

instance (Monoid w, Index ix) => MonadStateX ix s (RWSX ix r w s) where
    getx (ixv::ix)   = mkRWSX ixv $ \_ s -> (s, s, mempty)
    putx (ixv::ix) s = mkRWSX ixv $ \_ _ -> ((), s, mempty)

instance (Monoid w, Index ix) => MonadRWSX ix r w s (RWSX ix r w s)

-- ---------------------------------------------------------------------------
-- Our parameterizable RWS monad, with an inner monad


--data (Index ix) => RWSTX ix r w s m a = RWSTX ix (r -> s -> m (a,s,w))
--runRWSTX :: (Index ix) => ix -> RWSTX ix r w s m a -> (r -> s -> (m (a,s,w)))
--runRWSTX (_::ix) (RWSTX (_::ix) f) = f
newtype RWSTX ix r w s m a = RWSTX { runRWSTX' :: r -> s -> m (a, s, w) }
mkRWSTX :: (Index ix) => ix -> (r->s->m(a,s,w)) -> RWSTX ix r w s m a
mkRWSTX _ m = RWSTX m
runRWSTX :: (Index ix) => ix -> RWSTX ix r w s m a -> (r -> s -> m (a,s,w))
runRWSTX _ = runRWSTX'

evalRWSTX :: (Monad m, Index ix) => ix -> RWSTX ix r w s m a -> r -> s -> m (a, w)
evalRWSTX (ixv::ix) m r s = do
    (a, _, w) <- runRWSTX' {-ixv-} m r s
    return (a, w)

execRWSTX :: (Monad m, Index ix) => ix -> RWSTX ix r w s m a -> r -> s -> m (s, w)
execRWSTX (ixv::ix) m r s = do
    (_, s', w) <- runRWSTX' {-ixv-} m r s
    return (s', w)

mapRWSTX :: (Index ix) => ix -> (m (a, s, w) -> n (b, s, w')) -> RWSTX ix r w s m a -> RWSTX ix r w' s n b
mapRWSTX (ixv::ix) f m = mkRWSTX ixv $ \r s -> f (runRWSTX' {-ixv-} m r s)

withRWSTX :: (Index ix) => ix -> (r' -> s -> (r, s)) -> RWSTX ix r w s m a -> RWSTX ix r' w s m a
withRWSTX (ixv::ix) f m = mkRWSTX ixv $ \r s -> uncurry (runRWSTX' {-ixv-} m) (f r s)

instance (Monad m, Index ix) => Functor (RWSTX ix r w s m) where
    fmap f m = mkRWSTX (getVal::ix) $ \r s -> do
        (a, s', w) <- runRWSTX' {-(getVal::ix)-} m r s
        return (f a, s', w)

instance (Monoid w, Monad m, Index ix) => Monad (RWSTX ix r w s m) where
    return a = mkRWSTX (getVal::ix) $ \_ s -> return (a, s, mempty)
    m >>= k  = mkRWSTX (getVal::ix) $ \r s -> do
        (a, s', w)  <- runRWSTX' {-(getVal::ix)-} m r s
        (b, s'',w') <- runRWSTX' {-(getVal::ix)-} (k a) r s'
        return (b, s'', w `mappend` w')
    fail msg = mkRWSTX (getVal::ix) $ \_ _ -> fail msg

instance (Monoid w, MonadPlus m, Index ix) => MonadPlus (RWSTX ix r w s m) where
    mzero       = mkRWSTX (getVal::ix) $ \_ _ -> mzero
    m `mplus` n = mkRWSTX (getVal::ix) $ \r s -> runRWSTX' {-(getVal::ix)-} m r s `mplus` runRWSTX' {-(getVal::ix)-} n r s

instance (Monoid w, MonadFix m, Index ix) => MonadFix (RWSTX ix r w s m) where
    mfix f = mkRWSTX (getVal::ix) $ \r s -> mfix $ \ (a, _, _) -> runRWSTX' {-(getVal::ix)-} (f a) r s


-- providing the three monads' behaviors
-- ReaderX
instance (Monoid w, Monad m, Index ix) => MonadReaderX ix r (RWSTX ix r w s m) where
    askx   (ixv::ix)     = mkRWSTX ixv $ \r s -> return (r, s, mempty)
    localx (ixv::ix) f m = mkRWSTX ixv $ \r s -> runRWSTX' {-ixv-} m (f r) s
-- WriterX
instance (Monoid w, Monad m, Index ix) => MonadWriterX ix w (RWSTX ix r w s m) where
    tellx (ixv::ix)   w = mkRWSTX ixv $ \_ s -> return ((),s,w)
    listenx (ixv::ix) m = mkRWSTX ixv $ \r s -> do
        (a, s', w) <- runRWSTX' {-ixv-} m r s
        return ((a, w), s', w)
    passx (ixv::ix)   m = mkRWSTX ixv $ \r s -> do
        ((a, f), s', w) <- runRWSTX' {-ixv-} m r s
        return (a, s', f w)
-- StateX
instance (Monoid w, Monad m, Index ix) => MonadStateX ix s (RWSTX ix r w s m) where
    getx (ixv::ix)   = mkRWSTX ixv $ \_ s -> return (s, s, mempty)
    putx (ixv::ix) s = mkRWSTX ixv $ \_ _ -> return ((), s, mempty)

-- RWSX same index.
instance (Monad m, Monoid w1, Monoid w2, Index ix
         , r1~r2, s1~s2,w1~w2 -- why the heck does this help?!?
         )
    => MonadRWSX ix r1 w1 s1 (RWSTX ix r2 w2 s2 m) where


-- ---------------------------------------------------------------------------
-- Instances for other mtl transformers


instance (Monoid w, Index ix) => MonadTrans (RWSTX ix r w s) where
    lift m = mkRWSTX (getVal::ix) $ \_ s -> do
        a <- m
        return (a, s, mempty)

instance (Monoid w, MonadIO m, Index ix) => MonadIO (RWSTX ix r w s m) where
    liftIO = lift . liftIO

instance (Monoid w, MonadCont m, Index ix) => MonadCont (RWSTX ix r w s m) where
    callCC f = mkRWSTX (getVal::ix) $ \r s ->
        callCC $ \c ->
        runRWSTX' {-(getVal::ix)-} (f (\a -> mkRWSTX (getVal::ix) $ \_ s' -> c (a, s', mempty))) r s


-- ---------------------------------------------------------------------------
-- ---------------------------------------------------------------------------
-- ---------------------------------------------------------------------------
-- ---------------------------------------------------------------------------

-- Error
instance (Monoid w, MonadError e m, Index ix) => MonadError e (RWSTX ix r w s m) where
    throwError       = lift . throwError
    m `catchError` h = mkRWSTX (getVal::ix) $ \r s -> runRWSTX' {-(getVal::ix)-} m r s
        `catchError` \e -> runRWSTX' {-(getVal::ix)-} (h e) r s

-- ErrorX
instance (Monoid w, MonadErrorX ixe e m, Index ix, Index ixe) => MonadErrorX ixe e (RWSTX ix r w s m) where
    throwErrorx (ixv::ixe)     = lift . throwErrorx ixv
    catchErrorx (ixv::ixe) m h = mkRWSTX (getVal::ix) $ \r s -> 
                                catchErrorx
                                (ixv)
                                (runRWSTX' {-(getVal::ix)-} m r s)
                                (\e -> runRWSTX' {-(getVal::ix)-} (h e) r s)

-- ---------------------------------------------------------------------------

-- interactions with the basic MTL monads is simple...

-- Reader
instance (Monoid w2, Monad m, Index ix2, MonadReader r1 m) => MonadReader r1 (RWSTX ix2 r2 w2 s2 m) where
    ask       = mkRWSTX (getVal::ix2) $ \r2 s2 -> ask    >>= \r1 -> return (r1, s2, mempty)
    local f m = mkRWSTX (getVal::ix2) $ \r2 s2 -> local  f (runRWSTX' {-(getVal::ix2)-} m r2 s2)

instance (Monoid w1, Index ix1
         ,MonadRWSX ix1 r1 w1 s1 m
         )
   => MonadRWSX ix1 r1 w1 s1 (ReaderT r2 m) where

-- Writer
instance (Monoid w1, Monoid w2, Monad m, Index ix2, MonadWriter w1 m) => MonadWriter w1 (RWSTX ix2 r2 w2 s2 m) where
    tell   w1 = mkRWSTX (getVal::ix2) $ \_ s2 -> tell w1 >> return ((),s2,mempty)
    listen  m = mkRWSTX (getVal::ix2) $ \r2 s2 -> do
        ((a,s2',w2'),w1) <- listen (runRWSTX' {-(getVal::ix2)-} m r2 s2)
        return ((a,w1),s2',w2')
    pass m = mkRWSTX (getVal::ix2) $ \r2 s2 -> do
        (a,b,c) <- runRWSTX' {-(getVal::ix2)-} m r2 s2
        a' <- pass $ return a
        return (a',b,c)

instance (Monoid w1, Monoid w2, Index ix1
         ,MonadRWSX ix1 r1 w1 s1 m
         )
   => MonadRWSX ix1 r1 w1 s1 (WriterT w2 m) where

-- State
instance (Monoid w2, Monad m, Index ix2, MonadState s1 m) => MonadState s1 (RWSTX ix2 r2 w2 s2 m) where
    get    = mkRWSTX (getVal::ix2) $ \_ (s::s2) -> get >>= (\v1 -> return (v1, s, mempty))
    put s1 = mkRWSTX (getVal::ix2) $ \_ (s::s2) -> put s1 >>  return ((), s, mempty)

instance (Monoid w1, Index ix1
         ,MonadRWSX ix1 r1 w1 s1 m
         )
   => MonadRWSX ix1 r1 w1 s1 (StateT s2 m) where

-- ---------------------------------------------------------------------------

-- RWS
instance (Monoid w1, Monoid w2, Index ix2
         , MonadRWS  r1 w1 s1 m
         )
    => MonadRWS      r1 w1 s1 (RWSTX ix2 r2 w2 s2 m) where

instance (Monoid w1, Monoid w2, Index ix1
         , MonadRWSX ix1 r1 w1 s1 m
         )
    => MonadRWSX ix1 r1 w1 s1 (RWST      r2 w2 s2 m) where


-- interactions with other MTLX monads must be handled with different indexes.

-- RWSX. Collects ReaderX, WriterX, and StateX constraints on different indexes to claim a further RWSX relation on different indexes.
instance (Monoid w1, Monoid w2, Index ix1, Index ix2
         ,MonadRWSX    ix1 r1 w1 s1 m
         )
    => MonadRWSX ix1 r1 w1 s1 (RWSTX ix2 r2 w2 s2 m) where

-- ---------------------------------------------------------------------------
-- ---------------------------------------------------------------------------

-- ReaderX, WriterX, and StateX with different indexes.

-- ReaderX
instance (Monoid w2, Monad m, Index ix1, Index ix2, MonadReaderX ix1 r1 m) => MonadReaderX ix1 r1 (RWSTX ix2 r2 w2 s2 m) where
    askx   (_::ix1)     = mkRWSTX (getVal::ix2) $ \r2 s2 -> askx   (getVal::ix1) >>= \r1 -> return (r1, s2, mempty)
    localx (_::ix1) f m = mkRWSTX (getVal::ix2) $ \r2 s2 -> localx (getVal::ix1) f (runRWSTX' {-(getVal::ix2)-} m r2 s2)

-- StateX
instance (Monoid w2, Monad m, Index ix1, Index ix2, MonadStateX ix1 s1 m) => MonadStateX ix1 s1 (RWSTX ix2 r2 w2 s2 m) where
    getx (_::ix1)    = mkRWSTX (getVal::ix2) $ \_ (s::s2) -> getx (getVal::ix1) >>= (\v1 -> return (v1, s, mempty))
    putx (_::ix1) s1 = mkRWSTX (getVal::ix2) $ \_ (s::s2) -> putx (getVal::ix1) s1 >>  return ((), s, mempty)

-- WriterX
instance (Monoid w1, Monoid w2, Monad m, Index ix1, Index ix2, MonadWriterX ix1 w1 m) => MonadWriterX ix1 w1 (RWSTX ix2 r2 w2 s2 m) where
    tellx (_::ix1)  w1 = mkRWSTX (getVal::ix2) $ \_ s2 -> tellx (getVal::ix1) w1 >> return ((),s2,mempty)
    listenx (_::ix1) m = mkRWSTX (getVal::ix2) $ \r2 s2 -> do
        ((a,s2',w2'),w1) <- listenx (getVal::ix1) (runRWSTX' {-(getVal::ix2)-} m r2 s2)
        return ((a,w1),s2',w2')
    passx (_::ix1)  m = mkRWSTX (getVal::ix2) $ \r2 s2 -> do
        (a,b,c) <- runRWSTX' {-(getVal::ix2)-} m r2 s2
        a' <- passx (getVal::ix1) $ return a
        return (a',b,c)