{- |
A 'Marshal' class that is compatible with LLVM's data layout.
Most prominent difference is that LLVM's @i1@ requires a byte in memory,
whereas Haskell's 'Bool' occupies a 32-bit word.
Additionally this class supports 'Data.Struct', 'Data.Vector', 'Data.Array'.
-}
module LLVM.ExecutionEngine.Marshal (
    Marshal(..),
    sizeOf,
    alignment,
    StructFields,
    sizeOfArray,
    pokeList,
    ) where

import qualified LLVM.Core.Vector as Vector ()
import qualified LLVM.Core.Data as Data
import qualified LLVM.Core.Type as Type
import qualified LLVM.Util.Proxy as LP
import qualified LLVM.ExecutionEngine.Target as Target
import LLVM.ExecutionEngine.Target (TargetData)

import qualified LLVM.FFI.Core as FFI

import qualified Type.Data.Num.Decimal.Number as Dec
import Type.Base.Proxy (Proxy(Proxy))

import qualified Foreign.Storable as Store
import Foreign.StablePtr (StablePtr)
import Foreign.Ptr (Ptr, FunPtr, castPtr, plusPtr)

import System.IO.Unsafe (unsafePerformIO)

import qualified Control.Monad.Trans.State as MS
import Control.Applicative (liftA2, pure)

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Word (Word8, Word16, Word32, Word64)



targetData :: TargetData
targetData = unsafePerformIO Target.getTargetData


sizeOf :: (Type.IsType a) => LP.Proxy a -> Int
sizeOf = Target.storeSizeOfType targetData . Type.unsafeTypeRef

alignment :: (Type.IsType a) => LP.Proxy a -> Int
alignment = Target.abiAlignmentOfType targetData . Type.unsafeTypeRef

sizeOfArray :: (Type.IsType a) => LP.Proxy a -> Int -> Int
sizeOfArray proxy n =
   Target.abiSizeOfType targetData (Type.unsafeTypeRef proxy) * n


class (Type.IsType a) => Marshal a where
    peek :: Ptr a -> IO a
    poke :: Ptr a -> a -> IO ()

peekPrimitive :: (Store.Storable a) => Ptr a -> IO a
peekPrimitive = Store.peek

pokePrimitive :: (Store.Storable a) => Ptr a -> a -> IO ()
pokePrimitive = Store.poke

instance Marshal Float  where
    peek = peekPrimitive; poke = pokePrimitive
instance Marshal Double where
    peek = peekPrimitive; poke = pokePrimitive

instance Marshal Int8  where
    peek = peekPrimitive; poke = pokePrimitive
instance Marshal Int16 where
    peek = peekPrimitive; poke = pokePrimitive
instance Marshal Int32 where
    peek = peekPrimitive; poke = pokePrimitive
instance Marshal Int64 where
    peek = peekPrimitive; poke = pokePrimitive
instance Marshal Word8  where
    peek = peekPrimitive; poke = pokePrimitive
instance Marshal Word16 where
    peek = peekPrimitive; poke = pokePrimitive
instance Marshal Word32 where
    peek = peekPrimitive; poke = pokePrimitive
instance Marshal Word64 where
    peek = peekPrimitive; poke = pokePrimitive
instance (Type.IsType a) => Marshal (Ptr a) where
    peek = peekPrimitive; poke = pokePrimitive
instance (Type.IsFunction a) => Marshal (FunPtr a) where
    peek = peekPrimitive; poke = pokePrimitive
instance Marshal (StablePtr a) where
    peek = peekPrimitive; poke = pokePrimitive

instance Marshal Bool where
    peek = fmap (/= 0) . Store.peek . castBoolPtr
    poke ptr a = Store.poke (castBoolPtr ptr) (fromIntegral $ fromEnum a)

castBoolPtr :: Ptr Bool -> Ptr Word8
castBoolPtr = castPtr

instance
    (Type.Natural n, Marshal a, Type.IsSized a) =>
        Marshal (Data.Array n a) where
    peek = peekArray Proxy LP.Proxy
    poke = pokeArray (\(Data.Array as) -> as)

instance
    (Type.Positive n, Marshal a, Type.IsPrimitive a) =>
        Marshal (Data.Vector n a) where
    peek = peekVector Proxy LP.Proxy
    poke = pokeArray Fold.toList

peekArray ::
    (Type.Natural n, Marshal a) =>
    Proxy n -> LP.Proxy a ->
    Ptr (Data.Array n a) -> IO (Data.Array n a)
peekArray n proxy =
    let step = Target.abiSizeOfType targetData $ Type.unsafeTypeRef proxy
    in \ptr ->
        fmap Data.Array $ mapM peek $
        take (Dec.integralFromProxy n) $
        iterate (flip plusPtr step) (castElemPtr ptr)

peekVector ::
    (Type.Positive n, Marshal a) =>
    Proxy n -> LP.Proxy a ->
    Ptr (Data.Vector n a) -> IO (Data.Vector n a)
peekVector _n proxy =
    let step = Target.abiSizeOfType targetData $ Type.unsafeTypeRef proxy
    in \ptr ->
        flip MS.evalStateT (castElemPtr ptr) $
        Trav.traverse
            (\() -> MS.StateT $ \ptri -> do
                a <- peek ptri
                return (a, plusPtr ptri step))
            (pure ())

pokeArray :: (Marshal a) => (f a -> [a]) -> Ptr (f a) -> f a -> IO ()
pokeArray toList ptr = pokeList (castElemPtr ptr) . toList

pokeList :: (Marshal a) => Ptr a -> [a] -> IO ()
pokeList = pokeListAux LP.Proxy

pokeListAux :: (Marshal a) => LP.Proxy a -> Ptr a -> [a] -> IO ()
pokeListAux proxy =
    let step = Target.abiSizeOfType targetData $ Type.unsafeTypeRef proxy
    in \ptr -> sequence_ . zipWith poke (iterate (flip plusPtr step) ptr)

castElemPtr :: Ptr (f a) -> Ptr a
castElemPtr = castPtr


instance (StructFields fields) => Marshal (Data.Struct fields) where
    peek = withPtrProxy $ \proxy ->
        let typeRef = Type.unsafeTypeRef proxy
        in fmap Data.Struct . peekStruct typeRef 0
    poke = withPtrProxy $ \proxy ->
        let typeRef = Type.unsafeTypeRef proxy
            pokePlain = pokeStruct typeRef 0
        in \ptr (Data.Struct as) -> pokePlain ptr as

withPtrProxy :: (LP.Proxy a -> Ptr a -> b) -> Ptr a -> b
withPtrProxy act = act LP.Proxy

class (Type.StructFields fields) => StructFields fields where
    peekStruct :: FFI.TypeRef -> Int -> Ptr struct -> IO fields
    pokeStruct :: FFI.TypeRef -> Int -> Ptr struct -> fields -> IO ()

instance
    (Marshal a, Type.IsSized a, StructFields as) =>
        StructFields (a,as) where
    peekStruct typeRef i =
        let offset = Target.offsetOfElement targetData typeRef i
            peekIs = peekStruct typeRef (i+1)
        in \ptr -> liftA2 (,) (peek $ plusPtr ptr offset) (peekIs ptr)
    pokeStruct typeRef i =
        let offset = Target.offsetOfElement targetData typeRef i
            pokeIs = pokeStruct typeRef (i+1)
        in \ptr (a,as) -> poke (plusPtr ptr offset) a >> pokeIs ptr as

instance StructFields () where
    peekStruct _type _i _ptr = return ()
    pokeStruct _type _i _ptr () = return ()