{-# OPTIONS -fno-warn-orphans #-}
{-# LANGUAGE UndecidableInstances, ScopedTypeVariables, OverlappingInstances, FlexibleInstances, MultiParamTypeClasses #-}

-----------------------------------------------------------------------------
-- Module      :  Control.Monad.WriterX.Strict
-- Copyright   :  (c) Mark Snyder 2008.
-- License     :  BSD-style
-- Maintainer  :  Mark Snyder, marks@ittc.ku.edu
-- Stability   :  experimental
-- Portability :  non-portable (multi-param classes, functional dependencies)
-----------------------------------------------------------------------------

module Control.Monad.WriterX.Strict (
    module Control.Monad.WriterX.Class,
    WriterX(..),
    mkWriterX,
    runWriterX,
    execWriterX,
    mapWriterX,

    WriterTX(..),
    mkWriterTX,
    runWriterTX,
    execWriterTX,
    mapWriterTX,
    module Control.Monad,
    module Control.Monad.Fix,
    module Control.Monad.Trans,
    module Data.Monoid,
  ) where

import Control.Monad
import Control.Monad.Cont.Class
import Control.Monad.Error.Class
import Control.Monad.Fix
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Trans
import Control.Monad.Writer.Class
import Control.Monad.RWS
import Data.Monoid

import Control.Monad.Index

import Control.Monad.ErrorX.Class
import Control.Monad.ReaderX.Class
import Control.Monad.StateX.Class
import Control.Monad.WriterX.Class

-- ---------------------------------------------------------------------------

newtype WriterX ix w a = WriterX { runWriterX' :: (a, w) }
mkWriterX :: (Index ix) => ix -> (a,w) -> WriterX ix w a
mkWriterX _ v = WriterX v
runWriterX :: (Index ix) => ix -> WriterX ix w a -> (a,w)
runWriterX _ m = runWriterX' m


execWriterX :: (Index ix) => ix -> WriterX ix w a -> w
execWriterX (_::ix) m = snd (runWriterX' m)

mapWriterX :: (Index ix) => ix -> ((a, w) -> (b, w')) -> WriterX ix w a -> WriterX ix w' b
mapWriterX (ixv::ix) f m = mkWriterX ixv $ f (runWriterX' m)

instance (Index ix) => Functor (WriterX ix w) where
    fmap f m = mkWriterX (getVal::ix) $ case runWriterX' m of
                            (a, w) -> (f a, w)

instance (Monoid w, Index ix) => Monad (WriterX ix w) where
    return a = mkWriterX (getVal::ix) (a, mempty)
    m >>= k  = mkWriterX (getVal::ix) $ case runWriterX' m of
                                          (a, w) -> case runWriterX' (k a) of
                                                      (b, w') -> (b, w `mappend` w')


instance (Monoid w, Index ix) => MonadFix (WriterX ix w) where
    mfix m = mkWriterX (getVal::ix) $
            let (a, w) = runWriterX' (m a) in (a, w)

instance (Monoid w, Index ix) => MonadWriterX ix w (WriterX ix w) where
    tellx (ixv::ix)   w = mkWriterX ixv ((), w)
    listenx (ixv::ix) m = mkWriterX ixv $ case runWriterX' m of
                            (a, w) -> ((a, w), w)
    passx (ixv::ix)   m = mkWriterX ixv $ case runWriterX' m of
                            ((a, f), w) -> (a, f w)

-- ---------------------------------------------------------------------------

newtype WriterTX ix w m a = WriterTX  { runWriterTX' :: m (a, w) }
mkWriterTX :: (Index ix) => ix -> m (a,w) -> WriterTX ix w m a
mkWriterTX _ v = WriterTX v
runWriterTX :: (Index ix) => ix -> WriterTX ix w m a -> m (a,w)
runWriterTX _ m = runWriterTX' m

execWriterTX :: (Index ix, Monad m) => ix -> WriterTX ix w m a -> m w
execWriterTX (_::ix) m = do
    (_, w) <- runWriterTX' m
    return w

mapWriterTX :: (Index ix) => ix -> (m (a, w) -> n (b, w')) -> WriterTX ix w m a -> WriterTX ix w' n b
mapWriterTX (ixv::ix) f m = mkWriterTX ixv $ f (runWriterTX' m)

instance (Monad m, Index ix) => Functor (WriterTX ix w m) where
    fmap f m = mkWriterTX (getVal::ix) $ do
        (a, w) <- runWriterTX' m
        return (f a, w)

instance (Monoid w, Monad m, Index ix) => Monad (WriterTX ix w m) where
    return a = mkWriterTX (getVal::ix) $ return (a, mempty)
    m >>= k  = mkWriterTX (getVal::ix) $ do
        (a, w)  <- runWriterTX' m
        (b, w') <- runWriterTX' (k a)
        return (b, w `mappend` w')
    fail msg = mkWriterTX (getVal::ix) $ fail msg

instance (Monoid w, MonadPlus m, Index ix) => MonadPlus (WriterTX ix w m) where
    mzero       = mkWriterTX (getVal::ix) mzero
    m `mplus` n = mkWriterTX (getVal::ix) $ runWriterTX' m `mplus` runWriterTX' n

instance (Monoid w, MonadFix m, Index ix) => MonadFix (WriterTX ix w m) where
    mfix m = mkWriterTX (getVal::ix) $ mfix $ \ ~(a, _) -> runWriterTX' (m a)

instance (Monoid w, Monad m, Index ix) => MonadWriterX ix w (WriterTX ix w m) where
    tellx   (ixv::ix) w = mkWriterTX ixv $ return ((), w)
    listenx (ixv::ix) m = mkWriterTX ixv $ do
        (a, w) <- runWriterTX' m
        return ((a, w), w)
    passx   (ixv::ix) m = mkWriterTX ixv $ do
        ((a, f), w) <- runWriterTX' m
        return (a, f w)

-- ---------------------------------------------------------------------------

instance (Monoid w, Index ix) => MonadTrans (WriterTX ix w) where
    lift m = mkWriterTX (getVal::ix) $ do
        a <- m
        return (a, mempty)

instance (Monoid w, MonadIO m, Index ix) => MonadIO (WriterTX ix w m) where
    liftIO = lift . liftIO

instance (Monoid w, MonadCont m, Index ix) => MonadCont (WriterTX ix w m) where
    callCC f = mkWriterTX (getVal::ix) $
        callCC $ \c ->
        runWriterTX' (f (\a -> mkWriterTX (getVal::ix) $ c (a, mempty)))


-- ---------------------------------------------------------------------------

-- Instances for other mtl transformers

-- Error
instance (Monoid w, MonadError e m, Index ix) => MonadError e (WriterTX ix w m) where
    throwError       = lift . throwError
    m `catchError` h = mkWriterTX (getVal::ix) $ runWriterTX' m
        `catchError` \e -> runWriterTX' (h e)

-- ErrorX
instance (Monoid w, Index ixe, Index ixw, MonadErrorX ixe e m) => MonadErrorX ixe e (WriterTX ixw w m) where
    throwErrorx (ixv::ixe)       = lift . throwErrorx ixv
    catchErrorx (ixv::ixe) m h = mkWriterTX (getVal::ixw) $ 
                    catchErrorx
                    (ixv::ixe)
                    (runWriterTX' m)
                    (\e -> runWriterTX' (h e))

-- Reader
-- This instance needs -fallow-undecidable-instances, because it does not satisfy the coverage condition
instance (Monoid w, MonadReader r m, Index ixw) => MonadReader r (WriterTX ixw w m) where
    ask       = lift ask
    local f m = mkWriterTX (getVal::ixw) $ local f (runWriterTX' m)

-- ReaderX
-- This instance needs -fallow-undecidable-instances, because it does not satisfy the coverage condition
instance (Monoid w, Index ixr, Index ixw, MonadReaderX ixr r m) => MonadReaderX ixr r (WriterTX ixw w m) where
    askx (ixv::ixr)       = lift $ askx ixv
    localx (ixv::ixr) f m = mkWriterTX (getVal::ixw) $ localx ixv f (runWriterTX' m)

-- State -- Needs -fallow-undecidable-instances
instance (Monoid w, MonadState s m, Index ixw) => MonadState s (WriterTX ixw w m) where
    get = lift get
    put = lift . put

-- StateX -- Needs -fallow-undecidable-instances
instance (Monoid w, Index ixw, MonadStateX ixs s m) => MonadStateX ixs s (WriterTX ixw w m) where
    getx (ixv::ixs) = lift $ getx ixv
    putx (ixv::ixs) = lift . putx ixv

-- Writer
instance (Index ixw2, Monoid w, Monoid s, MonadWriter w m) => MonadWriter w (WriterTX ixw2 s m) where
   tell     = lift . tell
   listen m = mkWriterTX (getVal::ixw2) $ do
       ((a,s'),w) <- listen (runWriterTX' m)
       return ((a,w),s')
   pass   m = mkWriterTX (getVal::ixw2) $ pass $ do
       ((a,f),s') <- runWriterTX' m
       return ((a,s'),f)

-- WriterX
instance (Index ixw1, Index ixw2, Monoid w1, Monoid w2, MonadWriterX ixw1 w1 m)
    => MonadWriterX ixw1 w1 (WriterTX ixw2 w2 m) where
   tellx (ixv::ixw1)     = lift . tellx ixv
   listenx (ixv::ixw1) m = mkWriterTX (getVal::ixw2) $ do
       ((a,s'),w) <- listenx ixv (runWriterTX' m)
       return ((a,w),s')
   passx (ixv::ixw1)   m = mkWriterTX (getVal::ixw2) $ passx ixv $ do
       ((a,f),s') <- runWriterTX' m
       return ((a,s'),f)

-- RWS
instance (Monoid w1, Monoid w2, Index ix2, MonadReader r m, MonadState s m, MonadWriter w1 m) => MonadRWS r w1 s (WriterTX ix2 w2 m) where

instance (Monoid w1, Monoid w2, Monad m, Index ix1, MonadWriterX ix1 w1 m) => MonadWriterX ix1 w1 (RWST r2 w2 s2 m) where
    tellx (_::ix1)  w1 = RWST $ \_ s2 -> tellx (getVal::ix1) w1 >> return ((),s2,mempty)
    listenx (_::ix1) m = RWST $ \r2 s2 -> do
        ((a,s2',w2'),w1) <- listenx (getVal::ix1) (runRWST m r2 s2)
        return ((a,w1),s2',w2')
    passx (_::ix1) m = RWST $ \r2 s2 -> do
        (a,b,c) <- runRWST m r2 s2
        a' <- passx (getVal::ix1) $ return a
        return (a',b,c)