{-# LANGUAGE NoImplicitPrelude #-}
module Synthesizer.LLVM.Parameter where

import qualified LLVM.Core as LLVM

import qualified Algebra.Transcendental as Trans
import qualified Algebra.Algebraic as Algebraic
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring
import qualified Algebra.Additive as Additive

import qualified Control.Category as Cat
import qualified Control.Arrow as Arr
import qualified Control.Applicative as App
import Control.Applicative (pure, liftA2, )

import Data.Tuple.HT (mapFst, )

import NumericPrelude.Numeric
import Prelude (fmap, error, (.), const, id, Functor, Monad, )
import qualified Prelude as P


{- |
This data type is for parameters of parameterized signal generators and causal processes.
It is better than using plain functions of type @p -> a@
since it allows for numeric instances
and we can make explicit,
whether a parameter is constant.

We recommend to use parameters for atomic types.
Although a parameter of type @T p (a,b)@ is possible,
it means that the whole parameter is variable
if only one of the pair elements is variable.
This way you may miss optimizations.
-}
data T p a =
   Constant a |
   Variable (p -> a)


get :: T p a -> (p -> a)
get (Constant a) = const a
get (Variable f) = f


{- |
The call @value param v@ requires
that @v@ represents the same value as @valueTupleOf (get param p)@ for some @p@.
However @v@ might be the result of a load operation
and @param@ might be a constant.
In this case it is more efficient to use @valueTupleOf (get param undefined)@
since the constant is translated to an LLVM constant
that allows for certain optimizations.

This is the main function for taking advantage of a constant parameter
in low-level implementations.
For simplicity we do not omit constant parameters in the parameter struct
since this would mean to construct types at runtime and might become ugly.
Instead we just check using 'value' at the according places in LLVM code
whether a parameter is constant
and ignore the parameter from the struct in this case.
In many cases there will be no speed benefit
because the parameter will be loaded to a register anyway.
It can only lead to speed-up if subsequent optimizations
can precompute constant expressions.
Another example is 'drop' where a loop with constant loop count can be generated.
For small loop counts and simple loop bodies the loop might get unrolled.
-}
value ::
   LLVM.MakeValueTuple tuple value =>
   T p tuple -> value -> value
value (Constant a) _ = LLVM.valueTupleOf a
value (Variable _) v = v


{- |
@.@ can be used for fetching a parameter from a super-parameter.
-}
instance Cat.Category T where
   id = Variable id
   Constant f . _ = Constant f
   Variable f . Constant a = Constant (f a)
   Variable f . Variable g = Variable (f . g)

{- |
@arr@ is useful for lifting parameter selectors to our parameter type
without relying on the constructor.
-}
instance Arr.Arrow T where
   arr = Variable
   first f = Variable (mapFst (get f))



{- |
Useful for splitting @T p (a,b)@ into @T p a@ and @T p b@
using @fmap fst@ and @fmap snd@.
-}
instance Functor (T p) where
   fmap f (Constant a) = Constant (f a)
   fmap f (Variable g) = Variable (f . g)

{- |
Useful for combining @T p a@ and @T p b@ to @T p (a,b)@
using @liftA2 (,)@.
However, we do not recommend to do so
because the result parameter can only be constant
if both operands are constant.
-}
instance App.Applicative (T p) where
   pure a = Constant a
   Constant f <*> Constant a = Constant (f a)
   f <*> a = Variable (\p -> get f p (get a p))

instance Monad (T p) where
   return = pure
   Constant x >>= f = f x
   Variable x >>= f =
      Variable (\p -> get (f (x p)) p)


instance Additive.C a => Additive.C (T p a) where
   zero = pure zero
   negate = fmap negate
   (+) = liftA2 (+)
   (-) = liftA2 (-)

instance Ring.C a => Ring.C (T p a) where
   one = pure one
   (*) = liftA2 (*)
   x^n = fmap (^n) x
   fromInteger = pure . fromInteger

instance Field.C a => Field.C (T p a) where
   (/) = liftA2 (/)
   recip = fmap recip
   fromRational' = pure . fromRational'

instance Algebraic.C a => Algebraic.C (T p a) where
   x ^/ r = fmap (^/ r) x
   sqrt = fmap sqrt
   root n = fmap (Algebraic.root n)

instance Trans.C a => Trans.C (T p a) where
   pi      = pure   pi
   exp     = fmap   exp
   log     = fmap   log
   logBase = liftA2 logBase
   (**)    = liftA2 (**)
   sin     = fmap   sin
   tan     = fmap   tan
   cos     = fmap   cos
   asin    = fmap   asin
   atan    = fmap   atan
   acos    = fmap   acos
   sinh    = fmap   sinh
   tanh    = fmap   tanh
   cosh    = fmap   cosh
   asinh   = fmap   asinh
   atanh   = fmap   atanh
   acosh   = fmap   acosh


{-
Instances for Haskell98 numeric type classes
that are useful when working together with other libraries on fixed types.
-}
instance P.Eq a => P.Eq (T p a) where
   (==) = error "Synthesizer.LLVM.Parameter: Num instance requires Eq but we cannot define that"

instance P.Show a => P.Show (T p a) where
   show _ = "Synthesizer.LLVM.Parameter"

instance P.Num a => P.Num (T p a) where
   (+) = liftA2 (P.+)
   (-) = liftA2 (P.-)
   (*) = liftA2 (P.*)
   negate = fmap P.negate
   abs = fmap P.abs
   signum = fmap P.signum
   fromInteger = pure . P.fromInteger

instance P.Fractional a => P.Fractional (T p a) where
   (/) = liftA2 (P./)
   fromRational = pure . P.fromRational