{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{- |
Support for unified handling of scalars and vectors.

Attention:
The rounding and fraction functions only work
for floating point values with maximum magnitude of @maxBound :: Int32@.
This way we save expensive handling of possibly seldom cases.
-}
module LLVM.Extra.ScalarOrVector (
   Fraction (truncate, fraction),
   signedFraction,
   addToPhase,
   incPhase,
   Replicate (replicate, replicateConst),
   replicateOf,
   Real (min, max, abs),
   ) where

import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Extension.X86 as X86
import qualified LLVM.Extra.Extension as Ext

import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Arithmetic as A

import qualified Data.TypeLevel.Num as TypeNum
import Data.TypeLevel.Num (D1, )
import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, ConstValue, valueOf,
    Vector, FP128,
    IsConst, IsFloating,
    CodeGenFunction, )

import Control.Monad.HT ((<=<), )

import Data.Word (Word8, Word16, Word32, Word64, )
import Data.Int  (Int8,  Int16,  Int32,  Int64, )

import Prelude hiding (Real, replicate, min, max, abs, truncate, floor, round, )


{-
class
   (IsFloating frac,
    IsInteger int,
    LLVM.NumberOfElements n frac,
    LLVM.NumberOfElements n int) =>
      Fraction n int frac | frac -> int, frac -> n, int -> n where
   fptosi :: Value frac -> CodeGenFunction r (Value int)
   fptosi = LLVM.fptosi
   sitofp :: Value int -> CodeGenFunction r (Value frac)
   sitofp = LLVM.sitofp
-}

{-
class
   (IsFloating frac) =>
      Fraction int frac | frac -> int where
   fptosi :: Value frac -> CodeGenFunction r (Value int)
   sitofp :: Value int -> CodeGenFunction r (Value frac)

instance Fraction Int32 Float where
   fptosi = LLVM.fptosi
   sitofp = LLVM.sitofp

instance Fraction Int64 Double where
   fptosi = LLVM.fptosi
   sitofp = LLVM.sitofp

instance (TypeNum.Pos n) =>
      Fraction (Vector n Int32) (Vector n Float) where
   fptosi = LLVM.fptosi
   sitofp = LLVM.sitofp

instance (TypeNum.Pos n) =>
      Fraction (Vector n Int64) (Vector n Double) where
   fptosi = LLVM.fptosi
   sitofp = LLVM.sitofp
-}


class (Real a, IsFloating a) => Fraction a where
   truncate :: Value a -> CodeGenFunction r (Value a)
   fraction :: Value a -> CodeGenFunction r (Value a)

instance Fraction Float where
   truncate =
      mapAuto
         (LLVM.sitofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptosi)
         (Ext.with X86.roundss $ \round x -> round x (valueOf 3))
   fraction =
      (\x ->
         fractionGen x
         `Ext.run`
         (Ext.with X86.cmpss $ \cmp ->
            fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x))
      `mapAuto`
      (Ext.with X86.roundss $ \round x ->
         A.sub x =<< round x (valueOf 1))

instance Fraction Double where
   truncate =
      mapAuto
         -- X86 only converts Double to Int32, it cannot target Int64
         (LLVM.sitofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptosi)
         (Ext.with X86.roundsd $ \round x -> round x (valueOf 3))
   fraction =
      (\x ->
         fractionGen x
         `Ext.run`
         (Ext.with X86.cmpsd $ \cmp ->
            fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x))
{-
For Doubles it would be more efficient to convert the lower 32 bit
instead of the lower 64 bit,
since x86 supports only conversion from 32 bit natively.
      (Ext.with X86.cmpsd $ \cmp -> fractionLogical
         (\x y -> cmp x y >>= LLVM.bitcastUnify )
-}
      `mapAuto`
      (Ext.with X86.roundsd $ \round x ->
         A.sub x =<< round x (valueOf 1))

instance (TypeNum.Pos n, Vector.Real a, IsFloating a, IsConst a) =>
      Fraction (Vector n a) where
   truncate = Vector.truncate
   fraction = Vector.fraction


{- |
The fraction has the same sign as the argument.
This is not particular useful but fast on IEEE implementations.
-}
signedFraction ::
   (Fraction a) =>
   Value a -> CodeGenFunction r (Value a)
signedFraction x =
   A.sub x =<< truncate x

fractionGen ::
   (Num a, Fraction v, Replicate a v, IsConst a, LLVM.CmpRet v b) =>
   Value v -> CodeGenFunction r (Value v)
fractionGen x =
   do xf <- signedFraction x
      b <- A.fcmp LLVM.FPOGE xf (LLVM.value LLVM.zero)
      LLVM.select b xf =<< A.add xf (replicateOf 1)

fractionLogical ::
   (Fraction a, LLVM.NumberOfElements D1 a,
    LLVM.IsInteger b, LLVM.NumberOfElements D1 b) =>
   (LLVM.FPPredicate ->
    Value a -> Value a -> CodeGenFunction r (Value b)) ->
   Value a -> CodeGenFunction r (Value a)
fractionLogical cmp x =
   do xf <- signedFraction x
      b <- cmp LLVM.FPOLT xf (LLVM.value LLVM.zero)
      A.sub xf =<< LLVM.sitofp b

{- |
increment (first operand) may be negative,
phase must always be non-negative
-}
addToPhase ::
   (Fraction a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
addToPhase d p =
   fraction =<< A.add d p

{- |
both increment and phase must be non-negative
-}
incPhase ::
   (Fraction a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
incPhase d p =
   signedFraction =<< A.add d p



class Replicate scalar vector | vector -> scalar where
   replicate :: Value scalar -> CodeGenFunction r (Value vector)
   replicateConst :: ConstValue scalar -> ConstValue vector

instance Replicate Float  Float  where replicate = return; replicateConst = id;
instance Replicate Double Double where replicate = return; replicateConst = id;
instance Replicate FP128  FP128  where replicate = return; replicateConst = id;
instance Replicate Bool   Bool   where replicate = return; replicateConst = id;
instance Replicate Int8   Int8   where replicate = return; replicateConst = id;
instance Replicate Int16  Int16  where replicate = return; replicateConst = id;
instance Replicate Int32  Int32  where replicate = return; replicateConst = id;
instance Replicate Int64  Int64  where replicate = return; replicateConst = id;
instance Replicate Word8  Word8  where replicate = return; replicateConst = id;
instance Replicate Word16 Word16 where replicate = return; replicateConst = id;
instance Replicate Word32 Word32 where replicate = return; replicateConst = id;
instance Replicate Word64 Word64 where replicate = return; replicateConst = id;
instance (TypeNum.Pos n, LLVM.IsPrimitive a) => Replicate a (Vector n a) where
{- crashes LLVM-2.5, seems to be fixed in LLVM-2.6 -}
   replicate x = do
      v <- singleton x
      LLVM.shufflevector v (LLVM.value LLVM.undef) LLVM.zero
{- crashes LLVM-2.5
   replicate x = do
      v <- LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 1)
      LLVM.shufflevector v (LLVM.value LLVM.undef) (constVector $ repeat $ LLVM.constOf 1)
-}
{- the (repeat zero) is also converted to 'zeroinitializer' and crashes LLVM compiler

         (constVector $ repeat LLVM.zero)
-}
{-
   replicate = Vector.replicate
-}
   replicateConst x = LLVM.constVector [x];

singleton ::
   (LLVM.IsPrimitive a) =>
   Value a -> CodeGenFunction r (Value (Vector D1 a))
singleton x =
   LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 0)

replicateOf ::
   (IsConst a, Replicate a v) =>
   a -> Value v
replicateOf a =
   LLVM.value (replicateConst (LLVM.constOf a))


class (LLVM.IsArithmetic a) => Real a where
   min :: Value a -> Value a -> CodeGenFunction r (Value a)
   max :: Value a -> Value a -> CodeGenFunction r (Value a)
   abs :: Value a -> CodeGenFunction r (Value a)


instance Real Float  where
   min = zipAutoWith A.min X86.minss
   max = zipAutoWith A.max X86.maxss
   abs = mapAuto     A.abs X86.absss
   -- abs x = max x =<< LLVM.neg x
   -- abs x = A.abs

instance Real Double where
   min = zipAutoWith A.min X86.minsd
   max = zipAutoWith A.max X86.maxsd
   abs = mapAuto     A.abs X86.abssd


infixl 1 `mapAuto`

{- |
There are functions that are intended for processing scalars
but have formally vector input and output.
This function breaks vector function down to a scalar function
by accessing the lowest vector element.
-}
runScalar ::
   (Vector.Access n a va, Vector.Access n b vb) =>
   (va -> CodeGenFunction r vb) ->
   (a -> CodeGenFunction r b)
runScalar op a =
   Vector.extract (valueOf 0)
     =<< op
     =<< Vector.insert (valueOf 0) a Class.undefTuple

mapAuto ::
   (Vector.Access n a va, Vector.Access n b vb) =>
   (a -> CodeGenFunction r b) ->
   Ext.T (va -> CodeGenFunction r vb) ->
   (a -> CodeGenFunction r b)
mapAuto f g a =
   Ext.run (f a) $
   Ext.with g $ \op -> runScalar op a

zipAutoWith ::
   (Vector.Access n a va, Vector.Access n b vb, Vector.Access n c vc) =>
   (a -> b -> CodeGenFunction r c) ->
   Ext.T (va -> vb -> CodeGenFunction r vc) ->
   (a -> b -> CodeGenFunction r c)
zipAutoWith f g =
   curry $ mapAuto (uncurry f) (fmap uncurry g)


instance Real FP128  where min = A.min; max = A.max; abs = A.abs;
instance Real Int8   where min = A.min; max = A.max; abs = A.abs;
instance Real Int16  where min = A.min; max = A.max; abs = A.abs;
instance Real Int32  where min = A.min; max = A.max; abs = A.abs;
instance Real Int64  where min = A.min; max = A.max; abs = A.abs;
instance Real Word8  where min = A.min; max = A.max; abs = return;
instance Real Word16 where min = A.min; max = A.max; abs = return;
instance Real Word32 where min = A.min; max = A.max; abs = return;
instance Real Word64 where min = A.min; max = A.max; abs = return;

instance (TypeNum.Pos n, Vector.Real a) =>
         Real (Vector n a) where
   min = Vector.min
   max = Vector.max
   abs = Vector.abs