{-# LANGUAGE FlexibleContexts #-}
module LLVM.Extra.ArithmeticPrivate where

import qualified LLVM.Core as LLVM
import LLVM.Core
   (CodeGenFunction, value, valueOf, Value,
    CmpPredicate(CmpLE, CmpGE), FPPredicate, CmpRet,
    IsConst, IsFirstClass, IsArithmetic, IsInteger, IsFloating,
    Ptr, getElementPtr, )

import Data.Word (Word32, )

import Prelude hiding (and, or, sqrt, sin, cos, exp, log, abs, min, max, )


add ::
   (IsArithmetic a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
add = LLVM.add

sub ::
   (IsArithmetic a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
sub = LLVM.sub


inc ::
   (IsArithmetic a, IsConst a, Num a) =>
   Value a -> CodeGenFunction r (Value a)
inc x = add x (valueOf 1)

dec ::
   (IsArithmetic a, IsConst a, Num a) =>
   Value a -> CodeGenFunction r (Value a)
dec x = sub x (valueOf 1)

advanceArrayElementPtr ::
   Value (Ptr a) ->
   CodeGenFunction r (Value (Ptr a))
advanceArrayElementPtr p =
   getElementPtr p (valueOf 1 :: Value Word32, ())



mul ::
   (IsArithmetic a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
mul = LLVM.mul


{- |
This would also work for vectors,
if LLVM would support 'select' with bool vectors as condition.
-}
min :: (IsFirstClass a, CmpRet a Bool) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
min = cmpSelect (cmp CmpLE)

max :: (IsFirstClass a, CmpRet a Bool) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
max = cmpSelect (cmp CmpGE)

abs :: (IsArithmetic a, CmpRet a Bool) =>
   Value a -> CodeGenFunction r (Value a)
abs x = do
   b <- cmp CmpGE x (value LLVM.zero)
   LLVM.select b x =<< LLVM.neg x


cmpSelect ::
   (IsFirstClass a, CmpRet a Bool) =>
   (Value a -> Value a -> CodeGenFunction r (Value Bool)) ->
   (Value a -> Value a -> CodeGenFunction r (Value a))
cmpSelect f x y =
   f x y >>= \b -> LLVM.select b x y


fcmp ::
   (IsFloating a, CmpRet a b) =>
   FPPredicate -> Value a -> Value a -> CodeGenFunction r (Value b)
fcmp = LLVM.fcmp

cmp ::
   (CmpRet a b) =>
   CmpPredicate -> Value a -> Value a -> CodeGenFunction r (Value b)
cmp = LLVM.cmp



and ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
and = LLVM.and

or ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
or = LLVM.or