{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveTraversable #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Synthesizer.LLVM.Filter.Allpass (
   Parameter, parameter,
   CascadeParameter, flangerParameter, flangerParameterPlain,
   causal, cascade, phaser,
   cascadePipeline, phaserPipeline,
   causalPacked, cascadePacked, phaserPacked,

   causalP, cascadeP, phaserP,
   causalPackedP, cascadePackedP, phaserPackedP,
   ) where

import Synthesizer.Plain.Filter.Recursive.Allpass (Parameter(Parameter))
import qualified Synthesizer.Plain.Filter.Recursive.Allpass as Allpass
import qualified Synthesizer.Plain.Filter.Recursive.FirstOrder as Filt1

import qualified Synthesizer.Plain.Modifier as Modifier
import qualified Synthesizer.LLVM.Filter.FirstOrder as Filt1L

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.CausalParameterized.Functional as F
import qualified Synthesizer.LLVM.Causal.ProcessValue as CausalV
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import qualified Synthesizer.LLVM.Simple.Value as Value

import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Scalar as Scalar
import qualified LLVM.Extra.Storable as Storable
import qualified LLVM.Extra.Marshal as Marshal
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction)

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Base.Proxy (Proxy(Proxy))

import Foreign.Storable (Storable)

import qualified Control.Category as Cat
import qualified Control.Applicative as App
import Control.Arrow ((<<<), (^<<), (<<^), (&&&), arr, first, second)

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import Data.Tuple.HT (mapPair)

import qualified Algebra.Transcendental as Trans

import NumericPrelude.Numeric
import NumericPrelude.Base


instance (Tuple.Phi a) => Tuple.Phi (Parameter a) where
   phi = Tuple.phiTraversable
   addPhi = Tuple.addPhiFoldable

instance Tuple.Undefined a => Tuple.Undefined (Parameter a) where
   undef = Tuple.undefPointed

instance Tuple.Zero a => Tuple.Zero (Parameter a) where
   zero = Tuple.zeroPointed

instance (Memory.C a) => Memory.C (Parameter a) where
   type Struct (Parameter a) = Memory.Struct a
   load = Memory.loadNewtype Parameter
   store = Memory.storeNewtype (\(Parameter k) -> k)
   decompose = Memory.decomposeNewtype Parameter
   compose = Memory.composeNewtype (\(Parameter k) -> k)

instance (Marshal.C a) => Marshal.C (Parameter a) where
   pack (Parameter k) = Marshal.pack k
   unpack = Parameter . Marshal.unpack

instance (Storable.C a) => Storable.C (Parameter a) where
   load = Storable.loadNewtype Parameter Parameter
   store = Storable.storeNewtype Parameter (\(Parameter k) -> k)


{-
instance LLVM.ValueTuple a => LLVM.ValueTuple (Parameter a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (Parameter a) where
   tupleDesc = Class.tupleDescFoldable
-}

instance (Tuple.Value a) => Tuple.Value (Parameter a) where
   type ValueOf (Parameter a) = Parameter (Tuple.ValueOf a)
   valueOf = Tuple.valueOfFunctor

instance (Tuple.VectorValue n a) => Tuple.VectorValue n (Parameter a) where
   type VectorValueOf n (Parameter a) = Parameter (Tuple.VectorValueOf n a)
   vectorValueOf = fmap Tuple.vectorValueOf . Trav.sequenceA

instance (MultiValue.C a) => MultiValue.C (Allpass.Parameter a) where
   cons = paramFromPlainValue . MultiValue.cons . Allpass.getParameter

   undef = paramFromPlainValue MultiValue.undef
   zero = paramFromPlainValue MultiValue.zero

   phi bb =
      fmap paramFromPlainValue .
      MultiValue.phi bb .
      plainFromParamValue
   addPhi bb a b =
      MultiValue.addPhi bb
         (plainFromParamValue a)
         (plainFromParamValue b)

instance (MultiVector.C a) => MultiVector.C (Allpass.Parameter a) where
   cons = paramFromPlainVector . MultiVector.cons . fmap Allpass.getParameter
   undef = paramFromPlainVector MultiVector.undef
   zero = paramFromPlainVector MultiVector.zero

   phi bb =
      fmap paramFromPlainVector .
      MultiVector.phi bb .
      plainFromParamVector
   addPhi bb a b =
      MultiVector.addPhi bb
         (plainFromParamVector a)
         (plainFromParamVector b)

   shuffle is a b =
      fmap paramFromPlainVector $
      MultiVector.shuffle is (plainFromParamVector a) (plainFromParamVector b)
   extract i v =
      fmap paramFromPlainValue $
      MultiVector.extract i $
      plainFromParamVector v
   insert i a v =
      fmap paramFromPlainVector $
      MultiVector.insert i (plainFromParamValue a) $
      plainFromParamVector v

paramFromPlainVector ::
   MultiVector.T n a ->
   MultiVector.T n (Allpass.Parameter a)
paramFromPlainVector =
   MultiVector.lift1 Allpass.Parameter

plainFromParamVector ::
   MultiVector.T n (Allpass.Parameter a) ->
   MultiVector.T n a
plainFromParamVector =
   MultiVector.lift1 Allpass.getParameter

paramFromPlainValue ::
   MultiValue.T a ->
   MultiValue.T (Allpass.Parameter a)
paramFromPlainValue =
   MultiValue.lift1 Allpass.Parameter

plainFromParamValue ::
   MultiValue.T (Allpass.Parameter a) ->
   MultiValue.T a
plainFromParamValue =
   MultiValue.lift1 Allpass.getParameter


instance (Value.Flatten a) => Value.Flatten (Parameter a) where
   type Registers (Parameter a) = Parameter (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable


instance (Vector.Simple v) => Vector.Simple (Parameter v) where
   type Element (Parameter v) = Parameter (Vector.Element v)
   type Size (Parameter v) = Vector.Size v
   shuffleMatch = Vector.shuffleMatchTraversable
   extract = Vector.extractTraversable

instance (Vector.C v) => Vector.C (Parameter v) where
   insert = Vector.insertTraversable

type instance F.Arguments f (Parameter a) = f (Parameter a)
instance F.MakeArguments (Parameter a) where
   makeArgs = id


parameter ::
   (A.Transcendental a, A.RationalConstant a) =>
   a -> a -> CodeGenFunction r (Parameter a)
parameter = Value.unlift2 Allpass.parameter


newtype CascadeParameter n a =
   CascadeParameter (Allpass.Parameter a)
      deriving
         (Tuple.Undefined, Tuple.Zero, Storable,
          Functor, App.Applicative, Fold.Foldable, Trav.Traversable)

instance (Tuple.Phi a) => Tuple.Phi (CascadeParameter n a) where
   phi bb (CascadeParameter v) = fmap CascadeParameter $ Tuple.phi bb v
   addPhi bb (CascadeParameter x) (CascadeParameter y) = Tuple.addPhi bb x y


instance (Memory.C a) => Memory.C (CascadeParameter n a) where
   type Struct (CascadeParameter n a) = Memory.Struct a
   load = Memory.loadNewtype CascadeParameter
   store = Memory.storeNewtype (\(CascadeParameter k) -> k)
   decompose = Memory.decomposeNewtype CascadeParameter
   compose = Memory.composeNewtype (\(CascadeParameter k) -> k)

instance (Marshal.C a) => Marshal.C (CascadeParameter n a) where
   pack (CascadeParameter k) = Marshal.pack k
   unpack = CascadeParameter . Marshal.unpack

instance (Storable.C a) => Storable.C (CascadeParameter n a) where
   load = Storable.loadNewtype CascadeParameter id
   store = Storable.storeNewtype CascadeParameter id


{-
instance LLVM.ValueTuple a => LLVM.ValueTuple (CascadeParameter n a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (CascadeParameter n a) where
   tupleDesc = Class.tupleDescFoldable
-}

instance (Tuple.Value a) => Tuple.Value (CascadeParameter n a) where
   type ValueOf (CascadeParameter n a) = Parameter (Tuple.ValueOf a)
   valueOf (CascadeParameter a) = Tuple.valueOf a

instance
   (Tuple.VectorValue n a) =>
      Tuple.VectorValue n (CascadeParameter m a) where
   type VectorValueOf n (CascadeParameter m a) =
            Parameter (Tuple.VectorValueOf n a)
   vectorValueOf =
      fmap Tuple.vectorValueOf . Trav.traverse (\(CascadeParameter k) -> k)

instance (MultiValue.C a) => MultiValue.C (CascadeParameter n a) where
   cons (CascadeParameter a) = cascadeFromParamValue $ MultiValue.cons a

   undef = cascadeFromParamValue MultiValue.undef
   zero = cascadeFromParamValue MultiValue.zero

   phi bb =
      fmap cascadeFromParamValue .
      MultiValue.phi bb .
      paramFromCascadeValue
   addPhi bb a b =
      MultiValue.addPhi bb
         (paramFromCascadeValue a)
         (paramFromCascadeValue b)

instance (MultiVector.C a) => MultiVector.C (CascadeParameter n a) where
   cons =
      cascadeFromParamVector . MultiVector.cons .
      fmap (\(CascadeParameter a) -> a)
   undef = cascadeFromParamVector MultiVector.undef
   zero = cascadeFromParamVector MultiVector.zero

   phi bb =
      fmap cascadeFromParamVector .
      MultiVector.phi bb .
      paramFromCascadeVector
   addPhi bb a b =
      MultiVector.addPhi bb
         (paramFromCascadeVector a)
         (paramFromCascadeVector b)

   shuffle is a b =
      fmap cascadeFromParamVector $
      MultiVector.shuffle is
         (paramFromCascadeVector a) (paramFromCascadeVector b)
   extract i v =
      fmap cascadeFromParamValue $
      MultiVector.extract i $
      paramFromCascadeVector v
   insert i a v =
      fmap cascadeFromParamVector $
      MultiVector.insert i (paramFromCascadeValue a) $
      paramFromCascadeVector v

cascadeFromParamVector ::
   MultiVector.T n (Allpass.Parameter a) ->
   MultiVector.T n (CascadeParameter m a)
cascadeFromParamVector = MultiVector.lift1 id

paramFromCascadeVector ::
   MultiVector.T n (CascadeParameter m a) ->
   MultiVector.T n (Allpass.Parameter a)
paramFromCascadeVector = MultiVector.lift1 id

cascadeFromParamValue ::
   MultiValue.T (Allpass.Parameter a) ->
   MultiValue.T (CascadeParameter m a)
cascadeFromParamValue = MultiValue.lift1 id

paramFromCascadeValue ::
   MultiValue.T (CascadeParameter m a) ->
   MultiValue.T (Allpass.Parameter a)
paramFromCascadeValue = MultiValue.lift1 id

instance (Value.Flatten a) => Value.Flatten (CascadeParameter n a) where
   type Registers (CascadeParameter n a) = CascadeParameter n (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable


instance (Vector.Simple v) => Vector.Simple (CascadeParameter n v) where
   type Element (CascadeParameter n v) = CascadeParameter n (Vector.Element v)
   type Size (CascadeParameter n v) = Vector.Size v
   shuffleMatch = Vector.shuffleMatchTraversable
   extract = Vector.extractTraversable

instance (Vector.C v) => Vector.C (CascadeParameter n v) where
   insert  = Vector.insertTraversable

type instance F.Arguments f (CascadeParameter n a) = f (CascadeParameter n a)
instance F.MakeArguments (CascadeParameter n a) where
   makeArgs = id


flangerParameter ::
   (A.Transcendental a, A.RationalConstant a, TypeNum.Natural n) =>
   Proxy n -> a ->
   CodeGenFunction r (CascadeParameter n a)
flangerParameter order =
   Value.unlift1 (flangerParameterPlain order)

flangerParameterPlain ::
   (Trans.C a, TypeNum.Natural n) =>
   Proxy n -> a -> CascadeParameter n a
flangerParameterPlain order freq =
   CascadeParameter $
   Allpass.flangerParameter (TypeNum.integralFromProxy order) freq


modifier ::
   (a ~ A.Scalar v, A.PseudoModule v, A.IntegerConstant a) =>
   Modifier.Simple
      -- (Allpass.State (Value.T v))
      (Value.T v, Value.T v)
      (Parameter (Value.T a))
      (Value.T v) (Value.T v)
modifier =
   Allpass.firstOrderModifier

{-
For Allpass cascade you may use the 'Causal.pipeline' function.
-}
causal ::
   (Causal.C process,
    A.IntegerConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   process (Parameter a, v) v
causal =
   Causal.fromModifier modifier


replicateStage ::
   (Causal.C process,
    TypeNum.Natural n, Tuple.Phi b, Tuple.Undefined b) =>
   Proxy n ->
   process (Parameter a, b) b ->
   process (CascadeParameter n a, b) b
replicateStage order stg =
   Causal.replicateControlled
      (TypeNum.integralFromProxy order)
      (stg <<< first (arr (\(CascadeParameter p) -> p)))

cascade ::
   (Causal.C process,
    A.RationalConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v,
    TypeNum.Natural n) =>
   process (CascadeParameter n a, v) v
cascade =
   replicateStage Proxy causal

halfVector ::
   (Causal.C process, A.RationalConstant a, a ~ A.Scalar v, A.PseudoModule v) =>
   process v v
halfVector = CausalV.map (Value.fromRational' 0.5 *>)

phaser ::
   (Causal.C process,
    A.RationalConstant a, A.RationalConstant v,
    a ~ A.Scalar v, A.PseudoModule v, Memory.C v,
    TypeNum.Natural n) =>
   process (CascadeParameter n a, v) v
phaser =
   Causal.mix <<<
   cascade &&& arr snd <<<
   second halfVector


paramFromCascadeParam ::
   MultiValue.T (CascadeParameter n a) ->
   Allpass.Parameter (MultiValue.T a)
paramFromCascadeParam (MultiValue.Cons a) =
   fmap MultiValue.Cons a

{-
It shouldn't be too hard to use vector operations for the code we generate,
but LLVM-2.6 does not yet do it.
-}
stage ::
   (Causal.C process,
    TypeNum.Positive n, MultiVector.C a,
    MultiVector.T n (CascadeParameter n a, a) ~ v,
    MultiValue.PseudoRing a, MultiValue.IntegerConstant a,
    Marshal.C a) =>
   Proxy n -> process v v
stage _ =
   Causal.vectorize $
      uncurry MultiValue.zip
      ^<<
      (arr fst &&&
       (Scalar.decons
        ^<<
        causal
        <<^
        (\(p, v) ->
           (fmap Scalar.Cons $ paramFromCascadeParam p, Scalar.Cons v))))
      <<^
      MultiValue.unzip

withSize ::
   (Proxy n -> process (MultiValue.T (CascadeParameter n a), b) c) ->
   process (MultiValue.T (CascadeParameter n a), b) c
withSize f = f Proxy

{- |
Fast implementation of 'cascade' using vector instructions.
However, there must be at least one pipeline stage,
primitive element types
and we get a delay by the number of pipeline stages.
-}
cascadePipeline ::
   (Causal.C process,
    TypeNum.Positive n, MultiVector.C a,
    Tuple.ValueOf a ~ ar,
    MultiValue.PseudoRing a, MultiValue.IntegerConstant a,
    Marshal.C a, Marshal.Vector n a) =>
   process
      (MultiValue.T (CascadeParameter n a), MultiValue.T a)
      (MultiValue.T a)
cascadePipeline = withSize $ \order ->
   MultiValue.snd
   ^<<
   Causal.pipeline (stage order)
   <<^
   uncurry MultiValue.zip

vectorId ::
   (Causal.C process) =>
   Proxy n -> process (MultiVector.T n a) (MultiVector.T n a)
vectorId _ = Cat.id

half ::
   (Causal.C process, A.RationalConstant a, A.PseudoRing a) =>
   process a a
half = CausalV.map (Value.fromRational' 0.5 *)


multiValue ::
   (Tuple.ValueOf a ~ LLVM.Value a) =>
   LLVM.Value a -> MultiValue.T a
multiValue = MultiValue.Cons

unmultiValue ::
   (Tuple.ValueOf a ~ LLVM.Value a) =>
   MultiValue.T a -> LLVM.Value a
unmultiValue (MultiValue.Cons a) = a

multiCascadeParam ::
   (Tuple.ValueOf a ~ LLVM.Value a) =>
   CascadeParameter n (LLVM.Value a) ->
   MultiValue.T (CascadeParameter n a)
multiCascadeParam (CascadeParameter a) =
   MultiValue.Cons a

phaserPipeline ::
   (Causal.C process,
    TypeNum.Positive n,
    MultiValue.PseudoRing a, MultiValue.RationalConstant a,
    Marshal.C a, Marshal.Vector n a, MultiVector.C a,
    Tuple.ValueOf a ~ LLVM.Value a) =>
   process
      (CascadeParameter n (LLVM.Value a), LLVM.Value a)
      (LLVM.Value a)
phaserPipeline =
   unmultiValue
   ^<<
   phaserPipelineMulti
   <<^
   mapPair (multiCascadeParam, multiValue)


phaserPipelineMulti ::
   (Causal.C process,
    TypeNum.Positive n,
    MultiValue.PseudoRing a, MultiValue.RationalConstant a,
    Marshal.C a, Marshal.Vector n a, MultiVector.C a) =>
   process
      (MultiValue.T (CascadeParameter n a), MultiValue.T a)
      (MultiValue.T a)
phaserPipelineMulti = withSize $ \order ->
   Causal.mix <<<
   cascadePipeline &&&
   (Causal.pipeline (vectorId order) <<^ snd) <<<
--   (Causal.delay (const zero) (const $ TypeNum.integralFromProxy order) <<^ snd) <<<
   second half


causalPacked,
  causalNonRecursivePacked ::
   (Causal.C process,
    Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   process (Parameter a, v) v

causalPacked =
   Filt1L.causalRecursivePacked <<<
   (Causal.map
       (\(Parameter k, _) ->
           fmap Filt1.Parameter $ A.neg k) &&&
    causalNonRecursivePacked)

causalNonRecursivePacked =
   Causal.mapAccum
      (\(Parameter k, v0) x1 -> do
         (_,v1) <- Serial.shiftUp x1 v0
         y <- A.add v1 =<< A.mul v0 =<< Serial.upsample k
         let size = fromIntegral $ Serial.size v0
         u0 <- Serial.extract (LLVM.valueOf $ size - 1) v0
         return (y, u0))
      (return A.zero)

cascadePacked, phaserPacked ::
   (Causal.C process,
    TypeNum.Natural n,
    Serial.C v, Serial.Element v ~ a,
    A.PseudoRing a, A.IntegerConstant a, Memory.C a,
    A.PseudoRing v, A.RationalConstant v) =>
   process (CascadeParameter n a, v) v
cascadePacked =
   replicateStage Proxy causalPacked

phaserPacked =
   Causal.mix <<<
   cascadePacked &&& arr snd <<<
   second (Causal.map (A.mul (A.fromRational' 0.5)))





causalP ::
   (A.RationalConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v) =>
   CausalP.T p (Parameter a, v) v
causalP = causal

cascadeP ::
   (A.RationalConstant a, a ~ A.Scalar v, A.PseudoModule v, Memory.C v,
    TypeNum.Natural n) =>
   CausalP.T p (CascadeParameter n a, v) v
cascadeP = cascade

phaserP ::
   (A.RationalConstant a, A.RationalConstant v,
    a ~ A.Scalar v, A.PseudoModule v, Memory.C v,
    TypeNum.Natural n) =>
   CausalP.T p (CascadeParameter n a, v) v
phaserP = phaser


causalPackedP ::
   (Serial.C v, Serial.Element v ~ a,
    Memory.C a, A.IntegerConstant a,
    A.PseudoRing v, A.PseudoRing a) =>
   CausalP.T p (Parameter a, v) v
causalPackedP = causalPacked

cascadePackedP, phaserPackedP ::
   (TypeNum.Natural n,
    Serial.C v, Serial.Element v ~ a,
    A.PseudoRing a, A.IntegerConstant a, Memory.C a,
    A.PseudoRing v, A.RationalConstant v) =>
   CausalP.T p (CascadeParameter n a, v) v
cascadePackedP = cascadePacked
phaserPackedP = phaserPacked

{-# DEPRECATED causalP          "use 'causal' instead" #-}
{-# DEPRECATED cascadeP         "use 'cascade' instead" #-}
{-# DEPRECATED phaserP          "use 'phaser' instead" #-}
{-# DEPRECATED causalPackedP    "use 'causalPacked' instead" #-}
{-# DEPRECATED cascadePackedP   "use 'cascadePacked' instead" #-}
{-# DEPRECATED phaserPackedP    "use 'phaserPacked' instead" #-}