{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
module LLVM.Extra.Memory (
   C(load, store, decompose, compose), modify, castStorablePtr,
   Record, Element, element,
   loadRecord, storeRecord, decomposeRecord, composeRecord,
   loadNewtype, storeNewtype, decomposeNewtype, composeNewtype,
   FirstClass,
   ) where

import LLVM.Extra.Class (MakeValueTuple, Undefined, )
import qualified LLVM.Extra.ArithmeticPrivate as A
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Array as Array

import qualified LLVM.Core as LLVM
import LLVM.Core
   (Struct, getElementPtr0,
    extractvalue, insertvalue,
    Value, -- valueOf, Vector,
    IsType, IsSized,
    CodeGenFunction, )
import LLVM.Util.Loop (Phi, )

import qualified Data.TypeLevel.Num as TypeNum
import Data.TypeLevel.Num (d0, d1, d2, )

import Foreign.StablePtr (StablePtr, )
import Foreign.Ptr (Ptr, castPtr, )

import Data.Word (Word8, Word16, Word32, Word64, )
import Data.Int  (Int8,  Int16,  Int32,  Int64, )

import Control.Monad (ap, )
import Control.Applicative (pure, liftA2, liftA3, )
import qualified Control.Applicative as App

import Data.Tuple.HT (fst3, snd3, thd3, )


{- |
An implementation of both 'MakeValueTuple' and 'Memory.C'
must ensure that @haskellValue@ is compatible with @llvmStruct@.
That is, writing and reading @llvmStruct@ by LLVM
must be the same as accessing @haskellValue@ by 'Storable' methods.
ToDo: In future we may also require Storable constraint for llvmStruct.

We use a functional dependency in order to let type inference work nicely.
-}
class (Phi llvmValue, Undefined llvmValue, IsType llvmStruct) =>
      C llvmValue llvmStruct | llvmValue -> llvmStruct where
   load :: Value (Ptr llvmStruct) -> CodeGenFunction r llvmValue
   load ptr  =  decompose =<< LLVM.load ptr
   store :: llvmValue -> Value (Ptr llvmStruct) -> CodeGenFunction r ()
   store r ptr  =  flip LLVM.store ptr =<< compose r
   decompose :: Value llvmStruct -> CodeGenFunction r llvmValue
   compose :: llvmValue -> CodeGenFunction r (Value llvmStruct)

modify ::
   (C llvmValue llvmStruct) =>
   (llvmValue -> CodeGenFunction r llvmValue) ->
   Value (Ptr llvmStruct) -> CodeGenFunction r ()
modify f ptr =
   flip store ptr =<< f =<< load ptr


type Record r o v = Element r o v v

data Element r o v x =
   Element {
      loadElement :: Value (Ptr o) -> CodeGenFunction r x,
      storeElement :: Value (Ptr o) -> v -> CodeGenFunction r (),
      extractElement :: Value o -> CodeGenFunction r x,
      insertElement :: v -> Value o -> CodeGenFunction r (Value o)
         -- State.Monoid
   }

element ::
   (C x llvmStruct,
    LLVM.GetValue o n llvmStruct,
    LLVM.GetElementPtr o (n, ()) llvmStruct) =>
   (v -> x) -> n -> Element r o v x
element field n =
   Element {
      loadElement = \ptr -> load =<< getElementPtr0 ptr (n, ()),
      storeElement = \ptr v -> store (field v) =<< getElementPtr0 ptr (n, ()),
      extractElement = \o -> decompose =<< extractvalue o n,
      insertElement = \v o -> flip (insertvalue o) n =<< compose (field v)
   }

instance Functor (Element r o v) where
   fmap f m =
      Element {
         loadElement = fmap f . loadElement m,
         storeElement = storeElement m,
         extractElement = fmap f . extractElement m,
         insertElement = insertElement m
      }

instance App.Applicative (Element r o v) where
   pure x =
      Element {
         loadElement = \ _ptr -> return x,
         storeElement = \ _ptr _v -> return (),
         extractElement = \ _o -> return x,
         insertElement = \ _v o -> return o
      }
   f <*> x =
      Element {
         loadElement = \ptr -> loadElement f ptr `ap` loadElement x ptr,
         storeElement = \ptr y -> storeElement f ptr y >> storeElement x ptr y,
         extractElement = \o -> extractElement f o `ap` extractElement x o,
         insertElement = \y o -> insertElement f y o >>= insertElement x y
      }


loadRecord ::
   Record r o llvmValue ->
   Value (Ptr o) -> CodeGenFunction r llvmValue
loadRecord = loadElement

storeRecord ::
   Record r o llvmValue ->
   llvmValue -> Value (Ptr o) -> CodeGenFunction r ()
storeRecord m y ptr = storeElement m ptr y

decomposeRecord ::
   Record r o llvmValue ->
   Value o -> CodeGenFunction r llvmValue
decomposeRecord m =
   extractElement m

composeRecord ::
   (IsType o) =>
   Record r o llvmValue ->
   llvmValue -> CodeGenFunction r (Value o)
composeRecord m v =
   insertElement m v (LLVM.value LLVM.undef)



pair ::
   (C al as, C bl bs,
    IsSized as sas, IsSized bs sbs) =>
   Record r (Struct (as, (bs, ()))) (al, bl)
pair =
   liftA2 (,)
      (element fst d0)
      (element snd d1)

instance
      (C al as, C bl bs,
       IsSized as sas, IsSized bs sbs) =>
      C (al, bl) (Struct (as, (bs, ()))) where
   load = loadRecord pair
   store = storeRecord pair
   decompose = decomposeRecord pair
   compose = composeRecord pair


triple ::
   (C al as, C bl bs, C cl cs,
    IsSized as sas, IsSized bs sbs, IsSized cs scs) =>
   Record r (Struct (as, (bs, (cs, ())))) (al, bl, cl)
triple =
   liftA3 (,,)
      (element fst3 d0)
      (element snd3 d1)
      (element thd3 d2)

instance
      (C al as, C bl bs, C cl cs,
       IsSized as sas, IsSized bs sbs, IsSized cs scs) =>
      C (al, bl, cl) (Struct (as, (bs, (cs, ())))) where
   load = loadRecord triple
   store = storeRecord triple
   decompose = decomposeRecord triple
   compose = composeRecord triple


{-
This would not work for Booleans,
since on x86 LLVM's @i1@ type uses one byte in memory,
whereas Storable uses 4 byte and 4 byte alignment.

instance (LLVM.IsFirstClass a) => C (Value a) a where
   load = LLVM.load
   store = LLVM.store
   decompose = return
   compose = return
-}


class (LLVM.IsFirstClass llvmType, IsType llvmStruct) =>
      FirstClass llvmType llvmStruct | llvmType -> llvmStruct where
   fromStorable :: Value llvmStruct -> CodeGenFunction r (Value llvmType)
   toStorable :: Value llvmType -> CodeGenFunction r (Value llvmStruct)

instance FirstClass Float  Float  where fromStorable = return; toStorable = return
instance FirstClass Double Double where fromStorable = return; toStorable = return
instance FirstClass Int8   Int8   where fromStorable = return; toStorable = return
instance FirstClass Int16  Int16  where fromStorable = return; toStorable = return
instance FirstClass Int32  Int32  where fromStorable = return; toStorable = return
instance FirstClass Int64  Int64  where fromStorable = return; toStorable = return
instance FirstClass Word8  Word8  where fromStorable = return; toStorable = return
instance FirstClass Word16 Word16 where fromStorable = return; toStorable = return
instance FirstClass Word32 Word32 where fromStorable = return; toStorable = return
instance FirstClass Word64 Word64 where fromStorable = return; toStorable = return
instance FirstClass Bool   Word32 where
   fromStorable = A.cmp LLVM.CmpNE (LLVM.value LLVM.zero)
   toStorable = LLVM.zext
instance
   (LLVM.Pos n, LLVM.IsPrimitive a, LLVM.IsPrimitive am, FirstClass a am) =>
      FirstClass (LLVM.Vector n a) (LLVM.Vector n am) where
   fromStorable = Vector.map fromStorable
   toStorable = Vector.map toStorable
instance
   (LLVM.Nat n, LLVM.IsFirstClass am,
    FirstClass a am, IsSized a asize, IsSized am amsize) =>
      FirstClass (LLVM.Array n a) (LLVM.Array n am) where
   fromStorable = Array.map fromStorable
   toStorable = Array.map toStorable

instance (IsType a) => FirstClass (Ptr a) (Ptr a) where
   fromStorable = return; toStorable = return
instance FirstClass (StablePtr a) (StablePtr a) where
   fromStorable = return; toStorable = return


instance
   (LLVM.IsFirstClass (Struct s),
    IsType (Struct sm),
    ConvertStruct s sm TypeNum.D0 s sm) =>
      FirstClass (Struct s) (Struct sm) where
   fromStorable sm =
      case undefined of
         sfields -> do
            s <- decomposeField sfields (fields sm) d0 sm
            let _ = asTypeOf (fields s) sfields
            return s
   toStorable s =
      case undefined of
         smfields -> do
            sm <- composeField (fields s) smfields d0 s
            let _ = asTypeOf (fields sm) smfields
            return sm

fields :: Value (Struct s) -> s
fields _ = undefined

class
   ConvertStruct s sm i rem remm |
      s -> sm, rem -> remm, s rem -> i, sm remm -> i where
   decomposeField ::
      rem -> remm ->
      i -> Value (Struct sm) ->
      CodeGenFunction r (Value (Struct s))
   composeField ::
      rem -> remm ->
      i -> Value (Struct s) ->
      CodeGenFunction r (Value (Struct sm))

instance
   (LLVM.GetValue (Struct s) i a,
    LLVM.GetValue (Struct sm) i am,
    FirstClass a am,
    ConvertStruct s sm i' rem remm,
    TypeNum.Succ i i') =>
      ConvertStruct s sm i (a,rem) (am,remm) where
   decomposeField ~(_,rem_) ~(_,remm) i sm = do
      s <- decomposeField rem_ remm (TypeNum.succ i) sm
      a <- fromStorable =<< LLVM.extractvalue sm i
      LLVM.insertvalue s a i
   composeField ~(_,rem_) ~(_,remm) i s = do
      sm <- composeField rem_ remm (TypeNum.succ i) s
      am <- toStorable =<< LLVM.extractvalue s i
      LLVM.insertvalue sm am i

instance
   (IsType (Struct s),
    IsType (Struct sm)) =>
      ConvertStruct s sm i () () where
   decomposeField _ _ _ _ =
      return (LLVM.value LLVM.undef)
   composeField _ _ _ _ =
      return (LLVM.value LLVM.undef)

instance (FirstClass a am) => C (Value a) am where
   decompose = fromStorable
   compose = toStorable


instance C () (Struct ()) where
   load _ = return ()
   store _ _ = return ()
   decompose _ = return ()
   compose _ = return (LLVM.value LLVM.undef)

castStorablePtr ::
   (MakeValueTuple haskellValue llvmValue, C llvmValue llvmStruct) =>
   Ptr haskellValue -> Ptr llvmStruct
castStorablePtr = castPtr



loadNewtype ::
   (C a o) =>
   (a -> llvmValue) ->
   Value (Ptr o) -> CodeGenFunction r llvmValue
loadNewtype wrap ptr =
   fmap wrap $ load ptr

storeNewtype ::
   (C a o) =>
   (llvmValue -> a) ->
   llvmValue -> Value (Ptr o) -> CodeGenFunction r ()
storeNewtype unwrap y ptr =
   store (unwrap y) ptr

decomposeNewtype ::
   (C a o) =>
   (a -> llvmValue) ->
   Value o -> CodeGenFunction r llvmValue
decomposeNewtype wrap y =
   fmap wrap $ decompose y

composeNewtype ::
   (C a o) =>
   (llvmValue -> a) ->
   llvmValue -> CodeGenFunction r (Value o)
composeNewtype unwrap y =
   compose (unwrap y)