{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ExistentialQuantification #-}
module Data.Array.Knead.Parameter where

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Memory as Memory
import Foreign.Storable.Tuple ()
import Foreign.Storable (Storable, )

import qualified Control.Category as Cat
import qualified Control.Arrow as Arr
import qualified Control.Applicative as App
import Control.Applicative (pure, liftA2, )

import Data.Tuple.HT (mapFst, )
import Data.Word (Word32, )


{- |
This data type is for parameters of parameterized signal generators and causal processes.
It is better than using plain functions of type @p -> a@
since it allows for numeric instances
and we can make explicit,
whether a parameter is constant.

We recommend to use parameters for atomic types.
Although a parameter of type @T p (a,b)@ is possible,
it means that the whole parameter is variable
if only one of the pair elements is variable.
This way you may miss optimizations.
-}
data T p a =
   Constant a |
   Variable (p -> a)


get :: T p a -> (p -> a)
get (Constant a) = const a
get (Variable f) = f


{- |
The call @value param v@ requires
that @v@ represents the same value as @valueTupleOf (get param p)@ for some @p@.
However @v@ might be the result of a load operation
and @param@ might be a constant.
In this case it is more efficient to use @valueTupleOf (get param undefined)@
since the constant is translated to an LLVM constant
that allows for certain optimizations.

This is the main function for taking advantage of a constant parameter
in low-level implementations.
For simplicity we do not omit constant parameters in the parameter struct
since this would mean to construct types at runtime and might become ugly.
Instead we just check using 'value' at the according places in LLVM code
whether a parameter is constant
and ignore the parameter from the struct in this case.
In many cases there will be no speed benefit
because the parameter will be loaded to a register anyway.
It can only lead to speed-up if subsequent optimizations
can precompute constant expressions.
Another example is 'drop' where a loop with constant loop count can be generated.
For small loop counts and simple loop bodies the loop might get unrolled.
-}
valueTuple ::
   (Class.MakeValueTuple tuple, Class.ValueTuple tuple ~ value) =>
   T p tuple -> value -> value
valueTuple = genericValue Class.valueTupleOf

multiValue ::
   (MultiValue.C a) =>
   T p a -> MultiValue.T a -> MultiValue.T a
multiValue = genericValue MultiValue.cons

genericValue ::
   (a -> value) ->
   T p a -> value -> value
genericValue cons p v =
   case p of
      Constant a -> cons a
      Variable _ -> v


{- |
This function provides specialised variants of 'get' and 'value',
that use the unit type for constants
and thus save space in parameter structures.
-}
withTuple ::
   (Storable tuple, Class.MakeValueTuple tuple,
    Class.ValueTuple tuple ~ value, Memory.C value) =>
   T p tuple ->
   (forall parameters.
    (Storable parameters,
     Class.MakeValueTuple parameters,
     Memory.C (Class.ValueTuple parameters)) =>
    (p -> parameters) ->
    (Class.ValueTuple parameters -> value) ->
    a) ->
   a
withTuple (Constant a) f = f (const ()) (\() -> Class.valueTupleOf a)
withTuple (Variable v) f = f v id

withMulti ::
   (Storable b, MultiValueMemory.C b) =>
   T p b ->
   (forall parameters.
    (Storable parameters,
     MultiValueMemory.C parameters) =>
    (p -> parameters) ->
    (MultiValue.T parameters -> MultiValue.T b) ->
    a) ->
   a
withMulti = with MultiValue.cons

with ::
   (Storable b, MultiValueMemory.C b) =>
   (b -> MultiValue.T b) ->
   T p b ->
   (forall parameters.
    (Storable parameters,
     MultiValueMemory.C parameters) =>
    (p -> parameters) ->
    (MultiValue.T parameters -> MultiValue.T b) ->
    a) ->
   a
with cons p f =
   case p of
      Constant b -> f (const ()) (\_ -> cons b)
      Variable v -> f v id


data Tunnel p a =
   forall t.
   (Storable t, MultiValueMemory.C t) =>
   Tunnel (p -> t) (MultiValue.T t -> MultiValue.T a)

tunnel ::
   (Storable a, MultiValueMemory.C a) =>
   (a -> MultiValue.T a) -> T p a -> Tunnel p a
tunnel cons p =
   case p of
      Constant b -> Tunnel (const ()) (\_ -> cons b)
      Variable v -> Tunnel v id


word32 :: T p Int -> T p Word32
word32 = fmap fromIntegral


infixl 0 $#

($#) :: (T p a -> b) -> (a -> b)
($#) f a = f (pure a)


{- |
@.@ can be used for fetching a parameter from a super-parameter.
-}
instance Cat.Category T where
   id = Variable id
   Constant f . _ = Constant f
   Variable f . Constant a = Constant (f a)
   Variable f . Variable g = Variable (f . g)

{- |
@arr@ is useful for lifting parameter selectors to our parameter type
without relying on the constructor.
-}
instance Arr.Arrow T where
   arr = Variable
   first f = Variable (mapFst (get f))



{- |
Useful for splitting @T p (a,b)@ into @T p a@ and @T p b@
using @fmap fst@ and @fmap snd@.
-}
instance Functor (T p) where
   fmap f (Constant a) = Constant (f a)
   fmap f (Variable g) = Variable (f . g)

{- |
Useful for combining @T p a@ and @T p b@ to @T p (a,b)@
using @liftA2 (,)@.
However, we do not recommend to do so
because the result parameter can only be constant
if both operands are constant.
-}
instance App.Applicative (T p) where
   pure a = Constant a
   Constant f <*> Constant a = Constant (f a)
   f <*> a = Variable (\p -> get f p (get a p))

instance Monad (T p) where
   return = pure
   Constant x >>= f = f x
   Variable x >>= f =
      Variable (\p -> get (f (x p)) p)


instance Num a => Num (T p a) where
   (+) = liftA2 (+)
   (-) = liftA2 (-)
   (*) = liftA2 (*)
   negate = fmap negate
   abs = fmap abs
   signum = fmap signum
   fromInteger = pure . fromInteger

instance Fractional a => Fractional (T p a) where
   (/) = liftA2 (/)
   fromRational = pure . fromRational