#ifndef MIN_VERSION_transformers
#define MIN_VERSION_transformers(x,y,z) 1
#endif
#ifndef MIN_VERSION_mtl
#define MIN_VERSION_mtl(x,y,z) 1
#endif
module Control.Monad.Catch (
MonadCatch(..)
, CatchT(..), Catch
, runCatch
, mapCatchT
, catchAll
, catchIOError
, catchJust
, catchIf
, Handler(..), catches
, handle
, handleJust
, try
, tryJust
, onException
, bracket
, bracket_
, finally
, bracketOnError
, Exception(..)
, SomeException(..)
) where
#if defined(__GLASGOW_HASKELL__) && (__GLASGOW_HASKELL__ >= 706)
import Prelude hiding (foldr)
#else
import Prelude hiding (catch, foldr)
#endif
import Control.Applicative
import Control.Exception (Exception(..), SomeException(..))
import qualified Control.Exception as ControlException
import qualified Control.Monad.Trans.RWS.Lazy as LazyRWS
import qualified Control.Monad.Trans.RWS.Strict as StrictRWS
import qualified Control.Monad.Trans.State.Lazy as LazyS
import qualified Control.Monad.Trans.State.Strict as StrictS
import qualified Control.Monad.Trans.Writer.Lazy as LazyW
import qualified Control.Monad.Trans.Writer.Strict as StrictW
import Control.Monad.Trans.Identity
import Control.Monad.Reader as Reader
import Control.Monad.RWS
import Data.Foldable
import Data.Functor.Identity
import Data.Traversable as Traversable
class Monad m => MonadCatch m where
throwM :: Exception e => e -> m a
catch :: Exception e => m a -> (e -> m a) -> m a
mask :: ((forall a. m a -> m a) -> m b) -> m b
instance MonadCatch IO where
throwM = ControlException.throwIO
catch = ControlException.catch
mask = ControlException.mask
instance MonadCatch m => MonadCatch (IdentityT m) where
throwM e = lift $ throwM e
catch (IdentityT m) f = IdentityT (catch m (runIdentityT . f))
mask a = IdentityT $ mask $ \u -> runIdentityT (a $ q u)
where q u = IdentityT . u . runIdentityT
instance MonadCatch m => MonadCatch (LazyS.StateT s m) where
throwM e = lift $ throwM e
catch = LazyS.liftCatch catch
mask a = LazyS.StateT $ \s -> mask $ \u -> LazyS.runStateT (a $ q u) s
where q u (LazyS.StateT b) = LazyS.StateT (u . b)
instance MonadCatch m => MonadCatch (StrictS.StateT s m) where
throwM e = lift $ throwM e
catch = StrictS.liftCatch catch
mask a = StrictS.StateT $ \s -> mask $ \u -> StrictS.runStateT (a $ q u) s
where q u (StrictS.StateT b) = StrictS.StateT (u . b)
instance MonadCatch m => MonadCatch (ReaderT r m) where
throwM e = lift $ throwM e
catch (ReaderT m) c = ReaderT $ \r -> m r `catch` \e -> runReaderT (c e) r
mask a = ReaderT $ \e -> mask $ \u -> Reader.runReaderT (a $ q u) e
where q u (ReaderT b) = ReaderT (u . b)
instance (MonadCatch m, Monoid w) => MonadCatch (StrictW.WriterT w m) where
throwM e = lift $ throwM e
catch (StrictW.WriterT m) h = StrictW.WriterT $ m `catch ` \e -> StrictW.runWriterT (h e)
mask a = StrictW.WriterT $ mask $ \u -> StrictW.runWriterT (a $ q u)
where q u b = StrictW.WriterT $ u (StrictW.runWriterT b)
instance (MonadCatch m, Monoid w) => MonadCatch (LazyW.WriterT w m) where
throwM e = lift $ throwM e
catch (LazyW.WriterT m) h = LazyW.WriterT $ m `catch ` \e -> LazyW.runWriterT (h e)
mask a = LazyW.WriterT $ mask $ \u -> LazyW.runWriterT (a $ q u)
where q u b = LazyW.WriterT $ u (LazyW.runWriterT b)
instance (MonadCatch m, Monoid w) => MonadCatch (LazyRWS.RWST r w s m) where
throwM e = lift $ throwM e
catch (LazyRWS.RWST m) h = LazyRWS.RWST $ \r s -> m r s `catch` \e -> LazyRWS.runRWST (h e) r s
mask a = LazyRWS.RWST $ \r s -> mask $ \u -> LazyRWS.runRWST (a $ q u) r s
where q u (LazyRWS.RWST b) = LazyRWS.RWST $ \ r s -> u (b r s)
instance (MonadCatch m, Monoid w) => MonadCatch (StrictRWS.RWST r w s m) where
throwM e = lift $ throwM e
catch (StrictRWS.RWST m) h = StrictRWS.RWST $ \r s -> m r s `catch` \e -> StrictRWS.runRWST (h e) r s
mask a = StrictRWS.RWST $ \r s -> mask $ \u -> StrictRWS.runRWST (a $ q u) r s
where q u (StrictRWS.RWST b) = StrictRWS.RWST $ \ r s -> u (b r s)
newtype CatchT m a = CatchT { runCatchT :: m (Either SomeException a) }
type Catch = CatchT Identity
runCatch :: Catch a -> Either SomeException a
runCatch = runIdentity . runCatchT
instance Monad m => Functor (CatchT m) where
fmap f (CatchT m) = CatchT (liftM (fmap f) m)
instance Monad m => Applicative (CatchT m) where
pure a = CatchT (return (Right a))
(<*>) = ap
instance Monad m => Monad (CatchT m) where
return a = CatchT (return (Right a))
CatchT m >>= k = CatchT $ m >>= \ea -> case ea of
Left e -> return (Left e)
Right a -> runCatchT (k a)
fail = CatchT . return . Left . toException . userError
instance MonadFix m => MonadFix (CatchT m) where
mfix f = CatchT $ mfix $ \a -> runCatchT $ f $ case a of
Right r -> r
_ -> error "empty mfix argument"
instance Foldable m => Foldable (CatchT m) where
foldMap f (CatchT m) = foldMap (foldMapEither f) m where
foldMapEither g (Right a) = g a
foldMapEither _ (Left _) = mempty
instance (Monad m, Traversable m) => Traversable (CatchT m) where
traverse f (CatchT m) = CatchT <$> Traversable.traverse (traverseEither f) m where
traverseEither g (Right a) = Right <$> g a
traverseEither _ (Left e) = pure (Left e)
instance Monad m => Alternative (CatchT m) where
empty = mzero
(<|>) = mplus
instance Monad m => MonadPlus (CatchT m) where
mzero = CatchT $ return $ Left $ toException $ userError ""
mplus (CatchT m) (CatchT n) = CatchT $ m >>= \ea -> case ea of
Left _ -> n
Right a -> return (Right a)
instance MonadTrans CatchT where
lift m = CatchT $ do
a <- m
return $ Right a
instance MonadIO m => MonadIO (CatchT m) where
liftIO m = CatchT $ do
a <- liftIO m
return $ Right a
instance Monad m => MonadCatch (CatchT m) where
throwM = CatchT . return . Left . toException
catch (CatchT m) c = CatchT $ m >>= \ea -> case ea of
Left e -> case fromException e of
Just e' -> runCatchT (c e')
Nothing -> return (Left e)
Right a -> return (Right a)
mask a = a id
instance MonadState s m => MonadState s (CatchT m) where
get = lift get
put = lift . put
#if MIN_VERSION_mtl(2,1,0)
state = lift . state
#endif
instance MonadReader e m => MonadReader e (CatchT m) where
ask = lift ask
local f (CatchT m) = CatchT (local f m)
instance MonadWriter w m => MonadWriter w (CatchT m) where
tell = lift . tell
listen = mapCatchT $ \ m -> do
(a, w) <- listen m
return $! fmap (\ r -> (r, w)) a
pass = mapCatchT $ \ m -> pass $ do
a <- m
return $! case a of
Left l -> (Left l, id)
Right (r, f) -> (Right r, f)
#if MIN_VERSION_mtl(2,1,0)
writer aw = CatchT (Right `liftM` writer aw)
#endif
instance MonadRWS r w s m => MonadRWS r w s (CatchT m)
mapCatchT :: (m (Either SomeException a) -> n (Either SomeException b))
-> CatchT m a
-> CatchT n b
mapCatchT f m = CatchT $ f (runCatchT m)
catchAll :: MonadCatch m => m a -> (SomeException -> m a) -> m a
catchAll = catch
catchIOError :: MonadCatch m => m a -> (IOError -> m a) -> m a
catchIOError = catch
catchIf :: (MonadCatch m, Exception e) =>
(e -> Bool) -> m a -> (e -> m a) -> m a
catchIf f a b = a `catch` \e -> if f e then b e else throwM e
catchJust :: (MonadCatch m, Exception e) =>
(e -> Maybe b) -> m a -> (b -> m a) -> m a
catchJust f a b = a `catch` \e -> maybe (throwM e) b $ f e
handle :: (MonadCatch m, Exception e) => (e -> m a) -> m a -> m a
handle = flip catch
handleJust :: (MonadCatch m, Exception e) => (e -> Maybe b) -> (b -> m a) -> m a -> m a
handleJust f = flip (catchJust f)
try :: (MonadCatch m, Exception e) => m a -> m (Either e a)
try a = catch (Right `liftM` a) (return . Left)
tryJust :: (MonadCatch m, Exception e) =>
(e -> Maybe b) -> m a -> m (Either b a)
tryJust f a = catch (Right `liftM` a) (\e -> maybe (throwM e) (return . Left) (f e))
data Handler m a = forall e . ControlException.Exception e => Handler (e -> m a)
instance Monad m => Functor (Handler m) where
fmap f (Handler h) = Handler (liftM f . h)
catches :: (Foldable f, MonadCatch m) => m a -> f (Handler m a) -> m a
catches a hs = a `catch` handler
where
handler e = foldr probe (throwM e) hs
where
probe (Handler h) xs = maybe xs h (ControlException.fromException e)
onException :: MonadCatch m => m a -> m b -> m a
onException action handler = action `catchAll` \e -> handler >> throwM e
bracket :: MonadCatch m => m a -> (a -> m b) -> (a -> m c) -> m c
bracket acquire release use = mask $ \unmasked -> do
resource <- acquire
result <- unmasked (use resource) `onException` release resource
_ <- release resource
return result
bracket_ :: MonadCatch m => m a -> m b -> m c -> m c
bracket_ before after action = bracket before (const after) (const action)
finally :: MonadCatch m => m a -> m b -> m a
finally action finalizer = bracket_ (return ()) finalizer action
bracketOnError :: MonadCatch m => m a -> (a -> m b) -> (a -> m c) -> m c
bracketOnError acquire release use = mask $ \unmasked -> do
resource <- acquire
unmasked (use resource) `onException` release resource