{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{- |
Support for unified handling of scalars and vectors.

Attention:
The rounding and fraction functions only work
for floating point values with maximum magnitude of @maxBound :: Int32@.
This way we save expensive handling of possibly seldom cases.
-}
module LLVM.Extra.ScalarOrVector (
   Fraction (truncate, fraction),
   signedFraction,
   addToPhase,
   incPhase,

   truncateToInt,
   floorToInt,
   ceilingToInt,
   roundToIntFast,
   splitFractionToInt,

   Scalar,
   Replicate (replicate, replicateConst),
   replicateOf,
   Real (min, max, abs, signum),
   Saturated(addSat, subSat),
   PseudoModule (scale, scaleConst),
   IntegerConstant(constFromInteger),
   RationalConstant(constFromRational),
   TranscendentalConstant(constPi),
   ) where

import qualified LLVM.Extra.ScalarOrVectorPrivate as Priv
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.ArithmeticPrivate as A
import LLVM.Extra.ScalarOrVectorPrivate
   (Scalar, Replicate(replicate, replicateConst))

import qualified LLVM.Util.Intrinsic as Intrinsic
import qualified LLVM.Util.Proxy as LP
import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, ConstValue, constOf,
    CmpRet, CmpResult, ShapeOf,
    Vector, WordN(WordN), IntN(IntN), FP128,
    IsConst, IsInteger, IsFloating,
    CodeGenFunction, )

import qualified Type.Data.Num.Decimal as TypeNum

import Data.Word (Word8, Word16, Word32, Word64, )
import Data.Int  (Int8,  Int16,  Int32,  Int64, )
import Data.Maybe (fromMaybe)

import Prelude hiding (Real, replicate, min, max, abs, truncate)



class (Real a, IsFloating a) => Fraction a where
   truncate :: Value a -> CodeGenFunction r (Value a)
   fraction :: Value a -> CodeGenFunction r (Value a)

instance Fraction Float where
   truncate = Intrinsic.truncate
   fraction = A.fraction

instance Fraction Double where
   truncate = Intrinsic.truncate
   fraction = A.fraction

instance (TypeNum.Positive n, Vector.Real a, IsFloating a, IsConst a) =>
      Fraction (Vector n a) where
   truncate = Vector.truncate
   fraction = Vector.fraction


{- |
The fraction has the same sign as the argument.
This is not particular useful but fast on IEEE implementations.
-}
signedFraction ::
   (Fraction a) =>
   Value a -> CodeGenFunction r (Value a)
signedFraction x =
   A.sub x =<< truncate x

_fractionGen ::
   (IntegerConstant v, Fraction v, CmpRet v) =>
   Value v -> CodeGenFunction r (Value v)
_fractionGen x =
   do xf <- signedFraction x
      b <- A.fcmp LLVM.FPOGE xf zero
      LLVM.select b xf =<< A.add xf (LLVM.value $ constFromInteger 1)

_fractionLogical ::
   (Fraction a, LLVM.IsPrimitive a,
    IsInteger b, LLVM.IsPrimitive b) =>
   (LLVM.FPPredicate ->
    Value a -> Value a -> CodeGenFunction r (Value b)) ->
   Value a -> CodeGenFunction r (Value a)
_fractionLogical cmp x =
   do xf <- signedFraction x
      b <- cmp LLVM.FPOLT xf zero
      A.sub xf =<< LLVM.inttofp b

{- |
increment (first operand) may be negative,
phase must always be non-negative
-}
addToPhase ::
   (Fraction a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
addToPhase d p =
   fraction =<< A.add d p

{- |
both increment and phase must be non-negative
-}
incPhase ::
   (Fraction a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
incPhase d p =
   signedFraction =<< A.add d p



truncateToInt ::
   (IsFloating a, IsInteger i, ShapeOf a ~ ShapeOf i) =>
   Value a -> CodeGenFunction r (Value i)
truncateToInt = LLVM.fptoint

{- |
Rounds to the next integer.
For numbers of the form @n+0.5@,
we choose one of the neighboured integers
such that the overall implementation is most efficient.
-}
roundToIntFast ::
   (IsFloating a, RationalConstant a, CmpRet a,
    IsInteger i, IntegerConstant i, CmpRet i,
    CmpResult a ~ CmpResult i,
    ShapeOf a ~ ShapeOf i) =>
   Value a -> CodeGenFunction r (Value i)
roundToIntFast x = do
   pos <- A.cmp LLVM.CmpGT x zero
   truncateToInt =<< A.add x =<<
      LLVM.select pos (ratio 0.5) (ratio (-0.5))

floorToInt ::
   (IsFloating a, CmpRet a,
    IsInteger i, IntegerConstant i, CmpRet i,
    CmpResult a ~ CmpResult i,
    ShapeOf a ~ ShapeOf i) =>
   Value a -> CodeGenFunction r (Value i)
floorToInt x = do
   i <- truncateToInt x
   lt <- A.cmp LLVM.CmpLT x =<< LLVM.inttofp i
   A.sub i =<< LLVM.select lt (int 1) (int 0)

splitFractionToInt ::
   (IsFloating a, CmpRet a,
    IsInteger i, IntegerConstant i, CmpRet i,
    CmpResult a ~ CmpResult i,
    ShapeOf a ~ ShapeOf i) =>
   Value a -> CodeGenFunction r (Value i, Value a)
splitFractionToInt x = do
   i <- floorToInt x
   frac <- A.sub x =<< LLVM.inttofp i
   return (i, frac)

ceilingToInt ::
   (IsFloating a, CmpRet a,
    IsInteger i, IntegerConstant i, CmpRet i,
    CmpResult a ~ CmpResult i,
    ShapeOf a ~ ShapeOf i) =>
   Value a -> CodeGenFunction r (Value i)
ceilingToInt x = do
   i <- truncateToInt x
   gt <- A.cmp LLVM.CmpGT x =<< LLVM.inttofp i
   A.add i =<< LLVM.select gt (int 1) (int 0)


zero :: (LLVM.IsType a) => Value a
zero = LLVM.value LLVM.zero

int :: (IntegerConstant a) => Integer -> Value a
int = LLVM.value . constFromInteger

ratio :: (RationalConstant a) => Rational -> Value a
ratio = LLVM.value . constFromRational



replicateOf ::
   (IsConst (Scalar v), Replicate v) =>
   Scalar v -> Value v
replicateOf =
   LLVM.value . replicateConst . LLVM.constOf


class (LLVM.IsArithmetic a) => Real a where
   min :: Value a -> Value a -> CodeGenFunction r (Value a)
   max :: Value a -> Value a -> CodeGenFunction r (Value a)
   abs :: Value a -> CodeGenFunction r (Value a)
   signum :: Value a -> CodeGenFunction r (Value a)


instance Real Float  where
   min = Intrinsic.min
   max = Intrinsic.max
   abs = Intrinsic.abs
   signum = A.signum

instance Real Double where
   min = Intrinsic.min
   max = Intrinsic.max
   abs = Intrinsic.abs
   signum = A.signum

instance Real FP128  where
   min = Intrinsic.min
   max = Intrinsic.max
   abs = Intrinsic.abs
   signum x = do
      minusOne <- LLVM.inttofp $ LLVM.valueOf (-1 :: Int8)
      one      <- LLVM.inttofp $ LLVM.valueOf ( 1 :: Int8)
      A.signumGen minusOne one x


instance Real Int8   where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Int16  where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Int32  where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Int64  where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Word8  where min = A.min; max = A.max; signum = A.signum; abs = return;
instance Real Word16 where min = A.min; max = A.max; signum = A.signum; abs = return;
instance Real Word32 where min = A.min; max = A.max; signum = A.signum; abs = return;
instance Real Word64 where min = A.min; max = A.max; signum = A.signum; abs = return;

instance (TypeNum.Positive n) => Real (IntN n) where
   min = A.min; max = A.max; abs = A.abs
   signum = A.signumGen (LLVM.valueOf $ IntN (-1)) (LLVM.valueOf $ IntN 1)
instance (TypeNum.Positive n) => Real (WordN n) where
   min = A.min; max = A.max; abs = return
   signum = A.signumGen (LLVM.value LLVM.undef) (LLVM.valueOf $ WordN 1)

instance (TypeNum.Positive n, Vector.Real a) => Real (Vector n a) where
   min = Vector.min
   max = Vector.max
   abs = Vector.abs
   signum = Vector.signum


class (IsInteger a) => Saturated a where
   addSat, subSat :: Value a -> Value a -> CodeGenFunction r (Value a)

instance Saturated Int8   where
   addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy;
instance Saturated Int16  where
   addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy;
instance Saturated Int32  where
   addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy;
instance Saturated Int64  where
   addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy;
instance Saturated Word8  where
   addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy;
instance Saturated Word16 where
   addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy;
instance Saturated Word32 where
   addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy;
instance Saturated Word64 where
   addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy;
instance (TypeNum.Positive d) => Saturated (IntN  d) where
   addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy;
instance (TypeNum.Positive d) => Saturated (WordN d) where
   addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy;
instance
   (TypeNum.Positive n, LLVM.IsPrimitive a,
    Saturated a, Bounded a, CmpRet a, IsConst a) =>
      Saturated (Vector n a) where
   addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy;

addSatProxy, subSatProxy ::
   (IsInteger v, CmpRet v, Replicate v, ShapeOf v ~ shape,
    LLVM.ShapedType shape Bool ~ bv, ShapeOf bv ~ shape, CmpRet bv,
    Scalar v ~ a, IsConst a, Bounded a) =>
   LP.Proxy v -> Value v -> Value v -> CodeGenFunction r (Value v)
addSatProxy proxy =
   if LLVM.isSigned proxy
      then fromMaybe Priv.saddSat Intrinsic.maybeSAddSat
      else fromMaybe Priv.uaddSat Intrinsic.maybeUAddSat
subSatProxy proxy =
   if LLVM.isSigned proxy
      then fromMaybe Priv.ssubSat Intrinsic.maybeSSubSat
      else fromMaybe Priv.usubSat Intrinsic.maybeUSubSat



class
   (LLVM.IsArithmetic (Scalar v), LLVM.IsArithmetic v) =>
      PseudoModule v where
   scale :: (a ~ Scalar v) => Value a -> Value v -> CodeGenFunction r (Value v)
   scaleConst :: (a ~ Scalar v) => ConstValue a -> ConstValue v -> CodeGenFunction r (ConstValue v)

instance PseudoModule Word8  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Word16 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Word32 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Word64 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int8   where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int16  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int32  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int64  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Float  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Double where scale = LLVM.mul; scaleConst = LLVM.mul
instance (LLVM.IsArithmetic a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
         PseudoModule (Vector n a) where
   scale a v = flip A.mul v =<< replicate a
   scaleConst a v = LLVM.mul (replicateConst a `asTypeOf` v) v



class (LLVM.IsConst a) => IntegerConstant a where
   constFromInteger :: Integer -> ConstValue a

instance IntegerConstant Word8  where constFromInteger = constOf . fromInteger
instance IntegerConstant Word16 where constFromInteger = constOf . fromInteger
instance IntegerConstant Word32 where constFromInteger = constOf . fromInteger
instance IntegerConstant Word64 where constFromInteger = constOf . fromInteger
instance IntegerConstant Int8   where constFromInteger = constOf . fromInteger
instance IntegerConstant Int16  where constFromInteger = constOf . fromInteger
instance IntegerConstant Int32  where constFromInteger = constOf . fromInteger
instance IntegerConstant Int64  where constFromInteger = constOf . fromInteger
instance IntegerConstant Float  where constFromInteger = constOf . fromInteger
instance IntegerConstant Double where constFromInteger = constOf . fromInteger
instance (TypeNum.Positive n) => IntegerConstant (WordN n) where
   constFromInteger = constOf . WordN
instance (TypeNum.Positive n) => IntegerConstant (IntN n)  where
   constFromInteger = constOf . IntN
instance (IntegerConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
         IntegerConstant (Vector n a) where
   constFromInteger = replicateConst . constFromInteger


class (IntegerConstant a) => RationalConstant a where
   constFromRational :: Rational -> ConstValue a

instance RationalConstant Float  where constFromRational = constOf . fromRational
instance RationalConstant Double where constFromRational = constOf . fromRational
instance (RationalConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
         RationalConstant (Vector n a) where
   constFromRational = replicateConst . constFromRational


class (RationalConstant a) => TranscendentalConstant a where
   constPi :: ConstValue a

instance TranscendentalConstant Float  where constPi = constOf pi
instance TranscendentalConstant Double where constPi = constOf pi
instance (TranscendentalConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
         TranscendentalConstant (Vector n a) where
   constPi = replicateConst constPi