{-# LANGUAGE TypeFamilies #-}
module LLVM.Extra.ScalarOrVectorPrivate where

import qualified LLVM.Extra.ArithmeticPrivate as A

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal (D1)

import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, ConstValue, valueOf,
    CmpRet, ShapeOf,
    Vector, WordN, IntN, FP128,
    IsConst, IsInteger, CodeGenFunction)

import qualified Data.NonEmpty as NonEmpty
import Data.Word (Word8, Word16, Word32, Word64)
import Data.Int  (Int8,  Int16,  Int32,  Int64)

import Prelude hiding (replicate)


type family Scalar vector :: *

type instance Scalar Float  = Float
type instance Scalar Double = Double
type instance Scalar FP128  = FP128
type instance Scalar Bool   = Bool
type instance Scalar Int8   = Int8
type instance Scalar Int16  = Int16
type instance Scalar Int32  = Int32
type instance Scalar Int64  = Int64
type instance Scalar Word8  = Word8
type instance Scalar Word16 = Word16
type instance Scalar Word32 = Word32
type instance Scalar Word64 = Word64
type instance Scalar (IntN  d) = IntN  d
type instance Scalar (WordN d) = WordN d
type instance Scalar (Vector n a) = a


class Replicate vector where
   -- | an alternative is using the 'Vector.Constant' vector type
   replicate :: Value (Scalar vector) -> CodeGenFunction r (Value vector)
   replicateConst :: ConstValue (Scalar vector) -> ConstValue vector

instance Replicate Float  where replicate = return; replicateConst = id;
instance Replicate Double where replicate = return; replicateConst = id;
instance Replicate FP128  where replicate = return; replicateConst = id;
instance Replicate Bool   where replicate = return; replicateConst = id;
instance Replicate Int8   where replicate = return; replicateConst = id;
instance Replicate Int16  where replicate = return; replicateConst = id;
instance Replicate Int32  where replicate = return; replicateConst = id;
instance Replicate Int64  where replicate = return; replicateConst = id;
instance Replicate Word8  where replicate = return; replicateConst = id;
instance Replicate Word16 where replicate = return; replicateConst = id;
instance Replicate Word32 where replicate = return; replicateConst = id;
instance Replicate Word64 where replicate = return; replicateConst = id;
instance Replicate (IntN  d) where replicate = return; replicateConst = id;
instance Replicate (WordN d) where replicate = return; replicateConst = id;
instance
   (TypeNum.Positive n, LLVM.IsPrimitive a) =>
      Replicate (Vector n a) where
   replicate x = do
      v <- singleton x
      LLVM.shufflevector v (LLVM.value LLVM.undef) LLVM.zero
   replicateConst x = LLVM.constCyclicVector $ NonEmpty.Cons x []

singleton ::
   (LLVM.IsPrimitive a) =>
   Value a -> CodeGenFunction r (Value (Vector D1 a))
singleton x =
   LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 0)


uaddSat, usubSat ::
   (IsInteger v, CmpRet v, Replicate v, Scalar v ~ a, IsConst a, Bounded a) =>
   Value v -> Value v -> CodeGenFunction r (Value v)
uaddSat x y = do
   z <- A.add x y
   wrong <- A.cmp LLVM.CmpLT z x
   maxBnd <- replicate $ valueOf maxBound
   LLVM.select wrong maxBnd z
usubSat x y = do
   z <- A.sub x y
   wrong <- A.cmp LLVM.CmpGT z x
   LLVM.select wrong (LLVM.value LLVM.zero) z

saddSat, ssubSat ::
   (IsInteger v, CmpRet v, Replicate v, ShapeOf v ~ shape,
    LLVM.ShapedType shape Bool ~ bv, ShapeOf bv ~ shape, CmpRet bv,
    Scalar v ~ a, IsConst a, Bounded a) =>
   Value v -> Value v -> CodeGenFunction r (Value v)

saddSat x y = do
   z <- A.add x y
   nonNegX <- A.cmp LLVM.CmpGE x $ LLVM.value LLVM.zero
   nonNegY <- A.cmp LLVM.CmpGE y $ LLVM.value LLVM.zero
   distinctSign <- A.cmp LLVM.CmpNE nonNegX nonNegY
   overflow <- A.cmp LLVM.CmpLT z x
   underflow <- A.cmp LLVM.CmpGT z x
   maxBnd <- replicate $ valueOf maxBound
   minBnd <- replicate $ valueOf minBound
   maxSat <- LLVM.select overflow maxBnd z
   minSat <- LLVM.select underflow minBnd z
   saturated <- LLVM.select nonNegX maxSat minSat
   LLVM.select distinctSign z saturated

ssubSat x y = do
   z <- A.sub x y
   nonNegX <- A.cmp LLVM.CmpGE x $ LLVM.value LLVM.zero
   nonNegY <- A.cmp LLVM.CmpGE y $ LLVM.value LLVM.zero
   sameSign <- A.cmp LLVM.CmpEQ nonNegX nonNegY
   overflow <- A.cmp LLVM.CmpLT z x
   underflow <- A.cmp LLVM.CmpGT z x
   maxBnd <- replicate $ valueOf maxBound
   minBnd <- replicate $ valueOf minBound
   maxSat <- LLVM.select overflow maxBnd z
   minSat <- LLVM.select underflow minBnd z
   saturated <- LLVM.select nonNegX maxSat minSat
   LLVM.select sameSign z saturated

saddSatLogical ::
   (IsInteger v, CmpRet v, Replicate v, ShapeOf v ~ shape,
    LLVM.ShapedType shape Bool ~ bv, ShapeOf bv ~ shape, CmpRet bv,
    IsInteger bv,
    Scalar v ~ a, IsConst a, Bounded a) =>
   Value v -> Value v -> CodeGenFunction r (Value v)
saddSatLogical x y = do
   z <- A.add x y
   nonNegX <- A.cmp LLVM.CmpGE x $ LLVM.value LLVM.zero
   nonNegY <- A.cmp LLVM.CmpGE y $ LLVM.value LLVM.zero
   distinctSign <- A.cmp LLVM.CmpNE nonNegX nonNegY
   minBnd <- replicate $ valueOf minBound
   maxBnd <- replicate $ valueOf maxBound
   bounds <- LLVM.select nonNegX maxBnd minBnd
   overflow <- A.cmp LLVM.CmpLT z y
   underflow <- A.cmp LLVM.CmpGT z y
   xflow <- LLVM.select nonNegX overflow underflow
   correctSum <- A.or distinctSign xflow
   LLVM.select correctSum z bounds