{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
module LLVM.Extra.Memory (
   C(load, store, decompose, compose), modify,
   Struct,
   Record, Element, element,
   loadRecord, storeRecord, decomposeRecord, composeRecord,
   loadNewtype, storeNewtype, decomposeNewtype, composeNewtype,
   ) where

import LLVM.Extra.MemoryPrivate (decomposeFromLoad, composeFromStore, )

import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.ArithmeticPrivate as A
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Scalar as Scalar
import qualified LLVM.Extra.Array as Array
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Either as Either
import qualified LLVM.Extra.Maybe as Maybe

import qualified LLVM.Util.Proxy as LP
import qualified LLVM.Core as LLVM
import LLVM.Core
   (getElementPtr0,
    extractvalue, insertvalue,
    Value, -- valueOf, Vector,
    IsType, IsSized,
    CodeGenFunction, )

import qualified Type.Data.Num.Decimal as TypeNum
import qualified Type.Data.Num.Unary as Unary
import Type.Data.Num.Decimal (d0, d1, d2, d3)
import Type.Base.Proxy (Proxy(Proxy), )

import Foreign.StablePtr (StablePtr, )
import Foreign.Ptr (FunPtr)

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import qualified Data.FixedLength as FixedLength
import Data.Tuple.HT (fst3, snd3, thd3, )
import Data.Word (Word8, Word16, Word32, Word64, Word)
import Data.Int  (Int8,  Int16,  Int32,  Int64, )

import qualified Control.Applicative.HT as App
import Control.Monad (ap, (<=<))
import Control.Applicative (Applicative, pure, liftA2, liftA3, (<*>))

import Prelude2010 hiding (maybe, either, )
import Prelude ()


{- |
An implementation of both 'Tuple.Value' and 'Memory.C'
must ensure that @haskellValue@ is compatible
with @Stored (Struct haskellValue)@ (which we want to call @llvmStruct@).
That is, writing and reading @llvmStruct@ by LLVM
must be the same as accessing @haskellValue@ by 'Storable' methods.
ToDo: In future we may also require Storable constraint for @llvmStruct@.

We use a functional dependency in order to let type inference work nicely.
-}
class (Tuple.Phi llvmValue, Tuple.Undefined llvmValue, IsType (Struct llvmValue), IsSized (Struct llvmValue)) =>
      C llvmValue where
   {-# MINIMAL (load|decompose), (store|compose) #-}
   type Struct llvmValue :: *
   load :: Value (LLVM.Ptr (Struct llvmValue)) -> CodeGenFunction r llvmValue
   load ptr  =  decompose =<< LLVM.load ptr
   store :: llvmValue -> Value (LLVM.Ptr (Struct llvmValue)) -> CodeGenFunction r ()
   store r ptr  =  flip LLVM.store ptr =<< compose r
   decompose :: Value (Struct llvmValue) -> CodeGenFunction r llvmValue
   decompose = decomposeFromLoad load
   compose :: llvmValue -> CodeGenFunction r (Value (Struct llvmValue))
   compose = composeFromStore store

modify ::
   (C llvmValue) =>
   (llvmValue -> CodeGenFunction r llvmValue) ->
   Value (LLVM.Ptr (Struct llvmValue)) -> CodeGenFunction r ()
modify f ptr =
   flip store ptr =<< f =<< load ptr


instance C () where
   type Struct () = LLVM.Struct ()
   load _ = return ()
   store _ _ = return ()
   decompose _ = return ()
   compose _ = return (LLVM.value $ LLVM.constStruct ())


type Record r o v = Element r o v v

data Element r o v x =
   Element {
      loadElement :: Value (LLVM.Ptr o) -> CodeGenFunction r x,
      storeElement :: Value (LLVM.Ptr o) -> v -> CodeGenFunction r (),
      extractElement :: Value o -> CodeGenFunction r x,
      insertElement :: v -> Value o -> CodeGenFunction r (Value o)
         -- State.Monoid
   }

element ::
   (C x,
    LLVM.GetValue o n, LLVM.ValueType o n ~ Struct x,
    LLVM.GetElementPtr o (n, ()), LLVM.ElementPtrType o (n, ()) ~ Struct x) =>
   (v -> x) -> n -> Element r o v x
element field n =
   Element {
      loadElement = \ptr -> load =<< getElementPtr0 ptr (n, ()),
      storeElement = \ptr v -> store (field v) =<< getElementPtr0 ptr (n, ()),
      extractElement = \o -> decompose =<< extractvalue o n,
      insertElement = \v o -> flip (insertvalue o) n =<< compose (field v)
   }

instance Functor (Element r o v) where
   fmap f m =
      Element {
         loadElement = fmap f . loadElement m,
         storeElement = storeElement m,
         extractElement = fmap f . extractElement m,
         insertElement = insertElement m
      }

instance Applicative (Element r o v) where
   pure x =
      Element {
         loadElement = \ _ptr -> return x,
         storeElement = \ _ptr _v -> return (),
         extractElement = \ _o -> return x,
         insertElement = \ _v o -> return o
      }
   f <*> x =
      Element {
         loadElement = \ptr -> loadElement f ptr `ap` loadElement x ptr,
         storeElement = \ptr y -> storeElement f ptr y >> storeElement x ptr y,
         extractElement = \o -> extractElement f o `ap` extractElement x o,
         insertElement = \y o -> insertElement f y o >>= insertElement x y
      }


loadRecord ::
   Record r o llvmValue ->
   Value (LLVM.Ptr o) -> CodeGenFunction r llvmValue
loadRecord = loadElement

storeRecord ::
   Record r o llvmValue ->
   llvmValue -> Value (LLVM.Ptr o) -> CodeGenFunction r ()
storeRecord m y ptr = storeElement m ptr y

decomposeRecord ::
   Record r o llvmValue ->
   Value o -> CodeGenFunction r llvmValue
decomposeRecord m =
   extractElement m

composeRecord ::
   (IsType o) =>
   Record r o llvmValue ->
   llvmValue -> CodeGenFunction r (Value o)
composeRecord m v =
   insertElement m v (LLVM.value LLVM.undef)



pair ::
   (C a, C b) =>
   Record r (LLVM.Struct (Struct a, (Struct b, ()))) (a, b)
pair =
   liftA2 (,)
      (element fst d0)
      (element snd d1)

instance (C a, C b) => C (a, b) where
   type Struct (a, b) = LLVM.Struct (Struct a, (Struct b, ()))
   load = loadRecord pair
   store = storeRecord pair
   decompose = decomposeRecord pair
   compose = composeRecord pair


triple ::
   (C a, C b, C c) =>
   Record r (LLVM.Struct (Struct a, (Struct b, (Struct c, ())))) (a, b, c)
triple =
   liftA3 (,,)
      (element fst3 d0)
      (element snd3 d1)
      (element thd3 d2)

instance (C a, C b, C c) => C (a, b, c) where
   type Struct (a, b, c) =
           LLVM.Struct (Struct a, (Struct b, (Struct c, ())))
   load = loadRecord triple
   store = storeRecord triple
   decompose = decomposeRecord triple
   compose = composeRecord triple


quadruple ::
   (C a, C b, C c, C d) =>
   Record r
      (LLVM.Struct (Struct a, (Struct b, (Struct c, (Struct d, ())))))
      (a, b, c, d)
quadruple =
   App.lift4 (,,,)
      (element (\(x,_,_,_) -> x) d0)
      (element (\(_,x,_,_) -> x) d1)
      (element (\(_,_,x,_) -> x) d2)
      (element (\(_,_,_,x) -> x) d3)

instance (C a, C b, C c, C d) => C (a, b, c, d) where
   type Struct (a, b, c, d) =
           LLVM.Struct (Struct a, (Struct b, (Struct c, (Struct d, ()))))
   load = loadRecord quadruple
   store = storeRecord quadruple
   decompose = decomposeRecord quadruple
   compose = composeRecord quadruple


instance
   (Unary.Natural n, C a,
    TypeNum.Natural (TypeNum.FromUnary n),
    TypeNum.Natural (TypeNum.FromUnary n TypeNum.:*: LLVM.SizeOf (Struct a)),
    LLVM.IsFirstClass (Struct a)) =>
      C (FixedLength.T n a) where
   type Struct (FixedLength.T n a) =
            LLVM.Array (TypeNum.FromUnary n) (Struct a)
   compose xs =
      Fold.foldlM
         (\arr (x,i) -> compose x >>= \xc -> LLVM.insertvalue arr xc i)
         (LLVM.value LLVM.undef) $
      FixedLength.zipWith (,) xs $ iterateTrav (1+) (0::Word64)
   decompose arr =
      Trav.mapM (decompose <=< LLVM.extractvalue arr) $
      iterateTrav (1+) (0::Word64)

iterateTrav :: (Applicative t, Trav.Traversable t) => (a -> a) -> a -> t a
iterateTrav f a0 = snd $ Trav.mapAccumL (\a () -> (f a, a)) a0 $ pure ()


maybe ::
   (C a) =>
   Record r (LLVM.Struct (Bool, (Struct a, ()))) (Maybe.T a)
maybe =
   liftA2 Maybe.Cons
      (element Maybe.isJust d0)
      (element Maybe.fromJust d1)

instance (C a) => C (Maybe.T a) where
   type Struct (Maybe.T a) = LLVM.Struct (Bool, (Struct a, ()))
   load = loadRecord maybe
   store = storeRecord maybe
   decompose = decomposeRecord maybe
   compose = composeRecord maybe


either ::
   (C a, C b) =>
   Record r (LLVM.Struct (Bool, (Struct a, (Struct b, ())))) (Either.T a b)
either =
   liftA3 Either.Cons
      (element Either.isRight d0)
      (element Either.fromLeft d1)
      (element Either.fromRight d2)

instance (C a, C b) => C (Either.T a b) where
   type Struct (Either.T a b) = LLVM.Struct (Bool, (Struct a, (Struct b, ())))
   load = loadRecord either
   store = storeRecord either
   decompose = decomposeRecord either
   compose = composeRecord either



instance (C a) => C (Scalar.T a) where
   type Struct (Scalar.T a) = Struct a
   load = loadNewtype Scalar.Cons
   store = storeNewtype Scalar.decons
   decompose = decomposeNewtype Scalar.Cons
   compose = composeNewtype Scalar.decons


instance (IsSized a, LLVM.IsFirstClass a) => C (Value a) where
   type Struct (Value a) = a
   load = LLVM.load
   store = LLVM.store
   decompose = return
   compose = return


class (LLVM.IsFirstClass llvmType, IsType (Stored llvmType)) =>
      FirstClass llvmType where
   type Stored llvmType :: *
   fromStorable :: Value (Stored llvmType) -> CodeGenFunction r (Value llvmType)
   toStorable :: Value llvmType -> CodeGenFunction r (Value (Stored llvmType))

instance FirstClass Float  where type Stored Float  = Float  ; fromStorable = return; toStorable = return
instance FirstClass Double where type Stored Double = Double ; fromStorable = return; toStorable = return
instance FirstClass Int    where type Stored Int    = Int    ; fromStorable = return; toStorable = return
instance FirstClass Int8   where type Stored Int8   = Int8   ; fromStorable = return; toStorable = return
instance FirstClass Int16  where type Stored Int16  = Int16  ; fromStorable = return; toStorable = return
instance FirstClass Int32  where type Stored Int32  = Int32  ; fromStorable = return; toStorable = return
instance FirstClass Int64  where type Stored Int64  = Int64  ; fromStorable = return; toStorable = return
instance FirstClass Word   where type Stored Word   = Word   ; fromStorable = return; toStorable = return
instance FirstClass Word8  where type Stored Word8  = Word8  ; fromStorable = return; toStorable = return
instance FirstClass Word16 where type Stored Word16 = Word16 ; fromStorable = return; toStorable = return
instance FirstClass Word32 where type Stored Word32 = Word32 ; fromStorable = return; toStorable = return
instance FirstClass Word64 where type Stored Word64 = Word64 ; fromStorable = return; toStorable = return
instance FirstClass Bool   where
   type Stored Bool = Word32
   fromStorable = A.cmp LLVM.CmpNE (LLVM.value LLVM.zero)
   toStorable = LLVM.zext
instance
   (TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsPrimitive (Stored a), FirstClass a) =>
      FirstClass (LLVM.Vector n a) where
   type Stored (LLVM.Vector n a) = LLVM.Vector n (Stored a)
   fromStorable = Vector.map fromStorable
   toStorable = Vector.map toStorable
instance
   (TypeNum.Natural n, LLVM.IsFirstClass (Stored a),
    FirstClass a, IsSized a, IsSized (Stored a)) =>
      FirstClass (LLVM.Array n a) where
   type Stored (LLVM.Array n a) = LLVM.Array n (Stored a)
   fromStorable = Array.map fromStorable
   toStorable = Array.map toStorable

instance (IsType a) => FirstClass (LLVM.Ptr a) where
   type Stored (LLVM.Ptr a) = LLVM.Ptr a
   fromStorable = return; toStorable = return
instance (LLVM.IsFunction a) => FirstClass (FunPtr a) where
   type Stored (FunPtr a) = FunPtr a
   fromStorable = return; toStorable = return
instance FirstClass (StablePtr a) where
   type Stored (StablePtr a) = StablePtr a
   fromStorable = return; toStorable = return


instance
   (LLVM.IsFirstClass (LLVM.Struct s),
    IsType (LLVM.Struct (StoredStruct s)),
    ConvertStruct s TypeNum.D0 s) =>
      FirstClass (LLVM.Struct s) where
   type Stored (LLVM.Struct s) = LLVM.Struct (StoredStruct s)
   fromStorable sm =
      case LP.Proxy of
         sfields -> do
            s <- decomposeField sfields d0 sm
            let _ = asTypeOf (fields s) sfields
            return s
   toStorable s =
      composeField (fields s) d0 s

fields :: Value (LLVM.Struct s) -> LP.Proxy s
fields _ = LP.Proxy


type family StoredStruct s :: *
type instance StoredStruct () = ()
type instance StoredStruct (s,rem) = (Stored s, StoredStruct rem)

class ConvertStruct s i rem where
   decomposeField ::
      LP.Proxy rem -> Proxy i -> Value (LLVM.Struct (StoredStruct s)) ->
      CodeGenFunction r (Value (LLVM.Struct s))
   composeField ::
      LP.Proxy rem -> Proxy i -> Value (LLVM.Struct s) ->
      CodeGenFunction r (Value (LLVM.Struct (StoredStruct s)))

instance
   (sm ~ StoredStruct s,
    FirstClass a, am ~ Stored a,
    LLVM.GetValue (LLVM.Struct s) (Proxy i),
    LLVM.GetValue (LLVM.Struct sm) (Proxy i),
    LLVM.ValueType (LLVM.Struct s) (Proxy i) ~ a,
    LLVM.ValueType (LLVM.Struct sm) (Proxy i) ~ am,
    ConvertStruct s (TypeNum.Succ i) rem) =>
      ConvertStruct s i (a,rem) where
   decomposeField flds i sm = do
      s <- decomposeField (fmap snd flds) (decSucc i) sm
      a <- fromStorable =<< LLVM.extractvalue sm i
      LLVM.insertvalue s a i
   composeField flds i s = do
      sm <- composeField (fmap snd flds) (decSucc i) s
      am <- toStorable =<< LLVM.extractvalue s i
      LLVM.insertvalue sm am i

decSucc :: Proxy n -> Proxy (TypeNum.Succ n)
decSucc Proxy = Proxy

instance
   (sm ~ StoredStruct s,
    IsType (LLVM.Struct s),
    IsType (LLVM.Struct sm)) =>
      ConvertStruct s i () where
   decomposeField _ _ _ =
      return (LLVM.value LLVM.undef)
   composeField _ _ _ =
      return (LLVM.value LLVM.undef)


instance (MultiValue.C a, C (Tuple.ValueOf a)) => C (MultiValue.T a) where
   type Struct (MultiValue.T a) = Struct (Tuple.ValueOf a)
   load = fmap MultiValue.Cons . load
   store (MultiValue.Cons a) = store a
   decompose = fmap MultiValue.Cons . decompose
   compose (MultiValue.Cons a) = compose a

instance
   (TypeNum.Positive n, MultiVector.C a, C (Tuple.VectorValueOf n a)) =>
      C (MultiVector.T n a) where
   type Struct (MultiVector.T n a) = Struct (Tuple.VectorValueOf n a)
   load = fmap MultiVector.Cons . load
   store (MultiVector.Cons a) = store a
   decompose = fmap MultiVector.Cons . decompose
   compose (MultiVector.Cons a) = compose a



loadNewtype ::
   (C a) =>
   (a -> llvmValue) ->
   Value (LLVM.Ptr (Struct a)) -> CodeGenFunction r llvmValue
loadNewtype wrap ptr =
   fmap wrap $ load ptr

storeNewtype ::
   (C a) =>
   (llvmValue -> a) ->
   llvmValue -> Value (LLVM.Ptr (Struct a)) -> CodeGenFunction r ()
storeNewtype unwrap y ptr =
   store (unwrap y) ptr

decomposeNewtype ::
   (C a) =>
   (a -> llvmValue) ->
   Value (Struct a) -> CodeGenFunction r llvmValue
decomposeNewtype wrap y =
   fmap wrap $ decompose y

composeNewtype ::
   (C a) =>
   (llvmValue -> a) ->
   llvmValue -> CodeGenFunction r (Value (Struct a))
composeNewtype unwrap y =
   compose (unwrap y)