{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
module LLVM.Extra.Memory (
   C(load, store, decompose, compose), modify, castStorablePtr,
   Struct,
   Record, Element, element,
   loadRecord, storeRecord, decomposeRecord, composeRecord,
   loadNewtype, storeNewtype, decomposeNewtype, composeNewtype,
   FirstClass, Stored,
   ) where

import LLVM.Extra.Class (MakeValueTuple, ValueTuple, Undefined, )
import LLVM.Extra.MemoryPrivate (decomposeFromLoad, composeFromStore, )

import qualified LLVM.Extra.Multi.Vector.Memory as MultiVectorMemory
import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.ArithmeticPrivate as A
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Scalar as Scalar
import qualified LLVM.Extra.Array as Array
import qualified LLVM.Extra.Either as Either
import qualified LLVM.Extra.Maybe as Maybe

import qualified LLVM.Util.Proxy as LP
import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (Phi, )
import LLVM.Core
   (getElementPtr0,
    extractvalue, insertvalue,
    Value, -- valueOf, Vector,
    IsType, IsSized,
    CodeGenFunction, )

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal (d0, d1, d2, )
import Type.Base.Proxy (Proxy(Proxy), )

import Foreign.StablePtr (StablePtr, )
import Foreign.Ptr (FunPtr, Ptr, castPtr, )

import Data.Word (Word8, Word16, Word32, Word64, )
import Data.Int  (Int8,  Int16,  Int32,  Int64, )

import Control.Monad (ap, )
import Control.Applicative (pure, liftA2, liftA3, )
import qualified Control.Applicative as App

import Data.Tuple.HT (fst3, snd3, thd3, )

import Prelude hiding (maybe, either, )


{- |
An implementation of both 'MakeValueTuple' and 'Memory.C'
must ensure that @haskellValue@ is compatible
with @Stored (Struct haskellValue)@ (which we want to call @llvmStruct@).
That is, writing and reading @llvmStruct@ by LLVM
must be the same as accessing @haskellValue@ by 'Storable' methods.
ToDo: In future we may also require Storable constraint for @llvmStruct@.

We use a functional dependency in order to let type inference work nicely.
-}
class (Phi llvmValue, Undefined llvmValue, IsType (Struct llvmValue), IsSized (Struct llvmValue)) =>
      C llvmValue where
   {-# MINIMAL (load|decompose), (store|compose) #-}
   type Struct llvmValue :: *
   load :: Value (Ptr (Struct llvmValue)) -> CodeGenFunction r llvmValue
   load ptr  =  decompose =<< LLVM.load ptr
   store :: llvmValue -> Value (Ptr (Struct llvmValue)) -> CodeGenFunction r ()
   store r ptr  =  flip LLVM.store ptr =<< compose r
   decompose :: Value (Struct llvmValue) -> CodeGenFunction r llvmValue
   decompose = decomposeFromLoad load
   compose :: llvmValue -> CodeGenFunction r (Value (Struct llvmValue))
   compose = composeFromStore store

modify ::
   (C llvmValue) =>
   (llvmValue -> CodeGenFunction r llvmValue) ->
   Value (Ptr (Struct llvmValue)) -> CodeGenFunction r ()
modify f ptr =
   flip store ptr =<< f =<< load ptr


instance C () where
   type Struct () = LLVM.Struct ()
   load _ = return ()
   store _ _ = return ()
   decompose _ = return ()
   compose _ = return (LLVM.value $ LLVM.constStruct ())


type Record r o v = Element r o v v

data Element r o v x =
   Element {
      loadElement :: Value (Ptr o) -> CodeGenFunction r x,
      storeElement :: Value (Ptr o) -> v -> CodeGenFunction r (),
      extractElement :: Value o -> CodeGenFunction r x,
      insertElement :: v -> Value o -> CodeGenFunction r (Value o)
         -- State.Monoid
   }

element ::
   (C x,
    LLVM.GetValue o n, LLVM.ValueType o n ~ Struct x,
    LLVM.GetElementPtr o (n, ()), LLVM.ElementPtrType o (n, ()) ~ Struct x) =>
   (v -> x) -> n -> Element r o v x
element field n =
   Element {
      loadElement = \ptr -> load =<< getElementPtr0 ptr (n, ()),
      storeElement = \ptr v -> store (field v) =<< getElementPtr0 ptr (n, ()),
      extractElement = \o -> decompose =<< extractvalue o n,
      insertElement = \v o -> flip (insertvalue o) n =<< compose (field v)
   }

instance Functor (Element r o v) where
   fmap f m =
      Element {
         loadElement = fmap f . loadElement m,
         storeElement = storeElement m,
         extractElement = fmap f . extractElement m,
         insertElement = insertElement m
      }

instance App.Applicative (Element r o v) where
   pure x =
      Element {
         loadElement = \ _ptr -> return x,
         storeElement = \ _ptr _v -> return (),
         extractElement = \ _o -> return x,
         insertElement = \ _v o -> return o
      }
   f <*> x =
      Element {
         loadElement = \ptr -> loadElement f ptr `ap` loadElement x ptr,
         storeElement = \ptr y -> storeElement f ptr y >> storeElement x ptr y,
         extractElement = \o -> extractElement f o `ap` extractElement x o,
         insertElement = \y o -> insertElement f y o >>= insertElement x y
      }


loadRecord ::
   Record r o llvmValue ->
   Value (Ptr o) -> CodeGenFunction r llvmValue
loadRecord = loadElement

storeRecord ::
   Record r o llvmValue ->
   llvmValue -> Value (Ptr o) -> CodeGenFunction r ()
storeRecord m y ptr = storeElement m ptr y

decomposeRecord ::
   Record r o llvmValue ->
   Value o -> CodeGenFunction r llvmValue
decomposeRecord m =
   extractElement m

composeRecord ::
   (IsType o) =>
   Record r o llvmValue ->
   llvmValue -> CodeGenFunction r (Value o)
composeRecord m v =
   insertElement m v (LLVM.value LLVM.undef)



pair ::
   (C a, C b) =>
   Record r (LLVM.Struct (Struct a, (Struct b, ()))) (a, b)
pair =
   liftA2 (,)
      (element fst d0)
      (element snd d1)

instance (C a, C b) => C (a, b) where
   type Struct (a, b) = LLVM.Struct (Struct a, (Struct b, ()))
   load = loadRecord pair
   store = storeRecord pair
   decompose = decomposeRecord pair
   compose = composeRecord pair


triple ::
   (C a, C b, C c) =>
   Record r (LLVM.Struct (Struct a, (Struct b, (Struct c, ())))) (a, b, c)
triple =
   liftA3 (,,)
      (element fst3 d0)
      (element snd3 d1)
      (element thd3 d2)

instance (C a, C b, C c) => C (a, b, c) where
   type Struct (a, b, c) =
           LLVM.Struct (Struct a, (Struct b, (Struct c, ())))
   load = loadRecord triple
   store = storeRecord triple
   decompose = decomposeRecord triple
   compose = composeRecord triple


maybe ::
   (C a) =>
   Record r (LLVM.Struct (Word32, (Struct a, ()))) (Maybe.T a)
maybe =
   liftA2 Maybe.Cons
      (element Maybe.isJust d0)
      (element Maybe.fromJust d1)

instance (C a) => C (Maybe.T a) where
   type Struct (Maybe.T a) = LLVM.Struct (Word32, (Struct a, ()))
   load = loadRecord maybe
   store = storeRecord maybe
   decompose = decomposeRecord maybe
   compose = composeRecord maybe


either ::
   (C a, C b) =>
   Record r (LLVM.Struct (Word32, (Struct a, (Struct b, ())))) (Either.T a b)
either =
   liftA3 Either.Cons
      (element Either.isRight d0)
      (element Either.fromLeft d1)
      (element Either.fromRight d2)

instance (C a, C b) => C (Either.T a b) where
   type Struct (Either.T a b) = LLVM.Struct (Word32, (Struct a, (Struct b, ())))
   load = loadRecord either
   store = storeRecord either
   decompose = decomposeRecord either
   compose = composeRecord either



instance (C a) => C (Scalar.T a) where
   type Struct (Scalar.T a) = Struct a
   load = loadNewtype Scalar.Cons
   store = storeNewtype Scalar.decons
   decompose = decomposeNewtype Scalar.Cons
   compose = composeNewtype Scalar.decons


{-
This would not work for Booleans,
since on x86 LLVM's @i1@ type uses one byte in memory,
whereas Storable uses 4 byte and 4 byte alignment.

instance (LLVM.IsFirstClass a) => C (Value a) a where
   load = LLVM.load
   store = LLVM.store
   decompose = return
   compose = return
-}


class (LLVM.IsFirstClass llvmType, IsType (Stored llvmType)) =>
      FirstClass llvmType where
   type Stored llvmType :: *
   fromStorable :: Value (Stored llvmType) -> CodeGenFunction r (Value llvmType)
   toStorable :: Value llvmType -> CodeGenFunction r (Value (Stored llvmType))

instance FirstClass Float  where type Stored Float  = Float  ; fromStorable = return; toStorable = return
instance FirstClass Double where type Stored Double = Double ; fromStorable = return; toStorable = return
instance FirstClass Int8   where type Stored Int8   = Int8   ; fromStorable = return; toStorable = return
instance FirstClass Int16  where type Stored Int16  = Int16  ; fromStorable = return; toStorable = return
instance FirstClass Int32  where type Stored Int32  = Int32  ; fromStorable = return; toStorable = return
instance FirstClass Int64  where type Stored Int64  = Int64  ; fromStorable = return; toStorable = return
instance FirstClass Word8  where type Stored Word8  = Word8  ; fromStorable = return; toStorable = return
instance FirstClass Word16 where type Stored Word16 = Word16 ; fromStorable = return; toStorable = return
instance FirstClass Word32 where type Stored Word32 = Word32 ; fromStorable = return; toStorable = return
instance FirstClass Word64 where type Stored Word64 = Word64 ; fromStorable = return; toStorable = return
instance FirstClass Bool   where
   type Stored Bool = Word32
   fromStorable = A.cmp LLVM.CmpNE (LLVM.value LLVM.zero)
   toStorable = LLVM.zext
instance
   (TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsPrimitive (Stored a), FirstClass a) =>
      FirstClass (LLVM.Vector n a) where
   type Stored (LLVM.Vector n a) = LLVM.Vector n (Stored a)
   fromStorable = Vector.map fromStorable
   toStorable = Vector.map toStorable
instance
   (TypeNum.Natural n, LLVM.IsFirstClass (Stored a),
    FirstClass a, IsSized a, IsSized (Stored a)) =>
      FirstClass (LLVM.Array n a) where
   type Stored (LLVM.Array n a) = LLVM.Array n (Stored a)
   fromStorable = Array.map fromStorable
   toStorable = Array.map toStorable

instance (IsType a) => FirstClass (Ptr a) where
   type Stored (Ptr a) = Ptr a
   fromStorable = return; toStorable = return
instance (LLVM.IsFunction a) => FirstClass (FunPtr a) where
   type Stored (FunPtr a) = FunPtr a
   fromStorable = return; toStorable = return
instance FirstClass (StablePtr a) where
   type Stored (StablePtr a) = StablePtr a
   fromStorable = return; toStorable = return


instance
   (LLVM.IsFirstClass (LLVM.Struct s),
    IsType (LLVM.Struct (StoredStruct s)),
    ConvertStruct s TypeNum.D0 s) =>
      FirstClass (LLVM.Struct s) where
   type Stored (LLVM.Struct s) = LLVM.Struct (StoredStruct s)
   fromStorable sm =
      case LP.Proxy of
         sfields -> do
            s <- decomposeField sfields d0 sm
            let _ = asTypeOf (fields s) sfields
            return s
   toStorable s =
      composeField (fields s) d0 s

fields :: Value (LLVM.Struct s) -> LP.Proxy s
fields _ = LP.Proxy


type family StoredStruct s :: *
type instance StoredStruct () = ()
type instance StoredStruct (s,rem) = (Stored s, StoredStruct rem)

class ConvertStruct s i rem where
   decomposeField ::
      LP.Proxy rem -> Proxy i -> Value (LLVM.Struct (StoredStruct s)) ->
      CodeGenFunction r (Value (LLVM.Struct s))
   composeField ::
      LP.Proxy rem -> Proxy i -> Value (LLVM.Struct s) ->
      CodeGenFunction r (Value (LLVM.Struct (StoredStruct s)))

instance
   (sm ~ StoredStruct s,
    FirstClass a, am ~ Stored a,
    LLVM.GetValue (LLVM.Struct s) (Proxy i),
    LLVM.GetValue (LLVM.Struct sm) (Proxy i),
    LLVM.ValueType (LLVM.Struct s) (Proxy i) ~ a,
    LLVM.ValueType (LLVM.Struct sm) (Proxy i) ~ am,
    ConvertStruct s (TypeNum.Succ i) rem) =>
      ConvertStruct s i (a,rem) where
   decomposeField flds i sm = do
      s <- decomposeField (fmap snd flds) (decSucc i) sm
      a <- fromStorable =<< LLVM.extractvalue sm i
      LLVM.insertvalue s a i
   composeField flds i s = do
      sm <- composeField (fmap snd flds) (decSucc i) s
      am <- toStorable =<< LLVM.extractvalue s i
      LLVM.insertvalue sm am i

decSucc :: Proxy n -> Proxy (TypeNum.Succ n)
decSucc Proxy = Proxy

instance
   (sm ~ StoredStruct s,
    IsType (LLVM.Struct s),
    IsType (LLVM.Struct sm)) =>
      ConvertStruct s i () where
   decomposeField _ _ _ =
      return (LLVM.value LLVM.undef)
   composeField _ _ _ =
      return (LLVM.value LLVM.undef)


instance (FirstClass a, IsSized (Stored a)) => C (Value a) where
   type Struct (Value a) = Stored a
   decompose = fromStorable
   compose = toStorable


instance (MultiValueMemory.C a) => C (MultiValue.T a) where
   type Struct (MultiValue.T a) = MultiValueMemory.Struct a
   load      = MultiValueMemory.load
   store     = MultiValueMemory.store
   decompose = MultiValueMemory.decompose
   compose   = MultiValueMemory.compose


instance (MultiVectorMemory.C n a) => C (MultiVector.T n a) where
   type Struct (MultiVector.T n a) = MultiVectorMemory.Struct n a
   load      = MultiVectorMemory.load
   store     = MultiVectorMemory.store
   decompose = MultiVectorMemory.decompose
   compose   = MultiVectorMemory.compose


castStorablePtr ::
   (MakeValueTuple haskellValue, C (ValueTuple haskellValue)) =>
   Ptr haskellValue -> Ptr (Struct (ValueTuple haskellValue))
castStorablePtr = castPtr



loadNewtype ::
   (C a) =>
   (a -> llvmValue) ->
   Value (Ptr (Struct a)) -> CodeGenFunction r llvmValue
loadNewtype wrap ptr =
   fmap wrap $ load ptr

storeNewtype ::
   (C a) =>
   (llvmValue -> a) ->
   llvmValue -> Value (Ptr (Struct a)) -> CodeGenFunction r ()
storeNewtype unwrap y ptr =
   store (unwrap y) ptr

decomposeNewtype ::
   (C a) =>
   (a -> llvmValue) ->
   Value (Struct a) -> CodeGenFunction r llvmValue
decomposeNewtype wrap y =
   fmap wrap $ decompose y

composeNewtype ::
   (C a) =>
   (llvmValue -> a) ->
   llvmValue -> CodeGenFunction r (Value (Struct a))
composeNewtype unwrap y =
   compose (unwrap y)