{-# LANGUAGE CPP                  #-}
{-# LANGUAGE DeriveDataTypeable   #-}
{-# LANGUAGE OverloadedStrings    #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TypeApplications     #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- Don't warn about lua_concat; the way it's use here is safe.
{-# OPTIONS_GHC -Wno-warnings-deprecations #-}
{-|
Module      : HsLua.Core.Error
Copyright   : © 2017-2023 Albert Krewinkel
License     : MIT
Maintainer  : Albert Krewinkel <tarleb@hslua.org>

Lua exceptions and exception handling.
-}
module HsLua.Core.Error
  ( Exception (..)
  , LuaError (..)
  , Lua
  , try
  , failLua
  , throwErrorAsException
  , throwTypeMismatchError
  , changeErrorType
    -- * Helpers for hslua C wrapper functions.
  , liftLuaThrow
  , popErrorMessage
  , pushTypeMismatchError
  ) where

import Control.Applicative (Alternative (..))
import Control.Monad ((<$!>), void)
import Data.ByteString (ByteString)
import Data.Proxy (Proxy (Proxy))
import Data.Typeable (Typeable)
import Foreign.Marshal.Alloc (alloca)
import Foreign.Ptr
import HsLua.Core.Types (LuaE, liftLua)
import Lua

import qualified Control.Exception as E
import qualified Control.Monad.Catch as Catch
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as Char8
import qualified Data.ByteString.Unsafe as B
import qualified Foreign.Storable as Storable
import qualified HsLua.Core.Types as Lua
import qualified HsLua.Core.Utf8 as Utf8

#if !MIN_VERSION_base(4,13,0)
import Control.Monad.Fail (MonadFail (..))
#endif

-- | A Lua operation.
--
-- This type is suitable for most users. It uses a default exception for
-- error handling. Users who need more control over error handling can
-- use 'LuaE' with a custom error type instead.
type Lua a = LuaE Exception a

-- | Any type that you wish to use for error handling in HsLua must be
-- an instance of the @LuaError@ class.
class E.Exception e => LuaError e where
  -- | Converts the error at the top of the stack into an exception and
  -- pops the error off the stack.
  --
  -- This function is expected to produce a valid result for any Lua
  -- value; neither a Haskell exception nor a Lua error may result when
  -- this is called.
  popException :: LuaE e e
  -- | Pushes an exception to the top of the Lua stack. The pushed Lua
  -- object is used as an error object, and it is recommended that
  -- calling @tostring()@ on the object produces an informative message.
  pushException :: e -> LuaE e ()
  -- | Creates a new exception with the given message.
  luaException :: String -> e

-- | Default Lua error type. Exceptions raised by Lua-related operations.
newtype Exception = Exception { Exception -> String
exceptionMessage :: String}
  deriving (Exception -> Exception -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Exception -> Exception -> Bool
$c/= :: Exception -> Exception -> Bool
== :: Exception -> Exception -> Bool
$c== :: Exception -> Exception -> Bool
Eq, Typeable)

instance Show Exception where
  show :: Exception -> String
show (Exception String
msg) = String
"Lua exception: " forall a. [a] -> [a] -> [a]
++ String
msg

instance E.Exception Exception

instance LuaError Exception where
  popException :: LuaE Exception Exception
popException = do
    String -> Exception
Exception forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
Utf8.toString forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> forall a e. (State -> IO a) -> LuaE e a
liftLua State -> IO ByteString
popErrorMessage
  {-# INLINABLE popException #-}

  pushException :: Exception -> LuaE Exception ()
pushException (Exception String
msg) = forall a e. (State -> IO a) -> LuaE e a
Lua.liftLua forall a b. (a -> b) -> a -> b
$ \State
l ->
    forall a. ByteString -> (CStringLen -> IO a) -> IO a
B.unsafeUseAsCStringLen (String -> ByteString
Utf8.fromString String
msg) forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
msgPtr, Int
z) ->
      State -> Ptr CChar -> CSize -> IO ()
lua_pushlstring State
l Ptr CChar
msgPtr (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
z)
  {-# INLINABLE pushException #-}

  luaException :: String -> Exception
luaException = String -> Exception
Exception
  {-# INLINABLE luaException #-}

-- | Return either the result of a Lua computation or, if an exception was
-- thrown, the error.
try :: Catch.Exception e => LuaE e a -> LuaE e (Either e a)
try :: forall e a. Exception e => LuaE e a -> LuaE e (Either e a)
try = forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
Catch.try
{-# INLINABLE try #-}

-- | Raises an exception in the Lua monad.
failLua :: forall e a. LuaError e => String -> LuaE e a
failLua :: forall e a. LuaError e => String -> LuaE e a
failLua String
msg = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
Catch.throwM (forall e. LuaError e => String -> e
luaException @e String
msg)
{-# INLINABLE failLua #-}

-- | Converts a Lua error at the top of the stack into a Haskell
-- exception and throws it.
throwErrorAsException :: LuaError e => LuaE e a
throwErrorAsException :: forall e a. LuaError e => LuaE e a
throwErrorAsException = do
  e
err <- forall e. LuaError e => LuaE e e
popException
  forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
Catch.throwM forall a b. (a -> b) -> a -> b
$! e
err
{-# INLINABLE throwErrorAsException #-}

-- | Raises an exception that's appropriate when the type of a Lua
-- object at the given index did not match the expected type. The name
-- or description of the expected type is taken as an argument.
throwTypeMismatchError :: forall e a. LuaError e
                       => ByteString -> StackIndex -> LuaE e a
throwTypeMismatchError :: forall e a. LuaError e => ByteString -> StackIndex -> LuaE e a
throwTypeMismatchError ByteString
expected StackIndex
idx = do
  forall e. ByteString -> StackIndex -> LuaE e ()
pushTypeMismatchError ByteString
expected StackIndex
idx
  forall e a. LuaError e => LuaE e a
throwErrorAsException
{-# INLINABLE throwTypeMismatchError #-}

-- | Change the error type of a computation.
changeErrorType :: forall old new a. LuaE old a -> LuaE new a
changeErrorType :: forall old new a. LuaE old a -> LuaE new a
changeErrorType LuaE old a
op = forall a e. (State -> IO a) -> LuaE e a
Lua.liftLua forall a b. (a -> b) -> a -> b
$ \State
l -> do
  a
x <- forall e a. State -> LuaE e a -> IO a
Lua.runWith State
l LuaE old a
op
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! a
x
{-# INLINABLE changeErrorType #-}


--
-- Orphan instances
--

instance LuaError e => Alternative (LuaE e) where
  empty :: forall a. LuaE e a
empty = forall e a. LuaError e => String -> LuaE e a
failLua String
"empty"
  LuaE e a
x <|> :: forall a. LuaE e a -> LuaE e a -> LuaE e a
<|> LuaE e a
y = LuaE e a
x forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`Catch.catch` (\(e
_ :: e) -> LuaE e a
y)

instance LuaError e => MonadFail (LuaE e) where
  fail :: forall a. String -> LuaE e a
fail = forall e a. LuaError e => String -> LuaE e a
failLua

--
-- Helpers
--

-- | Takes a failable HsLua function and transforms it into a
-- monadic 'LuaE' operation. Throws an exception if an error
-- occured.
liftLuaThrow :: forall e a. LuaError e
             => (Lua.State -> Ptr Lua.StatusCode -> IO a)
             -> LuaE e a
liftLuaThrow :: forall e a.
LuaError e =>
(State -> Ptr StatusCode -> IO a) -> LuaE e a
liftLuaThrow State -> Ptr StatusCode -> IO a
f = forall a e. (State -> IO a) -> LuaE e a
Lua.liftLua (forall e a.
LuaError e =>
Proxy e -> (State -> Ptr StatusCode -> IO a) -> State -> IO a
throwOnError (forall {k} (t :: k). Proxy t
Proxy @e) State -> Ptr StatusCode -> IO a
f)

-- | Helper function which takes an ersatz function and checks for
-- errors during its execution. If an error occured, it is converted
-- into a 'LuaError' and thrown as an exception.
throwOnError :: forall e a. LuaError e
             => Proxy e
             -> (Lua.State -> Ptr Lua.StatusCode -> IO a)
             -> Lua.State
             -> IO a
throwOnError :: forall e a.
LuaError e =>
Proxy e -> (State -> Ptr StatusCode -> IO a) -> State -> IO a
throwOnError Proxy e
_errProxy State -> Ptr StatusCode -> IO a
f State
l = forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr StatusCode
statusPtr -> do
  a
result <- State -> Ptr StatusCode -> IO a
f State
l Ptr StatusCode
statusPtr
  StatusCode
status <- forall a. Storable a => Ptr a -> IO a
Storable.peek Ptr StatusCode
statusPtr
  if StatusCode
status forall a. Eq a => a -> a -> Bool
== StatusCode
LUA_OK
    then forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! a
result
    else forall e a. State -> LuaE e a -> IO a
Lua.runWith State
l (forall e a. LuaError e => LuaE e a
throwErrorAsException @e)


-- | Retrieve and pop the top object as an error message. This is very
-- similar to tostring', but ensures that we don't recurse if getting
-- the message failed.
--
-- This helpful as a \"last resort\" method when implementing
-- 'popException'.
popErrorMessage :: Lua.State -> IO ByteString
popErrorMessage :: State -> IO ByteString
popErrorMessage State
l = forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr CSize
lenPtr -> do
  Ptr CChar
cstr <- State -> StackIndex -> Ptr CSize -> IO (Ptr CChar)
hsluaL_tolstring State
l (-StackIndex
1) Ptr CSize
lenPtr
  if Ptr CChar
cstr forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr
    then do
      State -> CInt -> IO ()
lua_pop State
l CInt
1
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ String -> ByteString
Char8.pack
        String
"An error occurred, but the error object could not be retrieved."
    else do
      CSize
cstrLen <- forall a. Storable a => Ptr a -> IO a
Storable.peek Ptr CSize
lenPtr
      ByteString
msg <- CStringLen -> IO ByteString
B.packCStringLen (Ptr CChar
cstr, forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
cstrLen)
      State -> CInt -> IO ()
lua_pop State
l CInt
2  -- pop original msg and product of hsluaL_tolstring
      forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
msg

-- | Creates an error to notify about a Lua type mismatch and pushes it
-- to the stack.
pushTypeMismatchError :: ByteString  -- ^ name or description of expected type
                      -> StackIndex  -- ^ stack index of mismatching object
                      -> LuaE e ()
pushTypeMismatchError :: forall e. ByteString -> StackIndex -> LuaE e ()
pushTypeMismatchError ByteString
expected StackIndex
idx = forall a e. (State -> IO a) -> LuaE e a
liftLua forall a b. (a -> b) -> a -> b
$ \State
l -> do
  let pushtype :: IO (Ptr CChar)
pushtype = State -> StackIndex -> IO TypeCode
lua_type State
l StackIndex
idx forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= State -> TypeCode -> IO (Ptr CChar)
lua_typename State
l forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= State -> Ptr CChar -> IO (Ptr CChar)
lua_pushstring State
l
  forall a. ByteString -> (Ptr CChar -> IO a) -> IO a
B.unsafeUseAsCString ByteString
"__name" (State -> StackIndex -> Ptr CChar -> IO TypeCode
luaL_getmetafield State
l StackIndex
idx) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    TypeCode
LUA_TSTRING -> forall (m :: * -> *) a. Monad m => a -> m a
return () -- pushed the name
    TypeCode
LUA_TNIL    -> forall (f :: * -> *) a. Functor f => f a -> f ()
void IO (Ptr CChar)
pushtype
    TypeCode
_           -> State -> CInt -> IO ()
lua_pop State
l CInt
1 forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* IO (Ptr CChar)
pushtype
  let pushstring :: ByteString -> IO ()
pushstring ByteString
str = forall a. ByteString -> (CStringLen -> IO a) -> IO a
B.unsafeUseAsCStringLen ByteString
str forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
cstr, Int
cstrLen) ->
        State -> Ptr CChar -> CSize -> IO ()
lua_pushlstring State
l Ptr CChar
cstr (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
cstrLen)
  ByteString -> IO ()
pushstring ByteString
expected
  ByteString -> IO ()
pushstring ByteString
" expected, got "
  State -> StackIndex -> CInt -> IO ()
lua_rotate State
l (-StackIndex
3) (-CInt
1)  -- move actual type to the end
  State -> CInt -> IO ()
lua_concat State
l CInt
3