{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Rank2Types #-}
module Synthesizer.LLVM.CausalParameterized.RingBufferForward (
   T, track, trackSkip, trackSkipHold,
   index,
   ) where

import qualified Synthesizer.LLVM.CausalParameterized.ProcessPrivate
                                                              as CausalPrivP
import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.Parameterized.SignalPrivate as SigP
import qualified Synthesizer.LLVM.Parameter as Param
import Synthesizer.LLVM.CausalParameterized.Process (($<), ($*), )

import qualified LLVM.Extra.MaybeContinuation as MaybeCont
import qualified LLVM.Extra.Maybe as Maybe
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Class as Class

import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (Phi, )
import LLVM.Core (CodeGenFunction, Value, )

import Control.Arrow ((<<<), )
import Control.Applicative (pure, )
import Data.Tuple.HT (mapSnd, )

import Data.Word (Word32, )
import Foreign.Storable.Tuple ()
import Foreign.Ptr (Ptr, )

import Prelude hiding (length, )


{- |
This type is very similar to 'Synthesizer.LLVM.RingBuffer.T'
but differs in several details:

* It stores values in time order,
  whereas 'Synthesizer.LLVM.RingBuffer.T' stores in opposite order.

* Since it stores future values it is not causal
  and can only track signal generators.

* There is no need for an initial value.

* It stores one value less than 'Synthesizer.LLVM.RingBuffer.T'
  since it is meant to provide infixes of the signal
  rather than providing the basis for a delay line.

Those differences in detail would not justify a new type,
you could achieve the same by a combination of
'Synthesizer.LLVM.RingBuffer.track'
and
'Synthesizer.LLVM.CausalParameterized.Process.skip'.
The fundamental problem of this combination is
that it requires to keep the ring buffer alive
longer than the providing signal exists.
This is not possible with the current design.
That's why we provide the combination of @track@ and @skip@
in a way that does not suffer from that problem.
This functionality is critical for
'Synthesizer.LLVM.CausalParameterized.Helix.dynamic'.
-}
data T a =
   Cons {
      buffer :: Value (Ptr (Memory.Struct a)),
      length :: Value Word32,
      current :: Value Word32
   }

{- |
This function does not check for range violations.
If the ring buffer was generated by @track time@,
then the minimum index is zero and the maximum index is @time-1@.
Index zero refers to the current sample
and index @time-1@ refers to the one that is farthermost in the future.
-}
index ::
   (Memory.C a) =>
   Value Word32 -> T a -> CodeGenFunction r a
index i rb = do
   k <- flip A.irem (length rb) =<< A.add (current rb) i
   Memory.load =<< LLVM.getElementPtr (buffer rb) (k, ())


{- |
@track time signal@ bundles @time@ successive values of @signal@.
The values can be accessed using 'index' with indices
ranging from 0 to @time-1@.

The @time@ parameter must be non-negative.
-}
track ::
   (Memory.C a) =>
   Param.T p Int -> SigP.T p a -> SigP.T p (T a)
track time input = trackSkip time input $* 1

{- |
@trackSkip time input $* skips@
is like
@Process.skip (track time input) $* skips@
but this composition would require a @Memory@ constraint for 'T'
which we cannot provide.
-}
trackSkip ::
   (Memory.C a) =>
   Param.T p Int -> SigP.T p a -> CausalP.T p (Value Word32) (T a)
trackSkip time (SigP.Cons next alloca start stop create delete) =
   Param.with (Param.word32 time) $ \getTime valueTime ->
      CausalPrivP.Cons
         (trackNext next valueTime)
         alloca
         (trackStart start valueTime)
         (trackStop stop)
         (trackCreate create getTime)
         (trackDelete delete)

{- |
Like @trackSkip@ but repeats the last buffer content
when the end of the input signal is reached.
The returned 'Bool' flag is 'True' if a skip could be performed completely
and it is 'False' if the skip exceeds the end of the input.
That is, once a 'False' is returned all following values are tagged with 'False'.
The returned 'Word32' value is the number of actually skipped values.
This lags one step behind the input of skip values.
The number of an actual number of skips
is at most the number of requested skips.
If the flag is 'False', then the number of actual skips is zero.
The converse does not apply.

If the input signal is too short, the output is undefined.
(Before the available data the buffer will be filled with arbitrary values.)
We could fill the buffer with zeros,
but this would require an Arithmetic constraint
and the generated signal would not be very meaningful.
We could also return an empty signal if the input is too short.
However this would require a permanent check.
-}
trackSkipHold, trackSkipHold_ ::
   (Memory.C a) =>
   Param.T p Int -> SigP.T p a ->
   CausalP.T p (Value Word32) ((Value Bool, Value Word32), T a)
trackSkipHold time xs =
   (CausalP.zipWithSimple
       (\b ((c,x), buf) -> do
          y <- C.select b x A.zero
          return ((c, y), buf))
      $< (CausalP.delay1 (pure False) $* SigP.constant (pure True)))
{-
   (CausalPV.zipWithSimple (\b ((c,x), buf) -> ((c, b ?? (x,0)), buf))
      $< (CausalP.delay1 (pure False) $* SigP.constant (pure True)))
-}
   <<<
   trackSkipHold_ time xs

trackSkipHold_ time (SigP.Cons next alloca start stop create delete) =
   (Param.with (Param.word32 time) $ \getTime valueTime ->
      CausalPrivP.Cons
         (trackNextHold next valueTime)
         alloca
         (trackStartHold start valueTime)
         (trackStopHold stop)
         (trackCreate create getTime)
         (trackDelete delete))


trackNext ::
   (Memory.C al, Memory.Struct al ~ am, Phi z,
    Phi state, Class.Undefined state) =>
   (forall z0. (Phi z0) =>
    context -> local -> state -> MaybeCont.T r z0 (al, state)) ->
   (tl -> Value Word32) ->
   (context, (tl, Value (Ptr am))) -> local ->
   Value Word32 ->
   (Value Word32, (state, Value Word32)) ->
   MaybeCont.T r z (T al, (Value Word32, (state, Value Word32)))
trackNext next valueTime (context, (size,ptr)) local n1 (n0, statePos) = do
   let size0 = valueTime size
   (state3, pos3) <-
      MaybeCont.fromMaybe $ fmap snd $
      MaybeCont.fixedLengthLoop n0 statePos $ \(state0, pos0) -> do
         (a, state1) <- next context local state0
         MaybeCont.lift $
            fmap ((,) state1) $ storeNext (size0,ptr) a pos0
   return (Cons ptr size0 pos3, (n1, (state3, pos3)))

trackStart ::
   (LLVM.IsSized am,
    Phi state, Class.Undefined state) =>
   (param -> CodeGenFunction r (context, state)) ->
   (tl -> Value Word32) ->
   (param, tl) ->
   CodeGenFunction r
      ((context, (tl, Value (Ptr am))),
       (Value Word32, (state, Value Word32)))
trackStart start valueTime (param, size) = do
   (context, state) <- start param
   let size0 = valueTime size
   ptr <- LLVM.arrayMalloc size0
   return ((context, (size,ptr)), (size0, (state, A.zero)))

trackStop ::
   (LLVM.IsType am) =>
   (context -> state -> CodeGenFunction r ()) ->
   (context, (tl, Value (Ptr am))) ->
   (Value Word32, (state, Value Word32)) ->
   CodeGenFunction r ()
trackStop stop (context, (_size,ptr)) (_n, (state, _remain)) = do
   LLVM.free ptr
   stop context state


trackNextHold ::
   (Memory.C al, Memory.Struct al ~ am, Phi z,
    Phi state, Class.Undefined state) =>
   (forall z0. (Phi z0) =>
    context -> local -> state -> MaybeCont.T r z0 (al, state)) ->
   (tl -> Value Word32) ->
   (context, (tl, Value (Ptr am))) -> local ->
   Value Word32 ->
   (Value Word32, (Maybe.T state, Value Word32)) ->
   MaybeCont.T r z
      (((Value Bool, Value Word32), T al),
       (Value Word32, (Maybe.T state, Value Word32)))
trackNextHold
   next valueTime (context, (size,ptr)) local nNext (n0, (mstate0, pos0)) =
      MaybeCont.lift $ do
   let size0 = valueTime size
   (n3, (pos3, state3)) <-
      Maybe.run mstate0
         (return (n0, (pos0, mstate0)))
         (\state0 ->
            Maybe.loopWithExit (n0, (state0, pos0))
               (\(n1, (state1, pos1)) -> do
                  cont <- A.cmp LLVM.CmpGT n1 A.zero
                  fmap (mapSnd ((,) n1 . (,) pos1)) $
                     C.ifThen cont
                        (Maybe.nothing, Maybe.just state1)
                        (do aState <-
                              MaybeCont.toMaybe $ next context local state1
                            return (aState, fmap snd aState)))
               (\((a,state), (n1, (pos1, _mstate))) -> do
                  pos2 <- storeNext (size0,ptr) a pos1
                  n2 <- A.dec n1
                  return (n2, (state, pos2))))
   skipped <- A.sub n0 n3
   return (((Maybe.isJust state3, skipped), Cons ptr size0 pos3),
           (nNext, (state3, pos3)))

storeNext ::
   (Memory.C al, Memory.Struct al ~ am) =>
   (Value Word32, Value (Ptr am)) ->
   al -> Value Word32 -> CodeGenFunction r (Value Word32)
storeNext (size0,ptr) a pos0 = do
   Memory.store a =<< LLVM.getElementPtr ptr (pos0, ())
   pos1 <- A.inc pos0
   cont <- A.cmp LLVM.CmpLT pos1 size0
   C.select cont pos1 A.zero


trackStartHold ::
   (LLVM.IsSized am,
    Phi state, Class.Undefined state) =>
   (param -> CodeGenFunction r (context, state)) ->
   (tl -> Value Word32) ->
   (param, tl) ->
   CodeGenFunction r
      ((context, (tl, Value (Ptr am))),
       (Value Word32, (Maybe.T state, Value Word32)))
trackStartHold start valueTime (param, size) = do
   (context, state) <- start param
   let size0 = valueTime size
   ptr <- LLVM.arrayMalloc size0
   return ((context, (size,ptr)), (size0, (Maybe.just state, A.zero)))

trackStopHold ::
   (LLVM.IsType am) =>
   (context -> state -> CodeGenFunction r ()) ->
   (context, (tl, Value (Ptr am))) ->
   (Value Word32, (Maybe.T state, Value Word32)) ->
   CodeGenFunction r ()
trackStopHold stop (context, (_size,ptr)) (_n, (state, _remain)) = do
   LLVM.free ptr
   Maybe.for state $ stop context


trackCreate ::
   (p -> IO (ioContext, param)) ->
   (p -> t) ->
   p ->
   IO (ioContext, (param, t))
trackCreate create getTime p = do
   (context, param) <- create p
   return (context, (param, getTime p))

trackDelete :: (ioContext -> IO ()) -> ioContext -> IO ()
trackDelete = id