{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Synthesizer.LLVM.Filter.FirstOrder (
   Result(Result,lowpass_,highpass_), Parameter, parameter,
   causalP, lowpassCausalP, highpassCausalP,
   causalPackedP, lowpassCausalPackedP, highpassCausalPackedP,
   causalRecursivePackedP, -- for Allpass
   ) where

import qualified Synthesizer.Plain.Filter.Recursive.FirstOrder as FirstOrder
import Synthesizer.Plain.Filter.Recursive.FirstOrder
          (Parameter(Parameter), Result(Result,lowpass_,highpass_))

import qualified Synthesizer.Plain.Modifier as Modifier

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified LLVM.Extra.Representation as Rep
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector
import qualified Synthesizer.LLVM.Simple.Value as Value

import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Arithmetic as A

import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, valueOf, Vector, Undefined, undefTuple,
    IsFirstClass, IsConst, IsArithmetic, IsFloating,
    IsPrimitive, IsPowerOf2, IsSized,
    CodeGenFunction, )
import LLVM.Util.Loop (Phi, phis, addPhis, )

import Control.Arrow (arr, (&&&), (<<<), )
import Control.Monad (liftM2, foldM, )

import qualified Algebra.Transcendental as Trans
-- import qualified Algebra.Field as Field
import qualified Algebra.Module as Module
import qualified Algebra.Ring as Ring

import NumericPrelude.Numeric
import NumericPrelude.Base


instance (Phi a) => Phi (Parameter a) where
   phis = Class.phisTraversable
   addPhis = Class.addPhisFoldable

instance Undefined a => Undefined (Parameter a) where
   undefTuple = Class.undefTuplePointed

instance
      (Rep.Memory a s, IsSized s ss) =>
      Rep.Memory (Parameter a) s where
   load = Rep.loadNewtype Parameter
   store = Rep.storeNewtype (\(Parameter k) -> k)
   decompose = Rep.decomposeNewtype Parameter
   compose = Rep.composeNewtype (\(Parameter k) -> k)

instance (Value.Flatten ah al) =>
      Value.Flatten (Parameter ah) (Parameter al) where
   flatten = Value.flattenTraversable
   unfold =  Value.unfoldFunctor

instance LLVM.ValueTuple a => LLVM.ValueTuple (Parameter a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (Parameter a) where
   tupleDesc = Class.tupleDescFoldable

instance (LLVM.MakeValueTuple ah al) =>
      LLVM.MakeValueTuple (Parameter ah) (Parameter al) where
   valueTupleOf = Class.valueTupleOfFunctor


parameter ::
   (Trans.C a, IsConst a, IsFloating a) =>
   Value a ->
   CodeGenFunction r (Parameter (Value a))
parameter reson =
   Value.flatten $
   FirstOrder.parameter
      (Value.constantValue reson)


lowpassModifier, highpassModifier ::
   (Module.C (Value.T a) (Value.T v), IsArithmetic a, IsConst a) =>
   Modifier.Simple
--      (FirstOrder.State (Value.T v))
      (Value.T v)
      (Parameter (Value.T a))
      (Value.T v) (Value.T v)
lowpassModifier  = FirstOrder.lowpassModifier
highpassModifier = FirstOrder.highpassModifier

causalP ::
   (Ring.C a, Module.C (Value.T a) (Value.T v),
    IsFirstClass a, IsSized a as, IsConst a, IsArithmetic a,
    IsFirstClass v, IsSized v vs, IsConst v, IsArithmetic v) =>
   CausalP.T p
      (Parameter (Value a), Value v) (Result (Value v))
{-
in contrast to CausalP.fromModifier this allows for sharing
between lowpass and highpass channel
-}
causalP =
   CausalP.mapSimple (\(l,x) -> do
      h <- A.sub x l
      return (Result{FirstOrder.lowpass_ = l,
                     FirstOrder.highpass_ = h}))
    <<< (lowpassCausalP &&& arr snd)

lowpassCausalP, highpassCausalP ::
   (Ring.C a, Module.C (Value.T a) (Value.T v),
    IsFirstClass a, IsSized a as, IsConst a,
    IsFirstClass v, IsSized v vs, IsConst v,
    IsArithmetic a) =>
   CausalP.T p
      (Parameter (Value a), Value v) (Value v)
lowpassCausalP  = CausalP.fromModifier lowpassModifier
highpassCausalP = CausalP.fromModifier highpassModifier

lowpassCausalPackedP, highpassCausalPackedP, causalRecursivePackedP ::
   (Ring.C a,
    IsFirstClass a, IsConst a, IsSized a as,
    IsPowerOf2 n, -- IsSized (Vector n a) vas,
    IsArithmetic a, IsPrimitive a) =>
   CausalP.T p
      (Parameter (Value a), Value (Vector n a)) (Value (Vector n a))
highpassCausalPackedP =
   CausalP.mapSimple (uncurry A.sub) <<<
   (arr snd &&& lowpassCausalPackedP)
lowpassCausalPackedP =
   causalRecursivePackedP <<<
   (arr fst &&&
    CausalP.mapSimple
       (\(FirstOrder.Parameter k, x) ->
          A.mul x =<< SoV.replicate =<< A.sub (valueOf 1) k))

{-
x = [x0, x1, x2, x3]

filter k y1 x
  = [x0 + k*y1,
     x1 + k*x0 + k^2*y1,
     x2 + k*x1 + k^2*x0 + k^3*y1,
     x3 + k*x2 + k^2*x1 + k^3*x0 + k^4*y1,
     ... ]

f0x = insert 0 (k*y1) x
f1x = f0x + k * f0x->1
f2x = f1x + k^2 * f1x->2
-}
causalRecursivePackedP =
   CausalP.mapAccumSimple
      (\(FirstOrder.Parameter k, xk0) y1 -> do
         y1k <- A.mul k y1
         xk1 <- Vector.modify (valueOf 0) (A.add y1k) xk0
         let size = Vector.sizeInTuple xk0
         kv <- SoV.replicate k
         xk2 <-
            fmap fst $
            foldM
               (\(y,k0) d ->
                  liftM2 (,)
                     (A.add y =<<
                      Vector.shiftUpMultiZero d =<<
                      A.mul y k0)
                     (A.mul k0 k0))
               (xk1,kv)
               (takeWhile (< size) $ iterate (2*) 1)
{- do replicate in the loop
         xk2 <-
            fmap fst $
            foldM
               (\(y,k0) d ->
                  liftM2 (,)
                     (A.add y =<<
                      Vector.shiftUpMultiZero d =<<
                      A.mul y =<<
                      SoV.replicate k0)
                     (A.mul k0 k0))
               (xk1,k)
               (takeWhile (< size) $ iterate (2*) 1)
-}
         y0 <- Vector.extract (valueOf $ fromIntegral $ size - 1) xk2
         return (xk2, y0))
      (return (LLVM.value LLVM.zero))

{-
We can also optimize filtering with time-varying filter parameter.

k = [k0, k1, k2, k3]
x = [x0, x1, x2, x3]

filter k y1 x
  = [x0 + k0*y1,
     x1 + k1*x0 + k1*k0*y1,
     x2 + k2*x1 + k2*k1*x0 + k2*k1*k0*y1,
     x3 + k3*x2 + k3*k2*x1 + k3*k2*k1*x0 + k3*k2*k1*k0*y1,
     ... ]

f0x = insert 0 (k0*y1) x
f1x = f0x + k  * f0x->1      k'  = k * k->1
f2x = f1x + k' * f1x->2


We can even interpret vectorised first order filtering
as first order filtering with matrix coefficients.

[x0 + k0*y1,
 x1 + k1*x0 + k1*k0*y1,
 x2 + k2*x1 + k2*k1*x0 + k2*k1*k0*y1,
 x3 + k3*x2 + k3*k2*x1 + k3*k2*k1*x0 + k3*k2*k1*k0*y1]
  =
  / 1                   \   /x0\    / k0          0 0 0 \   /y1\
  | k1       1          | . |x1| +  | k1*k0       0 0 0 | . |y2|
  | k2*k1    k2    1    |   |x2|    | k2*k1*k0    0 0 0 |   |y3|
  \ k3*k2*k1 k3*k2 k3 1 /   \x3/    \ k3*k2*k1*k0 0 0 0 /   \y4/


  / 1                   \   / 1                 \   / 1          \
  | k1       1          | = |         1         | . | k1  1      |
  | k2*k1    k2    1    |   | k2*k1        1    |   |    k2  1   |
  \ k3*k2*k1 k3*k2 k3 1 /   \       k3*k2     1 /   \       k3 1 /
-}



causalPackedP ::
   (Ring.C a, IsArithmetic a, IsPrimitive a,
    IsFirstClass a, IsConst a, IsSized a as,
    IsPowerOf2 n) =>
   CausalP.T p
      (Parameter (Value a), Value (Vector n a))
      (Result (Value (Vector n a)))
causalPackedP =
   CausalP.mapSimple (\(l,x) -> do
      h <- A.sub x l
      return (Result{FirstOrder.lowpass_ = l,
                     FirstOrder.highpass_ = h}))
    <<< (lowpassCausalPackedP &&& arr snd)