{-# OPTIONS -fglasgow-exts -fno-warn-orphans  #-} 
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE OverlappingInstances #-}


-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Monad.StateX
-- Copyright   :  (c) Mark Snyder 2008.
-- License     :  BSD-style
-- Maintainer  :  Mark Snyder, marks@ittc.ku.edu
-- Portability :  non-portable (multi-param classes, functional dependencies)
--
-- Error monads.
--
--      This module is inspired by the paper
--      /Functional Programming with Overloading and
--          Higher-Order Polymorphism/,
--        Mark P Jones (<http://web.cecs.pdx.edu/~mpj/>)
--          Advanced School of Functional Programming, 1995.

-----------------------------------------------------------------------------

module Control.Monad.ErrorX (
    module Control.Monad.ErrorX.Class,
    ErrorTX(..),
    runErrorTX,
    mapErrorTX,
    module Control.Monad,
    module Control.Monad.Fix,
    module Control.Monad.Trans,
    module Control.Monad.Index
  ) where

import Control.Monad
import Control.Monad.Cont.Class
import Control.Monad.Error
import Control.Monad.Fix
import Control.Monad.RWS.Class
import Control.Monad.RWS
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Trans
import Control.Monad.Writer.Class

import Data.Monoid

import Control.Monad.Instances ()
import System.IO

import Control.Monad.Index
import Control.Monad.ErrorX.Class
import Control.Monad.ReaderX.Class
import Control.Monad.StateX.Class
import Control.Monad.WriterX.Class

instance (Index ix) => MonadErrorX ix IOError IO where
    throwErrorx (_::ix) = ioError
    catchErrorx (_::ix) m h = catch m h
-- ---------------------------------------------------------------------------
-- ---------------------------------------------------------------------------

--data (Index ix) => ErrorTX ix e m a = ErrorTX ix (m (Either e a))
--runErrorTX :: (Index ix) => ix -> ErrorTX ix e m a -> m (Either e a)
--runErrorTX (_::ix) (ErrorTX (_::ix) f) = f
newtype ErrorTX ix e m a = ErrorTX { runErrorTX' :: m (Either e a) }
mkErrorTX :: (Index ix) => ix -> m (Either e a) -> ErrorTX ix e m a
mkErrorTX _ m = ErrorTX m
runErrorTX :: (Index ix) => ix -> ErrorTX ix e m a -> m (Either e a)
runErrorTX _ m = runErrorTX' m



mapErrorTX :: (Index ix) => ix 
          -> (m (Either e a) -> n (Either e' b))
          -> ErrorTX ix e m a
          -> ErrorTX ix e' n b
mapErrorTX ixv f m = mkErrorTX ixv $ f (runErrorTX' {-ixv-} m)

-- basic instances for any monad.
instance (Monad m, Index ix) => Functor (ErrorTX ix e m) where
    fmap f m = mkErrorTX (getVal::ix) $ do
        a <- runErrorTX' {-(getVal::ix)-} m
        case a of
            Left  l -> return (Left  l)
            Right r -> return (Right (f r))

instance (Monad m, ErrorX ix e, Index ix) => Monad (ErrorTX ix e m) where
    return a = mkErrorTX (getVal::ix) $ return (Right a)
    m >>= k  = mkErrorTX (getVal::ix) $ do
        a <- runErrorTX' {-(getVal::ix)-} m
        case a of
            Left  l -> return (Left l)
            Right r -> runErrorTX' {-(getVal::ix)-} (k r)
    fail msg = mkErrorTX (getVal::ix) $ return (Left (strMsgx (getVal::ix) msg))

instance (Monad m, ErrorX ix e, Index ix) => MonadPlus (ErrorTX ix e m) where
    mzero       = mkErrorTX (getVal::ix) $ return (Left (noMsgx (getVal::ix)))
    m `mplus` n = mkErrorTX (getVal::ix) $ do
        a <- runErrorTX' {-(getVal::ix)-} m
        case a of
            Left  _ -> runErrorTX' {-(getVal::ix)-} n
            Right r -> return (Right r)

instance (MonadFix m, ErrorX ix e, Index ix) => MonadFix (ErrorTX ix e m) where
    mfix f = mkErrorTX (getVal::ix) $ mfix $ \a -> runErrorTX' {-(getVal::ix)-} $ f $ case a of
        Right r -> r
        _       -> error "empty mfix argument"

-- ErrorX (same index)
instance (Monad m, ErrorX ix e, Index ix) => MonadErrorX ix e (ErrorTX ix e m) where
    throwErrorx (ixv::ix) l   = mkErrorTX ixv $ return (Left l)
    catchErrorx (ixv::ix) m h = mkErrorTX ixv $ do
        a <- runErrorTX' {-ixv-} m
        case a of
            Left  l -> runErrorTX' {-(getVal::ix)-} (h l)
            Right r -> return (Right r)

-- ErrorX (different indexes)
instance (Monad m, ErrorX ix1 e1, ErrorX ix2 e2, Index ix1, Index ix2, MonadErrorX ix1 e1 m) => MonadErrorX ix1 e1 (ErrorTX ix2 e2 m) where
    throwErrorx (_::ix1) (v::e1) = mkErrorTX (getVal::ix2) $ do
           val <- throwErrorx (getVal::ix1) v
           return $ Right $ val
    catchErrorx (_::ix1) (m::ErrorTX ix2 e2 m a) (h::e1->ErrorTX ix2 e2 m a) = mkErrorTX (getVal::ix2) $
                    do  x <- catchErrorx (getVal::ix1) (runErrorTX' {-(getVal::ix2)-}  m >>= return . Right) (return . Left)
                        case x of
                          (Left  e) -> runErrorTX' {-(getVal::ix2)-} $ h e
                          (Right y) -> return y

-- ---------------------------------------------------------------------------
-- Instances for other mtl transformers

instance (ErrorX ix e, Index ix) => MonadTrans (ErrorTX ix e) where
    lift m = mkErrorTX (getVal::ix) $ do
        a <- m
        return (Right a)

instance (ErrorX ix e, MonadIO m, Index ix) => MonadIO (ErrorTX ix e m) where
    liftIO = lift . liftIO

instance (ErrorX ix e, MonadCont m, Index ix) => MonadCont (ErrorTX ix e m) where
    callCC f = mkErrorTX (getVal::ix) $
        callCC $ \c ->
        runErrorTX' {-(getVal::ix)-} (f (\a -> mkErrorTX (getVal::ix) $ c (Right a)))

-- RWS
instance (ErrorX ix e, MonadRWS r w s m, Index ix) => MonadRWS r w s (ErrorTX ix e m)

instance (Monoid w, MonadErrorX ixe e m, Index ixe) => MonadErrorX ixe e (RWST r w s m) where
    throwErrorx (ixv::ixe)     = lift . throwErrorx ixv
    catchErrorx (ixv::ixe) m h = RWST $ \r s -> 
                                catchErrorx
                                (ixv)
                                (runRWST m r s)
                                (\e -> runRWST (h e) r s)

-- Reader
instance (ErrorX ix e, MonadReader r m, Index ix) => MonadReader r (ErrorTX ix e m) where
    ask       = lift ask
    local f m = mkErrorTX (getVal::ix) $ local f (runErrorTX' {-(getVal::ix)-} m)

-- State
instance (ErrorX ix e, MonadState s m, Index ix) => MonadState s (ErrorTX ix e m) where
    get = lift get
    put = lift . put

--Writer
instance (ErrorX ix e, MonadWriter w m, Index ix) => MonadWriter w (ErrorTX ix e m) where
    tell     = lift . tell
    listen m = mkErrorTX (getVal::ix) $ do
        (a, w) <- listen (runErrorTX' {-(getVal::ix)-} m)
        case a of
            Left  l -> return $ Left   l
            Right r -> return $ Right (r, w)
    pass   m = mkErrorTX (getVal::ix) $ pass $ do
        a <- runErrorTX' {-(getVal::ix)-} m
        case a of
            Left  l      -> return (Left  l, id)
            Right (r, f) -> return (Right r, f )


-- instances for the other indexed monads.

-- ReaderX
instance (ErrorX ixe e, MonadReaderX ixr r m, Index ixr, Index ixe) => MonadReaderX ixr r (ErrorTX ixe e m) where
    askx (ixv::ixr)       = lift $ askx ixv
    localx (ixv::ixr) f m = mkErrorTX (getVal::ixe) $ localx ixv f (runErrorTX' {-(getVal::ixe)-} m)

--StateX
instance (ErrorX ixe e, Index ixs, MonadStateX ixs s m, Index ixe) => MonadStateX ixs s (ErrorTX ixe e m) where
    getx (ixv::ixs) = lift $ getx ixv
    putx (ixv::ixs) = lift . putx ixv

--WriterX
instance (ErrorX ixe e, Index ixw, MonadWriterX ixw w m, Index ixe) => MonadWriterX ixw w (ErrorTX ixe e m) where
    tellx (ixv::ixw)     = lift . tellx ixv
    listenx (ixv::ixw) m = mkErrorTX (getVal::ixe) $ do
        (a, w) <- listenx ixv (runErrorTX' {-(getVal::ixe)-} m)
        case a of
            Left  l -> return $ Left   l
            Right r -> return $ Right (r, w)
    passx (ixv::ixw)   m = mkErrorTX (getVal::ixe) $ passx ixv $ do
        a <- runErrorTX' {-(getVal::ixe)-} m
        case a of
            Left  l      -> return (Left  l, id)
            Right (r, f) -> return (Right r, f )