{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{- |
Transfer values between Haskell and JIT generated code
in an LLVM-compatible format.
E.g. 'Bool' is stored as 'i1' and occupies a byte,
@'Vector' n 'Bool'@ is stored as a bit vector,
@'Vector' n 'Word8'@ is stored in an order depending on machine endianess,
and Haskell tuples are stored as LLVM structs.
-}
module LLVM.Extra.Marshal (
   C(..),
   Struct,
   peek,
   poke,
   MV,

   VectorStruct,
   Vector(..),

   with,
   EE.alloca,
   ) where

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.ExecutionEngine as EE
import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum

import qualified Control.Functor.HT as FuncHT
import Control.Applicative (liftA2, liftA3, (<$>))

import Foreign.Storable (Storable)
import Foreign.StablePtr (StablePtr)
import Foreign.Ptr (FunPtr, Ptr)

import Data.Word (Word8, Word16, Word32, Word64, Word)
import Data.Int  (Int8,  Int16,  Int32,  Int64)



peek ::
   (C a, Struct a ~ struct, EE.Marshal struct) => LLVM.Ptr struct -> IO a
peek ptr = unpack <$> EE.peek ptr

poke ::
   (C a, Struct a ~ struct, EE.Marshal struct) => LLVM.Ptr struct -> a -> IO ()
poke ptr = EE.poke ptr . pack


type Struct a = Memory.Struct (Tuple.ValueOf a)

class
   (Tuple.Value a, Memory.C (Tuple.ValueOf a),
    EE.Marshal (Struct a), LLVM.IsSized (Struct a)) =>
      C a where
   pack :: a -> Struct a
   unpack :: Struct a -> a

instance C Bool   where pack = id; unpack = id
instance C Float  where pack = id; unpack = id
instance C Double where pack = id; unpack = id
instance C Word   where pack = id; unpack = id
instance C Word8  where pack = id; unpack = id
instance C Word16 where pack = id; unpack = id
instance C Word32 where pack = id; unpack = id
instance C Word64 where pack = id; unpack = id
instance C Int    where pack = id; unpack = id
instance C Int8   where pack = id; unpack = id
instance C Int16  where pack = id; unpack = id
instance C Int32  where pack = id; unpack = id
instance C Int64  where pack = id; unpack = id

instance (Storable a)        => C (Ptr a)       where pack = id; unpack = id
instance (LLVM.IsType a)     => C (LLVM.Ptr a)  where pack = id; unpack = id
instance (LLVM.IsFunction a) => C (FunPtr a)    where pack = id; unpack = id
instance                        C (StablePtr a) where pack = id; unpack = id

instance C () where
   pack = LLVM.Struct
   unpack (LLVM.Struct unit) = unit

instance
   (LLVM.IsSized (Struct a), LLVM.IsSized (Struct b), C a, C b) =>
      C (a,b) where
   pack (a,b) = LLVM.consStruct (pack a) (pack b)
   unpack = LLVM.uncurryStruct $ \a b -> (unpack a, unpack b)

instance
   (LLVM.IsSized (Struct a), LLVM.IsSized (Struct b), LLVM.IsSized (Struct c),
    C a, C b, C c) =>
      C (a,b,c) where
   pack (a,b,c) = LLVM.consStruct (pack a) (pack b) (pack c)
   unpack = LLVM.uncurryStruct $ \a b c -> (unpack a, unpack b, unpack c)

instance
   (LLVM.IsSized (Struct a), LLVM.IsSized (Struct b),
    LLVM.IsSized (Struct c), LLVM.IsSized (Struct d),
    C a, C b, C c, C d) =>
      C (a,b,c,d) where
   pack (a,b,c,d) = LLVM.consStruct (pack a) (pack b) (pack c) (pack d)
   unpack =
      LLVM.uncurryStruct $ \a b c d -> (unpack a, unpack b, unpack c, unpack d)



type VectorStruct n a = Memory.Struct (Tuple.VectorValueOf n a)

class
   (TypeNum.Positive n,
    Tuple.VectorValue n a, Memory.C (Tuple.VectorValueOf n a),
    EE.Marshal (VectorStruct n a), LLVM.IsSized (VectorStruct n a)) =>
      Vector n a where
   packVector :: LLVM.Vector n a -> VectorStruct n a
   unpackVector :: VectorStruct n a -> LLVM.Vector n a

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: LLVM.SizeOf a),
    Vector n a) =>
      C (LLVM.Vector n a) where
   pack = packVector; unpack = unpackVector


instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D1)) =>
      Vector n Bool where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D32)) =>
      Vector n Float where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D64)) =>
      Vector n Double where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: LLVM.IntSize)) =>
      Vector n Word where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D8)) =>
      Vector n Word8 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D16)) =>
      Vector n Word16 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D32)) =>
      Vector n Word32 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D64)) =>
      Vector n Word64 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: LLVM.IntSize)) =>
      Vector n Int where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D8)) =>
      Vector n Int8 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D16)) =>
      Vector n Int16 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D32)) =>
      Vector n Int32 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D64)) =>
      Vector n Int64 where
   packVector = id
   unpackVector = id

instance (Vector n a, Vector n b) => Vector n (a,b) where
   packVector x =
      case FuncHT.unzip x of
         (a,b) -> LLVM.consStruct (packVector a) (packVector b)
   unpackVector = LLVM.uncurryStruct $ \a b ->
      liftA2 (,) (unpackVector a) (unpackVector b)

instance (Vector n a, Vector n b, Vector n c) => Vector n (a,b,c) where
   packVector x =
      case FuncHT.unzip3 x of
         (a,b,c) -> LLVM.consStruct (packVector a) (packVector b) (packVector c)
   unpackVector = LLVM.uncurryStruct $ \a b c ->
      liftA3 (,,) (unpackVector a) (unpackVector b) (unpackVector c)



class (C a, MultiValue.C a) => MV a where

instance MV Float  where
instance MV Double where
instance MV Word   where
instance MV Word8  where
instance MV Word16 where
instance MV Word32 where
instance MV Word64 where
instance MV Int    where
instance MV Int8   where
instance MV Int16  where
instance MV Int32  where
instance MV Int64  where

instance (Storable a)        => MV (Ptr a)       where
instance (LLVM.IsType a)     => MV (LLVM.Ptr a)  where
instance (LLVM.IsFunction a) => MV (FunPtr a)    where
instance                        MV (StablePtr a) where

instance MV () where

instance
   (LLVM.IsSized (Struct a), LLVM.IsSized (Struct b), MV a, MV b) =>
      MV (a,b) where

instance
   (LLVM.IsSized (Struct a), LLVM.IsSized (Struct b), LLVM.IsSized (Struct c),
    MV a, MV b, MV c) =>
      MV (a,b,c) where

instance
   (LLVM.IsSized (Struct a), LLVM.IsSized (Struct b),
    LLVM.IsSized (Struct c), LLVM.IsSized (Struct d),
    MV a, MV b, MV c, MV d) =>
      MV (a,b,c,d) where


with :: (C a) => a -> (LLVM.Ptr (Struct a) -> IO b) -> IO b
with a act = EE.alloca $ \ptr -> poke ptr a >> act ptr