{-# OPTIONS -fno-warn-orphans #-}
{-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies, ScopedTypeVariables, FlexibleInstances, UndecidableInstances #-}


-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Monad.StateX.Class
-- Copyright   :  (c) Mark Snyder 2008.
-- License     :  BSD-style
-- Maintainer  :  Mark Snyder, marks@ittc.ku.edu
-- Stability   :  experimental
-- Portability :  non-portable (multi-param classes, functional dependencies)

-----------------------------------------------------------------------------

module Control.Monad.StateX.Class (
    MonadStateX(..),
    modifyx,
    getsx,
  ) where

import Control.Monad.Index

import Control.Monad.Error
import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Writer
import Control.Monad.RWS


-- ---------------------------------------------------------------------------

class (Monad m, Index ix) => MonadStateX ix s m | ix m -> s where
    getx :: ix -> m s
    putx :: ix -> s -> m ()

modifyx :: (MonadStateX ix s m) => ix -> (s -> s) -> m ()
modifyx ix f = do
    s <- getx ix
    putx ix (f s)

getsx :: (MonadStateX ix s m) => ix -> (s -> a) -> m a
getsx ix f = do
    s <- getx ix
    return (f s)

-- we need to be careful about allowing for strict and lazy versions
-- in these instances, but they are here to avoid orphaned instances
-- in the Lazy.hs and Strict.hs files.  instances defined using lift
-- should avoid this problem.  We do no pattern matching on tuples, so
-- we *should* be okay with all of these instances.

-- Error
instance (Error e, MonadStateX ix s m) => MonadStateX ix s (ErrorT e m) where
    getx (ixv::ix) = lift $ getx ixv
    putx (ixv::ix) = lift . putx ixv


-- State
instance (MonadStateX ix s1 m, Index ix) 
    => MonadStateX ix s1 (StateT s2 m) where
    getx (ixv::ix)  = StateT $ \s -> do
                              n <- getx ixv
                              return (n,s)
    putx (ixv::ix) (v::s1) = StateT $ \s -> do
                              putx ixv v
                              return ((),s)


-- Reader
instance (Index ix, MonadStateX ix s m) => MonadStateX ix s (ReaderT r m) where
  getx (ixv::ix)   = lift $ getx ixv
  putx (ixv::ix) v = lift $ putx ixv v


--Writer
instance (Index ix, MonadStateX ix s m, Monoid w) => MonadStateX ix s (WriterT w m) where
   getx (ixv::ix)   = lift $ getx ixv
   putx (ixv::ix) s = lift $ putx ixv s


-- RWS
instance (Monoid w2, Monad m, Index ix1, MonadStateX ix1 s1 m) => MonadStateX ix1 s1 (RWST r2 w2 s2 m) where
    getx (_::ix1)    = RWST $ \_ (s::s2) -> getx (getVal::ix1) >>= (\v1 -> return (v1, s, mempty))
    putx (_::ix1) s1 = RWST $ \_ (s::s2) -> putx (getVal::ix1) s1 >>  return ((), s, mempty)