{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{- |
Support for unified handling of scalars and vectors.

The rounding and fraction functions only work
for floating point values with maximum magnitude of @maxBound :: Int32@.
This way we save expensive handling of possibly seldom cases.
module LLVM.Extra.ScalarOrVector (
   Fraction (truncate, fraction),


   Replicate (replicate, replicateConst),
   Real (min, max, abs, signum),
   PseudoModule (scale, scaleConst),
   ) where

import LLVM.Extra.Vector (Element, Size, )

import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Extension.X86 as X86
import qualified LLVM.Extra.Extension as Ext

import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.ArithmeticPrivate as A

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal (D1, )

import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, ConstValue, valueOf, constOf,
    CmpRet, CmpResult, NumberOfElements,
    Vector, WordN(WordN), IntN(IntN), FP128,
    IsConst, IsInteger, IsFloating,
    CodeGenFunction, )

import Control.Monad.HT ((<=<), )

import qualified Data.NonEmpty as NonEmpty
import Data.Word (Word8, Word16, Word32, Word64, )
import Data.Int  (Int8,  Int16,  Int32,  Int64, )

import Prelude hiding (Real, replicate, min, max, abs, truncate, floor, round, )

class (Real a, IsFloating a) => Fraction a where
   truncate :: Value a -> CodeGenFunction r (Value a)
   fraction :: Value a -> CodeGenFunction r (Value a)

instance Fraction Float where
   truncate =
         (LLVM.inttofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptoint)
         (Ext.with X86.roundss $ \round x -> round x (valueOf 3))
   fraction =
      (\x ->
         fractionGen x
         (Ext.with X86.cmpss $ \cmp ->
            fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x))
      (Ext.with X86.roundss $ \round x ->
         A.sub x =<< round x (valueOf 1))

instance Fraction Double where
   truncate =
         -- X86 only converts Double to Int32, it cannot target Int64
         (LLVM.inttofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptoint)
         (Ext.with X86.roundsd $ \round x -> round x (valueOf 3))
   fraction =
      (\x ->
         fractionGen x
         (Ext.with X86.cmpsd $ \cmp ->
            fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x))
For Doubles it would be more efficient to convert the lower 32 bit
instead of the lower 64 bit,
since x86 supports only conversion from 32 bit natively.
      (Ext.with X86.cmpsd $ \cmp -> fractionLogical
         (\x y -> cmp x y >>= LLVM.bitcast )
      (Ext.with X86.roundsd $ \round x ->
         A.sub x =<< round x (valueOf 1))

instance (TypeNum.Positive n, Vector.Real a, IsFloating a, IsConst a) =>
      Fraction (Vector n a) where
   truncate = Vector.truncate
   fraction = Vector.fraction

{- |
The fraction has the same sign as the argument.
This is not particular useful but fast on IEEE implementations.
signedFraction ::
   (Fraction a) =>
   Value a -> CodeGenFunction r (Value a)
signedFraction x =
   A.sub x =<< truncate x

fractionGen ::
   (IntegerConstant v, Fraction v, CmpRet v) =>
   Value v -> CodeGenFunction r (Value v)
fractionGen x =
   do xf <- signedFraction x
      b <- A.fcmp LLVM.FPOGE xf zero
      LLVM.select b xf =<< A.add xf (LLVM.value $ constFromInteger 1)

fractionLogical ::
   (Fraction a, LLVM.IsScalarOrVector a, NumberOfElements a ~ D1,
    IsInteger b, LLVM.IsScalarOrVector b, NumberOfElements b ~ D1) =>
   (LLVM.FPPredicate ->
    Value a -> Value a -> CodeGenFunction r (Value b)) ->
   Value a -> CodeGenFunction r (Value a)
fractionLogical cmp x =
   do xf <- signedFraction x
      b <- cmp LLVM.FPOLT xf zero
      A.sub xf =<< LLVM.inttofp b

{- |
increment (first operand) may be negative,
phase must always be non-negative
addToPhase ::
   (Fraction a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
addToPhase d p =
   fraction =<< A.add d p

{- |
both increment and phase must be non-negative
incPhase ::
   (Fraction a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
incPhase d p =
   signedFraction =<< A.add d p

truncateToInt ::
   (IsFloating a, IsInteger i,
    NumberOfElements a ~ NumberOfElements i) =>
   Value a -> CodeGenFunction r (Value i)
truncateToInt = LLVM.fptoint

{- |
Rounds to the next integer.
For numbers of the form @n+0.5@,
we choose one of the neighboured integers
such that the overall implementation is most efficient.
roundToIntFast ::
   (IsFloating a, RationalConstant a, CmpRet a,
    IsInteger i, IntegerConstant i, CmpRet i,
    CmpResult a ~ CmpResult i,
    NumberOfElements a ~ NumberOfElements i) =>
   Value a -> CodeGenFunction r (Value i)
roundToIntFast x = do
   pos <- A.cmp LLVM.CmpGT x zero
   truncateToInt =<< A.add x =<<
      LLVM.select pos (ratio 0.5) (ratio (-0.5))

floorToInt ::
   (IsFloating a, CmpRet a,
    IsInteger i, IntegerConstant i, CmpRet i,
    CmpResult a ~ CmpResult i,
    NumberOfElements a ~ NumberOfElements i) =>
   Value a -> CodeGenFunction r (Value i)
floorToInt x = do
   i <- truncateToInt x
   lt <- A.cmp LLVM.CmpLT x =<< LLVM.inttofp i
   A.sub i =<< LLVM.select lt (int 1) (int 0)

splitFractionToInt ::
   (IsFloating a, CmpRet a,
    IsInteger i, IntegerConstant i, CmpRet i,
    CmpResult a ~ CmpResult i,
    NumberOfElements a ~ NumberOfElements i) =>
   Value a -> CodeGenFunction r (Value i, Value a)
splitFractionToInt x = do
   i <- floorToInt x
   frac <- A.sub x =<< LLVM.inttofp i
   return (i, frac)

ceilingToInt ::
   (IsFloating a, CmpRet a,
    IsInteger i, IntegerConstant i, CmpRet i,
    CmpResult a ~ CmpResult i,
    NumberOfElements a ~ NumberOfElements i) =>
   Value a -> CodeGenFunction r (Value i)
ceilingToInt x = do
   i <- truncateToInt x
   gt <- A.cmp LLVM.CmpGT x =<< LLVM.inttofp i
   A.add i =<< LLVM.select gt (int 1) (int 0)

zero :: (LLVM.IsType a) => Value a
zero = LLVM.value LLVM.zero

int :: (IntegerConstant a) => Integer -> Value a
int = LLVM.value . constFromInteger

ratio :: (RationalConstant a) => Rational -> Value a
ratio = LLVM.value . constFromRational

type family Scalar vector :: *

type instance Scalar Float  = Float
type instance Scalar Double = Double
type instance Scalar FP128  = FP128
type instance Scalar Bool   = Bool
type instance Scalar Int8   = Int8
type instance Scalar Int16  = Int16
type instance Scalar Int32  = Int32
type instance Scalar Int64  = Int64
type instance Scalar Word8  = Word8
type instance Scalar Word16 = Word16
type instance Scalar Word32 = Word32
type instance Scalar Word64 = Word64
type instance Scalar (Vector n a) = a

class Replicate vector where
   -- | an alternative is using the 'Vector.Constant' vector type
   replicate :: Value (Scalar vector) -> CodeGenFunction r (Value vector)
   replicateConst :: ConstValue (Scalar vector) -> ConstValue vector

instance Replicate Float  where replicate = return; replicateConst = id;
instance Replicate Double where replicate = return; replicateConst = id;
instance Replicate FP128  where replicate = return; replicateConst = id;
instance Replicate Bool   where replicate = return; replicateConst = id;
instance Replicate Int8   where replicate = return; replicateConst = id;
instance Replicate Int16  where replicate = return; replicateConst = id;
instance Replicate Int32  where replicate = return; replicateConst = id;
instance Replicate Int64  where replicate = return; replicateConst = id;
instance Replicate Word8  where replicate = return; replicateConst = id;
instance Replicate Word16 where replicate = return; replicateConst = id;
instance Replicate Word32 where replicate = return; replicateConst = id;
instance Replicate Word64 where replicate = return; replicateConst = id;
instance (TypeNum.Positive n, LLVM.IsPrimitive a) => Replicate (Vector n a) where
{- crashes LLVM-2.5, seems to be fixed in LLVM-2.6 -}
   replicate x = do
      v <- singleton x
      LLVM.shufflevector v (LLVM.value LLVM.undef) LLVM.zero
{- crashes LLVM-2.5
   replicate x = do
      v <- LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 1)
      LLVM.shufflevector v (LLVM.value LLVM.undef) (constVector $ repeat $ LLVM.constOf 1)
{- the (repeat zero) is also converted to 'zeroinitializer' and crashes LLVM compiler

         (constVector $ repeat LLVM.zero)
   replicate = Vector.replicate
   replicateConst x = LLVM.constCyclicVector $ NonEmpty.Cons x []

singleton ::
   (LLVM.IsPrimitive a) =>
   Value a -> CodeGenFunction r (Value (Vector D1 a))
singleton x =
   LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 0)

replicateOf ::
   (IsConst (Scalar v), Replicate v) =>
   Scalar v -> Value v
replicateOf =
   LLVM.value . replicateConst . LLVM.constOf

class (LLVM.IsArithmetic a) => Real a where
   min :: Value a -> Value a -> CodeGenFunction r (Value a)
   max :: Value a -> Value a -> CodeGenFunction r (Value a)
   abs :: Value a -> CodeGenFunction r (Value a)
   signum :: Value a -> CodeGenFunction r (Value a)

instance Real Float  where
   min = zipAutoWith A.min X86.minss
   max = zipAutoWith A.max X86.maxss
   abs = mapAuto     A.abs X86.absss
   -- abs x = max x =<< LLVM.neg x
   -- abs x = A.abs
   signum = A.signum

instance Real Double where
   min = zipAutoWith A.min X86.minsd
   max = zipAutoWith A.max X86.maxsd
   abs = mapAuto     A.abs X86.abssd
   signum = A.signum

instance Real FP128  where
   min = A.min
   max = A.max
   abs = A.abs
   signum x = do
      minusOne <- LLVM.inttofp $ LLVM.valueOf (-1 :: Int8)
      one      <- LLVM.inttofp $ LLVM.valueOf ( 1 :: Int8)
      A.signumGen minusOne one x

infixl 1 `mapAuto`

{- |
There are functions that are intended for processing scalars
but have formally vector input and output.
This function breaks vector function down to a scalar function
by accessing the lowest vector element.
runScalar ::
   (Vector.C v, Vector.C w, Size v ~ Size w) =>
   (v -> CodeGenFunction r w) ->
   (Element v -> CodeGenFunction r (Element w))
runScalar op a =
   Vector.extract (valueOf 0)
     =<< op
     =<< Vector.insert (valueOf 0) a Class.undefTuple

mapAuto ::
   (Vector.C v, Vector.C w, Size v ~ Size w) =>
   (Element v -> CodeGenFunction r (Element w)) ->
   Ext.T (v -> CodeGenFunction r w) ->
   (Element v -> CodeGenFunction r (Element w))
mapAuto f g a =
   Ext.run (f a) $
   Ext.with g $ \op -> runScalar op a

zipAutoWith ::
   (Vector.C u, Vector.C v, Vector.C w,
    Size u ~ Size v, Size v ~ Size w) =>
   (Element u -> Element v -> CodeGenFunction r (Element w)) ->
   Ext.T (u -> v -> CodeGenFunction r w) ->
   (Element u -> Element v -> CodeGenFunction r (Element w))
zipAutoWith f g =
   curry $ mapAuto (uncurry f) (fmap uncurry g)

instance Real Int8   where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Int16  where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Int32  where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Int64  where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Word8  where min = A.min; max = A.max; signum = A.signum; abs = return;
instance Real Word16 where min = A.min; max = A.max; signum = A.signum; abs = return;
instance Real Word32 where min = A.min; max = A.max; signum = A.signum; abs = return;
instance Real Word64 where min = A.min; max = A.max; signum = A.signum; abs = return;

instance (TypeNum.Positive n) => Real (IntN n) where
   min = A.min; max = A.max; abs = A.abs
   signum = A.signumGen (LLVM.valueOf $ IntN (-1)) (LLVM.valueOf $ IntN 1)
instance (TypeNum.Positive n) => Real (WordN n) where
   min = A.min; max = A.max; abs = return
   signum = A.signumGen (LLVM.value LLVM.undef) (LLVM.valueOf $ WordN 1)

instance (TypeNum.Positive n, Vector.Real a) => Real (Vector n a) where
   min = Vector.min
   max = Vector.max
   abs = Vector.abs
   signum = Vector.signum

   (LLVM.IsArithmetic (Scalar v), LLVM.IsArithmetic v) =>
      PseudoModule v where
   scale :: (a ~ Scalar v) => Value a -> Value v -> CodeGenFunction r (Value v)
   scaleConst :: (a ~ Scalar v) => ConstValue a -> ConstValue v -> CodeGenFunction r (ConstValue v)

instance PseudoModule Word8  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Word16 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Word32 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Word64 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int8   where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int16  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int32  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int64  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Float  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Double where scale = LLVM.mul; scaleConst = LLVM.mul
instance (LLVM.IsArithmetic a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
         PseudoModule (Vector n a) where
   scale a v = flip A.mul v =<< replicate a
   scaleConst a v = LLVM.mul (replicateConst a `asTypeOf` v) v

class (LLVM.IsConst a) => IntegerConstant a where
   constFromInteger :: Integer -> ConstValue a

instance IntegerConstant Word8  where constFromInteger = constOf . fromInteger
instance IntegerConstant Word16 where constFromInteger = constOf . fromInteger
instance IntegerConstant Word32 where constFromInteger = constOf . fromInteger
instance IntegerConstant Word64 where constFromInteger = constOf . fromInteger
instance IntegerConstant Int8   where constFromInteger = constOf . fromInteger
instance IntegerConstant Int16  where constFromInteger = constOf . fromInteger
instance IntegerConstant Int32  where constFromInteger = constOf . fromInteger
instance IntegerConstant Int64  where constFromInteger = constOf . fromInteger
instance IntegerConstant Float  where constFromInteger = constOf . fromInteger
instance IntegerConstant Double where constFromInteger = constOf . fromInteger
instance (TypeNum.Positive n) => IntegerConstant (WordN n) where
   constFromInteger = constOf . WordN
instance (TypeNum.Positive n) => IntegerConstant (IntN n)  where
   constFromInteger = constOf . IntN
instance (IntegerConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
         IntegerConstant (Vector n a) where
   constFromInteger = replicateConst . constFromInteger

class (IntegerConstant a) => RationalConstant a where
   constFromRational :: Rational -> ConstValue a

instance RationalConstant Float  where constFromRational = constOf . fromRational
instance RationalConstant Double where constFromRational = constOf . fromRational
instance (RationalConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
         RationalConstant (Vector n a) where
   constFromRational = replicateConst . constFromRational

class (RationalConstant a) => TranscendentalConstant a where
   constPi :: ConstValue a

instance TranscendentalConstant Float  where constPi = constOf pi
instance TranscendentalConstant Double where constPi = constOf pi
instance (TranscendentalConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
         TranscendentalConstant (Vector n a) where
   constPi = replicateConst constPi