module Control.Monad.CatchIO
  (
    MonadCatchIO(..)
  , E.Exception(..)
  , throw
  , try, tryJust
  , Handler(..), catches
  -- * Utilities
  , bracket
  , bracket_
  , bracketOnError
  , finally
  , onException
  )

where

import Prelude hiding ( catch )

import qualified Control.Exception.Extensible as E

import Control.Monad.Trans        (MonadIO,liftIO)
import Control.Monad.Trans.Error  (ErrorT          ,runErrorT ,mapErrorT ,Error)
import Control.Monad.Trans.Reader (ReaderT(ReaderT),runReaderT,mapReaderT)
import Control.Monad.Trans.State  (StateT(StateT)  ,runStateT ,mapStateT )
import Control.Monad.Trans.Writer (WriterT(WriterT),runWriterT,mapWriterT)
import Control.Monad.Trans.RWS    (RWST(RWST)      ,runRWST   ,mapRWST   )
import Data.Monoid                (Monoid)

class MonadIO m => MonadCatchIO m where
    -- | Generalized version of 'E.catch'
    catch   :: E.Exception e => m a -> (e -> m a) -> m a
 
    -- | Generalized version of 'E.block'
    block   :: m a -> m a

    -- | Generalized version of 'E.unblock'
    unblock :: m a -> m a

-- | Generalized version of 'E.throwIO'
throw :: (MonadCatchIO m, E.Exception e) => e -> m a

-- | Generalized version of 'E.try'
try :: (MonadCatchIO m, E.Exception e) => m a -> m (Either e a)

-- | Generalized version of 'E.tryJust'
tryJust :: (MonadCatchIO m, E.Exception e)
        => (e -> Maybe b) -> m a -> m (Either b a)

-- | Generalized version of 'E.Handler'
data Handler m a = forall e . E.Exception e => Handler (e -> m a)

-- | Generalized version of 'E.catches'
catches :: MonadCatchIO m => m a -> [Handler m a] -> m a
catches a handlers = a `catch` handler
    where handler e = foldr tryH (throw e) handlers
            where tryH (Handler h) res = case E.fromException e of
                                             Just e' -> h e'
                                             Nothing -> res


instance MonadCatchIO IO where
    catch   = E.catch
    block   = E.block
    unblock = E.unblock

instance MonadCatchIO m => MonadCatchIO (ReaderT r m) where
    m `catch` f = ReaderT $ \r -> (runReaderT m r)
                                    `catch` (\e -> runReaderT (f e) r)
    block       = mapReaderT block
    unblock     = mapReaderT unblock

instance MonadCatchIO m => MonadCatchIO (StateT s m) where
    m `catch` f = StateT $ \s -> (runStateT m s)
                                   `catch` (\e -> runStateT (f e) s)
    block       = mapStateT block
    unblock     = mapStateT unblock

instance (MonadCatchIO m, Error e) => MonadCatchIO (ErrorT e m) where
    m `catch` f = mapErrorT (\m' -> m' `catch` (\e -> runErrorT $ f e)) m
    block       = mapErrorT block
    unblock     = mapErrorT unblock

instance (Monoid w, MonadCatchIO m) => MonadCatchIO (WriterT w m) where
    m `catch` f = WriterT $ runWriterT m
                              `catch` \e -> runWriterT (f e)
    block       = mapWriterT block
    unblock     = mapWriterT unblock

instance (Monoid w, MonadCatchIO m) => MonadCatchIO (RWST r w s m) where
    m `catch` f = RWST $ \r s -> runRWST m r s
                           `catch` \e -> runRWST (f e) r s
    block       = mapRWST block
    unblock     = mapRWST unblock

throw = liftIO . E.throwIO

try a = catch (a >>= \ v -> return (Right v)) (\e -> return (Left e))

tryJust p a = do
  r <- try a
  case r of
        Right v -> return (Right v)
        Left  e -> case p e of
                        Nothing -> throw e `asTypeOf` (return $ Left undefined)
                        Just b  -> return (Left b)

-- | Generalized version of 'E.bracket'
bracket :: MonadCatchIO m => m a -> (a -> m b) -> (a -> m c) -> m c
bracket before after thing =
    block (do a <- before
              r <- unblock (thing a) `onException` after a
              after a
              return r)

-- | Generalized version of 'E.onException'
onException :: MonadCatchIO m => m a -> m b -> m a
onException a onEx = a `catch` (\ (e :: E.SomeException) -> onEx >> throw e)

-- | A variant of 'bracket' where the return value from the first computation
-- is not required.
bracket_ :: MonadCatchIO m
         => m a  -- ^ computation to run first (\"acquire resource\")
         -> m b  -- ^ computation to run last (\"release resource\")
         -> m c  -- ^ computation to run in-between
         -> m c  -- returns the value from the in-between computation
bracket_ before after thing = block $ do
  before
  r <- unblock thing `onException` after
  after
  return r

-- | A specialised variant of 'bracket' with just a computation to run
-- afterward.
finally :: MonadCatchIO m
        => m a -- ^ computation to run first
        -> m b -- ^ computation to run afterward (even if an exception was
               -- raised)
        -> m a -- returns the value from the first computation
thing `finally` after = block $ do
  r <- unblock thing `onException` after
  after
  return r

-- | Like 'bracket', but only performs the final action if there was an
-- exception raised by the in-between computation.
bracketOnError :: MonadCatchIO m
               => m a        -- ^ computation to run first (\"acquire resource\")
               -> (a -> m b) -- ^ computation to run last (\"release resource\")
               -> (a -> m c) -- ^ computation to run in-between
               -> m c        -- returns the value from the in-between computation
bracketOnError before after thing = block $ do
  a <- before
  unblock (thing a) `onException` after a