{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{- |
Transfer values between Haskell and JIT generated code
in an LLVM-compatible format.
E.g. 'Bool' is stored as 'i1' and occupies a byte,
@'Vector' n 'Bool'@ is stored as a bit vector,
@'Vector' n 'Word8'@ is stored in an order depending on machine endianess,
and Haskell tuples are stored as LLVM structs.
-}
module LLVM.Extra.Nice.Value.Marshal (
   C(..),
   Struct,
   peek,
   poke,

   VectorStruct,
   Vector(..),

   with,
   EE.alloca,
   ) where

import qualified LLVM.Extra.Nice.Vector as NiceVector
import qualified LLVM.Extra.Nice.Value.Private as NiceValue
import qualified LLVM.Extra.Memory as Memory
import LLVM.Extra.Nice.Vector.Instance ()

import qualified LLVM.ExecutionEngine as EE
import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum

import qualified Control.Functor.HT as FuncHT
import Control.Applicative (liftA2, liftA3, (<$>))

import Foreign.Storable (Storable)
import Foreign.StablePtr (StablePtr)
import Foreign.Ptr (FunPtr, Ptr)

import Data.Complex (Complex((:+)))
import Data.Word (Word8, Word16, Word32, Word64, Word)
import Data.Int  (Int8,  Int16,  Int32,  Int64)



peek ::
   (C a, Struct a ~ struct, EE.Marshal struct) => LLVM.Ptr struct -> IO a
peek ptr = unpack <$> EE.peek ptr

poke ::
   (C a, Struct a ~ struct, EE.Marshal struct) => LLVM.Ptr struct -> a -> IO ()
poke ptr = EE.poke ptr . pack


type Struct a = Memory.Struct (NiceValue.Repr a)

class
   (NiceValue.C a, Memory.C (NiceValue.Repr a),
    EE.Marshal (Struct a), LLVM.IsConst (Struct a)) =>
      C a where
   pack :: a -> Struct a
   unpack :: Struct a -> a

instance C Bool   where pack = id; unpack = id
instance C Float  where pack = id; unpack = id
instance C Double where pack = id; unpack = id
instance C Word   where pack = id; unpack = id
instance C Word8  where pack = id; unpack = id
instance C Word16 where pack = id; unpack = id
instance C Word32 where pack = id; unpack = id
instance C Word64 where pack = id; unpack = id
instance C Int    where pack = id; unpack = id
instance C Int8   where pack = id; unpack = id
instance C Int16  where pack = id; unpack = id
instance C Int32  where pack = id; unpack = id
instance C Int64  where pack = id; unpack = id

instance (Storable a)        => C (Ptr a)       where pack = id; unpack = id
instance (LLVM.IsType a)     => C (LLVM.Ptr a)  where pack = id; unpack = id
instance (LLVM.IsFunction a) => C (FunPtr a)    where pack = id; unpack = id
instance                        C (StablePtr a) where pack = id; unpack = id

instance C () where
   pack = LLVM.Struct
   unpack (LLVM.Struct unit) = unit

instance (C a, C b) => C (a,b) where
   pack (a,b) = LLVM.consStruct (pack a) (pack b)
   unpack = LLVM.uncurryStruct $ \a b -> (unpack a, unpack b)

instance (C a, C b, C c) => C (a,b,c) where
   pack (a,b,c) = LLVM.consStruct (pack a) (pack b) (pack c)
   unpack = LLVM.uncurryStruct $ \a b c -> (unpack a, unpack b, unpack c)

instance (C a, C b, C c, C d) => C (a,b,c,d) where
   pack (a,b,c,d) = LLVM.consStruct (pack a) (pack b) (pack c) (pack d)
   unpack =
      LLVM.uncurryStruct $ \a b c d -> (unpack a, unpack b, unpack c, unpack d)


instance (C a) => C (Complex a) where
   pack (a:+b) = LLVM.consStruct (pack a) (pack b)
   unpack = LLVM.uncurryStruct $ \a b -> unpack a :+ unpack b



type VectorStruct n a = Memory.Struct (NiceVector.Repr n a)

class
   (TypeNum.Positive n, C a,
    NiceVector.C a, Memory.C (NiceVector.Repr n a),
    EE.Marshal (VectorStruct n a),
    LLVM.IsConst (VectorStruct n a)) =>
      Vector n a where
   packVector :: LLVM.Vector n a -> VectorStruct n a
   unpackVector :: VectorStruct n a -> LLVM.Vector n a

instance (TypeNum.Positive n, Vector n a) => C (LLVM.Vector n a) where
   pack = packVector; unpack = unpackVector


instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D1)) =>
      Vector n Bool where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D32)) =>
      Vector n Float where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D64)) =>
      Vector n Double where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: LLVM.IntSize)) =>
      Vector n Word where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D8)) =>
      Vector n Word8 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D16)) =>
      Vector n Word16 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D32)) =>
      Vector n Word32 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D64)) =>
      Vector n Word64 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: LLVM.IntSize)) =>
      Vector n Int where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D8)) =>
      Vector n Int8 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D16)) =>
      Vector n Int16 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D32)) =>
      Vector n Int32 where
   packVector = id
   unpackVector = id

instance
   (TypeNum.Positive n, TypeNum.Natural (n TypeNum.:*: TypeNum.D64)) =>
      Vector n Int64 where
   packVector = id
   unpackVector = id

instance (Vector n a, Vector n b) => Vector n (a,b) where
   packVector x =
      case FuncHT.unzip x of
         (a,b) -> LLVM.consStruct (packVector a) (packVector b)
   unpackVector = LLVM.uncurryStruct $ \a b ->
      liftA2 (,) (unpackVector a) (unpackVector b)

instance (Vector n a, Vector n b, Vector n c) => Vector n (a,b,c) where
   packVector x =
      case FuncHT.unzip3 x of
         (a,b,c) -> LLVM.consStruct (packVector a) (packVector b) (packVector c)
   unpackVector = LLVM.uncurryStruct $ \a b c ->
      liftA3 (,,) (unpackVector a) (unpackVector b) (unpackVector c)


with :: (C a) => a -> (LLVM.Ptr (Struct a) -> IO b) -> IO b
with a act = EE.alloca $ \ptr -> poke ptr a >> act ptr
