{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{- |
Exponential curve with controllable delay.
-}
module Synthesizer.LLVM.Generator.Exponential2 (
   Parameter,
   parameter,
   parameterPlain,
   causalP,

   ParameterPacked,
   parameterPacked,
   parameterPackedPlain,
   causalPackedP,
   ) where

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.Simple.Value as Value
import qualified Synthesizer.LLVM.Parameter as Param

import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Representation as Rep

import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, valueOf, Vector,
    IsPowerOf2, IsConst, IsArithmetic, IsPrimitive, IsFirstClass, IsFloating, IsSized,
    Undefined, undefTuple,
    CodeGenFunction, )
import LLVM.Util.Loop (Phi, phis, addPhis, )

import qualified Data.TypeLevel.Num      as TypeNum
import qualified Data.TypeLevel.Num.Sets as TypeSet

import Foreign.Storable (Storable, )

import qualified Control.Applicative as App
import qualified Data.Foldable as Fold
import qualified Data.Traversable as Trav
import Control.Applicative (liftA2, (<*>), )
import Control.Arrow ((^<<), )

import qualified Algebra.Transcendental as Trans
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring

import NumericPrelude.Numeric
import NumericPrelude.Base


newtype Parameter a =
   Parameter a


instance Functor Parameter where
   {-# INLINE fmap #-}
   fmap f (Parameter k) = Parameter (f k)

instance App.Applicative Parameter where
   {-# INLINE pure #-}
   pure x = Parameter x
   {-# INLINE (<*>) #-}
   Parameter f <*> Parameter k =
      Parameter (f k)

instance Fold.Foldable Parameter where
   {-# INLINE foldMap #-}
   foldMap = Trav.foldMapDefault

instance Trav.Traversable Parameter where
   {-# INLINE sequenceA #-}
   sequenceA (Parameter k) =
      fmap Parameter k


instance (Phi a) => Phi (Parameter a) where
   phis = Class.phisTraversable
   addPhis = Class.addPhisFoldable

instance Undefined a => Undefined (Parameter a) where
   undefTuple = Class.undefTuplePointed

instance Class.Zero a => Class.Zero (Parameter a) where
   zeroTuple = Class.zeroTuplePointed

instance
      (Rep.Memory a s, IsSized s ss) =>
      Rep.Memory (Parameter a) s where
   load = Rep.loadNewtype Parameter
   store = Rep.storeNewtype (\(Parameter k) -> k)
   decompose = Rep.decomposeNewtype Parameter
   compose = Rep.composeNewtype (\(Parameter k) -> k)


instance LLVM.ValueTuple a => LLVM.ValueTuple (Parameter a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (Parameter a) where
   tupleDesc = Class.tupleDescFoldable

instance (LLVM.MakeValueTuple ah al) =>
      LLVM.MakeValueTuple (Parameter ah) (Parameter al) where
   valueTupleOf = Class.valueTupleOfFunctor


instance (Value.Flatten ah al) =>
      Value.Flatten (Parameter ah) (Parameter al) where
   flatten = Value.flattenTraversable
   unfold =  Value.unfoldFunctor


instance (Vector.ShuffleMatch n v) =>
      Vector.ShuffleMatch n (Parameter v) where
   shuffleMatch = Vector.shuffleMatchTraversable

instance (Vector.Access n a v) =>
      Vector.Access n (Parameter a) (Parameter v) where
   insert  = Vector.insertTraversable
   extract = Vector.extractTraversable


parameter ::
   (Trans.C a, IsConst a, IsFloating a) =>
   Value a ->
   CodeGenFunction r (Parameter (Value a))
parameter halfLife =
   Value.flatten $ parameterPlain $
   Value.constantValue halfLife

parameterPlain ::
   (Trans.C a) =>
   a -> Parameter a
parameterPlain halfLife =
   Parameter $ 0.5 ** recip halfLife


causalP ::
   (IsFirstClass a, IsSized a size,
    IsArithmetic a, IsConst a,
    Storable a, LLVM.MakeValueTuple a (Value a)) =>
   Param.T p a ->
   CausalP.T p (Parameter (Value a)) (Value a)
causalP initial =
   CausalP.mapAccum
      (\() (Parameter a) s -> do
         b <- A.mul a s
         return (s,b))
      return
      (return ())
      initial


data ParameterPacked a =
   ParameterPacked {ppFeedback, ppCurrent :: a}


instance Functor ParameterPacked where
   {-# INLINE fmap #-}
   fmap f p = ParameterPacked
      (f $ ppFeedback p) (f $ ppCurrent p)

instance App.Applicative ParameterPacked where
   {-# INLINE pure #-}
   pure x = ParameterPacked x x
   {-# INLINE (<*>) #-}
   f <*> p = ParameterPacked
      (ppFeedback f $ ppFeedback p)
      (ppCurrent f $ ppCurrent p)

instance Fold.Foldable ParameterPacked where
   {-# INLINE foldMap #-}
   foldMap = Trav.foldMapDefault

instance Trav.Traversable ParameterPacked where
   {-# INLINE sequenceA #-}
   sequenceA p =
      liftA2 ParameterPacked
         (ppFeedback p) (ppCurrent p)


instance (Phi a) => Phi (ParameterPacked a) where
   phis = Class.phisTraversable
   addPhis = Class.addPhisFoldable

instance Undefined a => Undefined (ParameterPacked a) where
   undefTuple = Class.undefTuplePointed

instance Class.Zero a => Class.Zero (ParameterPacked a) where
   zeroTuple = Class.zeroTuplePointed


memory ::
   (Rep.Memory l s, IsSized s ss) =>
   Rep.MemoryRecord r (LLVM.Struct (s, (s, ()))) (ParameterPacked l)
memory =
   liftA2 ParameterPacked
      (Rep.memoryElement ppFeedback TypeNum.d0)
      (Rep.memoryElement ppCurrent  TypeNum.d1)

instance
      (Rep.Memory l s, IsSized s ss) =>
      Rep.Memory (ParameterPacked l) (LLVM.Struct (s, (s, ()))) where
   load = Rep.loadRecord memory
   store = Rep.storeRecord memory
   decompose = Rep.decomposeRecord memory
   compose = Rep.composeRecord memory


instance LLVM.ValueTuple a => LLVM.ValueTuple (ParameterPacked a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (ParameterPacked a) where
   tupleDesc = Class.tupleDescFoldable

instance (LLVM.MakeValueTuple ah al) =>
      LLVM.MakeValueTuple (ParameterPacked ah) (ParameterPacked al) where
   valueTupleOf = Class.valueTupleOfFunctor


instance (Value.Flatten ah al) =>
      Value.Flatten (ParameterPacked ah) (ParameterPacked al) where
   flatten = Value.flattenTraversable
   unfold =  Value.unfoldFunctor


instance (Vector.ShuffleMatch m v) =>
      Vector.ShuffleMatch m (ParameterPacked v) where
   shuffleMatch = Vector.shuffleMatchTraversable

instance (Vector.Access m a v) =>
      Vector.Access m (ParameterPacked a) (ParameterPacked v) where
   insert  = Vector.insertTraversable
   extract = Vector.extractTraversable



withSize ::
   (n -> m (param (Value (Vector n a)))) ->
   m (param (Value (Vector n a)))
withSize f = f undefined

parameterPacked ::
   (Trans.C a, IsConst a, IsFloating a,
    IsPrimitive a, IsPowerOf2 n) =>
   Value a ->
   CodeGenFunction r (ParameterPacked (Value (Vector n a)))
parameterPacked halfLife = withSize $ \n -> do
   feedback <-
      SoV.replicate =<<
      A.pow (valueOf 0.5) =<<
      A.fdiv (valueOf $ fromIntegral $ TypeNum.toInt n) halfLife
   k <-
      A.pow (valueOf 0.5) =<<
      A.fdiv (valueOf 1) halfLife
   current <-
      Vector.iterate (A.mul k) (valueOf 1)
   return $ ParameterPacked feedback current
{-
   Value.flatten $ parameterPackedPlain $
   Value.constantValue halfLife
-}

withSizePlain ::
   (n -> param (Vector n a)) ->
   param (Vector n a)
withSizePlain f = f undefined

parameterPackedPlain ::
   (Trans.C a,
    IsPowerOf2 n) =>
   a -> ParameterPacked (Vector n a)
parameterPackedPlain halfLife =
   withSizePlain $ \n ->
   ParameterPacked
      (LLVM.vector [0.5 ** (fromIntegral (TypeNum.toInt n) / halfLife)])
      (LLVM.vector $ iterate (0.5 ** recip halfLife *) one)


causalPackedP ::
   (IsFirstClass a, IsSized a size,
    IsArithmetic a, IsConst a,
    Storable a, LLVM.MakeValueTuple a (Value a),
    IsPrimitive a, IsPowerOf2 n,
    TypeNum.Mul n size pss, TypeNum.Pos pss) =>
   Param.T p a ->
   CausalP.T p (ParameterPacked (Value (Vector n a))) (Value (Vector n a))
causalPackedP initial =
   CausalP.mapAccum
      (\() p s0 -> do
         s1 <- A.mul (ppFeedback p) s0
         b  <- A.mul (ppCurrent p) s0
         return (b,s1))
      return
      (return ())
      (LLVM.vector . (:[]) ^<< initial)