{-# LANGUAGE FlexibleContexts #-}
module LLVM.Extra.Arithmetic (
   add, sub, inc, dec,
   mul, square, fdiv,
   udiv, urem,
   fcmp, icmp,
   and, or,
   umin, umax,
   smin, smax, sabs,
   fmin, fmax, fabs,
   advanceArrayElementPtr,
   sqrt, sin, cos, exp, log, pow,
   ) where

import qualified LLVM.Core as LLVM
import LLVM.Core
   (Ptr, getElementPtr, value, valueOf, Value,
    IntPredicate(IntULE, IntSLE, IntUGE, IntSGE),
    FPPredicate(FPOLE, FPOGE),
    IsIntegerOrPointer,
    IsType, IsConst, IsInteger, IsFloating, IsArithmetic, IsFirstClass,
    CmpRet,
    CodeGenFunction, )

import Data.Word (Word32, )


import Prelude hiding (and, or, sqrt, sin, cos, exp, log, )



-- * arithmetic with better type inference

add ::
   (IsArithmetic a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
add = LLVM.add

sub ::
   (IsArithmetic a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
sub = LLVM.sub

inc ::
   (IsArithmetic a, IsConst a, Num a) =>
   Value a -> CodeGenFunction r (Value a)
inc x = add x (valueOf 1)

dec ::
   (IsArithmetic a, IsConst a, Num a) =>
   Value a -> CodeGenFunction r (Value a)
dec x = sub x (valueOf 1)


mul ::
   (IsArithmetic a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
mul = LLVM.mul

square ::
   (IsArithmetic a) =>
   Value a -> CodeGenFunction r (Value a)
square x = mul x x


fdiv ::
   (IsFloating a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
fdiv = LLVM.fdiv

fcmp ::
  (IsFloating a, CmpRet a b) =>
  FPPredicate -> Value a -> Value a -> CodeGenFunction r (Value b)
fcmp = LLVM.fcmp


icmp ::
  (IsIntegerOrPointer a, CmpRet a b) =>
  IntPredicate -> Value a -> Value a -> CodeGenFunction r (Value b)
icmp = LLVM.icmp

udiv ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
udiv = LLVM.udiv

urem ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
urem = LLVM.urem


and ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
and = LLVM.and

or ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
or = LLVM.or



{- |
This would also work for vectors,
if LLVM would support 'select' with bool vectors as condition.
-}
umin :: (IsInteger a, CmpRet a Bool) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
umin = cmpSelect (icmp IntULE)

umax :: (IsInteger a, CmpRet a Bool) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
umax = cmpSelect (icmp IntUGE)


smin :: (IsInteger a, CmpRet a Bool) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
smin = cmpSelect (icmp IntSLE)

smax :: (IsInteger a, CmpRet a Bool) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
smax = cmpSelect (icmp IntSGE)

sabs :: (IsInteger a, CmpRet a Bool) =>
   Value a -> CodeGenFunction r (Value a)
sabs x = do
   b <- icmp IntSGE x (value LLVM.zero)
   LLVM.select b x =<< LLVM.neg x


fmin :: (IsFloating a, CmpRet a Bool) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
fmin = cmpSelect (fcmp FPOLE)

fmax :: (IsFloating a, CmpRet a Bool) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
fmax = cmpSelect (fcmp FPOGE)

fabs :: (IsFloating a, CmpRet a Bool) =>
   Value a -> CodeGenFunction r (Value a)
fabs x = do
   b <- fcmp FPOGE x (value LLVM.zero)
   LLVM.select b x =<< LLVM.neg x


cmpSelect ::
   (IsFirstClass a, CmpRet a Bool) =>
   (Value a -> Value a -> CodeGenFunction r (Value Bool)) ->
   (Value a -> Value a -> CodeGenFunction r (Value a))
cmpSelect f x y =
   f x y >>= \b -> LLVM.select b x y



-- * pointers

advanceArrayElementPtr ::
   Value (Ptr o) ->
   CodeGenFunction r (Value (Ptr o))
advanceArrayElementPtr p =
   getElementPtr p (valueOf 1 :: Value Word32, ())



-- * transcendental functions


valueTypeName ::
   (IsType a) =>
   Value a -> String
valueTypeName =
   LLVM.typeName . (undefined :: Value a -> a)


callIntrinsic1 ::
   (IsFirstClass a) =>
   String -> Value a -> CodeGenFunction r (Value a)
callIntrinsic1 fn x = do
   op <- LLVM.externFunction ("llvm." ++ fn ++ "." ++ valueTypeName x)
   r <- LLVM.call op x
   LLVM.addAttributes r 0 [LLVM.ReadNoneAttribute]
   return r

callIntrinsic2 ::
   (IsFirstClass a) =>
   String -> Value a -> Value a -> CodeGenFunction r (Value a)
callIntrinsic2 fn x y = do
   op <- LLVM.externFunction ("llvm." ++ fn ++ "." ++ valueTypeName x)
   r <- LLVM.call op x y
   LLVM.addAttributes r 0 [LLVM.ReadNoneAttribute]
   return r


sqrt, sin, cos, exp, log ::
   (IsFloating a) =>
   Value a -> CodeGenFunction r (Value a)
sqrt = callIntrinsic1 "sqrt"
sin = callIntrinsic1 "sin"
cos = callIntrinsic1 "cos"
exp = callIntrinsic1 "exp"
log = callIntrinsic1 "log"

pow ::
   (IsFloating a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
pow = callIntrinsic2 "pow"