{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
module LLVM.Extra.Memory (
   C(load, store, decompose, compose), modify,
   Struct,
   Record, Element, element,
   loadRecord, storeRecord, decomposeRecord, composeRecord,
   loadNewtype, storeNewtype, decomposeNewtype, composeNewtype,
   ) where

import LLVM.Extra.MemoryPrivate (decomposeFromLoad, composeFromStore, )

import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Scalar as Scalar
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Struct as Struct
import qualified LLVM.Extra.Either as Either
import qualified LLVM.Extra.Maybe as Maybe

import qualified LLVM.Core as LLVM
import LLVM.Core
   (getElementPtr0,
    extractvalue, insertvalue,
    Value, -- valueOf, Vector,
    IsType, IsSized,
    CodeGenFunction, )

import qualified Type.Data.Num.Decimal as TypeNum
import qualified Type.Data.Num.Unary as Unary
import Type.Data.Num.Decimal (d0, d1, d2, d3)
import Type.Base.Proxy (Proxy(Proxy))

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import qualified Data.FixedLength as FixedLength
import Data.Tuple.HT (fst3, snd3, thd3, )
import Data.Word (Word)

import qualified Control.Applicative.HT as App
import Control.Monad (ap, (<=<))
import Control.Applicative (Applicative, pure, liftA2, liftA3, (<*>))

import Prelude2010 hiding (maybe, either, )
import Prelude ()


{- |
An implementation of both 'Tuple.Value' and 'Memory.C'
must ensure that @haskellValue@ is compatible
with @Stored (Struct haskellValue)@ (which we want to call @llvmStruct@).
That is, writing and reading @llvmStruct@ by LLVM
must be the same as accessing @haskellValue@ by 'Storable' methods.
ToDo: In future we may also require Storable constraint for @llvmStruct@.

We use a functional dependency in order to let type inference work nicely.
-}
class (Tuple.Phi llvmValue, Tuple.Undefined llvmValue, IsType (Struct llvmValue), IsSized (Struct llvmValue)) =>
      C llvmValue where
   {-# MINIMAL (load|decompose), (store|compose) #-}
   type Struct llvmValue :: *
   load :: Value (LLVM.Ptr (Struct llvmValue)) -> CodeGenFunction r llvmValue
   load ptr  =  decompose =<< LLVM.load ptr
   store :: llvmValue -> Value (LLVM.Ptr (Struct llvmValue)) -> CodeGenFunction r ()
   store r ptr  =  flip LLVM.store ptr =<< compose r
   decompose :: Value (Struct llvmValue) -> CodeGenFunction r llvmValue
   decompose = decomposeFromLoad load
   compose :: llvmValue -> CodeGenFunction r (Value (Struct llvmValue))
   compose = composeFromStore store

modify ::
   (C llvmValue) =>
   (llvmValue -> CodeGenFunction r llvmValue) ->
   Value (LLVM.Ptr (Struct llvmValue)) -> CodeGenFunction r ()
modify f ptr =
   flip store ptr =<< f =<< load ptr


instance C () where
   type Struct () = LLVM.Struct ()
   load _ = return ()
   store _ _ = return ()
   decompose _ = return ()
   compose _ = return (LLVM.value $ LLVM.constStruct ())


type Record r o v = Element r o v v

data Element r o v x =
   Element {
      loadElement :: Value (LLVM.Ptr o) -> CodeGenFunction r x,
      storeElement :: Value (LLVM.Ptr o) -> v -> CodeGenFunction r (),
      extractElement :: Value o -> CodeGenFunction r x,
      insertElement :: v -> Value o -> CodeGenFunction r (Value o)
         -- State.Monoid
   }

element ::
   (C x,
    LLVM.GetValue o n, LLVM.ValueType o n ~ Struct x,
    LLVM.GetElementPtr o (n, ()), LLVM.ElementPtrType o (n, ()) ~ Struct x) =>
   (v -> x) -> n -> Element r o v x
element field n =
   Element {
      loadElement = \ptr -> load =<< getElementPtr0 ptr (n, ()),
      storeElement = \ptr v -> store (field v) =<< getElementPtr0 ptr (n, ()),
      extractElement = \o -> decompose =<< extractvalue o n,
      insertElement = \v o -> flip (insertvalue o) n =<< compose (field v)
   }

instance Functor (Element r o v) where
   fmap f m =
      Element {
         loadElement = fmap f . loadElement m,
         storeElement = storeElement m,
         extractElement = fmap f . extractElement m,
         insertElement = insertElement m
      }

instance Applicative (Element r o v) where
   pure x =
      Element {
         loadElement = \ _ptr -> return x,
         storeElement = \ _ptr _v -> return (),
         extractElement = \ _o -> return x,
         insertElement = \ _v o -> return o
      }
   f <*> x =
      Element {
         loadElement = \ptr -> loadElement f ptr `ap` loadElement x ptr,
         storeElement = \ptr y -> storeElement f ptr y >> storeElement x ptr y,
         extractElement = \o -> extractElement f o `ap` extractElement x o,
         insertElement = \y o -> insertElement f y o >>= insertElement x y
      }


loadRecord ::
   Record r o llvmValue ->
   Value (LLVM.Ptr o) -> CodeGenFunction r llvmValue
loadRecord = loadElement

storeRecord ::
   Record r o llvmValue ->
   llvmValue -> Value (LLVM.Ptr o) -> CodeGenFunction r ()
storeRecord m y ptr = storeElement m ptr y

decomposeRecord ::
   Record r o llvmValue ->
   Value o -> CodeGenFunction r llvmValue
decomposeRecord m =
   extractElement m

composeRecord ::
   (IsType o) =>
   Record r o llvmValue ->
   llvmValue -> CodeGenFunction r (Value o)
composeRecord m v =
   insertElement m v (LLVM.value LLVM.undef)



pair ::
   (C a, C b) =>
   Record r (LLVM.Struct (Struct a, (Struct b, ()))) (a, b)
pair =
   liftA2 (,)
      (element fst d0)
      (element snd d1)

instance (C a, C b) => C (a, b) where
   type Struct (a, b) = LLVM.Struct (Struct a, (Struct b, ()))
   load = loadRecord pair
   store = storeRecord pair
   decompose = decomposeRecord pair
   compose = composeRecord pair


triple ::
   (C a, C b, C c) =>
   Record r (LLVM.Struct (Struct a, (Struct b, (Struct c, ())))) (a, b, c)
triple =
   liftA3 (,,)
      (element fst3 d0)
      (element snd3 d1)
      (element thd3 d2)

instance (C a, C b, C c) => C (a, b, c) where
   type Struct (a, b, c) =
           LLVM.Struct (Struct a, (Struct b, (Struct c, ())))
   load = loadRecord triple
   store = storeRecord triple
   decompose = decomposeRecord triple
   compose = composeRecord triple


quadruple ::
   (C a, C b, C c, C d) =>
   Record r
      (LLVM.Struct (Struct a, (Struct b, (Struct c, (Struct d, ())))))
      (a, b, c, d)
quadruple =
   App.lift4 (,,,)
      (element (\(x,_,_,_) -> x) d0)
      (element (\(_,x,_,_) -> x) d1)
      (element (\(_,_,x,_) -> x) d2)
      (element (\(_,_,_,x) -> x) d3)

instance (C a, C b, C c, C d) => C (a, b, c, d) where
   type Struct (a, b, c, d) =
           LLVM.Struct (Struct a, (Struct b, (Struct c, (Struct d, ()))))
   load = loadRecord quadruple
   store = storeRecord quadruple
   decompose = decomposeRecord quadruple
   compose = composeRecord quadruple


instance
   (Unary.Natural n, C a,
    TypeNum.Natural (TypeNum.FromUnary n),
    TypeNum.Natural (TypeNum.FromUnary n TypeNum.:*: LLVM.SizeOf (Struct a)),
    LLVM.IsFirstClass (Struct a)) =>
      C (FixedLength.T n a) where
   type Struct (FixedLength.T n a) =
            LLVM.Array (TypeNum.FromUnary n) (Struct a)
   compose xs =
      Fold.foldlM
         (\arr (x,i) -> compose x >>= \xc -> LLVM.insertvalue arr xc i)
         (LLVM.value LLVM.undef) $
      FixedLength.zipWith (,) xs $ iterateTrav (1+) (0::Word)
   decompose arr =
      Trav.mapM (decompose <=< LLVM.extractvalue arr) $
      iterateTrav (1+) (0::Word)

iterateTrav :: (Applicative t, Trav.Traversable t) => (a -> a) -> a -> t a
iterateTrav f a0 = snd $ Trav.mapAccumL (\a () -> (f a, a)) a0 $ pure ()


maybe ::
   (C a) =>
   Record r (LLVM.Struct (Bool, (Struct a, ()))) (Maybe.T a)
maybe =
   liftA2 Maybe.Cons
      (element Maybe.isJust d0)
      (element Maybe.fromJust d1)

instance (C a) => C (Maybe.T a) where
   type Struct (Maybe.T a) = LLVM.Struct (Bool, (Struct a, ()))
   load = loadRecord maybe
   store = storeRecord maybe
   decompose = decomposeRecord maybe
   compose = composeRecord maybe


either ::
   (C a, C b) =>
   Record r (LLVM.Struct (Bool, (Struct a, (Struct b, ())))) (Either.T a b)
either =
   liftA3 Either.Cons
      (element Either.isRight d0)
      (element Either.fromLeft d1)
      (element Either.fromRight d2)

instance (C a, C b) => C (Either.T a b) where
   type Struct (Either.T a b) = LLVM.Struct (Bool, (Struct a, (Struct b, ())))
   load = loadRecord either
   store = storeRecord either
   decompose = decomposeRecord either
   compose = composeRecord either



instance (C a) => C (Scalar.T a) where
   type Struct (Scalar.T a) = Struct a
   load = loadNewtype Scalar.Cons
   store = storeNewtype Scalar.decons
   decompose = decomposeNewtype Scalar.Cons
   compose = composeNewtype Scalar.decons


instance (IsSized a, LLVM.IsFirstClass a) => C (Value a) where
   type Struct (Value a) = a
   load = LLVM.load
   store = LLVM.store
   decompose = return
   compose = return


type family StructStruct s
type instance StructStruct (a,as) = (Struct a, StructStruct as)
type instance StructStruct () = ()

instance
   (Struct.Phi s, Struct.Undefined s,
    LLVM.StructFields (StructStruct s),
    ConvertStruct (StructStruct s) TypeNum.D0 s) =>
      C (Struct.T s) where
   type Struct (Struct.T s) = LLVM.Struct (StructStruct s)
   decompose = fmap Struct.Cons . decomposeFields TypeNum.d0
   compose (Struct.Cons s) = composeFields TypeNum.d0 s

class ConvertStruct s i rem where
   decomposeFields ::
      Proxy i -> Value (LLVM.Struct s) -> CodeGenFunction r rem
   composeFields ::
      Proxy i -> rem -> CodeGenFunction r (Value (LLVM.Struct s))

instance
   (TypeNum.Natural i, LLVM.GetField s i, LLVM.FieldType s i ~ Struct a, C a,
    ConvertStruct s (TypeNum.Succ i) rem) =>
      ConvertStruct s i (a,rem) where
   decomposeFields i sm =
      liftA2 (,)
         (decompose =<< LLVM.extractvalue sm i)
         (decomposeFields (decSucc i) sm)
   composeFields i (a,as) = do
      sm <- composeFields (decSucc i) as
      am <- compose a
      LLVM.insertvalue sm am i

decSucc :: Proxy n -> Proxy (TypeNum.Succ n)
decSucc Proxy = Proxy

instance (LLVM.StructFields s) => ConvertStruct s i () where
   decomposeFields _ _ = return ()
   composeFields _ _ = return (LLVM.value LLVM.undef)



instance (MultiValue.C a, C (Tuple.ValueOf a)) => C (MultiValue.T a) where
   type Struct (MultiValue.T a) = Struct (Tuple.ValueOf a)
   load = fmap MultiValue.Cons . load
   store (MultiValue.Cons a) = store a
   decompose = fmap MultiValue.Cons . decompose
   compose (MultiValue.Cons a) = compose a

instance
   (TypeNum.Positive n, MultiVector.C a, C (Tuple.VectorValueOf n a)) =>
      C (MultiVector.T n a) where
   type Struct (MultiVector.T n a) = Struct (Tuple.VectorValueOf n a)
   load = fmap MultiVector.Cons . load
   store (MultiVector.Cons a) = store a
   decompose = fmap MultiVector.Cons . decompose
   compose (MultiVector.Cons a) = compose a



loadNewtype ::
   (C a) =>
   (a -> llvmValue) ->
   Value (LLVM.Ptr (Struct a)) -> CodeGenFunction r llvmValue
loadNewtype wrap ptr =
   fmap wrap $ load ptr

storeNewtype ::
   (C a) =>
   (llvmValue -> a) ->
   llvmValue -> Value (LLVM.Ptr (Struct a)) -> CodeGenFunction r ()
storeNewtype unwrap y ptr =
   store (unwrap y) ptr

decomposeNewtype ::
   (C a) =>
   (a -> llvmValue) ->
   Value (Struct a) -> CodeGenFunction r llvmValue
decomposeNewtype wrap y =
   fmap wrap $ decompose y

composeNewtype ::
   (C a) =>
   (llvmValue -> a) ->
   llvmValue -> CodeGenFunction r (Value (Struct a))
composeNewtype unwrap y =
   compose (unwrap y)