{-# LANGUAGE TypeFamilies #-}
module LLVM.Extra.ArithmeticPrivate where

import qualified LLVM.Core as LLVM
import LLVM.Core
   (CodeGenFunction, valueOf, Value,
    CmpPredicate(CmpLE, CmpGE), FPPredicate, CmpRet, CmpResult,
    IsConst, IsFirstClass, IsArithmetic, IsInteger, IsFloating,
    getElementPtr, )

import Foreign.Ptr (Ptr, )
import Data.Word (Word32, )
import Data.Int (Int32, )

import Prelude hiding (and, or, sqrt, sin, cos, exp, log, abs, min, max, )


add ::
   (IsArithmetic a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
add = LLVM.add

sub ::
   (IsArithmetic a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
sub = LLVM.sub


inc ::
   (IsArithmetic a, IsConst a, Num a) =>
   Value a -> CodeGenFunction r (Value a)
inc x = add x (valueOf 1)

dec ::
   (IsArithmetic a, IsConst a, Num a) =>
   Value a -> CodeGenFunction r (Value a)
dec x = sub x (valueOf 1)

advanceArrayElementPtr ::
   Value (Ptr a) ->
   CodeGenFunction r (Value (Ptr a))
advanceArrayElementPtr p =
   getElementPtr p (valueOf 1 :: Value Word32, ())

decreaseArrayElementPtr ::
   Value (Ptr a) ->
   CodeGenFunction r (Value (Ptr a))
decreaseArrayElementPtr p =
   getElementPtr p (valueOf (-1) :: Value Int32, ())



mul ::
   (IsArithmetic a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
mul = LLVM.mul


{- |
This would also work for vectors,
but LLVM-3.1 crashes when actually doing this.
-}
min :: (IsFirstClass a, CmpRet a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
min = cmpSelect (cmp CmpLE)

max :: (IsFirstClass a, CmpRet a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
max = cmpSelect (cmp CmpGE)

abs :: (IsArithmetic a, CmpRet a) =>
   Value a -> CodeGenFunction r (Value a)
abs x = max x =<< LLVM.neg x


signumGen ::
   (LLVM.IsFirstClass a,
    LLVM.CmpRet a, LLVM.CmpResult a ~ Bool) =>
   LLVM.Value a -> LLVM.Value a ->
   Value a -> CodeGenFunction r (Value a)
signumGen minusOne one x = do
   let zero = LLVM.value LLVM.zero
   negative <- cmp LLVM.CmpLT x zero
   positive <- cmp LLVM.CmpGT x zero
   LLVM.select negative minusOne
      =<< LLVM.select positive one zero

signum ::
   (Num a,
    LLVM.IsConst a, LLVM.IsFirstClass a,
    LLVM.CmpRet a, LLVM.CmpResult a ~ Bool) =>
   Value a -> CodeGenFunction r (Value a)
signum = signumGen (LLVM.valueOf (-1)) (LLVM.valueOf 1)


cmpSelect ::
   (IsFirstClass a, CmpRet a) =>
   (Value a -> Value a -> CodeGenFunction r (Value (CmpResult a))) ->
   (Value a -> Value a -> CodeGenFunction r (Value a))
cmpSelect f x y =
   f x y >>= \b -> LLVM.select b x y


fcmp ::
   (IsFloating a, CmpRet a, CmpResult a ~ b) =>
   FPPredicate -> Value a -> Value a -> CodeGenFunction r (Value b)
fcmp = LLVM.fcmp

cmp ::
   (CmpRet a, CmpResult a ~ b) =>
   CmpPredicate -> Value a -> Value a -> CodeGenFunction r (Value b)
cmp = LLVM.cmp



and ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
and = LLVM.and

or ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
or = LLVM.or