module Control.Monad.CatchIO
  (
    MonadCatchIO(..)
  , E.Exception(..)
  , throw
  , try, tryJust
  , Handler(..), catches
  -- * Utilities
  , bracket
  , bracket_
  , bracketOnError
  , finally
  , onException
  )

where

import Prelude hiding ( catch )

import qualified Control.Exception.Extensible as E

import Control.Monad.Trans        (MonadIO,liftIO)
import Control.Monad.Trans.Error  (ErrorT          ,runErrorT ,mapErrorT ,Error)
import Control.Monad.Trans.Reader (ReaderT(ReaderT),runReaderT,mapReaderT)
import Control.Monad.Trans.State  (StateT(StateT)  ,runStateT ,mapStateT )
import Control.Monad.Trans.Writer (WriterT(WriterT),runWriterT,mapWriterT)
import Control.Monad.Trans.RWS    (RWST(RWST)      ,runRWST   ,mapRWST   )
import Data.Monoid                (Monoid)


class MonadIO m => MonadCatchIO m where
  -- | Generalized version of 'E.catch'
  catch   :: E.Exception e => m a -> (e -> m a) -> m a
  
  -- | Generalized version of 'E.block'
  block   :: m a -> m a
  
  -- | Generalized version of 'E.unblock'
  unblock :: m a -> m a


instance MonadCatchIO IO where
  catch   = E.catch
  block   = E.block
  unblock = E.unblock

instance MonadCatchIO m => MonadCatchIO (ReaderT r m) where
  m `catch` f = ReaderT $ \r -> (runReaderT m r) `catch` (\e -> runReaderT (f e) r)
  block       = mapReaderT block
  unblock     = mapReaderT unblock

instance MonadCatchIO m => MonadCatchIO (StateT s m) where
  m `catch` f = StateT $ \s -> (runStateT m s) `catch` (\e -> runStateT (f e) s)
  block       = mapStateT block
  unblock     = mapStateT unblock

instance (MonadCatchIO m, Error e) => MonadCatchIO (ErrorT e m) where
  m `catch` f = mapErrorT (\m' -> m' `catch` (\e -> runErrorT $ f e)) m
  block       = mapErrorT block
  unblock     = mapErrorT unblock

instance (Monoid w, MonadCatchIO m) => MonadCatchIO (WriterT w m) where
  m `catch` f = WriterT $ runWriterT m `catch` \e -> runWriterT (f e)
  block       = mapWriterT block
  unblock     = mapWriterT unblock

instance (Monoid w, MonadCatchIO m) => MonadCatchIO (RWST r w s m) where
  m `catch` f = RWST $ \r s -> runRWST m r s `catch` \e -> runRWST (f e) r s
  block       = mapRWST block
  unblock     = mapRWST unblock


-- | Generalized version of 'E.throwIO'
throw :: (MonadIO m, E.Exception e) => e -> m a
throw = liftIO . E.throwIO

-- | Generalized version of 'E.try'
try :: (MonadCatchIO m, E.Exception e) => m a -> m (Either e a)
try a = catch (a >>= \ v -> return (Right v)) (\e -> return (Left e))

-- | Generalized version of 'E.tryJust'
tryJust :: (MonadCatchIO m, E.Exception e)
        => (e -> Maybe b) -> m a -> m (Either b a)
tryJust p a = do
  r <- try a
  case r of
    Right v -> return (Right v)
    Left  e -> case p e of
      Nothing -> throw e `asTypeOf` (return $ Left undefined)
      Just b  -> return (Left b)

-- | Generalized version of 'E.Handler'
data Handler m a = forall e . E.Exception e => Handler (e -> m a)

-- | Generalized version of 'E.catches'
catches :: MonadCatchIO m => m a -> [Handler m a] -> m a
catches a handlers = a `catch` handler where
  handler e = foldr tryH (throw e) handlers where
    tryH (Handler h) res = case E.fromException e of
      Just e' -> h e'
      Nothing -> res

-- | Generalized version of 'E.bracket'
bracket :: MonadCatchIO m => m a -> (a -> m b) -> (a -> m c) -> m c
bracket before after thing = block $ do
  a <- before
  r <- unblock (thing a) `onException` after a
  after a
  return r

-- | Generalized version of 'E.onException'
onException :: MonadCatchIO m => m a -> m b -> m a
onException a onEx = a `catch` (\ (e :: E.SomeException) -> onEx >> throw e)

-- | A variant of 'bracket' where the return value from the first computation
-- is not required.
bracket_ :: MonadCatchIO m
         => m a  -- ^ computation to run first (\"acquire resource\")
         -> m b  -- ^ computation to run last (\"release resource\")
         -> m c  -- ^ computation to run in-between
         -> m c  -- returns the value from the in-between computation
bracket_ before after thing = block $ do
  before
  r <- unblock thing `onException` after
  after
  return r

-- | A specialised variant of 'bracket' with just a computation to run
-- afterward.
finally :: MonadCatchIO m
        => m a -- ^ computation to run first
        -> m b -- ^ computation to run afterward (even if an exception was
               -- raised)
        -> m a -- returns the value from the first computation
thing `finally` after = block $ do
  r <- unblock thing `onException` after
  after
  return r

-- | Like 'bracket', but only performs the final action if there was an
-- exception raised by the in-between computation.
bracketOnError :: MonadCatchIO m
               => m a        -- ^ computation to run first (\"acquire resource\")
               -> (a -> m b) -- ^ computation to run last (\"release resource\")
               -> (a -> m c) -- ^ computation to run in-between
               -> m c        -- returns the value from the in-between computation
bracketOnError before after thing = block $ do
  a <- before
  unblock (thing a) `onException` after a