{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Synthesizer.LLVM.Causal.Exponential2 (
Parameter,
parameter,
parameterPlain,
multiValueParameter,
unMultiValueParameter,
causal,
ParameterPacked,
parameterPacked,
parameterPackedExp,
parameterPackedPlain,
multiValueParameterPacked,
unMultiValueParameterPacked,
causalPacked,
) where
import qualified Synthesizer.LLVM.Causal.Private as CausalPriv
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Causal.Functional as F
import qualified Synthesizer.LLVM.Frame.SerialVector.Plain as SerialPlain
import qualified Synthesizer.LLVM.Frame.SerialVector.Code as SerialCode
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import qualified Synthesizer.LLVM.Frame.SerialVector.Class as SerialOld
import qualified Synthesizer.LLVM.Value as Value
import qualified LLVM.DSL.Expression as Expr
import LLVM.DSL.Expression (Exp)
import qualified LLVM.Extra.Multi.Value.Marshal as MarshalMV
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Storable as Storable
import qualified LLVM.Extra.Marshal as Marshal
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction, Value, IsFloating)
import qualified Type.Data.Num.Decimal as TypeNum
import qualified Foreign.Storable.Traversable as Store
import qualified Foreign.Storable
import Foreign.Storable (Storable)
import qualified Control.Applicative as App
import Control.Applicative (liftA2, pure, (<*>))
import Control.Arrow (arr, (&&&))
import qualified Data.Foldable as Fold
import qualified Data.Traversable as Trav
import Data.Traversable (traverse)
import qualified Algebra.Transcendental as Trans
import NumericPrelude.Numeric
import NumericPrelude.Base
newtype Parameter a = Parameter a
deriving (Show, Storable)
instance Functor Parameter where
{-# INLINE fmap #-}
fmap f (Parameter k) = Parameter (f k)
instance App.Applicative Parameter where
{-# INLINE pure #-}
pure x = Parameter x
{-# INLINE (<*>) #-}
Parameter f <*> Parameter k = Parameter (f k)
instance Fold.Foldable Parameter where
{-# INLINE foldMap #-}
foldMap = Trav.foldMapDefault
instance Trav.Traversable Parameter where
{-# INLINE sequenceA #-}
sequenceA (Parameter k) = fmap Parameter k
instance (Tuple.Phi a) => Tuple.Phi (Parameter a) where
phi = Tuple.phiTraversable
addPhi = Tuple.addPhiFoldable
instance Tuple.Undefined a => Tuple.Undefined (Parameter a) where
undef = Tuple.undefPointed
instance Tuple.Zero a => Tuple.Zero (Parameter a) where
zero = Tuple.zeroPointed
instance (Memory.C a) => Memory.C (Parameter a) where
type Struct (Parameter a) = Memory.Struct a
load = Memory.loadNewtype Parameter
store = Memory.storeNewtype (\(Parameter k) -> k)
decompose = Memory.decomposeNewtype Parameter
compose = Memory.composeNewtype (\(Parameter k) -> k)
instance (Marshal.C a) => Marshal.C (Parameter a) where
pack (Parameter k) = Marshal.pack k
unpack = Parameter . Marshal.unpack
instance (MarshalMV.C a) => MarshalMV.C (Parameter a) where
pack (Parameter k) = MarshalMV.pack k
unpack = Parameter . MarshalMV.unpack
instance (Storable.C a) => Storable.C (Parameter a) where
load = Storable.loadNewtype Parameter Parameter
store = Storable.storeNewtype Parameter (\(Parameter k) -> k)
instance (Tuple.Value a) => Tuple.Value (Parameter a) where
type ValueOf (Parameter a) = Parameter (Tuple.ValueOf a)
valueOf = Tuple.valueOfFunctor
instance (MultiValue.C a) => MultiValue.C (Parameter a) where
type Repr (Parameter a) = Parameter (MultiValue.Repr a)
cons = multiValueParameter . fmap MultiValue.cons
undef = multiValueParameter $ pure MultiValue.undef
zero = multiValueParameter $ pure MultiValue.zero
phi bb =
fmap multiValueParameter .
traverse (MultiValue.phi bb) . unMultiValueParameter
addPhi bb a b =
Fold.sequence_ $
liftA2 (MultiValue.addPhi bb)
(unMultiValueParameter a) (unMultiValueParameter b)
multiValueParameter ::
Parameter (MultiValue.T a) -> MultiValue.T (Parameter a)
multiValueParameter = MultiValue.Cons . fmap (\(MultiValue.Cons a) -> a)
unMultiValueParameter ::
MultiValue.T (Parameter a) -> Parameter (MultiValue.T a)
unMultiValueParameter (MultiValue.Cons x) = fmap MultiValue.Cons x
instance (Value.Flatten a) => Value.Flatten (Parameter a) where
type Registers (Parameter a) = Parameter (Value.Registers a)
flattenCode = Value.flattenCodeTraversable
unfoldCode = Value.unfoldCodeTraversable
instance (Vector.Simple v) => Vector.Simple (Parameter v) where
type Element (Parameter v) = Parameter (Vector.Element v)
type Size (Parameter v) = Vector.Size v
shuffleMatch = Vector.shuffleMatchTraversable
extract = Vector.extractTraversable
instance (Vector.C v) => Vector.C (Parameter v) where
insert = Vector.insertTraversable
instance
(Expr.Aggregate exp mv) =>
Expr.Aggregate (Parameter exp) (Parameter mv) where
type MultiValuesOf (Parameter exp) = Parameter (Expr.MultiValuesOf exp)
type ExpressionsOf (Parameter mv) = Parameter (Expr.ExpressionsOf mv)
bundle (Parameter p) = fmap Parameter $ Expr.bundle p
dissect (Parameter p) = Parameter $ Expr.dissect p
parameter ::
(Trans.C a, SoV.TranscendentalConstant a, IsFloating a) =>
Value a ->
CodeGenFunction r (Parameter (Value a))
parameter = Value.unlift1 parameterPlain
parameterPlain ::
(Trans.C a) =>
a -> Parameter a
parameterPlain halfLife =
Parameter $ 0.5 ^? recip halfLife
causal ::
(MarshalMV.C a, MultiValue.T a ~ am, MultiValue.PseudoRing a) =>
Exp a -> Causal.T (Parameter am) am
causal initial =
Causal.loop initial
(arr snd &&& CausalPriv.zipWith (\(Parameter a) -> A.mul a))
data ParameterPacked a =
ParameterPacked {ppFeedback, ppCurrent :: a}
instance Functor ParameterPacked where
{-# INLINE fmap #-}
fmap f p = ParameterPacked
(f $ ppFeedback p) (f $ ppCurrent p)
instance App.Applicative ParameterPacked where
{-# INLINE pure #-}
pure x = ParameterPacked x x
{-# INLINE (<*>) #-}
f <*> p = ParameterPacked
(ppFeedback f $ ppFeedback p)
(ppCurrent f $ ppCurrent p)
instance Fold.Foldable ParameterPacked where
{-# INLINE foldMap #-}
foldMap = Trav.foldMapDefault
instance Trav.Traversable ParameterPacked where
{-# INLINE sequenceA #-}
sequenceA p =
liftA2 ParameterPacked
(ppFeedback p) (ppCurrent p)
instance (Tuple.Phi a) => Tuple.Phi (ParameterPacked a) where
phi = Tuple.phiTraversable
addPhi = Tuple.addPhiFoldable
instance Tuple.Undefined a => Tuple.Undefined (ParameterPacked a) where
undef = Tuple.undefPointed
instance Tuple.Zero a => Tuple.Zero (ParameterPacked a) where
zero = Tuple.zeroPointed
instance Storable a => Storable (ParameterPacked a) where
sizeOf = Store.sizeOf
alignment = Store.alignment
peek = Store.peekApplicative
poke = Store.poke
type ParameterPackedStruct a = LLVM.Struct (a, (a, ()))
memory ::
(Memory.C a) =>
Memory.Record r (ParameterPackedStruct (Memory.Struct a)) (ParameterPacked a)
memory =
liftA2 ParameterPacked
(Memory.element ppFeedback TypeNum.d0)
(Memory.element ppCurrent TypeNum.d1)
instance (Memory.C a) => Memory.C (ParameterPacked a) where
type Struct (ParameterPacked a) = ParameterPackedStruct (Memory.Struct a)
load = Memory.loadRecord memory
store = Memory.storeRecord memory
decompose = Memory.decomposeRecord memory
compose = Memory.composeRecord memory
instance (Marshal.C a) => Marshal.C (ParameterPacked a) where
pack (ParameterPacked bend depth) = Marshal.pack (bend, depth)
unpack = uncurry ParameterPacked . Marshal.unpack
instance (MarshalMV.C a) => MarshalMV.C (ParameterPacked a) where
pack (ParameterPacked bend depth) = MarshalMV.pack (bend, depth)
unpack = uncurry ParameterPacked . MarshalMV.unpack
instance (Storable.C a) => Storable.C (ParameterPacked a) where
load = Storable.loadApplicative
store = Storable.storeFoldable
instance (Tuple.Value a) => Tuple.Value (ParameterPacked a) where
type ValueOf (ParameterPacked a) = ParameterPacked (Tuple.ValueOf a)
valueOf = Tuple.valueOfFunctor
instance (MultiValue.C a) => MultiValue.C (ParameterPacked a) where
type Repr (ParameterPacked a) = ParameterPacked (MultiValue.Repr a)
cons = multiValueParameterPacked . fmap MultiValue.cons
undef = multiValueParameterPacked $ pure MultiValue.undef
zero = multiValueParameterPacked $ pure MultiValue.zero
phi bb =
fmap multiValueParameterPacked .
traverse (MultiValue.phi bb) . unMultiValueParameterPacked
addPhi bb a b =
Fold.sequence_ $
liftA2 (MultiValue.addPhi bb)
(unMultiValueParameterPacked a) (unMultiValueParameterPacked b)
multiValueParameterPacked ::
ParameterPacked (MultiValue.T a) -> MultiValue.T (ParameterPacked a)
multiValueParameterPacked = MultiValue.Cons . fmap (\(MultiValue.Cons a) -> a)
unMultiValueParameterPacked ::
MultiValue.T (ParameterPacked a) -> ParameterPacked (MultiValue.T a)
unMultiValueParameterPacked (MultiValue.Cons x) = fmap MultiValue.Cons x
instance (Value.Flatten a) => Value.Flatten (ParameterPacked a) where
type Registers (ParameterPacked a) = ParameterPacked (Value.Registers a)
flattenCode = Value.flattenCodeTraversable
unfoldCode = Value.unfoldCodeTraversable
instance
(Expr.Aggregate exp mv) =>
Expr.Aggregate (ParameterPacked exp) (ParameterPacked mv) where
type MultiValuesOf (ParameterPacked exp) =
ParameterPacked (Expr.MultiValuesOf exp)
type ExpressionsOf (ParameterPacked mv) =
ParameterPacked (Expr.ExpressionsOf mv)
bundle p =
liftA2 ParameterPacked
(Expr.bundle $ ppFeedback p) (Expr.bundle $ ppCurrent p)
dissect p =
ParameterPacked
(Expr.dissect $ ppFeedback p) (Expr.dissect $ ppCurrent p)
type instance F.Arguments f (ParameterPacked a) = f (ParameterPacked a)
instance F.MakeArguments (ParameterPacked a) where
makeArgs = id
withSize ::
(TypeNum.Natural n) =>
(SerialOld.Write v, SerialOld.Size v ~ n, TypeNum.Positive n) =>
(TypeNum.Singleton n -> m (param v)) ->
m (param v)
withSize f = f TypeNum.singleton
parameterPacked ::
(SerialOld.Write v, SerialOld.Element v ~ a,
A.PseudoRing v, A.RationalConstant v,
A.Transcendental a, A.RationalConstant a) =>
a -> CodeGenFunction r (ParameterPacked v)
parameterPacked halfLife = withSize $ \n -> do
feedback <-
SerialOld.upsample =<<
A.pow (A.fromRational' 0.5) =<<
A.fdiv (A.fromInteger' $ TypeNum.integralFromSingleton n) halfLife
k <-
A.pow (A.fromRational' 0.5) =<<
A.fdiv (A.fromInteger' 1) halfLife
current <-
SerialOld.iterate (A.mul k) (A.fromInteger' 1)
return $ ParameterPacked feedback current
withSizePlain ::
(TypeNum.Positive n) =>
(TypeNum.Singleton n -> param (Serial.T n a)) ->
param (Serial.T n a)
withSizePlain f = f TypeNum.singleton
parameterPackedPlain ::
(TypeNum.Positive n, Trans.C a) =>
a -> ParameterPacked (Serial.T n a)
parameterPackedPlain halfLife =
withSizePlain $ \n ->
ParameterPacked
(SerialPlain.replicate
(0.5 ^? (fromInteger (TypeNum.integerFromSingleton n) / halfLife)))
(SerialPlain.iterate (0.5 ^? recip halfLife *) one)
withSizeExp ::
(TypeNum.Positive n) =>
(TypeNum.Singleton n -> param (exp (Serial.T n a))) ->
param (exp (Serial.T n a))
withSizeExp f = f TypeNum.singleton
parameterPackedExp ::
(TypeNum.Positive n) =>
(MultiValue.Transcendental a, MultiValue.RationalConstant a) =>
(MultiVector.C a) =>
Exp a -> ParameterPacked (Exp (Serial.T n a))
parameterPackedExp halfLife =
withSizeExp $ \n ->
ParameterPacked
(Serial.upsample
(0.5 ^? (fromInteger (TypeNum.integerFromSingleton n) / halfLife)))
(Serial.iterate (0.5 ^? recip halfLife *) one)
causalPacked ::
(MultiVector.PseudoRing a, MultiValue.IntegerConstant a,
TypeNum.Positive n, MarshalMV.Vector n a, MarshalMV.C a) =>
Exp a ->
Causal.T (ParameterPacked (SerialCode.Value n a)) (SerialCode.Value n a)
causalPacked initial =
Causal.loop
(Serial.upsample initial)
(CausalPriv.map $
\(p, s0) -> liftA2 (,)
(A.mul (ppCurrent p) s0)
(A.mul (ppFeedback p) s0))