{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
module Synthesizer.LLVM.CausalParameterized.ProcessPrivate where

import qualified Synthesizer.LLVM.Parameterized.SignalPrivate as Sig
import qualified Synthesizer.LLVM.Causal.ProcessPrivate as CausalPriv
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.ForeignPtr as ForeignPtr
import Synthesizer.LLVM.Causal.ProcessPrivate (loopNext)
import Synthesizer.LLVM.Causal.Process (mapProc, zipProcWith)
import Synthesizer.LLVM.Simple.SignalPrivate (proxyFromElement2)

import qualified Synthesizer.Causal.Class as CausalClass
import qualified Synthesizer.Causal.Utility as ArrowUtil

import qualified LLVM.DSL.Parameter as Param

import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.MaybeContinuation as MaybeCont
import qualified LLVM.Extra.Marshal as Marshal
import qualified LLVM.Extra.Memory as Memory

import qualified LLVM.ExecutionEngine as EE
import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction, Value, valueOf)

import Type.Data.Num.Decimal (d1)

import qualified Control.Monad.HT as M
import qualified Control.Arrow    as Arr
import qualified Control.Category as Cat
import Control.Arrow (arr, (^<<), (<<<), (&&&))
import Control.Applicative (Applicative, pure, (<*>), (<$>))
import Data.Tuple.HT (mapSnd)

import Data.Word (Word)

import Foreign.ForeignPtr (ForeignPtr, touchForeignPtr, mallocForeignPtrBytes)

import qualified System.Unsafe as Unsafe

import qualified Number.Ratio as Ratio
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring
import qualified Algebra.Additive as Additive

import NumericPrelude.Numeric
import NumericPrelude.Base hiding (and, iterate, map, zip, zipWith, take, takeWhile, init)

import qualified Prelude as P


data T p a b =
   forall context state local ioContext parameters.
      (Marshal.C parameters, Memory.C context, Memory.C state) =>
   Cons
      (forall r c.
       (Tuple.Phi c) =>
       context -> local ->
       a -> state -> MaybeCont.T r c (b, state))
          -- compute next value
      (forall r.
       CodeGenFunction r local)
          -- allocate temporary variables before a loop
      (forall r.
       Tuple.ValueOf parameters ->
       CodeGenFunction r (context, state))
          -- initial state
      (forall r.
       context -> state ->
       CodeGenFunction r ())
          -- cleanup
      (p -> IO (ioContext, parameters))
          {- initialization from IO monad
          This will be run within Unsafe.performIO,
          so no observable In/Out actions please!
          -}
      (ioContext -> IO ())
          -- finalization from IO monad, also run within Unsafe.performIO


type instance CausalClass.ProcessOf (Sig.T p) = T p

instance CausalClass.C (T p) where
   type SignalOf (T p) = Sig.T p
   toSignal = toSignal
   fromSignal = fromSignal

instance Causal.C (T p) where
   simple next start =
      simple (\() -> next) (\() -> fmap ((,) ()) start) (pure ())

   alter f (Cons next0 alloca start0 stop0 create delete) =
      case f (CausalPriv.Core (uncurry next0) return id) of
         CausalPriv.Core next1 start1 stop1 ->
            Cons
               (curry next1) alloca
               (Sig.withStart start0 start1)
               (\c -> stop0 c . stop1)
               create delete

   replicateControlled n = replicateControlled $ pure n


simple ::
   (Marshal.C parameters, Memory.C context, Memory.C state) =>
   (forall r c.
    (Tuple.Phi c) =>
    context -> a -> state -> MaybeCont.T r c (b, state)) ->
   (forall r.
    Tuple.ValueOf parameters ->
    CodeGenFunction r (context, state)) ->
   Param.T p parameters -> T p a b
simple f start param =
   Param.withValue param $ \get value ->
   Cons
      (\context () -> f context)
      (return ())
      (start . value)
      (const $ const $ return ())
      (return . (,) () . get)
      (const $ return ())


toSignal :: T p () a -> Sig.T p a
toSignal
      (Cons next alloca start stop createIOContext deleteIOContext) = Sig.Cons
   (\p l -> next p l ())
   alloca
   start stop
   createIOContext deleteIOContext

fromSignal :: Sig.T p b -> T p a b
fromSignal
      (Sig.Cons next alloca start stop createIOContext deleteIOContext) = Cons
   (\p l _ -> next p l)
   alloca
   start stop
   createIOContext deleteIOContext


mapAccum ::
   (Marshal.C pnh, Tuple.ValueOf pnh ~ pnl,
    Marshal.C psh, Tuple.ValueOf psh ~ psl,
    Memory.C s) =>
   (forall r. pnl -> a -> s -> CodeGenFunction r (b,s)) ->
   (forall r. psl -> CodeGenFunction r s) ->
   Param.T p pnh ->
   Param.T p psh ->
   T p a b
mapAccum next start selectParamN selectParamS =
   simple
      (\p a s -> MaybeCont.lift $ next p a s)
      (\(n,s) -> fmap ((,) n) $ start s)
      (selectParamN &&& selectParamS)


map ::
   (Marshal.C ph, Tuple.ValueOf ph ~ pl) =>
   (forall r. pl -> a -> CodeGenFunction r b) ->
   Param.T p ph ->
   T p a b
map f selectParamF =
   mapAccum
      (\p a s -> fmap (flip (,) s) $ f p a)
      (const $ return ())
      selectParamF
      (return ())

mapSimple ::
   (forall r. a -> CodeGenFunction r b) ->
   T p a b
mapSimple f =
   map (const f) (return ())

zipWith ::
   (Marshal.C ph, Tuple.ValueOf ph ~ pl) =>
   (forall r. pl -> a -> b -> CodeGenFunction r c) ->
   Param.T p ph ->
   T p (a,b) c
zipWith f =
   map (uncurry . f)

zipWithSimple ::
   (forall r. a -> b -> CodeGenFunction r c) ->
   T p (a,b) c
zipWithSimple f =
   mapSimple (uncurry f)


apply :: T p a b -> Sig.T p a -> Sig.T p b
apply = CausalClass.apply

feedFst :: Sig.T p a -> T p b (a,b)
feedFst = CausalClass.feedFst

feedSnd :: Sig.T p a -> T p b (b,a)
feedSnd = CausalClass.feedSnd


{-
Very similar to 'apply',
since 'apply' can be considered being of type
@T p a b -> T p () a -> T p () b@.
-}
compose :: T p a b -> T p b c -> T p a c
compose
      (Cons nextA allocaA startA stopA createIOContextA deleteIOContextA)
      (Cons nextB allocaB startB stopB createIOContextB deleteIOContextB) =
   Cons
      (composeNext MaybeCont.onFail stopA stopB nextA nextB)
      (M.lift2 (,) allocaA allocaB)
      (composeStart startA startB)
      (composeStop stopA stopB)
      (composeCreate createIOContextA createIOContextB)
      (composeDelete deleteIOContextA deleteIOContextB)

composeNext ::
   (Monad maybe) =>
   (forall x. code () -> maybe x -> maybe x) ->
   (contextA -> stateA -> code ()) ->
   (contextB -> stateB -> code ()) ->
   (contextA -> localA -> a -> stateA -> maybe (b, stateA)) ->
   (contextB -> localB -> b -> stateB -> maybe (c, stateB)) ->
   (contextA, contextB) ->
   (localA, localB) ->
   a ->
   (stateA, stateB) ->
   maybe (c, (stateA, stateB))
composeNext onFail stopA stopB nextA nextB
      (paramA, paramB) (localA, localB) a (sa0,sb0) = do
   (b,sa1) <- onFail (stopB paramB sb0) $ nextA paramA localA a sa0
   (c,sb1) <- onFail (stopA paramA sa1) $ nextB paramB localB b sb0
   return (c, (sa1,sb1))

composeStart ::
   Monad m =>
   (paramA -> m (contextA, stateA)) ->
   (paramB -> m (contextB, stateB)) ->
   (paramA, paramB) -> m ((contextA, contextB), (stateA, stateB))
composeStart = Sig.combineStart

composeStop ::
   Monad m =>
   (contextA -> stateA -> m ()) ->
   (contextB -> stateB -> m ()) ->
   (contextA, contextB) -> (stateA, stateB) -> m ()
composeStop = Sig.combineStop

composeCreate ::
   Monad m =>
   (p -> m (ioContextA, contextA)) ->
   (p -> m (ioContextB, contextB)) ->
   p -> m ((ioContextA, ioContextB), (contextA, contextB))
composeCreate = Sig.combineCreate

composeDelete ::
   (Monad m) =>
   (ca -> m ()) -> (cb -> m ()) -> (ca, cb) -> m ()
composeDelete = Sig.combineDelete


{- |
serial replication

But you may also use it for a parallel replication, see 'replicateParallel'.
-}
replicateControlled ::
   (Tuple.Undefined x, Tuple.Phi x) =>
   Param.T p Int -> T p (c,x) x -> T p (c,x) x
replicateControlled
      n (Cons next alloca start stop createIOContext deleteIOContext) =
   case Param.wordInt n of
      n32 -> Cons
         (\(len, cs) ->
            replicateControlledNext next stop (Param.valueTuple n32 len, cs))
         (-- we re-use the temporary variable for all stages)
          alloca)
         (\(len, param) ->
            replicateControlledStart start (Param.valueTuple n32 len, param))
         (\(len, cs) ->
            replicateControlledStop stop (Param.valueTuple n32 len, cs))
         (\p ->
            replicateControlledCreate $
               M.replicate (Param.get n p) (createIOContext p))
         (replicateControlledDelete deleteIOContext)

replicateControlledNext ::
   (Memory.C context, Memory.C state,
    contextState ~
       LLVM.Struct (Memory.Struct context, (Memory.Struct state, ())),
    Tuple.Phi z, Tuple.Phi a, Tuple.Undefined a) =>
   (forall z0. (Tuple.Phi z0) =>
    context -> local -> (ctrl, a) -> state ->
    MaybeCont.T r z0 (a, state)) ->
   (context -> state -> CodeGenFunction r ()) ->
   (Value Word, Value (LLVM.Ptr contextState)) ->
   local ->
   (ctrl, a) ->
   () ->
   MaybeCont.T r z (a, ())
replicateControlledNext next stop (len, contextStates) local (c,a) () =
   MaybeCont.fromMaybe $ fmap (\(_,ms) -> flip (,) () <$> ms) $
      MaybeCont.arrayLoop len contextStates a $
            \contextStatePtr a0 -> do
         (context, s0) <- MaybeCont.lift $ Memory.load contextStatePtr
         (a1,s1) <-
            MaybeCont.onFail
               (replicateControlledStopExcept
                  stop len contextStates contextStatePtr) $
            next context local (c,a0) s0
         MaybeCont.lift $
            Memory.store s1 =<< LLVM.getElementPtr0 contextStatePtr (d1, ())
         return a1

replicateControlledStopExcept ::
   (Memory.C a, Memory.C b,
    ab ~ LLVM.Struct (Memory.Struct a, (Memory.Struct b, ()))) =>
   (a -> b -> CodeGenFunction r ()) ->
   Value Word ->
   Value (LLVM.Ptr ab) ->
   Value (LLVM.Ptr ab) ->
   CodeGenFunction r ()
replicateControlledStopExcept stop len contextStates contextStatePtr =
   C.arrayLoop len contextStates () $ \ptr () -> do
      b <- A.cmp LLVM.CmpNE ptr contextStatePtr
      C.ifThen b () $ uncurry stop =<< Memory.load ptr

_replicateControlledNext ::
   (Memory.C context, Memory.C state,
    contextState ~
       LLVM.Struct (Memory.Struct context, (Memory.Struct state, ())),
    Tuple.Phi z, Tuple.Phi a, Tuple.Undefined a) =>
   (forall z0. (Tuple.Phi z0) =>
    context -> (ctrl, a) -> state ->
    MaybeCont.T r z0 (a, state)) ->
   (Value Word, Value (LLVM.Ptr contextState)) ->
   (ctrl, a) ->
   () ->
   MaybeCont.T r z (a, ())
_replicateControlledNext next (len, contextStates) (c,a) () =
   fmap (flip (,) ()) $ MaybeCont.fromBool $ fmap snd $
   C.arrayLoopWithExit len contextStates (valueOf True, a) $
         \contextStatePtr (_,a0) -> do
      (context, s0) <- Memory.load contextStatePtr
      (cont, (a1,s1)) <- MaybeCont.toBool $ next context (c,a0) s0
      Memory.store s1 =<< LLVM.getElementPtr0 contextStatePtr (d1, ())
      return (cont, (cont,a1))

replicateControlledStart ::
   (Memory.C a, Memory.C b) =>
   (a -> CodeGenFunction r b) ->
   (Value Word, Value (LLVM.Ptr (Memory.Struct a))) ->
   CodeGenFunction r ((Value Word, Value (LLVM.Ptr (Memory.Struct b))), ())
replicateControlledStart start (len, params) = do
   contextStates <- LLVM.arrayMalloc len
   C.arrayLoop2 len params contextStates () $ \paramPtr statePtr () ->
      flip Memory.store statePtr =<< start =<< Memory.load paramPtr
   return ((len, contextStates), ())

replicateControlledStop ::
   (Memory.C a, Memory.C b,
    ab ~ LLVM.Struct (Memory.Struct a, (Memory.Struct b, ()))) =>
   (a -> b -> CodeGenFunction r ()) ->
   (Value Word, Value (LLVM.Ptr ab)) ->
   () ->
   CodeGenFunction r ()
replicateControlledStop stop (len, contextStates) () = do
   C.arrayLoop len contextStates () $ \contextStatePtr () ->
      uncurry stop =<< Memory.load contextStatePtr
   LLVM.free contextStates


replicateControlledCreate ::
   (Monad m, Marshal.C b, Marshal.Struct b ~ struct) =>
   m [(a, b)] ->
   m (([a], ForeignPtr.MemoryPtr struct), (Word, LLVM.Ptr struct))
replicateControlledCreate createIOContexts = do
   (ioContexts, params) <- M.lift unzip createIOContexts
   let len = length params
   let fptr = Unsafe.performIO $ do
         fptr0 <-
            mallocForeignPtrBytes $ EE.sizeOfArray (proxyFromElement2 fptr) len
         ForeignPtr.with fptr0 $ flip EE.pokeList (fmap Marshal.pack params)
         return fptr0
   return ((ioContexts, fptr),
           (fromIntegral len,
            EE.castFromStoredPtr $ Unsafe.foreignPtrToPtr fptr))

replicateControlledDelete ::
   (a -> IO ()) ->
   ([a], ForeignPtr b) -> IO ()
replicateControlledDelete deleteIOContext (ioContexts, fptr) = do
   mapM_ deleteIOContext ioContexts
   touchForeignPtr fptr


instance Cat.Category (T p) where
   id = mapSimple return
   (.) = flip compose

instance Arr.Arrow (T p) where
   arr f = mapSimple (return . f)
   first = Causal.first


instance Functor (T p a) where
   fmap = ArrowUtil.map

instance Applicative (T p a) where
   pure = ArrowUtil.pure
   (<*>) = ArrowUtil.apply


instance (A.Additive b) => Additive.C (T p a b) where
   zero = pure A.zero
   negate = mapProc A.neg
   (+) = zipProcWith A.add
   (-) = zipProcWith A.sub

instance (A.PseudoRing b, A.IntegerConstant b) => Ring.C (T p a b) where
   one = pure A.one
   fromInteger n = pure (A.fromInteger' n)
   (*) = zipProcWith A.mul

instance (A.Field b, A.RationalConstant b) => Field.C (T p a b) where
   fromRational' x = pure (A.fromRational' $ Ratio.toRational98 x)
   (/) = zipProcWith A.fdiv


instance (A.PseudoRing b, A.Real b, A.IntegerConstant b) => P.Num (T p a b) where
   fromInteger n = pure (A.fromInteger' n)
   negate = mapProc A.neg
   (+) = zipProcWith A.add
   (-) = zipProcWith A.sub
   (*) = zipProcWith A.mul
   abs = mapProc A.abs
   signum = mapProc A.signum

instance (A.Field b, A.Real b, A.RationalConstant b) => P.Fractional (T p a b) where
   fromRational x = pure (A.fromRational' x)
   (/) = zipProcWith A.fdiv


{- |
Not quite the loop of ArrowLoop
because we need a delay of one time step
and thus an initialization value.

For a real ArrowLoop.loop, that is a zero-delay loop,
we would formally need a MonadFix instance of CodeGenFunction.
But this will not become reality, since LLVM is not able to re-order code
in a way that allows to access a result before creating the input.
-}
loop ::
   (Marshal.C c, Tuple.ValueOf c ~ cl) =>
   Param.T p c -> T p (a,cl) (b,cl) -> T p a b
loop initial (Cons next alloca start stop createIOContext deleteIOContext) =
   Param.withValue initial $ \getInitial valueInitial -> Cons
      (curry $ loopNext $ uncurry next)
      alloca
      (\(i,p) -> fmap (mapSnd ((,) (valueInitial i))) $ start p)
      (loopStop stop)
      (\p -> do
         (ctx, param) <- createIOContext p
         return (ctx, (getInitial p, param)))
      deleteIOContext

loopStop :: (context -> state -> m) -> context -> (c, state) -> m
loopStop stop ctx (_c,s) = stop ctx s


takeWhile ::
   (Marshal.C ph, Tuple.ValueOf ph ~ pl) =>
   (forall r. pl -> a -> CodeGenFunction r (Value Bool)) ->
   Param.T p ph ->
   T p a a
takeWhile check selectParam = simple
   (\p a () -> do
      MaybeCont.guard =<< MaybeCont.lift (check p a)
      return (a, ()))
   (\p -> return (p, ()))
   selectParam


take ::
   Param.T p Int ->
   T p a a
take len =
   snd ^<<
   Causal.takeWhile (A.cmp LLVM.CmpLT A.zero . fst) <<<
   feedFst
      (Sig.iterate (const A.dec) (return ())
         (Param.wordInt $ max 0 ^<< len))


{- |
The first output value is the initial value.
Thus 'integrate' delays by one sample compared with 'integrateSync'.
-}
integrate ::
   (Marshal.C a, Tuple.ValueOf a ~ al, A.Additive al) =>
   Param.T p a ->
   T p al al
integrate =
   flip loop (arr snd &&& zipWithSimple A.add)

integrateSync ::
   (Marshal.C a, Tuple.ValueOf a ~ al, A.Additive al) =>
   Param.T p a ->
   T p al al
integrateSync =
   flip loop ((\a -> (a,a)) ^<< zipWithSimple A.add)