{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
module Synthesizer.LLVM.Simple.Value where

import qualified LLVM.Extra.ScalarOrVector as SoV

import qualified LLVM.Extra.Arithmetic as A

import LLVM.Core hiding (zero, )
import qualified LLVM.Core as LLVM
import qualified LLVM.Util.Arithmetic as Arith

import qualified Synthesizer.Basic.Phase as Phase

import Control.Monad (liftM2, liftM3, )

import qualified Synthesizer.LLVM.Frame.Stereo as Stereo

import qualified Algebra.Transcendental as Trans
import qualified Algebra.Algebraic as Algebraic
import qualified Algebra.RealRing as RealRing
import qualified Algebra.Module as Module
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring
import qualified Algebra.Additive as Additive

import qualified Data.Traversable as Trav

import NumericPrelude.Numeric
import NumericPrelude.Base


{-
The @r@ type parameter must be hidden and forall-quantified
because otherwise we would need an impossible type
where we have to quantify for @r@ and @t@ in different scopes
while having a class constraint that involves both of them.

> osci ::
>    (RealRing.C (Value.T r t),
>     IsFirstClass t, IsSized t size, IsFloating t,
>     IsPrimitive t, IsConst t) =>
>    (forall r. Wave.T (Value.T r t) (Value.T r y)) ->
>    t -> t -> T (Value y)

-}
newtype T a = Cons {decons :: forall r. Arith.TValue r a}

{- |
We do not require a numeric prelude superclass,
thus also LLVM only types like vectors are instances.
-}
instance (IsArithmetic a, IsConst a) => Additive.C (T a) where
   zero = constantValue (value LLVM.zero)
   (+) = binop add
   (-) = binop sub
   negate (Cons x) = Cons (neg =<< x)

instance (Ring.C a, IsArithmetic a, IsConst a) =>
      Ring.C (T a) where
   one = constant one
   (*) = binop mul
   fromInteger = constant . fromInteger

{-
Two instance declarations are enough for Module here.
The difference to Module instances on Haskell tuples is,
that LLVM vectors cannot be nested.
-}
instance (Ring.C a, IsArithmetic a, IsConst a) =>
      Module.C (T a) (T a) where
   (*>) = (*)

instance (Ring.C a, IsArithmetic a, IsConst a, IsPrimitive a, IsPowerOf2 n) =>
      Module.C (T a) (T (Vector n a)) where
   (Cons a) *> (Cons v) = Cons (do
         a0 <- a
         a1 <- SoV.replicate a0
         A.mul a1 =<< v
      )

instance (Ring.C a, IsArithmetic a, IsConst a) => Enum (T a) where
   succ x = x + one
   pred x = x - one
   fromEnum _ = error "CodeGenFunction Value: fromEnum"
   toEnum = fromIntegral

{-
instance (IsArithmetic a, Cmp a b, Num a, IsConst a) => Real (T a) where
   toRational _ = error "CodeGenFunction Value: toRational"

instance (Cmp a b, Num a, IsConst a, IsInteger a) => Integral (T a) where
   quot = binop (if (isSigned (undefined :: a)) then sdiv else udiv)
   rem  = binop (if (isSigned (undefined :: a)) then srem else urem)
   quotRem x y = (quot x y, rem x y)
   toInteger _ = error "CodeGenFunction Value: toInteger"
-}

instance (Field.C a, IsConst a, IsFloating a) => Field.C (T a) where
   (/) = binop fdiv
   fromRational' = constant . fromRational'

{-
instance (Cmp a b, Fractional a, IsConst a, IsFloating a) => RealFrac (T a) where
   properFraction _ = error "CodeGenFunction Value: properFraction"
-}

instance (Algebraic.C a, IsConst a, IsFloating a) => Algebraic.C (T a) where
   sqrt = lift1 A.sqrt

instance (Trans.C a, IsConst a, IsFloating a) => Trans.C (T a) where
   pi = constant pi
   sin = lift1 A.sin
   cos = lift1 A.cos
   (**) = lift2 A.pow
   exp = lift1 A.exp
   log = lift1 A.log

   asin _ = error "LLVM missing intrinsic: asin"
   acos _ = error "LLVM missing intrinsic: acos"
   atan _ = error "LLVM missing intrinsic: atan"

{-
   sinh x           = (exp x - exp (-x)) / 2
   cosh x           = (exp x + exp (-x)) / 2
   asinh x          = log (x + sqrt (x*x + 1))
   acosh x          = log (x + sqrt (x*x - 1))
   atanh x          = (log (1 + x) - log (1 - x)) / 2
-}


twoPi ::
   (Trans.C a, IsConst a, IsFloating a) =>
   T a
twoPi = 2*pi
{-
twoPi ::
   (Cmp a b, P.Floating a, IsConst a, IsFloating a) =>
   Arith.TValue r a
twoPi = P.fromInteger 2 P.* P.pi
-}


lift1 ::
   (forall r. Value a -> CodeGenFunction r (Value b)) ->
   T a -> T b
lift1 f x =
   Cons $ f =<< decons x

lift2 ::
   (forall r. Value a -> Value b -> CodeGenFunction r (Value c)) ->
   T a -> T b -> T c
lift2 f x y =
   Cons $ uncurry f =<< liftM2 (,) (decons x) (decons y)


constantValue :: Value a -> T a
constantValue x =
   Cons (return x)

constant :: (IsConst a) => a -> T a
constant = constantValue . valueOf

binop ::
   (forall r. Value a -> Value b -> Arith.TValue r c) ->
   T a -> T b -> T c
binop op x y = Cons (do
   x' <- decons x
   y' <- decons y
   op x' y')


class Flatten value register | value -> register where
   flatten :: value -> CodeGenFunction r register
   unfold :: register -> value

flattenTraversable ::
   (Flatten value register, Trav.Traversable f) =>
   f value -> CodeGenFunction r (f register)
flattenTraversable =
   Trav.mapM flatten

unfoldFunctor ::
   (Flatten value register, Functor f) =>
   f register -> f value
unfoldFunctor =
   fmap unfold


instance (Flatten ah al, Flatten bh bl) =>
      Flatten (ah,bh) (al,bl) where
   flatten (a,b) =
      liftM2 (,) (flatten a) (flatten b)
   unfold (a,b) =
      (unfold a, unfold b)

instance (Flatten ah al, Flatten bh bl, Flatten ch cl) =>
      Flatten (ah,bh,ch) (al,bl,cl) where
   flatten (a,b,c) =
      liftM3 (,,) (flatten a) (flatten b) (flatten c)
   unfold (a,b,c) =
      (unfold a, unfold b, unfold c)

instance Flatten v r =>
      Flatten (Stereo.T v) (Stereo.T r) where
   flatten s =
      liftM2 Stereo.cons
         (flatten $ Stereo.left s)
         (flatten $ Stereo.right s)
   unfold s =
      Stereo.cons
         (unfold $ Stereo.left s)
         (unfold $ Stereo.right s)

instance
   (RealRing.C v, Flatten v r) =>
      Flatten (Phase.T v) r where
   flatten s =
      flatten $ Phase.toRepresentative s
   unfold s =
      -- could also be unsafeFromRepresentative
      Phase.fromRepresentative $ unfold s


instance (IsConst a) => Flatten (T a) (Value a) where
   flatten = decons
   unfold  = constantValue
instance Flatten () () where
   flatten = return
   unfold  = id