{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      : WGPU.Internal.Memory
-- Description : Managing memory.
--
-- This module contains type classes used to manage marshalling of objects into
-- memory before calling C functions.
--
-- = Motivation
--
-- In many locations in the API, we have:
--
-- A type (example only) which contains a nice Haskell representation of
-- some data:
--
-- @
-- data ApiType = ApiType { things :: Vector Thing }
-- @
--
-- and a raw type which is required for a C function:
--
-- @
-- data WGPUApiType = WGPUApiType
--   { thingsCount :: 'Word8',            -- this is an array length
--     things      :: 'Ptr' WGPUApiThing  -- this is a pointer to an array
--   }
-- @
--
-- This type class constraint represents the ability to encode @ApiType@ as
-- @WGPUApiType@, performing any necessary memory allocation and freeing:
--
-- @
-- 'ToRaw' ApiType WGPUApiType
-- @
--
-- 'ToRaw' uses the 'ContT' monad so that bracketing of the memory resources
-- can be performed around some continuation that uses the memory.
--
-- In the example above, we could write a 'ToRaw' instance as follows:
--
-- @
-- instance 'ToRaw' ApiType WGPUApiType where
--   'raw' ApiType{..} = do
--     names_ptr <- 'rawArrayPtr' names
--     'pure' $ WGPUApiType
--       { namesCount = fromIntegral . length $ names,
--         names      = names_ptr
--       }
-- @
--
-- The 'ToRawPtr' type class represents similar functionality, except that it
-- creates a pointer to a value. Thus it does both raw conversion and storing
-- the raw value in allocated memory. It exists as a separate type class so
-- that library types (eg. 'Text' and 'ByteString') can be marshalled into
-- pointers more easily.
module WGPU.Internal.Memory
  ( -- * Classes
    ToRaw (raw),
    ToRawPtr (rawPtr),

    -- * Functions
    rawArrayPtr,
    showWithPtr,
  )
where

import Control.Monad (forM_)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Cont (ContT (ContT), evalContT)
import Data.ByteString (ByteString)
import Data.ByteString.Unsafe (unsafeUseAsCString)
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Vector.Generic (Vector)
import qualified Data.Vector.Generic as Vector
import Data.Word (Word8)
import Foreign
  ( Ptr,
    Storable,
    advancePtr,
    allocaArray,
    castPtr,
    fillBytes,
    poke,
    sizeOf,
    with,
  )
import Foreign.C (CBool, CChar, withCString)

-------------------------------------------------------------------------------
-- Type Classes

-- | Represents a value of type @a@ that can be stored as type @b@ in the
-- 'ContT' monad.
--
-- Implementations of this type class should bracket any resource management for
-- creating the @b@ value around the continuation. For example. memory to hold
-- elements of @b@ should be allocated and freed in a bracketed fashion.
class ToRaw a b | a -> b where
  -- | Convert a value to a raw representation, bracketing any resource
  -- management.
  raw :: a -> ContT c IO b

-- | Represents a value of type @a@ that can be stored as type @(Ptr b)@ in the
-- 'ContT' monad.
--
-- Implementations of this type class should bracket resource management for
-- creating @('Ptr' b)@ around the continuation. In particular, the memory
-- allocated for @('Ptr' b)@ must be allocated before the continuation is
-- called, and freed afterward.
class ToRawPtr a b where
  rawPtr :: a -> ContT c IO (Ptr b)

-------------------------------------------------------------------------------
-- Derived Functionality

-- | Return a pointer to an allocated array, populated with raw values from a
-- vector.
rawArrayPtr ::
  forall v a b c.
  (ToRaw a b, Storable b, Vector v a) =>
  -- | Vector of values to store in a C array.
  v a ->
  -- | Pointer to the array with raw values stored in it.
  ContT c IO (Ptr b)
rawArrayPtr :: v a -> ContT c IO (Ptr b)
rawArrayPtr v a
xs =
  ((Ptr b -> IO c) -> IO c) -> ContT c IO (Ptr b)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr b -> IO c) -> IO c) -> ContT c IO (Ptr b))
-> ((Ptr b -> IO c) -> IO c) -> ContT c IO (Ptr b)
forall a b. (a -> b) -> a -> b
$ \Ptr b -> IO c
action -> do
    let n :: Int
        n :: Int
n = v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
Vector.length v a
xs
    Int -> (Ptr b -> IO c) -> IO c
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
n ((Ptr b -> IO c) -> IO c) -> (Ptr b -> IO c) -> IO c
forall a b. (a -> b) -> a -> b
$ \Ptr b
arrayPtr ->
      ContT c IO c -> IO c
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT c IO c -> IO c) -> ContT c IO c -> IO c
forall a b. (a -> b) -> a -> b
$ do
        [(a, Int)] -> ((a, Int) -> ContT c IO ()) -> ContT c IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_
          ([a] -> [Int] -> [(a, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip (v a -> [a]
forall (v :: * -> *) a. Vector v a => v a -> [a]
Vector.toList v a
xs) [Int
0 ..])
          (\(a
x, Int
i) -> a -> Ptr b -> ContT c IO ()
pokeRaw a
x (Ptr b -> Int -> Ptr b
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr b
arrayPtr Int
i))
        IO c -> ContT c IO c
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Ptr b -> IO c
action Ptr b
arrayPtr)
  where
    pokeRaw :: a -> Ptr b -> ContT c IO ()
    pokeRaw :: a -> Ptr b -> ContT c IO ()
pokeRaw a
value Ptr b
raw_ptr = a -> ContT c IO b
forall a b c. ToRaw a b => a -> ContT c IO b
raw a
value ContT c IO b -> (b -> ContT c IO ()) -> ContT c IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO () -> ContT c IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT c IO ()) -> (b -> IO ()) -> b -> ContT c IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr b -> b -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr b
raw_ptr

-------------------------------------------------------------------------------
-- Instances

-- Allow every ToRaw instance to be a ToRawPtr instance.
instance {-# OVERLAPPABLE #-} (Storable b, ToRaw a b) => ToRawPtr a b where
  rawPtr :: a -> ContT c IO (Ptr b)
rawPtr a
x = do
    b
rawX <- a -> ContT c IO b
forall a b c. ToRaw a b => a -> ContT c IO b
raw a
x
    ((Ptr b -> IO c) -> IO c) -> ContT c IO (Ptr b)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr b -> IO c) -> IO c) -> ContT c IO (Ptr b))
-> ((Ptr b -> IO c) -> IO c) -> ContT c IO (Ptr b)
forall a b. (a -> b) -> a -> b
$ b -> (Ptr b -> IO c) -> IO c
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
zeroingWith b
rawX

instance ToRaw Bool CBool where raw :: Bool -> ContT c IO CBool
raw Bool
x = CBool -> ContT c IO CBool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (if Bool
x then CBool
1 else CBool
0)

instance ToRawPtr Text CChar where rawPtr :: Text -> ContT c IO (Ptr CChar)
rawPtr = ((Ptr CChar -> IO c) -> IO c) -> ContT c IO (Ptr CChar)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr CChar -> IO c) -> IO c) -> ContT c IO (Ptr CChar))
-> (Text -> (Ptr CChar -> IO c) -> IO c)
-> Text
-> ContT c IO (Ptr CChar)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> (Ptr CChar -> IO c) -> IO c
forall a. String -> (Ptr CChar -> IO a) -> IO a
withCString (String -> (Ptr CChar -> IO c) -> IO c)
-> (Text -> String) -> Text -> (Ptr CChar -> IO c) -> IO c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
Text.unpack

instance ToRawPtr ByteString Word8 where
  rawPtr :: ByteString -> ContT c IO (Ptr Word8)
rawPtr ByteString
byteString =
    ((Ptr Word8 -> IO c) -> IO c) -> ContT c IO (Ptr Word8)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr Word8 -> IO c) -> IO c) -> ContT c IO (Ptr Word8))
-> ((Ptr Word8 -> IO c) -> IO c) -> ContT c IO (Ptr Word8)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8 -> IO c
action -> ByteString -> (Ptr CChar -> IO c) -> IO c
forall a. ByteString -> (Ptr CChar -> IO a) -> IO a
unsafeUseAsCString ByteString
byteString (Ptr Word8 -> IO c
action (Ptr Word8 -> IO c)
-> (Ptr CChar -> Ptr Word8) -> Ptr CChar -> IO c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr)

-------------------------------------------------------------------------------
-- Utils

-- | Like 'with', but zeroes memory after the action has been performed.
--
-- Allocates memory for a value of type @a@ and fills the memory with the
-- 'Foreign.Storable' representation of the @a@ value.
zeroingWith ::
  Storable a =>
  -- | Value to use.
  a ->
  -- | Action to perform with a pointer to the value.
  (Ptr a -> IO b) ->
  -- | Result of running the action.
  IO b
zeroingWith :: a -> (Ptr a -> IO b) -> IO b
zeroingWith a
value Ptr a -> IO b
action =
  a -> (Ptr a -> IO b) -> IO b
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with a
value ((Ptr a -> IO b) -> IO b) -> (Ptr a -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \Ptr a
value_ptr -> do
    b
result <- Ptr a -> IO b
action Ptr a
value_ptr
    Ptr a -> Word8 -> Int -> IO ()
forall a. Ptr a -> Word8 -> Int -> IO ()
fillBytes Ptr a
value_ptr Word8
0x00 (a -> Int
forall a. Storable a => a -> Int
sizeOf a
value)
    b -> IO b
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
result

-- | Formatter for 'Show' instances for opaque pointers.
--
-- Displays a name and a corresponding opaque pointer.
showWithPtr ::
  -- | Name of the type.
  String ->
  -- | Opaque pointer that the type contains.
  Ptr a ->
  -- | Final show string.
  String
showWithPtr :: String -> Ptr a -> String
showWithPtr String
name Ptr a
ptr = String
"<" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
":" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Ptr a -> String
forall a. Show a => a -> String
show Ptr a
ptr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
">"