module LLVM.Extra.ScalarOrVectorPrivate where
import qualified LLVM.Extra.ArithmeticPrivate as A
import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal (D1)
import qualified LLVM.Core as LLVM
import LLVM.Core
(Value, ConstValue, valueOf,
CmpRet, ShapeOf,
Vector, WordN, IntN, FP128,
IsConst, IsInteger, CodeGenFunction)
import qualified Data.NonEmpty as NonEmpty
import Data.Word (Word8, Word16, Word32, Word64)
import Data.Int (Int8, Int16, Int32, Int64)
import Prelude hiding (replicate)
type family Scalar vector :: *
type instance Scalar Float = Float
type instance Scalar Double = Double
type instance Scalar FP128 = FP128
type instance Scalar Bool = Bool
type instance Scalar Int8 = Int8
type instance Scalar Int16 = Int16
type instance Scalar Int32 = Int32
type instance Scalar Int64 = Int64
type instance Scalar Word8 = Word8
type instance Scalar Word16 = Word16
type instance Scalar Word32 = Word32
type instance Scalar Word64 = Word64
type instance Scalar (IntN d) = IntN d
type instance Scalar (WordN d) = WordN d
type instance Scalar (Vector n a) = a
class Replicate vector where
replicate :: Value (Scalar vector) -> CodeGenFunction r (Value vector)
replicateConst :: ConstValue (Scalar vector) -> ConstValue vector
instance Replicate Float where replicate = return; replicateConst = id;
instance Replicate Double where replicate = return; replicateConst = id;
instance Replicate FP128 where replicate = return; replicateConst = id;
instance Replicate Bool where replicate = return; replicateConst = id;
instance Replicate Int8 where replicate = return; replicateConst = id;
instance Replicate Int16 where replicate = return; replicateConst = id;
instance Replicate Int32 where replicate = return; replicateConst = id;
instance Replicate Int64 where replicate = return; replicateConst = id;
instance Replicate Word8 where replicate = return; replicateConst = id;
instance Replicate Word16 where replicate = return; replicateConst = id;
instance Replicate Word32 where replicate = return; replicateConst = id;
instance Replicate Word64 where replicate = return; replicateConst = id;
instance Replicate (IntN d) where replicate = return; replicateConst = id;
instance Replicate (WordN d) where replicate = return; replicateConst = id;
instance
(TypeNum.Positive n, LLVM.IsPrimitive a) =>
Replicate (Vector n a) where
replicate x = do
v <- singleton x
LLVM.shufflevector v (LLVM.value LLVM.undef) LLVM.zero
replicateConst x = LLVM.constCyclicVector $ NonEmpty.Cons x []
singleton ::
(LLVM.IsPrimitive a) =>
Value a -> CodeGenFunction r (Value (Vector D1 a))
singleton x =
LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 0)
uaddSat, usubSat ::
(IsInteger v, CmpRet v, Replicate v, Scalar v ~ a, IsConst a, Bounded a) =>
Value v -> Value v -> CodeGenFunction r (Value v)
uaddSat x y = do
z <- A.add x y
wrong <- A.cmp LLVM.CmpLT z x
maxBnd <- replicate $ valueOf maxBound
LLVM.select wrong maxBnd z
usubSat x y = do
z <- A.sub x y
wrong <- A.cmp LLVM.CmpGT z x
LLVM.select wrong (LLVM.value LLVM.zero) z
saddSat, ssubSat ::
(IsInteger v, CmpRet v, Replicate v, ShapeOf v ~ shape,
LLVM.ShapedType shape Bool ~ bv, ShapeOf bv ~ shape, CmpRet bv,
Scalar v ~ a, IsConst a, Bounded a) =>
Value v -> Value v -> CodeGenFunction r (Value v)
saddSat x y = do
z <- A.add x y
nonNegX <- A.cmp LLVM.CmpGE x $ LLVM.value LLVM.zero
nonNegY <- A.cmp LLVM.CmpGE y $ LLVM.value LLVM.zero
distinctSign <- A.cmp LLVM.CmpNE nonNegX nonNegY
overflow <- A.cmp LLVM.CmpLT z x
underflow <- A.cmp LLVM.CmpGT z x
maxBnd <- replicate $ valueOf maxBound
minBnd <- replicate $ valueOf minBound
maxSat <- LLVM.select overflow maxBnd z
minSat <- LLVM.select underflow minBnd z
saturated <- LLVM.select nonNegX maxSat minSat
LLVM.select distinctSign z saturated
ssubSat x y = do
z <- A.sub x y
nonNegX <- A.cmp LLVM.CmpGE x $ LLVM.value LLVM.zero
nonNegY <- A.cmp LLVM.CmpGE y $ LLVM.value LLVM.zero
sameSign <- A.cmp LLVM.CmpEQ nonNegX nonNegY
overflow <- A.cmp LLVM.CmpLT z x
underflow <- A.cmp LLVM.CmpGT z x
maxBnd <- replicate $ valueOf maxBound
minBnd <- replicate $ valueOf minBound
maxSat <- LLVM.select overflow maxBnd z
minSat <- LLVM.select underflow minBnd z
saturated <- LLVM.select nonNegX maxSat minSat
LLVM.select sameSign z saturated
saddSatLogical ::
(IsInteger v, CmpRet v, Replicate v, ShapeOf v ~ shape,
LLVM.ShapedType shape Bool ~ bv, ShapeOf bv ~ shape, CmpRet bv,
IsInteger bv,
Scalar v ~ a, IsConst a, Bounded a) =>
Value v -> Value v -> CodeGenFunction r (Value v)
saddSatLogical x y = do
z <- A.add x y
nonNegX <- A.cmp LLVM.CmpGE x $ LLVM.value LLVM.zero
nonNegY <- A.cmp LLVM.CmpGE y $ LLVM.value LLVM.zero
distinctSign <- A.cmp LLVM.CmpNE nonNegX nonNegY
minBnd <- replicate $ valueOf minBound
maxBnd <- replicate $ valueOf maxBound
bounds <- LLVM.select nonNegX maxBnd minBnd
overflow <- A.cmp LLVM.CmpLT z y
underflow <- A.cmp LLVM.CmpGT z y
xflow <- LLVM.select nonNegX overflow underflow
correctSum <- A.or distinctSign xflow
LLVM.select correctSum z bounds