{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
module Synthesizer.LLVM.CausalParameterized.ProcessPrivate where

import qualified Synthesizer.LLVM.Parameterized.SignalPrivate as Sig
import qualified LLVM.Extra.MaybeContinuation as Maybe
import qualified Synthesizer.LLVM.Parameter as Param
import qualified LLVM.Extra.Representation as Rep

import qualified LLVM.Extra.Arithmetic as A

import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (Phi, )
import LLVM.Core
          (Value, valueOf, MakeValueTuple,
           IsSized, IsFirstClass, IsArithmetic, CodeGenFunction, )

import qualified Control.Arrow    as Arr
import qualified Control.Category as Cat
import Control.Arrow ((^<<), (<<<), (<<^), (&&&), )
import Control.Monad (liftM2, )

import qualified Algebra.Ring     as Ring

import Data.Word (Word32, )
import Foreign.Storable.Tuple ()
import Foreign.Storable (Storable, )

import Data.Tuple.HT (swap, )

import NumericPrelude.Numeric
import NumericPrelude.Base hiding (and, iterate, map, zip, zipWith, take, takeWhile, )


data T p a b =
   forall state packed size ioContext
        startParamTuple startParamValue startParamPacked startParamSize
        nextParamTuple  nextParamValue  nextParamPacked  nextParamSize.
      (Storable startParamTuple,
       Storable nextParamTuple,
       MakeValueTuple startParamTuple startParamValue,
       MakeValueTuple nextParamTuple  nextParamValue,
       Rep.Memory     startParamValue startParamPacked,
       Rep.Memory     nextParamValue  nextParamPacked,
       IsSized        startParamPacked startParamSize,
       IsSized        nextParamPacked  nextParamSize,
       Rep.Memory state packed,
       IsSized packed size) =>
   Cons
      (forall r c.
       (Phi c) =>
       nextParamValue ->
       a -> state -> Maybe.T r c (b, state))
          -- compute next value
      (forall r.
       startParamValue ->
       CodeGenFunction r state)
          -- initial state
      (p -> IO (ioContext, (nextParamTuple, startParamTuple)))
          {- initialization from IO monad
          This will be run within unsafePerformIO,
          so no observable In/Out actions please!
          -}
      (ioContext -> IO ())
          -- finalization from IO monad, also run within unsafePerformIO


simple ::
   (Storable startParamTuple,
    Storable nextParamTuple,
    MakeValueTuple startParamTuple startParamValue,
    MakeValueTuple nextParamTuple nextParamValue,
    Rep.Memory startParamValue startParamPacked,
    Rep.Memory nextParamValue nextParamPacked,
    IsSized    startParamPacked startParamSize,
    IsSized    nextParamPacked  nextParamSize,
    Rep.Memory state packed,
    IsSized packed size) =>
   (forall r c.
    (Phi c) =>
    nextParamValue ->
    a -> state -> Maybe.T r c (b, state)) ->
   (forall r.
    startParamValue ->
    CodeGenFunction r state) ->
   Param.T p nextParamTuple ->
   Param.T p startParamTuple -> T p a b
simple f start selectParam initial = Cons
   (f . Param.value selectParam)
   (start . Param.value initial)
   (return . (,) () . Param.get (selectParam &&& initial))
   (const $ return ())


toSignal :: T p () a -> Sig.T p a
toSignal (Cons next start createIOContext deleteIOContext) = Sig.Cons
   (\ioContext -> next ioContext ())
   start
   createIOContext deleteIOContext

fromSignal :: Sig.T p a -> T p () a
fromSignal (Sig.Cons next start createIOContext deleteIOContext) = Cons
   (\ioContext () -> next ioContext)
   start
   createIOContext deleteIOContext


mapAccum ::
   (Storable pnh, MakeValueTuple pnh pnl, Rep.Memory pnl pnp, IsSized pnp pns,
    Storable psh, MakeValueTuple psh psl, Rep.Memory psl psp, IsSized psp pss,
    Rep.Memory s struct, IsSized struct sa) =>
   (forall r. pnl -> a -> s -> CodeGenFunction r (b,s)) ->
   (forall r. psl -> CodeGenFunction r s) ->
   Param.T p pnh ->
   Param.T p psh ->
   T p a b
mapAccum next start selectParamN selectParamS =
   simple
      (\p a s -> Maybe.lift $ next p a s)
      start
      selectParamN selectParamS


map ::
   (Storable ph, MakeValueTuple ph pl, Rep.Memory pl pp, IsSized pp ps) =>
   (forall r. pl -> a -> CodeGenFunction r b) ->
   Param.T p ph ->
   T p a b
map f selectParamF =
   mapAccum
      (\p a s -> fmap (flip (,) s) $ f p a)
      (const $ return ())
      selectParamF
      (return ())

mapSimple ::
   (forall r. a -> CodeGenFunction r b) ->
   T p a b
mapSimple f =
   map (const f) (return ())


apply :: T p a b -> Sig.T p a -> Sig.T p b
apply proc sig =
   toSignal (proc <<< fromSignal sig)

feedFst :: Sig.T p a -> T p b (a,b)
feedFst sig =
   first (fromSignal sig) <<^ (\b -> ((),b))

feedSnd :: Sig.T p a -> T p b (b,a)
feedSnd sig =
   swap ^<< feedFst sig


{-
Very similar to 'apply',
since 'apply' can be considered being of type
@T p a b -> T p () a -> T p () b@.
-}
compose :: T p a b -> T p b c -> T p a c
compose
      (Cons nextA startA createIOContextA deleteIOContextA)
      (Cons nextB startB createIOContextB deleteIOContextB) =
   Cons
      (\(paramA, paramB) a (sa0,sb0) ->
         do (b,sa1) <- nextA paramA a sa0
            (c,sb1) <- nextB paramB b sb0
            return (c, (sa1,sb1)))
      (\(paramA, paramB) ->
         liftM2 (,)
            (startA paramA)
            (startB paramB))
      (\p -> do
         (ca,(nextParamA,startParamA)) <- createIOContextA p
         (cb,(nextParamB,startParamB)) <- createIOContextB p
         return ((ca,cb),
            ((nextParamA,  nextParamB),
             (startParamA, startParamB))))
      (\(ca,cb) ->
         deleteIOContextA ca >>
         deleteIOContextB cb)


first :: T p b c -> T p (b, d) (c, d)
first (Cons next start createIOContext deleteIOContext) = Cons
   (\ioContext (b,d) sa0 ->
      do (c,sa1) <- next ioContext b sa0
         return ((c,d), sa1))
   start
   createIOContext deleteIOContext


instance Cat.Category (T p) where
   id = mapSimple return
   (.) = flip compose

instance Arr.Arrow (T p) where
   arr f = mapSimple (return . f)
   first = first


takeWhile ::
   (Storable ph, MakeValueTuple ph pl, Rep.Memory pl pp, IsSized pp ps) =>
   (forall r. pl -> a -> CodeGenFunction r (Value Bool)) ->
   Param.T p ph ->
   T p a a
takeWhile check selectParam = simple
   (\p a () -> do
      Maybe.guard =<< Maybe.lift (check p a)
      return (a, ()))
   return
   selectParam
   (return ())


take ::
   Param.T p Int ->
   T p a a
take len =
   snd ^<<
   takeWhile (const $ A.icmp LLVM.IntULT (valueOf 0) . fst) (return ()) <<<
   feedFst
      (Sig.iterate (const A.dec) (return ())
         ((fromIntegral :: Int -> Word32) . max 0 ^<< len))


{- |
The first output value is the start value.
Thus 'integrate' delays by one sample compared with 'integrate0'.
-}
integrate ::
   (Storable a, IsArithmetic a,
    MakeValueTuple a (Value a), IsFirstClass a, IsSized a size) =>
   Param.T p a ->
   T p (Value a) (Value a)
integrate =
   mapAccum
      (\() a s -> do
         b <- A.add a s
         return (s,b))
      return
      (return ())

integrate0 ::
   (Storable a, IsArithmetic a,
    MakeValueTuple a (Value a), IsFirstClass a, IsSized a size) =>
   Param.T p a ->
   T p (Value a) (Value a)
integrate0 =
   mapAccum
      (\() a s -> do
         b <- A.add a s
         return (b,b))
      return
      (return ())