module Control.Monad.CatchIO ( MonadCatchIO(..),
                               E.Exception(..),
                               throw,
                               try, tryJust,
                               onException, bracket,
                               Handler(..), catches )

where

import Prelude hiding ( catch )

import qualified Control.Exception.Extensible as E

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Error
import Control.Monad.Writer
import Control.Monad.RWS

class MonadIO m => MonadCatchIO m where
    -- | Generalized version of 'E.catch'
    catch   :: E.Exception e => m a -> (e -> m a) -> m a
 
    -- | Generalized version of 'E.block'
    block   :: m a -> m a

    -- | Generalized version of 'E.unblock'
    unblock :: m a -> m a

-- | Generalized version of 'E.throwIO'
throw :: (MonadCatchIO m, E.Exception e) => e -> m a

-- | Generalized version of 'E.try'
try :: (MonadCatchIO m, E.Exception e) => m a -> m (Either e a)

-- | Generalized version of 'E.tryJust'
tryJust :: (MonadCatchIO m, E.Exception e)
        => (e -> Maybe b) -> m a -> m (Either b a)

-- | Generalized version of 'E.Handler'
data Handler m a = forall e . E.Exception e => Handler (e -> m a)

-- | Generalized version of 'E.catches'
catches :: MonadCatchIO m => m a -> [Handler m a] -> m a
catches a handlers = a `catch` handler
    where handler e = foldr tryH (throw e) handlers
            where tryH (Handler h) res = case E.fromException e of
                                             Just e' -> h e'
                                             Nothing -> res


instance MonadCatchIO IO where
    catch   = E.catch
    block   = E.block
    unblock = E.unblock

instance MonadCatchIO m => MonadCatchIO (ReaderT r m) where
    m `catch` f = ReaderT $ \r -> (runReaderT m r)
                                    `catch` (\e -> runReaderT (f e) r)
    block       = mapReaderT block
    unblock     = mapReaderT unblock

instance MonadCatchIO m => MonadCatchIO (StateT s m) where
    m `catch` f = StateT $ \s -> (runStateT m s)
                                   `catch` (\e -> runStateT (f e) s)
    block       = mapStateT block
    unblock     = mapStateT unblock

instance (MonadCatchIO m, Error e) => MonadCatchIO (ErrorT e m) where
    m `catch` f = mapErrorT (\m' -> m' `catch` (\e -> runErrorT $ f e)) m
    block       = mapErrorT block
    unblock     = mapErrorT unblock

instance (Monoid w, MonadCatchIO m) => MonadCatchIO (WriterT w m) where
    m `catch` f = WriterT $ runWriterT m
                              `catch` \e -> runWriterT (f e)
    block       = mapWriterT block
    unblock     = mapWriterT unblock

instance (Monoid w, MonadCatchIO m) => MonadCatchIO (RWST r w s m) where
    m `catch` f = RWST $ \r s -> runRWST m r s
                           `catch` \e -> runRWST (f e) r s
    block       = mapRWST block
    unblock     = mapRWST unblock

throw = liftIO . E.throwIO

try a = catch (a >>= \ v -> return (Right v)) (\e -> return (Left e))

tryJust p a = do
  r <- try a
  case r of
        Right v -> return (Right v)
        Left  e -> case p e of
                        Nothing -> throw e `asTypeOf` (return $ Left undefined)
                        Just b  -> return (Left b)

-- | Generalized version of 'E.bracket'
bracket :: MonadCatchIO m => m a -> (a -> m b) -> (a -> m c) -> m c
bracket before after thing =
    block (do a <- before
              r <- unblock (thing a) `onException` after a
              after a
              return r)

-- | Generalized version of 'E.onException'
onException :: MonadCatchIO m => m a -> m b -> m a
onException a onEx = a `catch` (\(e::E.SomeException) -> onEx >> throw e)