{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Rank2Types #-}
module Synthesizer.LLVM.CausalParameterized.ProcessPacked (
   CausalS.pack,
   CausalS.packSmall,
   CausalS.unpack,
   raise,
   amplify,
   amplifyStereo,
   CausalS.osciCore,
   osciSimple,
   CausalS.shapeModOsci,
   delay1,
   differentiate,
   integrate,
   CausalS.arrayElement,
   ) where

import Synthesizer.LLVM.CausalParameterized.ProcessPrivate (T)
import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.Causal.ProcessPacked as CausalS
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Frame as Frame
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import qualified Synthesizer.LLVM.Frame.Stereo as Stereo

import qualified LLVM.DSL.Parameter as Param

import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Marshal as Marshal
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Arithmetic as A

import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction, Value, IsSized, IsArithmetic, IsPrimitive)

import qualified Type.Data.Num.Decimal as TypeNum

import qualified Control.Category as Cat

import Data.Tuple.HT (swap)

import NumericPrelude.Numeric
import NumericPrelude.Base hiding (and, iterate, map, zip, zipWith)


raise ::
   (IsArithmetic a, Marshal.C a, Tuple.ValueOf a ~ Value a, IsPrimitive a,
    TypeNum.Positive n) =>
   Param.T p a ->
   T p (Serial.Value n a) (Serial.Value n a)
raise =
   CausalP.map
      (\x y -> Serial.upsample x >>= flip Frame.mix y)

amplify ::
   (IsArithmetic a, Marshal.C a, Tuple.ValueOf a ~ Value a, IsPrimitive a,
    TypeNum.Positive n) =>
   Param.T p a ->
   T p (Serial.Value n a) (Serial.Value n a)
amplify =
   CausalP.map
      (\x y -> Serial.upsample x >>= flip Frame.amplifyMono y)

amplifyStereo ::
   (IsArithmetic a, Marshal.C a, Tuple.ValueOf a ~ Value a, IsPrimitive a,
    TypeNum.Positive n) =>
   Param.T p a ->
   T p (Stereo.T (Serial.Value n a)) (Stereo.T (Serial.Value n a))
amplifyStereo =
   CausalP.map
      (\x y -> Serial.upsample x >>= flip Frame.amplifyStereo y)


-- for backwards compatibility
osciSimple ::
   (Causal.C process,
    Vector.Real t, SoV.Fraction t, LLVM.IsFloating t, IsSized t,
    TypeNum.Positive n) =>
   (forall r. Serial.Value n t -> CodeGenFunction r y) ->
   process (Serial.Value n t, Serial.Value n t) y
osciSimple = CausalS.osci


delay1 ::
   (Serial.C va, n ~ Serial.Size va, al ~ Serial.Element va,
    Marshal.C a, Tuple.ValueOf a ~ al) =>
   Param.T p a -> T p va va
delay1 initial =
   CausalP.loop initial $
   Causal.map (fmap swap . uncurry Serial.shiftUp . swap)

differentiate ::
   (Serial.C va, n ~ Serial.Size va, al ~ Serial.Element va,
    A.Additive va,
    Marshal.C a, Tuple.ValueOf a ~ al) =>
   Param.T p a -> T p va va
differentiate initial =
   Cat.id - delay1 initial

integrate ::
   (Vector.Arithmetic a, Marshal.C a, Tuple.ValueOf a ~ Value a, IsPrimitive a,
    TypeNum.Positive n) =>
   Param.T p a ->
   T p (Serial.Value n a) (Serial.Value n a)
integrate =
   CausalP.mapAccum
      (\() a acc0 -> do
         (acc1,b) <- Serial.cumulate acc0 a
         return (b,acc1))
      return
      (return ())