{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleContexts #-}
module LLVM.Extra.Arithmetic (
   -- * arithmetic: generalized and improved type inference
   Additive (zero, add, sub, neg), one, inc, dec,
   PseudoRing (mul), square,
   Scalar,
   PseudoModule (scale),
   Field (fdiv),
   IntegerConstant(fromInteger'),
   RationalConstant(fromRational'),
   idiv, irem,
   FloatingComparison(fcmp), Comparison(cmp),
   CmpResult, LLVM.CmpPredicate(..),
   Logic (and, or, xor, inv),
   Real (min, max, abs, signum),
   Fraction (truncate, fraction),
   signedFraction, addToPhase, incPhase,
   -- * pointer arithmetic
   advanceArrayElementPtr,
   decreaseArrayElementPtr,
   -- * transcendental functions
   Algebraic (sqrt),
   Transcendental (pi, sin, cos, exp, log, pow),
   ) where

import qualified LLVM.Util.Intrinsic as Intrinsic
import LLVM.Extra.ArithmeticPrivate
   (inc, dec, advanceArrayElementPtr, decreaseArrayElementPtr, )

import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Core as LLVM
import LLVM.Core
   (CodeGenFunction, value, Value, ConstValue,
    IsInteger, IsFloating, IsArithmetic)

import Control.Monad (liftM2, liftM3, )

import Prelude hiding
   (Real, and, or, sqrt, sin, cos, exp, log, abs, min, max, truncate, )



{- |
This and the following type classes
are intended for arithmetic operations on wrappers around LLVM types.
E.g. you might define a fixed point fraction type by

> newtype Fixed = Fixed Int32

and then use the same methods for floating point and fixed point arithmetic.

In contrast to the arithmetic methods in the @llvm@ wrapper,
in our methods the types of operands and result match.
Advantage: Type inference determines most of the types automatically.
Disadvantage: You cannot use constant values directly,
but you have to convert them all to 'Value'.
-}
class (Tuple.Zero a) => Additive a where
   zero :: a
   add :: a -> a -> CodeGenFunction r a
   sub :: a -> a -> CodeGenFunction r a
   neg :: a -> CodeGenFunction r a

instance (IsArithmetic a) => Additive (Value a) where
   zero = LLVM.value LLVM.zero
   add = LLVM.add
   sub = LLVM.sub
   neg = LLVM.neg

instance (IsArithmetic a) => Additive (ConstValue a) where
   zero = LLVM.zero
   add = LLVM.add
   sub = LLVM.sub
   neg = sub LLVM.zero

instance (Additive a, Additive b) => Additive (a,b) where
   zero = (zero, zero)
   add (x0,x1) (y0,y1) =
      liftM2 (,) (add x0 y0) (add x1 y1)
   sub (x0,x1) (y0,y1) =
      liftM2 (,) (sub x0 y0) (sub x1 y1)
   neg (x0,x1) =
      liftM2 (,) (neg x0)    (neg x1)

instance (Additive a, Additive b, Additive c) => Additive (a,b,c) where
   zero = (zero, zero, zero)
   add (x0,x1,x2) (y0,y1,y2) =
      liftM3 (,,) (add x0 y0) (add x1 y1) (add x2 y2)
   sub (x0,x1,x2) (y0,y1,y2) =
      liftM3 (,,) (sub x0 y0) (sub x1 y1) (sub x2 y2)
   neg (x0,x1,x2) =
      liftM3 (,,) (neg x0)    (neg x1)    (neg x2)


class (Additive a) => PseudoRing a where
   mul :: a -> a -> CodeGenFunction r a

instance (IsArithmetic v) => PseudoRing (Value v) where
   mul = LLVM.mul

instance (IsArithmetic v) => PseudoRing (ConstValue v) where
   mul = LLVM.mul


type family Scalar vector :: *
type instance Scalar (Value a) = Value (SoV.Scalar a)
type instance Scalar (ConstValue a) = ConstValue (SoV.Scalar a)

class (PseudoRing (Scalar v), Additive v) => PseudoModule v where
   scale :: Scalar v -> v -> CodeGenFunction r v

instance (SoV.PseudoModule v) => PseudoModule (Value v) where
   scale = SoV.scale

instance (SoV.PseudoModule v) => PseudoModule (ConstValue v) where
   scale = SoV.scaleConst


class IntegerConstant a where
   fromInteger' :: Integer -> a

instance SoV.IntegerConstant a => IntegerConstant (ConstValue a) where
   fromInteger' = SoV.constFromInteger

instance SoV.IntegerConstant a => IntegerConstant (Value a) where
   fromInteger' = value . SoV.constFromInteger


one :: (IntegerConstant a) => a
one = fromInteger' 1


{-
more general alternative to 'inc',
but you may not like the resulting type constraints
-}
_inc ::
   (PseudoRing a, IntegerConstant a) =>
   a -> CodeGenFunction r a
_inc x = add x one

_dec ::
   (PseudoRing a, IntegerConstant a) =>
   a -> CodeGenFunction r a
_dec x = sub x one


square ::
   (PseudoRing a) =>
   a -> CodeGenFunction r a
square x = mul x x


class (PseudoRing a) => Field a where
   fdiv :: a -> a -> CodeGenFunction r a

instance (LLVM.IsFloating v) => Field (Value v) where
   fdiv = LLVM.fdiv

instance (LLVM.IsFloating v) => Field (ConstValue v) where
   fdiv = LLVM.fdiv


class (IntegerConstant a) => RationalConstant a where
   fromRational' :: Rational -> a

instance SoV.RationalConstant a => RationalConstant (ConstValue a) where
   fromRational' = SoV.constFromRational

instance SoV.RationalConstant a => RationalConstant (Value a) where
   fromRational' = value . SoV.constFromRational



{- |
In Haskell terms this is a 'quot'.
-}
idiv ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
idiv = LLVM.idiv

irem ::
   (IsInteger a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
irem = LLVM.irem



class (Additive a) => Real a where
   min :: a -> a -> CodeGenFunction r a
   max :: a -> a -> CodeGenFunction r a
   abs :: a -> CodeGenFunction r a
   signum :: a -> CodeGenFunction r a

instance (SoV.Real a) => Real (Value a) where
   min = SoV.min
   max = SoV.max
   abs = SoV.abs
   signum = SoV.signum


class (Real a) => Fraction a where
   truncate :: a -> CodeGenFunction r a
   fraction :: a -> CodeGenFunction r a

instance (SoV.Fraction a) => Fraction (Value a) where
   truncate = SoV.truncate
   fraction = SoV.fraction

signedFraction ::
   (Fraction a) =>
   a -> CodeGenFunction r a
signedFraction x =
   sub x =<< truncate x

addToPhase ::
   (Fraction a) =>
   a -> a -> CodeGenFunction r a
addToPhase d p =
   fraction =<< add d p

{- |
both increment and phase must be non-negative
-}
incPhase ::
   (Fraction a) =>
   a -> a -> CodeGenFunction r a
incPhase d p =
   signedFraction =<< add d p


class Comparison a where
   type CmpResult a :: *
   cmp :: LLVM.CmpPredicate -> a -> a -> CodeGenFunction r (CmpResult a)

instance (LLVM.CmpRet a) => Comparison (Value a) where
   type CmpResult (Value a) = Value (LLVM.CmpResult a)
   cmp = LLVM.cmp

instance (LLVM.CmpRet a) => Comparison (ConstValue a) where
   type CmpResult (ConstValue a) = ConstValue (LLVM.CmpResult a)
   cmp = LLVM.cmp


class (Comparison a) => FloatingComparison a where
   fcmp :: LLVM.FPPredicate -> a -> a -> CodeGenFunction r (CmpResult a)

instance (IsFloating a, LLVM.CmpRet a) => FloatingComparison (Value a) where
   fcmp = LLVM.fcmp

instance (IsFloating a, LLVM.CmpRet a) => FloatingComparison (ConstValue a) where
   fcmp = LLVM.fcmp



class Logic a where
   and :: a -> a -> CodeGenFunction r a
   or :: a -> a -> CodeGenFunction r a
   xor :: a -> a -> CodeGenFunction r a
   inv :: a -> CodeGenFunction r a

instance (LLVM.IsInteger a) => Logic (Value a) where
   and = LLVM.and
   or = LLVM.or
   xor = LLVM.xor
   inv = LLVM.inv

instance (LLVM.IsInteger a) => Logic (ConstValue a) where
   and = LLVM.and
   or = LLVM.or
   xor = LLVM.xor
   inv = LLVM.inv



class Field a => Algebraic a where
   sqrt :: a -> CodeGenFunction r a

instance (IsFloating a) => Algebraic (Value a) where
   sqrt = Intrinsic.call1 "sqrt"


class Algebraic a => Transcendental a where
   pi :: CodeGenFunction r a
   sin, cos, exp, log :: a -> CodeGenFunction r a
   pow :: a -> a -> CodeGenFunction r a

instance (IsFloating a, SoV.TranscendentalConstant a) => Transcendental (Value a) where
   pi = return $ value SoV.constPi
   sin = Intrinsic.call1 "sin"
   cos = Intrinsic.call1 "cos"
   exp = Intrinsic.call1 "exp"
   log = Intrinsic.call1 "log"
   pow = Intrinsic.call2 "pow"