{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {- | Support for unified handling of scalars and vectors. Attention: The rounding and fraction functions only work for floating point values with maximum magnitude of @maxBound :: Int32@. This way we save expensive handling of possibly seldom cases. -} module LLVM.Extra.ScalarOrVector ( Fraction (truncate, fraction), signedFraction, addToPhase, incPhase, truncateToInt, floorToInt, ceilingToInt, roundToIntFast, splitFractionToInt, Scalar, Replicate (replicate, replicateConst), replicateOf, Real (min, max, abs, signum), Saturated(addSat, subSat), PseudoModule (scale, scaleConst), IntegerConstant(constFromInteger), RationalConstant(constFromRational), TranscendentalConstant(constPi), ) where import qualified LLVM.Extra.ScalarOrVectorPrivate as Priv import qualified LLVM.Extra.Vector as Vector import qualified LLVM.Extra.ArithmeticPrivate as A import LLVM.Extra.ScalarOrVectorPrivate (Scalar, Replicate(replicate, replicateConst)) import qualified LLVM.Util.Intrinsic as Intrinsic import qualified LLVM.Util.Proxy as LP import qualified LLVM.Core as LLVM import LLVM.Core (Value, ConstValue, constOf, CmpRet, CmpResult, ShapeOf, Vector, WordN(WordN), IntN(IntN), FP128, IsConst, IsInteger, IsFloating, CodeGenFunction, ) import qualified Type.Data.Num.Decimal as TypeNum import Data.Word (Word8, Word16, Word32, Word64, Word) import Data.Int (Int8, Int16, Int32, Int64, ) import Data.Maybe (fromMaybe) import Prelude hiding (Real, replicate, min, max, abs, truncate) class (Real a, IsFloating a) => Fraction a where truncate :: Value a -> CodeGenFunction r (Value a) fraction :: Value a -> CodeGenFunction r (Value a) instance Fraction Float where truncate = Intrinsic.truncate fraction = A.fraction instance Fraction Double where truncate = Intrinsic.truncate fraction = A.fraction instance (TypeNum.Positive n, Vector.Real a, IsFloating a, IsConst a) => Fraction (Vector n a) where truncate = Vector.truncate fraction = Vector.fraction {- | The fraction has the same sign as the argument. This is not particular useful but fast on IEEE implementations. -} signedFraction :: (Fraction a) => Value a -> CodeGenFunction r (Value a) signedFraction x = A.sub x =<< truncate x _fractionGen :: (IntegerConstant v, Fraction v, CmpRet v) => Value v -> CodeGenFunction r (Value v) _fractionGen x = do xf <- signedFraction x b <- A.fcmp LLVM.FPOGE xf zero LLVM.select b xf =<< A.add xf (LLVM.value $ constFromInteger 1) _fractionLogical :: (Fraction a, LLVM.IsPrimitive a, IsInteger b, LLVM.IsPrimitive b) => (LLVM.FPPredicate -> Value a -> Value a -> CodeGenFunction r (Value b)) -> Value a -> CodeGenFunction r (Value a) _fractionLogical cmp x = do xf <- signedFraction x b <- cmp LLVM.FPOLT xf zero A.sub xf =<< LLVM.inttofp b {- | increment (first operand) may be negative, phase must always be non-negative -} addToPhase :: (Fraction a) => Value a -> Value a -> CodeGenFunction r (Value a) addToPhase d p = fraction =<< A.add d p {- | both increment and phase must be non-negative -} incPhase :: (Fraction a) => Value a -> Value a -> CodeGenFunction r (Value a) incPhase d p = signedFraction =<< A.add d p truncateToInt :: (IsFloating a, IsInteger i, ShapeOf a ~ ShapeOf i) => Value a -> CodeGenFunction r (Value i) truncateToInt = LLVM.fptoint {- | Rounds to the next integer. For numbers of the form @n+0.5@, we choose one of the neighboured integers such that the overall implementation is most efficient. -} roundToIntFast :: (IsFloating a, RationalConstant a, CmpRet a, IsInteger i, IntegerConstant i, CmpRet i, CmpResult a ~ CmpResult i, ShapeOf a ~ ShapeOf i) => Value a -> CodeGenFunction r (Value i) roundToIntFast x = do pos <- A.cmp LLVM.CmpGT x zero truncateToInt =<< A.add x =<< LLVM.select pos (ratio 0.5) (ratio (-0.5)) floorToInt :: (IsFloating a, CmpRet a, IsInteger i, IntegerConstant i, CmpRet i, CmpResult a ~ CmpResult i, ShapeOf a ~ ShapeOf i) => Value a -> CodeGenFunction r (Value i) floorToInt x = do i <- truncateToInt x lt <- A.cmp LLVM.CmpLT x =<< LLVM.inttofp i A.sub i =<< LLVM.select lt (int 1) (int 0) splitFractionToInt :: (IsFloating a, CmpRet a, IsInteger i, IntegerConstant i, CmpRet i, CmpResult a ~ CmpResult i, ShapeOf a ~ ShapeOf i) => Value a -> CodeGenFunction r (Value i, Value a) splitFractionToInt x = do i <- floorToInt x frac <- A.sub x =<< LLVM.inttofp i return (i, frac) ceilingToInt :: (IsFloating a, CmpRet a, IsInteger i, IntegerConstant i, CmpRet i, CmpResult a ~ CmpResult i, ShapeOf a ~ ShapeOf i) => Value a -> CodeGenFunction r (Value i) ceilingToInt x = do i <- truncateToInt x gt <- A.cmp LLVM.CmpGT x =<< LLVM.inttofp i A.add i =<< LLVM.select gt (int 1) (int 0) zero :: (LLVM.IsType a) => Value a zero = LLVM.value LLVM.zero int :: (IntegerConstant a) => Integer -> Value a int = LLVM.value . constFromInteger ratio :: (RationalConstant a) => Rational -> Value a ratio = LLVM.value . constFromRational replicateOf :: (IsConst (Scalar v), Replicate v) => Scalar v -> Value v replicateOf = LLVM.value . replicateConst . LLVM.constOf class (LLVM.IsArithmetic a) => Real a where min :: Value a -> Value a -> CodeGenFunction r (Value a) max :: Value a -> Value a -> CodeGenFunction r (Value a) abs :: Value a -> CodeGenFunction r (Value a) signum :: Value a -> CodeGenFunction r (Value a) instance Real Float where min = Intrinsic.min max = Intrinsic.max abs = Intrinsic.abs signum = A.signum instance Real Double where min = Intrinsic.min max = Intrinsic.max abs = Intrinsic.abs signum = A.signum instance Real FP128 where min = Intrinsic.min max = Intrinsic.max abs = Intrinsic.abs signum x = do minusOne <- LLVM.inttofp $ LLVM.valueOf (-1 :: Int8) one <- LLVM.inttofp $ LLVM.valueOf ( 1 :: Int8) A.signumGen minusOne one x instance Real Int where min = A.min; max = A.max; signum = A.signum; abs = A.abs; instance Real Int8 where min = A.min; max = A.max; signum = A.signum; abs = A.abs; instance Real Int16 where min = A.min; max = A.max; signum = A.signum; abs = A.abs; instance Real Int32 where min = A.min; max = A.max; signum = A.signum; abs = A.abs; instance Real Int64 where min = A.min; max = A.max; signum = A.signum; abs = A.abs; instance Real Word where min = A.min; max = A.max; signum = A.signum; abs = return; instance Real Word8 where min = A.min; max = A.max; signum = A.signum; abs = return; instance Real Word16 where min = A.min; max = A.max; signum = A.signum; abs = return; instance Real Word32 where min = A.min; max = A.max; signum = A.signum; abs = return; instance Real Word64 where min = A.min; max = A.max; signum = A.signum; abs = return; instance (TypeNum.Positive n) => Real (IntN n) where min = A.min; max = A.max; abs = A.abs signum = A.signumGen (LLVM.valueOf $ IntN (-1)) (LLVM.valueOf $ IntN 1) instance (TypeNum.Positive n) => Real (WordN n) where min = A.min; max = A.max; abs = return signum = A.signumGen (LLVM.value LLVM.undef) (LLVM.valueOf $ WordN 1) instance (TypeNum.Positive n, Vector.Real a) => Real (Vector n a) where min = Vector.min max = Vector.max abs = Vector.abs signum = Vector.signum class (IsInteger a) => Saturated a where addSat, subSat :: Value a -> Value a -> CodeGenFunction r (Value a) instance Saturated Int where addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy; instance Saturated Int8 where addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy; instance Saturated Int16 where addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy; instance Saturated Int32 where addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy; instance Saturated Int64 where addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy; instance Saturated Word where addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy; instance Saturated Word8 where addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy; instance Saturated Word16 where addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy; instance Saturated Word32 where addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy; instance Saturated Word64 where addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy; instance (TypeNum.Positive d) => Saturated (IntN d) where addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy; instance (TypeNum.Positive d) => Saturated (WordN d) where addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy; instance (TypeNum.Positive n, LLVM.IsPrimitive a, Saturated a, Bounded a, CmpRet a, IsConst a) => Saturated (Vector n a) where addSat = addSatProxy LP.Proxy; subSat = subSatProxy LP.Proxy; addSatProxy, subSatProxy :: (IsInteger v, CmpRet v, Replicate v, ShapeOf v ~ shape, LLVM.ShapedType shape Bool ~ bv, ShapeOf bv ~ shape, CmpRet bv, Scalar v ~ a, IsConst a, Bounded a) => LP.Proxy v -> Value v -> Value v -> CodeGenFunction r (Value v) addSatProxy proxy = if LLVM.isSigned proxy then fromMaybe Priv.saddSat Intrinsic.maybeSAddSat else fromMaybe Priv.uaddSat Intrinsic.maybeUAddSat subSatProxy proxy = if LLVM.isSigned proxy then fromMaybe Priv.ssubSat Intrinsic.maybeSSubSat else fromMaybe Priv.usubSat Intrinsic.maybeUSubSat class (LLVM.IsArithmetic (Scalar v), LLVM.IsArithmetic v) => PseudoModule v where scale :: (a ~ Scalar v) => Value a -> Value v -> CodeGenFunction r (Value v) scaleConst :: (a ~ Scalar v) => ConstValue a -> ConstValue v -> CodeGenFunction r (ConstValue v) instance PseudoModule Word where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Word8 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Word16 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Word32 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Word64 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Int where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Int8 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Int16 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Int32 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Int64 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Float where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Double where scale = LLVM.mul; scaleConst = LLVM.mul instance (LLVM.IsArithmetic a, LLVM.IsPrimitive a, TypeNum.Positive n) => PseudoModule (Vector n a) where scale a v = flip A.mul v =<< replicate a scaleConst a v = LLVM.mul (replicateConst a `asTypeOf` v) v class (LLVM.IsConst a) => IntegerConstant a where constFromInteger :: Integer -> ConstValue a instance IntegerConstant Word where constFromInteger = constOf . fromInteger instance IntegerConstant Word8 where constFromInteger = constOf . fromInteger instance IntegerConstant Word16 where constFromInteger = constOf . fromInteger instance IntegerConstant Word32 where constFromInteger = constOf . fromInteger instance IntegerConstant Word64 where constFromInteger = constOf . fromInteger instance IntegerConstant Int where constFromInteger = constOf . fromInteger instance IntegerConstant Int8 where constFromInteger = constOf . fromInteger instance IntegerConstant Int16 where constFromInteger = constOf . fromInteger instance IntegerConstant Int32 where constFromInteger = constOf . fromInteger instance IntegerConstant Int64 where constFromInteger = constOf . fromInteger instance IntegerConstant Float where constFromInteger = constOf . fromInteger instance IntegerConstant Double where constFromInteger = constOf . fromInteger instance (TypeNum.Positive n) => IntegerConstant (WordN n) where constFromInteger = constOf . WordN instance (TypeNum.Positive n) => IntegerConstant (IntN n) where constFromInteger = constOf . IntN instance (IntegerConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) => IntegerConstant (Vector n a) where constFromInteger = replicateConst . constFromInteger class (IntegerConstant a) => RationalConstant a where constFromRational :: Rational -> ConstValue a instance RationalConstant Float where constFromRational = constOf . fromRational instance RationalConstant Double where constFromRational = constOf . fromRational instance (RationalConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) => RationalConstant (Vector n a) where constFromRational = replicateConst . constFromRational class (RationalConstant a) => TranscendentalConstant a where constPi :: ConstValue a instance TranscendentalConstant Float where constPi = constOf pi instance TranscendentalConstant Double where constPi = constOf pi instance (TranscendentalConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) => TranscendentalConstant (Vector n a) where constPi = replicateConst constPi