{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
module LLVM.Extra.Multi.Value.Memory where

import qualified LLVM.Extra.Multi.Value as MultiValue
import LLVM.Extra.MemoryPrivate (decomposeFromLoad, composeFromStore, )

import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction, Value, )

import qualified Type.Data.Num.Decimal as TypeNum

import Foreign.StablePtr (StablePtr, )
import Foreign.Ptr (Ptr, FunPtr, castPtr, )

import Data.Complex (Complex, )
import Data.Word (Word8, Word16, Word32, Word64, )
import Data.Int (Int8, Int16, Int32, Int64, )

import Control.Applicative (pure, liftA2, liftA3, (<*>), )


class (MultiValue.C a, LLVM.IsSized (Struct a)) => C a where
   {-# MINIMAL (load|decompose), (store|compose) #-}
   type Struct a :: *
   load :: Value (Ptr (Struct a)) -> CodeGenFunction r (MultiValue.T a)
   load ptr  =  decompose =<< LLVM.load ptr
   store :: MultiValue.T a -> Value (Ptr (Struct a)) -> CodeGenFunction r ()
   store r ptr  =  flip LLVM.store ptr =<< compose r
   decompose :: Value (Struct a) -> CodeGenFunction r (MultiValue.T a)
   decompose = decomposeFromLoad load
   compose :: MultiValue.T a -> CodeGenFunction r (Value (Struct a))
   compose = composeFromStore store

instance C Float where
   type Struct Float = Float
   load = loadPrimitive
   store = storePrimitive
   decompose = decomposePrimitive
   compose = composePrimitive

instance C Double where
   type Struct Double = Double
   load = loadPrimitive
   store = storePrimitive
   decompose = decomposePrimitive
   compose = composePrimitive

instance C Word8 where
   type Struct Word8 = Word8
   load = loadPrimitive
   store = storePrimitive
   decompose = decomposePrimitive
   compose = composePrimitive

instance C Word16 where
   type Struct Word16 = Word16
   load = loadPrimitive
   store = storePrimitive
   decompose = decomposePrimitive
   compose = composePrimitive

instance C Word32 where
   type Struct Word32 = Word32
   load = loadPrimitive
   store = storePrimitive
   decompose = decomposePrimitive
   compose = composePrimitive

instance C Word64 where
   type Struct Word64 = Word64
   load = loadPrimitive
   store = storePrimitive
   decompose = decomposePrimitive
   compose = composePrimitive

instance C Int8 where
   type Struct Int8 = Int8
   load = loadPrimitive
   store = storePrimitive
   decompose = decomposePrimitive
   compose = composePrimitive

instance C Int16 where
   type Struct Int16 = Int16
   load = loadPrimitive
   store = storePrimitive
   decompose = decomposePrimitive
   compose = composePrimitive

instance C Int32 where
   type Struct Int32 = Int32
   load = loadPrimitive
   store = storePrimitive
   decompose = decomposePrimitive
   compose = composePrimitive

instance C Int64 where
   type Struct Int64 = Int64
   load = loadPrimitive
   store = storePrimitive
   decompose = decomposePrimitive
   compose = composePrimitive

instance (LLVM.IsType a) => C (Ptr a) where
   type Struct (Ptr a) = Ptr a
   load = loadPrimitive
   store = storePrimitive
   decompose = decomposePrimitive
   compose = composePrimitive

instance (LLVM.IsFunction a) => C (FunPtr a) where
   type Struct (FunPtr a) = FunPtr a
   load = loadPrimitive
   store = storePrimitive
   decompose = decomposePrimitive
   compose = composePrimitive

instance C (StablePtr a) where
   type Struct (StablePtr a) = StablePtr a
   load = loadPrimitive
   store = storePrimitive
   decompose = decomposePrimitive
   compose = composePrimitive


loadPrimitive ::
   (MultiValue.Repr Value a ~ Value a) =>
   Value (Ptr a) -> CodeGenFunction r (MultiValue.T a)
loadPrimitive = fmap MultiValue.Cons . LLVM.load

storePrimitive ::
   (MultiValue.Repr Value a ~ Value a) =>
   MultiValue.T a -> Value (Ptr a) -> CodeGenFunction r ()
storePrimitive (MultiValue.Cons a) = LLVM.store a

decomposePrimitive ::
   (MultiValue.Repr Value a ~ Value a) =>
   Value a -> CodeGenFunction r (MultiValue.T a)
decomposePrimitive = return . MultiValue.Cons

composePrimitive ::
   (MultiValue.Repr Value a ~ Value a) =>
   MultiValue.T a -> CodeGenFunction r (Value a)
composePrimitive (MultiValue.Cons a) = return a


instance C () where
   type Struct () = LLVM.Struct ()
   load = loadUnit
   store = storeUnit
   decompose = decomposeUnit
   compose = composeUnit

loadUnit ::
   (MultiValue.Repr Value a ~ ()) =>
   Value (Ptr (LLVM.Struct ())) -> CodeGenFunction r (MultiValue.T a)
loadUnit _ = return $ MultiValue.Cons ()

storeUnit ::
   MultiValue.T a -> Value (Ptr (LLVM.Struct ())) -> CodeGenFunction r ()
storeUnit _ _ = return ()

decomposeUnit ::
   (MultiValue.Repr Value a ~ ()) =>
   Value (LLVM.Struct ()) -> CodeGenFunction r (MultiValue.T a)
decomposeUnit _ = return $ MultiValue.Cons ()

composeUnit ::
   MultiValue.T a -> CodeGenFunction r (Value (LLVM.Struct ()))
composeUnit _ = return (LLVM.value $ LLVM.constStruct ())


instance (C a) => C (Complex a) where
   type Struct (Complex a) = LLVM.Struct (Struct a, (Struct a, ()))
   decompose c =
      liftA2 MultiValue.consComplex
         (decompose =<< LLVM.extractvalue c TypeNum.d0)
         (decompose =<< LLVM.extractvalue c TypeNum.d1)
   compose c =
      case MultiValue.deconsComplex c of
         (r,i) -> do
            sr <- compose r
            si <- compose i
            rr <- LLVM.insertvalue (LLVM.value LLVM.undef) sr TypeNum.d0
            LLVM.insertvalue rr si TypeNum.d1


instance (C a, C b) => C (a,b) where
   type Struct (a,b) = LLVM.Struct (Struct a, (Struct b, ()))
   decompose ab =
      liftA2 MultiValue.zip
         (decompose =<< LLVM.extractvalue ab TypeNum.d0)
         (decompose =<< LLVM.extractvalue ab TypeNum.d1)
   compose ab =
      case MultiValue.unzip ab of
         (a,b) -> do
            sa <- compose a
            sb <- compose b
            ra <- LLVM.insertvalue (LLVM.value LLVM.undef) sa TypeNum.d0
            LLVM.insertvalue ra sb TypeNum.d1

instance (C a, C b, C c) => C (a,b,c) where
   type Struct (a,b,c) = LLVM.Struct (Struct a, (Struct b, (Struct c, ())))
   decompose abc =
      liftA3 MultiValue.zip3
         (decompose =<< LLVM.extractvalue abc TypeNum.d0)
         (decompose =<< LLVM.extractvalue abc TypeNum.d1)
         (decompose =<< LLVM.extractvalue abc TypeNum.d2)
   compose abc =
      case MultiValue.unzip3 abc of
         (a,b,c) -> do
            sa <- compose a
            sb <- compose b
            sc <- compose c
            ra <- LLVM.insertvalue (LLVM.value LLVM.undef) sa TypeNum.d0
            rb <- LLVM.insertvalue ra sb TypeNum.d1
            LLVM.insertvalue rb sc TypeNum.d2

instance (C a, C b, C c, C d) => C (a,b,c,d) where
   type Struct (a,b,c,d) = LLVM.Struct (Struct a, (Struct b, (Struct c, (Struct d, ()))))
   decompose abcd =
      pure MultiValue.zip4
         <*> (decompose =<< LLVM.extractvalue abcd TypeNum.d0)
         <*> (decompose =<< LLVM.extractvalue abcd TypeNum.d1)
         <*> (decompose =<< LLVM.extractvalue abcd TypeNum.d2)
         <*> (decompose =<< LLVM.extractvalue abcd TypeNum.d3)
   compose abcd =
      case MultiValue.unzip4 abcd of
         (a,b,c,d) -> do
            sa <- compose a
            sb <- compose b
            sc <- compose c
            sd <- compose d
            ra <- LLVM.insertvalue (LLVM.value LLVM.undef) sa TypeNum.d0
            rb <- LLVM.insertvalue ra sb TypeNum.d1
            rc <- LLVM.insertvalue rb sc TypeNum.d2
            LLVM.insertvalue rc sd TypeNum.d3


castStructPtr :: Ptr a -> Ptr (Struct a)
castStructPtr = castPtr