{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{- |
Support for unified handling of scalars and vectors.

Attention:
The rounding and fraction functions only work
for floating point values with maximum magnitude of @maxBound :: Int32@.
This way we save expensive handling of possibly seldom cases.
-}
module LLVM.Extra.ScalarOrVector (
   Fraction (truncate, fraction),
   signedFraction,
   addToPhase,
   incPhase,

   truncateToInt,
   floorToInt,
   ceilingToInt,
   roundToIntFast,
   splitFractionToInt,

   Scalar,
   Replicate (replicate, replicateConst),
   replicateOf,
   Real (min, max, abs, signum),
   PseudoModule (scale, scaleConst),
   IntegerConstant(constFromInteger),
   RationalConstant(constFromRational),
   TranscendentalConstant(constPi),
   ) where

import LLVM.Extra.Vector (Element, Size, )

import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Extension.X86 as X86
import qualified LLVM.Extra.Extension as Ext

import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.ArithmeticPrivate as A

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal (D1, )

import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, ConstValue, valueOf, constOf,
    CmpRet, CmpResult, NumberOfElements,
    Vector, FP128,
    IsConst, IsInteger, IsFloating,
    CodeGenFunction, )

import Control.Monad.HT ((<=<), )

import qualified Data.NonEmpty as NonEmpty
import Data.Word (Word8, Word16, Word32, Word64, )
import Data.Int  (Int8,  Int16,  Int32,  Int64, )

import Prelude hiding (Real, replicate, min, max, abs, truncate, floor, round, )



class (Real a, IsFloating a) => Fraction a where
   truncate :: Value a -> CodeGenFunction r (Value a)
   fraction :: Value a -> CodeGenFunction r (Value a)

instance Fraction Float where
   truncate =
      mapAuto
         (LLVM.inttofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptoint)
         (Ext.with X86.roundss $ \round x -> round x (valueOf 3))
   fraction =
      (\x ->
         fractionGen x
         `Ext.run`
         (Ext.with X86.cmpss $ \cmp ->
            fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x))
      `mapAuto`
      (Ext.with X86.roundss $ \round x ->
         A.sub x =<< round x (valueOf 1))

instance Fraction Double where
   truncate =
      mapAuto
         -- X86 only converts Double to Int32, it cannot target Int64
         (LLVM.inttofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptoint)
         (Ext.with X86.roundsd $ \round x -> round x (valueOf 3))
   fraction =
      (\x ->
         fractionGen x
         `Ext.run`
         (Ext.with X86.cmpsd $ \cmp ->
            fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x))
{-
For Doubles it would be more efficient to convert the lower 32 bit
instead of the lower 64 bit,
since x86 supports only conversion from 32 bit natively.
      (Ext.with X86.cmpsd $ \cmp -> fractionLogical
         (\x y -> cmp x y >>= LLVM.bitcast )
-}
      `mapAuto`
      (Ext.with X86.roundsd $ \round x ->
         A.sub x =<< round x (valueOf 1))

instance (TypeNum.Positive n, Vector.Real a, IsFloating a, IsConst a) =>
      Fraction (Vector n a) where
   truncate = Vector.truncate
   fraction = Vector.fraction


{- |
The fraction has the same sign as the argument.
This is not particular useful but fast on IEEE implementations.
-}
signedFraction ::
   (Fraction a) =>
   Value a -> CodeGenFunction r (Value a)
signedFraction x =
   A.sub x =<< truncate x

fractionGen ::
   (IntegerConstant v, Fraction v, CmpRet v) =>
   Value v -> CodeGenFunction r (Value v)
fractionGen x =
   do xf <- signedFraction x
      b <- A.fcmp LLVM.FPOGE xf zero
      LLVM.select b xf =<< A.add xf (LLVM.value $ constFromInteger 1)

fractionLogical ::
   (Fraction a, LLVM.IsScalarOrVector a, NumberOfElements a ~ D1,
    IsInteger b, LLVM.IsScalarOrVector b, NumberOfElements b ~ D1) =>
   (LLVM.FPPredicate ->
    Value a -> Value a -> CodeGenFunction r (Value b)) ->
   Value a -> CodeGenFunction r (Value a)
fractionLogical cmp x =
   do xf <- signedFraction x
      b <- cmp LLVM.FPOLT xf zero
      A.sub xf =<< LLVM.inttofp b

{- |
increment (first operand) may be negative,
phase must always be non-negative
-}
addToPhase ::
   (Fraction a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
addToPhase d p =
   fraction =<< A.add d p

{- |
both increment and phase must be non-negative
-}
incPhase ::
   (Fraction a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
incPhase d p =
   signedFraction =<< A.add d p



truncateToInt ::
   (IsFloating a, IsInteger i,
    NumberOfElements a ~ NumberOfElements i) =>
   Value a -> CodeGenFunction r (Value i)
truncateToInt = LLVM.fptoint

{- |
Rounds to the next integer.
For numbers of the form @n+0.5@,
we choose one of the neighboured integers
such that the overall implementation is most efficient.
-}
roundToIntFast ::
   (IsFloating a, RationalConstant a, CmpRet a,
    IsInteger i, IntegerConstant i, CmpRet i,
    CmpResult a ~ CmpResult i,
    NumberOfElements a ~ NumberOfElements i) =>
   Value a -> CodeGenFunction r (Value i)
roundToIntFast x = do
   pos <- A.cmp LLVM.CmpGT x zero
   truncateToInt =<< A.add x =<<
      LLVM.select pos (ratio 0.5) (ratio (-0.5))

floorToInt ::
   (IsFloating a, CmpRet a,
    IsInteger i, IntegerConstant i, CmpRet i,
    CmpResult a ~ CmpResult i,
    NumberOfElements a ~ NumberOfElements i) =>
   Value a -> CodeGenFunction r (Value i)
floorToInt x = do
   i <- truncateToInt x
   lt <- A.cmp LLVM.CmpLT x =<< LLVM.inttofp i
   A.sub i =<< LLVM.select lt (int 1) (int 0)

splitFractionToInt ::
   (IsFloating a, CmpRet a,
    IsInteger i, IntegerConstant i, CmpRet i,
    CmpResult a ~ CmpResult i,
    NumberOfElements a ~ NumberOfElements i) =>
   Value a -> CodeGenFunction r (Value i, Value a)
splitFractionToInt x = do
   i <- floorToInt x
   frac <- A.sub x =<< LLVM.inttofp i
   return (i, frac)

ceilingToInt ::
   (IsFloating a, CmpRet a,
    IsInteger i, IntegerConstant i, CmpRet i,
    CmpResult a ~ CmpResult i,
    NumberOfElements a ~ NumberOfElements i) =>
   Value a -> CodeGenFunction r (Value i)
ceilingToInt x = do
   i <- truncateToInt x
   gt <- A.cmp LLVM.CmpGT x =<< LLVM.inttofp i
   A.add i =<< LLVM.select gt (int 1) (int 0)


zero :: (LLVM.IsType a) => Value a
zero = LLVM.value LLVM.zero

int :: (IntegerConstant a) => Integer -> Value a
int = LLVM.value . constFromInteger

ratio :: (RationalConstant a) => Rational -> Value a
ratio = LLVM.value . constFromRational


type family Scalar vector :: *

type instance Scalar Float  = Float
type instance Scalar Double = Double
type instance Scalar FP128  = FP128
type instance Scalar Bool   = Bool
type instance Scalar Int8   = Int8
type instance Scalar Int16  = Int16
type instance Scalar Int32  = Int32
type instance Scalar Int64  = Int64
type instance Scalar Word8  = Word8
type instance Scalar Word16 = Word16
type instance Scalar Word32 = Word32
type instance Scalar Word64 = Word64
type instance Scalar (Vector n a) = a



class Replicate vector where
   -- | an alternative is using the 'Vector.Constant' vector type
   replicate :: Value (Scalar vector) -> CodeGenFunction r (Value vector)
   replicateConst :: ConstValue (Scalar vector) -> ConstValue vector

instance Replicate Float  where replicate = return; replicateConst = id;
instance Replicate Double where replicate = return; replicateConst = id;
instance Replicate FP128  where replicate = return; replicateConst = id;
instance Replicate Bool   where replicate = return; replicateConst = id;
instance Replicate Int8   where replicate = return; replicateConst = id;
instance Replicate Int16  where replicate = return; replicateConst = id;
instance Replicate Int32  where replicate = return; replicateConst = id;
instance Replicate Int64  where replicate = return; replicateConst = id;
instance Replicate Word8  where replicate = return; replicateConst = id;
instance Replicate Word16 where replicate = return; replicateConst = id;
instance Replicate Word32 where replicate = return; replicateConst = id;
instance Replicate Word64 where replicate = return; replicateConst = id;
instance (TypeNum.Positive n, LLVM.IsPrimitive a) => Replicate (Vector n a) where
{- crashes LLVM-2.5, seems to be fixed in LLVM-2.6 -}
   replicate x = do
      v <- singleton x
      LLVM.shufflevector v (LLVM.value LLVM.undef) LLVM.zero
{- crashes LLVM-2.5
   replicate x = do
      v <- LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 1)
      LLVM.shufflevector v (LLVM.value LLVM.undef) (constVector $ repeat $ LLVM.constOf 1)
-}
{- the (repeat zero) is also converted to 'zeroinitializer' and crashes LLVM compiler

         (constVector $ repeat LLVM.zero)
-}
{-
   replicate = Vector.replicate
-}
   replicateConst x = LLVM.constCyclicVector $ NonEmpty.Cons x []

singleton ::
   (LLVM.IsPrimitive a) =>
   Value a -> CodeGenFunction r (Value (Vector D1 a))
singleton x =
   LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 0)

replicateOf ::
   (IsConst (Scalar v), Replicate v) =>
   Scalar v -> Value v
replicateOf =
   LLVM.value . replicateConst . LLVM.constOf


class (LLVM.IsArithmetic a) => Real a where
   min :: Value a -> Value a -> CodeGenFunction r (Value a)
   max :: Value a -> Value a -> CodeGenFunction r (Value a)
   abs :: Value a -> CodeGenFunction r (Value a)
   signum :: Value a -> CodeGenFunction r (Value a)


instance Real Float  where
   min = zipAutoWith A.min X86.minss
   max = zipAutoWith A.max X86.maxss
   abs = mapAuto     A.abs X86.absss
   -- abs x = max x =<< LLVM.neg x
   -- abs x = A.abs
   signum = A.signum

instance Real Double where
   min = zipAutoWith A.min X86.minsd
   max = zipAutoWith A.max X86.maxsd
   abs = mapAuto     A.abs X86.abssd
   signum = A.signum

instance Real FP128  where
   min = A.min
   max = A.max
   abs = A.abs
   signum x = do
      minusOne <- LLVM.inttofp $ LLVM.valueOf (-1 :: Int8)
      one      <- LLVM.inttofp $ LLVM.valueOf ( 1 :: Int8)
      A.signumGen minusOne one x


infixl 1 `mapAuto`

{- |
There are functions that are intended for processing scalars
but have formally vector input and output.
This function breaks vector function down to a scalar function
by accessing the lowest vector element.
-}
runScalar ::
   (Vector.C v, Vector.C w, Size v ~ Size w) =>
   (v -> CodeGenFunction r w) ->
   (Element v -> CodeGenFunction r (Element w))
runScalar op a =
   Vector.extract (valueOf 0)
     =<< op
     =<< Vector.insert (valueOf 0) a Class.undefTuple

mapAuto ::
   (Vector.C v, Vector.C w, Size v ~ Size w) =>
   (Element v -> CodeGenFunction r (Element w)) ->
   Ext.T (v -> CodeGenFunction r w) ->
   (Element v -> CodeGenFunction r (Element w))
mapAuto f g a =
   Ext.run (f a) $
   Ext.with g $ \op -> runScalar op a

zipAutoWith ::
   (Vector.C u, Vector.C v, Vector.C w,
    Size u ~ Size v, Size v ~ Size w) =>
   (Element u -> Element v -> CodeGenFunction r (Element w)) ->
   Ext.T (u -> v -> CodeGenFunction r w) ->
   (Element u -> Element v -> CodeGenFunction r (Element w))
zipAutoWith f g =
   curry $ mapAuto (uncurry f) (fmap uncurry g)


instance Real Int8   where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Int16  where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Int32  where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Int64  where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Word8  where min = A.min; max = A.max; signum = A.signum; abs = return;
instance Real Word16 where min = A.min; max = A.max; signum = A.signum; abs = return;
instance Real Word32 where min = A.min; max = A.max; signum = A.signum; abs = return;
instance Real Word64 where min = A.min; max = A.max; signum = A.signum; abs = return;

instance (TypeNum.Positive n, Vector.Real a) =>
         Real (Vector n a) where
   min = Vector.min
   max = Vector.max
   abs = Vector.abs
   signum = Vector.signum



class
   (LLVM.IsArithmetic (Scalar v), LLVM.IsArithmetic v) =>
      PseudoModule v where
   scale :: (a ~ Scalar v) => Value a -> Value v -> CodeGenFunction r (Value v)
   scaleConst :: (a ~ Scalar v) => ConstValue a -> ConstValue v -> CodeGenFunction r (ConstValue v)

instance PseudoModule Word8  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Word16 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Word32 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Word64 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int8   where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int16  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int32  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int64  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Float  where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Double where scale = LLVM.mul; scaleConst = LLVM.mul
instance (LLVM.IsArithmetic a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
         PseudoModule (Vector n a) where
   scale a v = flip A.mul v =<< replicate a
   scaleConst a v = LLVM.mul (replicateConst a `asTypeOf` v) v



class (LLVM.IsConst a) => IntegerConstant a where
   constFromInteger :: Integer -> ConstValue a

instance IntegerConstant Word8  where constFromInteger = constOf . fromInteger
instance IntegerConstant Word16 where constFromInteger = constOf . fromInteger
instance IntegerConstant Word32 where constFromInteger = constOf . fromInteger
instance IntegerConstant Word64 where constFromInteger = constOf . fromInteger
instance IntegerConstant Int8   where constFromInteger = constOf . fromInteger
instance IntegerConstant Int16  where constFromInteger = constOf . fromInteger
instance IntegerConstant Int32  where constFromInteger = constOf . fromInteger
instance IntegerConstant Int64  where constFromInteger = constOf . fromInteger
instance IntegerConstant Float  where constFromInteger = constOf . fromInteger
instance IntegerConstant Double where constFromInteger = constOf . fromInteger
instance (IntegerConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
         IntegerConstant (Vector n a) where
   constFromInteger = replicateConst . constFromInteger


class (IntegerConstant a) => RationalConstant a where
   constFromRational :: Rational -> ConstValue a

instance RationalConstant Float  where constFromRational = constOf . fromRational
instance RationalConstant Double where constFromRational = constOf . fromRational
instance (RationalConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
         RationalConstant (Vector n a) where
   constFromRational = replicateConst . constFromRational


class (RationalConstant a) => TranscendentalConstant a where
   constPi :: ConstValue a

instance TranscendentalConstant Float  where constPi = constOf pi
instance TranscendentalConstant Double where constPi = constOf pi
instance (TranscendentalConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
         TranscendentalConstant (Vector n a) where
   constPi = replicateConst constPi