{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module LLVM.DSL.Expression where

import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Tuple as LLTuple
import qualified LLVM.Extra.FastMath as FastMath
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import qualified LLVM.Core as LLVM
import LLVM.Extra.Multi.Value (PatternTuple, Decomposed, Atom)

import qualified Control.Monad.HT as Monad
import Control.Monad.IO.Class (liftIO)

import qualified Data.Enum.Storable as Enum
import qualified Data.Tuple.HT as TupleHT
import qualified Data.Tuple as Tuple
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Data.Complex (Complex((:+)))
import Data.Bool8 (Bool8)

import qualified Foreign.Storable.Record.Tuple as StTuple

import qualified Algebra.Transcendental as Trans
import qualified Algebra.Algebraic as Algebraic
import qualified Algebra.Absolute as Absolute
import qualified Algebra.Module as Module
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring
import qualified Algebra.Additive as Additive

import System.IO.Unsafe (unsafePerformIO)

import qualified Prelude as P
import Prelude hiding
   (fst, snd, min, max, zip, unzip, zip3, unzip3,
    curry, uncurry, recip, pi, sqrt, maybe, toEnum, fromEnum, pred, succ)


newtype Exp a = Exp {unExp :: forall r. LLVM.CodeGenFunction r (MultiValue.T a)}


{-
Using IORef should be thread-safe here,
because you cannot fork within CodeGenFunction.
-}
unique :: (forall r. LLVM.CodeGenFunction r (MultiValue.T a)) -> Exp a
unique = Exp

_unique :: (forall r. LLVM.CodeGenFunction r (MultiValue.T a)) -> Exp a
_unique code = unsafePerformIO $ fmap (withKey code) $ newIORef Nothing

withKey ::
   (forall r. LLVM.CodeGenFunction r (MultiValue.T a)) ->
   IORef (Maybe (MultiValue.T a)) -> Exp a
withKey code ref =
   Exp (do
      ma <- liftIO $ readIORef ref
      case ma of
         Just a -> return a
         Nothing -> do
            a <- code
            liftIO $ writeIORef ref $ Just a
            return a)


class Value val where
   lift0 :: MultiValue.T a -> val a
   lift1 ::
      (MultiValue.T a -> MultiValue.T b) ->
      val a -> val b
   lift2 ::
      (MultiValue.T a -> MultiValue.T b -> MultiValue.T c) ->
      val a -> val b -> val c

instance Value MultiValue.T where
   lift0 = id
   lift1 = id
   lift2 = id

instance Value Exp where
   lift0 a = unique (return a)
   lift1 f (Exp a) = unique (Monad.lift f a)
   lift2 f (Exp a) (Exp b) = unique (Monad.lift2 f a b)

lift3 ::
   (Value val) =>
   (MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> MultiValue.T d) ->
   val a -> val b -> val c -> val d
lift3 f a b = lift2 (MultiValue.uncurry f) (zip a b)

lift4 ::
   (Value val) =>
   (MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> MultiValue.T d ->
    MultiValue.T e) ->
   val a -> val b -> val c -> val d -> val e
lift4 f a b = lift3 (MultiValue.uncurry f) (zip a b)



liftM ::
   (forall r.
    MultiValue.T a ->
    LLVM.CodeGenFunction r (MultiValue.T b)) ->
   (Exp a -> Exp b)
liftM f (Exp a) = unique (f =<< a)

liftM2 ::
   (forall r.
    MultiValue.T a -> MultiValue.T b ->
    LLVM.CodeGenFunction r (MultiValue.T c)) ->
   (Exp a -> Exp b -> Exp c)
liftM2 f (Exp a) (Exp b) = unique (Monad.liftJoin2 f a b)

liftM3 ::
   (forall r.
    MultiValue.T a -> MultiValue.T b -> MultiValue.T c ->
    LLVM.CodeGenFunction r (MultiValue.T d)) ->
   (Exp a -> Exp b -> Exp c -> Exp d)
liftM3 f (Exp a) (Exp b) (Exp c) = unique (Monad.liftJoin3 f a b c)


unliftM1 ::
   (Exp a -> Exp b) ->
   MultiValue.T a -> LLVM.CodeGenFunction r (MultiValue.T b)
unliftM1 f ix = unExp (f (lift0 ix))

unliftM2 ::
   (Exp a -> Exp b -> Exp c) ->
   MultiValue.T a -> MultiValue.T b ->
   LLVM.CodeGenFunction r (MultiValue.T c)
unliftM2 f ix jx = unExp (f (lift0 ix) (lift0 jx))

unliftM3 ::
   (Exp a -> Exp b -> Exp c -> Exp d) ->
   MultiValue.T a -> MultiValue.T b -> MultiValue.T c ->
   LLVM.CodeGenFunction r (MultiValue.T d)
unliftM3 f ix jx kx = unExp (f (lift0 ix) (lift0 jx) (lift0 kx))


liftTupleM ::
   (forall r.
    LLTuple.ValueOf a ->
    LLVM.CodeGenFunction r (LLTuple.ValueOf b)) ->
   (Exp a -> Exp b)
liftTupleM f = liftM (MultiValue.liftM f)

liftTupleM2 ::
   (forall r.
    LLTuple.ValueOf a -> LLTuple.ValueOf b ->
    LLVM.CodeGenFunction r (LLTuple.ValueOf c)) ->
   (Exp a -> Exp b -> Exp c)
liftTupleM2 f = liftM2 (MultiValue.liftM2 f)

liftTupleM3 ::
   (forall r.
    LLTuple.ValueOf a -> LLTuple.ValueOf b -> LLTuple.ValueOf c ->
    LLVM.CodeGenFunction r (LLTuple.ValueOf d)) ->
   (Exp a -> Exp b -> Exp c -> Exp d)
liftTupleM3 f = liftM3 (MultiValue.liftM3 f)



zip :: (Value val) => val a -> val b -> val (a, b)
zip = lift2 MultiValue.zip

zip3 :: (Value val) => val a -> val b -> val c -> val (a, b, c)
zip3 = lift3 MultiValue.zip3

zip4 :: (Value val) => val a -> val b -> val c -> val d -> val (a, b, c, d)
zip4 = lift4 MultiValue.zip4

unzip :: (Value val) => val (a, b) -> (val a, val b)
unzip ab = (fst ab, snd ab)

unzip3 :: (Value val) => val (a, b, c) -> (val a, val b, val c)
unzip3 abc = (fst3 abc, snd3 abc, thd3 abc)

unzip4 :: (Value val) => val (a, b, c, d) -> (val a, val b, val c, val d)
unzip4 abcd =
   (lift1 (\(MultiValue.Cons (a,_,_,_)) -> MultiValue.Cons a) abcd,
    lift1 (\(MultiValue.Cons (_,b,_,_)) -> MultiValue.Cons b) abcd,
    lift1 (\(MultiValue.Cons (_,_,c,_)) -> MultiValue.Cons c) abcd,
    lift1 (\(MultiValue.Cons (_,_,_,d)) -> MultiValue.Cons d) abcd)


fst :: (Value val) => val (a, b) -> val a
fst = lift1 MultiValue.fst

snd :: (Value val) => val (a, b) -> val b
snd = lift1 MultiValue.snd

mapFst :: (Exp a -> Exp b) -> Exp (a, c) -> Exp (b, c)
mapFst f = liftM (MultiValue.mapFstF (unliftM1 f))

mapSnd :: (Exp b -> Exp c) -> Exp (a, b) -> Exp (a, c)
mapSnd f = liftM (MultiValue.mapSndF (unliftM1 f))

mapPair :: (Exp a0 -> Exp a1, Exp b0 -> Exp b1) -> Exp (a0, b0) -> Exp (a1, b1)
mapPair (f,g) = mapFst f . mapSnd g

swap :: (Value val) => val (a, b) -> val (b, a)
swap = lift1 MultiValue.swap

curry :: (Exp (a,b) -> c) -> (Exp a -> Exp b -> c)
curry f = Tuple.curry (f . Tuple.uncurry zip)

uncurry :: (Exp a -> Exp b -> c) -> (Exp (a,b) -> c)
uncurry f = Tuple.uncurry f . unzip


fst3 :: (Value val) => val (a,b,c) -> val a
fst3 = lift1 MultiValue.fst3

snd3 :: (Value val) => val (a,b,c) -> val b
snd3 = lift1 MultiValue.snd3

thd3 :: (Value val) => val (a,b,c) -> val c
thd3 = lift1 MultiValue.thd3

mapFst3 :: (Exp a0 -> Exp a1) -> Exp (a0,b,c) -> Exp (a1,b,c)
mapFst3 f = liftM (MultiValue.mapFst3F (unliftM1 f))

mapSnd3 :: (Exp b0 -> Exp b1) -> Exp (a,b0,c) -> Exp (a,b1,c)
mapSnd3 f = liftM (MultiValue.mapSnd3F (unliftM1 f))

mapThd3 :: (Exp c0 -> Exp c1) -> Exp (a,b,c0) -> Exp (a,b,c1)
mapThd3 f = liftM (MultiValue.mapThd3F (unliftM1 f))

mapTriple ::
   (Exp a0 -> Exp a1, Exp b0 -> Exp b1, Exp c0 -> Exp c1) ->
   Exp (a0,b0,c0) -> Exp (a1,b1,c1)
mapTriple (f,g,h) = mapFst3 f . mapSnd3 g . mapThd3 h


tuple :: Exp tuple -> Exp (StTuple.Tuple tuple)
tuple = lift1 MultiValue.tuple

untuple :: Exp (StTuple.Tuple tuple) -> Exp tuple
untuple = lift1 MultiValue.untuple


modifyMultiValue ::
   (Value val,
    MultiValue.Compose a,
    MultiValue.Decompose pattern,
    MultiValue.PatternTuple pattern ~ tuple) =>
   pattern ->
   (Decomposed MultiValue.T pattern -> a) ->
   val tuple -> val (MultiValue.Composed a)
modifyMultiValue p f = lift1 $ MultiValue.modify p f

modifyMultiValue2 ::
   (Value val,
    MultiValue.Compose a,
    MultiValue.Decompose patternA,
    MultiValue.Decompose patternB,
    MultiValue.PatternTuple patternA ~ tupleA,
    MultiValue.PatternTuple patternB ~ tupleB) =>
   patternA ->
   patternB ->
   (Decomposed MultiValue.T patternA ->
    Decomposed MultiValue.T patternB -> a) ->
   val tupleA -> val tupleB -> val (MultiValue.Composed a)
modifyMultiValue2 pa pb f = lift2 $ MultiValue.modify2 pa pb f

modifyMultiValueM ::
   (MultiValue.Compose a,
    MultiValue.Decompose pattern,
    MultiValue.PatternTuple pattern ~ tuple) =>
   pattern ->
   (forall r.
    Decomposed MultiValue.T pattern ->
    LLVM.CodeGenFunction r a) ->
   Exp tuple -> Exp (MultiValue.Composed a)
modifyMultiValueM p f = liftM (MultiValue.modifyF p f)

modifyMultiValueM2 ::
   (MultiValue.Compose a,
    MultiValue.Decompose patternA,
    MultiValue.Decompose patternB,
    MultiValue.PatternTuple patternA ~ tupleA,
    MultiValue.PatternTuple patternB ~ tupleB) =>
   patternA ->
   patternB ->
   (forall r.
    Decomposed MultiValue.T patternA ->
    Decomposed MultiValue.T patternB ->
    LLVM.CodeGenFunction r a) ->
   Exp tupleA -> Exp tupleB -> Exp (MultiValue.Composed a)
modifyMultiValueM2 pa pb f = liftM2 (MultiValue.modifyF2 pa pb f)


class Compose multituple where
   type Composed multituple
   {- |
   A nested 'zip'.
   -}
   compose :: multituple -> Exp (Composed multituple)

class
   (Composed (Decomposed Exp pattern) ~ PatternTuple pattern) =>
      Decompose pattern where
   {- |
   Analogous to 'MultiValue.decompose'.
   -}
   decompose :: pattern -> Exp (PatternTuple pattern) -> Decomposed Exp pattern


{- |
Analogus to 'MultiValue.modifyMultiValue'.
-}
modify ::
   (Compose a, Decompose pattern) =>
   pattern ->
   (Decomposed Exp pattern -> a) ->
   Exp (PatternTuple pattern) -> Exp (Composed a)
modify p f = compose . f . decompose p

modify2 ::
   (Compose a, Decompose patternA, Decompose patternB) =>
   patternA ->
   patternB ->
   (Decomposed Exp patternA -> Decomposed Exp patternB -> a) ->
   Exp (PatternTuple patternA) ->
   Exp (PatternTuple patternB) -> Exp (Composed a)
modify2 pa pb f a b = compose $ f (decompose pa a) (decompose pb b)



instance Compose (Exp a) where
   type Composed (Exp a) = a
   compose = id

instance Decompose (Atom a) where
   decompose _ = id



instance Compose () where
   type Composed () = ()
   compose = cons

instance Decompose () where
   decompose _ _ = ()


instance (Compose a, Compose b) => Compose (a,b) where
   type Composed (a,b) = (Composed a, Composed b)
   compose = Tuple.uncurry zip . TupleHT.mapPair (compose, compose)

instance (Decompose pa, Decompose pb) => Decompose (pa,pb) where
   decompose (pa,pb) =
      TupleHT.mapPair (decompose pa, decompose pb) . unzip


instance (Compose a, Compose b, Compose c) => Compose (a,b,c) where
   type Composed (a,b,c) = (Composed a, Composed b, Composed c)
   compose =
      TupleHT.uncurry3 zip3 . TupleHT.mapTriple (compose, compose, compose)

instance
   (Decompose pa, Decompose pb, Decompose pc) =>
      Decompose (pa,pb,pc) where
   decompose (pa,pb,pc) =
      TupleHT.mapTriple (decompose pa, decompose pb, decompose pc) . unzip3


instance (Compose a, Compose b, Compose c, Compose d) => Compose (a,b,c,d) where
   type Composed (a,b,c,d) = (Composed a, Composed b, Composed c, Composed d)
   compose (a,b,c,d) = zip4 (compose a) (compose b) (compose c) (compose d)

instance
   (Decompose pa, Decompose pb, Decompose pc, Decompose pd) =>
      Decompose (pa,pb,pc,pd) where
   decompose (pa,pb,pc,pd) x =
      case unzip4 x of
         (a,b,c,d) ->
            (decompose pa a, decompose pb b, decompose pc c, decompose pd d)


instance (Compose tuple) => Compose (StTuple.Tuple tuple) where
   type Composed (StTuple.Tuple tuple) = StTuple.Tuple (Composed tuple)
   compose (StTuple.Tuple tup) = tuple $ compose tup

instance (Decompose p) => Decompose (StTuple.Tuple p) where
   decompose (StTuple.Tuple p) = StTuple.Tuple . decompose p . untuple


instance (Compose a) => Compose (Complex a) where
   type Composed (Complex a) = Complex (Composed a)
   compose (r:+i) = consComplex (compose r) (compose i)

instance (Decompose p) => Decompose (Complex p) where
   decompose (pr:+pi) =
      Tuple.uncurry (:+) .
      TupleHT.mapPair (decompose pr, decompose pi) . deconsComplex

{- |
You can construct complex numbers this way,
but they will not make you happy,
because the numeric operations require a RealFloat instance
that we could only provide with lots of undefined methods
(also in its superclasses).
You may either define your own arithmetic
or use the NumericPrelude type classes.
-}
consComplex :: Exp a -> Exp a -> Exp (Complex a)
consComplex = lift2 MultiValue.consComplex

deconsComplex :: Exp (Complex a) -> (Exp a, Exp a)
deconsComplex c = (lift1 MultiValue.realPart c, lift1 MultiValue.imagPart c)



cons :: (MultiValue.C a) => a -> Exp a
cons = lift0 . MultiValue.cons

unit :: Exp ()
unit = cons ()

zero :: (MultiValue.C a) => Exp a
zero = lift0 MultiValue.zero

add :: (MultiValue.Additive a) => Exp a -> Exp a -> Exp a
add = liftM2 MultiValue.add

sub :: (MultiValue.Additive a) => Exp a -> Exp a -> Exp a
sub = liftM2 MultiValue.sub

neg :: (MultiValue.Additive a) => Exp a -> Exp a
neg = liftM MultiValue.neg

one :: (MultiValue.IntegerConstant a) => Exp a
one = fromInteger' 1

mul :: (MultiValue.PseudoRing a) => Exp a -> Exp a -> Exp a
mul = liftM2 MultiValue.mul

sqr :: (MultiValue.PseudoRing a) => Exp a -> Exp a
sqr = liftM $ \x -> MultiValue.mul x x

recip :: (MultiValue.Field a, MultiValue.IntegerConstant a) => Exp a -> Exp a
recip = fdiv one

fdiv :: (MultiValue.Field a) => Exp a -> Exp a -> Exp a
fdiv = liftM2 MultiValue.fdiv

sqrt :: (MultiValue.Algebraic a) => Exp a -> Exp a
sqrt = liftM MultiValue.sqrt

pow :: (MultiValue.Transcendental a) => Exp a -> Exp a -> Exp a
pow = liftM2 MultiValue.pow

idiv :: (MultiValue.Integral a) => Exp a -> Exp a -> Exp a
idiv = liftM2 MultiValue.idiv

irem :: (MultiValue.Integral a) => Exp a -> Exp a -> Exp a
irem = liftM2 MultiValue.irem

shl :: (MultiValue.BitShift a) => Exp a -> Exp a -> Exp a
shl = liftM2 MultiValue.shl

shr :: (MultiValue.BitShift a) => Exp a -> Exp a -> Exp a
shr = liftM2 MultiValue.shr

fromInteger' :: (MultiValue.IntegerConstant a) => Integer -> Exp a
fromInteger' = lift0 . MultiValue.fromInteger'

fromRational' :: (MultiValue.RationalConstant a) => Rational -> Exp a
fromRational' = lift0 . MultiValue.fromRational'


boolPFrom8 :: Exp Bool8 -> Exp Bool
boolPFrom8 = lift1 MultiValue.boolPFrom8

bool8FromP :: Exp Bool -> Exp Bool8
bool8FromP = lift1 MultiValue.bool8FromP

intFromBool8 :: (MultiValue.NativeInteger i ir) => Exp Bool8 -> Exp i
intFromBool8 = liftM MultiValue.intFromBool8

floatFromBool8 :: (MultiValue.NativeFloating a ar) => Exp Bool8 -> Exp a
floatFromBool8 = liftM MultiValue.floatFromBool8


toEnum ::
   (LLTuple.ValueOf w ~ LLVM.Value w) =>
   Exp w -> Exp (Enum.T w e)
toEnum = lift1 MultiValue.toEnum

fromEnum ::
   (LLTuple.ValueOf w ~ LLVM.Value w) =>
   Exp (Enum.T w e) -> Exp w
fromEnum = lift1 MultiValue.fromEnum

succ, pred ::
   (LLVM.IsArithmetic w, SoV.IntegerConstant w) =>
   Exp (Enum.T w e) -> Exp (Enum.T w e)
succ = liftM MultiValue.succ
pred = liftM MultiValue.pred


fromFastMath :: Exp (FastMath.Number flags a) -> Exp a
fromFastMath = lift1 FastMath.mvDenumber

toFastMath :: Exp a -> Exp (FastMath.Number flags a)
toFastMath = lift1 FastMath.mvNumber


minBound, maxBound :: (MultiValue.Bounded a) => Exp a
minBound = lift0 MultiValue.minBound
maxBound = lift0 MultiValue.maxBound


cmp ::
   (MultiValue.Comparison a) =>
   LLVM.CmpPredicate -> Exp a -> Exp a -> Exp Bool
cmp ord = liftM2 (MultiValue.cmp ord)

infix 4 ==*, /=*, <*, <=*, >*, >=*

(==*), (/=*), (<*), (>=*), (>*), (<=*) ::
   (MultiValue.Comparison a) => Exp a -> Exp a -> Exp Bool
(==*) = cmp LLVM.CmpEQ
(/=*) = cmp LLVM.CmpNE
(<*)  = cmp LLVM.CmpLT
(>=*) = cmp LLVM.CmpGE
(>*)  = cmp LLVM.CmpGT
(<=*) = cmp LLVM.CmpLE


min, max :: (MultiValue.Real a) => Exp a -> Exp a -> Exp a
min = liftM2 A.min
max = liftM2 A.max

limit :: (MultiValue.Real a) => (Exp a, Exp a) -> Exp a -> Exp a
limit (l,u) = max l . min u

fraction :: (MultiValue.Fraction a) => Exp a -> Exp a
fraction = liftM MultiValue.fraction


true, false :: Exp Bool
true = cons True
false = cons False

infixr 3 &&*
(&&*) :: Exp Bool -> Exp Bool -> Exp Bool
(&&*) = liftM2 MultiValue.and

infixr 2 ||*
(||*) :: Exp Bool -> Exp Bool -> Exp Bool
(||*) = liftM2 MultiValue.or

not :: Exp Bool -> Exp Bool
not = liftM MultiValue.inv

{- |
Like 'ifThenElse' but computes both alternative expressions
and then uses LLVM's efficient @select@ instruction.
-}
select :: (MultiValue.Select a) => Exp Bool -> Exp a -> Exp a -> Exp a
select = liftM3 MultiValue.select

ifThenElse :: (MultiValue.C a) => Exp Bool -> Exp a -> Exp a -> Exp a
ifThenElse ec ex ey =
   unique (do
      MultiValue.Cons c <- unExp ec
      C.ifThenElse c (unExp ex) (unExp ey))


complement :: (MultiValue.Logic a) => Exp a -> Exp a
complement = liftM MultiValue.inv

infixl 7 .&.*
(.&.*) :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a
(.&.*) = liftM2 MultiValue.and

infixl 5 .|.*
(.|.*) :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a
(.|.*) = liftM2 MultiValue.or

infixl 6 `xor`
xor :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a
xor = liftM2 MultiValue.xor


toMaybe :: Exp Bool -> Exp a -> Exp (Maybe a)
toMaybe = lift2 MultiValue.toMaybe

maybe :: (MultiValue.C b) => Exp b -> (Exp a -> Exp b) -> Exp (Maybe a) -> Exp b
maybe n j = liftM $ \m -> do
   let (MultiValue.Cons b, a) = MultiValue.splitMaybe m
   C.ifThenElse b (unliftM1 j a) (unExp n)


instance
   (MultiValue.PseudoRing a, MultiValue.Real a, MultiValue.IntegerConstant a) =>
      Num (Exp a) where
   fromInteger = fromInteger'
   (+) = add
   (-) = sub
   negate = neg
   (*) = mul
   abs = liftM MultiValue.abs
   signum = liftM MultiValue.signum

instance
   (MultiValue.Field a, MultiValue.Real a, MultiValue.RationalConstant a) =>
      Fractional (Exp a) where
   fromRational = fromRational'
   (/) = fdiv

instance
   (MultiValue.Transcendental a, MultiValue.Real a,
    MultiValue.RationalConstant a) =>
      Floating (Exp a) where
   pi = unique MultiValue.pi
   sin = liftM MultiValue.sin
   cos = liftM MultiValue.cos
   sqrt = sqrt
   (**) = pow
   exp = liftM MultiValue.exp
   log = liftM MultiValue.log

   asin _ = error "LLVM missing intrinsic: asin"
   acos _ = error "LLVM missing intrinsic: acos"
   atan _ = error "LLVM missing intrinsic: atan"

   sinh x  = (exp x - exp (-x)) / 2
   cosh x  = (exp x + exp (-x)) / 2
   asinh x = log (x + sqrt (x*x + 1))
   acosh x = log (x + sqrt (x*x - 1))
   atanh x = (log (1 + x) - log (1 - x)) / 2


{- |
We do not require a numeric prelude superclass,
thus also LLVM only types like vectors are instances.
-}
instance (MultiValue.Additive a) => Additive.C (Exp a) where
   zero = zero
   (+) = add
   (-) = sub
   negate = neg

instance
   (MultiValue.PseudoRing a, MultiValue.IntegerConstant a) =>
      Ring.C (Exp a) where
   one = one
   (*) = mul
   fromInteger = fromInteger'

{-
This instance is enough for Module here.
The difference to Module instances on Haskell tuples is,
that LLVM vectors cannot be nested.
-}
instance
   (a ~ MultiValue.Scalar v,
    MultiValue.PseudoModule v, MultiValue.IntegerConstant a) =>
      Module.C (Exp a) (Exp v) where
   (*>) = liftM2 MultiValue.scale

instance
   (MultiValue.Field a, MultiValue.RationalConstant a) =>
      Field.C (Exp a) where
   (/) = fdiv
   fromRational' = fromRational' . Field.fromRational'

instance
   (MultiValue.Transcendental a, MultiValue.RationalConstant a) =>
      Algebraic.C (Exp a) where
   sqrt = sqrt
   root n x = pow x (recip $ fromInteger' n)
   x^/r = pow x (Field.fromRational' r)


tau :: (MultiValue.Transcendental a, MultiValue.RationalConstant a) => Exp a
tau = mul (fromInteger' 2) Trans.pi

instance
   (MultiValue.Transcendental a, MultiValue.RationalConstant a) =>
      Trans.C (Exp a) where
   pi = unique MultiValue.pi
   sin = liftM MultiValue.sin
   cos = liftM MultiValue.cos
   (**) = pow
   exp = liftM MultiValue.exp
   log = liftM MultiValue.log

   asin _ = error "LLVM missing intrinsic: asin"
   acos _ = error "LLVM missing intrinsic: acos"
   atan _ = error "LLVM missing intrinsic: atan"


instance
   (MultiValue.Real a, MultiValue.PseudoRing a, MultiValue.IntegerConstant a) =>
      Absolute.C (Exp a) where
   abs = liftM MultiValue.abs
   signum = liftM MultiValue.signum