{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeOperators #-}
module Synthesizer.LLVM.Causal.ProcessPacked where

import qualified Synthesizer.LLVM.Causal.Process as Causal
import Synthesizer.LLVM.Causal.ProcessPrivate (Core(Core), alter)

import qualified Synthesizer.LLVM.Frame.SerialVector as Serial

import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.MaybeContinuation as Maybe
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C

import qualified LLVM.Core as LLVM
import LLVM.Core
          (CodeGenFunction, Value, valueOf,
           IsSized, IsFirstClass)

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal ((:<:))
import Type.Base.Proxy (Proxy)

import qualified Control.Monad.Trans.Class as MT
import qualified Control.Monad.Trans.State as MS
import qualified Control.Arrow as Arr
import Control.Arrow ((<<<))

import Data.Word (Word)

import NumericPrelude.Numeric
import NumericPrelude.Base


{- |
Run a scalar process on packed data.
If the signal length is not divisible by the chunk size,
then the last chunk is dropped.
-}
pack ::
   (Causal.C process,
    Serial.Read va, n ~ Serial.Size va, a ~ Serial.Element va,
    Serial.C    vb, n ~ Serial.Size vb, b ~ Serial.Element vb) =>
   process a b -> process va vb
pack = alter (\(Core next start stop) -> Core
   (\param a s -> do
      r <- Maybe.lift $ Serial.readStart a
      ((_,w2),(_,s2)) <-
         Maybe.fromBool $
         C.whileLoop
            (valueOf True,
             let w = Tuple.undef
             in  ((r,w),
                  (valueOf (fromIntegral $ Serial.sizeOfIterator w :: Word), s)))
            (\(cont,(_rw0,(i0,_s0))) ->
               A.and cont =<<
                  A.cmp LLVM.CmpGT i0 A.zero)
            (\(_,((r0,w0),(i0,s0))) -> Maybe.toBool $ do
               (ai,r1) <- Maybe.lift $ Serial.readNext r0
               (bi,s1) <- next param ai s0
               Maybe.lift $ do
                  w1 <- Serial.writeNext bi w0
                  i1 <- A.dec i0
                  return ((r1,w1),(i1,s1)))
      b <- Maybe.lift $ Serial.writeStop w2
      return (b, s2))
   start
   stop)

{- |
Like 'pack' but duplicates the code for the scalar process.
That is, for vectors of size n,
the code for the scalar causal process will be written n times.
This is efficient only for simple input processes.
-}
packSmall ::
   (Causal.C process,
    Serial.Read va, n ~ Serial.Size va, a ~ Serial.Element va,
    Serial.C    vb, n ~ Serial.Size vb, b ~ Serial.Element vb) =>
   process a b -> process va vb
packSmall = alter (\(Core next start stop) -> Core
   (\param a ->
      MS.runStateT $
         (MT.lift . Maybe.lift . Serial.assemble)
         =<<
         mapM (MS.StateT . next param)
         =<<
         (MT.lift $ Maybe.lift $ Serial.extractAll a))
   start
   stop)


{- |
Run a packed process on scalar data.
If the signal length is not divisible by the chunk size,
then the last chunk is dropped.
In order to stay causal, we have to delay the output by @n@ samples.
-}
unpack ::
   (Causal.C process,
    Serial.Zero va, n ~ Serial.Size va, a ~ Serial.Element va,
    Serial.Read vb, n ~ Serial.Size vb, b ~ Serial.Element vb,
    Memory.C va, Memory.C ita, ita ~ Serial.WriteIt va,
    Memory.C vb, Memory.C itb, itb ~ Serial.ReadIt vb) =>
   process va vb -> process a b
unpack = alter (\(Core next start stop) -> Core
   (\param ai ((w0,r0),(i0,s0)) -> do
      endOfVector <- Maybe.lift $ A.cmp LLVM.CmpEQ i0 A.zero
      ((w2,r2),(i2,s2)) <-
         Maybe.fromBool $
         C.ifThen endOfVector (valueOf True, ((w0,r0),(i0,s0))) $ do
            a0 <- Serial.writeStop w0
            (cont1, (b1,s1)) <- Maybe.toBool $ next param a0 s0
            r1 <- Serial.readStart b1
            w1 <- Serial.writeStart
            return (cont1,
                      ((w1, r1),
                       (valueOf $ fromIntegral $ Serial.size a0, s1)))
      Maybe.lift $ do
         w3 <- Serial.writeNext ai w2
         (bi,r3) <- Serial.readNext r2
         i3 <- A.dec i2
         return (bi, ((w3,r3),(i3,s2))))
   (\s -> do
      s1 <- start s
      w <- Serial.writeZero
      return ((w, Tuple.undef), (valueOf (0::Word), s1)))
   (\(_wr,(_i,state)) -> stop state))


osciCore ::
   (Causal.C process,
    IsSized t, Vector.Real t, SoV.Fraction t, LLVM.IsFloating t,
    TypeNum.Positive n) =>
   process (Serial.Value n t, Serial.Value n t) (Serial.Value n t)
osciCore =
   Causal.zipWith A.addToPhase <<<
   Arr.second
      (Causal.mapAccum
         (\a phase0 -> do
            (phase1,b1) <- Serial.cumulate phase0 a
            phase2 <- A.signedFraction phase1
            return (b1,phase2))
         (return A.zero))

osci ::
   (Causal.C process,
    IsSized t, Vector.Real t, SoV.Fraction t, LLVM.IsFloating t,
    TypeNum.Positive n) =>
   (forall r. Serial.Value n t -> CodeGenFunction r y) ->
   process (Serial.Value n t, Serial.Value n t) y
osci wave =
   Causal.map wave <<< osciCore

shapeModOsci ::
   (Causal.C process,
    IsSized t, Vector.Real t, SoV.Fraction t, LLVM.IsFloating t,
    TypeNum.Positive n) =>
   (forall r. c -> Serial.Value n t -> CodeGenFunction r y) ->
   process (c, (Serial.Value n t, Serial.Value n t)) y
shapeModOsci wave =
   Causal.zipWith wave <<< Arr.second osciCore



arrayElement ::
   (Causal.C process,
    IsFirstClass a, LLVM.Value a ~ Serial.Element v, Serial.C v,
    TypeNum.Natural index, TypeNum.Natural dim,
    index :<: dim) =>
   Proxy index -> process (Value (LLVM.Array dim a)) v
arrayElement i =
   Causal.map Serial.upsample <<< Causal.arrayElement i