{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Synthesizer.LLVM.Filter.SecondOrder (
   Parameter, bandpassParameter,
   ParameterStruct, -- for cascade
   causalP, causalPackedP,
   ) where

import qualified Synthesizer.Plain.Filter.Recursive.SecondOrder as Filt2
import Synthesizer.Plain.Filter.Recursive.SecondOrder (Parameter(Parameter), )

import qualified Synthesizer.Plain.Modifier as Modifier

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.Simple.Value as Value

import qualified LLVM.Extra.Representation as Rep
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector

import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Monad as M

import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, valueOf, Struct, Undefined, undefTuple,
    IsFirstClass, IsConst, IsArithmetic, IsFloating,
    Vector, IsPowerOf2, IsPrimitive, IsSized,
    CodeGenFunction, )
import LLVM.Util.Loop (Phi, phis, addPhis, )

import Data.TypeLevel.Num (d0, d1, d2, d3, d4, )
import qualified Data.TypeLevel.Num as TypeNum

import Control.Arrow (arr, (<<<), (&&&), )
import Control.Monad (liftM2, foldM, )
import Synthesizer.ApplicativeUtility (liftA4, liftA5, )

import qualified Algebra.Transcendental as Trans
-- import qualified Algebra.Field as Field
import qualified Algebra.Module as Module
import qualified Algebra.Ring as Ring

import NumericPrelude.Numeric
import NumericPrelude.Base


instance (Phi a) => Phi (Parameter a) where
   phis = Class.phisTraversable
   addPhis = Class.addPhisFoldable

instance Undefined a => Undefined (Parameter a) where
   undefTuple = Class.undefTuplePointed

instance LLVM.ValueTuple a => LLVM.ValueTuple (Parameter a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (Parameter a) where
   tupleDesc = Class.tupleDescFoldable

instance LLVM.MakeValueTuple h l =>
      LLVM.MakeValueTuple (Parameter h) (Parameter l) where
   valueTupleOf = Class.valueTupleOfFunctor


type ParameterStruct a = Struct (a, (a, (a, (a, (a, ())))))

parameterMemory ::
   (Rep.Memory a s, IsSized s ss) =>
   Rep.MemoryRecord r (ParameterStruct s) (Parameter a)
parameterMemory =
   liftA5 Parameter
      (Rep.memoryElement Filt2.c0 d0)
      (Rep.memoryElement Filt2.c1 d1)
      (Rep.memoryElement Filt2.c2 d2)
      (Rep.memoryElement Filt2.d1 d3)
      (Rep.memoryElement Filt2.d2 d4)

instance
      (Rep.Memory a s, IsSized s ss) =>
      Rep.Memory (Parameter a) (Struct (s, (s, (s, (s, (s, ())))))) where
   load = Rep.loadRecord parameterMemory
   store = Rep.storeRecord parameterMemory
   decompose = Rep.decomposeRecord parameterMemory
   compose = Rep.composeRecord parameterMemory


instance (Value.Flatten ah al) =>
      Value.Flatten (Parameter ah) (Parameter al) where
   flatten = Value.flattenTraversable
   unfold =  Value.unfoldFunctor



instance (Phi a) => Phi (Filt2.State a) where
   phis = Class.phisTraversable
   addPhis = Class.addPhisFoldable

instance Undefined a => Undefined (Filt2.State a) where
   undefTuple = Class.undefTuplePointed

stateMemory ::
   (Rep.Memory a s, IsSized s ss) =>
   Rep.MemoryRecord r (Struct (s, (s, (s, (s, (s, ())))))) (Filt2.State a)
stateMemory =
   liftA4 Filt2.State
      (Rep.memoryElement Filt2.u1 d0)
      (Rep.memoryElement Filt2.u2 d1)
      (Rep.memoryElement Filt2.y1 d2)
      (Rep.memoryElement Filt2.y2 d3)


instance
      (Rep.Memory a s, IsSized s ss) =>
      Rep.Memory (Filt2.State a) (Struct (s, (s, (s, (s, (s, ())))))) where
   load = Rep.loadRecord stateMemory
   store = Rep.storeRecord stateMemory
   decompose = Rep.decomposeRecord stateMemory
   compose = Rep.composeRecord stateMemory

instance (Value.Flatten ah al) =>
      Value.Flatten (Filt2.State ah) (Filt2.State al) where
   flatten = Value.flattenTraversable
   unfold =  Value.unfoldFunctor


{-# DEPRECATED bandpassParameter "only for testing, use Universal or Moog filter for production code" #-}
bandpassParameter ::
   (Trans.C a, IsFloating a, IsConst a) =>
   Value a ->
   Value a ->
   CodeGenFunction r (Parameter (Value a))
bandpassParameter reson cutoff = do
   rreson <- A.fdiv (valueOf 1) reson
   k <- A.sub (valueOf 1) rreson
   k2 <- LLVM.neg =<< A.mul k k
   kcos <-
      A.mul (valueOf 2) =<< A.mul k =<<
      A.cos =<< A.mul cutoff =<<
      Value.decons Value.twoPi
   return $
      Filt2.Parameter
         rreson (valueOf zero) (valueOf zero)
         kcos k2

modifier ::
   (Module.C (Value.T a) (Value.T v), IsArithmetic a, IsConst a) =>
   Modifier.Simple
      (Filt2.State (Value.T v))
      (Parameter (Value.T a))
      (Value.T v) (Value.T v)
modifier =
   Filt2.modifier

causalP ::
   (Ring.C a, Module.C (Value.T a) (Value.T v),
    IsFirstClass a, IsSized a as, IsConst a,
    IsFirstClass v, IsSized v vs, IsConst v,
    IsArithmetic a) =>
   CausalP.T p
      (Parameter (Value a), Value v) (Value v)
causalP =
   CausalP.fromModifier modifier


{- |
Vector size must be at least D2.
-}
causalPackedP,
  causalRecursivePackedP ::
   (Ring.C a,
    IsFirstClass a, IsArithmetic a, IsConst a,
    IsPowerOf2 n, IsPrimitive a, IsSized a as,
    TypeNum.Mul n as vas, TypeNum.Pos vas) =>
--    IsPowerOf2 n, IsPrimitive a, IsSized (Vector n a) as) =>
   CausalP.T p
      (Parameter (Value a), Value (Vector n a)) (Value (Vector n a))
causalPackedP =
   causalRecursivePackedP <<<
   (arr fst &&& causalNonRecursivePackedP)

_causalRecursivePackedPAlt,
  causalNonRecursivePackedP ::
   (Ring.C a,
    IsFirstClass a, IsArithmetic a, IsConst a,
    IsPowerOf2 n, IsPrimitive a, IsSized a as) =>
   CausalP.T p
      (Parameter (Value a), Value (Vector n a)) (Value (Vector n a))
causalNonRecursivePackedP =
   CausalP.mapAccumSimple
      (\(p, v0) (x1,x2) -> do
         (_,v1) <- Vector.shiftUp x1 v0
         (_,v2) <- Vector.shiftUp x2 v1
         w0 <- A.mul v0 =<< SoV.replicate (Filt2.c0 p)
         w1 <- A.mul v1 =<< SoV.replicate (Filt2.c1 p)
         w2 <- A.mul v2 =<< SoV.replicate (Filt2.c2 p)
         y  <- A.add w0 =<< A.add w1 w2
         let size = fromIntegral $ Vector.sizeInTuple v0
         u0 <- Vector.extract (valueOf $ size - 1) v0
         u1 <- Vector.extract (valueOf $ size - 2) v0
         return (y, (u0,u1)))
      (return (LLVM.value LLVM.zero, LLVM.value LLVM.zero))

{-
A filter of second order can be considered
as the convolution of two filters of first order.

[1,r]*[1,0,r^2] = [1,r,r^2,r^3]
[1,r,r^2,r^3] * [1,s,s^2,s^3]
 = [1,r]*[1,s]*[1,0,r^2]*[1,0,s^2]
     with
       a=r+s
       b=r*s
 = [1,a,b]*[1,0,r^2]*[1,0,s^2]
 = [1,a,b]*[1,0,a^2-2*b,0,b^2]

[1,0,0,0,r^4]*[1,0,0,0,s^4]
 = [1,0,0,0,(a^2-2*b)^2-2*b^2,0,0,0,b^4]
 = [1,0,0,0,a^4-4*a^2*b+2*b^2,0,0,0,b^4]
-}

{-
x = [x0, x1, x2, x3]

filter2 (a,-b) (y1,y2) x
  = [x0 + a*y1 - b*y2,
     x1 + a*x0 + (a^2-b)*y1 - a*b*y2,
     x2 + a*x1 + (a^2-b)*x0 + (a^3-2*a*b)*y1 + (-a^2*b+b^2)*y2,
     x3 + a*x2 + (a^2-b)*x1 + (a^3-2*a*b)*x0 + (a^4-3*a^2*b+b^2)*y1 + (-a^3*b+2*a*b^2)*y2]

(f0x = insert 0 (k*y1) x)
f1x = f0x + a * f0x->1 + b * f0x->2
f2x = f1x + (a^2-2*b) * f1x->2 + b^2 * f1x->4
-}
causalRecursivePackedP =
   CausalP.mapAccumSimple
      (\(p, x0) y1v -> do
         let size = Vector.sizeInTuple x0

         d1v  <- SoV.replicate (Filt2.d1 p)
         d2v  <- SoV.replicate (Filt2.d2 p)
         d2vn <- LLVM.neg d2v

         y1  <- Vector.extract (valueOf $ fromIntegral size - 1) y1v
         xk1 <-
            Vector.modify (valueOf 0)
               (\u0 -> A.add u0 =<< A.mul (Filt2.d1 p) y1) =<<
            A.add x0 =<< A.mul d2v =<<
            Vector.shiftDownMultiZero (size - 2) y1v

         -- let xk2 = xk1
         xk2 <-
            fmap fst $
            foldM
               (\(y,(a,b)) d ->
                  liftM2 (,)
                     (A.add y =<<
                      M.liftR2 A.add
                         {-
                         Possibility for optimization:
                         In the last step the second operand is a zero vector
                         (LLVM already optimizes this away)
                         and the first operand could be merged
                         with the second operand of the previous step.
                         -}
                         (Vector.shiftUpMultiZero d =<< A.mul y a)
                         (Vector.shiftUpMultiZero (2*d) =<< A.mul y b)) $
                  liftM2 (,)
                     (M.liftR2 A.sub
                         (A.mul a a)
                         (A.mul b (SoV.replicateOf 2)))
                     (A.mul b b))
               (xk1,(d1v,d2vn))
               (takeWhile (< size) $ iterate (2*) 1)

         return (xk2, xk2))
      (return (LLVM.value LLVM.zero))

_causalRecursivePackedPAlt =
   CausalP.mapAccumSimple
      (\(p, x0) (x1,x2) -> do
         let size = Vector.sizeInTuple x0
         -- let xk1 = x0
         xk1 <-
            Vector.modify (valueOf 0)
               (\u0 ->
                  A.add u0 =<<
                  M.liftR2 A.add (A.mul (Filt2.d2 p) x2) (A.mul (Filt2.d1 p) x1)) =<<
            Vector.modify (valueOf 1)
               (\u1 -> A.add u1 =<< A.mul (Filt2.d2 p) x1)
            x0

         -- let xk2 = xk1
         d1v <- SoV.replicate (Filt2.d1 p)
         d2v <- SoV.replicate =<< LLVM.neg (Filt2.d2 p)
         xk2 <-
            fmap fst $
            foldM
               (\(y,(a,b)) d ->
                  liftM2 (,)
                     (A.add y =<<
                      M.liftR2 A.add
                         (Vector.shiftUpMultiZero d =<< A.mul y a)
                         (Vector.shiftUpMultiZero (2*d) =<< A.mul y b)) $
                  liftM2 (,)
                     (M.liftR2 A.sub
                         (A.mul a a)
                         (A.mul b (SoV.replicateOf 2)))
                     (A.mul b b))
               (xk1,(d1v,d2v))
               (takeWhile (< size) $ iterate (2*) 1)

         y0 <- Vector.extract (valueOf $ fromIntegral size - 1) xk2
         y1 <- Vector.extract (valueOf $ fromIntegral size - 2) xk2
         return (xk2, (y0,y1)))
      (return (LLVM.value LLVM.zero, LLVM.value LLVM.zero))

{-
A filter of second order can also be represented
by a filter of first order with 2x2-matrix coefficients.

filter1 ((d1,d2), (1,0)) (y1,y2) [(x0,0), (x1,0), (x2,0), (x3,0)]

/d1i d2i\ . /d1j d2j\ = /d1i*d1j + d2i  d1i*d2j\
\ 1   0 /   \ 1   0 /   \    d1j            d2j/


With this representation we can also implement filters
with time-variant filter parameters
using time-variant first-order filter.
-}