{-# LANGUAGE FlexibleContexts #-}
module LLVM.Extra.Arithmetic (
   add, sub, inc, dec,
   mul, square, fdiv,
   idiv, irem,
   fcmp, cmp,
   and, or,
   min, max, abs,
   advanceArrayElementPtr,
   sqrt, sin, cos, exp, log, pow,
   ) where

import qualified LLVM.Core as LLVM
import LLVM.Core
   (Ptr, getElementPtr, value, valueOf, Value,
    CmpPredicate(CmpLE, CmpGE), CmpRet,
    FPPredicate,
    IsType, IsConst, IsInteger, IsFloating, IsArithmetic, IsFirstClass,
    CodeGenFunction, )

import Data.Word (Word32, )


import Prelude hiding (and, or, sqrt, sin, cos, exp, log, abs, min, max, )



-- * arithmetic with better type inference

add ::
   (IsArithmetic a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
add = LLVM.add

sub ::
   (IsArithmetic a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
sub = LLVM.sub

inc ::
   (IsArithmetic a, IsConst a, Num a) =>
   Value a -> CodeGenFunction r (Value a)
inc x = add x (valueOf 1)

dec ::
   (IsArithmetic a, IsConst a, Num a) =>
   Value a -> CodeGenFunction r (Value a)
dec x = sub x (valueOf 1)


mul ::
   (IsArithmetic a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
mul = LLVM.mul

square ::
   (IsArithmetic a) =>
   Value a -> CodeGenFunction r (Value a)
square x = mul x x


fdiv ::
   (IsFloating a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
fdiv = LLVM.fdiv

fcmp ::
   (IsFloating a, CmpRet a b) =>
   FPPredicate -> Value a -> Value a -> CodeGenFunction r (Value b)
fcmp = LLVM.fcmp


cmp ::
   (CmpRet a b) =>
   CmpPredicate -> Value a -> Value a -> CodeGenFunction r (Value b)
cmp = LLVM.cmp

idiv ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
idiv = LLVM.idiv

irem ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
irem = LLVM.irem


and ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
and = LLVM.and

or ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
or = LLVM.or



{- |
This would also work for vectors,
if LLVM would support 'select' with bool vectors as condition.
-}
min :: (IsFirstClass a, CmpRet a Bool) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
min = cmpSelect (cmp CmpLE)

max :: (IsFirstClass a, CmpRet a Bool) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
max = cmpSelect (cmp CmpGE)

abs :: (IsArithmetic a, CmpRet a Bool) =>
   Value a -> CodeGenFunction r (Value a)
abs x = do
   b <- cmp CmpGE x (value LLVM.zero)
   LLVM.select b x =<< LLVM.neg x


cmpSelect ::
   (IsFirstClass a, CmpRet a Bool) =>
   (Value a -> Value a -> CodeGenFunction r (Value Bool)) ->
   (Value a -> Value a -> CodeGenFunction r (Value a))
cmpSelect f x y =
   f x y >>= \b -> LLVM.select b x y



-- * pointers

advanceArrayElementPtr ::
   Value (Ptr o) ->
   CodeGenFunction r (Value (Ptr o))
advanceArrayElementPtr p =
   getElementPtr p (valueOf 1 :: Value Word32, ())



-- * transcendental functions


valueTypeName ::
   (IsType a) =>
   Value a -> String
valueTypeName =
   LLVM.typeName . (undefined :: Value a -> a)


callIntrinsic1 ::
   (IsFirstClass a) =>
   String -> Value a -> CodeGenFunction r (Value a)
callIntrinsic1 fn x = do
   op <- LLVM.externFunction ("llvm." ++ fn ++ "." ++ valueTypeName x)
   LLVM.call op x >>= addReadNone

callIntrinsic2 ::
   (IsFirstClass a) =>
   String -> Value a -> Value a -> CodeGenFunction r (Value a)
callIntrinsic2 fn x y = do
   op <- LLVM.externFunction ("llvm." ++ fn ++ "." ++ valueTypeName x)
   LLVM.call op x y >>= addReadNone


{-
If we add the attribute, then LLVM-2.8 complains:

$ ./dist/build/synthi-llvm-test/synthi-llvm-test
Attribute readnone only applies to the function!
  %97 = call readnone float @llvm.sin.f32(float %96)
Attribute readnone only applies to the function!
  %99 = call readnone float @llvm.exp.f32(float %98)
Attribute readnone only applies to the function!
  %102 = call readnone float @llvm.cos.f32(float %101)
Broken module found, compilation aborted!
Stack dump:
0.      Running pass 'Function Pass Manager' on module '_module'.
1.      Running pass 'Module Verifier' on function '@fillsignal'
make: *** [test] Abgebrochen
-}
addReadNone :: Value a -> CodeGenFunction r (Value a)
addReadNone x = do
--   LLVM.addAttributes x 0 [LLVM.ReadNoneAttribute]
   return x



sqrt, sin, cos, exp, log ::
   (IsFloating a) =>
   Value a -> CodeGenFunction r (Value a)
sqrt = callIntrinsic1 "sqrt"
sin = callIntrinsic1 "sin"
cos = callIntrinsic1 "cos"
exp = callIntrinsic1 "exp"
log = callIntrinsic1 "log"

pow ::
   (IsFloating a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
pow = callIntrinsic2 "pow"