-- Search for -fallow-undecidable-instances to see why this is needed
{-# LANGUAGE UndecidableInstances,ScopedTypeVariables, OverlappingInstances #-}

-----------------------------------------------------------------------------
-- Module      :  Control.Monad.WriterX.Strict
-- Copyright   :  (c) Mark Snyder 2008.
-- License     :  BSD-style
-- Maintainer  :  Mark Snyder, marks@ittc.ku.edu
-- Stability   :  experimental
-- Portability :  non-portable (multi-param classes, functional dependencies)
-----------------------------------------------------------------------------

module Control.Monad.WriterX.Strict (
    module Control.Monad.WriterX.Class,
    WriterX(..),
    runWriterX,
    execWriterX,
    mapWriterX,

    WriterTX(..),
    runWriterTX,
    execWriterTX,
    mapWriterTX,
    module Control.Monad,
    module Control.Monad.Fix,
    module Control.Monad.Trans,
    module Data.Monoid,
  ) where

import Control.Monad
import Control.Monad.Cont.Class
import Control.Monad.Error.Class
import Control.Monad.Fix
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Trans
import Control.Monad.Writer.Class
import Data.Monoid

import Control.Monad.Index

--import Control.Monad.ErrorX.Class
import Control.Monad.ReaderX.Class
import Control.Monad.StateX.Class
import Control.Monad.WriterX.Class

-- ---------------------------------------------------------------------------

data (Index ix) => WriterX ix w a = WriterX ix (a,w)
runWriterX :: (Index ix) => ix -> WriterX ix w a -> (a,w)
runWriterX (_::ix) (WriterX (_::ix) f) = f
--instead of... newtype Writer w a = Writer { runWriter :: (a, w) }

execWriterX :: (Index ix) => ix -> WriterX ix w a -> w
execWriterX (ixv::ix) m = snd (runWriterX ixv m)

mapWriterX::(Index ix)=>ix->((a,w)->(b,w'))->WriterX ix w a->WriterX ix w' b
mapWriterX (ixv::ix) f m = WriterX ixv $ f (runWriterX ixv m)

instance (Index ix) => Functor (WriterX ix w) where
    fmap f m = WriterX (getVal::ix) $ case runWriterX (getVal::ix) m of
                            (a, w) -> (f a, w)

instance (Monoid w, Index ix) => Monad (WriterX ix w) where
    return a = WriterX (getVal::ix) (a, mempty)
    m >>= k  = WriterX (getVal::ix) $ case runWriterX (getVal::ix) m of
                            (a, w) -> case runWriterX (getVal::ix) (k a) of
                                (b, w') -> (b, w `mappend` w')

instance (Monoid w, Index ix) => MonadFix (WriterX ix w) where
    mfix m = WriterX (getVal::ix) $
            let (a, w) = runWriterX (getVal::ix) (m a) in (a, w)

instance (Monoid w, Index ix) => MonadWriterX ix w (WriterX ix w) where
    tellx (ixv::ix)   w = WriterX ixv ((), w)
    listenx (ixv::ix) m = WriterX ixv $ case runWriterX ixv m of
                            (a, w) -> ((a, w), w)
    passx (ixv::ix)   m = WriterX ixv $ case runWriterX ixv m of
                            ((a, f), w) -> (a, f w)

-- ---------------------------------------------------------------------------
-- Our parameterizable writer monad, with an inner monad

data (Index ix) => WriterTX ix w m a = WriterTX ix (m (a,w))
runWriterTX :: (Index ix) => ix -> WriterTX ix w m a -> m (a,w)
runWriterTX (_::ix) (WriterTX (_::ix) m) = m
-- instead of... newtype WriterT w m a = WriterT { runWriterT :: m (a, w) }

execWriterTX :: (Monad m, Index ix) => ix -> WriterTX ix w m a -> m w
execWriterTX (ixv::ix) m = do
    (_, w) <- runWriterTX ixv m
    return w

mapWriterTX :: (Index ix) => ix -> (m (a, w) -> n (b, w')) -> WriterTX ix w m a -> WriterTX ix w' n b
mapWriterTX (ixv::ix) f m = WriterTX ixv $ f (runWriterTX ixv m)

instance (Monad m, Index ix) => Functor (WriterTX ix w m) where
    fmap f m = WriterTX (getVal::ix) $ do
        (a, w) <- runWriterTX (getVal::ix) m
        return (f a, w)

instance (Monoid w, Monad m, Index ix) => Monad (WriterTX ix w m) where
    return a = WriterTX (getVal::ix) $ return (a, mempty)
    m >>= k  = WriterTX (getVal::ix) $ do
        (a, w)  <- runWriterTX (getVal::ix) m
        (b, w') <- runWriterTX (getVal::ix) (k a)
        return (b, w `mappend` w')
    fail msg = WriterTX (getVal::ix) $ fail msg

instance (Monoid w, MonadPlus m, Index ix) => MonadPlus (WriterTX ix w m) where
    mzero       = WriterTX (getVal::ix) mzero
    m `mplus` n = WriterTX (getVal::ix) $ runWriterTX (getVal::ix) m `mplus` runWriterTX (getVal::ix) n

instance (Monoid w, MonadFix m, Index ix) => MonadFix (WriterTX ix w m) where
    mfix m = WriterTX (getVal::ix) $ mfix $ \ ~(a, _) -> runWriterTX (getVal::ix) (m a)

instance (Monoid w, Monad m, Index ix) => MonadWriterX ix w (WriterTX ix w m) where
    tellx   (ixv::ix) w = WriterTX ixv $ return ((), w)
    listenx (ixv::ix) m = WriterTX ixv $ do
        (a, w) <- runWriterTX (getVal::ix) m
        return ((a, w), w)
    passx   (ixv::ix) m = WriterTX ixv $ do
        ((a, f), w) <- runWriterTX ixv m
        return (a, f w)

-- ---------------------------------------------------------------------------
-- Instances for other mtl transformers

instance (Monoid w, Index ix) => MonadTrans (WriterTX ix w) where
    lift m = WriterTX (getVal::ix) $ do
        a <- m
        return (a, mempty)

instance (Monoid w, MonadIO m, Index ix) => MonadIO (WriterTX ix w m) where
    liftIO = lift . liftIO

instance (Monoid w, MonadCont m, Index ix) => MonadCont (WriterTX ix w m) where
    callCC f = WriterTX (getVal::ix) $
        callCC $ \c -> runWriterTX
                       (getVal::ix)
                       (f (\a -> WriterTX (getVal::ix) $ c (a, mempty)))




-- Error
instance (Monoid w, MonadError e m, Index ix) => MonadError e (WriterTX ix w m) where
    throwError       = lift . throwError
    m `catchError` h = WriterTX (getVal::ix) $ runWriterTX (getVal::ix) m
        `catchError` \e -> runWriterTX (getVal::ix) (h e)



{-- Not Ready.
-- ErrorX
instance (Monoid w, MonadErrorX ixe e m, Index ixw) => MonadErrorX ixe e (WriterTX ixw w m) where
    throwErrorx (ixv::ixe)       = lift . throwErrorx ixv
    catchErrorx (ixv::ixe) m h = WriterTX (getVal::ixw) $ 
                    catchErrorx
                    (ixv::ixe)
                    (runWriterTX (getVal::ixw) m)
                    (\e -> runWriterTX (getVal::ixw) (h e))

--}


-- Reader
-- This instance needs -fallow-undecidable-instances, because
-- it does not satisfy the coverage condition
instance (Monoid w, MonadReader r m, Index ix) => MonadReader r (WriterTX ix w m) where
    ask       = lift ask
    local f m = WriterTX (getVal::ix) $ local f (runWriterTX (getVal::ix) m)




-- ReaderX
-- This instance needs -fallow-undecidable-instances, because -- it does not satisfy the coverage condition
instance (Monoid w, MonadReaderX ixr r m, Index ixw) => MonadReaderX ixr r (WriterTX ixw w m) where
    askx (ixv::ixr)       = lift $ askx ixv
    localx (ixv::ixr) f m = WriterTX (getVal::ixw) $ localx ixv f (runWriterTX (getVal::ixw) m)




-- Writer
instance (Index ixw2, MonadWriter w m, Monoid s) => MonadWriter w (WriterTX ixw2 s m) where
   tell     = lift . tell
   listen m = WriterTX (getVal::ixw2) $ do
       ~((a,s'),w) <- listen (runWriterTX (getVal::ixw2) m)
       return ((a,w),s')
   pass   m = WriterTX (getVal::ixw2) $ pass $ do
       ~((a,f),s') <- runWriterTX (getVal::ixw2) m
       return ((a,s'),f)




--WriterX
instance (Index ixw2, MonadWriterX ixw1 w m, Monoid w, Monoid s)
  => MonadWriterX ixw1 w (WriterTX ixw2 s m) where
   tellx (ixv::ixw1)     = lift . tellx ixv
   listenx (ixv::ixw1) m = WriterTX (getVal::ixw2) $ do
       ~((a,s'),w) <- listenx ixv (runWriterTX (getVal::ixw2) m)
       return ((a,w),s')
   passx (ixv::ixw1)   m = WriterTX (getVal::ixw2) $ passx ixv $ do
       ~((a,f),s') <- runWriterTX (getVal::ixw2) m
       return ((a,s'),f)




-- State
-- Needs -fallow-undecidable-instances
instance (Monoid w, MonadState s m, Index ix) => MonadState s (WriterTX ix w m) where
    get = lift get
    put = lift . put




-- StateX
-- Needs -fallow-undecidable-instances
instance (Monoid w, MonadStateX ixs s m, Index ixw) => MonadStateX ixs s (WriterTX ixw w m) where
    getx (ixv::ixs) = lift $ getx ixv
    putx (ixv::ixs) = lift . putx ixv