{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Control.Monad.Abort.Class (
    MonadAbort(..),
    MonadRecover(..)
  ) where

import Data.Monoid
import Control.Monad.Cont
import Control.Monad.Error
import Control.Monad.List
import Control.Monad.Reader
import Control.Monad.State (MonadState(..))
import qualified Control.Monad.State.Lazy as L
import qualified Control.Monad.State.Strict as S
import Control.Monad.Writer (MonadWriter(..))
import qualified Control.Monad.Writer.Lazy as L
import qualified Control.Monad.Writer.Strict as S
import Control.Monad.RWS (MonadRWS)
import qualified Control.Monad.RWS.Lazy as L
import qualified Control.Monad.RWS.Strict as S
import Control.Monad.Trans.Abort (AbortT(..))
import qualified Control.Monad.Trans.Abort as A

class Monad μ  MonadAbort e μ | μ  e where
  abort  e  μ α

class MonadAbort e μ  MonadRecover e μ | μ  e where
  recover  μ α  (e  μ α)  μ α

instance Monad μ  MonadAbort e (AbortT e μ) where
  abort = A.abort

instance Monad μ  MonadRecover e (AbortT e μ) where
  recover = A.recover

instance Monad μ  MonadError e (AbortT e μ) where
  throwError = A.abort
  catchError = A.recover

instance MonadCont μ  MonadCont (AbortT e μ) where
  callCC k = AbortT $ callCC $ \f  runAbortT $ k (lift . f . Right) 

instance MonadReader r μ  MonadReader r (AbortT e μ) where
  ask = lift ask
  local f = AbortT . local f . runAbortT

instance MonadState s μ  MonadState s (AbortT e μ) where
  get = lift get
  put = lift . put

instance MonadWriter w μ  MonadWriter w (AbortT e μ) where
  tell = lift . tell
  listen m = AbortT $ do
    (lr, w)  listen $ runAbortT m
    return $! fmap (, w) lr
  pass m = AbortT $ pass $ do
    lr  runAbortT m
    return $! either ((, id) . Left) (\(r, f)  (Right r, f)) lr

instance MonadRWS r w s μ  MonadRWS r w s (AbortT e μ)

instance MonadAbort e μ  MonadAbort e (ContT r μ) where
  abort = lift . abort

instance MonadAbort e μ  MonadAbort e (ListT μ) where
  abort = lift . abort

instance MonadRecover e μ  MonadRecover e (ListT μ) where
  recover m h = ListT $ runListT m `recover` (runListT . h)

instance MonadAbort e μ  MonadAbort e (ReaderT r μ) where
  abort = lift . abort

instance MonadRecover e μ  MonadRecover e (ReaderT r μ) where
  recover m h = ReaderT $ \r  runReaderT m r `recover` ((`runReaderT` r) . h)

instance MonadAbort e μ  MonadAbort e (L.StateT s μ) where
  abort = lift . abort

instance MonadRecover e μ  MonadRecover e (L.StateT s μ) where
  recover m h = L.StateT $ \s 
    L.runStateT m s `recover` ((`L.runStateT` s) . h)

instance MonadAbort e μ  MonadAbort e (S.StateT s μ) where
  abort = lift . abort

instance MonadRecover e μ  MonadRecover e (S.StateT s μ) where
  recover m h = S.StateT $ \s 
    S.runStateT m s `recover` ((`S.runStateT` s) . h)

instance (MonadAbort e μ, Monoid w)  MonadAbort e (L.WriterT w μ) where
  abort = lift . abort

instance (MonadRecover e μ, Monoid w)  MonadRecover e (L.WriterT w μ) where
  recover m h = L.WriterT $ L.runWriterT m `recover` (L.runWriterT . h)

instance (MonadAbort e μ, Monoid w)  MonadAbort e (S.WriterT w μ) where
  abort = lift . abort

instance (MonadRecover e μ, Monoid w)  MonadRecover e (S.WriterT w μ) where
  recover m h = S.WriterT $ S.runWriterT m `recover` (S.runWriterT . h)

instance (MonadAbort e μ, Monoid w)  MonadAbort e (L.RWST r w s μ) where
  abort = lift . abort

instance (MonadRecover e μ, Monoid w)  MonadRecover e (L.RWST r w s μ) where
  recover m h = L.RWST $ \r s 
    L.runRWST m r s `recover` (\e  L.runRWST (h e) r s)

instance (MonadAbort e μ, Monoid w)  MonadAbort e (S.RWST r w s μ) where
  abort = lift . abort

instance (MonadRecover e μ, Monoid w)  MonadRecover e (S.RWST r w s μ) where
  recover m h = S.RWST $ \r s 
    S.runRWST m r s `recover` (\e  S.runRWST (h e) r s)