{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{- |
Exponential curve with controllable delay.
-}
module Synthesizer.LLVM.Generator.Exponential2 (
   Parameter,
   parameter,
   parameterPlain,
   causalP,

   ParameterPacked,
   parameterPacked,
   parameterPackedPlain,
   causalPackedP,
   ) where

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.Simple.Value as Value
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import qualified Synthesizer.LLVM.CausalParameterized.Functional as F

import qualified LLVM.DSL.Parameter as Param

import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Storable as Storable
import qualified LLVM.Extra.Marshal as Marshal
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Arithmetic as A

import qualified LLVM.Core as LLVM
import LLVM.Core
         (CodeGenFunction, Value, IsArithmetic, IsPrimitive, IsFloating, SizeOf)

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal.Number ((:*:))

import Foreign.Storable (Storable)
import qualified Foreign.Storable
-- import qualified Foreign.Storable.Record as Store
import qualified Foreign.Storable.Traversable as Store

import qualified Control.Applicative as App
import qualified Data.Foldable as Fold
import qualified Data.Traversable as Trav
import Control.Applicative (liftA2, (<*>))
import Control.Arrow (arr, (^<<), (&&&))
import Control.Monad (liftM2)

import qualified Algebra.Transcendental as Trans

import NumericPrelude.Numeric
import NumericPrelude.Base


newtype Parameter a = Parameter a
   deriving (Show, Storable)


instance Functor Parameter where
   {-# INLINE fmap #-}
   fmap f (Parameter k) = Parameter (f k)

instance App.Applicative Parameter where
   {-# INLINE pure #-}
   pure x = Parameter x
   {-# INLINE (<*>) #-}
   Parameter f <*> Parameter k =
      Parameter (f k)

instance Fold.Foldable Parameter where
   {-# INLINE foldMap #-}
   foldMap = Trav.foldMapDefault

instance Trav.Traversable Parameter where
   {-# INLINE sequenceA #-}
   sequenceA (Parameter k) =
      fmap Parameter k


instance (Tuple.Phi a) => Tuple.Phi (Parameter a) where
   phi = Tuple.phiTraversable
   addPhi = Tuple.addPhiFoldable

instance Tuple.Undefined a => Tuple.Undefined (Parameter a) where
   undef = Tuple.undefPointed

instance Tuple.Zero a => Tuple.Zero (Parameter a) where
   zero = Tuple.zeroPointed

instance (Memory.C a) => Memory.C (Parameter a) where
   type Struct (Parameter a) = Memory.Struct a
   load = Memory.loadNewtype Parameter
   store = Memory.storeNewtype (\(Parameter k) -> k)
   decompose = Memory.decomposeNewtype Parameter
   compose = Memory.composeNewtype (\(Parameter k) -> k)

instance (Storable.C a) => Storable.C (Parameter a) where
   load = Storable.loadNewtype Parameter Parameter
   store = Storable.storeNewtype Parameter (\(Parameter k) -> k)

{-
instance LLVM.ValueTuple a => LLVM.ValueTuple (Parameter a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (Parameter a) where
   tupleDesc = Class.tupleDescFoldable
-}

instance (Tuple.Value a) => Tuple.Value (Parameter a) where
   type ValueOf (Parameter a) = Parameter (Tuple.ValueOf a)
   valueOf = Tuple.valueOfFunctor


instance (Value.Flatten a) => Value.Flatten (Parameter a) where
   type Registers (Parameter a) = Parameter (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable


instance (Vector.Simple v) => Vector.Simple (Parameter v) where
   type Element (Parameter v) = Parameter (Vector.Element v)
   type Size (Parameter v) = Vector.Size v
   shuffleMatch = Vector.shuffleMatchTraversable
   extract = Vector.extractTraversable

instance (Vector.C v) => Vector.C (Parameter v) where
   insert  = Vector.insertTraversable


parameter ::
   (Trans.C a, SoV.TranscendentalConstant a, IsFloating a) =>
   Value a ->
   CodeGenFunction r (Parameter (Value a))
parameter = Value.unlift1 parameterPlain

parameterPlain ::
   (Trans.C a) =>
   a -> Parameter a
parameterPlain halfLife =
   Parameter $ 0.5 ** recip halfLife


causalP ::
   (Marshal.C a, Tuple.ValueOf a ~ al, A.PseudoRing al) =>
   Param.T p a ->
   CausalP.T p (Parameter al) al
causalP initial =
   CausalP.loop initial
      (arr snd &&& CausalP.zipWithSimple (\(Parameter a) -> A.mul a))


data ParameterPacked a =
   ParameterPacked {ppFeedback, ppCurrent :: a}


instance Functor ParameterPacked where
   {-# INLINE fmap #-}
   fmap f p = ParameterPacked
      (f $ ppFeedback p) (f $ ppCurrent p)

instance App.Applicative ParameterPacked where
   {-# INLINE pure #-}
   pure x = ParameterPacked x x
   {-# INLINE (<*>) #-}
   f <*> p = ParameterPacked
      (ppFeedback f $ ppFeedback p)
      (ppCurrent f $ ppCurrent p)

instance Fold.Foldable ParameterPacked where
   {-# INLINE foldMap #-}
   foldMap = Trav.foldMapDefault

instance Trav.Traversable ParameterPacked where
   {-# INLINE sequenceA #-}
   sequenceA p =
      liftA2 ParameterPacked
         (ppFeedback p) (ppCurrent p)


instance (Tuple.Phi a) => Tuple.Phi (ParameterPacked a) where
   phi = Tuple.phiTraversable
   addPhi = Tuple.addPhiFoldable

instance Tuple.Undefined a => Tuple.Undefined (ParameterPacked a) where
   undef = Tuple.undefPointed

instance Tuple.Zero a => Tuple.Zero (ParameterPacked a) where
   zero = Tuple.zeroPointed


{-
storeParameter ::
   Storable a => Store.Dictionary (ParameterPacked a)
storeParameter =
   Store.run $
   liftA2 ParameterPacked
      (Store.element ppFeedback)
      (Store.element ppCurrent)

instance Storable a => Storable (ParameterPacked a) where
   sizeOf    = Store.sizeOf storeParameter
   alignment = Store.alignment storeParameter
   peek      = Store.peek storeParameter
   poke      = Store.poke storeParameter
-}

instance Storable a => Storable (ParameterPacked a) where
   sizeOf    = Store.sizeOf
   alignment = Store.alignment
   peek      = Store.peekApplicative
   poke      = Store.poke


type ParameterPackedStruct a = LLVM.Struct (a, (a, ()))

memory ::
   (Memory.C a) =>
   Memory.Record r (ParameterPackedStruct (Memory.Struct a)) (ParameterPacked a)
memory =
   liftA2 ParameterPacked
      (Memory.element ppFeedback TypeNum.d0)
      (Memory.element ppCurrent  TypeNum.d1)

instance (Memory.C a) => Memory.C (ParameterPacked a) where
   type Struct (ParameterPacked a) = ParameterPackedStruct (Memory.Struct a)
   load = Memory.loadRecord memory
   store = Memory.storeRecord memory
   decompose = Memory.decomposeRecord memory
   compose = Memory.composeRecord memory

instance (Storable.C a) => Storable.C (ParameterPacked a) where
   load = Storable.loadApplicative
   store = Storable.storeFoldable


{-
instance LLVM.ValueTuple a => LLVM.ValueTuple (ParameterPacked a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (ParameterPacked a) where
   tupleDesc = Class.tupleDescFoldable
-}

instance (Tuple.Value a) => Tuple.Value (ParameterPacked a) where
   type ValueOf (ParameterPacked a) = ParameterPacked (Tuple.ValueOf a)
   valueOf = Tuple.valueOfFunctor


instance (Value.Flatten a) => Value.Flatten (ParameterPacked a) where
   type Registers (ParameterPacked a) = ParameterPacked (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable

type instance F.Arguments f (ParameterPacked a) = f (ParameterPacked a)
instance F.MakeArguments (ParameterPacked a) where
   makeArgs = id



withSize ::
   (TypeNum.Natural n) =>
   (Serial.C v, Serial.Size v ~ n, TypeNum.Positive n) =>
   (TypeNum.Singleton n -> m (param v)) ->
   m (param v)
withSize f = f TypeNum.singleton

parameterPacked ::
   (Serial.C v, Serial.Element v ~ a,
    A.PseudoRing v, A.RationalConstant v,
    A.Transcendental a, A.RationalConstant a) =>
   a -> CodeGenFunction r (ParameterPacked v)
parameterPacked halfLife = withSize $ \n -> do
   feedback <-
      Serial.upsample =<<
      A.pow (A.fromRational' 0.5) =<<
      A.fdiv (A.fromInteger' $ TypeNum.integralFromSingleton n) halfLife
   k <-
      A.pow (A.fromRational' 0.5) =<<
      A.fdiv (A.fromInteger' 1) halfLife
   current <-
      Serial.iterate (A.mul k) (A.fromInteger' 1)
   return $ ParameterPacked feedback current
{-
   Value.unlift1 parameterPackedPlain
-}

withSizePlain ::
   (TypeNum.Natural n) =>
   (TypeNum.Singleton n -> param (Serial.Plain n a)) ->
   param (Serial.Plain n a)
withSizePlain f = f TypeNum.singleton

parameterPackedPlain ::
   (Trans.C a,
    TypeNum.Positive n) =>
   a -> ParameterPacked (Serial.Plain n a)
parameterPackedPlain halfLife =
   withSizePlain $ \n ->
   ParameterPacked
      (Serial.replicate_ n (0.5 ** (fromInteger (TypeNum.integerFromSingleton n) / halfLife)))
      (Serial.iteratePlain (0.5 ** recip halfLife *) one)


withSizeValue ::
   (TypeNum.Natural n) =>
   (TypeNum.Singleton n -> f (Serial.Value n a)) ->
   f (Serial.Value n a)
withSizeValue f = f TypeNum.singleton

causalPackedP ::
   (IsArithmetic a, SoV.IntegerConstant a,
    Marshal.C a, Tuple.ValueOf a ~ Value a,
    Marshal.Vector n a, Tuple.VectorValueOf n a ~ Value (LLVM.Vector n a),
    IsPrimitive a,
    TypeNum.Positive (n :*: SizeOf a),
    TypeNum.Positive n) =>
   Param.T p a ->
   CausalP.T p (ParameterPacked (Serial.Value n a)) (Serial.Value n a)
causalPackedP initial =
   withSizeValue $ \n ->
   CausalP.loop
      (Serial.replicate_ n ^<< initial)
      (CausalP.mapSimple $
       \(p, s0) -> liftM2 (,)
          (A.mul (ppCurrent p) s0)
          (A.mul (ppFeedback p) s0))