{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ExistentialQuantification #-}
module LLVM.DSL.Parameter (




   -- * for implementation of new processes
   ) where

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Marshal as Marshal

import qualified Algebra.Transcendental as Trans
import qualified Algebra.Algebraic as Algebraic
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring
import qualified Algebra.Additive as Additive

import qualified Control.Category as Cat
import qualified Control.Arrow as Arr
import qualified Control.Applicative as App
import qualified Control.Functor.HT as FuncHT
import Control.Applicative (pure, liftA2)

import Data.Tuple.HT (mapFst, mapPair, mapTriple)
import Data.Word (Word)

import Prelude2010
import Prelude ()

{- |
This data type is for parameters of parameterized LLVM code.
It is better than using plain functions of type @p -> a@
since it allows for numeric instances
and we can make explicit,
whether a parameter is constant.

We recommend to use parameters for atomic types.
Although a parameter of type @T p (a,b)@ is possible,
it means that the whole parameter is variable
if only one of the pair elements is variable.
This way you may miss opportunities for constant folding.
data T p a =
   Constant a |
   Variable (p -> a)

get :: T p a -> (p -> a)
get (Constant a) = const a
get (Variable f) = f

{- |
The call @value param v@ requires
that @v@ represents the same value as @valueTupleOf (get param p)@ for some @p@.
However @v@ might be the result of a load operation
and @param@ might be a constant.
In this case it is more efficient to use @valueTupleOf (get param undefined)@
since the constant is translated to an LLVM constant
that allows for certain optimizations.

This is the main function for taking advantage of a constant parameter
in low-level implementations.
For simplicity we do not omit constant parameters in the parameter struct
since this would mean to construct types at runtime and might become ugly.
Instead we just check using 'value' at the according places in LLVM code
whether a parameter is constant
and ignore the parameter from the struct in this case.
In many cases there will be no speed benefit
because the parameter will be loaded to a register anyway.
It can only lead to speed-up if subsequent optimizations
can precompute constant expressions.
Another example is 'drop' where a loop with constant loop count can be generated.
For small loop counts and simple loop bodies the loop might get unrolled.
valueTuple ::
   (Tuple.Value tuple, Tuple.ValueOf tuple ~ value) =>
   T p tuple -> value -> value
valueTuple = genericValue Tuple.valueOf

multiValue ::
   (MultiValue.C a) =>
   T p a -> MultiValue.T a -> MultiValue.T a
multiValue = genericValue MultiValue.cons

genericValue ::
   (a -> value) ->
   T p a -> value -> value
genericValue cons p v =
   case p of
      Constant a -> cons a
      Variable _ -> v

{- |
This function provides specialised variants of 'get' and 'value',
that use the unit type for constants
and thus save space in parameter structures.
{-# INLINE withValue #-}
withValue ::
   (Marshal.C tuple, Tuple.ValueOf tuple ~ value) =>
   T p tuple ->
   (forall parameters.
    (Marshal.C parameters) =>
    (p -> parameters) ->
    (Tuple.ValueOf parameters -> value) ->
    a) ->
withValue (Constant a) f = f (const ()) (\() -> Tuple.valueOf a)
withValue (Variable v) f = f v id

{-# INLINE withMulti #-}
withMulti ::
   (Marshal.MV b) =>
   T p b ->
   (forall parameters.
    (Marshal.MV parameters) =>
    (p -> parameters) ->
    (MultiValue.T parameters -> MultiValue.T b) ->
    a) ->
withMulti = with MultiValue.cons

{-# INLINE with #-}
with ::
   (Marshal.MV b) =>
   (b -> MultiValue.T b) ->
   T p b ->
   (forall parameters.
    (Marshal.MV parameters) =>
    (p -> parameters) ->
    (MultiValue.T parameters -> MultiValue.T b) ->
    a) ->
with cons p f =
   case p of
      Constant b -> f (const ()) (\_ -> cons b)
      Variable v -> f v id

data Tunnel p a =
   forall t.
   (Marshal.MV t) => Tunnel (p -> t) (MultiValue.T t -> MultiValue.T a)

tunnel :: (Marshal.MV a) => (a -> MultiValue.T a) -> T p a -> Tunnel p a
tunnel cons p =
   case p of
      Constant b -> Tunnel (const ()) (\_ -> cons b)
      Variable v -> Tunnel v id

wordInt :: T p Int -> T p Word
wordInt = fmap fromIntegral

infixl 0 $#

($#) :: (T p a -> b) -> (a -> b)
($#) f a = f (pure a)

class Tuple tuple where
   type Composed tuple :: *
   type Source tuple :: *
   decompose :: T (Source tuple) (Composed tuple) -> tuple

instance Tuple (T p a) where
   type Composed (T p a) = a
   type Source (T p a) = p
   decompose = id

instance (Tuple a, Tuple b, Source a ~ Source b) => Tuple (a,b) where
   type Composed (a,b) = (Composed a, Composed b)
   type Source (a,b) = Source a
   decompose = mapPair (decompose, decompose) . FuncHT.unzip

   (Tuple a, Tuple b, Tuple c, Source a ~ Source b, Source b ~ Source c) =>
      Tuple (a,b,c) where
   type Composed (a,b,c) = (Composed a, Composed b, Composed c)
   type Source (a,b,c) = Source a
   decompose = mapTriple (decompose, decompose, decompose) . FuncHT.unzip3

{- |
Provide all elements of a nested tuple as separate parameters.

If you do not use one of the tuple elements,
you will get a type error like
@Couldn't match type `Param.Composed t0' with `Int'@.
The problem is that the type checker cannot infer
that an element is a @Parameter.T@ if it remains unused.
withTuple ::
   (Tuple tuple, Source tuple ~ p, Composed tuple ~ p) =>
   (tuple -> f p) -> f p
withTuple f = idFromFunctor $ f . decompose

idFromFunctor :: (T p p -> f p) -> f p
idFromFunctor f = f Cat.id

withTuple1 ::
   (Tuple tuple, Source tuple ~ p, Composed tuple ~ p) =>
   (tuple -> f p a) -> f p a
withTuple1 f = idFromFunctor1 $ f . decompose

idFromFunctor1 :: (T p p -> f p a) -> f p a
idFromFunctor1 f = f Cat.id

withTuple2 ::
   (Tuple tuple, Source tuple ~ p, Composed tuple ~ p) =>
   (tuple -> f p a b) -> f p a b
withTuple2 f = idFromFunctor2 $ f . decompose

idFromFunctor2 :: (T p p -> f p a b) -> f p a b
idFromFunctor2 f = f Cat.id

{- |
@.@ can be used for fetching a parameter from a super-parameter.
instance Cat.Category T where
   id = Variable id
   Constant f . _ = Constant f
   Variable f . Constant a = Constant (f a)
   Variable f . Variable g = Variable (f . g)

{- |
@arr@ is useful for lifting parameter selectors to our parameter type
without relying on the constructor.
instance Arr.Arrow T where
   arr = Variable
   first f = Variable (mapFst (get f))

{- |
Useful for splitting @T p (a,b)@ into @T p a@ and @T p b@
using @fmap fst@ and @fmap snd@.
instance Functor (T p) where
   fmap f (Constant a) = Constant (f a)
   fmap f (Variable g) = Variable (f . g)

{- |
Useful for combining @T p a@ and @T p b@ to @T p (a,b)@
using @liftA2 (,)@.
However, we do not recommend to do so
because the result parameter can only be constant
if both operands are constant.
instance App.Applicative (T p) where
   pure a = Constant a
   Constant f <*> Constant a = Constant (f a)
   f <*> a = Variable (\p -> get f p (get a p))

instance Monad (T p) where
   return = pure
   Constant x >>= f = f x
   Variable x >>= f =
      Variable (\p -> get (f (x p)) p)

instance Num a => Num (T p a) where
   (+) = liftA2 (+)
   (-) = liftA2 (-)
   (*) = liftA2 (*)
   negate = fmap negate
   abs = fmap abs
   signum = fmap signum
   fromInteger = pure . fromInteger

instance Fractional a => Fractional (T p a) where
   (/) = liftA2 (/)
   fromRational = pure . fromRational

instance Floating a => Floating (T p a) where
   pi = pure pi
   sqrt = fmap sqrt
   (**) = liftA2 (**)
   exp = fmap exp
   log = fmap log
   logBase = liftA2 logBase
   sin = fmap sin
   tan = fmap tan
   cos = fmap cos
   asin = fmap asin
   atan = fmap atan
   acos = fmap acos
   sinh = fmap sinh
   tanh = fmap tanh
   cosh = fmap cosh
   asinh = fmap asinh
   atanh = fmap atanh
   acosh = fmap acosh

instance Additive.C a => Additive.C (T p a) where
   zero = pure Additive.zero
   negate = fmap Additive.negate
   (+) = liftA2 (Additive.+)
   (-) = liftA2 (Additive.-)

instance Ring.C a => Ring.C (T p a) where
   one = pure Ring.one
   (*) = liftA2 (Ring.*)
   x^n = fmap (Ring.^n) x
   fromInteger = pure . Ring.fromInteger

instance Field.C a => Field.C (T p a) where
   (/) = liftA2 (Field./)
   recip = fmap Field.recip
   fromRational' = pure . Field.fromRational'

instance Algebraic.C a => Algebraic.C (T p a) where
   x ^/ r = fmap (Algebraic.^/ r) x
   sqrt = fmap Algebraic.sqrt
   root n = fmap (Algebraic.root n)

instance Trans.C a => Trans.C (T p a) where
   pi      = pure   Trans.pi
   exp     = fmap   Trans.exp
   log     = fmap   Trans.log
   logBase = liftA2 Trans.logBase
   (**)    = liftA2 (Trans.**)
   sin     = fmap   Trans.sin
   tan     = fmap   Trans.tan
   cos     = fmap   Trans.cos
   asin    = fmap   Trans.asin
   atan    = fmap   Trans.atan
   acos    = fmap   Trans.acos
   sinh    = fmap   Trans.sinh
   tanh    = fmap   Trans.tanh
   cosh    = fmap   Trans.cosh
   asinh   = fmap   Trans.asinh
   atanh   = fmap   Trans.atanh
   acosh   = fmap   Trans.acosh