{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
module LLVM.Extra.Storable.Private where

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.ArithmeticPrivate as A
import qualified LLVM.Util.Proxy as LP
import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction, Value)

import qualified Type.Data.Num.Decimal as TypeNum

import qualified Control.Monad.Trans.Class as MT
import qualified Control.Monad.Trans.Reader as MR
import qualified Control.Monad.Trans.State as MS
import qualified Control.Applicative.HT as App
import qualified Control.Functor.HT as FuncHT
import Control.Monad (foldM, replicateM, replicateM_, (<=<))
import Control.Applicative (Applicative, pure)

import qualified Foreign.Storable.Record.Tuple as StoreTuple
import qualified Foreign.Storable as Store
import Foreign.Storable.FixedArray (roundUp)
import Foreign.Ptr (Ptr)

import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import Data.Orphans ()
import Data.Complex (Complex)
import Data.Word (Word8, Word16, Word32, Word64, Word)
import Data.Int  (Int8,  Int16,  Int32,  Int64)
import Data.Bool8 (Bool8)



class
   (Store.Storable a, Tuple.Value a,
    Tuple.Phi (Tuple.ValueOf a), Tuple.Undefined (Tuple.ValueOf a)) =>
      C a where

   {-
   Not all Storable types have a compatible LLVM type,
   or even more, one LLVM type that is compatible on all platforms.
   -}
   load :: Value (Ptr a) -> CodeGenFunction r (Tuple.ValueOf a)
   store :: Tuple.ValueOf a -> Value (Ptr a) -> CodeGenFunction r ()

storeNext ::
   (C a, Tuple.ValueOf a ~ al, Value (Ptr a) ~ ptr) =>
   al -> ptr -> CodeGenFunction r ptr
storeNext a ptr  =  store a ptr >> incrementPtr ptr

modify ::
   (C a, Tuple.ValueOf a ~ al) =>
   (al -> CodeGenFunction r al) ->
   Value (Ptr a) -> CodeGenFunction r ()
modify f ptr  =  flip store ptr =<< f =<< load ptr


loadMultiValue ::
   (C a) => Value (Ptr a) -> CodeGenFunction r (MultiValue.T a)
loadMultiValue ptr = fmap MultiValue.Cons $ load ptr

storeMultiValue ::
   (C a) => MultiValue.T a -> Value (Ptr a) -> CodeGenFunction r ()
storeMultiValue (MultiValue.Cons a) ptr = store a ptr

storeNextMultiValue ::
   (C a) => MultiValue.T a -> Value (Ptr a) -> CodeGenFunction r (Value (Ptr a))
storeNextMultiValue (MultiValue.Cons a) ptr =
   store a ptr >> incrementPtr ptr

modifyMultiValue ::
   (C a) =>
   (MultiValue.T a -> CodeGenFunction r (MultiValue.T a)) ->
   Value (Ptr a) -> CodeGenFunction r ()
modifyMultiValue f ptr =
   flip storeMultiValue ptr =<< f =<< loadMultiValue ptr


loadPrimitive ::
   (LLVM.Storable a) => Value (Ptr a) -> CodeGenFunction r (Value a)
loadPrimitive ptr = LLVM.load =<< LLVM.bitcast ptr

storePrimitive ::
   (LLVM.Storable a) => Value a -> Value (Ptr a) -> CodeGenFunction r ()
storePrimitive a ptr = LLVM.store a =<< LLVM.bitcast ptr

instance C Float where
   load = loadPrimitive; store = storePrimitive

instance C Double where
   load = loadPrimitive; store = storePrimitive

instance C Word where
   load = loadPrimitive; store = storePrimitive

instance C Word8 where
   load = loadPrimitive; store = storePrimitive

instance C Word16 where
   load = loadPrimitive; store = storePrimitive

instance C Word32 where
   load = loadPrimitive; store = storePrimitive

instance C Word64 where
   load = loadPrimitive; store = storePrimitive

instance C Int where
   load = loadPrimitive; store = storePrimitive

instance C Int8 where
   load = loadPrimitive; store = storePrimitive

instance C Int16 where
   load = loadPrimitive; store = storePrimitive

instance C Int32 where
   load = loadPrimitive; store = storePrimitive

instance C Int64 where
   load = loadPrimitive; store = storePrimitive

{- |
Not very efficient implementation
because we want to adapt to @sizeOf Bool@ dynamically.
Unfortunately, LLVM-9's optimizer does not recognize the instruction pattern.
Better use 'Bool8' for booleans.
-}
instance C Bool where
   load ptr = do
      bytePtr <- castToBytePtr ptr
      bytes <-
         flip MS.evalStateT bytePtr $
            replicateM (Store.sizeOf (False :: Bool))
               (MT.lift . LLVM.load =<< incPtrState)
      let zero = LLVM.valueOf 0
      mask <- foldM A.or zero bytes
      A.cmp LLVM.CmpNE mask zero
   store b ptr = do
      bytePtr <- castToBytePtr ptr
      byte <- LLVM.sext b
      flip MS.evalStateT bytePtr $
         replicateM_ (Store.sizeOf (False :: Bool))
            (MT.lift . LLVM.store byte =<< incPtrState)

incPtrState :: MS.StateT BytePtr (CodeGenFunction r) BytePtr
incPtrState = update A.advanceArrayElementPtr

instance C Bool8 where
   load ptr =
      A.cmp LLVM.CmpNE (LLVM.valueOf 0) =<< LLVM.load =<< castToBytePtr ptr
   store b ptr = do
      byte <- LLVM.zext b
      LLVM.store byte =<< castToBytePtr ptr

instance (C a) => C (Complex a) where
   load = loadApplicative; store = storeFoldable



instance (Tuple tuple) => C (StoreTuple.Tuple tuple) where
   load = loadTuple
   store = storeTuple

class
   (StoreTuple.Storable tuple, Tuple.Value tuple,
    Tuple.Phi (Tuple.ValueOf tuple), Tuple.Undefined (Tuple.ValueOf tuple)) =>
      Tuple tuple where
   loadTuple ::
      Value (Ptr (StoreTuple.Tuple tuple)) ->
      CodeGenFunction r (Tuple.ValueOf tuple)
   storeTuple ::
      Tuple.ValueOf tuple ->
      Value (Ptr (StoreTuple.Tuple tuple)) ->
      CodeGenFunction r ()

instance (C a, C b) => Tuple (a,b) where
   loadTuple ptr =
      runElements ptr $
         App.mapPair (loadElement, loadElement) $
         FuncHT.unzip $ proxyFromElement3 ptr
   storeTuple (a,b) ptr =
      case FuncHT.unzip $ proxyFromElement3 ptr of
         (pa,pb) -> runElements ptr $ storeElement pa a >> storeElement pb b

instance (C a, C b, C c) => Tuple (a,b,c) where
   loadTuple ptr =
      runElements ptr $
         App.mapTriple (loadElement, loadElement, loadElement) $
         FuncHT.unzip3 $ proxyFromElement3 ptr
   storeTuple (a,b,c) ptr =
      case FuncHT.unzip3 $ proxyFromElement3 ptr of
         (pa,pb,pc) ->
            runElements ptr $
               storeElement pa a >> storeElement pb b >> storeElement pc c

runElements ::
   Value (Ptr a) ->
   MR.ReaderT BytePtr (MS.StateT Int (CodeGenFunction r)) c ->
   CodeGenFunction r c
runElements ptr act = do
   bytePtr <- castToBytePtr ptr
   flip MS.evalStateT 0 $ flip MR.runReaderT bytePtr act

loadElement ::
   (C a) =>
   LP.Proxy a ->
   MR.ReaderT BytePtr (MS.StateT Int (CodeGenFunction r)) (Tuple.ValueOf a)
loadElement proxy =
   MT.lift . MT.lift . load =<< elementPtr proxy

storeElement ::
   (C a) =>
   LP.Proxy a -> Tuple.ValueOf a ->
   MR.ReaderT BytePtr (MS.StateT Int (CodeGenFunction r)) ()
storeElement proxy a =
   MT.lift . MT.lift . store a =<< elementPtr proxy

elementPtr ::
   (C a) =>
   LP.Proxy a ->
   MR.ReaderT BytePtr
      (MS.StateT Int (CodeGenFunction r)) (LLVM.Value (Ptr a))
elementPtr proxy = do
   ptr <- MR.ask
   MT.lift $ do
      offset <- elementOffset proxy
      MT.lift $ castFromBytePtr =<< LLVM.getElementPtr ptr (offset, ())

elementOffset ::
   (Monad m, Store.Storable a) => LP.Proxy a -> MS.StateT Int m Int
elementOffset proxy = do
   let dummy = elementFromProxy proxy
   MS.modify (roundUp $ Store.alignment dummy)
   offset <- MS.get
   MS.modify (+ Store.sizeOf dummy)
   return offset


instance
   (TypeNum.Positive n, Vector a, Tuple.VectorValue n a,
    Tuple.Phi (Tuple.VectorValueOf n a)) =>
      C (LLVM.Vector n a) where
   load ptr =
      assembleVector (proxyFromElement3 ptr) =<< loadApplicative ptr
   store a ptr =
      flip storeFoldable ptr
         =<< disassembleVector (proxyFromElement3 ptr) a

class (C a) => Vector a where
   assembleVector ::
      (TypeNum.Positive n) =>
      LP.Proxy a -> LLVM.Vector n (Tuple.ValueOf a) ->
      CodeGenFunction r (Tuple.VectorValueOf n a)
   disassembleVector ::
      (TypeNum.Positive n) =>
      LP.Proxy a -> Tuple.VectorValueOf n a ->
      CodeGenFunction r (LLVM.Vector n (Tuple.ValueOf a))

assemblePrimitive ::
   (TypeNum.Positive n, LLVM.IsPrimitive a) =>
   LLVM.Vector n (Value a) -> CodeGenFunction r (Value (LLVM.Vector n a))
assemblePrimitive =
   foldM
      (\v (i,x) -> LLVM.insertelement v x (LLVM.valueOf i))
      (LLVM.value LLVM.undef)
    . zip [0..] . Fold.toList

disassemblePrimitive ::
   (TypeNum.Positive n, LLVM.IsPrimitive a) =>
   Value (LLVM.Vector n a) -> CodeGenFunction r (LLVM.Vector n (Value a))
disassemblePrimitive v =
   Trav.mapM (LLVM.extractelement v . LLVM.valueOf) indices

indices :: (Applicative f, Trav.Traversable f) => f Word32
indices =
   flip MS.evalState 0 $ Trav.sequenceA $ pure $ MS.state (\k -> (k,k+1))

instance Vector Float where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Double where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Word where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Word8 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Word16 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Word32 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Word64 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Int where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Int8 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Int16 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Int32 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Int64 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Bool where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Bool8 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive


instance
   (Tuple tuple, TupleVector tuple) =>
      Vector (StoreTuple.Tuple tuple) where
   assembleVector = deinterleave . fmap StoreTuple.getTuple
   disassembleVector = interleave . fmap StoreTuple.getTuple


class TupleVector a where
   deinterleave ::
      (TypeNum.Positive n) =>
      LP.Proxy a -> LLVM.Vector n (Tuple.ValueOf a) ->
      CodeGenFunction r (Tuple.VectorValueOf n a)
   interleave ::
      (TypeNum.Positive n) =>
      LP.Proxy a -> Tuple.VectorValueOf n a ->
      CodeGenFunction r (LLVM.Vector n (Tuple.ValueOf a))

instance (Vector a, Vector b) => TupleVector (a,b) where
   deinterleave = FuncHT.uncurry $ \pa pb -> FuncHT.uncurry $ \a b ->
      App.lift2 (,) (assembleVector pa a) (assembleVector pb b)
   interleave = FuncHT.uncurry $ \pa pb (a,b) ->
      App.lift2 (App.lift2 (,))
         (disassembleVector pa a) (disassembleVector pb b)

instance (Vector a, Vector b, Vector c) => TupleVector (a,b,c) where
   deinterleave = FuncHT.uncurry3 $ \pa pb pc -> FuncHT.uncurry3 $ \a b c ->
      App.lift3 (,,)
         (assembleVector pa a)
         (assembleVector pb b)
         (assembleVector pc c)
   interleave = FuncHT.uncurry3 $ \pa pb pc (a,b,c) ->
      App.lift3 (App.lift3 (,,))
         (disassembleVector pa a)
         (disassembleVector pb b)
         (disassembleVector pc c)


{-
instance Storable () available since base-4.9/GHC-8.0.
Before we need Data.Orphans.
-}
instance C () where
   load _ptr = return ()
   store () _ptr = return ()


loadNewtype ::
   (C a, Tuple.ValueOf a ~ al) =>
   (a -> wrapped) ->
   (al -> wrappedl) ->
   Value (Ptr wrapped) -> CodeGenFunction r wrappedl
loadNewtype wrap wrapl =
   fmap wrapl . load <=< rmapPtr wrap

storeNewtype ::
   (C a, Tuple.ValueOf a ~ al) =>
   (a -> wrapped) ->
   (wrappedl -> al) ->
   wrappedl -> Value (Ptr wrapped) -> CodeGenFunction r ()
storeNewtype wrap unwrapl y =
   store (unwrapl y) <=< rmapPtr wrap

rmapPtr :: (a -> b) -> Value (Ptr b) -> CodeGenFunction r (Value (Ptr a))
rmapPtr _f = LLVM.bitcast


loadTraversable ::
   (NonEmptyC.Repeat f, Trav.Traversable f, C a, Tuple.ValueOf a ~ al) =>
   Value (Ptr (f a)) -> CodeGenFunction r (f al)
loadTraversable =
   (MS.evalStateT $ Trav.sequence $ NonEmptyC.repeat $ loadState)
      <=< castElementPtr

loadApplicative ::
   (Applicative f, Trav.Traversable f, C a, Tuple.ValueOf a ~ al) =>
   Value (Ptr (f a)) -> CodeGenFunction r (f al)
loadApplicative =
   (MS.evalStateT $ Trav.sequence $ pure loadState) <=< castElementPtr

loadState ::
   (C a, Tuple.ValueOf a ~ al) =>
   MS.StateT (Value (Ptr a)) (CodeGenFunction r) al
loadState = MT.lift . load =<< advancePtrState


storeFoldable ::
   (Fold.Foldable f, C a, Tuple.ValueOf a ~ al) =>
   f al -> Value (Ptr (f a)) -> CodeGenFunction r ()
storeFoldable xs = MS.evalStateT (Fold.mapM_ storeState xs) <=< castElementPtr

storeState ::
   (C a, Tuple.ValueOf a ~ al) =>
   al -> MS.StateT (Value (Ptr a)) (CodeGenFunction r) ()
storeState a = MT.lift . store a =<< advancePtrState


update :: (Monad m) => (a -> m a) -> MS.StateT a m a
update f = MS.StateT $ \a0 -> do a1 <- f a0; return (a0,a1)

advancePtrState ::
   (C a, Tuple.ValueOf a ~ al, Value (Ptr a) ~ ptr) =>
   MS.StateT ptr (CodeGenFunction r) ptr
advancePtrState = update $ advancePtrStatic 1

advancePtr ::
   (Store.Storable a, Value (Ptr a) ~ ptr) =>
   Value Int -> ptr -> CodeGenFunction r ptr
advancePtr n ptr = do
   size <- A.mul n $ LLVM.valueOf $ Store.sizeOf (elementFromPtr ptr)
   addPointer size ptr

advancePtrStatic ::
   (Store.Storable a, Value (Ptr a) ~ ptr) =>
   Int -> ptr -> CodeGenFunction r ptr
advancePtrStatic n ptr =
   addPointer (LLVM.valueOf (Store.sizeOf (elementFromPtr ptr) * n)) ptr

incrementPtr ::
   (Store.Storable a, Value (Ptr a) ~ ptr) =>
   ptr -> CodeGenFunction r ptr
incrementPtr = advancePtrStatic 1

decrementPtr ::
   (Store.Storable a, Value (Ptr a) ~ ptr) =>
   ptr -> CodeGenFunction r ptr
decrementPtr = advancePtrStatic (-1)

addPointer :: Value Int -> Value (Ptr a) -> CodeGenFunction r (Value (Ptr a))
addPointer k ptr = do
   bytePtr <- castToBytePtr ptr
   castFromBytePtr =<< LLVM.getElementPtr bytePtr (k, ())

type BytePtr = Value (LLVM.Ptr Word8)

castToBytePtr :: Value (Ptr a) -> CodeGenFunction r BytePtr
castToBytePtr = LLVM.bitcast

castFromBytePtr :: BytePtr -> CodeGenFunction r (Value (Ptr a))
castFromBytePtr = LLVM.bitcast

castElementPtr :: Value (Ptr (f a)) -> CodeGenFunction r (Value (Ptr a))
castElementPtr = LLVM.bitcast


sizeOf :: (Store.Storable a) => LP.Proxy a -> Int
sizeOf = Store.sizeOf . elementFromProxy

elementFromPtr :: LLVM.Value (Ptr a) -> a
elementFromPtr _ = error "elementFromProxy"

elementFromProxy :: LP.Proxy a -> a
elementFromProxy LP.Proxy = error "elementFromProxy"

proxyFromElement2 :: f (g a) -> LP.Proxy a
proxyFromElement2 _ = LP.Proxy

proxyFromElement3 :: f (g (h a)) -> LP.Proxy a
proxyFromElement3 _ = LP.Proxy