{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Synthesizer.LLVM.Filter.FirstOrder (
   Result(Result,lowpass_,highpass_), Parameter, parameter,
   causal, lowpassCausal, highpassCausal,
   causalInit, lowpassCausalInit, highpassCausalInit,
   causalInitPacked, lowpassCausalInitPacked, highpassCausalInitPacked,
   causalPacked, lowpassCausalPacked, highpassCausalPacked,
   causalRecursivePacked, -- for Allpass

   causalP, lowpassCausalP, highpassCausalP,
   causalInitP, lowpassCausalInitP, highpassCausalInitP,
   causalPackedP, lowpassCausalPackedP, highpassCausalPackedP,
   causalInitPackedP, lowpassCausalInitPackedP, highpassCausalInitPackedP,
   causalRecursivePackedP, -- for Allpass
   ) where

import qualified Synthesizer.Plain.Filter.Recursive.FirstOrder as FirstOrder
import Synthesizer.Plain.Filter.Recursive.FirstOrder
          (Parameter(Parameter), Result(Result))

import qualified Synthesizer.Plain.Modifier as Modifier

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import qualified Synthesizer.LLVM.Parameter as Param
import qualified Synthesizer.LLVM.Simple.Value as Value

import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Arithmetic as A
import LLVM.Extra.Class (Undefined, undefTuple, )

import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (Phi, phis, addPhis, )

import Foreign.Storable (Storable, )

import Control.Arrow (arr, (&&&), (<<<), )
import Control.Monad (liftM2, foldM, )

import NumericPrelude.Numeric
import NumericPrelude.Base


instance (Phi a) => Phi (Parameter a) where
   phis = Class.phisTraversable
   addPhis = Class.addPhisFoldable

instance Undefined a => Undefined (Parameter a) where
   undefTuple = Class.undefTuplePointed

instance (Memory.C a) => Memory.C (Parameter a) where
   type Struct (Parameter a) = Memory.Struct a
   load = Memory.loadNewtype Parameter
   store = Memory.storeNewtype (\(Parameter k) -> k)
   decompose = Memory.decomposeNewtype Parameter
   compose = Memory.composeNewtype (\(Parameter k) -> k)

instance (Value.Flatten a) => Value.Flatten (Parameter a) where
   type Registers (Parameter a) = Parameter (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable

instance (Value.Flatten a) => Value.Flatten (Result a) where
   type Registers (Result a) = Result (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable

{-
instance LLVM.ValueTuple a => LLVM.ValueTuple (Parameter a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (Parameter a) where
   tupleDesc = Class.tupleDescFoldable
-}

instance (Class.MakeValueTuple a) => Class.MakeValueTuple (Parameter a) where
   type ValueTuple (Parameter a) = Parameter (Class.ValueTuple a)
   valueTupleOf = Class.valueTupleOfFunctor


parameter ::
   (A.Transcendental a, A.RationalConstant a) =>
   a -> LLVM.CodeGenFunction r (Parameter a)
parameter = Value.unlift1 FirstOrder.parameter


modifier ::
   (a ~ A.Scalar v, A.PseudoModule v, A.IntegerConstant a) =>
   Modifier.Simple
      (Value.T v)
      (Parameter (Value.T a))
      (Value.T v) (Result (Value.T v))
modifier  = FirstOrder.modifier

lowpassModifier, highpassModifier ::
   (a ~ A.Scalar v, A.PseudoModule v, A.IntegerConstant a) =>
   Modifier.Simple
--      (FirstOrder.State (Value.T v))
      (Value.T v)
      (Parameter (Value.T a))
      (Value.T v) (Value.T v)
lowpassModifier  = FirstOrder.lowpassModifier
highpassModifier = FirstOrder.highpassModifier

causal ::
   (Causal.C process,
    A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   process (Parameter a, v) (Result v)
causal = Causal.fromModifier modifier

lowpassCausal, highpassCausal ::
   (Causal.C process,
    A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   process (Parameter a, v) v
lowpassCausal  = CausalP.fromModifier lowpassModifier
highpassCausal = CausalP.fromModifier highpassModifier


modifierInit ::
   (a ~ A.Scalar v, A.PseudoModule v, A.IntegerConstant a) =>
   Modifier.Initialized
      (Value.T v) (Value.T v)
      (Parameter (Value.T a))
      (Value.T v) (Result (Value.T v))
modifierInit = FirstOrder.modifierInit

lowpassModifierInit, highpassModifierInit ::
   (a ~ A.Scalar v, A.PseudoModule v, A.IntegerConstant a) =>
   Modifier.Initialized
      (Value.T v) (Value.T v)
      (Parameter (Value.T a))
      (Value.T v) (Value.T v)
lowpassModifierInit  = FirstOrder.lowpassModifierInit
highpassModifierInit = FirstOrder.highpassModifierInit

causalInit ::
   (Causal.C process,
    A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   v -> process (Parameter a, v) (Result v)
causalInit =
   Causal.fromModifier . Modifier.initialize modifierInit . Value.unfold

lowpassCausalInit, highpassCausalInit ::
   (Causal.C process,
    A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   v -> process (Parameter a, v) v
lowpassCausalInit =
   CausalP.fromModifier .
   Modifier.initialize lowpassModifierInit . Value.unfold
highpassCausalInit =
   CausalP.fromModifier .
   Modifier.initialize highpassModifierInit . Value.unfold


lowpassCausalPacked, highpassCausalPacked, causalRecursivePacked,
      preampPacked ::
   (Causal.C process,
    Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   process (Parameter a, v) v
highpassCausalPacked =
   Causal.zipWith A.sub <<< arr snd &&& lowpassCausalPacked
lowpassCausalPacked =
   causalRecursivePacked <<< (arr fst &&& preampPacked)

causalRecursivePacked =
   causalRecursiveInitPacked A.zero

lowpassCausalInitPacked, highpassCausalInitPacked, causalRecursiveInitPacked ::
   (Causal.C process,
    A.PseudoRing v, Serial.C v, Serial.Element v ~ a,
    A.PseudoRing a, A.IntegerConstant a, Memory.C a) =>
   a -> process (Parameter a, v) v
causalRecursiveInitPacked a =
   Causal.mapAccum causalRecursivePackedStep (return a)

highpassCausalInitPacked a =
   Causal.zipWith A.sub <<< arr snd &&& lowpassCausalInitPacked a
lowpassCausalInitPacked a =
   causalRecursiveInitPacked a <<< (arr fst &&& preampPacked)

preampPacked =
   Causal.map
      (\(FirstOrder.Parameter k, x) ->
         A.mul x =<< Serial.upsample =<< A.sub (A.fromInteger' 1) k)



{-
x = [x0, x1, x2, x3]

filter k y1 x
  = [x0 + k*y1,
     x1 + k*x0 + k^2*y1,
     x2 + k*x1 + k^2*x0 + k^3*y1,
     x3 + k*x2 + k^2*x1 + k^3*x0 + k^4*y1,
     ... ]

f0x = insert 0 (k*y1) x
f1x = f0x + k * f0x->1
f2x = f1x + k^2 * f1x->2
-}
causalRecursivePackedStep ::
   (A.PseudoRing v, Serial.C v, Serial.Element v ~ a, A.PseudoRing a) =>
   (Parameter a, v) -> a -> LLVM.CodeGenFunction r (v,a)
causalRecursivePackedStep =
      \(Parameter k, xk0) y1 -> do
         y1k <- A.mul k y1
         xk1 <- Serial.modify A.zero (A.add y1k) xk0
         let size = Serial.size xk0
         kv <- Serial.upsample k
         xk2 <-
            fmap fst $
            foldM
               (\(y,k0) d ->
                  liftM2 (,)
                     (A.add y =<<
                      Serial.shiftUpMultiZero d =<<
                      A.mul y k0)
                     (A.mul k0 k0))
               (xk1,kv)
               (takeWhile (< size) $ iterate (2*) 1)
{- do replicate in the loop
         xk2 <-
            fmap fst $
            foldM
               (\(y,k0) d ->
                  liftM2 (,)
                     (A.add y =<<
                      Serial.shiftUpMultiZero d =<<
                      A.mul y =<<
                      Serial.upsample k0)
                     (A.mul k0 k0))
               (xk1,k)
               (takeWhile (< size) $ iterate (2*) 1)
-}
         y0 <- Serial.extract (LLVM.valueOf $ fromIntegral $ size - 1) xk2
         return (xk2, y0)

{-
We can also optimize filtering with time-varying filter parameter.

k = [k0, k1, k2, k3]
x = [x0, x1, x2, x3]

filter k y1 x
  = [x0 + k0*y1,
     x1 + k1*x0 + k1*k0*y1,
     x2 + k2*x1 + k2*k1*x0 + k2*k1*k0*y1,
     x3 + k3*x2 + k3*k2*x1 + k3*k2*k1*x0 + k3*k2*k1*k0*y1,
     ... ]

f0x = insert 0 (k0*y1) x
f1x = f0x + k  * f0x->1      k'  = k * k->1
f2x = f1x + k' * f1x->2


We can even interpret vectorised first order filtering
as first order filtering with matrix coefficients.

[x0 + k0*y1,
 x1 + k1*x0 + k1*k0*y1,
 x2 + k2*x1 + k2*k1*x0 + k2*k1*k0*y1,
 x3 + k3*x2 + k3*k2*x1 + k3*k2*k1*x0 + k3*k2*k1*k0*y1]
  =
  / 1                   \   /x0\    / k0          0 0 0 \   /y1\
  | k1       1          | . |x1| +  | k1*k0       0 0 0 | . |y2|
  | k2*k1    k2    1    |   |x2|    | k2*k1*k0    0 0 0 |   |y3|
  \ k3*k2*k1 k3*k2 k3 1 /   \x3/    \ k3*k2*k1*k0 0 0 0 /   \y4/


  / 1                   \   / 1                 \   / 1          \
  | k1       1          | = |         1         | . | k1  1      |
  | k2*k1    k2    1    |   | k2*k1        1    |   |    k2  1   |
  \ k3*k2*k1 k3*k2 k3 1 /   \       k3*k2     1 /   \       k3 1 /
-}



addHighpass ::
   (Causal.C process, A.Additive v) =>
   process (param, v) v -> process (param, v) (Result v)
addHighpass lowpass =
{-
Before we added sharing to Simple.Value,
only this implementation allowed sharing
and using CausalP.fromModifier did not.
-}
   Causal.map (\(l,x) -> do
      h <- A.sub x l
      return (Result{FirstOrder.lowpass_ = l,
                     FirstOrder.highpass_ = h}))
    <<< (lowpass &&& arr snd)

causalPacked ::
   (Causal.C process,
    Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   process (Parameter a, v) (Result v)
causalPacked = addHighpass lowpassCausalPacked

causalInitPacked ::
   (Causal.C process,
    Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   a -> process (Parameter a, v) (Result v)
causalInitPacked a = addHighpass (lowpassCausalInitPacked a)



causalP ::
   (A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   CausalP.T p (Parameter a, v) (Result v)
causalP = addHighpass lowpassCausalP

lowpassCausalP, highpassCausalP ::
   (A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   CausalP.T p (Parameter a, v) v
lowpassCausalP  = lowpassCausal
highpassCausalP = highpassCausal

causalInitP ::
   (A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v,
    Storable vh, Class.MakeValueTuple vh, Class.ValueTuple vh ~ v) =>
   Param.T p vh -> CausalP.T p (Parameter a, v) (Result v)
causalInitP = CausalP.fromInitializedModifier modifierInit

lowpassCausalInitP, highpassCausalInitP ::
   (A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v,
    Storable vh, Class.MakeValueTuple vh, Class.ValueTuple vh ~ v) =>
   Param.T p vh -> CausalP.T p (Parameter a, v) v
lowpassCausalInitP = CausalP.fromInitializedModifier lowpassModifierInit
highpassCausalInitP = CausalP.fromInitializedModifier highpassModifierInit

lowpassCausalPackedP, highpassCausalPackedP, causalRecursivePackedP ::
   (Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   CausalP.T p (Parameter a, v) v
highpassCausalPackedP = highpassCausalPacked
lowpassCausalPackedP = lowpassCausalPacked
causalRecursivePackedP = causalRecursivePacked

lowpassCausalInitPackedP, highpassCausalInitPackedP,
      causalRecursiveInitPackedP ::
   (A.PseudoRing v, Serial.C v, Serial.Element v ~ a,
    A.PseudoRing a, A.IntegerConstant a, Memory.C a,
    Storable ah, Class.MakeValueTuple ah, Class.ValueTuple ah ~ a) =>
   Param.T p ah -> CausalP.T p (Parameter a, v) v
causalRecursiveInitPackedP a =
   CausalP.mapAccum (\() -> causalRecursivePackedStep) return (return ()) a

highpassCausalInitPackedP a =
   Causal.zipWith A.sub <<< arr snd &&& lowpassCausalInitPackedP a
lowpassCausalInitPackedP a =
   causalRecursiveInitPackedP a <<< (arr fst &&& preampPacked)

causalPackedP ::
   (Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   CausalP.T p (Parameter a, v) (Result v)
causalPackedP = causalPacked

causalInitPackedP ::
   (A.PseudoRing v, Serial.C v, Serial.Element v ~ a,
    A.PseudoRing a, A.IntegerConstant a, Memory.C a,
    Storable ah, Class.MakeValueTuple ah, Class.ValueTuple ah ~ a) =>
   Param.T p ah -> CausalP.T p (Parameter a, v) (Result v)
causalInitPackedP a = addHighpass (lowpassCausalInitPackedP a)


{-# DEPRECATED causalP                "use 'causal' instead" #-}
{-# DEPRECATED lowpassCausalP         "use 'lowpassCausal' instead" #-}
{-# DEPRECATED highpassCausalP        "use 'highpassCausal' instead" #-}
{-# DEPRECATED causalPackedP          "use 'causalPacked' instead" #-}
{-# DEPRECATED lowpassCausalPackedP   "use 'lowpassCausalPacked' instead" #-}
{-# DEPRECATED highpassCausalPackedP  "use 'highpassCausalPacked' instead" #-}
{-# DEPRECATED causalRecursivePackedP "use 'causalRecursivePacked' instead" #-}