{-# LANGUAGE CPP #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Control.Monad.Exception
  ( exception
  , forceWHNF
  , throw
  , throwIO
  , catch
  , catchJust
  , handle
  , handleJust
  , Handler(..)
  , catches
  , try
  , tryJust
  , onException
  , onExceptions
  , MonadFinally(..)
  , onEscape
  , tryAll
  , MonadMask(..)
  , mask
  , mask_
  , uninterruptibleMask
  , uninterruptibleMask_
  , bracket
  , bracket_
  , bracketOnEscape
  , bracketOnError
  , module Control.Monad.Abort.Class
  , module Control.Exception
  ) where

#if !MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif
import Data.Monoid (Monoid)
import Data.Proxy (Proxy(..))
import Data.Traversable
import Data.Functor.Identity
import Control.Applicative (Applicative)
import Control.Monad (join, liftM)
import Control.Monad.Base
import Control.Monad.Trans.Class
import Control.Monad.Trans.Control
import Control.Monad.Trans.Abort hiding (abort, recover)
import Control.Monad.Trans.Finish
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.List
import Control.Monad.Trans.Error
#if MIN_VERSION_transformers(0,4,0)
import Control.Monad.Trans.Except
#endif
import Control.Monad.Trans.Reader
import qualified Control.Monad.Trans.State.Lazy as L
import qualified Control.Monad.Trans.State.Strict as S
import qualified Control.Monad.Trans.Writer.Lazy as L
import qualified Control.Monad.Trans.Writer.Strict as S
import qualified Control.Monad.Trans.RWS.Lazy as L
import qualified Control.Monad.Trans.RWS.Strict as S
import Control.Monad.Abort.Class
import Control.Exception hiding (
  evaluate, throw, throwIO, catch, catchJust, handle, handleJust,
  Handler(..), catches, try, tryJust, finally, onException,
#if !MIN_VERSION_base(4,7,0)
  block, unblock, blocked,
#endif
  getMaskingState, mask, mask_, uninterruptibleMask,
  uninterruptibleMask_, bracket, bracket_, bracketOnError)
import qualified Control.Exception as E
import GHC.Base (maskAsyncExceptions#, maskUninterruptible#,
                 unmaskAsyncExceptions#)
import GHC.IO (IO(..))

exception  Exception e  e  α
exception = E.throw

forceWHNF  MonadBase IO μ  α  μ α
forceWHNF = liftBase . E.evaluate

throw  (MonadAbort SomeException μ, Exception e)  e  μ α
throw = abort . toException

throwIO  (MonadBase IO μ, Exception e)  e  μ α
throwIO = liftBase . E.throwIO

catch  (MonadRecover SomeException μ, Exception e)  μ α  (e  μ α)  μ α
catch m h = recover m $ \e  maybe (throw e) h (fromException e)

catchJust  (MonadRecover SomeException μ, Exception e)
           (e  Maybe β)  μ α  (β  μ α)  μ α
catchJust f m h = catch m $ \e  maybe (throw e) h $ f e

handle  (MonadRecover SomeException μ, Exception e)  (e  μ α)  μ α  μ α
handle = flip catch

handleJust  (MonadRecover SomeException μ, Exception e)
            (e  Maybe β)  (β  μ α)  μ α  μ α
handleJust = flip . catchJust

data Handler μ α =  e . Exception e  Handler (e  μ α)

catches  MonadRecover SomeException μ  μ α  [Handler μ α]  μ α
catches m = recover m . hl
  where hl [] e = abort e
        hl (Handler h : hs) e = maybe (hl hs e) h $ fromException e

try  (MonadRecover SomeException μ, Exception e)  μ α  μ (Either e α)
try m = do
  r  evaluate m
  case r of
    Right a  return $ Right a
    Left e | Just e'  fromException e  return $ Left e'
    Left e  throw e

tryJust  (MonadRecover SomeException μ, Exception e)
         (e  Maybe β)  μ α  μ (Either β α)
tryJust f m = do
  r  evaluate m
  case r of
    Right a  return $ Right a
    Left e | Just b  fromException e >>= f  return $ Left b
    Left e  throw e

onException  (MonadRecover SomeException μ, Exception e)
             μ α  (e  μ β)  μ α
onException m h = catch m (\e  h e >> throw e)

onExceptions  MonadRecover SomeException μ
              μ α  [Handler μ β]  μ α
onExceptions m = recover m . hl
  where hl [] e = abort e
        hl (Handler h : hs) e =
          maybe (hl hs e) (\e'  h e' >> abort e) $ fromException e

class (Applicative μ, Monad μ)  MonadFinally μ where
  finally'  μ α  (Maybe α  μ β)  μ (α, β)
  finally   μ α  μ β  μ α
  finally m = fmap fst . finally' m . const

instance MonadFinally Identity where
  finally' m f = do
    a  m
    return (a, runIdentity $ f $ Just a)

instance MonadFinally IO where
  finally' m f = E.mask $ \restore  do
    a  restore m `E.onException` f Nothing
    b  f $ Just a
    return (a, b)

instance MonadFinally μ  MonadFinally (MaybeT μ) where
  finally' m f = MaybeT $ do
    ~(mr, fr)  finally' (runMaybeT m) $ \mbr 
      runMaybeT $ f $ case mbr of
        Just (Just a)  Just a
        _              Nothing
    return $ (,) <$> mr <*> fr

instance MonadFinally μ  MonadFinally (ListT μ) where
  finally' m f = ListT $ do
    ~(mrs, frss)  finally' (runListT m) $ \mbr  case mbr of
      Just rs@(_ : _)  forM rs $ runListT . f . Just
      _  fmap pure $ runListT $ f Nothing
    return $ zip mrs frss >>= \(mr, frs)  zip (repeat mr) frs

instance MonadFinally μ  MonadFinally (AbortT e μ) where
  finally' m f = AbortT $ do
    ~(mr, fr)  finally' (runAbortT m) $ \mbr 
      runAbortT $ f $ case mbr of
        Just (Right a)  Just a
        _               Nothing
    return $ (,) <$> mr <*> fr

instance MonadFinally μ  MonadFinally (FinishT β μ) where
  finally' m f = FinishT $ do
    ~(mr, fr)  finally' (runFinishT m) $ \mbr 
      runFinishT $ f $ case mbr of
        Just (Right a)  Just a
        _               Nothing
    return $ (,) <$> mr <*> fr

instance (MonadFinally μ, Error e)  MonadFinally (ErrorT e μ) where
  finally' m f = ErrorT $ do
    ~(mr, fr)  finally' (runErrorT m) $ \mbr 
      runErrorT $ f $ case mbr of
        Just (Right a)  Just a
        _               Nothing
    return $ (,) <$> mr <*> fr

#if MIN_VERSION_transformers(0,4,0)
instance MonadFinally μ  MonadFinally (ExceptT e μ) where
  finally' m f = ExceptT $ do
    ~(mr, fr)  finally' (runExceptT m) $ \mbr 
      runExceptT $ f $ case mbr of
        Just (Right a)  Just a
        _               Nothing
    return $ (,) <$> mr <*> fr
#endif

instance MonadFinally μ  MonadFinally (ReaderT r μ) where
  finally' m f = ReaderT $ \r 
    finally' (runReaderT m r) ((`runReaderT` r) . f)

instance MonadFinally μ  MonadFinally (L.StateT s μ) where
  finally' m f = L.StateT $ \s  do
    ~(~(mr, _), ~(fr, s''))  finally' (L.runStateT m s) $ \mbr  do
      let ~(a, s') = case mbr of
             Just ~(x, t)  (Just x, t)
             Nothing       (Nothing, s)
      L.runStateT (f a) s'
    return ((mr, fr), s'')

instance MonadFinally μ  MonadFinally (S.StateT s μ) where
  finally' m f = S.StateT $ \s  do
    ((mr, _), (fr, s''))  finally' (S.runStateT m s) $ \mbr  case mbr of
      Just (a, s')  S.runStateT (f $ Just a) s'
      Nothing       S.runStateT (f Nothing) s
    return ((mr, fr), s'')

instance (MonadFinally μ, Monoid w)  MonadFinally (L.WriterT w μ) where
  finally' m f = L.WriterT $ do
    ~(~(mr, w), ~(fr, w'))  finally' (L.runWriterT m) $
      L.runWriterT . f . fmap fst
    return ((mr, fr), w `mappend` w')

instance (MonadFinally μ, Monoid w)  MonadFinally (S.WriterT w μ) where
  finally' m f = S.WriterT $ do
    ((mr, w), (fr, w'))  finally' (S.runWriterT m) $ \mbr  case mbr of
      Just (a, _)  S.runWriterT $ f $ Just a
      Nothing      S.runWriterT $ f Nothing
    return ((mr, fr), w `mappend` w')

instance (MonadFinally μ, Monoid w)  MonadFinally (L.RWST r w s μ) where
  finally' m f = L.RWST $ \r s  do
    ~(~(mr, _, w), ~(fr, s'', w'))  finally' (L.runRWST m r s) $ \mbr  do
      let ~(a, s') = case mbr of
             Just ~(x, t, _)  (Just x, t)
             Nothing          (Nothing, s)
      L.runRWST (f a) r s'
    return ((mr, fr), s'', w `mappend` w')

instance (MonadFinally μ, Monoid w)  MonadFinally (S.RWST r w s μ) where
  finally' m f = S.RWST $ \r s  do
    ((mr, _, w), (fr, s'', w'))  finally' (S.runRWST m r s) $ \mbr 
      case mbr of
        Just (a, s', _)  S.runRWST (f $ Just a) r s'
        Nothing          S.runRWST (f Nothing) r s
    return ((mr, fr), s'', w `mappend` w')

onEscape  MonadFinally μ  μ α  μ β  μ α
onEscape m f = fmap fst $ finally' m $ maybe (() <$ f) (const $ return ())

tryAll  MonadFinally μ  [μ α]  μ ()
tryAll []       = return ()
tryAll (m : ms) = finally (() <$ m) $ tryAll ms

deriving instance Ord MaskingState
deriving instance Enum MaskingState
deriving instance Bounded MaskingState

class (Applicative μ, Monad μ, Ord m, Bounded m)
       MonadMask m μ | μ  m where
  defMaskingState  Proxy μ  m
  getMaskingState  μ m
  setMaskingState  m  μ α  μ α

instance MonadMask () Identity where
  defMaskingState = const ()
  getMaskingState = return ()
  setMaskingState = const id

instance MonadMask MaskingState IO where
  defMaskingState = const MaskedInterruptible
  getMaskingState = E.getMaskingState
  setMaskingState Unmasked (IO io) = IO $ unmaskAsyncExceptions# io
  setMaskingState MaskedInterruptible (IO io) = IO $ maskAsyncExceptions# io
  setMaskingState MaskedUninterruptible (IO io) = IO $ maskUninterruptible# io

proxyDefMaskingState   m μ t . MonadMask m μ  Proxy (t μ)  m
proxyDefMaskingState = const (defMaskingState (Proxy  Proxy μ))

liftSetMaskingState  (MonadTransControl t, MonadMask m μ, Monad (t μ))
                     m  t μ α  t μ α
liftSetMaskingState ms m =
  join $ liftM (restoreT . return) $ liftWith $ \run 
    setMaskingState ms (run m)
{-# INLINE liftSetMaskingState #-}

instance MonadMask m μ  MonadMask m (MaybeT μ) where
  defMaskingState = proxyDefMaskingState
  getMaskingState = lift getMaskingState
  setMaskingState = liftSetMaskingState

instance MonadMask m μ  MonadMask m (ListT μ) where
  defMaskingState = proxyDefMaskingState
  getMaskingState = lift getMaskingState
  setMaskingState = liftSetMaskingState

instance MonadMask m μ  MonadMask m (AbortT e μ) where
  defMaskingState = proxyDefMaskingState
  getMaskingState = lift getMaskingState
  setMaskingState = liftSetMaskingState

instance MonadMask m μ  MonadMask m (FinishT β μ) where
  defMaskingState = proxyDefMaskingState
  getMaskingState = lift getMaskingState
  setMaskingState = liftSetMaskingState

instance (MonadMask m μ, Error e)  MonadMask m (ErrorT e μ) where
  defMaskingState = proxyDefMaskingState
  getMaskingState = lift getMaskingState
  setMaskingState = liftSetMaskingState

#if MIN_VERSION_transformers(0,4,0)
instance MonadMask m μ  MonadMask m (ExceptT e μ) where
  defMaskingState = proxyDefMaskingState
  getMaskingState = lift getMaskingState
  setMaskingState = liftSetMaskingState
#endif

instance MonadMask m μ  MonadMask m (ReaderT r μ) where
  defMaskingState = proxyDefMaskingState
  getMaskingState = lift getMaskingState
  setMaskingState = liftSetMaskingState

instance MonadMask m μ  MonadMask m (L.StateT s μ) where
  defMaskingState = proxyDefMaskingState
  getMaskingState = lift getMaskingState
  setMaskingState = liftSetMaskingState

instance MonadMask m μ  MonadMask m (S.StateT s μ) where
  defMaskingState = proxyDefMaskingState
  getMaskingState = lift getMaskingState
  setMaskingState = liftSetMaskingState

instance (MonadMask m μ, Monoid w)  MonadMask m (L.WriterT w μ) where
  defMaskingState = proxyDefMaskingState
  getMaskingState = lift getMaskingState
  setMaskingState = liftSetMaskingState

instance (MonadMask m μ, Monoid w)  MonadMask m (S.WriterT w μ) where
  defMaskingState = proxyDefMaskingState
  getMaskingState = lift getMaskingState
  setMaskingState = liftSetMaskingState

instance (MonadMask m μ, Monoid w)  MonadMask m (L.RWST r w s μ) where
  defMaskingState = proxyDefMaskingState
  getMaskingState = lift getMaskingState
  setMaskingState = liftSetMaskingState

instance (MonadMask m μ, Monoid w)  MonadMask m (S.RWST r w s μ) where
  defMaskingState = proxyDefMaskingState
  getMaskingState = lift getMaskingState
  setMaskingState = liftSetMaskingState

withMaskingState  MonadMask m μ
                  m  (( η β . MonadMask m η  η β  η β)  μ α)  μ α
withMaskingState ms' m = do
  ms  getMaskingState
  if ms' > ms
    then setMaskingState ms' $ m $ setMaskingState ms
    else m id

withMaskingState_  MonadMask m μ  m  μ α  μ α
withMaskingState_ m = withMaskingState m . const

mask   m μ α . MonadMask m μ
      (( η β . MonadMask m η  η β  η β)  μ α)  μ α
mask = withMaskingState $ defMaskingState (Proxy  Proxy μ)
 
mask_   m μ α . MonadMask m μ  μ α  μ α
mask_ = withMaskingState_ $ defMaskingState (Proxy  Proxy μ)

uninterruptibleMask  MonadMask MaskingState μ
                     (( η β . MonadMask MaskingState η  η β  η β)  μ α)
                     μ α
uninterruptibleMask = withMaskingState MaskedUninterruptible

uninterruptibleMask_  MonadMask MaskingState μ  μ α  μ α
uninterruptibleMask_ = withMaskingState_ MaskedUninterruptible

bracket  (MonadFinally μ, MonadMask m μ)
         μ α  (α  μ β)  (α  μ γ)  μ γ
bracket acq release m = mask $ \restore  do
  a  acq
  finally (restore $ m a) (release a)

bracket_  (MonadFinally μ, MonadMask m μ)  μ α  μ β  μ γ  μ γ
bracket_ acq release m = bracket acq (const release) (const m)

bracketOnEscape  (MonadFinally μ, MonadMask m μ)
                 μ α  (α  μ β)  (α  μ γ)  μ γ
bracketOnEscape acq release m = mask $ \restore  do
  a  acq
  restore (m a) `onEscape` release a

bracketOnError  (MonadRecover e μ, MonadMask m μ)
                μ α  (α  μ β)  (α  μ γ)  μ γ
bracketOnError acq release m = mask $ \restore  do
  a  acq
  r  restore (m a) `onError_` release a
  r <$ release a