{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module LLVM.Extra.Representation (
   Memory(load, store, decompose, compose), modify, castStorablePtr,
   MemoryRecord, MemoryElement, memoryElement,
   loadRecord, storeRecord, decomposeRecord, composeRecord,
   loadNewtype, storeNewtype, decomposeNewtype, composeNewtype,

   newForeignPtrInit, newForeignPtrParam,
   newForeignPtr, withForeignPtr,
   malloc, free,
   ) where

import qualified LLVM.Core as LLVM
import LLVM.Core
   (MakeValueTuple,
    Struct, getElementPtr0,
    extractvalue, insertvalue,
    Value, valueOf, Vector,
    IsType, IsSized,
    CodeGenFunction, )
import LLVM.Util.Loop (Phi, )

import qualified Foreign.Marshal.Utils as Marshal
import qualified Foreign.ForeignPtr as FPtr
import qualified Foreign.Concurrent as FC
import Foreign.Storable (Storable, poke, )
import Foreign.Ptr (Ptr, castPtr, FunPtr, )
import Data.TypeLevel.Num (d0, d1, d2, D4, )
import Data.Word (Word32, Word64, )
-- import Data.Word (Word8, Word16, Word32, Word64, )
-- import Data.Int  (Int8,  Int16,  Int32,  Int64, )

import Control.Monad (ap, )
import Control.Applicative (pure, liftA2, liftA3, )
import qualified Control.Applicative as App

import Data.Tuple.HT (fst3, snd3, thd3, )


-- * Memory class and helper functions

{- |
An implementation of both 'MakeValueTuple' and 'Memory'
must ensure that @haskellValue@ is compatible with @llvmStruct@.
That is, writing and reading @llvmStruct@ by LLVM
must be the same as accessing @haskellValue@ by 'Storable' methods.

We use a functional dependency in order to let type inference work nicely.
-}
class (Phi llvmValue, IsType llvmStruct) =>
      Memory llvmValue llvmStruct | llvmValue -> llvmStruct where
   load :: Value (Ptr llvmStruct) -> CodeGenFunction r llvmValue
   load ptr  =  decompose =<< LLVM.load ptr
   store :: llvmValue -> Value (Ptr llvmStruct) -> CodeGenFunction r (Value ())
   store r ptr  =  flip LLVM.store ptr =<< compose r
   decompose :: Value llvmStruct -> CodeGenFunction r llvmValue
   compose :: llvmValue -> CodeGenFunction r (Value llvmStruct)

modify ::
   (Memory llvmValue llvmStruct) =>
   (llvmValue -> CodeGenFunction r llvmValue) ->
   Value (Ptr llvmStruct) -> CodeGenFunction r (Value ())
modify f ptr =
   flip store ptr =<< f =<< load ptr


type MemoryRecord r o v = MemoryElement r o v v

data MemoryElement r o v x =
   MemoryElement {
      loadElement :: Value (Ptr o) -> CodeGenFunction r x,
      storeElement :: Value (Ptr o) -> v -> CodeGenFunction r (Value ()),
      extractElement :: Value o -> CodeGenFunction r x,
      insertElement :: v -> Value o -> CodeGenFunction r (Value o)
         -- State.Monoid
   }

memoryElement ::
   (Memory x llvmStruct,
    LLVM.GetValue o n llvmStruct,
    LLVM.GetElementPtr o (n, ()) llvmStruct) =>
   (v -> x) -> n -> MemoryElement r o v x
memoryElement field n =
   MemoryElement {
      loadElement = \ptr -> load =<< getElementPtr0 ptr (n, ()),
      storeElement = \ptr v -> store (field v) =<< getElementPtr0 ptr (n, ()),
      extractElement = \o -> decompose =<< extractvalue o n,
      insertElement = \v o -> flip (insertvalue o) n =<< compose (field v)
   }

instance Functor (MemoryElement r o v) where
   fmap f m =
      MemoryElement {
         loadElement = fmap f . loadElement m,
         storeElement = storeElement m,
         extractElement = fmap f . extractElement m,
         insertElement = insertElement m
      }

instance App.Applicative (MemoryElement r o v) where
   pure x =
      MemoryElement {
         loadElement = \ _ptr -> return x,
         storeElement = \ _ptr _v ->
            return (error "MemoryElement: undefined value" :: Value ()),
         extractElement = \ _o -> return x,
         insertElement = \ _v o -> return o
      }
   f <*> x =
      MemoryElement {
         loadElement = \ptr -> loadElement f ptr `ap` loadElement x ptr,
         storeElement = \ptr y -> storeElement f ptr y >> storeElement x ptr y,
         extractElement = \o -> extractElement f o `ap` extractElement x o,
         insertElement = \y o -> insertElement f y o >>= insertElement x y
      }


loadRecord ::
   MemoryRecord r o llvmValue ->
   Value (Ptr o) -> CodeGenFunction r llvmValue
loadRecord = loadElement

storeRecord ::
   MemoryRecord r o llvmValue ->
   llvmValue -> Value (Ptr o) -> CodeGenFunction r (Value ())
storeRecord m y ptr = storeElement m ptr y

decomposeRecord ::
   MemoryRecord r o llvmValue ->
   Value o -> CodeGenFunction r llvmValue
decomposeRecord m =
   extractElement m

composeRecord ::
   (IsType o) =>
   MemoryRecord r o llvmValue ->
   llvmValue -> CodeGenFunction r (Value o)
composeRecord m v =
   insertElement m v (LLVM.value LLVM.undef)



pairMemory ::
   (Memory al as, Memory bl bs,
    IsSized as sas, IsSized bs sbs) =>
   MemoryRecord r (Struct (as, (bs, ()))) (al, bl)
pairMemory =
   liftA2 (,)
      (memoryElement fst d0)
      (memoryElement snd d1)

instance
      (Memory al as, Memory bl bs,
       IsSized as sas, IsSized bs sbs) =>
      Memory (al, bl) (Struct (as, (bs, ()))) where
   load = loadRecord pairMemory
   store = storeRecord pairMemory
   decompose = decomposeRecord pairMemory
   compose = composeRecord pairMemory


tripleMemory ::
   (Memory al as, Memory bl bs, Memory cl cs,
    IsSized as sas, IsSized bs sbs, IsSized cs scs) =>
   MemoryRecord r (Struct (as, (bs, (cs, ())))) (al, bl, cl)
tripleMemory =
   liftA3 (,,)
      (memoryElement fst3 d0)
      (memoryElement snd3 d1)
      (memoryElement thd3 d2)

instance
      (Memory al as, Memory bl bs, Memory cl cs,
       IsSized as sas, IsSized bs sbs, IsSized cs scs) =>
      Memory (al, bl, cl) (Struct (as, (bs, (cs, ())))) where
   load = loadRecord tripleMemory
   store = storeRecord tripleMemory
   decompose = decomposeRecord tripleMemory
   compose = composeRecord tripleMemory


instance (LLVM.IsFirstClass a) => Memory (Value a) a where
   load = LLVM.load
   store = LLVM.store
   decompose = return
   compose = return

instance Memory () (Struct ()) where
   load _ = return ()
   store _ _ = return (error "().store: no result" :: Value ())
   decompose _ = return ()
   compose _ = return (LLVM.value LLVM.undef)

castStorablePtr ::
   (MakeValueTuple haskellValue llvmValue, Memory llvmValue llvmStruct) =>
   Ptr haskellValue -> Ptr llvmStruct
castStorablePtr = castPtr



loadNewtype ::
   (Memory a o) =>
   (a -> llvmValue) ->
   Value (Ptr o) -> CodeGenFunction r llvmValue
loadNewtype wrap ptr =
   fmap wrap $ load ptr

storeNewtype ::
   (Memory a o) =>
   (llvmValue -> a) ->
   llvmValue -> Value (Ptr o) -> CodeGenFunction r (Value ())
storeNewtype unwrap y ptr =
   store (unwrap y) ptr

decomposeNewtype ::
   (Memory a o) =>
   (a -> llvmValue) ->
   Value o -> CodeGenFunction r llvmValue
decomposeNewtype wrap y =
   fmap wrap $ decompose y

composeNewtype ::
   (Memory a o) =>
   (llvmValue -> a) ->
   llvmValue -> CodeGenFunction r (Value o)
composeNewtype unwrap y =
   compose (unwrap y)




-- * ForeignPtr support

type Importer f = FunPtr f -> f

foreign import ccall safe "dynamic" derefStartPtr ::
   Importer (IO (Ptr a))

newForeignPtrInit ::
   FunPtr (Ptr a -> IO ()) ->
   FunPtr (IO (Ptr a)) ->
   IO (FPtr.ForeignPtr a)
newForeignPtrInit stop start =
   FPtr.newForeignPtr stop =<< derefStartPtr start


foreign import ccall safe "dynamic" derefStartParamPtr ::
   Importer (Ptr b -> IO (Ptr a))

{-
We cannot use 'bracket' when constructing lazy StorableVector,
since this would mean that the temporary memory is freed immediately.
Instead we must add a Finalizer to the ForeignPtr.
-}
newForeignPtrParam ::
   (Storable b, MakeValueTuple b bl, Memory bl bp) =>
   FunPtr (Ptr a -> IO ()) ->
   FunPtr (Ptr bp -> IO (Ptr a)) ->
   b -> IO (FPtr.ForeignPtr a)
newForeignPtrParam stop start b =
   FPtr.newForeignPtr stop =<<
   Marshal.with b (derefStartParamPtr start . castStorablePtr)

{-
requires (Storable ap) constraint
and we have no Storable instance for Struct

newForeignPtr ::
   (Storable a, MakeValueTuple a al, Memory al ap) =>
   a -> IO (FPtr.ForeignPtr ap)
newForeignPtr a = do
   ptr <- FPtr.mallocForeignPtr
   FPtr.withForeignPtr ptr (flip poke a . castPtr)
   return ptr
-}

{- |
Adding the finalizer to a ForeignPtr seems to be the only way
that warrants execution of the finalizer (not too early and not never).
However, the normal ForeignPtr finalizers must be independent from Haskell runtime.
In contrast to ForeignPtr finalizers,
addFinalizer adds finalizers to boxes, that are optimized away.
Thus finalizers are run too early or not at all.
Concurrent.ForeignPtr and using threaded execution
is the only way to get finalizers in Haskell IO.
-}
newForeignPtr ::
   Storable a =>
   IO () ->
   a -> IO (FPtr.ForeignPtr a)
newForeignPtr finalizer a = do
   ptr <- FPtr.mallocForeignPtr
   FC.addForeignPtrFinalizer ptr finalizer
   FPtr.withForeignPtr ptr (flip poke a)
   return ptr

withForeignPtr ::
   (Storable a, MakeValueTuple a al, Memory al ap) =>
   FPtr.ForeignPtr a -> (Ptr ap -> IO b) -> IO b
withForeignPtr fp func =
   FPtr.withForeignPtr fp (func . castStorablePtr)


{-
malloc :: (IsSized a s) => CodeGenFunction r (Value (Ptr a))
malloc = LLVM.malloc

free :: (IsSized a s) => Value (Ptr a) -> CodeGenFunction r (Value ())
free = LLVM.free
-}


type Aligned a = Struct (a, (Ptr (Vector D4 Float), ()))
type AlignedPtr a = Ptr (Aligned a)

{- |
Returns 16 Byte aligned piece of memory.
Otherwise program crashes when vectors are part of the structure.
I think that malloc in LLVM-2.5 and LLVM-2.6 is simply buggy.

FIXME:
Aligning to 16 Byte might not be appropriate for all vector types on all platforms.
Maybe we should use alignment of Storable class
in order to determine the right alignment.
-}
malloc :: (IsSized a s) => CodeGenFunction r (Value (Ptr a))
malloc =
   let m :: (IsSized a s) =>
            CodeGenFunction r (Value (Ptr (Struct (Vector D4 Float, (Aligned a, ())))))
       m = LLVM.malloc
   in  do p <- m
          -- skip pad
          p1 <- getElementPtr0 p (d1, ())
          p1int <- LLVM.ptrtoint p1
          -- go back to the last 16 byte aligned address
          p16int <- LLVM.and (valueOf (-16) :: Value Word64) (p1int :: Value Word64)
          p16 <- LLVM.inttoptr p16int
          {-
          v has same address as p but different type.
          This way we avoid a recursive datatype but we avoid also a cast.
          -}
          v <- getElementPtr0 p (d0, ())
          store v =<< getElementPtr0 (p16 `asTypeOf` p1) (d1, ())
          getElementPtr0 p16 (d0, ())

{-
This is correct but will be optimized incorrectly.
The "optimized" code will access a pointer
that is 4 cells greater than the right pointer
for certain sizes of the record @a@.

free :: (IsSized a s) => Value (Ptr a) -> CodeGenFunction r (Value ())
free p =
   LLVM.free =<<
   load =<<
   flip getElementPtr0 (d1, ()) =<<
   (LLVM.bitcastUnify ::
      (IsSized a sa) =>
      Value (Ptr a) ->
      CodeGenFunction r (Value (AlignedPtr a))) p
-}

free :: (IsSized a s) => Value (Ptr a) -> CodeGenFunction r (Value ())
free p =
   LLVM.free =<<
   load =<<
   (LLVM.bitcastUnify ::
      (IsSized a sa) =>
      Value (Ptr a) ->
      CodeGenFunction r (Value (Ptr (AlignedPtr a)))) =<<
   LLVM.getElementPtr p (1 :: Word32, ())