{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
module Synthesizer.LLVM.Simple.Value (
   T, decons,
   twoPi, square, sqrt,
   max, min, limit, fraction,

   (%==), (%/=), (%<), (%<=), (%>), (%>=), not,
   (%&&), (%||),
   (?), (??),

   lift0, lift1, lift2, lift3,
   unlift0, unlift1, unlift2, unlift3, unlift4, unlift5,
   constantValue, constant,
   fromInteger', fromRational',

   Flatten(flattenCode, unfoldCode), Registers,
   flatten, unfold,
   flattenCodeTraversable, unfoldCodeTraversable,
   flattenFunction,
   ) where

import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Arithmetic as A

import LLVM.Util.Loop (Phi, )
import LLVM.Core (CodeGenFunction, )
import qualified LLVM.Core as LLVM

import qualified Synthesizer.Basic.Phase as Phase

import qualified Data.Vault.Lazy as Vault
import qualified Control.Monad.Trans.Class as MT
import qualified Control.Monad.Trans.State as MS
import Control.Monad (liftM2, liftM3, )
import Control.Applicative (Applicative, pure, (<*>), )
import Control.Functor.HT (unzip, unzip3, )

import qualified Synthesizer.LLVM.Frame.Stereo as Stereo

-- import qualified Algebra.NormedSpace.Maximum   as NormedMax
import qualified Algebra.NormedSpace.Euclidean as NormedEuc
import qualified Algebra.NormedSpace.Sum       as NormedSum

import qualified Algebra.Transcendental as Trans
import qualified Algebra.Algebraic as Algebraic
import qualified Algebra.RealRing as RealRing
import qualified Algebra.Absolute as Absolute
import qualified Algebra.Module as Module
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring
import qualified Algebra.Additive as Additive

import qualified Number.Complex as Complex

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold

import qualified System.Unsafe as Unsafe

import qualified Prelude as P
import NumericPrelude.Numeric hiding (pi, sqrt, fromRational', fraction, )
import NumericPrelude.Base hiding (min, max, unzip, unzip3, not, )


{-
The @r@ type parameter must be hidden and forall-quantified
because otherwise we would need an impossible type
where we have to quantify for @r@ and @t@ in different scopes
while having a class constraint that involves both of them.

> osci ::
>    (RealRing.C (Value.T r t),
>     IsFirstClass t, IsSized t size, IsFloating t,
>     IsPrimitive t, IsConst t) =>
>    (forall r. Wave.T (Value.T r t) (Value.T r y)) ->
>    t -> t -> T (Value y)

-}
newtype T a = Cons {code :: forall r. Compute r a}

decons :: T a -> (forall r. LLVM.CodeGenFunction r a)
decons value =
   MS.evalStateT (code value) Vault.empty

instance Functor T where
   fmap f x = consUnique (fmap f (code x))

instance Applicative T where
   pure = constantValue
   f <*> x = consUnique (code f <*> code x)


type Compute r a =
   MS.StateT Vault.Vault (LLVM.CodeGenFunction r) a

consUnique :: (forall r. Compute r a) -> T a
consUnique code0 =
   Unsafe.performIO $
   fmap (consKey code0) Vault.newKey

consKey :: (forall r. Compute r a) -> Vault.Key a -> T a
consKey code0 key =
   Cons (do
      ma <- MS.gets (Vault.lookup key)
      case ma of
         Just a -> return a
         Nothing -> do
            a <- code0
            MS.modify (Vault.insert key a)
            return a)

{- |
We do not require a numeric prelude superclass,
thus also LLVM only types like vectors are instances.
-}
instance (A.Additive a) => Additive.C (T a) where
   zero = constantValue A.zero
   (+) = lift2 A.add
   (-) = lift2 A.sub
   negate = lift1 A.neg

instance (A.PseudoRing a, A.IntegerConstant a) =>
      Ring.C (T a) where
   one = constantValue A.one
   (*) = lift2 A.mul
   fromInteger = fromInteger'

{-
This instance is enough for Module here.
The difference to Module instances on Haskell tuples is,
that LLVM vectors cannot be nested.
-}
instance (a ~ A.Scalar v, A.PseudoModule v, A.IntegerConstant a) =>
      Module.C (T a) (T v) where
   (*>) = lift2 A.scale

instance (A.Additive a, A.IntegerConstant a) => Enum (T a) where
   succ x = x + constantValue A.one
   pred x = x - constantValue A.one
   fromEnum _ = error "CodeGenFunction Value: fromEnum"
   toEnum = constantValue . A.fromInteger' . fromIntegral

{-
instance (IsArithmetic a, Cmp a b, Num a, IsConst a) => Real (T a) where
   toRational _ = error "CodeGenFunction Value: toRational"

instance (Cmp a b, Num a, IsConst a, IsInteger a) => Integral (T a) where
   quot = lift2 idiv
   rem  = lift2 irem
   quotRem x y = (quot x y, rem x y)
   toInteger _ = error "CodeGenFunction Value: toInteger"
-}

instance (A.Field a, A.RationalConstant a) => Field.C (T a) where
   (/) = lift2 A.fdiv
   fromRational' = fromRational' . Field.fromRational'

{-
instance (Cmp a b, Fractional a, IsConst a, IsFloating a) => RealFrac (T a) where
   properFraction _ = error "CodeGenFunction Value: properFraction"
-}

instance (A.Transcendental a, A.RationalConstant a) => Algebraic.C (T a) where
   sqrt = lift1 A.sqrt
   root n x = lift2 A.pow x (1 / fromInteger n)
   x^/r = lift2 A.pow x (Field.fromRational' r)

instance (A.Transcendental a, A.RationalConstant a) => Trans.C (T a) where
   pi = lift0 A.pi
   sin = lift1 A.sin
   cos = lift1 A.cos
   (**) = lift2 A.pow
   exp = lift1 A.exp
   log = lift1 A.log

   asin _ = error "LLVM missing intrinsic: asin"
   acos _ = error "LLVM missing intrinsic: acos"
   atan _ = error "LLVM missing intrinsic: atan"


instance
   (A.PseudoRing a, A.Real a, A.IntegerConstant a) =>
      P.Num (T a) where
   fromInteger = fromInteger'
   (+) = lift2 A.add
   (-) = lift2 A.sub
   (*) = lift2 A.mul
   negate = lift1 A.neg
   abs = lift1 A.abs
   signum = lift1 A.signum

instance
   (A.Field a, A.Real a, A.RationalConstant a) =>
      P.Fractional (T a) where
   fromRational = fromRational'
   (/) = lift2 A.fdiv

instance
   (A.Transcendental a, A.Real a, A.RationalConstant a) =>
      P.Floating (T a) where
   pi = lift0 A.pi
   sin = lift1 A.sin
   cos = lift1 A.cos
   (**) = lift2 A.pow
   exp = lift1 A.exp
   log = lift1 A.log

   asin _ = error "LLVM missing intrinsic: asin"
   acos _ = error "LLVM missing intrinsic: acos"
   atan _ = error "LLVM missing intrinsic: atan"

   sinh x  = (exp x - exp (-x)) / 2
   cosh x  = (exp x + exp (-x)) / 2
   asinh x = log (x + sqrt (x*x + 1))
   acosh x = log (x + sqrt (x*x - 1))
   atanh x = (log (1 + x) - log (1 - x)) / 2


twoPi ::
   (A.Transcendental a, A.RationalConstant a) =>
   T a
twoPi = 2 * Trans.pi

square ::
   (A.PseudoRing a) =>
   T a -> T a
square = lift1 A.square

{- |
The same as 'Algebraic.sqrt',
but needs only Algebraic constraint, not Transcendental.
-}
sqrt ::
   (A.Algebraic a) =>
   T a -> T a
sqrt = lift1 A.sqrt


min, max :: (A.Real a) => T a -> T a -> T a
min = lift2 A.min
max = lift2 A.max

limit :: (A.Real a) => (T a, T a) -> T a -> T a
limit (l,u) = max l . min u

fraction :: (A.Fraction a) => T a -> T a
fraction = lift1 A.fraction


instance (A.Real a, A.PseudoRing a, A.IntegerConstant a) =>
      Absolute.C (T a) where
   abs = lift1 A.abs
   signum = lift1 A.signum

{-
For useful instances with different scalar and vector type,
we would need a more flexible superclass.
-}
instance (A.Real a, A.IntegerConstant a, a ~ A.Scalar a, A.PseudoModule a) =>
      NormedSum.C (T a) (T a) where
   norm = lift1 A.abs

instance (A.Real a, A.IntegerConstant a, a ~ A.Scalar a, A.PseudoModule a) =>
      NormedEuc.Sqr (T a) (T a) where
   normSqr = lift1 A.square

instance
   (NormedEuc.Sqr (T a) (T v),
    A.RationalConstant a, A.Algebraic a) =>
      NormedEuc.C (T a) (T v) where
   norm = lift1 A.sqrt . NormedEuc.normSqr

{-
instance (A.Real a, A.IntegerConstant a, A.PseudoModule a a) =>
      NormedMax.C (T a) (T a) where
   norm = lift1 A.abs
-}


infix  4  %==, %/=, %<, %<=, %>=, %>

(%==), (%/=), (%<), (%<=), (%>), (%>=) ::
   (LLVM.CmpRet a) =>
   T (LLVM.Value a) -> T (LLVM.Value a) -> T (LLVM.Value (LLVM.CmpResult a))
(%==) = lift2 $ LLVM.cmp LLVM.CmpEQ
(%/=) = lift2 $ LLVM.cmp LLVM.CmpNE
(%>)  = lift2 $ LLVM.cmp LLVM.CmpGT
(%>=) = lift2 $ LLVM.cmp LLVM.CmpGE
(%<)  = lift2 $ LLVM.cmp LLVM.CmpLT
(%<=) = lift2 $ LLVM.cmp LLVM.CmpLE

infixr 3  %&&
infixr 2  %||

-- | Lazy AND
(%&&) :: T (LLVM.Value Bool) -> T (LLVM.Value Bool) -> T (LLVM.Value Bool)
a %&& b = a ? (b, constant False)

-- | Lazy OR
(%||) :: T (LLVM.Value Bool) -> T (LLVM.Value Bool) -> T (LLVM.Value Bool)
a %|| b = a ? (constant True, b)

not :: T (LLVM.Value Bool) -> T (LLVM.Value Bool)
not = lift1 LLVM.inv


infix  0 ?
{- |
@true ? (t,f)@ evaluates @t@,
@false ? (t,f)@ evaluates @f@.
@t@ and @f@ can reuse interim results,
but they cannot contribute shared results,
since only one of them will be run.
Cf. '(??)'
-}
(?) ::
   (Flatten value, Registers value ~ a, Phi a) =>
   T (LLVM.Value Bool) -> (value, value) -> value
c ? (t, f) =
   unfoldCode $ consUnique $ do
      b <- code c
      shared <- MS.get
      MT.lift $
         C.ifThenElse b
            (MS.evalStateT (flattenCode t) shared)
            (MS.evalStateT (flattenCode f) shared)

infix 0 ??
{- |
The expression @c ?? (t,f)@ evaluates both @t@ and @f@
and selects components from @t@ and @f@ according to @c@.
It is useful for vector values and
for sharing @t@ or @f@ with other branches of an expression.
-}
(??) ::
   (LLVM.IsFirstClass a, LLVM.CmpRet a) =>
   T (LLVM.Value (LLVM.CmpResult a)) ->
   (T (LLVM.Value a), T (LLVM.Value a)) ->
   T (LLVM.Value a)
c ?? (t, f) = lift3 LLVM.select c t f



lift0 ::
   (forall r. CodeGenFunction r a) ->
   T a
lift0 f =
   consUnique $ MT.lift $ f

lift1 ::
   (forall r. a -> CodeGenFunction r b) ->
   T a -> T b
lift1 f x =
   consUnique $ MT.lift . f =<< code x

lift2 ::
   (forall r. a -> b -> CodeGenFunction r c) ->
   T a -> T b -> T c
lift2 f x y =
   consUnique $ do
      xv <- code x
      yv <- code y
      MT.lift $ f xv yv

lift3 ::
   (forall r. a -> b -> c -> CodeGenFunction r d) ->
   T a -> T b -> T c -> T d
lift3 f x y z =
   consUnique $ do
      xv <- code x
      yv <- code y
      zv <- code z
      MT.lift $ f xv yv zv


_unlift0 ::
   T a ->
   (forall r. CodeGenFunction r a)
_unlift0 = decons

unlift0 ::
   (Flatten value) =>
   value ->
   (forall r. CodeGenFunction r (Registers value))
unlift0 = flatten

_unlift1 ::
   (T a -> T b) ->
   (forall r. a -> CodeGenFunction r b)
_unlift1 = unlift1

{-
Better type inference than flattenFunction.
-}
unlift1 ::
   (Flatten value) =>
   (T a -> value) ->
   (forall r. a -> CodeGenFunction r (Registers value))
unlift1 f a =
   flatten (f (constantValue a))

_unlift2 ::
   (T a -> T b -> T c) ->
   (forall r. a -> b -> CodeGenFunction r c)
_unlift2 = unlift2

unlift2 ::
   (Flatten value) =>
   (T a -> T b -> value) ->
   (forall r. a -> b -> CodeGenFunction r (Registers value))
unlift2 f a b =
   flatten (f (constantValue a) (constantValue b))

unlift3 ::
   (Flatten value) =>
   (T a -> T b -> T c -> value) ->
   (forall r. a -> b -> c -> CodeGenFunction r (Registers value))
unlift3 f a b c =
   flatten (f (constantValue a) (constantValue b) (constantValue c))

unlift4 ::
   (Flatten value) =>
   (T a -> T b -> T c -> T d -> value) ->
   (forall r. a -> b -> c -> d -> CodeGenFunction r (Registers value))
unlift4 f a b c d =
   flatten $
   f (constantValue a) (constantValue b) (constantValue c) (constantValue d)

unlift5 ::
   (Flatten value) =>
   (T a -> T b -> T c -> T d -> T e -> value) ->
   (forall r. a -> b -> c -> d -> e -> CodeGenFunction r (Registers value))
unlift5 f a b c d e =
   flatten $
   f (constantValue a) (constantValue b) (constantValue c)
      (constantValue d) (constantValue e)


constantValue :: a -> T a
constantValue x =
   consUnique (return x)

constant :: (LLVM.IsConst a) => a -> T (LLVM.Value a)
constant = constantValue . LLVM.valueOf

fromInteger' :: (A.IntegerConstant a) => Integer -> T a
fromInteger' = constantValue . A.fromInteger'

fromRational' :: (A.RationalConstant a) => P.Rational -> T a
fromRational' = constantValue . A.fromRational'


class Flatten value where
   type Registers value :: *
   flattenCode :: value -> Compute r (Registers value)
   unfoldCode :: T (Registers value) -> value

flatten ::
   (Flatten value) =>
   value -> CodeGenFunction r (Registers value)
flatten x = MS.evalStateT (flattenCode x) Vault.empty

unfold ::
   (Flatten value) =>
   (Registers value) -> value
unfold x = unfoldCode $ pure x

flattenCodeTraversable ::
   (Flatten value, Trav.Traversable f) =>
   f value -> Compute r (f (Registers value))
flattenCodeTraversable =
   Trav.mapM flattenCode

unfoldCodeTraversable ::
   (Flatten value, Trav.Traversable f, Applicative f) =>
   T (f (Registers value)) -> f value
unfoldCodeTraversable =
   unfoldFromGetters getters

unfoldFromGetters ::
   (Functor f, Flatten b) =>
   f (a -> Registers b) -> T a -> f b
unfoldFromGetters g x =
   fmap (unfoldCode . flip fmap x) g

getters ::
   (Trav.Traversable f, Applicative f) =>
   f (f a -> a)
getters =
   fmap (\n x -> Fold.toList x !! n) $
   MS.evalState (Trav.sequenceA (pure (MS.state $ \n -> (n, succ n)))) 0


flattenFunction ::
   (Flatten a, Flatten b) =>
   (a -> b) -> (Registers a -> CodeGenFunction r (Registers b))
flattenFunction f =
   flatten . f . unfold

{-
This function is hardly useful,
since most functions are not of type
@(Registers a -> (forall r. CodeGenFunction r (Registers b)))@
but of type
@(forall r. Registers a -> CodeGenFunction r (Registers b))@.
We would also need a method unfoldF.
See ValueUnfoldF for some implementations.

unfoldFunction ::
   (Flatten a, Flatten b) =>
   (Registers a -> (forall r. CodeGenFunction r (Registers b))) -> (a -> b)
unfoldFunction f x =
   unfoldF (f =<< flatten x)
-}


instance (Flatten a, Flatten b) => Flatten (a,b) where
   type Registers (a,b) = (Registers a, Registers b)
   flattenCode (a,b) =
      liftM2 (,) (flattenCode a) (flattenCode b)
   unfoldCode x =
      case unzip x of
         (a,b) -> (unfoldCode a, unfoldCode b)

instance (Flatten a, Flatten b, Flatten c) => Flatten (a,b,c) where
   type Registers (a,b,c) = (Registers a, Registers b, Registers c)
   flattenCode (a,b,c) =
      liftM3 (,,) (flattenCode a) (flattenCode b) (flattenCode c)
   unfoldCode x =
      case unzip3 x of
         (a,b,c) -> (unfoldCode a, unfoldCode b, unfoldCode c)

instance Flatten a => Flatten (Stereo.T a) where
   type Registers (Stereo.T a) = Stereo.T (Registers a)
   flattenCode = flattenCodeTraversable
   unfoldCode = unfoldCodeTraversable

instance Flatten a => Flatten (Complex.T a) where
   type Registers (Complex.T a) = Complex.T (Registers a)
   flattenCode s =
      liftM2 (Complex.+:)
         (flattenCode $ Complex.real s)
         (flattenCode $ Complex.imag s)
   unfoldCode =
      unfoldFromGetters $ Complex.real Complex.+: Complex.imag

instance (RealRing.C a, Flatten a) => Flatten (Phase.T a) where
   type Registers (Phase.T a) = Registers a
   flattenCode s =
      flattenCode $ Phase.toRepresentative s
   unfoldCode s =
      -- could also be unsafeFromRepresentative
      Phase.fromRepresentative $ unfoldCode s


instance Flatten (T a) where
   type Registers (T a) = a
   flattenCode = code
   unfoldCode = id

instance Flatten () where
   type Registers () = ()
   flattenCode = return
   unfoldCode _ = ()