{-# LANGUAGE Strict #-}

module Data.SpirV.Reflect.FFI
  ( Module
  , load
  , loadBytes
  ) where

import Control.Exception (onException)
import Control.Monad.IO.Class (MonadIO(..))
import Data.ByteString (ByteString)
import Data.ByteString qualified as ByteString
import Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import Data.SpirV.Reflect.FFI.Internal qualified as C
import Foreign (callocBytes, castPtr, free)
import System.Mem.Weak (addFinalizer)

import Data.SpirV.Reflect.Module (Module)

load :: MonadIO io => FilePath -> io Module
load :: forall (io :: * -> *). MonadIO io => FilePath -> io Module
load FilePath
path = IO Module -> io Module
forall a. IO a -> io a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Module -> io Module) -> IO Module -> io Module
forall a b. (a -> b) -> a -> b
$
  FilePath -> IO ByteString
ByteString.readFile FilePath
path IO ByteString -> (ByteString -> IO Module) -> IO Module
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> IO Module
forall (io :: * -> *). MonadIO io => ByteString -> io Module
loadBytes

loadBytes :: MonadIO io => ByteString -> io Module
loadBytes :: forall (io :: * -> *). MonadIO io => ByteString -> io Module
loadBytes ByteString
bytes = IO Module -> io Module
forall a. IO a -> io a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
  ByteString -> (CStringLen -> IO Module) -> IO Module
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
bytes \(Ptr CChar
code, Int
size) -> do
    Ptr ()
smPtr <- Int -> IO (Ptr ())
forall a. Int -> IO (Ptr a)
callocBytes Int
C.shaderModuleSize
    (IO Module -> IO () -> IO Module)
-> IO () -> IO Module -> IO Module
forall a b c. (a -> b -> c) -> b -> a -> c
flip IO Module -> IO () -> IO Module
forall a b. IO a -> IO b -> IO a
onException (Ptr () -> IO ()
forall a. Ptr a -> IO ()
free Ptr ()
smPtr) do
      Result
res <- ModuleFlags -> CULong -> Ptr () -> Ptr () -> IO Result
C.createShaderModule2
        ModuleFlags
C.SpvReflectModuleFlagNoCopy
        (Int -> CULong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
size)
        (Ptr CChar -> Ptr ()
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
code)
        Ptr ()
smPtr
      case Result
res of
        Result
C.SpvReflectResultSuccess -> do
          Module
m <- Ptr () -> IO Module
C.inflateModule Ptr ()
smPtr IO Module -> IO () -> IO Module
forall a b. IO a -> IO b -> IO a
`onException` Ptr () -> IO ()
C.destroyShaderModule Ptr ()
smPtr
          Module -> IO () -> IO ()
forall key. key -> IO () -> IO ()
addFinalizer Module
m (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr () -> IO ()
C.destroyShaderModule Ptr ()
smPtr
          pure Module
m
        Result
err ->
          FilePath -> IO Module
forall a. HasCallStack => FilePath -> a
error (FilePath -> IO Module) -> FilePath -> IO Module
forall a b. (a -> b) -> a -> b
$ FilePath
"spvReflectCreateShaderModule2:" FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> Result -> FilePath
forall a. Show a => a -> FilePath
show Result
err