{- Copyright 2012 Google Inc. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -} {-# LANGUAGE CPP #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE UndecidableInstances #-} #if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702 {-# LANGUAGE Trustworthy #-} #endif #ifndef MIN_VERSION_transformers #define MIN_VERSION_transformers(x,y,z) 1 #endif #ifndef MIN_VERSION_mtl #define MIN_VERSION_mtl(x,y,z) 1 #endif -------------------------------------------------------------------- -- | -- Copyright : (c) Edward Kmett 2013, (c) Google Inc. 2012 -- Maintainer: Edward Kmett -- Stability : experimental -- Portability: non-portable -- -- This module supports monads that can throw extensible exceptions. The -- exceptions are the very same from "Control.Exception", and the operations -- offered very similar, but here they are not limited to 'IO'. -- -- This code is in the style of both transformers and mtl, and is compatible -- with them, though doesn't mimic the module structure or offer the complete -- range of features in those packages. -- -- This is very similar to 'ErrorT' and 'MonadError', but based on features of -- "Control.Exception". In particular, it handles the complex case of -- asynchronous exceptions by including 'mask' in the typeclass. Note that the -- extensible extensions feature relies the RankNTypes language extension. -------------------------------------------------------------------- module Control.Monad.Catch ( -- * Typeclass -- $mtl MonadCatch(..) -- * Transformer -- $transformer , CatchT(..), Catch , runCatch , mapCatchT -- * Utilities -- $utilities , catchAll , catchIOError , catchJust , catchIf , Handler(..), catches , handle , handleJust , try , tryJust , onException , bracket , bracket_ , finally , bracketOnError -- * Re-exports from Control.Exception , Exception(..) , SomeException(..) ) where #if defined(__GLASGOW_HASKELL__) && (__GLASGOW_HASKELL__ >= 706) import Prelude hiding (foldr) #else import Prelude hiding (catch, foldr) #endif import Control.Applicative import Control.Exception (Exception(..), SomeException(..)) import qualified Control.Exception as ControlException import qualified Control.Monad.Trans.RWS.Lazy as LazyRWS import qualified Control.Monad.Trans.RWS.Strict as StrictRWS import qualified Control.Monad.Trans.State.Lazy as LazyS import qualified Control.Monad.Trans.State.Strict as StrictS import qualified Control.Monad.Trans.Writer.Lazy as LazyW import qualified Control.Monad.Trans.Writer.Strict as StrictW import Control.Monad.Trans.Identity import Control.Monad.Reader as Reader import Control.Monad.RWS import Data.Foldable import Data.Functor.Identity import Data.Traversable as Traversable ------------------------------------------------------------------------------ -- $mtl -- The mtl style typeclass ------------------------------------------------------------------------------ class Monad m => MonadCatch m where -- | Throw an exception. Note that this throws when this action is run in -- the monad @m@, not when it is applied. It is a generalization of -- "Control.Exception"'s 'ControlException.throwIO'. throwM :: Exception e => e -> m a -- | Provide a handler for exceptions thrown during execution of the first -- action. Note that type of the type of the argument to the handler will -- constrain which exceptions are caught. See "Control.Exception"'s -- 'ControlException.catch'. catch :: Exception e => m a -> (e -> m a) -> m a -- | Runs an action with asynchronous exceptions disabled. The action is -- provided a method for restoring the async. environment to what it was -- at the 'mask' call. See "Control.Exception"'s 'ControlException.mask'. mask :: ((forall a. m a -> m a) -> m b) -> m b -- | Like 'mask', but the masked computation is not interruptible (see -- "Control.Exception"'s 'ControlException.uninterruptibleMask'. WARNING: -- Only use if you need to mask exceptions around an interruptible operation -- AND you can guarantee the interruptible operation will only block for a -- short period of time. Otherwise you render the program/thread unresponsive -- and/or unkillable. uninterruptibleMask :: ((forall a. m a -> m a) -> m b) -> m b instance MonadCatch IO where throwM = ControlException.throwIO catch = ControlException.catch mask = ControlException.mask uninterruptibleMask = ControlException.uninterruptibleMask instance MonadCatch m => MonadCatch (IdentityT m) where throwM e = lift $ throwM e catch (IdentityT m) f = IdentityT (catch m (runIdentityT . f)) mask a = IdentityT $ mask $ \u -> runIdentityT (a $ q u) where q u = IdentityT . u . runIdentityT uninterruptibleMask a = IdentityT $ uninterruptibleMask $ \u -> runIdentityT (a $ q u) where q u = IdentityT . u . runIdentityT instance MonadCatch m => MonadCatch (LazyS.StateT s m) where throwM e = lift $ throwM e catch = LazyS.liftCatch catch mask a = LazyS.StateT $ \s -> mask $ \u -> LazyS.runStateT (a $ q u) s where q u (LazyS.StateT b) = LazyS.StateT (u . b) uninterruptibleMask a = LazyS.StateT $ \s -> uninterruptibleMask $ \u -> LazyS.runStateT (a $ q u) s where q u (LazyS.StateT b) = LazyS.StateT (u . b) instance MonadCatch m => MonadCatch (StrictS.StateT s m) where throwM e = lift $ throwM e catch = StrictS.liftCatch catch mask a = StrictS.StateT $ \s -> mask $ \u -> StrictS.runStateT (a $ q u) s where q u (StrictS.StateT b) = StrictS.StateT (u . b) uninterruptibleMask a = StrictS.StateT $ \s -> uninterruptibleMask $ \u -> StrictS.runStateT (a $ q u) s where q u (StrictS.StateT b) = StrictS.StateT (u . b) instance MonadCatch m => MonadCatch (ReaderT r m) where throwM e = lift $ throwM e catch (ReaderT m) c = ReaderT $ \r -> m r `catch` \e -> runReaderT (c e) r mask a = ReaderT $ \e -> mask $ \u -> Reader.runReaderT (a $ q u) e where q u (ReaderT b) = ReaderT (u . b) uninterruptibleMask a = ReaderT $ \e -> uninterruptibleMask $ \u -> Reader.runReaderT (a $ q u) e where q u (ReaderT b) = ReaderT (u . b) instance (MonadCatch m, Monoid w) => MonadCatch (StrictW.WriterT w m) where throwM e = lift $ throwM e catch (StrictW.WriterT m) h = StrictW.WriterT $ m `catch ` \e -> StrictW.runWriterT (h e) mask a = StrictW.WriterT $ mask $ \u -> StrictW.runWriterT (a $ q u) where q u b = StrictW.WriterT $ u (StrictW.runWriterT b) uninterruptibleMask a = StrictW.WriterT $ uninterruptibleMask $ \u -> StrictW.runWriterT (a $ q u) where q u b = StrictW.WriterT $ u (StrictW.runWriterT b) instance (MonadCatch m, Monoid w) => MonadCatch (LazyW.WriterT w m) where throwM e = lift $ throwM e catch (LazyW.WriterT m) h = LazyW.WriterT $ m `catch ` \e -> LazyW.runWriterT (h e) mask a = LazyW.WriterT $ mask $ \u -> LazyW.runWriterT (a $ q u) where q u b = LazyW.WriterT $ u (LazyW.runWriterT b) uninterruptibleMask a = LazyW.WriterT $ uninterruptibleMask $ \u -> LazyW.runWriterT (a $ q u) where q u b = LazyW.WriterT $ u (LazyW.runWriterT b) instance (MonadCatch m, Monoid w) => MonadCatch (LazyRWS.RWST r w s m) where throwM e = lift $ throwM e catch (LazyRWS.RWST m) h = LazyRWS.RWST $ \r s -> m r s `catch` \e -> LazyRWS.runRWST (h e) r s mask a = LazyRWS.RWST $ \r s -> mask $ \u -> LazyRWS.runRWST (a $ q u) r s where q u (LazyRWS.RWST b) = LazyRWS.RWST $ \ r s -> u (b r s) uninterruptibleMask a = LazyRWS.RWST $ \r s -> uninterruptibleMask $ \u -> LazyRWS.runRWST (a $ q u) r s where q u (LazyRWS.RWST b) = LazyRWS.RWST $ \ r s -> u (b r s) instance (MonadCatch m, Monoid w) => MonadCatch (StrictRWS.RWST r w s m) where throwM e = lift $ throwM e catch (StrictRWS.RWST m) h = StrictRWS.RWST $ \r s -> m r s `catch` \e -> StrictRWS.runRWST (h e) r s mask a = StrictRWS.RWST $ \r s -> mask $ \u -> StrictRWS.runRWST (a $ q u) r s where q u (StrictRWS.RWST b) = StrictRWS.RWST $ \ r s -> u (b r s) uninterruptibleMask a = StrictRWS.RWST $ \r s -> uninterruptibleMask $ \u -> StrictRWS.runRWST (a $ q u) r s where q u (StrictRWS.RWST b) = StrictRWS.RWST $ \ r s -> u (b r s) ------------------------------------------------------------------------------ -- $transformer -- The @transformers@-style monad transfomer ------------------------------------------------------------------------------ -- | Add 'Exception' handling abilities to a 'Monad'. newtype CatchT m a = CatchT { runCatchT :: m (Either SomeException a) } type Catch = CatchT Identity runCatch :: Catch a -> Either SomeException a runCatch = runIdentity . runCatchT instance Monad m => Functor (CatchT m) where fmap f (CatchT m) = CatchT (liftM (fmap f) m) instance Monad m => Applicative (CatchT m) where pure a = CatchT (return (Right a)) (<*>) = ap instance Monad m => Monad (CatchT m) where return a = CatchT (return (Right a)) CatchT m >>= k = CatchT $ m >>= \ea -> case ea of Left e -> return (Left e) Right a -> runCatchT (k a) fail = CatchT . return . Left . toException . userError instance MonadFix m => MonadFix (CatchT m) where mfix f = CatchT $ mfix $ \a -> runCatchT $ f $ case a of Right r -> r _ -> error "empty mfix argument" instance Foldable m => Foldable (CatchT m) where foldMap f (CatchT m) = foldMap (foldMapEither f) m where foldMapEither g (Right a) = g a foldMapEither _ (Left _) = mempty instance (Monad m, Traversable m) => Traversable (CatchT m) where traverse f (CatchT m) = CatchT <$> Traversable.traverse (traverseEither f) m where traverseEither g (Right a) = Right <$> g a traverseEither _ (Left e) = pure (Left e) instance Monad m => Alternative (CatchT m) where empty = mzero (<|>) = mplus instance Monad m => MonadPlus (CatchT m) where mzero = CatchT $ return $ Left $ toException $ userError "" mplus (CatchT m) (CatchT n) = CatchT $ m >>= \ea -> case ea of Left _ -> n Right a -> return (Right a) instance MonadTrans CatchT where lift m = CatchT $ do a <- m return $ Right a instance MonadIO m => MonadIO (CatchT m) where liftIO m = CatchT $ do a <- liftIO m return $ Right a instance Monad m => MonadCatch (CatchT m) where throwM = CatchT . return . Left . toException catch (CatchT m) c = CatchT $ m >>= \ea -> case ea of Left e -> case fromException e of Just e' -> runCatchT (c e') Nothing -> return (Left e) Right a -> return (Right a) mask a = a id uninterruptibleMask a = a id instance MonadState s m => MonadState s (CatchT m) where get = lift get put = lift . put #if MIN_VERSION_mtl(2,1,0) state = lift . state #endif instance MonadReader e m => MonadReader e (CatchT m) where ask = lift ask local f (CatchT m) = CatchT (local f m) instance MonadWriter w m => MonadWriter w (CatchT m) where tell = lift . tell listen = mapCatchT $ \ m -> do (a, w) <- listen m return $! fmap (\ r -> (r, w)) a pass = mapCatchT $ \ m -> pass $ do a <- m return $! case a of Left l -> (Left l, id) Right (r, f) -> (Right r, f) #if MIN_VERSION_mtl(2,1,0) writer aw = CatchT (Right `liftM` writer aw) #endif instance MonadRWS r w s m => MonadRWS r w s (CatchT m) -- | Map the unwrapped computation using the given function. -- -- * @'runErrorT' ('mapErrorT' f m) = f ('runErrorT' m@) mapCatchT :: (m (Either SomeException a) -> n (Either SomeException b)) -> CatchT m a -> CatchT n b mapCatchT f m = CatchT $ f (runCatchT m) ------------------------------------------------------------------------------ -- $utilities -- These functions follow those from "Control.Exception", except that they are -- based on methods from the 'MonadCatch' typeclass. See -- "Control.Exception" for API usage. ------------------------------------------------------------------------------ -- | Catches all exceptions, and somewhat defeats the purpose of the extensible -- exception system. Use sparingly. catchAll :: MonadCatch m => m a -> (SomeException -> m a) -> m a catchAll = catch -- | Catch all 'IOError' (eqv. 'IOException') exceptions. Still somewhat too -- general, but better than using 'catchAll'. See 'catchIf' for an easy way -- of catching specific 'IOError's based on the predicates in "System.IO.Error". catchIOError :: MonadCatch m => m a -> (IOError -> m a) -> m a catchIOError = catch -- | Catch exceptions only if they pass some predicate. Often useful with the -- predicates for testing 'IOError' values in "System.IO.Error". catchIf :: (MonadCatch m, Exception e) => (e -> Bool) -> m a -> (e -> m a) -> m a catchIf f a b = a `catch` \e -> if f e then b e else throwM e -- | A more generalized way of determining which exceptions to catch at -- run time. catchJust :: (MonadCatch m, Exception e) => (e -> Maybe b) -> m a -> (b -> m a) -> m a catchJust f a b = a `catch` \e -> maybe (throwM e) b $ f e -- | Flipped 'catch'. See "Control.Exception"'s 'ControlException.handle'. handle :: (MonadCatch m, Exception e) => (e -> m a) -> m a -> m a handle = flip catch {-# INLINE handle #-} -- | Flipped 'catchJust'. See "Control.Exception"'s 'ControlException.handleJust'. handleJust :: (MonadCatch m, Exception e) => (e -> Maybe b) -> (b -> m a) -> m a -> m a handleJust f = flip (catchJust f) {-# INLINE handleJust #-} -- | Similar to 'catch', but returns an 'Either' result. See "Control.Exception"'s -- 'Control.Exception.try'. try :: (MonadCatch m, Exception e) => m a -> m (Either e a) try a = catch (Right `liftM` a) (return . Left) -- | A variant of 'try' that takes an exception predicate to select -- which exceptions are caught. See "Control.Exception"'s 'ControlException.tryJust' tryJust :: (MonadCatch m, Exception e) => (e -> Maybe b) -> m a -> m (Either b a) tryJust f a = catch (Right `liftM` a) (\e -> maybe (throwM e) (return . Left) (f e)) -- | Generalized version of 'ControlException.Handler' data Handler m a = forall e . ControlException.Exception e => Handler (e -> m a) instance Monad m => Functor (Handler m) where fmap f (Handler h) = Handler (liftM f . h) -- | Catches different sorts of exceptions. See "Control.Exception"'s 'ControlException.catches' catches :: (Foldable f, MonadCatch m) => m a -> f (Handler m a) -> m a catches a hs = a `catch` handler where handler e = foldr probe (throwM e) hs where probe (Handler h) xs = maybe xs h (ControlException.fromException e) -- | Run an action only if an exception is thrown in the main action. The -- exception is not caught, simply rethrown. onException :: MonadCatch m => m a -> m b -> m a onException action handler = action `catchAll` \e -> handler >> throwM e -- | Generalized abstracted pattern of safe resource acquisition and release -- in the face of exceptions. The first action \"acquires\" some value, which -- is \"released\" by the second action at the end. The third action \"uses\" -- the value and its result is the result of the 'bracket'. -- -- If an exception occurs during the use, the release still happens before the -- exception is rethrown. bracket :: MonadCatch m => m a -> (a -> m b) -> (a -> m c) -> m c bracket acquire release use = mask $ \unmasked -> do resource <- acquire result <- unmasked (use resource) `onException` release resource _ <- release resource return result -- | Version of 'bracket' without any value being passed to the second and -- third actions. bracket_ :: MonadCatch m => m a -> m b -> m c -> m c bracket_ before after action = bracket before (const after) (const action) -- | Perform an action with a finalizer action that is run, even if an -- exception occurs. finally :: MonadCatch m => m a -> m b -> m a finally action finalizer = bracket_ (return ()) finalizer action -- | Like 'bracket', but only performs the final action if there was an -- exception raised by the in-between computation. bracketOnError :: MonadCatch m => m a -> (a -> m b) -> (a -> m c) -> m c bracketOnError acquire release use = mask $ \unmasked -> do resource <- acquire unmasked (use resource) `onException` release resource