{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}

module Hercules.CNix.Exception
  ( handleExceptions,
    handleExceptions',
    handleExceptionPtr,
  )
where

import Hercules.CNix.Store.Context (context)
import qualified Language.C.Inline.Cpp as C
import qualified Language.C.Inline.Cpp.Exception as C
import Protolude
import qualified System.Environment

C.context context

C.include "<nix/config.h>"
C.include "<nix/shared.hh>"
C.include "<nix/globals.hh>"

-- | Log C++ exceptions and call 'exitWith' the way Nix would exit when an
-- exception occurs.
handleExceptions :: IO a -> IO a
handleExceptions :: forall a. IO a -> IO a
handleExceptions IO a
io = do
  String
progName <- IO String
System.Environment.getProgName
  (ExitCode -> IO a) -> Text -> IO a -> IO a
forall a. (ExitCode -> IO a) -> Text -> IO a -> IO a
handleExceptions' ExitCode -> IO a
forall a. ExitCode -> IO a
exitWith (String -> Text
forall a b. ConvertText a b => a -> b
toS String
progName) IO a
io

-- | Log C++ exceptions and call 'exitWith' the way Nix would exit.
handleExceptions' ::
  -- | What to do when Nix would want to exit with 'ExitCode'
  (ExitCode -> IO a) ->
  -- | Program name (command name)
  Text ->
  IO a ->
  IO a
handleExceptions' :: forall a. (ExitCode -> IO a) -> Text -> IO a -> IO a
handleExceptions' ExitCode -> IO a
handleExit Text
programName IO a
io =
  let select :: CppException -> Maybe CppExceptionPtr
select (C.CppStdException CppExceptionPtr
eptr ByteString
_msg Maybe ByteString
_t) = CppExceptionPtr -> Maybe CppExceptionPtr
forall a. a -> Maybe a
Just CppExceptionPtr
eptr
      select CppException
_ = Maybe CppExceptionPtr
forall a. Maybe a
Nothing

      convertExit :: a -> ExitCode
convertExit a
0 = ExitCode
ExitSuccess
      convertExit a
e = Int -> ExitCode
ExitFailure (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
e)

      doHandle :: CppExceptionPtr -> IO a
doHandle = ExitCode -> IO a
handleExit (ExitCode -> IO a) -> (CInt -> ExitCode) -> CInt -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> ExitCode
forall {a}. Integral a => a -> ExitCode
convertExit (CInt -> IO a)
-> (CppExceptionPtr -> IO CInt) -> CppExceptionPtr -> IO a
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ByteString -> CppExceptionPtr -> IO CInt
handleExceptionPtr (Text -> ByteString
encodeUtf8 Text
programName)
   in (CppException -> Maybe CppExceptionPtr)
-> (CppExceptionPtr -> IO a) -> IO a -> IO a
forall e b a.
Exception e =>
(e -> Maybe b) -> (b -> IO a) -> IO a -> IO a
handleJust CppException -> Maybe CppExceptionPtr
select CppExceptionPtr -> IO a
doHandle IO a
io

-- | Low-level wrapper around @nix::handleExceptions(rethrow_exception(e))@.
handleExceptionPtr :: ByteString -> C.CppExceptionPtr -> IO C.CInt
handleExceptionPtr :: ByteString -> CppExceptionPtr -> IO CInt
handleExceptionPtr ByteString
programName CppExceptionPtr
eptr =
  [C.throwBlock| int {
    auto & eptr = *$fptr-ptr:(std::exception_ptr *eptr);
    std::string programName($bs-ptr:programName, $bs-len:programName);
    return nix::handleExceptions(programName, [&]() {
      std::rethrow_exception(eptr);
    });
  }|]