{-# language BlockArguments #-}
{-# language DataKinds #-}
{-# language FlexibleContexts #-}
{-# language GADTs #-}
{-# language LambdaCase #-}
{-# language RankNTypes #-}
{-# language TypeOperators #-}

module IO.Effects.Exception
  ( -- * Running @Program@s with exceptions
    runExceptions

    -- * Throwing
  , throwIO

    -- * Catching (with recovery)
  , catch
  , catchIO
  , catchAny
  , catchJust
  , handle
  , handleIO
  , handleAny
  , handleJust
  , try
  , tryIO
  , tryAny
  , tryJust
  , Handler(..)
  , catches

    -- * Cleanup (no recovery)
  , onException
  , bracket
  , bracket_
  , finally
  , withException
  , bracketOnError
  , bracketOnError_

    -- * Distinguishing exception types
  , isSyncException

    -- * Masking
  , mask
  , mask_
  , uninterruptibleMask
  , uninterruptibleMask_

    -- * Evaluation
  , evaluate

    -- * @Exceptions@ syntax
  , Exceptions(..)
  ) where

import Control.Exception ( Exception, IOException, SomeException( SomeException ), SomeAsyncException( SomeAsyncException ), fromException, toException )
import qualified Control.Exception as EUnsafe
import IO.Effects.Internal


-- | The underlying syntax used by the exception system.
data Exceptions m a where
  Catch :: Exception e => m a -> ( e -> m a ) -> Exceptions m a
  Evaluate :: a -> Exceptions m a
  Mask :: ( ( forall a. m a -> m a ) -> m b ) -> Exceptions m b
  ThrowIO :: Exception e => e -> Exceptions m a
  UninterruptibleMask :: ( ( forall a. m a -> m a ) -> m b ) -> Exceptions m b


{-| Run a program that may use exceptions. The program itself is wrapped in a
call to 'tryAny' to prevent any top-level exceptions leaking.

@runExceptions@ allows asynchronous exceptions to pass through and does not
catch them.

-}
runExceptions
  :: ProgramWithHandler Exceptions es a
  -> Program es ( Either SomeException a )
runExceptions p =
  go ( tryAny p )

  where

    go =
      interpret \case
        Catch m f ->
          Program ( EUnsafe.catch ( programToIO m ) ( programToIO . f ) )

        Evaluate a ->
          Program ( EUnsafe.evaluate a )

        Mask f ->
          Program ( EUnsafe.mask ( \restore -> programToIO ( f ( Program . restore . programToIO ) ) ) )

        UninterruptibleMask f ->
          Program ( EUnsafe.uninterruptibleMask ( \restore -> programToIO ( f ( Program . restore . programToIO ) ) ) )

        ThrowIO e ->
          Program ( EUnsafe.throwIO e )


-- | Catch synchronous exceptions. Asynchronous exceptions (exceptions that
-- match 'isSyncException' will not be caught).
--
-- See also 'EUnsafe.catch' (from "Control.Exception").
catch
  :: ( Exception e, Member Exceptions es )
  => Program es a -> ( e -> Program es a ) -> Program es a
catch m f =
  send ( m `Catch` f' )

  where

    f' e =
      if isSyncException e then
        f e

      else
        -- intentionally rethrowing an async exception synchronously,
        -- since we want to preserve async behavior
        throwIO e


-- | 'catch' specialized to only catching 'IOException's.
catchIO
  :: Member Exceptions es
  => Program es a -> ( IOException -> Program es a ) -> Program es a
catchIO =
  catch


-- | 'catch' specialized to catch all synchronous exceptions.
catchAny
  :: Member Exceptions es
  => Program es a -> ( SomeException -> Program es a ) -> Program es a
catchAny =
  catch


-- catchDeep m =
--   catch ( m >>= evaluateDeep )

-- | @catchJust@ is like 'catch' but it takes an extra argument which is an
-- exception predicate, a function which selects which type of exceptions we're
-- interested in.
catchJust
  :: ( Member Exceptions es, Exception e )
  => ( e -> Maybe b )
  -> Program es a
  -> ( b -> Program es a )
  -> Program es a
catchJust f a b =
  a `catch` \e -> maybe ( throwIO e ) b $ f e


-- | Flipped version of 'catch'.
handle
  :: ( Member Exceptions es, Exception e )
  => ( e -> Program es a ) -> Program es a -> Program es a
handle =
  flip catch


-- | Flipped version of 'catchIO'.
handleIO
  :: Member Exceptions es
  => ( IOException -> Program es a ) -> Program es a -> Program es a
handleIO =
  flip catch


-- | Flipped version of 'catchAny'.
handleAny
  :: Member Exceptions es
  => ( SomeException -> Program es a ) -> Program es a -> Program es a
handleAny =
  flip catch


-- | Flipped version of 'catchJust'.
handleJust
  :: ( Member Exceptions es, Exception e )
  => ( e -> Maybe b )
  -> ( b -> Program es a )
  -> Program es a
  -> Program es a
handleJust f =
  flip ( catchJust f )


-- | Like 'try' (from "Control.Exception"), but will not catch asynchronous
-- exceptions.
try
  :: ( Member Exceptions es, Exception e )
  => Program es a -> Program es ( Either e a )
try f =
  catch ( fmap Right f ) ( return . Left )


-- | 'try' specialized to 'IOException's.
tryIO
  :: Member Exceptions es
  => Program es a -> Program es ( Either IOException a )
tryIO =
  try


-- | 'try' specialized to catch all synchronous exceptions.
tryAny
  :: Member Exceptions es
  => Program es a -> Program es ( Either SomeException a )
tryAny =
  try


-- | A variant of 'try' that takes an exception predicate to select which
-- exceptions are caught.
tryJust
  :: ( Member Exceptions es, Exception e )
  => ( e -> Maybe b ) -> Program es a -> Program es ( Either b a )
tryJust f a =
  catch ( Right `fmap` a ) ( \e -> maybe ( throwIO e ) ( return . Left ) ( f e ) )


-- | Generalized version of 'EUnsafe.Handler' (from "Control.Exception").
data Handler m a where
  Handler :: Exception e => ( e -> m a ) -> Handler m a


catchesHandler
  :: Member Exceptions es
  => [ Handler ( Program es ) a ] -> SomeException -> Program es a
catchesHandler handlers e =
  foldr tryHandler ( throwIO e ) handlers

  where

    tryHandler ( Handler handler ) res =
      maybe res handler ( fromException e )


-- | Same as upstream 'EUnsafe.catches' (from "Control.Exception"), but will not
-- catch asynchronous exceptions.
catches
  :: Member Exceptions es
  => Program es a -> [ Handler ( Program es ) a ] -> Program es a
catches io handlers =
  io `catch` catchesHandler handlers


-- | See 'EUnsafe.evaluate' (from "Control.Exception").
evaluate :: Member Exceptions es => a -> Program es a
evaluate =
  send . Evaluate


-- | Async safe version of 'EUnsafe.bracket' (from "Control.Exception").
bracket
 :: Member Exceptions es
 => Program es a
 -> ( a -> Program es b )
 -> ( a -> Program es c )
 -> Program es c
bracket before after thing = mask \restore -> do
  x <-
    before

  res1 <-
    tryAny ( restore ( thing x ) )

  case res1 of
    Left e1 -> do
      -- explicitly ignore exceptions from after. We know that
      -- no async exceptions were thrown there, so therefore
      -- the stronger exception must come from thing
      --
      -- https://github.com/fpco/safe-exceptions/issues/2
      _ <-
        tryAny ( uninterruptibleMask_ ( after x ) )

      throwIO e1

    Right y ->
      y <$ uninterruptibleMask_ ( after x )


-- | Async safe version of 'EUnsafe.bracket_' (from "Control.Exception").
bracket_
  :: Member Exceptions es
  => Program es a -> Program es b -> Program es c -> Program es c
bracket_ before after thing =
  bracket before ( const after ) ( const thing )


-- | Async safe version of 'EUnsafe.bracketOnError' (from "Control.Exception").
bracketOnError
  :: Member Exceptions es
  => Program es a -> ( a -> Program es b ) -> ( a -> Program es c ) -> Program es c
bracketOnError before after thing = mask \restore -> do
  x <-
    before

  res1 <-
    tryAny ( restore ( thing x ) )

  case res1 of
    Left e1 -> do
      -- ignore the exception, see bracket for explanation
      _ <-
        tryAny ( uninterruptibleMask_ ( after x ) )

      throwIO e1

    Right y ->
      return y


-- | Async safe version of 'EUnsafe.bracketOnError_' (from "Control.Exception").
bracketOnError_
  :: Member Exceptions es
  => Program es a -> Program es b -> Program es c -> Program es c
bracketOnError_ before after thing =
  bracketOnError before ( const after ) ( const thing )


-- | Async safe version of 'EUnsafe.finally' (from "Control.Exception").
finally
  :: Member Exceptions es
  => Program es a -> Program es b -> Program es a
finally thing after = uninterruptibleMask \restore -> do
  res1 <-
    tryAny ( restore thing )

  case res1 of
    Left e1 -> do
      -- see bracket for explanation
      _ <-
        tryAny after

      throwIO e1

    Right x -> do
      x <$ after


-- | Async safe version of 'EUnsafe.withException' (from "Control.Exception").
withException
  :: ( Member Exceptions es, Exception e )
  => Program es a -> ( e -> Program es b ) -> Program es a
withException thing after = uninterruptibleMask \restore -> do
  res1 <-
    try ( restore thing )

  case res1 of
    Left e1 -> do
      -- see explanation in bracket
      _ <-
        tryAny ( after e1 )

      throwIO e1

    Right x ->
      return x


-- | Async safe version of 'EUnsafe.onException' (from "Control.Exception").
onException
  :: Member Exceptions es
  => Program es a -> Program es b -> Program es a
onException thing after =
  withException thing \SomeException{} -> after


-- | Synchronously throw the given exception.
--
-- See also: 'EUnsafe.throwIO' (from "Control.Exception").
throwIO :: ( Member Exceptions es, Exception e ) => e -> Program es a
throwIO =
  send . ThrowIO


-- | See 'EUnsafe.mask' (from "Control.Exception").
mask
  :: Member Exceptions es
  => ( ( forall x. Program es x -> Program es x ) -> Program es a )
  -> Program es a
mask f =
  send ( Mask f )


-- | See 'EUnsafe.uninterruptibleMask' (from "Control.Exception").
uninterruptibleMask
  :: Member Exceptions es
  => ( ( forall x. Program es x -> Program es x ) -> Program es a )
  -> Program es a
uninterruptibleMask f =
  send ( UninterruptibleMask f )


-- | See 'EUnsafe.mask_' (from "Control.Exception").
mask_ :: Member Exceptions es => Program es a -> Program es a
mask_ m =
  mask \_ -> m


-- | See 'EUnsafe.uninterruptibleMask_' (from "Control.Exception").
uninterruptibleMask_ :: Member Exceptions es => Program es a -> Program es a
uninterruptibleMask_ m =
  uninterruptibleMask \_ -> m


-- | Check if the given exception is synchronous.
isSyncException :: Exception e => e -> Bool
isSyncException e =
  case fromException ( toException e ) of
    Just SomeAsyncException{} ->
      False

    Nothing ->
      True