-- Search for -fallow-undecidable-instances to see why this is needed
{-# LANGUAGE UndecidableInstances,ScopedTypeVariables, OverlappingInstances, FlexibleInstances, MultiParamTypeClasses #-}

-----------------------------------------------------------------------------
-- Module      :  Control.Monad.WriterX.Strict
-- Copyright   :  (c) Mark Snyder 2008.
-- License     :  BSD-style
-- Maintainer  :  Mark Snyder, marks@ittc.ku.edu
-- Stability   :  experimental
-- Portability :  non-portable (multi-param classes, functional dependencies)
-----------------------------------------------------------------------------

module Control.Monad.WriterX.Strict (
    module Control.Monad.WriterX.Class,
    WriterX(..),
    runWriterX,
    execWriterX,
    mapWriterX,

    WriterTX(..),
    runWriterTX,
    execWriterTX,
    mapWriterTX,
    module Control.Monad,
    module Control.Monad.Fix,
    module Control.Monad.Trans,
    module Data.Monoid,
  ) where

import Control.Monad
import Control.Monad.Cont.Class
import Control.Monad.Error.Class
import Control.Monad.Fix
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Trans
import Control.Monad.Writer.Class
import Control.Monad.RWS

import Data.Monoid

import Control.Monad.Index

import Control.Monad.ErrorX.Class
import Control.Monad.ReaderX.Class
import Control.Monad.StateX.Class
import Control.Monad.WriterX.Class

-- ---------------------------------------------------------------------------

--data (Index ix) => WriterX ix w a = WriterX ix (a,w)
--runWriterX :: (Index ix) => ix -> WriterX ix w a -> (a,w)
--runWriterX (_::ix) (WriterX (_::ix) f) = f
newtype WriterX ix w a = WriterX { runWriterX' :: (a, w) }
mkWriterX :: (Index ix) => ix -> (a,w) -> WriterX ix w a
mkWriterX _ v = WriterX v
runWriterX :: (Index ix) => ix -> WriterX ix w a -> (a,w)
runWriterX _ m = runWriterX' m

execWriterX :: (Index ix) => ix -> WriterX ix w a -> w
execWriterX (ixv::ix) m = snd (runWriterX' {-ixv-} m)

mapWriterX::(Index ix)=>ix->((a,w)->(b,w'))->WriterX ix w a->WriterX ix w' b
mapWriterX (ixv::ix) f m = mkWriterX ixv $ f (runWriterX' {-ixv-} m)

instance (Index ix) => Functor (WriterX ix w) where
    fmap f m = mkWriterX (getVal::ix) $ case runWriterX' {-(getVal::ix)-} m of
                            (a, w) -> (f a, w)

instance (Monoid w, Index ix) => Monad (WriterX ix w) where
    return a = mkWriterX (getVal::ix) (a, mempty)
    m >>= k  = mkWriterX (getVal::ix) $ case runWriterX' {-(getVal::ix)-} m of
                                          (a, w) -> case runWriterX' {-(getVal::ix)-} (k a) of
                                                      (b, w') -> (b, w `mappend` w')

instance (Monoid w, Index ix) => MonadFix (WriterX ix w) where
    mfix m = mkWriterX (getVal::ix) $
            let (a, w) = runWriterX' {-(getVal::ix)-} (m a) in (a, w)

instance (Monoid w, Index ix) => MonadWriterX ix w (WriterX ix w) where
    tellx (ixv::ix)   w = mkWriterX ixv ((), w)
    listenx (ixv::ix) m = mkWriterX ixv $ case runWriterX' {-ixv-} m of
                            (a, w) -> ((a, w), w)
    passx (ixv::ix)   m = mkWriterX ixv $ case runWriterX' {-ixv-} m of
                            ((a, f), w) -> (a, f w)

-- ---------------------------------------------------------------------------
-- Our parameterizable writer monad, with an inner monad

--data (Index ix) => WriterTX ix w m a = WriterTX ix (m (a,w))
--runWriterTX :: (Index ix) => ix -> WriterTX ix w m a -> m (a,w)
--runWriterTX (_::ix) (WriterTX (_::ix) m) = m
newtype WriterTX ix w m a = WriterTX  { runWriterTX' :: m (a, w) }
mkWriterTX :: (Index ix) => ix -> m (a,w) -> WriterTX ix w m a
mkWriterTX _ v = WriterTX v
runWriterTX :: (Index ix) => ix -> WriterTX ix w m a -> m (a,w)
runWriterTX _ m = runWriterTX' m

execWriterTX :: (Monad m, Index ix) => ix -> WriterTX ix w m a -> m w
execWriterTX (ixv::ix) m = do
    (_, w) <- runWriterTX' {-ixv-} m
    return w

mapWriterTX :: (Index ix) => ix -> (m (a, w) -> n (b, w')) -> WriterTX ix w m a -> WriterTX ix w' n b
mapWriterTX (ixv::ix) f m = mkWriterTX ixv $ f (runWriterTX' {-ixv-} m)

instance (Monad m, Index ix) => Functor (WriterTX ix w m) where
    fmap f m = mkWriterTX (getVal::ix) $ do
        (a, w) <- runWriterTX' {-(getVal::ix)-} m
        return (f a, w)

instance (Monoid w, Monad m, Index ix) => Monad (WriterTX ix w m) where
    return a = mkWriterTX (getVal::ix) $ return (a, mempty)
    m >>= k  = mkWriterTX (getVal::ix) $ do
        (a, w)  <- runWriterTX' {-(getVal::ix)-} m
        (b, w') <- runWriterTX' {-(getVal::ix)-} (k a)
        return (b, w `mappend` w')
    fail msg = mkWriterTX (getVal::ix) $ fail msg

instance (Monoid w, MonadPlus m, Index ix) => MonadPlus (WriterTX ix w m) where
    mzero       = mkWriterTX (getVal::ix) mzero
    m `mplus` n = mkWriterTX (getVal::ix) $ runWriterTX' {-(getVal::ix)-} m `mplus` runWriterTX' {-(getVal::ix)-} n

instance (Monoid w, MonadFix m, Index ix) => MonadFix (WriterTX ix w m) where
    mfix m = mkWriterTX (getVal::ix) $ mfix $ \ ~(a, _) -> runWriterTX' {-(getVal::ix)-} (m a)

instance (Monoid w, Monad m, Index ix) => MonadWriterX ix w (WriterTX ix w m) where
    tellx   (ixv::ix) w = mkWriterTX ixv $ return ((), w)
    listenx (ixv::ix) m = mkWriterTX ixv $ do
        (a, w) <- runWriterTX' {-(getVal::ix)-} m
        return ((a, w), w)
    passx   (ixv::ix) m = mkWriterTX ixv $ do
        ((a, f), w) <- runWriterTX' {-ixv-} m
        return (a, f w)

-- ---------------------------------------------------------------------------

instance (Monoid w, Index ix) => MonadTrans (WriterTX ix w) where
    lift m = mkWriterTX (getVal::ix) $ do
        a <- m
        return (a, mempty)

instance (Monoid w, MonadIO m, Index ix) => MonadIO (WriterTX ix w m) where
    liftIO = lift . liftIO

instance (Monoid w, MonadCont m, Index ix) => MonadCont (WriterTX ix w m) where
    callCC f = mkWriterTX (getVal::ix) $
        callCC $ \c -> runWriterTX' {-(getVal::ix)-} (f (\a -> mkWriterTX (getVal::ix) $ c (a, mempty)))

-- ---------------------------------------------------------------------------

-- Instances for other mtl transformers

-- Error
instance (Monoid w, MonadError e m, Index ix) => MonadError e (WriterTX ix w m) where
    throwError       = lift . throwError
    m `catchError` h = mkWriterTX (getVal::ix) $ runWriterTX' {-(getVal::ix)-} m
        `catchError` \e -> runWriterTX' {-(getVal::ix)-} (h e)

-- ErrorX
instance (Monoid w, MonadErrorX ixe e m, Index ixw) => MonadErrorX ixe e (WriterTX ixw w m) where
    throwErrorx (ixv::ixe)       = lift . throwErrorx ixv
    catchErrorx (ixv::ixe) m h = mkWriterTX (getVal::ixw) $ 
                    catchErrorx
                    (ixv::ixe)
                    (runWriterTX' {-(getVal::ixw)-} m)
                    (\e -> runWriterTX' {-(getVal::ixw)-} (h e))

-- Reader
-- This instance needs -fallow-undecidable-instances, because it does not satisfy the coverage condition
instance (Monoid w, MonadReader r m, Index ix) => MonadReader r (WriterTX ix w m) where
    ask       = lift ask
    local f m = mkWriterTX (getVal::ix) $ local f (runWriterTX' {-(getVal::ix)-} m)

-- ReaderX
-- This instance needs -fallow-undecidable-instances, because -- it does not satisfy the coverage condition
instance (Monoid w, MonadReaderX ixr r m, Index ixw) => MonadReaderX ixr r (WriterTX ixw w m) where
    askx (ixv::ixr)       = lift $ askx ixv
    localx (ixv::ixr) f m = mkWriterTX (getVal::ixw) $ localx ixv f (runWriterTX' {-(getVal::ixw)-} m)

-- Writer
instance (Index ixw2, MonadWriter w m, Monoid s) => MonadWriter w (WriterTX ixw2 s m) where
   tell     = lift . tell
   listen m = mkWriterTX (getVal::ixw2) $ do
       ~((a,s'),w) <- listen (runWriterTX' {-(getVal::ixw2)-} m)
       return ((a,w),s')
   pass   m = mkWriterTX (getVal::ixw2) $ pass $ do
       ~((a,f),s') <- runWriterTX' {-(getVal::ixw2)-} m
       return ((a,s'),f)

--WriterX
instance (Index ixw2, MonadWriterX ixw1 w m, Monoid w, Monoid s)
  => MonadWriterX ixw1 w (WriterTX ixw2 s m) where
   tellx (ixv::ixw1)     = lift . tellx ixv
   listenx (ixv::ixw1) m = mkWriterTX (getVal::ixw2) $ do
       ~((a,s'),w) <- listenx ixv (runWriterTX' {-(getVal::ixw2)-} m)
       return ((a,w),s')
   passx (ixv::ixw1)   m = mkWriterTX (getVal::ixw2) $ passx ixv $ do
       ~((a,f),s') <- runWriterTX' {-(getVal::ixw2)-} m
       return ((a,s'),f)

-- State (Needs -fallow-undecidable-instances)
instance (Monoid w, MonadState s m, Index ix) => MonadState s (WriterTX ix w m) where
    get = lift get
    put = lift . put

-- StateX (Needs -fallow-undecidable-instances)
instance (Monoid w, MonadStateX ixs s m, Index ixw) => MonadStateX ixs s (WriterTX ixw w m) where
    getx (ixv::ixs) = lift $ getx ixv
    putx (ixv::ixs) = lift . putx ixv


-- RWS
instance (Monoid w1, Monoid w2, Index ix2, MonadReader r m, MonadState s m, MonadWriter w1 m) => MonadRWS r w1 s (WriterTX ix2 w2 m) where

instance (Monoid w1, Monoid w2, Monad m, Index ix1, MonadWriterX ix1 w1 m) => MonadWriterX ix1 w1 (RWST r2 w2 s2 m) where
    tellx (_::ix1)  w1 = RWST $ \_ s2 -> tellx (getVal::ix1) w1 >> return ((),s2,mempty)
    listenx (_::ix1) m = RWST $ \r2 s2 -> do
        ~((a,s2',w2'),w1) <- listenx (getVal::ix1) (runRWST m r2 s2)
        return ((a,w1),s2',w2')
    passx (_::ix1) m = RWST $ \r2 s2 -> do
        ~(a,b,c) <- runRWST m r2 s2
        ~a' <- passx (getVal::ix1) $ return a
        return (a',b,c)