{-# LANGUAGE AllowAmbiguousTypes   #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MagicHash             #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE UnboxedTuples         #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- | This module provides an orphan instance of `PrimBytes` for `VulkanMarshalPrim`
--   structures.
--   This enables them to be stored in @DataFrames@ from @easytensor@ package.
--   Thanks to internal structure of Vulkan structures, they can be manipulated
--   inside DataFrames in a very efficient way (just by copying byte arrays).
--   However, original @DataFrames@ are based on unpinned arrays;
--   functions here check this and copy data to new pinned arrays if needed.
--
--   In addition to the orphan instance, this module provides a few
--   handy helper functions.
module Graphics.Vulkan.Marshal.Create.DataFrame
  ( setVec, getVec
  , fillDataFrame, withDFPtr, setDFRef
  ) where


import Foreign.Storable
import GHC.Exts                         (unsafeCoerce#)
import GHC.Base
import GHC.Ptr                          (Ptr (..))
import Graphics.Vulkan
import Graphics.Vulkan.Marshal.Create
import Graphics.Vulkan.Marshal.Internal
import Numeric.DataFrame
import Numeric.DataFrame.IO
import Numeric.Dimensions
import Numeric.PrimBytes


-- | Write an array of values in one go.
setVec :: forall fname x t
        . ( FieldType fname x ~ t
          , PrimBytes t
          , KnownDim (FieldArrayLength fname x)
          , CanWriteFieldArray fname x
          )
       => Vector t (FieldArrayLength fname x) -> CreateVkStruct x '[fname] ()
setVec :: Vector t (FieldArrayLength fname x) -> CreateVkStruct x '[fname] ()
setVec Vector t (FieldArrayLength fname x)
v
  | Dict (KnownBackends t '[FieldArrayLength fname x])
Dict <- InferKnownBackend t '[FieldArrayLength fname x] =>
Dict (KnownBackends t '[FieldArrayLength fname x])
forall k (t :: k) (ds :: [Nat]).
InferKnownBackend t ds =>
Dict (KnownBackends t ds)
inferKnownBackend @t @'[FieldArrayLength fname x]
    = (Ptr x -> IO ()) -> CreateVkStruct x '[fname] ()
forall x a (fs :: [Symbol]).
(Ptr x -> IO a) -> CreateVkStruct x fs a
unsafeIOCreate ((Ptr x -> IO ()) -> CreateVkStruct x '[fname] ())
-> (Ptr x -> IO ()) -> CreateVkStruct x '[fname] ()
forall a b. (a -> b) -> a -> b
$ \Ptr x
p -> Ptr x -> Int -> Vector t (FieldArrayLength fname x) -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr x
p (HasField fname x => Int
forall (fname :: Symbol) a. HasField fname a => Int
fieldOffset @fname @x) Vector t (FieldArrayLength fname x)
v


-- | Get an array of values, possibly without copying
--   (if vector implementation allows).
getVec :: forall fname x t
        . ( FieldType fname x ~ t
          , PrimBytes t
          , KnownDim (FieldArrayLength fname x)
          , CanReadFieldArray fname x
          )
       => x -> Vector t (FieldArrayLength fname x)
getVec :: x -> Vector t (FieldArrayLength fname x)
getVec x
x
  | ByteArray#
ba <- VkStruct (VkStruct' x) -> ByteArray#
forall a. VkStruct a -> ByteArray#
unsafeByteArray x
VkStruct (VkStruct' x)
x
  , Addr#
xaddr <- VkStruct (VkStruct' x) -> Addr#
forall a. VkStruct a -> Addr#
unsafeAddr x
VkStruct (VkStruct' x)
x
  , Addr#
baddr <- ByteArray# -> Addr#
byteArrayContents# ByteArray#
ba
  , I# Int#
off <- HasField fname x => Int
forall (fname :: Symbol) a. HasField fname a => Int
fieldOffset @fname @x
  , Dict (KnownBackends t '[FieldArrayLength fname x])
Dict <- InferKnownBackend t '[FieldArrayLength fname x] =>
Dict (KnownBackends t '[FieldArrayLength fname x])
forall k (t :: k) (ds :: [Nat]).
InferKnownBackend t ds =>
Dict (KnownBackends t ds)
inferKnownBackend @t @'[FieldArrayLength fname x]
  = Int# -> ByteArray# -> Vector t (FieldArrayLength fname x)
forall a. PrimBytes a => Int# -> ByteArray# -> a
fromBytes (Addr# -> Addr# -> Int#
minusAddr# Addr#
xaddr Addr#
baddr Int# -> Int# -> Int#
+# Int#
off) ByteArray#
ba

instance VulkanMarshal (VkStruct a) => PrimBytes (VkStruct a) where
    type PrimFields (VkStruct a) = '[]
    byteSize :: VkStruct a -> Int#
byteSize   VkStruct a
a = case VkStruct a -> Int
forall a. Storable a => a -> Int
sizeOf VkStruct a
a of (I# Int#
s) -> Int#
s
    {-# INLINE byteSize #-}
    byteAlign :: VkStruct a -> Int#
byteAlign  VkStruct a
a = case VkStruct a -> Int
forall a. Storable a => a -> Int
alignment VkStruct a
a of (I# Int#
n) -> Int#
n
    {-# INLINE byteAlign #-}
    byteOffset :: VkStruct a -> Int#
byteOffset VkStruct a
a = Addr# -> Addr# -> Int#
minusAddr# (VkStruct a -> Addr#
forall a. VkStruct a -> Addr#
unsafeAddr VkStruct a
a)
                                  (ByteArray# -> Addr#
byteArrayContents# (VkStruct a -> ByteArray#
forall a. VkStruct a -> ByteArray#
unsafeByteArray VkStruct a
a))
    {-# INLINE byteOffset #-}
    getBytes :: VkStruct a -> ByteArray#
getBytes = VkStruct a -> ByteArray#
forall a. VkStruct a -> ByteArray#
unsafeByteArray
    {-# INLINE getBytes #-}
    fromBytes :: Int# -> ByteArray# -> VkStruct a
fromBytes = Int# -> ByteArray# -> VkStruct a
forall a. Int# -> ByteArray# -> VkStruct a
unsafeFromByteArrayOffset
    {-# INLINE fromBytes #-}
    readBytes :: MutableByteArray# s
-> Int# -> State# s -> (# State# s, VkStruct a #)
readBytes MutableByteArray# s
mba Int#
off = IO (VkStruct a) -> State# s -> (# State# s, VkStruct a #)
unsafeCoerce# ((Ptr (VkStruct a) -> IO ()) -> IO (VkStruct a)
forall a. VulkanMarshal a => (Ptr a -> IO ()) -> IO a
newVkData @(VkStruct a) Ptr (VkStruct a) -> IO ()
f)
      where
        f :: Ptr (VkStruct a) -> IO ()
        f :: Ptr (VkStruct a) -> IO ()
f (Ptr Addr#
addr) = (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ())
-> (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
          (# MutableByteArray# RealWorld
-> Int# -> Addr# -> Int# -> State# RealWorld -> State# RealWorld
forall d.
MutableByteArray# d
-> Int# -> Addr# -> Int# -> State# d -> State# d
copyMutableByteArrayToAddr# (MutableByteArray# s -> MutableByteArray# RealWorld
unsafeCoerce# MutableByteArray# s
mba)
                 Int#
off Addr#
addr (VkStruct a -> Int#
forall a. PrimBytes a => a -> Int#
byteSize @(VkStruct a) VkStruct a
forall a. HasCallStack => a
undefined) State# RealWorld
s
           , () #)
    writeBytes :: MutableByteArray# s -> Int# -> VkStruct a -> State# s -> State# s
writeBytes MutableByteArray# s
mba Int#
off VkStruct a
a
      = Addr#
-> MutableByteArray# s -> Int# -> Int# -> State# s -> State# s
forall d.
Addr#
-> MutableByteArray# d -> Int# -> Int# -> State# d -> State# d
copyAddrToByteArray# (VkStruct a -> Addr#
forall a. VkStruct a -> Addr#
unsafeAddr VkStruct a
a) MutableByteArray# s
mba Int#
off (VkStruct a -> Int#
forall a. PrimBytes a => a -> Int#
byteSize @(VkStruct a) VkStruct a
forall a. HasCallStack => a
undefined)
    readAddr :: Addr# -> State# s -> (# State# s, VkStruct a #)
readAddr Addr#
addr = IO (VkStruct a) -> State# s -> (# State# s, VkStruct a #)
unsafeCoerce# (Ptr (VkStruct a) -> IO (VkStruct a)
forall a. Storable a => Ptr a -> IO a
peek (Addr# -> Ptr (VkStruct a)
forall a. Addr# -> Ptr a
Ptr Addr#
addr) :: IO (VkStruct a))
    writeAddr :: VkStruct a -> Addr# -> State# s -> State# s
writeAddr VkStruct a
a Addr#
addr State# s
s
      = case IO () -> State# s -> (# State# s, () #)
unsafeCoerce# (Ptr (VkStruct a) -> VkStruct a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Addr# -> Ptr (VkStruct a)
forall a. Addr# -> Ptr a
Ptr Addr#
addr) VkStruct a
a :: IO ()) State# s
s of
         (# State# s
s', () #) -> State# s
s'
    byteFieldOffset :: Proxy# name -> VkStruct a -> Int#
byteFieldOffset Proxy# name
_ VkStruct a
_ = Int# -> Int#
negateInt# Int#
1#



-- | Run some operation with a pointer to the first item in the frame.
--   All items of the frame are kept in a contiguous memory area accessed by
--   that pointer.
--
--   The function attempts to get an underlying `ByteArray#` without data copy;
--   otherwise, it creates a new pinned `ByteArray#` and passes a pointer to it.
--   Therefore:
--
--     * Sometimes, @Ptr a@ points to the original DF; sometimes, to a copied one.
--     * If the original DF is based on unpinned `ByteArray#`, using this
--       performs a copy anyway.
--
withDFPtr :: forall (a :: Type) (ds :: [Nat]) (b :: Type)
           . (PrimBytes a, Dimensions ds)
          => DataFrame a ds -> (Ptr a -> IO b) -> IO b
withDFPtr :: DataFrame a ds -> (Ptr a -> IO b) -> IO b
withDFPtr DataFrame a ds
x Ptr a -> IO b
k
  | Dict (KnownBackends a ds)
Dict <- InferKnownBackend a ds => Dict (KnownBackends a ds)
forall k (t :: k) (ds :: [Nat]).
InferKnownBackend t ds =>
Dict (KnownBackends t ds)
inferKnownBackend @a @ds
  , ByteArray#
ba <- DataFrame a ds -> ByteArray#
forall a. PrimBytes a => a -> ByteArray#
getBytesPinned DataFrame a ds
x = do
    b
b <- Ptr a -> IO b
k (Addr# -> Ptr a
forall a. Addr# -> Ptr a
Ptr (ByteArray# -> Addr#
byteArrayContents# ByteArray#
ba Addr# -> Int# -> Addr#
`plusAddr#` DataFrame a ds -> Int#
forall a. PrimBytes a => a -> Int#
byteOffset DataFrame a ds
x))
    (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ())
-> (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s -> (# ByteArray# -> State# RealWorld -> State# RealWorld
forall k1. k1 -> State# RealWorld -> State# RealWorld
touch# ByteArray#
ba State# RealWorld
s, () #)
    b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
b

-- | A variant of `setVkRef` that writes a pointer to a contiguous array of
--   structures.
--
--   Write a pointer to a vulkan structure - member of current structure
--    and make sure the member exists as long as this structure exists.
--
--   Prefer this function to using @unsafePtr a@, because the latter
--    does not keep the dependency information in GC, which results in
--    member structure being garbage-collected and the reference being invalid.
setDFRef :: forall fname x a (ds :: [Nat])
          . ( CanWriteField fname x
            , FieldType fname x ~ Ptr a
            , PrimBytes a, Dimensions ds
            )
         => DataFrame a ds -> CreateVkStruct x '[fname] ()
setDFRef :: DataFrame a ds -> CreateVkStruct x '[fname] ()
setDFRef DataFrame a ds
v
    | Dict (KnownBackends a ds)
Dict <- InferKnownBackend a ds => Dict (KnownBackends a ds)
forall k (t :: k) (ds :: [Nat]).
InferKnownBackend t ds =>
Dict (KnownBackends t ds)
inferKnownBackend @a @ds
    , ByteArray#
ba <- DataFrame a ds -> ByteArray#
forall a. PrimBytes a => a -> ByteArray#
getBytesPinned DataFrame a ds
v
    , Addr#
addr <- ByteArray# -> Addr#
byteArrayContents# ByteArray#
ba Addr# -> Int# -> Addr#
`plusAddr#` DataFrame a ds -> Int#
forall a. PrimBytes a => a -> Int#
byteOffset DataFrame a ds
v
     = let f :: Ptr x -> IO ( ([Ptr ()],[IO ()]) , ())
           f :: Ptr x -> IO (([Ptr ()], [IO ()]), ())
f Ptr x
p = (,) ([],[(State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ())
-> (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s -> (# ByteArray# -> State# RealWorld -> State# RealWorld
forall k1. k1 -> State# RealWorld -> State# RealWorld
touch# ByteArray#
ba State# RealWorld
s, () #)])
             (() -> (([Ptr ()], [IO ()]), ()))
-> IO () -> IO (([Ptr ()], [IO ()]), ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr x -> FieldType fname x -> IO ()
forall (fname :: Symbol) a.
CanWriteField fname a =>
Ptr a -> FieldType fname a -> IO ()
writeField @fname @x Ptr x
p (Addr# -> Ptr a
forall a. Addr# -> Ptr a
Ptr Addr#
addr)
       in  (Ptr x -> IO (([Ptr ()], [IO ()]), ()))
-> CreateVkStruct x '[fname] ()
unsafeCoerce# Ptr x -> IO (([Ptr ()], [IO ()]), ())
f -- workaround for the hidden CreateVkStruct constr.


-- | Given the number of elements, create a new pinned DataFrame and initialize
--   it using the provided function.
--
--   The argument function is called one time with a `Ptr` pointing to the
--   beginning of a contiguous array.
--   This array is converted into a dataframe, possibly without copying.
--
--   It is safe to pass result of this function to `withDFPtr`.
fillDataFrame :: forall a
               . PrimBytes a
              => Word -> (Ptr a -> IO ()) -> IO (Vector a (XN 0))
fillDataFrame :: Word -> (Ptr a -> IO ()) -> IO (Vector a (XN 0))
fillDataFrame Word
n Ptr a -> IO ()
k
  | Dx (Dim n
D :: Dim n) <- Word -> Dim (XN 0)
someDimVal Word
n
  , Dict (KnownBackends a '[n])
Dict <- InferKnownBackend a '[n] => Dict (KnownBackends a '[n])
forall k (t :: k) (ds :: [Nat]).
InferKnownBackend t ds =>
Dict (KnownBackends t ds)
inferKnownBackend @a @'[n]
  = do
     IODataFrame a '[n]
mdf <- IO (IODataFrame a '[n])
forall k t (ns :: [k]).
(PrimBytes t, Dimensions ns) =>
IO (IODataFrame t ns)
newPinnedDataFrame
     IODataFrame a '[n] -> (Ptr a -> IO ()) -> IO ()
forall k t (ns :: [k]) r.
PrimBytes t =>
IODataFrame t ns -> (Ptr t -> IO r) -> IO r
withDataFramePtr IODataFrame a '[n]
mdf Ptr a -> IO ()
k
     DataFrame a '[n] -> Vector a (XN 0)
forall l (ts :: l) (xns :: [XNat]) (ns :: [Nat]).
(All KnownDimType xns, FixedDims xns ns, Dimensions ns,
 KnownBackends ts ns) =>
DataFrame ts ns -> DataFrame ts xns
XFrame (DataFrame a '[n] -> Vector a (XN 0))
-> IO (DataFrame a '[n]) -> IO (Vector a (XN 0))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IODataFrame a '[n] -> IO (DataFrame a '[n])
forall k t (ns :: [k]).
PrimArray t (DataFrame t ns) =>
IODataFrame t ns -> IO (DataFrame t ns)
unsafeFreezeDataFrame @a @'[n] IODataFrame a '[n]
mdf
fillDataFrame Word
_ Ptr a -> IO ()
_ = [Char] -> IO (Vector a (XN 0))
forall a. HasCallStack => [Char] -> a
error [Char]
"fillDataFrame: impossible combination of arguments."