{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Synthesizer.LLVM.Filter.FirstOrder (
   Result(Result,lowpass_,highpass_), Parameter, parameter,
   causal, lowpassCausal, highpassCausal,
   causalInit, lowpassCausalInit, highpassCausalInit,
   causalInitPacked, lowpassCausalInitPacked, highpassCausalInitPacked,
   causalPacked, lowpassCausalPacked, highpassCausalPacked,
   causalRecursivePacked, -- for Allpass

   causalP, lowpassCausalP, highpassCausalP,
   causalInitP, lowpassCausalInitP, highpassCausalInitP,
   causalPackedP, lowpassCausalPackedP, highpassCausalPackedP,
   causalInitPackedP, lowpassCausalInitPackedP, highpassCausalInitPackedP,
   causalRecursivePackedP, -- for Allpass
   ) where

import qualified Synthesizer.Plain.Filter.Recursive.FirstOrder as FirstOrder
import Synthesizer.Plain.Filter.Recursive.FirstOrder
          (Parameter(Parameter), Result(Result,lowpass_,highpass_))

import qualified Synthesizer.Plain.Modifier as Modifier

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import qualified Synthesizer.LLVM.Simple.Value as Value

import qualified LLVM.DSL.Parameter as Param

import qualified LLVM.Extra.Storable as Storable
import qualified LLVM.Extra.Marshal as Marshal
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Arithmetic as A

import qualified LLVM.Core as LLVM

import Control.Arrow (arr, (&&&), (<<<))
import Control.Monad (liftM2, foldM)

import NumericPrelude.Numeric
import NumericPrelude.Base


instance (Tuple.Phi a) => Tuple.Phi (Parameter a) where
   phi = Tuple.phiTraversable
   addPhi = Tuple.addPhiFoldable

instance Tuple.Undefined a => Tuple.Undefined (Parameter a) where
   undef = Tuple.undefPointed

instance (Memory.C a) => Memory.C (Parameter a) where
   type Struct (Parameter a) = Memory.Struct a
   load = Memory.loadNewtype Parameter
   store = Memory.storeNewtype (\(Parameter k) -> k)
   decompose = Memory.decomposeNewtype Parameter
   compose = Memory.composeNewtype (\(Parameter k) -> k)

instance (Marshal.C a) => Marshal.C (Parameter a) where
   pack (Parameter k) = Marshal.pack k
   unpack = Parameter . Marshal.unpack

instance (Storable.C a) => Storable.C (Parameter a) where
   load = Storable.loadNewtype Parameter Parameter
   store = Storable.storeNewtype Parameter (\(Parameter k) -> k)

instance (Value.Flatten a) => Value.Flatten (Parameter a) where
   type Registers (Parameter a) = Parameter (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable

instance (Value.Flatten a) => Value.Flatten (Result a) where
   type Registers (Result a) = Result (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable

{-
instance LLVM.ValueTuple a => LLVM.ValueTuple (Parameter a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (Parameter a) where
   tupleDesc = Class.tupleDescFoldable
-}

instance (Tuple.Value a) => Tuple.Value (Parameter a) where
   type ValueOf (Parameter a) = Parameter (Tuple.ValueOf a)
   valueOf = Tuple.valueOfFunctor


parameter ::
   (A.Transcendental a, A.RationalConstant a) =>
   a -> LLVM.CodeGenFunction r (Parameter a)
parameter = Value.unlift1 FirstOrder.parameter


modifier ::
   (a ~ A.Scalar v, A.PseudoModule v, A.IntegerConstant a) =>
   Modifier.Simple
      (Value.T v)
      (Parameter (Value.T a))
      (Value.T v) (Result (Value.T v))
modifier = FirstOrder.modifier

lowpassModifier, highpassModifier ::
   (a ~ A.Scalar v, A.PseudoModule v, A.IntegerConstant a) =>
   Modifier.Simple
--      (FirstOrder.State (Value.T v))
      (Value.T v)
      (Parameter (Value.T a))
      (Value.T v) (Value.T v)
lowpassModifier  = FirstOrder.lowpassModifier
highpassModifier = FirstOrder.highpassModifier

causal ::
   (Causal.C process,
    A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   process (Parameter a, v) (Result v)
causal = Causal.fromModifier modifier

lowpassCausal, highpassCausal ::
   (Causal.C process,
    A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   process (Parameter a, v) v
lowpassCausal  = CausalP.fromModifier lowpassModifier
highpassCausal = CausalP.fromModifier highpassModifier


modifierInit ::
   (a ~ A.Scalar v, A.PseudoModule v, A.IntegerConstant a) =>
   Modifier.Initialized
      (Value.T v) (Value.T v)
      (Parameter (Value.T a))
      (Value.T v) (Result (Value.T v))
modifierInit = FirstOrder.modifierInit

lowpassModifierInit, highpassModifierInit ::
   (a ~ A.Scalar v, A.PseudoModule v, A.IntegerConstant a) =>
   Modifier.Initialized
      (Value.T v) (Value.T v)
      (Parameter (Value.T a))
      (Value.T v) (Value.T v)
lowpassModifierInit  = FirstOrder.lowpassModifierInit
highpassModifierInit = FirstOrder.highpassModifierInit

causalInit ::
   (Causal.C process,
    A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   v -> process (Parameter a, v) (Result v)
causalInit =
   Causal.fromModifier . Modifier.initialize modifierInit . Value.unfold

lowpassCausalInit, highpassCausalInit ::
   (Causal.C process,
    A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   v -> process (Parameter a, v) v
lowpassCausalInit =
   CausalP.fromModifier .
   Modifier.initialize lowpassModifierInit . Value.unfold
highpassCausalInit =
   CausalP.fromModifier .
   Modifier.initialize highpassModifierInit . Value.unfold


lowpassCausalPacked, highpassCausalPacked, causalRecursivePacked,
      preampPacked ::
   (Causal.C process,
    Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   process (Parameter a, v) v
highpassCausalPacked =
   Causal.zipWith A.sub <<< arr snd &&& lowpassCausalPacked
lowpassCausalPacked =
   causalRecursivePacked <<< (arr fst &&& preampPacked)

causalRecursivePacked =
   causalRecursiveInitPacked A.zero

lowpassCausalInitPacked, highpassCausalInitPacked, causalRecursiveInitPacked ::
   (Causal.C process,
    A.PseudoRing v, Serial.C v, Serial.Element v ~ a,
    A.PseudoRing a, A.IntegerConstant a, Memory.C a) =>
   a -> process (Parameter a, v) v
causalRecursiveInitPacked a =
   Causal.mapAccum causalRecursivePackedStep (return a)

highpassCausalInitPacked a =
   Causal.zipWith A.sub <<< arr snd &&& lowpassCausalInitPacked a
lowpassCausalInitPacked a =
   causalRecursiveInitPacked a <<< (arr fst &&& preampPacked)

preampPacked =
   Causal.map
      (\(Parameter k, x) ->
         A.mul x =<< Serial.upsample =<< A.sub (A.fromInteger' 1) k)



{-
x = [x0, x1, x2, x3]

filter k y1 x
  = [x0 + k*y1,
     x1 + k*x0 + k^2*y1,
     x2 + k*x1 + k^2*x0 + k^3*y1,
     x3 + k*x2 + k^2*x1 + k^3*x0 + k^4*y1,
     ... ]

f0x = insert 0 (k*y1) x
f1x = f0x + k * f0x->1
f2x = f1x + k^2 * f1x->2
-}
causalRecursivePackedStep ::
   (A.PseudoRing v, Serial.C v, Serial.Element v ~ a, A.PseudoRing a) =>
   (Parameter a, v) -> a -> LLVM.CodeGenFunction r (v,a)
causalRecursivePackedStep =
      \(Parameter k, xk0) y1 -> do
         y1k <- A.mul k y1
         xk1 <- Serial.modify A.zero (A.add y1k) xk0
         let size = Serial.size xk0
         kv <- Serial.upsample k
         xk2 <-
            fmap fst $
            foldM
               (\(y,k0) d ->
                  liftM2 (,)
                     (A.add y =<<
                      Serial.shiftUpMultiZero d =<<
                      A.mul y k0)
                     (A.mul k0 k0))
               (xk1,kv)
               (takeWhile (< size) $ iterate (2*) 1)
{- do replicate in the loop
         xk2 <-
            fmap fst $
            foldM
               (\(y,k0) d ->
                  liftM2 (,)
                     (A.add y =<<
                      Serial.shiftUpMultiZero d =<<
                      A.mul y =<<
                      Serial.upsample k0)
                     (A.mul k0 k0))
               (xk1,k)
               (takeWhile (< size) $ iterate (2*) 1)
-}
         y0 <- Serial.extract (LLVM.valueOf $ fromIntegral $ size - 1) xk2
         return (xk2, y0)

{-
We can also optimize filtering with time-varying filter parameter.

k = [k0, k1, k2, k3]
x = [x0, x1, x2, x3]

filter k y1 x
  = [x0 + k0*y1,
     x1 + k1*x0 + k1*k0*y1,
     x2 + k2*x1 + k2*k1*x0 + k2*k1*k0*y1,
     x3 + k3*x2 + k3*k2*x1 + k3*k2*k1*x0 + k3*k2*k1*k0*y1,
     ... ]

f0x = insert 0 (k0*y1) x
f1x = f0x + k  * f0x->1      k'  = k * k->1
f2x = f1x + k' * f1x->2


We can even interpret vectorised first order filtering
as first order filtering with matrix coefficients.

[x0 + k0*y1,
 x1 + k1*x0 + k1*k0*y1,
 x2 + k2*x1 + k2*k1*x0 + k2*k1*k0*y1,
 x3 + k3*x2 + k3*k2*x1 + k3*k2*k1*x0 + k3*k2*k1*k0*y1]
  =
  / 1                   \   /x0\    / k0          0 0 0 \   /y1\
  | k1       1          | . |x1| +  | k1*k0       0 0 0 | . |y2|
  | k2*k1    k2    1    |   |x2|    | k2*k1*k0    0 0 0 |   |y3|
  \ k3*k2*k1 k3*k2 k3 1 /   \x3/    \ k3*k2*k1*k0 0 0 0 /   \y4/


  / 1                   \   / 1                 \   / 1          \
  | k1       1          | = |         1         | . | k1  1      |
  | k2*k1    k2    1    |   | k2*k1        1    |   |    k2  1   |
  \ k3*k2*k1 k3*k2 k3 1 /   \       k3*k2     1 /   \       k3 1 /
-}



addHighpass ::
   (Causal.C process, A.Additive v) =>
   process (param, v) v -> process (param, v) (Result v)
addHighpass lowpass =
{-
Before we added sharing to Simple.Value,
only this implementation allowed sharing
and using CausalP.fromModifier did not.
-}
   Causal.map (\(l,x) -> do
      h <- A.sub x l
      return (Result{lowpass_ = l, highpass_ = h}))
    <<< (lowpass &&& arr snd)

causalPacked ::
   (Causal.C process,
    Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   process (Parameter a, v) (Result v)
causalPacked = addHighpass lowpassCausalPacked

causalInitPacked ::
   (Causal.C process,
    Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   a -> process (Parameter a, v) (Result v)
causalInitPacked a = addHighpass (lowpassCausalInitPacked a)



causalP ::
   (A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   CausalP.T p (Parameter a, v) (Result v)
causalP = addHighpass lowpassCausalP

lowpassCausalP, highpassCausalP ::
   (A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   CausalP.T p (Parameter a, v) v
lowpassCausalP  = lowpassCausal
highpassCausalP = highpassCausal

causalInitP ::
   (A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v,
    Marshal.C vh, Tuple.ValueOf vh ~ v) =>
   Param.T p vh -> CausalP.T p (Parameter a, v) (Result v)
causalInitP = CausalP.fromInitializedModifier modifierInit

lowpassCausalInitP, highpassCausalInitP ::
   (A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v,
    Marshal.C vh, Tuple.ValueOf vh ~ v) =>
   Param.T p vh -> CausalP.T p (Parameter a, v) v
lowpassCausalInitP = CausalP.fromInitializedModifier lowpassModifierInit
highpassCausalInitP = CausalP.fromInitializedModifier highpassModifierInit

lowpassCausalPackedP, highpassCausalPackedP, causalRecursivePackedP ::
   (Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   CausalP.T p (Parameter a, v) v
highpassCausalPackedP = highpassCausalPacked
lowpassCausalPackedP = lowpassCausalPacked
causalRecursivePackedP = causalRecursivePacked

lowpassCausalInitPackedP, highpassCausalInitPackedP,
      causalRecursiveInitPackedP ::
   (A.PseudoRing v, Serial.C v, Serial.Element v ~ a,
    A.PseudoRing a, A.IntegerConstant a,
    Marshal.C ah, Tuple.ValueOf ah ~ a) =>
   Param.T p ah -> CausalP.T p (Parameter a, v) v
causalRecursiveInitPackedP a =
   CausalP.mapAccum (\() -> causalRecursivePackedStep) return (return ()) a

highpassCausalInitPackedP a =
   Causal.zipWith A.sub <<< arr snd &&& lowpassCausalInitPackedP a
lowpassCausalInitPackedP a =
   causalRecursiveInitPackedP a <<< (arr fst &&& preampPacked)

causalPackedP ::
   (Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   CausalP.T p (Parameter a, v) (Result v)
causalPackedP = causalPacked

causalInitPackedP ::
   (A.PseudoRing v, Serial.C v, Serial.Element v ~ a,
    A.PseudoRing a, A.IntegerConstant a,
    Marshal.C ah, Tuple.ValueOf ah ~ a) =>
   Param.T p ah -> CausalP.T p (Parameter a, v) (Result v)
causalInitPackedP a = addHighpass (lowpassCausalInitPackedP a)


{-# DEPRECATED causalP                "use 'causal' instead" #-}
{-# DEPRECATED lowpassCausalP         "use 'lowpassCausal' instead" #-}
{-# DEPRECATED highpassCausalP        "use 'highpassCausal' instead" #-}
{-# DEPRECATED causalPackedP          "use 'causalPacked' instead" #-}
{-# DEPRECATED lowpassCausalPackedP   "use 'lowpassCausalPacked' instead" #-}
{-# DEPRECATED highpassCausalPackedP  "use 'highpassCausalPacked' instead" #-}
{-# DEPRECATED causalRecursivePackedP "use 'causalRecursivePacked' instead" #-}