{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      : WGPU.Internal.Memory
-- Description : Managing memory.
--
-- This module contains type classes used to manage marshalling of objects into
-- memory before calling C functions.
--
-- = Motivation
--
-- In many locations in the API, we have:
--
-- A type (example only) which contains a nice Haskell representation of
-- some data:
--
-- @
-- data ApiType = ApiType { things :: Vector Thing }
-- @
--
-- and a raw type which is required for a C function:
--
-- @
-- data WGPUApiType = WGPUApiType
--   { thingsCount :: 'Word8',            -- this is an array length
--     things      :: 'Ptr' WGPUApiThing  -- this is a pointer to an array
--   }
-- @
--
-- This type class constraint represents the ability to encode @ApiType@ as
-- @WGPUApiType@, performing any necessary memory allocation and freeing:
--
-- @
-- 'ToRaw' ApiType WGPUApiType
-- @
--
-- 'ToRaw' uses the 'ContT' monad so that bracketing of the memory resources
-- can be performed around some continuation that uses the memory.
--
-- In the example above, we could write a 'ToRaw' instance as follows:
--
-- @
-- instance 'ToRaw' ApiType WGPUApiType where
--   'raw' ApiType{..} = do
--     names_ptr <- 'rawArrayPtr' names
--     'pure' $ WGPUApiType
--       { namesCount = fromIntegral . length $ names,
--         names      = names_ptr
--       }
-- @
--
-- The 'ToRawPtr' type class represents similar functionality, except that it
-- creates a pointer to a value. Thus it does both raw conversion and storing
-- the raw value in allocated memory. It exists as a separate type class so
-- that library types (eg. 'Text' and 'ByteString') can be marshalled into
-- pointers more easily.
module WGPU.Internal.Memory
  ( -- * Classes
    ToRaw (raw),
    FromRaw (fromRaw),
    ToRawPtr (rawPtr),
    FromRawPtr (fromRawPtr),

    -- * Functions

    -- ** Internal
    evalContT,
    allocaC,
    rawArrayPtr,
    showWithPtr,
    withCZeroingAfter,

    -- ** Lifted to MonadIO
    newEmptyMVar,
    takeMVar,
    putMVar,
    freeHaskellFunPtr,
    poke,
  )
where

import Control.Concurrent (MVar)
import qualified Control.Concurrent (newEmptyMVar, putMVar, takeMVar)
import Control.Monad.Cont (ContT (ContT), callCC, runContT)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.ByteString (ByteString)
import Data.ByteString.Unsafe (unsafeUseAsCString)
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Vector.Generic (Vector)
import qualified Data.Vector.Generic as Vector
import Data.Word (Word8)
import Foreign
  ( FunPtr,
    Ptr,
    Storable,
    advancePtr,
    alignment,
    alloca,
    allocaArray,
    castPtr,
    nullPtr,
    peek,
    sizeOf,
  )
import qualified Foreign (fillBytes, freeHaskellFunPtr, poke)
import Foreign.C (CBool, CChar, peekCString, withCString)

-------------------------------------------------------------------------------
-- Type Classes

-- | Represents a value of type @a@ that can be stored as type @b@ in the
-- 'ContT' monad.
--
-- Implementations of this type class should bracket any resource management for
-- creating the @b@ value around the continuation. For example. memory to hold
-- elements of @b@ should be allocated and freed in a bracketed fashion.
class ToRaw a b | a -> b where
  -- | Convert a value to a raw representation, bracketing any resource
  -- management.
  raw :: a -> ContT r IO b

-- | Represents a value of type @a@ that can be stored as type @(Ptr b)@ in the
-- 'ContT' monad.
--
-- Implementations of this type class should bracket resource management for
-- creating @('Ptr' b)@ around the continuation. In particular, the memory
-- allocated for @('Ptr' b)@ must be allocated before the continuation is
-- called, and freed afterward.
class ToRawPtr a b where
  rawPtr :: a -> ContT r IO (Ptr b)

-- | Represents a type @a@ that can be read from a raw value @b@.
class FromRaw b a | a -> b where
  fromRaw :: MonadIO m => b -> m a

-- | Represents a type @a@ that can be read from a raw pointer @b@.
class FromRawPtr b a where
  fromRawPtr :: MonadIO m => Ptr b -> m a

-------------------------------------------------------------------------------
-- Derived Functionality

-- | Return a pointer to an allocated array, populated with raw values from a
-- vector.
rawArrayPtr ::
  forall v r a b.
  (ToRaw a b, Storable b, Vector v a) =>
  -- | Vector of values to store in a C array.
  v a ->
  -- | Pointer to the array with raw values stored in it.
  ContT r IO (Ptr b)
rawArrayPtr :: v a -> ContT r IO (Ptr b)
rawArrayPtr v a
xs = ((Ptr b -> ContT r IO (Ptr b)) -> ContT r IO (Ptr b))
-> ContT r IO (Ptr b)
forall (m :: * -> *) a b. MonadCont m => ((a -> m b) -> m a) -> m a
callCC (((Ptr b -> ContT r IO (Ptr b)) -> ContT r IO (Ptr b))
 -> ContT r IO (Ptr b))
-> ((Ptr b -> ContT r IO (Ptr b)) -> ContT r IO (Ptr b))
-> ContT r IO (Ptr b)
forall a b. (a -> b) -> a -> b
$ \Ptr b -> ContT r IO (Ptr b)
k -> do
  let pokeRaw :: a -> Ptr b -> ContT c IO ()
      pokeRaw :: a -> Ptr b -> ContT c IO ()
pokeRaw a
value Ptr b
raw_ptr = a -> ContT c IO b
forall a b r. ToRaw a b => a -> ContT r IO b
raw a
value ContT c IO b -> (b -> ContT c IO ()) -> ContT c IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO () -> ContT c IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT c IO ()) -> (b -> IO ()) -> b -> ContT c IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr b -> b -> IO ()
forall (m :: * -> *) a.
(MonadIO m, Storable a) =>
Ptr a -> a -> m ()
poke Ptr b
raw_ptr

      n :: Int
      n :: Int
n = v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
Vector.length v a
xs
  Ptr b
arrayPtr <- Int -> ContT r IO (Ptr b)
forall a r. Storable a => Int -> ContT r IO (Ptr a)
allocaArrayC Int
n
  v a -> (Int -> a -> ContT r IO ()) -> ContT r IO ()
forall (m :: * -> *) (v :: * -> *) a b.
(Monad m, Vector v a) =>
v a -> (Int -> a -> m b) -> m ()
Vector.iforM_ v a
xs ((Int -> a -> ContT r IO ()) -> ContT r IO ())
-> (Int -> a -> ContT r IO ()) -> ContT r IO ()
forall a b. (a -> b) -> a -> b
$ \Int
i a
x -> a -> Ptr b -> ContT r IO ()
forall c. a -> Ptr b -> ContT c IO ()
pokeRaw a
x (Ptr b -> Int -> Ptr b
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr b
arrayPtr Int
i)
  Ptr b
r <- Ptr b -> ContT r IO (Ptr b)
k Ptr b
arrayPtr
  Ptr b -> Int -> ContT r IO ()
forall (m :: * -> *) a. MonadIO m => Ptr a -> Int -> m ()
zeroMemory Ptr b
arrayPtr (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* b -> Int
forall a. Storable a => a -> Int
alignment (b
forall a. HasCallStack => a
undefined :: b))
  Ptr b -> ContT r IO (Ptr b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ptr b
r
{-# INLINEABLE rawArrayPtr #-}

-------------------------------------------------------------------------------
-- Instances

-- Allow every ToRaw instance to be a ToRawPtr instance.
instance {-# OVERLAPPABLE #-} (Storable b, ToRaw a b) => ToRawPtr a b where
  rawPtr :: a -> ContT r IO (Ptr b)
rawPtr a
x = a -> ContT r IO b
forall a b r. ToRaw a b => a -> ContT r IO b
raw a
x ContT r IO b -> (b -> ContT r IO (Ptr b)) -> ContT r IO (Ptr b)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= b -> ContT r IO (Ptr b)
forall a r. Storable a => a -> ContT r IO (Ptr a)
withCZeroingAfter
  {-# INLINEABLE rawPtr #-}

instance {-# OVERLAPPABLE #-} (Storable b, FromRaw b a) => FromRawPtr b a where
  fromRawPtr :: Ptr b -> m a
fromRawPtr Ptr b
ptr = (IO b -> m b
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO b -> m b) -> (Ptr b -> IO b) -> Ptr b -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr b -> IO b
forall a. Storable a => Ptr a -> IO a
peek) Ptr b
ptr m b -> (b -> m a) -> m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= b -> m a
forall b a (m :: * -> *). (FromRaw b a, MonadIO m) => b -> m a
fromRaw
  {-# INLINEABLE fromRawPtr #-}

instance ToRaw Bool CBool where
  raw :: Bool -> ContT r IO CBool
raw Bool
x = CBool -> ContT r IO CBool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (if Bool
x then CBool
1 else CBool
0)
  {-# INLINE raw #-}

instance ToRawPtr Text CChar where
  rawPtr :: Text -> ContT r IO (Ptr CChar)
rawPtr = String -> ContT r IO (Ptr CChar)
forall r. String -> ContT r IO (Ptr CChar)
withCStringC (String -> ContT r IO (Ptr CChar))
-> (Text -> String) -> Text -> ContT r IO (Ptr CChar)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
Text.unpack
  {-# INLINEABLE rawPtr #-}

instance ToRawPtr ByteString Word8 where
  rawPtr :: ByteString -> ContT r IO (Ptr Word8)
rawPtr = (Ptr CChar -> Ptr Word8)
-> ContT r IO (Ptr CChar) -> ContT r IO (Ptr Word8)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr (ContT r IO (Ptr CChar) -> ContT r IO (Ptr Word8))
-> (ByteString -> ContT r IO (Ptr CChar))
-> ByteString
-> ContT r IO (Ptr Word8)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ContT r IO (Ptr CChar)
forall r. ByteString -> ContT r IO (Ptr CChar)
unsafeUseAsCStringC
  {-# INLINEABLE rawPtr #-}

instance FromRaw (Ptr CChar) Text where
  fromRaw :: Ptr CChar -> m Text
fromRaw Ptr CChar
ptr =
    if Ptr CChar
ptr Ptr CChar -> Ptr CChar -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr CChar
forall a. Ptr a
nullPtr
      then Text -> m Text
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
Text.empty
      else (IO Text -> m Text
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Text -> m Text)
-> (Ptr CChar -> IO Text) -> Ptr CChar -> m Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> Text) -> IO String -> IO Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap String -> Text
Text.pack (IO String -> IO Text)
-> (Ptr CChar -> IO String) -> Ptr CChar -> IO Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr CChar -> IO String
peekCString) Ptr CChar
ptr
  {-# INLINEABLE fromRaw #-}

-------------------------------------------------------------------------------
-- Continuation helpers

allocaC :: Storable a => ContT r IO (Ptr a)
allocaC :: ContT r IO (Ptr a)
allocaC = ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (Ptr a -> IO r) -> IO r
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca
{-# INLINEABLE allocaC #-}

allocaArrayC :: Storable a => Int -> ContT r IO (Ptr a)
allocaArrayC :: Int -> ContT r IO (Ptr a)
allocaArrayC Int
sz = ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (Int -> (Ptr a -> IO r) -> IO r
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
sz)
{-# INLINEABLE allocaArrayC #-}

withCStringC :: String -> ContT r IO (Ptr CChar)
withCStringC :: String -> ContT r IO (Ptr CChar)
withCStringC String
str = ((Ptr CChar -> IO r) -> IO r) -> ContT r IO (Ptr CChar)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (String -> (Ptr CChar -> IO r) -> IO r
forall a. String -> (Ptr CChar -> IO a) -> IO a
withCString String
str)
{-# INLINEABLE withCStringC #-}

unsafeUseAsCStringC :: ByteString -> ContT r IO (Ptr CChar)
unsafeUseAsCStringC :: ByteString -> ContT r IO (Ptr CChar)
unsafeUseAsCStringC ByteString
byteString = ((Ptr CChar -> IO r) -> IO r) -> ContT r IO (Ptr CChar)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (ByteString -> (Ptr CChar -> IO r) -> IO r
forall a. ByteString -> (Ptr CChar -> IO a) -> IO a
unsafeUseAsCString ByteString
byteString)
{-# INLINEABLE unsafeUseAsCStringC #-}

withCZeroingAfter :: Storable a => a -> ContT r IO (Ptr a)
withCZeroingAfter :: a -> ContT r IO (Ptr a)
withCZeroingAfter a
x = ((Ptr a -> ContT r IO (Ptr a)) -> ContT r IO (Ptr a))
-> ContT r IO (Ptr a)
forall (m :: * -> *) a b. MonadCont m => ((a -> m b) -> m a) -> m a
callCC (((Ptr a -> ContT r IO (Ptr a)) -> ContT r IO (Ptr a))
 -> ContT r IO (Ptr a))
-> ((Ptr a -> ContT r IO (Ptr a)) -> ContT r IO (Ptr a))
-> ContT r IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ \Ptr a -> ContT r IO (Ptr a)
k -> do
  Ptr a
ptr <- ContT r IO (Ptr a)
forall a r. Storable a => ContT r IO (Ptr a)
allocaC
  Ptr a -> a -> ContT r IO ()
forall (m :: * -> *) a.
(MonadIO m, Storable a) =>
Ptr a -> a -> m ()
poke Ptr a
ptr a
x
  Ptr a
r <- Ptr a -> ContT r IO (Ptr a)
k Ptr a
ptr
  Ptr a -> Int -> ContT r IO ()
forall (m :: * -> *) a. MonadIO m => Ptr a -> Int -> m ()
zeroMemory Ptr a
ptr (a -> Int
forall a. Storable a => a -> Int
sizeOf a
x)
  Ptr a -> ContT r IO (Ptr a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ptr a
r
{-# INLINEABLE withCZeroingAfter #-}

-------------------------------------------------------------------------------
-- Memory actions lifted to MonadIO

newEmptyMVar :: MonadIO m => m (MVar a)
newEmptyMVar :: m (MVar a)
newEmptyMVar = IO (MVar a) -> m (MVar a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO (MVar a)
forall a. IO (MVar a)
Control.Concurrent.newEmptyMVar
{-# INLINEABLE newEmptyMVar #-}

takeMVar :: MonadIO m => MVar a -> m a
takeMVar :: MVar a -> m a
takeMVar = IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> (MVar a -> IO a) -> MVar a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVar a -> IO a
forall a. MVar a -> IO a
Control.Concurrent.takeMVar
{-# INLINEABLE takeMVar #-}

putMVar :: MonadIO m => MVar a -> a -> m ()
putMVar :: MVar a -> a -> m ()
putMVar MVar a
mvar a
x = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ MVar a -> a -> IO ()
forall a. MVar a -> a -> IO ()
Control.Concurrent.putMVar MVar a
mvar a
x
{-# INLINEABLE putMVar #-}

poke :: (MonadIO m, Storable a) => Ptr a -> a -> m ()
poke :: Ptr a -> a -> m ()
poke Ptr a
ptr a
value = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
Foreign.poke Ptr a
ptr a
value)
{-# INLINEABLE poke #-}

freeHaskellFunPtr :: MonadIO m => FunPtr a -> m ()
freeHaskellFunPtr :: FunPtr a -> m ()
freeHaskellFunPtr = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (FunPtr a -> IO ()) -> FunPtr a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunPtr a -> IO ()
forall a. FunPtr a -> IO ()
Foreign.freeHaskellFunPtr
{-# INLINEABLE freeHaskellFunPtr #-}

fillBytes :: MonadIO m => Ptr a -> Word8 -> Int -> m ()
fillBytes :: Ptr a -> Word8 -> Int -> m ()
fillBytes Ptr a
ptr Word8
x Int
sz = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Ptr a -> Word8 -> Int -> IO ()
forall a. Ptr a -> Word8 -> Int -> IO ()
Foreign.fillBytes Ptr a
ptr Word8
x Int
sz)
{-# INLINEABLE fillBytes #-}

zeroMemory :: MonadIO m => Ptr a -> Int -> m ()
zeroMemory :: Ptr a -> Int -> m ()
zeroMemory Ptr a
ptr = Ptr a -> Word8 -> Int -> m ()
forall (m :: * -> *) a. MonadIO m => Ptr a -> Word8 -> Int -> m ()
fillBytes Ptr a
ptr Word8
0x00
{-# INLINEABLE zeroMemory #-}

-------------------------------------------------------------------------------

evalContT :: Monad m => ContT a m a -> m a
evalContT :: ContT a m a -> m a
evalContT ContT a m a
cont = ContT a m a -> (a -> m a) -> m a
forall k (r :: k) (m :: k -> *) a. ContT r m a -> (a -> m r) -> m r
runContT ContT a m a
cont a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

-------------------------------------------------------------------------------

-- | Formatter for 'Show' instances for opaque pointers.
--
-- Displays a name and a corresponding opaque pointer.
showWithPtr ::
  -- | Name of the type.
  String ->
  -- | Opaque pointer that the type contains.
  Ptr a ->
  -- | Final show string.
  String
showWithPtr :: String -> Ptr a -> String
showWithPtr String
name Ptr a
ptr = String
"<" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
":" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Ptr a -> String
forall a. Show a => a -> String
show Ptr a
ptr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
">"
{-# INLINEABLE showWithPtr #-}