{-# LANGUAGE CPP                  #-}
{-# LANGUAGE DeriveDataTypeable   #-}
{-# LANGUAGE OverloadedStrings    #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TypeApplications     #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- Don't warn about lua_concat; the way it's use here is safe.
{-# OPTIONS_GHC -Wno-warnings-deprecations #-}
{-|
Module      : HsLua.Core.Error
Copyright   : © 2017-2022 Albert Krewinkel
License     : MIT
Maintainer  : Albert Krewinkel <tarleb+hslua@zeitkraut.de>

Lua exceptions and exception handling.
-}
module HsLua.Core.Error
  ( Exception (..)
  , LuaError (..)
  , Lua
  , try
  , failLua
  , throwErrorAsException
  , throwTypeMismatchError
  , changeErrorType
    -- * Helpers for hslua C wrapper functions.
  , liftLuaThrow
  , popErrorMessage
  , pushTypeMismatchError
  ) where

import Control.Applicative (Alternative (..))
import Control.Monad ((<$!>), void)
import Data.ByteString (ByteString)
import Data.Proxy (Proxy (Proxy))
import Data.Typeable (Typeable)
import Foreign.Marshal.Alloc (alloca)
import Foreign.Ptr
import HsLua.Core.Types (LuaE, liftLua)
import Lua

import qualified Control.Exception as E
import qualified Control.Monad.Catch as Catch
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as Char8
import qualified Data.ByteString.Unsafe as B
import qualified Foreign.Storable as Storable
import qualified HsLua.Core.Types as Lua
import qualified HsLua.Core.Utf8 as Utf8

#if !MIN_VERSION_base(4,13,0)
import Control.Monad.Fail (MonadFail (..))
#endif

-- | A Lua operation.
--
-- This type is suitable for most users. It uses a default exception for
-- error handling. Users who need more control over error handling can
-- use 'LuaE' with a custom error type instead.
type Lua a = LuaE Exception a

-- | Any type that you wish to use for error handling in HsLua must be
-- an instance of the @LuaError@ class.
class E.Exception e => LuaError e where
  -- | Converts the error at the top of the stack into an exception and
  -- pops the error off the stack.
  --
  -- This function is expected to produce a valid result for any Lua
  -- value; neither a Haskell exception nor a Lua error may result when
  -- this is called.
  popException :: LuaE e e
  -- | Pushes an exception to the top of the Lua stack. The pushed Lua
  -- object is used as an error object, and it is recommended that
  -- calling @tostring()@ on the object produces an informative message.
  pushException :: e -> LuaE e ()
  -- | Creates a new exception with the given message.
  luaException :: String -> e

-- | Default Lua error type. Exceptions raised by Lua-related operations.
newtype Exception = Exception { Exception -> String
exceptionMessage :: String}
  deriving (Exception -> Exception -> Bool
(Exception -> Exception -> Bool)
-> (Exception -> Exception -> Bool) -> Eq Exception
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Exception -> Exception -> Bool
$c/= :: Exception -> Exception -> Bool
== :: Exception -> Exception -> Bool
$c== :: Exception -> Exception -> Bool
Eq, Typeable)

instance Show Exception where
  show :: Exception -> String
show (Exception String
msg) = String
"Lua exception: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
msg

instance E.Exception Exception

instance LuaError Exception where
  popException :: LuaE Exception Exception
popException = do
    String -> Exception
Exception (String -> Exception)
-> (ByteString -> String) -> ByteString -> Exception
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
Utf8.toString (ByteString -> Exception)
-> LuaE Exception ByteString -> LuaE Exception Exception
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> (State -> IO ByteString) -> LuaE Exception ByteString
forall a e. (State -> IO a) -> LuaE e a
liftLua State -> IO ByteString
popErrorMessage
  {-# INLINABLE popException #-}

  pushException :: Exception -> LuaE Exception ()
pushException (Exception String
msg) = (State -> IO ()) -> LuaE Exception ()
forall a e. (State -> IO a) -> LuaE e a
Lua.liftLua ((State -> IO ()) -> LuaE Exception ())
-> (State -> IO ()) -> LuaE Exception ()
forall a b. (a -> b) -> a -> b
$ \State
l ->
    ByteString -> (CStringLen -> IO ()) -> IO ()
forall a. ByteString -> (CStringLen -> IO a) -> IO a
B.unsafeUseAsCStringLen (String -> ByteString
Utf8.fromString String
msg) ((CStringLen -> IO ()) -> IO ()) -> (CStringLen -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
msgPtr, Int
z) ->
      State -> Ptr CChar -> CSize -> IO ()
lua_pushlstring State
l Ptr CChar
msgPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
z)
  {-# INLINABLE pushException #-}

  luaException :: String -> Exception
luaException = String -> Exception
Exception
  {-# INLINABLE luaException #-}

-- | Return either the result of a Lua computation or, if an exception was
-- thrown, the error.
try :: Catch.Exception e => LuaE e a -> LuaE e (Either e a)
try :: LuaE e a -> LuaE e (Either e a)
try = LuaE e a -> LuaE e (Either e a)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
Catch.try
{-# INLINABLE try #-}

-- | Raises an exception in the Lua monad.
failLua :: forall e a. LuaError e => String -> LuaE e a
failLua :: String -> LuaE e a
failLua String
msg = e -> LuaE e a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
Catch.throwM (String -> e
forall e. LuaError e => String -> e
luaException @e String
msg)
{-# INLINABLE failLua #-}

-- | Converts a Lua error at the top of the stack into a Haskell
-- exception and throws it.
throwErrorAsException :: LuaError e => LuaE e a
throwErrorAsException :: LuaE e a
throwErrorAsException = do
  e
err <- LuaE e e
forall e. LuaError e => LuaE e e
popException
  e -> LuaE e a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
Catch.throwM (e -> LuaE e a) -> e -> LuaE e a
forall a b. (a -> b) -> a -> b
$! e
err
{-# INLINABLE throwErrorAsException #-}

-- | Raises an exception that's appropriate when the type of a Lua
-- object at the given index did not match the expected type. The name
-- or description of the expected type is taken as an argument.
throwTypeMismatchError :: forall e a. LuaError e
                       => ByteString -> StackIndex -> LuaE e a
throwTypeMismatchError :: ByteString -> StackIndex -> LuaE e a
throwTypeMismatchError ByteString
expected StackIndex
idx = do
  ByteString -> StackIndex -> LuaE e ()
forall e. ByteString -> StackIndex -> LuaE e ()
pushTypeMismatchError ByteString
expected StackIndex
idx
  LuaE e a
forall e a. LuaError e => LuaE e a
throwErrorAsException
{-# INLINABLE throwTypeMismatchError #-}

-- | Change the error type of a computation.
changeErrorType :: forall old new a. LuaE old a -> LuaE new a
changeErrorType :: LuaE old a -> LuaE new a
changeErrorType LuaE old a
op = (State -> IO a) -> LuaE new a
forall a e. (State -> IO a) -> LuaE e a
Lua.liftLua ((State -> IO a) -> LuaE new a) -> (State -> IO a) -> LuaE new a
forall a b. (a -> b) -> a -> b
$ \State
l -> do
  a
x <- State -> LuaE old a -> IO a
forall e a. State -> LuaE e a -> IO a
Lua.runWith State
l LuaE old a
op
  a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> IO a) -> a -> IO a
forall a b. (a -> b) -> a -> b
$! a
x
{-# INLINABLE changeErrorType #-}


--
-- Orphan instances
--

instance LuaError e => Alternative (LuaE e) where
  empty :: LuaE e a
empty = String -> LuaE e a
forall e a. LuaError e => String -> LuaE e a
failLua String
"empty"
  LuaE e a
x <|> :: LuaE e a -> LuaE e a -> LuaE e a
<|> LuaE e a
y = LuaE e a
x LuaE e a -> (e -> LuaE e a) -> LuaE e a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`Catch.catch` (\(e
_ :: e) -> LuaE e a
y)

instance LuaError e => MonadFail (LuaE e) where
  fail :: String -> LuaE e a
fail = String -> LuaE e a
forall e a. LuaError e => String -> LuaE e a
failLua

--
-- Helpers
--

-- | Takes a failable HsLua function and transforms it into a
-- monadic 'LuaE' operation. Throws an exception if an error
-- occured.
liftLuaThrow :: forall e a. LuaError e
             => (Lua.State -> Ptr Lua.StatusCode -> IO a)
             -> LuaE e a
liftLuaThrow :: (State -> Ptr StatusCode -> IO a) -> LuaE e a
liftLuaThrow State -> Ptr StatusCode -> IO a
f = (State -> IO a) -> LuaE e a
forall a e. (State -> IO a) -> LuaE e a
Lua.liftLua (Proxy e -> (State -> Ptr StatusCode -> IO a) -> State -> IO a
forall e a.
LuaError e =>
Proxy e -> (State -> Ptr StatusCode -> IO a) -> State -> IO a
throwOnError (Proxy e
forall k (t :: k). Proxy t
Proxy @e) State -> Ptr StatusCode -> IO a
f)

-- | Helper function which takes an ersatz function and checks for
-- errors during its execution. If an error occured, it is converted
-- into a 'LuaError' and thrown as an exception.
throwOnError :: forall e a. LuaError e
             => Proxy e
             -> (Lua.State -> Ptr Lua.StatusCode -> IO a)
             -> Lua.State
             -> IO a
throwOnError :: Proxy e -> (State -> Ptr StatusCode -> IO a) -> State -> IO a
throwOnError Proxy e
_errProxy State -> Ptr StatusCode -> IO a
f State
l = (Ptr StatusCode -> IO a) -> IO a
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr StatusCode -> IO a) -> IO a)
-> (Ptr StatusCode -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr StatusCode
statusPtr -> do
  a
result <- State -> Ptr StatusCode -> IO a
f State
l Ptr StatusCode
statusPtr
  StatusCode
status <- Ptr StatusCode -> IO StatusCode
forall a. Storable a => Ptr a -> IO a
Storable.peek Ptr StatusCode
statusPtr
  if StatusCode
status StatusCode -> StatusCode -> Bool
forall a. Eq a => a -> a -> Bool
== StatusCode
LUA_OK
    then a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> IO a) -> a -> IO a
forall a b. (a -> b) -> a -> b
$! a
result
    else State -> LuaE e a -> IO a
forall e a. State -> LuaE e a -> IO a
Lua.runWith State
l (forall a. LuaError e => LuaE e a
forall e a. LuaError e => LuaE e a
throwErrorAsException @e)


-- | Retrieve and pop the top object as an error message. This is very
-- similar to tostring', but ensures that we don't recurse if getting
-- the message failed.
--
-- This helpful as a \"last resort\" method when implementing
-- 'popException'.
popErrorMessage :: Lua.State -> IO ByteString
popErrorMessage :: State -> IO ByteString
popErrorMessage State
l = (Ptr CSize -> IO ByteString) -> IO ByteString
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CSize -> IO ByteString) -> IO ByteString)
-> (Ptr CSize -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr CSize
lenPtr -> do
  Ptr CChar
cstr <- State -> StackIndex -> Ptr CSize -> IO (Ptr CChar)
hsluaL_tolstring State
l (-StackIndex
1) Ptr CSize
lenPtr
  if Ptr CChar
cstr Ptr CChar -> Ptr CChar -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr CChar
forall a. Ptr a
nullPtr
    then do
      State -> CInt -> IO ()
lua_pop State
l CInt
1
      ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ String -> ByteString
Char8.pack
        String
"An error occurred, but the error object could not be retrieved."
    else do
      CSize
cstrLen <- Ptr CSize -> IO CSize
forall a. Storable a => Ptr a -> IO a
Storable.peek Ptr CSize
lenPtr
      ByteString
msg <- CStringLen -> IO ByteString
B.packCStringLen (Ptr CChar
cstr, CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
cstrLen)
      State -> CInt -> IO ()
lua_pop State
l CInt
2  -- pop original msg and product of hsluaL_tolstring
      ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
msg

-- | Creates an error to notify about a Lua type mismatch and pushes it
-- to the stack.
pushTypeMismatchError :: ByteString  -- ^ name or description of expected type
                      -> StackIndex  -- ^ stack index of mismatching object
                      -> LuaE e ()
pushTypeMismatchError :: ByteString -> StackIndex -> LuaE e ()
pushTypeMismatchError ByteString
expected StackIndex
idx = (State -> IO ()) -> LuaE e ()
forall a e. (State -> IO a) -> LuaE e a
liftLua ((State -> IO ()) -> LuaE e ()) -> (State -> IO ()) -> LuaE e ()
forall a b. (a -> b) -> a -> b
$ \State
l -> do
  let pushtype :: IO (Ptr CChar)
pushtype = State -> StackIndex -> IO TypeCode
lua_type State
l StackIndex
idx IO TypeCode -> (TypeCode -> IO (Ptr CChar)) -> IO (Ptr CChar)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= State -> TypeCode -> IO (Ptr CChar)
lua_typename State
l IO (Ptr CChar) -> (Ptr CChar -> IO (Ptr CChar)) -> IO (Ptr CChar)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= State -> Ptr CChar -> IO (Ptr CChar)
lua_pushstring State
l
  ByteString -> (Ptr CChar -> IO TypeCode) -> IO TypeCode
forall a. ByteString -> (Ptr CChar -> IO a) -> IO a
B.unsafeUseAsCString ByteString
"__name" (State -> StackIndex -> Ptr CChar -> IO TypeCode
luaL_getmetafield State
l StackIndex
idx) IO TypeCode -> (TypeCode -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    TypeCode
LUA_TSTRING -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- pushed the name
    TypeCode
LUA_TNIL    -> IO (Ptr CChar) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void IO (Ptr CChar)
pushtype
    TypeCode
_           -> State -> CInt -> IO ()
lua_pop State
l CInt
1 IO () -> IO (Ptr CChar) -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* IO (Ptr CChar)
pushtype
  let pushstring :: ByteString -> IO ()
pushstring ByteString
str = ByteString -> (CStringLen -> IO ()) -> IO ()
forall a. ByteString -> (CStringLen -> IO a) -> IO a
B.unsafeUseAsCStringLen ByteString
str ((CStringLen -> IO ()) -> IO ()) -> (CStringLen -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
cstr, Int
cstrLen) ->
        State -> Ptr CChar -> CSize -> IO ()
lua_pushlstring State
l Ptr CChar
cstr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
cstrLen)
  ByteString -> IO ()
pushstring ByteString
expected
  ByteString -> IO ()
pushstring ByteString
" expected, got "
  State -> StackIndex -> CInt -> IO ()
lua_rotate State
l (-StackIndex
3) (-CInt
1)  -- move actual type to the end
  State -> CInt -> IO ()
lua_concat State
l CInt
3