{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Synthesizer.LLVM.Filter.Allpass (
   Parameter, parameter,
   CascadeParameter, flangerParameter, flangerParameterPlain,
   causalP, cascadeP, phaserP,
   cascadePipelineP, phaserPipelineP,
   causalPackedP, cascadePackedP, phaserPackedP,
   ) where

import Synthesizer.Plain.Filter.Recursive.Allpass (Parameter(Parameter), )
import qualified Synthesizer.Plain.Filter.Recursive.Allpass as Allpass
import qualified Synthesizer.Plain.Filter.Recursive.FirstOrder as Filt1

import qualified Synthesizer.LLVM.Filter.FirstOrder as Filt1L
import qualified Synthesizer.Plain.Modifier as Modifier

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Representation as Rep
import qualified Synthesizer.LLVM.Simple.Value as Value

import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, valueOf, Vector,
    IsPowerOf2, IsConst, IsArithmetic, IsPrimitive, IsFirstClass, IsFloating, IsSized,
    Undefined, undefTuple,
    CodeGenFunction, )
import LLVM.Util.Loop (Phi, phis, addPhis, )

import qualified Data.TypeLevel.Num      as TypeNum
import qualified Data.TypeLevel.Num.Sets as TypeSet

import Foreign.Storable (Storable, )

import qualified Control.Category as Cat
import qualified Control.Applicative as App
import qualified Data.Foldable as Fold
import qualified Data.Traversable as Trav
import Control.Arrow ((<<<), (^<<), (<<^), (&&&), arr, first, second, )

import qualified Algebra.Transcendental as Trans
import qualified Algebra.Field as Field
import qualified Algebra.Module as Module
import qualified Algebra.Ring as Ring

import NumericPrelude.Numeric
import NumericPrelude.Base


instance (Phi a) => Phi (Parameter a) where
   phis = Class.phisTraversable
   addPhis = Class.addPhisFoldable

instance Undefined a => Undefined (Parameter a) where
   undefTuple = Class.undefTuplePointed

instance Class.Zero a => Class.Zero (Parameter a) where
   zeroTuple = Class.zeroTuplePointed

instance
      (Rep.Memory a s, IsSized s ss) =>
      Rep.Memory (Parameter a) s where
   load = Rep.loadNewtype Parameter
   store = Rep.storeNewtype (\(Parameter k) -> k)
   decompose = Rep.decomposeNewtype Parameter
   compose = Rep.composeNewtype (\(Parameter k) -> k)


instance LLVM.ValueTuple a => LLVM.ValueTuple (Parameter a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (Parameter a) where
   tupleDesc = Class.tupleDescFoldable

instance (LLVM.MakeValueTuple ah al) =>
      LLVM.MakeValueTuple (Parameter ah) (Parameter al) where
   valueTupleOf = Class.valueTupleOfFunctor


instance (Value.Flatten ah al) =>
      Value.Flatten (Parameter ah) (Parameter al) where
   flatten = Value.flattenTraversable
   unfold =  Value.unfoldFunctor


instance (Vector.ShuffleMatch n v) =>
      Vector.ShuffleMatch n (Parameter v) where
   shuffleMatch = Vector.shuffleMatchTraversable

instance (Vector.Access n a v) =>
      Vector.Access n (Parameter a) (Parameter v) where
   insert  = Vector.insertTraversable
   extract = Vector.extractTraversable


parameter ::
   (Trans.C a, IsConst a, IsFloating a) =>
   Value a -> Value a ->
   CodeGenFunction r (Parameter (Value a))
parameter phase freq =
   Value.flatten $
   Allpass.parameter
      (Value.constantValue phase) (Value.constantValue freq)


newtype CascadeParameter n a =
   CascadeParameter (Allpass.Parameter a)
      deriving
         (Phi, Undefined, Class.Zero, Storable,
          Functor, App.Applicative, Fold.Foldable, Trav.Traversable)

instance
      (Rep.Memory a s, IsSized s ss) =>
      Rep.Memory (CascadeParameter n a) s where
   load = Rep.loadNewtype CascadeParameter
   store = Rep.storeNewtype (\(CascadeParameter k) -> k)
   decompose = Rep.decomposeNewtype CascadeParameter
   compose = Rep.composeNewtype (\(CascadeParameter k) -> k)


instance LLVM.ValueTuple a => LLVM.ValueTuple (CascadeParameter n a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (CascadeParameter n a) where
   tupleDesc = Class.tupleDescFoldable

instance (LLVM.MakeValueTuple ah al) =>
      LLVM.MakeValueTuple (CascadeParameter n ah) (CascadeParameter n al) where
   valueTupleOf = Class.valueTupleOfFunctor


instance (Value.Flatten ah al) =>
      Value.Flatten (CascadeParameter n ah) (CascadeParameter n al) where
   flatten = Value.flattenTraversable
   unfold =  Value.unfoldFunctor


instance (Vector.ShuffleMatch m v) =>
      Vector.ShuffleMatch m (CascadeParameter n v) where
   shuffleMatch = Vector.shuffleMatchTraversable

instance (Vector.Access m a v) =>
      Vector.Access m (CascadeParameter n a) (CascadeParameter n v) where
   insert  = Vector.insertTraversable
   extract = Vector.extractTraversable


flangerParameter ::
   (Trans.C a, IsConst a, IsFloating a, TypeNum.Nat n) =>
   n -> Value a ->
   CodeGenFunction r (CascadeParameter n (Value a))
flangerParameter order freq =
   Value.flatten $
   CascadeParameter $
   Allpass.flangerParameter (TypeNum.toInt order) $
   Value.constantValue freq

flangerParameterPlain ::
   (Trans.C a, TypeNum.Nat n) =>
   n -> a -> CascadeParameter n a
flangerParameterPlain order freq =
   CascadeParameter $
   Allpass.flangerParameter (TypeNum.toInt order) freq


modifier ::
   (Module.C (Value.T a) (Value.T v), IsArithmetic a, IsConst a) =>
   Modifier.Simple
      -- (Allpass.State (Value.T v))
      (Value.T v, Value.T v)
      (Parameter (Value.T a))
      (Value.T v) (Value.T v)
modifier =
   Allpass.firstOrderModifier

{-
For Allpass cascade you may use the 'CausalP.pipeline' function.
-}
causalP ::
   (Field.C a, Module.C (Value.T a) (Value.T v),
    IsFirstClass a, IsSized a as, IsConst a,
    IsFirstClass v, IsSized v vs, IsConst v,
    IsArithmetic a) =>
   CausalP.T p
      (Parameter (Value a), Value v) (Value v)
causalP =
   CausalP.fromModifier modifier


replicateStage ::
   (TypeNum.Nat n) =>
   n ->
   CausalP.T p (Parameter a, b) b ->
   CausalP.T p (CascadeParameter n a, b) b
replicateStage order stg =
   CausalP.replicateControlled
      (TypeNum.toInt order)
      (stg <<< first (arr (\(CascadeParameter p) -> p)))

cascadeP ::
   (Field.C a, Module.C (Value.T a) (Value.T v),
    IsFirstClass a, IsSized a as, IsConst a,
    IsFirstClass v, IsSized v vs, IsConst v,
    IsArithmetic a,
    TypeNum.Nat n) =>
   CausalP.T p
      (CascadeParameter n (Value a), Value v) (Value v)
cascadeP =
   replicateStage undefined causalP

half ::
   (Field.C a, Module.C (Value.T a) (Value.T v),
    IsFirstClass a, IsSized a as, IsConst a,
    IsFirstClass v, IsSized v vs, IsConst v,
    IsFloating a, IsArithmetic v,
    TypeNum.Nat n) =>
   CausalP.T p
      (CascadeParameter n (Value a), Value v) (Value v)
half =
   CausalP.mapSimple (\(p,x) ->
      Value.decons
         ((const :: Value.T a -> CascadeParameter n (Value a) -> Value.T a) 0.5 p *>
          Value.constantValue x))

phaserP ::
   (Field.C a, Module.C (Value.T a) (Value.T v),
    IsFirstClass a, IsSized a as, IsConst a,
    IsFirstClass v, IsSized v vs, IsConst v,
    IsFloating a, IsArithmetic v,
    TypeNum.Nat n) =>
   CausalP.T p
      (CascadeParameter n (Value a), Value v) (Value v)
phaserP =
   CausalP.mix <<<
   cascadeP &&& arr snd <<<
   (arr fst &&& half)


{-
It shouldn't be too hard to use vector operations for the code we generate,
but LLVM-2.6 does not yet do it.
-}
stage ::
   (IsPowerOf2 n, IsPrimitive a, IsFirstClass a,
    IsConst a, IsArithmetic a, Ring.C a,
    IsSized a sa) =>
   n ->
   CausalP.T p
      (CascadeParameter n (Value (Vector n a)), Value (Vector n a))
      (CascadeParameter n (Value (Vector n a)), Value (Vector n a))
stage _ =
   CausalP.vectorize
      (arr fst &&&
       (CausalP.fromModifier modifier <<<
        first (arr (\(CascadeParameter p) -> p))))

withSize ::
   (n -> CausalP.T p (CascadeParameter n a, b) c) ->
   CausalP.T p (CascadeParameter n a, b) c
withSize f = f undefined

{- |
Fast implementation of 'cascadeP' using vector instructions.
However, we are currently limited to powers of two,
primitive element types
and we get a delay by the number of pipeline stages.
-}
cascadePipelineP ::
   (Field.C a, IsFirstClass a, IsSized a as,
    TypeNum.Mul n as vas, TypeSet.Pos vas,
--    IsSized (Vector n a) vas,
    IsPowerOf2 n,
    IsArithmetic a, IsPrimitive a, IsConst a) =>
   CausalP.T p
      (CascadeParameter n (Value a), Value a) (Value a)
cascadePipelineP = withSize $ \order ->
   snd ^<< CausalP.pipeline (stage order)

vectorId ::
   (Vector.Access n a v) =>
   n -> CausalP.T p v v
vectorId _ = Cat.id

phaserPipelineP ::
   (Field.C a,
    IsFirstClass a, IsSized a as,
    IsSized (Vector n a) vas,
    TypeNum.Mul n as vas,
    IsPowerOf2 n,
    IsFloating a, IsPrimitive a, IsConst a) =>
   CausalP.T p
      (CascadeParameter n (Value a), Value a) (Value a)
phaserPipelineP = withSize $ \order ->
   CausalP.mix <<<
   cascadePipelineP &&&
   (CausalP.pipeline (vectorId order) <<^ snd) <<<
--   (CausalP.delay (const zero) (const $ TypeNum.toInt order) <<^ snd) <<<
   (arr fst &&& half)


causalPackedP,
  causalNonRecursivePackedP ::
   (Ring.C a,
    IsFirstClass a, IsArithmetic a, IsConst a,
    IsPowerOf2 n, IsPrimitive a, IsSized a as) =>
   CausalP.T p
      (Parameter (Value a), Value (Vector n a)) (Value (Vector n a))
causalPackedP =
   Filt1L.causalRecursivePackedP <<<
   (CausalP.mapSimple
       (\(Parameter k, _) ->
           fmap Filt1.Parameter $ LLVM.neg k) &&&
    causalNonRecursivePackedP)

causalNonRecursivePackedP =
   CausalP.mapAccumSimple
      (\(Parameter k, v0) x1 -> do
         (_,v1) <- Vector.shiftUp x1 v0
         y <- A.add v1 =<< A.mul v0 =<< SoV.replicate k
         let size = fromIntegral $ Vector.sizeInTuple v0
         u0 <- Vector.extract (valueOf $ size - 1) v0
         return (y, u0))
      (return (LLVM.value LLVM.zero))

cascadePackedP, phaserPackedP ::
   (Field.C a,
    IsFirstClass a, IsArithmetic a, IsConst a,
    IsPowerOf2 m, IsPrimitive a, IsSized a as,
    TypeNum.Nat n) =>
   CausalP.T p
      (CascadeParameter n (Value a), Value (Vector m a)) (Value (Vector m a))
cascadePackedP =
   replicateStage undefined causalPackedP

phaserPackedP =
   CausalP.mix <<<
   cascadePackedP &&& arr snd <<<
   second (CausalP.mapSimple (A.mul (SoV.replicateOf 0.5)))