{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE TypeFamilies #-}
-- |
-- Module      : Control.Prim.Monad.Throw
-- Copyright   : (c) Alexey Kuleshevich 2020
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <alexey@kuleshevi.ch>
-- Stability   : experimental
-- Portability : non-portable
--
module Control.Prim.Monad.Throw
  ( MonadThrow(..)
  ) where

import Control.Exception
import Control.Monad.ST
import Control.Monad.ST.Unsafe
import GHC.Conc.Sync (STM(..))
import GHC.Exts
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Cont (ContT)
import Control.Monad.Trans.Except (ExceptT(..))
import Control.Monad.Trans.Identity (IdentityT)
import Control.Monad.Trans.Maybe (MaybeT(..))
import Control.Monad.Trans.Reader (ReaderT(..))
import Control.Monad.Trans.RWS.Lazy as Lazy (RWST)
import Control.Monad.Trans.RWS.Strict as Strict (RWST)
import Control.Monad.Trans.State.Lazy as Lazy (StateT)
import Control.Monad.Trans.State.Strict as Strict (StateT)
import Control.Monad.Trans.Writer.Lazy as Lazy (WriterT)
import Control.Monad.Trans.Writer.Strict as Strict (WriterT)
#if MIN_VERSION_transformers(0, 5, 3)
import Control.Monad.Trans.Accum (AccumT)
import Control.Monad.Trans.Select (SelectT)
#if MIN_VERSION_transformers(0, 5, 6)
import Control.Monad.Trans.RWS.CPS as CPS (RWST)
import Control.Monad.Trans.Writer.CPS as CPS (WriterT)
#endif
#endif


-- | A class for monads in which exceptions may be thrown.
--
-- Instances should obey the following law:
--
-- > throwM e >> x = throwM e
--
-- In other words, throwing an exception short-circuits the rest of the monadic
-- computation.
--
-- === Note
--
-- This is an identical class to
-- [MonadThrow](https://hackage.haskell.org/package/exceptions/docs/Control-Monad-Catch.html#t:MonadThrow)
-- from @exceptions@ package. The reason why it was copied, instead of a direct dependency
-- on the aforementioned package is because @MonadCatch@ and @MonadMask@ are not right
-- abstractions for exception handling in presence of concurrency and also because
-- instances for such transformers as `MaybeT` and `ExceptT` are flawed.
class Monad m => MonadThrow m where
  -- | Throw an exception. Note that this throws when this action is run in
  -- the monad @m@, not when it is applied. It is a generalization of
  -- "Control.Prim.Exception"'s 'Control.Prim.Exception.throw'.
  --
  throwM :: Exception e => e -> m a

instance MonadThrow Maybe where
  throwM :: e -> Maybe a
throwM e
_ = Maybe a
forall a. Maybe a
Nothing

instance e ~ SomeException => MonadThrow (Either e) where
  throwM :: e -> Either e a
throwM = SomeException -> Either SomeException a
forall a b. a -> Either a b
Left (SomeException -> Either SomeException a)
-> (e -> SomeException) -> e -> Either SomeException a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> SomeException
forall e. Exception e => e -> SomeException
toException

instance MonadThrow IO where
  throwM :: e -> IO a
throwM = e -> IO a
forall e a. Exception e => e -> IO a
throwIO

instance MonadThrow (ST s) where
  throwM :: e -> ST s a
throwM e
e = IO a -> ST s a
forall a s. IO a -> ST s a
unsafeIOToST (IO a -> ST s a) -> IO a -> ST s a
forall a b. (a -> b) -> a -> b
$ e -> IO a
forall e a. Exception e => e -> IO a
throwIO e
e

instance MonadThrow STM where
  throwM :: e -> STM a
throwM e
e = (State# RealWorld -> (# State# RealWorld, a #)) -> STM a
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> STM a
STM ((State# RealWorld -> (# State# RealWorld, a #)) -> STM a)
-> (State# RealWorld -> (# State# RealWorld, a #)) -> STM a
forall a b. (a -> b) -> a -> b
$ SomeException -> State# RealWorld -> (# State# RealWorld, a #)
forall a b. a -> State# RealWorld -> (# State# RealWorld, b #)
raiseIO# (e -> SomeException
forall e. Exception e => e -> SomeException
toException e
e)


instance MonadThrow m => MonadThrow (ContT r m) where
  throwM :: e -> ContT r m a
throwM = m a -> ContT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> ContT r m a) -> (e -> m a) -> e -> ContT r m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM

instance (e ~ SomeException, Monad m) => MonadThrow (ExceptT e m) where
  throwM :: e -> ExceptT e m a
throwM e
e = m (Either SomeException a) -> ExceptT SomeException m a
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (Either SomeException a -> m (Either SomeException a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SomeException -> Either SomeException a
forall a b. a -> Either a b
Left (e -> SomeException
forall e. Exception e => e -> SomeException
toException e
e)))

instance MonadThrow m => MonadThrow (IdentityT m) where
  throwM :: e -> IdentityT m a
throwM = m a -> IdentityT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> IdentityT m a) -> (e -> m a) -> e -> IdentityT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM

instance Monad m => MonadThrow (MaybeT m) where
  throwM :: e -> MaybeT m a
throwM e
_ = m (Maybe a) -> MaybeT m a
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (Maybe a -> m (Maybe a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing)

instance MonadThrow m => MonadThrow (ReaderT r m) where
  throwM :: e -> ReaderT r m a
throwM = m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> ReaderT r m a) -> (e -> m a) -> e -> ReaderT r m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM

instance (Monoid w, MonadThrow m) => MonadThrow (Lazy.RWST r w s m) where
  throwM :: e -> RWST r w s m a
throwM = m a -> RWST r w s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> RWST r w s m a) -> (e -> m a) -> e -> RWST r w s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM

instance (Monoid w, MonadThrow m) => MonadThrow (Strict.RWST r w s m) where
  throwM :: e -> RWST r w s m a
throwM = m a -> RWST r w s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> RWST r w s m a) -> (e -> m a) -> e -> RWST r w s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM

instance MonadThrow m => MonadThrow (Lazy.StateT s m) where
  throwM :: e -> StateT s m a
throwM = m a -> StateT s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> StateT s m a) -> (e -> m a) -> e -> StateT s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM

instance MonadThrow m => MonadThrow (Strict.StateT s m) where
  throwM :: e -> StateT s m a
throwM = m a -> StateT s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> StateT s m a) -> (e -> m a) -> e -> StateT s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM

instance (Monoid w, MonadThrow m) => MonadThrow (Lazy.WriterT w m) where
  throwM :: e -> WriterT w m a
throwM = m a -> WriterT w m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> WriterT w m a) -> (e -> m a) -> e -> WriterT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM

instance (Monoid w, MonadThrow m) => MonadThrow (Strict.WriterT w m) where
  throwM :: e -> WriterT w m a
throwM = m a -> WriterT w m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> WriterT w m a) -> (e -> m a) -> e -> WriterT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM

#if MIN_VERSION_transformers(0, 5, 3)

instance (Monoid w, MonadThrow m) => MonadThrow (AccumT w m) where
  throwM :: e -> AccumT w m a
throwM = m a -> AccumT w m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> AccumT w m a) -> (e -> m a) -> e -> AccumT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM
instance MonadThrow m => MonadThrow (SelectT r m) where
  throwM :: e -> SelectT r m a
throwM = m a -> SelectT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> SelectT r m a) -> (e -> m a) -> e -> SelectT r m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM

#if MIN_VERSION_transformers(0, 5, 6)

instance MonadThrow m => MonadThrow (CPS.RWST r w st m) where
  throwM :: e -> RWST r w st m a
throwM = m a -> RWST r w st m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> RWST r w st m a) -> (e -> m a) -> e -> RWST r w st m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM
instance MonadThrow m => MonadThrow (CPS.WriterT w m) where
  throwM :: e -> WriterT w m a
throwM = m a -> WriterT w m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> WriterT w m a) -> (e -> m a) -> e -> WriterT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM

#endif
#endif