{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
module Synthesizer.LLVM.Filter.ComplexFirstOrder (
   Parameter, parameter,
   causal, causalP,
   ) where

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Simple.Value as Value

import qualified Synthesizer.LLVM.Frame.Stereo as Stereo
import qualified Synthesizer.LLVM.Complex as Complex

import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Tuple as Tuple

import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction)

import Type.Data.Num.Decimal (d0, d1, d2)

import qualified Control.Applicative as App
import Control.Applicative (liftA2, liftA3, (<*>))

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold

import NumericPrelude.Numeric
import NumericPrelude.Base


data Parameter a =
   Parameter a (Complex.T a)

instance Functor Parameter where
   {-# INLINE fmap #-}
   fmap f (Parameter k c) =
      Parameter (f k) (fmap f c)

instance App.Applicative Parameter where
   {-# INLINE pure #-}
   pure x = Parameter x (x Complex.+: x)
   {-# INLINE (<*>) #-}
   Parameter fk fc <*> Parameter pk pc =
      Parameter (fk pk) $
         (Complex.real fc $ Complex.real pc)
         Complex.+:
         (Complex.imag fc $ Complex.imag pc)

instance Fold.Foldable Parameter where
   {-# INLINE foldMap #-}
   foldMap = Trav.foldMapDefault

instance Trav.Traversable Parameter where
   {-# INLINE sequenceA #-}
   sequenceA (Parameter k c) =
      liftA2 Parameter k $
      liftA2 (Complex.+:) (Complex.real c) (Complex.imag c)


instance (Tuple.Phi a) => Tuple.Phi (Parameter a) where
   phi = Tuple.phiTraversable
   addPhi = Tuple.addPhiFoldable

instance Tuple.Undefined a => Tuple.Undefined (Parameter a) where
   undef = Tuple.undefPointed


type ParameterStruct a = LLVM.Struct (a, (a, (a, ())))

parameterMemory ::
   (Memory.C a) =>
   Memory.Record r (ParameterStruct (Memory.Struct a)) (Parameter a)
parameterMemory =
   liftA3 (\amp kr ki -> Parameter amp (kr Complex.+: ki))
      (Memory.element (\(Parameter  amp _) -> amp) d0)
      (Memory.element (\(Parameter _amp k) -> Complex.real k) d1)
      (Memory.element (\(Parameter _amp k) -> Complex.imag k) d2)

instance (Memory.C a) => Memory.C (Parameter a) where
   type Struct (Parameter a) = ParameterStruct (Memory.Struct a)
   load = Memory.loadRecord parameterMemory
   store = Memory.storeRecord parameterMemory
   decompose = Memory.decomposeRecord parameterMemory
   compose = Memory.composeRecord parameterMemory

instance (Value.Flatten a) => Value.Flatten (Parameter a) where
   type Registers (Parameter a) = Parameter (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable


parameter, _parameter ::
   (A.Transcendental a, A.RationalConstant a) =>
   a -> a -> CodeGenFunction r (Parameter a)
parameter reson freq =
   let amp = recip $ Value.unfold reson
   in  Value.flatten $ Parameter amp $
       Complex.scale (1-amp) $ Complex.cis $
       Value.unfold freq * Value.twoPi

_parameter reson freq = do
   amp <- A.fdiv A.one reson
   k   <- A.sub  A.one amp
   w  <- A.mul freq =<< Value.decons Value.twoPi
   kr <- A.mul k =<< A.cos w
   ki <- A.mul k =<< A.sin w
   return (Parameter amp (kr Complex.+: ki))


{-
Synthesizer.Plain.Filter.Recursive.FirstOrderComplex.step
cannot be used directly, because Filt1C has complex amplitude
-}
next, _next ::
   (A.PseudoRing a, A.IntegerConstant a) =>
   (Parameter a, Stereo.T a) ->
   Complex.T a ->
   CodeGenFunction r (Stereo.T a, Complex.T a)
next inp state =
   let stereoFromComplex ::
          Complex.T a -> Complex.T (Value.T a) ->
          Stereo.T (Value.T a)
       stereoFromComplex _ c =
          Stereo.cons (Complex.real c) (Complex.imag c)
       (Parameter amp k, x) = Value.unfold inp
       xc = Stereo.left x  Complex.+:  Stereo.right x
       y = Complex.scale amp xc + k * Value.unfold state
   in  Value.flatten (stereoFromComplex state y, y)

_next (Parameter amp k, x) s = do
   let kr = Complex.real k
       ki = Complex.imag k
       sr = Complex.real s
       si = Complex.imag s
   yr <- Value.decons $
      Value.lift0 (A.mul (Stereo.left x) amp) +
      Value.lift0 (A.mul kr sr) - Value.lift0 (A.mul ki si)
   yi <- Value.decons $
      Value.lift0 (A.mul (Stereo.right x) amp) +
      Value.lift0 (A.mul kr si) + Value.lift0 (A.mul ki sr)
   return (Stereo.cons yr yi, yr Complex.+: yi)


start ::
   (A.Additive a) =>
   CodeGenFunction r (Complex.T a)
start =
   return (A.zero Complex.+: A.zero)

causal ::
   (Causal.C process, A.PseudoRing a, A.IntegerConstant a, Memory.C a) =>
   process
      (Parameter a, Stereo.T a)
      (Stereo.T a)
causal =
   Causal.mapAccum next start

{-# DEPRECATED causalP "use causal instead" #-}
causalP ::
   (A.PseudoRing a, A.IntegerConstant a, Memory.C a) =>
   CausalP.T p
      (Parameter a, Stereo.T a)
      (Stereo.T a)
causalP =
   CausalP.mapAccumSimple next start