{-# LANGUAGE TypeFamilies #-} module LLVM.Extra.ScalarOrVectorPrivate where import qualified LLVM.Extra.ArithmeticPrivate as A import qualified Type.Data.Num.Decimal as TypeNum import Type.Data.Num.Decimal (D1) import qualified LLVM.Core as LLVM import LLVM.Core (Value, ConstValue, valueOf, CmpRet, ShapeOf, Vector, WordN, IntN, FP128, IsConst, IsInteger, CodeGenFunction) import qualified Data.NonEmpty as NonEmpty import Data.Word (Word8, Word16, Word32, Word64) import Data.Int (Int8, Int16, Int32, Int64) import Prelude hiding (replicate) type family Scalar vector :: * type instance Scalar Float = Float type instance Scalar Double = Double type instance Scalar FP128 = FP128 type instance Scalar Bool = Bool type instance Scalar Int8 = Int8 type instance Scalar Int16 = Int16 type instance Scalar Int32 = Int32 type instance Scalar Int64 = Int64 type instance Scalar Word8 = Word8 type instance Scalar Word16 = Word16 type instance Scalar Word32 = Word32 type instance Scalar Word64 = Word64 type instance Scalar (IntN d) = IntN d type instance Scalar (WordN d) = WordN d type instance Scalar (Vector n a) = a class Replicate vector where -- | an alternative is using the 'Vector.Constant' vector type replicate :: Value (Scalar vector) -> CodeGenFunction r (Value vector) replicateConst :: ConstValue (Scalar vector) -> ConstValue vector instance Replicate Float where replicate = return; replicateConst = id; instance Replicate Double where replicate = return; replicateConst = id; instance Replicate FP128 where replicate = return; replicateConst = id; instance Replicate Bool where replicate = return; replicateConst = id; instance Replicate Int8 where replicate = return; replicateConst = id; instance Replicate Int16 where replicate = return; replicateConst = id; instance Replicate Int32 where replicate = return; replicateConst = id; instance Replicate Int64 where replicate = return; replicateConst = id; instance Replicate Word8 where replicate = return; replicateConst = id; instance Replicate Word16 where replicate = return; replicateConst = id; instance Replicate Word32 where replicate = return; replicateConst = id; instance Replicate Word64 where replicate = return; replicateConst = id; instance Replicate (IntN d) where replicate = return; replicateConst = id; instance Replicate (WordN d) where replicate = return; replicateConst = id; instance (TypeNum.Positive n, LLVM.IsPrimitive a) => Replicate (Vector n a) where replicate x = do v <- singleton x LLVM.shufflevector v (LLVM.value LLVM.undef) LLVM.zero replicateConst x = LLVM.constCyclicVector $ NonEmpty.Cons x [] singleton :: (LLVM.IsPrimitive a) => Value a -> CodeGenFunction r (Value (Vector D1 a)) singleton x = LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 0) uaddSat, usubSat :: (IsInteger v, CmpRet v, Replicate v, Scalar v ~ a, IsConst a, Bounded a) => Value v -> Value v -> CodeGenFunction r (Value v) uaddSat x y = do z <- A.add x y wrong <- A.cmp LLVM.CmpLT z x maxBnd <- replicate $ valueOf maxBound LLVM.select wrong maxBnd z usubSat x y = do z <- A.sub x y wrong <- A.cmp LLVM.CmpGT z x LLVM.select wrong (LLVM.value LLVM.zero) z saddSat, ssubSat :: (IsInteger v, CmpRet v, Replicate v, ShapeOf v ~ shape, LLVM.ShapedType shape Bool ~ bv, ShapeOf bv ~ shape, CmpRet bv, Scalar v ~ a, IsConst a, Bounded a) => Value v -> Value v -> CodeGenFunction r (Value v) saddSat x y = do z <- A.add x y nonNegX <- A.cmp LLVM.CmpGE x $ LLVM.value LLVM.zero nonNegY <- A.cmp LLVM.CmpGE y $ LLVM.value LLVM.zero distinctSign <- A.cmp LLVM.CmpNE nonNegX nonNegY overflow <- A.cmp LLVM.CmpLT z x underflow <- A.cmp LLVM.CmpGT z x maxBnd <- replicate $ valueOf maxBound minBnd <- replicate $ valueOf minBound maxSat <- LLVM.select overflow maxBnd z minSat <- LLVM.select underflow minBnd z saturated <- LLVM.select nonNegX maxSat minSat LLVM.select distinctSign z saturated ssubSat x y = do z <- A.sub x y nonNegX <- A.cmp LLVM.CmpGE x $ LLVM.value LLVM.zero nonNegY <- A.cmp LLVM.CmpGE y $ LLVM.value LLVM.zero sameSign <- A.cmp LLVM.CmpEQ nonNegX nonNegY overflow <- A.cmp LLVM.CmpLT z x underflow <- A.cmp LLVM.CmpGT z x maxBnd <- replicate $ valueOf maxBound minBnd <- replicate $ valueOf minBound maxSat <- LLVM.select overflow maxBnd z minSat <- LLVM.select underflow minBnd z saturated <- LLVM.select nonNegX maxSat minSat LLVM.select sameSign z saturated saddSatLogical :: (IsInteger v, CmpRet v, Replicate v, ShapeOf v ~ shape, LLVM.ShapedType shape Bool ~ bv, ShapeOf bv ~ shape, CmpRet bv, IsInteger bv, Scalar v ~ a, IsConst a, Bounded a) => Value v -> Value v -> CodeGenFunction r (Value v) saddSatLogical x y = do z <- A.add x y nonNegX <- A.cmp LLVM.CmpGE x $ LLVM.value LLVM.zero nonNegY <- A.cmp LLVM.CmpGE y $ LLVM.value LLVM.zero distinctSign <- A.cmp LLVM.CmpNE nonNegX nonNegY minBnd <- replicate $ valueOf minBound maxBnd <- replicate $ valueOf maxBound bounds <- LLVM.select nonNegX maxBnd minBnd overflow <- A.cmp LLVM.CmpLT z y underflow <- A.cmp LLVM.CmpGT z y xflow <- LLVM.select nonNegX overflow underflow correctSum <- A.or distinctSign xflow LLVM.select correctSum z bounds