{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Rank2Types #-}
module Synthesizer.LLVM.Causal.RingBufferForward (
   T, track, trackSkip, trackSkipHold,
   index, mapIndex,
   ) where

import qualified Synthesizer.LLVM.Causal.Private as CausalPriv
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Generator.Private as Sig
import Synthesizer.LLVM.RingBuffer (MemoryPtr)

import Synthesizer.LLVM.Causal.Process (($*#))
import Synthesizer.Causal.Class (($<), ($*))

import qualified LLVM.DSL.Expression as Expr
import LLVM.DSL.Expression (Exp)

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.MaybeContinuation as MaybeCont
import qualified LLVM.Extra.Maybe as Maybe
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Tuple as Tuple

import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction, Value)

import qualified Control.Arrow as Arrow
import Control.Arrow ((<<<), (<<^))
import Data.Tuple.HT (mapSnd, mapPair)

import Data.Word (Word)

import Prelude hiding (length)



{- |
This type is very similar to 'Synthesizer.LLVM.RingBuffer.T'
but differs in several details:

* It stores values in time order,
  whereas 'Synthesizer.LLVM.RingBuffer.T' stores in opposite order.

* Since it stores future values it is not causal
  and can only track signal generators.

* There is no need for an initial value.

* It stores one value less than 'Synthesizer.LLVM.RingBuffer.T'
  since it is meant to provide infixes of the signal
  rather than providing the basis for a delay line.

Those differences in detail would not justify a new type,
you could achieve the same by a combination of
'Synthesizer.LLVM.RingBuffer.track'
and
'Synthesizer.LLVM.CausalParameterized.Process.skip'.
The fundamental problem of this combination is
that it requires to keep the ring buffer alive
longer than the providing signal exists.
This is not possible with the current design.
That's why we provide the combination of @track@ and @skip@
in a way that does not suffer from that problem.
This functionality is critical for
'Synthesizer.LLVM.CausalParameterized.Helix.dynamic'.
-}
data T a =
   Cons {
      buffer :: Value (MemoryPtr a),
      length :: Value Word,
      current :: Value Word
   }

{- |
This function does not check for range violations.
If the ring buffer was generated by @track time@,
then the minimum index is zero and the maximum index is @time-1@.
Index zero refers to the current sample
and index @time-1@ refers to the one that is farthermost in the future.
-}
index :: (Memory.C a) => MultiValue.T Word -> T a -> CodeGenFunction r a
index (MultiValue.Cons i) rb = do
   k <- flip A.irem (length rb) =<< A.add (current rb) i
   Memory.load =<< LLVM.getElementPtr (buffer rb) (k, ())

mapIndex :: (Memory.C a) => Exp Word -> Causal.T (T a) a
mapIndex k = CausalPriv.map (\buf -> flip index buf =<< Expr.unExp k)


{- |
@track time signal@ bundles @time@ successive values of @signal@.
The values can be accessed using 'index' with indices
ranging from 0 to @time-1@.

The @time@ parameter must be non-negative.
-}
track :: (Memory.C a) => Exp Word -> Sig.T a -> Sig.T (T a)
track time input = trackSkip time input $* 1

{- |
@trackSkip time input $* skips@
is like
@Process.skip (track time input) $* skips@
but this composition would require a @Memory@ constraint for 'T'
which we cannot provide.
-}
trackSkip ::
   (Memory.C a) =>
   Exp Word -> Sig.T a -> Causal.T (MultiValue.T Word) (T a)
trackSkip time (Sig.Cons next start stop) =
   CausalPriv.Cons
      (trackNext next)
      (trackStart start time)
      (trackStop stop)
   <<^
   (\(MultiValue.Cons skip) -> skip)

{- |
Like @trackSkip@ but repeats the last buffer content
when the end of the input signal is reached.
The returned 'Bool' flag is 'True' if a skip could be performed completely
and it is 'False' if the skip exceeds the end of the input.
That is, once a 'False' is returned all following values are tagged with 'False'.
The returned 'Word' value is the number of actually skipped values.
This lags one step behind the input of skip values.
The number of an actual number of skips
is at most the number of requested skips.
If the flag is 'False', then the number of actual skips is zero.
The converse does not apply.

If the input signal is too short, the output is undefined.
(Before the available data the buffer will be filled with arbitrary values.)
We could fill the buffer with zeros,
but this would require an Arithmetic constraint
and the generated signal would not be very meaningful.
We could also return an empty signal if the input is too short.
However this would require a permanent check.
-}
trackSkipHold ::
   (Memory.C a) =>
   Exp Word -> Sig.T a ->
   Causal.T (MultiValue.T Word) ((MultiValue.T Bool, MultiValue.T Word), T a)
trackSkipHold time xs =
   Arrow.first
      (Arrow.second clearFirst <<^ mapPair (MultiValue.Cons, MultiValue.Cons))
   <<<
   trackSkipHold_ time xs
   <<^
   (\(MultiValue.Cons skip) -> skip)

clearFirst ::
   (MultiValue.PseudoRing a, MultiValue.Real a,
    MultiValue.IntegerConstant a, MultiValue.Select a) =>
   Causal.MV a a
clearFirst =
   Causal.zipWith (\b x -> Expr.select b x 0)
      $< (Causal.delay1 Expr.false $*# True)

trackSkipHold_ ::
   (Memory.C a) =>
   Exp Word -> Sig.T a ->
   Causal.T (Value Word) ((Value Bool, Value Word), T a)
trackSkipHold_ time (Sig.Cons next start stop) =
   CausalPriv.Cons
      (trackNextHold next)
      (trackStartHold start time)
      (trackStopHold stop)


trackNext ::
   (Memory.C al, Tuple.Phi z,
    Tuple.Phi state, Tuple.Undefined state) =>
   (forall z0. (Tuple.Phi z0) =>
    context -> local -> state -> MaybeCont.T r z0 (al, state)) ->
   (context, (Value Word, Value (MemoryPtr al))) -> local ->
   Value Word ->
   (Value Word, (state, Value Word)) ->
   MaybeCont.T r z (T al, (Value Word, (state, Value Word)))
trackNext next (context, (size0,ptr)) local n1 (n0, statePos) = do
   (state3, pos3) <-
      MaybeCont.fromMaybe $ fmap snd $
      MaybeCont.fixedLengthLoop n0 statePos $ \(state0, pos0) -> do
         (a, state1) <- next context local state0
         MaybeCont.lift $
            fmap ((,) state1) $ storeNext (size0,ptr) a pos0
   return (Cons ptr size0 pos3, (n1, (state3, pos3)))

trackStart ::
   (LLVM.IsSized am, Tuple.Phi state, Tuple.Undefined state) =>
   CodeGenFunction r (context, state) ->
   Exp Word ->
   CodeGenFunction r
      ((context, (Value Word, Value (LLVM.Ptr am))),
       (Value Word, (state, Value Word)))
trackStart start size = do
   (context, state) <- start
   ~(MultiValue.Cons size0) <- Expr.unExp size
   ptr <- LLVM.arrayMalloc size0
   return ((context, (size0,ptr)), (size0, (state, A.zero)))

trackStop ::
   (LLVM.IsType am) =>
   (context -> CodeGenFunction r ()) ->
   (context, (tl, Value (LLVM.Ptr am))) ->
   CodeGenFunction r ()
trackStop stop (context, (_size,ptr)) = do
   LLVM.free ptr
   stop context


trackNextHold ::
   (Memory.C al, Tuple.Phi z,
    Tuple.Phi state, Tuple.Undefined state) =>
   (forall z0. (Tuple.Phi z0) =>
    context -> local -> state -> MaybeCont.T r z0 (al, state)) ->
   (context, (Value Word, Value (MemoryPtr al))) -> local ->
   Value Word ->
   (Value Word, (Maybe.T state, Value Word)) ->
   MaybeCont.T r z
      (((Value Bool, Value Word), T al),
       (Value Word, (Maybe.T state, Value Word)))
trackNextHold next (context, (size0,ptr)) local nNext (n0, (mstate0, pos0)) =
      MaybeCont.lift $ do
   (n3, (pos3, state3)) <-
      Maybe.run mstate0
         (return (n0, (pos0, mstate0)))
         (\state0 ->
            Maybe.loopWithExit (n0, (state0, pos0))
               (\(n1, (state1, pos1)) -> do
                  cont <- A.cmp LLVM.CmpGT n1 A.zero
                  fmap (mapSnd ((,) n1 . (,) pos1)) $
                     C.ifThen cont
                        (Maybe.nothing, Maybe.just state1)
                        (do aState <-
                              MaybeCont.toMaybe $ next context local state1
                            return (aState, fmap snd aState)))
               (\((a,state), (n1, (pos1, _mstate))) -> do
                  pos2 <- storeNext (size0,ptr) a pos1
                  n2 <- A.dec n1
                  return (n2, (state, pos2))))
   skipped <- A.sub n0 n3
   return (((Maybe.isJust state3, skipped), Cons ptr size0 pos3),
           (nNext, (state3, pos3)))

storeNext ::
   (Memory.C al) =>
   (Value Word, Value (MemoryPtr al)) ->
   al -> Value Word -> CodeGenFunction r (Value Word)
storeNext (size0,ptr) a pos0 = do
   Memory.store a =<< LLVM.getElementPtr ptr (pos0, ())
   pos1 <- A.inc pos0
   cont <- A.cmp LLVM.CmpLT pos1 size0
   C.select cont pos1 A.zero


trackStartHold ::
   (LLVM.IsSized am,
    Tuple.Phi state, Tuple.Undefined state) =>
   CodeGenFunction r (context, state) ->
   Exp Word ->
   CodeGenFunction r
      ((context, (Value Word, Value (LLVM.Ptr am))),
       (Value Word, (Maybe.T state, Value Word)))
trackStartHold start size = do
   (context, state) <- start
   ~(MultiValue.Cons size0) <- Expr.unExp size
   ptr <- LLVM.arrayMalloc size0
   return ((context, (size0,ptr)), (size0, (Maybe.just state, A.zero)))

trackStopHold ::
   (LLVM.IsType am) =>
   (context -> CodeGenFunction r ()) ->
   (context, (Value Word, Value (LLVM.Ptr am))) ->
   CodeGenFunction r ()
trackStopHold stop (context, (_size,ptr)) = do
   LLVM.free ptr
   stop context