{-# LANGUAGE ConstraintKinds , DefaultSignatures , FlexibleInstances , FunctionalDependencies , MultiParamTypeClasses , UndecidableInstances #-} module Control.Monad.Catch ( MonadThrow (..) , MonadCatch (..) , MonadError , mapE , WrappedMonadError (..) , WrappedMonadCatch (..) ) where import Control.Exception (IOException) import Control.Monad.Error hiding (MonadError) import qualified Control.Monad.Error.Class as Error import Control.Monad.Trans.Identity import Control.Monad.Trans.List import Control.Monad.Trans.Maybe import Control.Monad.Trans.Reader import qualified Control.Monad.Trans.RWS.Lazy as Lazy import qualified Control.Monad.Trans.RWS.Strict as Strict import qualified Control.Monad.Trans.State.Lazy as Lazy import qualified Control.Monad.Trans.State.Strict as Strict import qualified Control.Monad.Trans.Writer.Lazy as Lazy import qualified Control.Monad.Trans.Writer.Strict as Strict import Data.Monoid class Monad m => MonadThrow e m | m -> e where throw :: e -> m a default throw :: (MonadThrow e m, MonadTrans t) => e -> t m a throw = lift . throw class ( MonadThrow e m , Monad n ) => MonadCatch e m n | m -> e, n e -> m where catch :: m a -> (e -> n a) -> n a type MonadError e m = (MonadThrow e m, MonadCatch e m m) mapE :: (MonadThrow e' n, MonadCatch e m n) => (e -> e') -> m a -> n a mapE f m = m `catch` (throw . f) instance (Error e, Monad m) => MonadThrow e (ErrorT e m) where throw = throwError instance ( Error e , Error e' , Monad m ) => MonadCatch e (ErrorT e m) (ErrorT e' m) where m `catch` h = ErrorT $ runErrorT m >>= either (runErrorT . h) (return . Right) instance MonadThrow e m => MonadThrow e (IdentityT m) instance MonadCatch e m n => MonadCatch e (IdentityT m) (IdentityT n) where m `catch` h = IdentityT $ runIdentityT m `catch` (runIdentityT . h) instance MonadThrow e m => MonadThrow e (ListT m) instance MonadCatch e m n => MonadCatch e (ListT m) (ListT n) where m `catch` h = ListT $ runListT m `catch` \ e -> runListT (h e) instance MonadThrow e m => MonadThrow e (MaybeT m) instance MonadCatch e m n => MonadCatch e (MaybeT m) (MaybeT n) where m `catch` h = MaybeT $ runMaybeT m `catch` (runMaybeT . h) instance MonadThrow e m => MonadThrow e (ReaderT r m) instance MonadCatch e m n => MonadCatch e (ReaderT r m) (ReaderT r n) where m `catch` h = ReaderT $ \ r -> runReaderT m r `catch` \ e -> runReaderT (h e) r instance (Monoid w, MonadThrow e m) => MonadThrow e (Lazy.RWST r w s m) instance (Monoid w, MonadCatch e m n) => MonadCatch e (Lazy.RWST r w s m) (Lazy.RWST r w s n) where m `catch` h = Lazy.RWST $ \ r s -> Lazy.runRWST m r s `catch` \ e -> Lazy.runRWST (h e) r s instance (Monoid w, MonadThrow e m) => MonadThrow e (Strict.RWST r w s m) instance (Monoid w, MonadCatch e m n) => MonadCatch e (Strict.RWST r w s m) (Strict.RWST r w s n) where m `catch` h = Strict.RWST $ \ r s -> Strict.runRWST m r s `catch` \ e -> Strict.runRWST (h e) r s instance MonadThrow e m => MonadThrow e (Lazy.StateT s m) instance MonadCatch e m n => MonadCatch e (Lazy.StateT s m) (Lazy.StateT s n) where m `catch` h = Lazy.StateT $ \ s -> Lazy.runStateT m s `catch` \ e -> Lazy.runStateT (h e) s instance MonadThrow e m => MonadThrow e (Strict.StateT s m) instance MonadCatch e m n => MonadCatch e (Strict.StateT s m) (Strict.StateT s n) where m `catch` h = Strict.StateT $ \ s -> Strict.runStateT m s `catch` \ e -> Strict.runStateT (h e) s instance (Monoid w, MonadThrow e m) => MonadThrow e (Lazy.WriterT w m) instance ( Monoid w , MonadCatch e m n ) => MonadCatch e (Lazy.WriterT w m) (Lazy.WriterT w n) where m `catch` h = Lazy.WriterT $ Lazy.runWriterT m `catch` \ e -> Lazy.runWriterT (h e) instance ( Monoid w , MonadCatch e m n ) => MonadCatch e (Strict.WriterT w m) (Strict.WriterT w n) where m `catch` h = Strict.WriterT $ Strict.runWriterT m `catch` \ e -> Strict.runWriterT (h e) instance (Monoid w, MonadThrow e m) => MonadThrow e (Strict.WriterT w m) instance MonadThrow e (Either e) where throw = Left instance MonadCatch e (Either e) (Either e') where Left e `catch` h = h e Right a `catch` _h = Right a instance MonadThrow IOException IO where throw = throwError instance MonadCatch IOException IO IO where catch = catchError newtype WrappedMonadError m a = WrapMonadError { unwrapMonadError :: m a } instance Monad m => Monad (WrappedMonadError m) where return = WrapMonadError . return m >>= f = WrapMonadError $ unwrapMonadError m >>= unwrapMonadError . f instance Error.MonadError e m => MonadThrow e (WrappedMonadError m) where throw = WrapMonadError . throwError instance Error.MonadError e m => MonadCatch e (WrappedMonadError m) (WrappedMonadError m) where m `catch` h = WrapMonadError $ unwrapMonadError m `catchError` (unwrapMonadError . h) newtype WrappedMonadCatch m a = WrapMonadCatch { unwrapMonadCatch :: m a } instance Monad m => Monad (WrappedMonadCatch m) where return = WrapMonadCatch . return m >>= f = WrapMonadCatch $ unwrapMonadCatch m >>= unwrapMonadCatch . f instance MonadCatch e m m => Error.MonadError e (WrappedMonadCatch m) where throwError = WrapMonadCatch . throw m `catchError` h = WrapMonadCatch $ unwrapMonadCatch m `catch` (unwrapMonadCatch . h)