{-# OPTIONS -fno-warn-orphans #-}
-- Search for -fallow-undecidable-instances to see why this is needed

{-# LANGUAGE UndecidableInstances,ScopedTypeVariables, FlexibleInstances, MultiParamTypeClasses #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Monad.StateX.Strict
-- Copyright   :  (c) Mark Snyder 2008.
-- License     :  BSD-style
-- Maintainer  :  Mark Snyder, marks@ittc.ku.edu
-- Stability   :  experimental
-- Portability :  non-portable (multi-param classes, functional dependencies)
--
-- Strict state monads.
--
-----------------------------------------------------------------------------

module Control.Monad.StateX.Strict (
    module Control.Monad.StateX.Class,
    -- * The StateX Monad
    StateX(..),
    mkStateX,
    runStateX,
    evalStatex,
    execStatex,
    mapStatex,
    withStatex,
    -- * The StateTX Monad
    StateTX(..),
    mkStateTX,
    runStateTX,
    evalStateTX,
    execStateTX,
    mapStateTX,
    withStateTX,
    module Control.Monad,
    module Control.Monad.Fix,
    module Control.Monad.Trans,
  ) where

import Control.Monad
import Control.Monad.Cont.Class
--import Control.Monad.Error.Class
import Control.Monad.Fix
--import Control.Monad.Reader.Class
--import Control.Monad.State.Class
import Control.Monad.Trans
--import Control.Monad.Writer.Class

import Control.Monad.Error
import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Writer
import Control.Monad.RWS

import Control.Monad.Index

import Control.Monad.ErrorX.Class
import Control.Monad.ReaderX.Class
import Control.Monad.StateX.Class
import Control.Monad.WriterX.Class

-- ---------------------------------------------------------------------------
--data (Index ix) => StateX ix s a = StateX ix (s -> (a,s))
--runStateX :: (Index ix) => ix -> StateX ix s a -> (s -> (a,s))
--runStateX (_::ix) (StateX (_::ix) f) s = f s
newtype StateX ix s a = StateX { runStateX' :: s -> (a, s) }
mkStateX :: (Index ix) => ix -> (s->(a,s)) -> StateX ix s a
mkStateX _ v = StateX v
runStateX :: (Index ix) => ix -> StateX ix s a -> (s->(a,s))
runStateX _ m s = runStateX' m s

evalStatex :: (Index ix) => ix -> StateX ix s a -> s -> a
evalStatex ix m s = fst (runStateX ix m s)

execStatex :: (Index ix) => ix -> StateX ix s a -> s -> s
execStatex ix m s = snd (runStateX ix m s)

mapStatex :: (Index ix) => ix -> ((a, s) -> (b, s)) -> StateX ix s a -> StateX ix s b
mapStatex (ixv::ix) f m = mkStateX ixv $ f . runStateX' {-ixv-} m

withStatex :: (Index ix) => ix -> (s -> s) -> StateX ix s a -> StateX ix s a
withStatex (ixv::ix) f m = mkStateX ixv $ runStateX' {-ixv-} m . f


instance (Index ix) => Functor (StateX ix s) where
    fmap f m = mkStateX (getVal::ix) $ \s ->
               case runStateX' {-(getVal::ix)-} m s of
                 (a, s') -> (f a, s')

instance (Index ix) => Monad (StateX ix s) where
    return a = mkStateX (getVal::ix) $ \s -> (a, s)
    m >>= k  = mkStateX (getVal::ix) $ \s ->
               case runStateX' {-(getVal::ix)-} m s of
                 (a, s') -> runStateX' {-(getVal::ix)-} (k a) s'

instance (Index ix) => MonadFix (StateX ix s) where
    mfix f = mkStateX (getVal::ix) $ \s -> let (a, s') = runStateX' {-(getVal::ix)-} (f a) s in (a, s')

instance (Index ix) => MonadStateX ix s (StateX ix s) where
    getx (ixv::ix)   = mkStateX ixv $ \s -> (s, s)
    putx (ixv::ix) s = mkStateX ixv $ \_ -> ((), s)

-- ---------------------------------------------------------------------------
--data (Index ix) => StateTX ix s m a = StateTX ix (s-> m (a,s))
--runStateTX :: (Index ix) => ix -> StateTX ix s m a -> (s -> m (a,s))
--runStateTX (_::ix) (StateTX (_::ix) f) s = f s
newtype StateTX ix s m a = StateTX { runStateTX' :: s -> m (a,s) }
mkStateTX :: (Index ix) => ix -> (s->m(a,s)) -> StateTX ix s m a
mkStateTX  _ v = StateTX v
runStateTX :: (Index ix) => ix -> StateTX ix s m a -> s -> m (a,s)
runStateTX _ m s = runStateTX' m s

evalStateTX :: (Monad m, Index ix) => ix -> StateTX ix s m a -> s -> m a
evalStateTX (ixv::ix) m s = do
    (a, _) <- runStateTX' {-ixv-} m s
    return a

execStateTX :: (Monad m, Index ix) => ix -> StateTX ix s m a -> s -> m s
execStateTX (ixv::ix) m s = do
    (_, s') <- runStateTX' {-ixv-} m s
    return s'

mapStateTX :: (Index ix) => ix -> (m (a, s) -> n (b, s)) -> StateTX ix s m a -> StateTX ix s n b
mapStateTX (ixv::ix) f m = mkStateTX ixv $ f . runStateTX' {-ixv-} m

withStateTX :: (Index ix) => ix -> (s -> s) -> StateTX ix s m a -> StateTX ix s m a
withStateTX (ixv::ix) f m = mkStateTX ixv $ runStateTX' {-ixv-} m . f

instance (Monad m, Index ix) => Functor (StateTX ix s m) where
    fmap f m = mkStateTX (getVal::ix) $ \s -> do
        (x, s') <- runStateTX' {-(getVal::ix)-} m s
        return (f x, s')

instance (Monad m, Index ix) => Monad (StateTX ix s m) where
    return a = mkStateTX (getVal::ix) $ \s -> return (a, s)
    m >>= k  = mkStateTX (getVal::ix) $ \s -> do
        (a, s') <- runStateTX' {-(getVal::ix)-} m s
        runStateTX' {-(getVal::ix)-} (k a) s'
    fail str = mkStateTX (getVal::ix) $ \_ -> fail str

instance (MonadPlus m, Index ix) => MonadPlus (StateTX ix s m) where
    mzero       = mkStateTX (getVal::ix) $ \_ -> mzero
    m `mplus` n = mkStateTX (getVal::ix) $ \s -> runStateTX' {-(getVal::ix)-} m s `mplus` runStateTX' {-(getVal::ix)-} n s

instance (MonadFix m, Index ix) => MonadFix (StateTX ix s m) where
    mfix f = mkStateTX (getVal::ix) $ \s -> mfix $ \ ~(a, _) -> runStateTX' {-(getVal::ix)-} (f a) s

instance (Monad m, Index ix) => MonadStateX ix s (StateTX ix s m) where
    getx (ixv::ix)   = mkStateTX ixv $ \s -> return (s, s)
    putx (ixv::ix) s = mkStateTX ixv $ \_ -> return ((), s)

-- ---------------------------------------------------------------------------
-- Instances for other mtl transformers

instance (Index ix) => MonadTrans (StateTX ix s) where
    lift m = mkStateTX (getVal::ix) $ \s -> do
        a <- m
        return (a, s)

instance (MonadIO m, Index ix) => MonadIO (StateTX ix s m) where
    liftIO = lift . liftIO

instance (MonadCont m, Index ix) => MonadCont (StateTX ix s m) where
    callCC f = mkStateTX (getVal::ix) $ \s ->
        callCC $ \c ->
        runStateTX (getVal::ix) (f (\a -> mkStateTX (getVal::ix) $ \s' -> c (a, s'))) s




-- Error
instance (MonadError e m, Index ix) => MonadError e (StateTX ix s m) where
    throwError       = lift . throwError
    m `catchError` h = mkStateTX (getVal::ix) $ \s -> runStateTX' {-(getVal::ix)-} m s
        `catchError` \e -> runStateTX' {-(getVal::ix)-} (h e) s

instance (Error e, MonadStateX ixs s m) => MonadStateX ixs s (ErrorT e m) where
    getx (ixv::ixs) = lift $ getx ixv
    putx (ixv::ixs) = lift . putx ixv



-- ErrorX
instance (MonadErrorX ixe e m, Index ixe, Index ixs) => MonadErrorX ixe e (StateTX ixs s m) where
    throwErrorx (ixv::ixe)       = lift . throwErrorx ixv
    catchErrorx (ixv::ixe) m h = mkStateTX (getVal::ixs) $ \s -> 
         catchErrorx
         ixv 
         (runStateTX' {-(getVal::ixs)-} m s)
         (\e -> runStateTX' {-(getVal::ixs)-} (h e) s)


-- State
instance (MonadState s1 m,Index ix) => MonadState s1 (StateTX ix s2 m) where
    get   = mkStateTX (getVal::ix) $ \s -> do
                                n <- get 
                                return (n,s)                             
    put (v::s1) = mkStateTX (getVal::ix) $ \s -> do
                                put v
                                return ((),s)

instance (MonadStateX ix s1 m, Index ix) 
    => MonadStateX ix s1 (StateT s2 m) where
    getx (ixv::ix)  = StateT $ \s -> do
                              n <- getx ixv
                              return (n,s)
    putx (ixv::ix) (v::s1) = StateT $ \s -> do
                              putx ixv v
                              return ((),s)




-- StateX
instance (Index ix1, Index ix2, MonadStateX ix1 s1 m ) 
    => MonadStateX ix1 s1 (StateTX ix2 s2 m) where
    getx (ixv::ix1)    =  mkStateTX (getVal::ix2) $ 
                       \(s::s2)-> getx (ixv::ix1) >>= (\v1 -> return (v1,s))
    putx (ixv::ix1) v1 =  mkStateTX (getVal::ix2) $ 
                       \(s::s2)-> putx (ixv::ix1) (v1) >> return  ((),s)




--Reader
-- Needs -fallow-undecidable-instances
instance (Index ix, MonadReader r m) =>  MonadReader r (StateTX ix st m) where
  ask  = lift ask 
  local f m = mkStateTX (getVal::ix) $ \(s::st) -> local f (runStateTX' {-(getVal::ix)-} m s)

instance (Index ix, MonadStateX ix s m) => MonadStateX ix s (ReaderT r m) where
  getx (ixv::ix)   = lift $ getx ixv
  putx (ixv::ix) v = lift $ putx ixv v




-- ReaderX
-- Needs -fallow-undecidable-instances
instance (Index ixs, MonadReaderX ixr r m) =>  MonadReaderX ixr r (StateTX ixs st m) where
  askx (ixv::ixr) = lift $ askx ixv
  localx (ixv::ixr) f m = mkStateTX (getVal::ixs) $ \(s::st) -> localx ixv f (runStateTX' {-(getVal::ixs)-} m s)


--Writer
-- Needs -fallow-undecidable-instances
instance (Index ix, MonadWriter w m) => MonadWriter w (StateTX ix s m) where
   tell     = lift . tell
   listen m = mkStateTX (getVal::ix) $ \s -> do
       ~((a,s'),w) <- listen (runStateTX' {-(getVal::ix)-} m s)
       return ((a,w),s')
   pass   m = mkStateTX (getVal::ix) $ \s -> pass $ do
       ~((a,f),s') <- runStateTX' {-(getVal::ix)-} m s
       return ((a,s'),f)

instance (Index ix, MonadStateX ix s m, Monoid w) => MonadStateX ix s (WriterT w m) where
   getx (ixv::ix)   = lift $ getx ixv
   putx (ixv::ix) s = lift $ putx ixv s


-- WriterX
-- Needs -fallow-undecidable-instances
instance (Index ixs, MonadWriterX ixw w m) => MonadWriterX ixw w (StateTX ixs s m) where
   tellx (ixv::ixw)     = lift . tellx ixv
   listenx (ixv::ixw) m = mkStateTX (getVal::ixs) $ \s -> do
       ~((a,s'),w) <- listenx ixv (runStateTX' {-(getVal::ixs)-} m s)
       return ((a,w),s')
   passx (ixv::ixw)   m = mkStateTX (getVal::ixs) $ \s -> passx ixv $ do
       ~((a,f),s') <- runStateTX' {-(getVal::ixs)-} m s
       return ((a,s'),f)

-- RWS
instance (Monoid w, Index ix2, MonadReader r m, MonadState s m, MonadWriter w m) => MonadRWS r w s (StateTX ix2 s2 m) where

instance (Monoid w2, Monad m, Index ix1, MonadStateX ix1 s1 m) => MonadStateX ix1 s1 (RWST r2 w2 s2 m) where
    getx (_::ix1)    = RWST $ \_ (s::s2) -> getx (getVal::ix1) >>= (\v1 -> return (v1, s, mempty))
    putx (_::ix1) s1 = RWST $ \_ (s::s2) -> putx (getVal::ix1) s1 >>  return ((), s, mempty)