{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
module Synthesizer.LLVM.Filter.ComplexFirstOrder (
   Parameter, parameter,
   causal, causalP,
   ) where

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Simple.Value as Value

import qualified Synthesizer.LLVM.Frame.Stereo as Stereo
import qualified Synthesizer.LLVM.Complex as Complex

import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Class as Class
import LLVM.Extra.Class (Undefined, undefTuple, )

import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction, )
import LLVM.Util.Loop (Phi, phis, addPhis, )

import Type.Data.Num.Decimal (d0, d1, d2, )

import qualified Control.Applicative as App
import Control.Applicative (liftA2, liftA3, (<*>), )

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold

import NumericPrelude.Numeric
import NumericPrelude.Base


data Parameter a =
   Parameter a (Complex.T a)

instance Functor Parameter where
   {-# INLINE fmap #-}
   fmap f (Parameter k c) =
      Parameter (f k) (fmap f c)

instance App.Applicative Parameter where
   {-# INLINE pure #-}
   pure x = Parameter x (x Complex.+: x)
   {-# INLINE (<*>) #-}
   Parameter fk fc <*> Parameter pk pc =
      Parameter (fk pk) $
         (Complex.real fc $ Complex.real pc)
         Complex.+:
         (Complex.imag fc $ Complex.imag pc)

instance Fold.Foldable Parameter where
   {-# INLINE foldMap #-}
   foldMap = Trav.foldMapDefault

instance Trav.Traversable Parameter where
   {-# INLINE sequenceA #-}
   sequenceA (Parameter k c) =
      liftA2 Parameter k $
      liftA2 (Complex.+:) (Complex.real c) (Complex.imag c)


instance (Phi a) => Phi (Parameter a) where
   phis = Class.phisTraversable
   addPhis = Class.addPhisFoldable

instance Undefined a => Undefined (Parameter a) where
   undefTuple = Class.undefTuplePointed


type ParameterStruct a = LLVM.Struct (a, (a, (a, ())))

parameterMemory ::
   (Memory.C a) =>
   Memory.Record r (ParameterStruct (Memory.Struct a)) (Parameter a)
parameterMemory =
   liftA3 (\amp kr ki -> Parameter amp (kr Complex.+: ki))
      (Memory.element (\(Parameter  amp _) -> amp) d0)
      (Memory.element (\(Parameter _amp k) -> Complex.real k) d1)
      (Memory.element (\(Parameter _amp k) -> Complex.imag k) d2)

instance (Memory.C a) => Memory.C (Parameter a) where
   type Struct (Parameter a) = ParameterStruct (Memory.Struct a)
   load = Memory.loadRecord parameterMemory
   store = Memory.storeRecord parameterMemory
   decompose = Memory.decomposeRecord parameterMemory
   compose = Memory.composeRecord parameterMemory

instance (Value.Flatten a) => Value.Flatten (Parameter a) where
   type Registers (Parameter a) = Parameter (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable


parameter, _parameter ::
   (A.Transcendental a, A.RationalConstant a) =>
   a -> a -> CodeGenFunction r (Parameter a)
parameter reson freq =
   let amp = recip $ Value.unfold reson
   in  Value.flatten $ Parameter amp $
       Complex.scale (1-amp) $ Complex.cis $
       Value.unfold freq * Value.twoPi

_parameter reson freq = do
   amp <- A.fdiv A.one reson
   k   <- A.sub  A.one amp
   w  <- A.mul freq =<< Value.decons Value.twoPi
   kr <- A.mul k =<< A.cos w
   ki <- A.mul k =<< A.sin w
   return (Parameter amp (kr Complex.+: ki))


{-
Synthesizer.Plain.Filter.Recursive.FirstOrderComplex.step
cannot be used directly, because Filt1C has complex amplitude
-}
next, _next ::
   (A.PseudoRing a, A.IntegerConstant a) =>
   (Parameter a, Stereo.T a) ->
   Complex.T a ->
   CodeGenFunction r (Stereo.T a, Complex.T a)
next inp state =
   let stereoFromComplex ::
          Complex.T a -> Complex.T (Value.T a) ->
          Stereo.T (Value.T a)
       stereoFromComplex _ c =
          Stereo.cons (Complex.real c) (Complex.imag c)
       (Parameter amp k, x) = Value.unfold inp
       xc = Stereo.left x  Complex.+:  Stereo.right x
       y = Complex.scale amp xc + k * Value.unfold state
   in  Value.flatten (stereoFromComplex state y, y)

_next (Parameter amp k, x) s = do
   let kr = Complex.real k
       ki = Complex.imag k
       sr = Complex.real s
       si = Complex.imag s
   yr <- Value.decons $
      Value.lift0 (A.mul (Stereo.left x) amp) +
      Value.lift0 (A.mul kr sr) - Value.lift0 (A.mul ki si)
   yi <- Value.decons $
      Value.lift0 (A.mul (Stereo.right x) amp) +
      Value.lift0 (A.mul kr si) + Value.lift0 (A.mul ki sr)
   return (Stereo.cons yr yi, yr Complex.+: yi)


start ::
   (A.Additive a) =>
   CodeGenFunction r (Complex.T a)
start =
   return (A.zero Complex.+: A.zero)

causal ::
   (Causal.C process, A.PseudoRing a, A.IntegerConstant a, Memory.C a) =>
   process
      (Parameter a, Stereo.T a)
      (Stereo.T a)
causal =
   Causal.mapAccum next start

{-# DEPRECATED causalP "use causal instead" #-}
causalP ::
   (A.PseudoRing a, A.IntegerConstant a, Memory.C a) =>
   CausalP.T p
      (Parameter a, Stereo.T a)
      (Stereo.T a)
causalP =
   CausalP.mapAccumSimple next start