{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE FlexibleContexts     #-}
{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE MagicHash            #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

-- | The 'MonadRaise' class, which is an effect for
--   early escape / happy path programming with an exception side channel

module Control.Monad.Raise.Class
  ( MonadRaise (..)
  , ErrorCase
  ) where

import           Control.Exception

import           Control.Monad.Catch.Pure
import           Control.Monad.Cont
import           Control.Monad.ST

import           Control.Monad.Trans.Except
import           Control.Monad.Trans.Identity
import           Control.Monad.Trans.Maybe
import           Control.Monad.Trans.Reader

import qualified Control.Monad.RWS.Lazy       as Lazy
import qualified Control.Monad.RWS.Strict     as Strict

import qualified Control.Monad.State.Lazy     as Lazy
import qualified Control.Monad.State.Strict   as Strict

import qualified Control.Monad.Writer.Lazy    as Lazy
import qualified Control.Monad.Writer.Strict  as Strict

import           Data.Kind
import           Data.WorldPeace
import           Data.WorldPeace.Subset.Class

import           GHC.Base
import           GHC.Conc
import           GHC.IO

-- $setup
--
-- >>> :set -XDataKinds
-- >>> :set -XFlexibleContexts
-- >>> :set -XTypeApplications
--
-- >>> import Data.WorldPeace

-- | Raise semantics, like a type-directed @MonadThrow@.
--   Not unlike @MonadError@ with an in-built open variant.
class Monad m => MonadRaise m where
  type Errors m :: [Type]

  -- | Raise an error
  --
  -- The @Proxy@ gives a type hint to the type checker.
  -- If you have a case where it can be inferred, see 'Control.Monad.Raise.raise''.
  --
  -- ==== __Examples__
  --
  -- >>> data FooErr  = FooErr  deriving Show
  -- >>> data BarErr  = BarErr  deriving Show
  -- >>> data QuuxErr = QuuxErr deriving Show
  -- >>>
  -- >>> type MyErrs  = '[FooErr, BarErr]
  -- >>>
  -- >>> :{
  --  goesBoom :: Int -> Either (OpenUnion MyErrs) Int
  --  goesBoom x =
  --    if x > 50
  --      then return x
  --      else raise FooErr
  -- :}
  --
  -- >>> goesBoom 42
  -- Left (Identity FooErr)
  --
  -- >>> :{
  --  maybeBoom :: Int -> Maybe Int
  --  maybeBoom x =
  --    if x > 50
  --      then return x
  --      else raise ()
  -- :}
  --
  -- >>> maybeBoom 42
  -- Nothing
  raise :: Subset err (ErrorCase m) => err -> m a

-- | Type alias representing the concrete union of the monad's errors
type ErrorCase m = OpenUnion (Errors m)

instance MonadRaise [] where
  type Errors [] = '[()]
  raise :: err -> [a]
raise err
_ = []

instance MonadRaise Maybe where
  type Errors Maybe = '[()]
  raise :: err -> Maybe a
raise err
_ = Maybe a
forall a. Maybe a
Nothing

instance MonadRaise (Either (OpenUnion errs)) where
  type Errors (Either (OpenUnion errs)) = errs
  raise :: err -> Either (OpenUnion errs) a
raise = OpenUnion errs -> Either (OpenUnion errs) a
forall a b. a -> Either a b
Left (OpenUnion errs -> Either (OpenUnion errs) a)
-> (err -> OpenUnion errs) -> err -> Either (OpenUnion errs) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> OpenUnion errs
forall err errs. Subset err errs => err -> errs
include

instance MonadRaise IO where
  type Errors IO = '[IOException]
  raise :: err -> IO a
raise = (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, a #)) -> IO a)
-> (err -> State# RealWorld -> (# State# RealWorld, a #))
-> err
-> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> State# RealWorld -> (# State# RealWorld, a #)
forall a b. a -> State# RealWorld -> (# State# RealWorld, b #)
raiseIO#

instance MonadRaise (ST s) where
  type Errors (ST s) = '[IOException]
  raise :: err -> ST s a
raise = IO a -> ST s a
forall a s. IO a -> ST s a
GHC.IO.unsafeIOToST (IO a -> ST s a) -> (err -> IO a) -> err -> ST s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> IO a
forall (m :: * -> *) err a.
(MonadRaise m, Subset err (ErrorCase m)) =>
err -> m a
raise

instance MonadRaise STM where
  type Errors STM = '[IOException]
  raise :: err -> STM a
raise = (State# RealWorld -> (# State# RealWorld, a #)) -> STM a
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> STM a
STM ((State# RealWorld -> (# State# RealWorld, a #)) -> STM a)
-> (err -> State# RealWorld -> (# State# RealWorld, a #))
-> err
-> STM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> State# RealWorld -> (# State# RealWorld, a #)
forall a b. a -> State# RealWorld -> (# State# RealWorld, b #)
raiseIO#

instance (MonadRaise m, Contains (Errors m) errs)
  => MonadRaise (ExceptT (OpenUnion errs) m) where
    type Errors (ExceptT (OpenUnion errs) m) = errs
    raise :: err -> ExceptT (OpenUnion errs) m a
raise = m (Either (OpenUnion errs) a) -> ExceptT (OpenUnion errs) m a
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either (OpenUnion errs) a) -> ExceptT (OpenUnion errs) m a)
-> (err -> m (Either (OpenUnion errs) a))
-> err
-> ExceptT (OpenUnion errs) m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Either (OpenUnion errs) a -> m (Either (OpenUnion errs) a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (OpenUnion errs) a -> m (Either (OpenUnion errs) a))
-> (err -> Either (OpenUnion errs) a)
-> err
-> m (Either (OpenUnion errs) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> Either (OpenUnion errs) a
forall (m :: * -> *) err a.
(MonadRaise m, Subset err (ErrorCase m)) =>
err -> m a
raise

instance MonadRaise m => MonadRaise (IdentityT m) where
  type Errors (IdentityT m) = Errors m
  raise :: err -> IdentityT m a
raise = m a -> IdentityT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> IdentityT m a) -> (err -> m a) -> err -> IdentityT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> m a
forall (m :: * -> *) err a.
(MonadRaise m, Subset err (ErrorCase m)) =>
err -> m a
raise

instance (MonadRaise m, () `IsMember` Errors m) => MonadRaise (MaybeT m) where
  type Errors (MaybeT m) = Errors m
  raise :: err -> MaybeT m a
raise err
err = m (Maybe a) -> MaybeT m a
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe a) -> MaybeT m a) -> m (Maybe a) -> MaybeT m a
forall a b. (a -> b) -> a -> b
$ err -> m (Maybe a)
forall (m :: * -> *) err a.
(MonadRaise m, Subset err (ErrorCase m)) =>
err -> m a
raise err
err

instance MonadRaise m => MonadRaise (ReaderT cfg m) where
  type Errors (ReaderT cfg m) = Errors m
  raise :: err -> ReaderT cfg m a
raise = m a -> ReaderT cfg m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> ReaderT cfg m a) -> (err -> m a) -> err -> ReaderT cfg m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> m a
forall (m :: * -> *) err a.
(MonadRaise m, Subset err (ErrorCase m)) =>
err -> m a
raise

instance MonadRaise m => MonadRaise (CatchT m) where
  type Errors (CatchT m) = Errors m
  raise :: err -> CatchT m a
raise = m a -> CatchT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> CatchT m a) -> (err -> m a) -> err -> CatchT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> m a
forall (m :: * -> *) err a.
(MonadRaise m, Subset err (ErrorCase m)) =>
err -> m a
raise

instance MonadRaise m => MonadRaise (ContT r m) where
  type Errors (ContT r m) = Errors m
  raise :: err -> ContT r m a
raise = m a -> ContT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> ContT r m a) -> (err -> m a) -> err -> ContT r m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> m a
forall (m :: * -> *) err a.
(MonadRaise m, Subset err (ErrorCase m)) =>
err -> m a
raise

instance MonadRaise m => MonadRaise (Lazy.StateT s m) where
  type Errors (Lazy.StateT s m) = Errors m
  raise :: err -> StateT s m a
raise = m a -> StateT s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> StateT s m a) -> (err -> m a) -> err -> StateT s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> m a
forall (m :: * -> *) err a.
(MonadRaise m, Subset err (ErrorCase m)) =>
err -> m a
raise

instance MonadRaise m => MonadRaise (Strict.StateT s m) where
  type Errors (Strict.StateT s m) = Errors m
  raise :: err -> StateT s m a
raise = m a -> StateT s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> StateT s m a) -> (err -> m a) -> err -> StateT s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> m a
forall (m :: * -> *) err a.
(MonadRaise m, Subset err (ErrorCase m)) =>
err -> m a
raise

instance (MonadRaise m, Monoid w) => MonadRaise (Lazy.WriterT w m) where
  type Errors (Lazy.WriterT w m) = Errors m
  raise :: err -> WriterT w m a
raise = m a -> WriterT w m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> WriterT w m a) -> (err -> m a) -> err -> WriterT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> m a
forall (m :: * -> *) err a.
(MonadRaise m, Subset err (ErrorCase m)) =>
err -> m a
raise

instance (MonadRaise m, Monoid w) => MonadRaise (Strict.WriterT w m) where
  type Errors (Strict.WriterT w m) = Errors m
  raise :: err -> WriterT w m a
raise = m a -> WriterT w m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> WriterT w m a) -> (err -> m a) -> err -> WriterT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> m a
forall (m :: * -> *) err a.
(MonadRaise m, Subset err (ErrorCase m)) =>
err -> m a
raise

instance (MonadRaise m, Monoid w) => MonadRaise (Lazy.RWST r w s m) where
  type Errors (Lazy.RWST r w s m) = Errors m
  raise :: err -> RWST r w s m a
raise = m a -> RWST r w s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> RWST r w s m a) -> (err -> m a) -> err -> RWST r w s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> m a
forall (m :: * -> *) err a.
(MonadRaise m, Subset err (ErrorCase m)) =>
err -> m a
raise

instance (MonadRaise m, Monoid w) => MonadRaise (Strict.RWST r w s m) where
  type Errors (Strict.RWST r w s m) = Errors m
  raise :: err -> RWST r w s m a
raise = m a -> RWST r w s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> RWST r w s m a) -> (err -> m a) -> err -> RWST r w s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> m a
forall (m :: * -> *) err a.
(MonadRaise m, Subset err (ErrorCase m)) =>
err -> m a
raise