{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveTraversable #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Synthesizer.LLVM.Filter.Allpass (
   Parameter, parameter,
   CascadeParameter, flangerParameter, flangerParameterPlain,
   causal, cascade, phaser,
   cascadePipeline, phaserPipeline,
   causalPacked, cascadePacked, phaserPacked,

   causalP, cascadeP, phaserP,
   causalPackedP, cascadePackedP, phaserPackedP,
   ) where

import Synthesizer.Plain.Filter.Recursive.Allpass (Parameter(Parameter), )
import qualified Synthesizer.Plain.Filter.Recursive.Allpass as Allpass
import qualified Synthesizer.Plain.Filter.Recursive.FirstOrder as Filt1

import qualified Synthesizer.Plain.Modifier as Modifier
import qualified Synthesizer.LLVM.Filter.FirstOrder as Filt1L

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.CausalParameterized.Functional as F
import qualified Synthesizer.LLVM.Causal.ProcessValue as CausalV
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import qualified Synthesizer.LLVM.Simple.Value as Value

import qualified LLVM.Extra.Multi.Vector.Memory as MultiVectorMemory
import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Scalar as Scalar
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Util.Loop as Loop
import qualified LLVM.Core as LLVM
import LLVM.Extra.Class (Undefined, undefTuple, )
import LLVM.Core (CodeGenFunction, )
import LLVM.Util.Loop (Phi, phis, addPhis, )

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Base.Proxy (Proxy(Proxy), )

import Foreign.Storable (Storable, )

import qualified Control.Category as Cat
import qualified Control.Applicative as App
import Control.Arrow ((<<<), (^<<), (<<^), (&&&), arr, first, second, )

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import Data.Tuple.HT (mapPair, )

import qualified Algebra.Transcendental as Trans

import NumericPrelude.Numeric
import NumericPrelude.Base


instance (Phi a) => Phi (Parameter a) where
   phis = Class.phisTraversable
   addPhis = Class.addPhisFoldable

instance Undefined a => Undefined (Parameter a) where
   undefTuple = Class.undefTuplePointed

instance Class.Zero a => Class.Zero (Parameter a) where
   zeroTuple = Class.zeroTuplePointed

instance (Memory.C a) => Memory.C (Parameter a) where
   type Struct (Parameter a) = Memory.Struct a
   load = Memory.loadNewtype Parameter
   store = Memory.storeNewtype (\(Parameter k) -> k)
   decompose = Memory.decomposeNewtype Parameter
   compose = Memory.composeNewtype (\(Parameter k) -> k)


{-
instance LLVM.ValueTuple a => LLVM.ValueTuple (Parameter a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (Parameter a) where
   tupleDesc = Class.tupleDescFoldable
-}

instance (Class.MakeValueTuple a) => Class.MakeValueTuple (Parameter a) where
   type ValueTuple (Parameter a) = Parameter (Class.ValueTuple a)
   valueTupleOf = Class.valueTupleOfFunctor

instance (MultiValue.C a) => MultiValue.C (Allpass.Parameter a) where
   type Repr f (Allpass.Parameter a) = Allpass.Parameter (MultiValue.Repr f a)
   cons = paramFromPlainValue . MultiValue.cons . Allpass.getParameter

   undef = paramFromPlainValue MultiValue.undef
   zero = paramFromPlainValue MultiValue.zero

   phis bb =
      fmap paramFromPlainValue .
      MultiValue.phis bb .
      plainFromParamValue
   addPhis bb a b =
      MultiValue.addPhis bb
         (plainFromParamValue a)
         (plainFromParamValue b)

instance (MultiVector.C a) => MultiVector.C (Allpass.Parameter a) where
   cons = paramFromPlainVector . MultiVector.cons . fmap Allpass.getParameter
   undef = paramFromPlainVector MultiVector.undef
   zero = paramFromPlainVector MultiVector.zero

   phis bb =
      fmap paramFromPlainVector .
      MultiVector.phis bb .
      plainFromParamVector
   addPhis bb a b =
      MultiVector.addPhis bb
         (plainFromParamVector a)
         (plainFromParamVector b)

   shuffle is a b =
      fmap paramFromPlainVector $
      MultiVector.shuffle is (plainFromParamVector a) (plainFromParamVector b)
   extract i v =
      fmap paramFromPlainValue $
      MultiVector.extract i $
      plainFromParamVector v
   insert i a v =
      fmap paramFromPlainVector $
      MultiVector.insert i (plainFromParamValue a) $
      plainFromParamVector v

paramFromPlainVector ::
   MultiVector.T n a ->
   MultiVector.T n (Allpass.Parameter a)
paramFromPlainVector =
   MultiVector.lift1 Allpass.Parameter

plainFromParamVector ::
   MultiVector.T n (Allpass.Parameter a) ->
   MultiVector.T n a
plainFromParamVector =
   MultiVector.lift1 Allpass.getParameter

paramFromPlainValue ::
   MultiValue.T a ->
   MultiValue.T (Allpass.Parameter a)
paramFromPlainValue =
   MultiValue.lift1 Allpass.Parameter

plainFromParamValue ::
   MultiValue.T (Allpass.Parameter a) ->
   MultiValue.T a
plainFromParamValue =
   MultiValue.lift1 Allpass.getParameter


instance (MultiVectorMemory.C n a) => MultiVectorMemory.C n (Allpass.Parameter a) where
   type Struct n (Allpass.Parameter a) = MultiVectorMemory.Struct n a
   load      = fmap paramFromPlainVector . MultiVectorMemory.load
   store     = MultiVectorMemory.store . plainFromParamVector
   decompose = fmap paramFromPlainVector . MultiVectorMemory.decompose
   compose   = MultiVectorMemory.compose . plainFromParamVector


instance (Value.Flatten a) => Value.Flatten (Parameter a) where
   type Registers (Parameter a) = Parameter (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable


instance (Vector.Simple v) => Vector.Simple (Parameter v) where
   type Element (Parameter v) = Parameter (Vector.Element v)
   type Size (Parameter v) = Vector.Size v
   shuffleMatch = Vector.shuffleMatchTraversable
   extract = Vector.extractTraversable

instance (Vector.C v) => Vector.C (Parameter v) where
   insert = Vector.insertTraversable


parameter ::
   (A.Transcendental a, A.RationalConstant a) =>
   a -> a -> CodeGenFunction r (Parameter a)
parameter = Value.unlift2 Allpass.parameter


newtype CascadeParameter n a =
   CascadeParameter (Allpass.Parameter a)
      deriving
         (Undefined, Class.Zero, Storable,
          Functor, App.Applicative, Fold.Foldable, Trav.Traversable)

instance (Phi a) => Phi (CascadeParameter n a) where
   phis bb (CascadeParameter v) = fmap CascadeParameter $ Loop.phis bb v
   addPhis bb (CascadeParameter x) (CascadeParameter y) = Loop.addPhis bb x y


instance (Memory.C a) => Memory.C (CascadeParameter n a) where
   type Struct (CascadeParameter n a) = Memory.Struct a
   load = Memory.loadNewtype CascadeParameter
   store = Memory.storeNewtype (\(CascadeParameter k) -> k)
   decompose = Memory.decomposeNewtype CascadeParameter
   compose = Memory.composeNewtype (\(CascadeParameter k) -> k)


{-
instance LLVM.ValueTuple a => LLVM.ValueTuple (CascadeParameter n a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (CascadeParameter n a) where
   tupleDesc = Class.tupleDescFoldable
-}

instance (Class.MakeValueTuple a) => Class.MakeValueTuple (CascadeParameter n a) where
   type ValueTuple (CascadeParameter n a) = CascadeParameter n (Class.ValueTuple a)
   valueTupleOf = Class.valueTupleOfFunctor

instance (MultiValue.C a) => MultiValue.C (CascadeParameter n a) where
   type Repr f (CascadeParameter n a) = MultiValue.Repr f (Allpass.Parameter a)
   cons (CascadeParameter a) = cascadeFromParamValue $ MultiValue.cons a

   undef = cascadeFromParamValue MultiValue.undef
   zero = cascadeFromParamValue MultiValue.zero

   phis bb =
      fmap cascadeFromParamValue .
      MultiValue.phis bb .
      paramFromCascadeValue
   addPhis bb a b =
      MultiValue.addPhis bb
         (paramFromCascadeValue a)
         (paramFromCascadeValue b)

instance (MultiVector.C a) => MultiVector.C (CascadeParameter n a) where
   cons =
      cascadeFromParamVector . MultiVector.cons .
      fmap (\(CascadeParameter a) -> a)
   undef = cascadeFromParamVector MultiVector.undef
   zero = cascadeFromParamVector MultiVector.zero

   phis bb =
      fmap cascadeFromParamVector .
      MultiVector.phis bb .
      paramFromCascadeVector
   addPhis bb a b =
      MultiVector.addPhis bb
         (paramFromCascadeVector a)
         (paramFromCascadeVector b)

   shuffle is a b =
      fmap cascadeFromParamVector $
      MultiVector.shuffle is
         (paramFromCascadeVector a) (paramFromCascadeVector b)
   extract i v =
      fmap cascadeFromParamValue $
      MultiVector.extract i $
      paramFromCascadeVector v
   insert i a v =
      fmap cascadeFromParamVector $
      MultiVector.insert i (paramFromCascadeValue a) $
      paramFromCascadeVector v

cascadeFromParamVector ::
   MultiVector.T n (Allpass.Parameter a) ->
   MultiVector.T n (CascadeParameter m a)
cascadeFromParamVector = MultiVector.lift1 id

paramFromCascadeVector ::
   MultiVector.T n (CascadeParameter m a) ->
   MultiVector.T n (Allpass.Parameter a)
paramFromCascadeVector = MultiVector.lift1 id

cascadeFromParamValue ::
   MultiValue.T (Allpass.Parameter a) ->
   MultiValue.T (CascadeParameter m a)
cascadeFromParamValue = MultiValue.lift1 id

paramFromCascadeValue ::
   MultiValue.T (CascadeParameter m a) ->
   MultiValue.T (Allpass.Parameter a)
paramFromCascadeValue = MultiValue.lift1 id

instance (MultiVectorMemory.C n a) => MultiVectorMemory.C n (CascadeParameter n a) where
   type Struct n (CascadeParameter n a) = MultiVectorMemory.Struct n (Allpass.Parameter a)
   load      = fmap cascadeFromParamVector . MultiVectorMemory.load
   store     = MultiVectorMemory.store . paramFromCascadeVector
   decompose = fmap cascadeFromParamVector . MultiVectorMemory.decompose
   compose   = MultiVectorMemory.compose . paramFromCascadeVector

instance (Value.Flatten a) => Value.Flatten (CascadeParameter n a) where
   type Registers (CascadeParameter n a) = CascadeParameter n (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable


instance (Vector.Simple v) => Vector.Simple (CascadeParameter n v) where
   type Element (CascadeParameter n v) = CascadeParameter n (Vector.Element v)
   type Size (CascadeParameter n v) = Vector.Size v
   shuffleMatch = Vector.shuffleMatchTraversable
   extract = Vector.extractTraversable

instance (Vector.C v) => Vector.C (CascadeParameter n v) where
   insert  = Vector.insertTraversable

type instance F.Arguments f (CascadeParameter n a) = f (CascadeParameter n a)
instance F.MakeArguments (CascadeParameter n a) where
   makeArgs = id


flangerParameter ::
   (A.Transcendental a, A.RationalConstant a, TypeNum.Natural n) =>
   Proxy n -> a ->
   CodeGenFunction r (CascadeParameter n a)
flangerParameter order =
   Value.unlift1 (flangerParameterPlain order)

flangerParameterPlain ::
   (Trans.C a, TypeNum.Natural n) =>
   Proxy n -> a -> CascadeParameter n a
flangerParameterPlain order freq =
   CascadeParameter $
   Allpass.flangerParameter (TypeNum.integralFromProxy order) freq


modifier ::
   (a ~ A.Scalar v, A.PseudoModule v, A.IntegerConstant a) =>
   Modifier.Simple
      -- (Allpass.State (Value.T v))
      (Value.T v, Value.T v)
      (Parameter (Value.T a))
      (Value.T v) (Value.T v)
modifier =
   Allpass.firstOrderModifier

{-
For Allpass cascade you may use the 'Causal.pipeline' function.
-}
causal ::
   (Causal.C process,
    A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   process (Parameter a, v) v
causal =
   Causal.fromModifier modifier


replicateStage ::
   (Causal.C process,
    TypeNum.Natural n, Phi b, Undefined b) =>
   Proxy n ->
   process (Parameter a, b) b ->
   process (CascadeParameter n a, b) b
replicateStage order stg =
   Causal.replicateControlled
      (TypeNum.integralFromProxy order)
      (stg <<< first (arr (\(CascadeParameter p) -> p)))

cascade ::
   (Causal.C process,
    A.RationalConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v,
    TypeNum.Natural n) =>
   process (CascadeParameter n a, v) v
cascade =
   replicateStage Proxy causal

halfVector ::
   (Causal.C process, A.RationalConstant a, a ~ A.Scalar v, A.PseudoModule v) =>
   process v v
halfVector = CausalV.map (Value.fromRational' 0.5 *>)

phaser ::
   (Causal.C process,
    A.RationalConstant a, A.RationalConstant v,
    a ~ A.Scalar v, A.PseudoModule v, Memory.C v,
    TypeNum.Natural n) =>
   process (CascadeParameter n a, v) v
phaser =
   Causal.mix <<<
   cascade &&& arr snd <<<
   second halfVector


paramFromCascadeParam ::
   MultiValue.T (CascadeParameter n a) ->
   Allpass.Parameter (MultiValue.T a)
paramFromCascadeParam (MultiValue.Cons a) =
   fmap MultiValue.Cons a

{-
It shouldn't be too hard to use vector operations for the code we generate,
but LLVM-2.6 does not yet do it.
-}
stage ::
   (Causal.C process,
    TypeNum.Positive n, MultiVector.C a,
    MultiVector.T n (CascadeParameter n a, a) ~ v,
    MultiValue.PseudoRing a, MultiValue.IntegerConstant a,
    MultiValueMemory.C a) =>
   Proxy n -> process v v
stage _ =
   Causal.vectorize $
      uncurry MultiValue.zip
      ^<<
      (arr fst &&&
       (Scalar.decons
        ^<<
        causal
        <<^
        (\(p, v) ->
           (fmap Scalar.Cons $ paramFromCascadeParam p, Scalar.Cons v))))
      <<^
      MultiValue.unzip

withSize ::
   (Proxy n -> process (MultiValue.T (CascadeParameter n a), b) c) ->
   process (MultiValue.T (CascadeParameter n a), b) c
withSize f = f Proxy

{- |
Fast implementation of 'cascade' using vector instructions.
However, there must be at least one pipeline stage,
primitive element types
and we get a delay by the number of pipeline stages.
-}
cascadePipeline ::
   (Causal.C process,
    TypeNum.Positive n, MultiVector.C a,
    MultiValue.Repr LLVM.Value a ~ ar,
    MultiValue.PseudoRing a, MultiValue.IntegerConstant a,
    MultiValueMemory.C a, MultiVectorMemory.C n a) =>
   process
      (MultiValue.T (CascadeParameter n a), MultiValue.T a)
      (MultiValue.T a)
cascadePipeline = withSize $ \order ->
   MultiValue.snd
   ^<<
   Causal.pipeline (stage order)
   <<^
   uncurry MultiValue.zip

vectorId ::
   (Causal.C process) =>
   Proxy n -> process (MultiVector.T n a) (MultiVector.T n a)
vectorId _ = Cat.id

half ::
   (Causal.C process, A.RationalConstant a, A.PseudoRing a) =>
   process a a
half = CausalV.map (Value.fromRational' 0.5 *)


multiValue ::
   (MultiValue.Repr LLVM.Value a ~ LLVM.Value a) =>
   LLVM.Value a -> MultiValue.T a
multiValue = MultiValue.Cons

unmultiValue ::
   (MultiValue.Repr LLVM.Value a ~ LLVM.Value a) =>
   MultiValue.T a -> LLVM.Value a
unmultiValue (MultiValue.Cons a) = a

multiCascadeParam ::
   (MultiValue.Repr LLVM.Value a ~ LLVM.Value a) =>
   CascadeParameter n (LLVM.Value a) ->
   MultiValue.T (CascadeParameter n a)
multiCascadeParam (CascadeParameter a) =
   MultiValue.Cons a

phaserPipeline ::
   (Causal.C process,
    TypeNum.Positive n,
    MultiValue.PseudoRing a, MultiValue.RationalConstant a,
    MultiValueMemory.C a, MultiVectorMemory.C n a,
    MultiValue.Repr LLVM.Value a ~ LLVM.Value a) =>
   process
      (CascadeParameter n (LLVM.Value a), LLVM.Value a)
      (LLVM.Value a)
phaserPipeline =
   unmultiValue
   ^<<
   phaserPipelineMulti
   <<^
   mapPair (multiCascadeParam, multiValue)


phaserPipelineMulti ::
   (Causal.C process,
    TypeNum.Positive n,
    MultiValue.PseudoRing a, MultiValue.RationalConstant a,
    MultiValueMemory.C a, MultiVectorMemory.C n a) =>
   process
      (MultiValue.T (CascadeParameter n a), MultiValue.T a)
      (MultiValue.T a)
phaserPipelineMulti = withSize $ \order ->
   Causal.mix <<<
   cascadePipeline &&&
   (Causal.pipeline (vectorId order) <<^ snd) <<<
--   (Causal.delay (const zero) (const $ TypeNum.integralFromProxy order) <<^ snd) <<<
   second half


causalPacked,
  causalNonRecursivePacked ::
   (Causal.C process,
    Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   process (Parameter a, v) v

causalPacked =
   Filt1L.causalRecursivePacked <<<
   (Causal.map
       (\(Parameter k, _) ->
           fmap Filt1.Parameter $ A.neg k) &&&
    causalNonRecursivePacked)

causalNonRecursivePacked =
   Causal.mapAccum
      (\(Parameter k, v0) x1 -> do
         (_,v1) <- Serial.shiftUp x1 v0
         y <- A.add v1 =<< A.mul v0 =<< Serial.upsample k
         let size = fromIntegral $ Serial.size v0
         u0 <- Serial.extract (LLVM.valueOf $ size - 1) v0
         return (y, u0))
      (return A.zero)

cascadePacked, phaserPacked ::
   (Causal.C process,
    TypeNum.Natural n,
    Serial.C v, Serial.Element v ~ a,
    A.PseudoRing a, A.IntegerConstant a, Memory.C a,
    A.PseudoRing v, A.RationalConstant v) =>
   process (CascadeParameter n a, v) v
cascadePacked =
   replicateStage Proxy causalPacked

phaserPacked =
   Causal.mix <<<
   cascadePacked &&& arr snd <<<
   second (Causal.map (A.mul (A.fromRational' 0.5)))





causalP ::
   (A.RationalConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   CausalP.T p (Parameter a, v) v
causalP = causal

cascadeP ::
   (A.RationalConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v,
    TypeNum.Natural n) =>
   CausalP.T p (CascadeParameter n a, v) v
cascadeP = cascade

phaserP ::
   (A.RationalConstant a, A.RationalConstant v,
    a ~ A.Scalar v, A.PseudoModule v, Memory.C v,
    TypeNum.Natural n) =>
   CausalP.T p (CascadeParameter n a, v) v
phaserP = phaser


causalPackedP ::
   (Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   CausalP.T p (Parameter a, v) v
causalPackedP = causalPacked

cascadePackedP, phaserPackedP ::
   (TypeNum.Natural n,
    Serial.C v, Serial.Element v ~ a,
    A.PseudoRing a, A.IntegerConstant a, Memory.C a,
    A.PseudoRing v, A.RationalConstant v) =>
   CausalP.T p (CascadeParameter n a, v) v
cascadePackedP = cascadePacked
phaserPackedP = phaserPacked

{-# DEPRECATED causalP          "use 'causal' instead" #-}
{-# DEPRECATED cascadeP         "use 'cascade' instead" #-}
{-# DEPRECATED phaserP          "use 'phaser' instead" #-}
{-# DEPRECATED causalPackedP    "use 'causalPacked' instead" #-}
{-# DEPRECATED cascadePackedP   "use 'cascadePacked' instead" #-}
{-# DEPRECATED phaserPackedP    "use 'phaserPacked' instead" #-}