{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {- | Support for unified handling of scalars and vectors. Attention: The rounding and fraction functions only work for floating point values with maximum magnitude of @maxBound :: Int32@. This way we save expensive handling of possibly seldom cases. -} module LLVM.Extra.ScalarOrVector ( Fraction (truncate, fraction), signedFraction, addToPhase, incPhase, Scalar, Replicate (replicate, replicateConst), replicateOf, Real (min, max, abs, signum), PseudoModule (scale, scaleConst), IntegerConstant(constFromInteger), RationalConstant(constFromRational), TranscendentalConstant(constPi), ) where import LLVM.Extra.Vector (Element, Size, ) import qualified LLVM.Extra.Vector as Vector import qualified LLVM.Extra.Extension.X86 as X86 import qualified LLVM.Extra.Extension as Ext import qualified LLVM.Extra.Class as Class import qualified LLVM.Extra.ArithmeticPrivate as A import qualified Type.Data.Num.Decimal as TypeNum import Type.Data.Num.Decimal (D1, ) import qualified LLVM.Core as LLVM import LLVM.Core (Value, ConstValue, valueOf, constOf, Vector, FP128, IsConst, IsFloating, CodeGenFunction, ) import Control.Monad.HT ((<=<), ) import qualified Data.NonEmpty as NonEmpty import Data.Word (Word8, Word16, Word32, Word64, ) import Data.Int (Int8, Int16, Int32, Int64, ) import Prelude hiding (Real, replicate, min, max, abs, truncate, floor, round, ) class (Real a, IsFloating a) => Fraction a where truncate :: Value a -> CodeGenFunction r (Value a) fraction :: Value a -> CodeGenFunction r (Value a) instance Fraction Float where truncate = mapAuto (LLVM.inttofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptoint) (Ext.with X86.roundss $ \round x -> round x (valueOf 3)) fraction = (\x -> fractionGen x `Ext.run` (Ext.with X86.cmpss $ \cmp -> fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x)) `mapAuto` (Ext.with X86.roundss $ \round x -> A.sub x =<< round x (valueOf 1)) instance Fraction Double where truncate = mapAuto -- X86 only converts Double to Int32, it cannot target Int64 (LLVM.inttofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptoint) (Ext.with X86.roundsd $ \round x -> round x (valueOf 3)) fraction = (\x -> fractionGen x `Ext.run` (Ext.with X86.cmpsd $ \cmp -> fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x)) {- For Doubles it would be more efficient to convert the lower 32 bit instead of the lower 64 bit, since x86 supports only conversion from 32 bit natively. (Ext.with X86.cmpsd $ \cmp -> fractionLogical (\x y -> cmp x y >>= LLVM.bitcast ) -} `mapAuto` (Ext.with X86.roundsd $ \round x -> A.sub x =<< round x (valueOf 1)) instance (TypeNum.Positive n, Vector.Real a, IsFloating a, IsConst a) => Fraction (Vector n a) where truncate = Vector.truncate fraction = Vector.fraction {- | The fraction has the same sign as the argument. This is not particular useful but fast on IEEE implementations. -} signedFraction :: (Fraction a) => Value a -> CodeGenFunction r (Value a) signedFraction x = A.sub x =<< truncate x fractionGen :: (IntegerConstant v, Fraction v, LLVM.CmpRet v) => Value v -> CodeGenFunction r (Value v) fractionGen x = do xf <- signedFraction x b <- A.fcmp LLVM.FPOGE xf (LLVM.value LLVM.zero) LLVM.select b xf =<< A.add xf (LLVM.value $ constFromInteger 1) fractionLogical :: (Fraction a, LLVM.IsScalarOrVector a, LLVM.NumberOfElements a ~ D1, LLVM.IsInteger b, LLVM.IsScalarOrVector b, LLVM.NumberOfElements b ~ D1) => (LLVM.FPPredicate -> Value a -> Value a -> CodeGenFunction r (Value b)) -> Value a -> CodeGenFunction r (Value a) fractionLogical cmp x = do xf <- signedFraction x b <- cmp LLVM.FPOLT xf (LLVM.value LLVM.zero) A.sub xf =<< LLVM.inttofp b {- | increment (first operand) may be negative, phase must always be non-negative -} addToPhase :: (Fraction a) => Value a -> Value a -> CodeGenFunction r (Value a) addToPhase d p = fraction =<< A.add d p {- | both increment and phase must be non-negative -} incPhase :: (Fraction a) => Value a -> Value a -> CodeGenFunction r (Value a) incPhase d p = signedFraction =<< A.add d p type family Scalar vector :: * type instance Scalar Float = Float type instance Scalar Double = Double type instance Scalar FP128 = FP128 type instance Scalar Bool = Bool type instance Scalar Int8 = Int8 type instance Scalar Int16 = Int16 type instance Scalar Int32 = Int32 type instance Scalar Int64 = Int64 type instance Scalar Word8 = Word8 type instance Scalar Word16 = Word16 type instance Scalar Word32 = Word32 type instance Scalar Word64 = Word64 type instance Scalar (Vector n a) = a class Replicate vector where -- | an alternative is using the 'Vector.Constant' vector type replicate :: Value (Scalar vector) -> CodeGenFunction r (Value vector) replicateConst :: ConstValue (Scalar vector) -> ConstValue vector instance Replicate Float where replicate = return; replicateConst = id; instance Replicate Double where replicate = return; replicateConst = id; instance Replicate FP128 where replicate = return; replicateConst = id; instance Replicate Bool where replicate = return; replicateConst = id; instance Replicate Int8 where replicate = return; replicateConst = id; instance Replicate Int16 where replicate = return; replicateConst = id; instance Replicate Int32 where replicate = return; replicateConst = id; instance Replicate Int64 where replicate = return; replicateConst = id; instance Replicate Word8 where replicate = return; replicateConst = id; instance Replicate Word16 where replicate = return; replicateConst = id; instance Replicate Word32 where replicate = return; replicateConst = id; instance Replicate Word64 where replicate = return; replicateConst = id; instance (TypeNum.Positive n, LLVM.IsPrimitive a) => Replicate (Vector n a) where {- crashes LLVM-2.5, seems to be fixed in LLVM-2.6 -} replicate x = do v <- singleton x LLVM.shufflevector v (LLVM.value LLVM.undef) LLVM.zero {- crashes LLVM-2.5 replicate x = do v <- LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 1) LLVM.shufflevector v (LLVM.value LLVM.undef) (constVector $ repeat $ LLVM.constOf 1) -} {- the (repeat zero) is also converted to 'zeroinitializer' and crashes LLVM compiler (constVector $ repeat LLVM.zero) -} {- replicate = Vector.replicate -} replicateConst x = LLVM.constCyclicVector $ NonEmpty.Cons x [] singleton :: (LLVM.IsPrimitive a) => Value a -> CodeGenFunction r (Value (Vector D1 a)) singleton x = LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 0) replicateOf :: (IsConst (Scalar v), Replicate v) => Scalar v -> Value v replicateOf = LLVM.value . replicateConst . LLVM.constOf class (LLVM.IsArithmetic a) => Real a where min :: Value a -> Value a -> CodeGenFunction r (Value a) max :: Value a -> Value a -> CodeGenFunction r (Value a) abs :: Value a -> CodeGenFunction r (Value a) signum :: Value a -> CodeGenFunction r (Value a) instance Real Float where min = zipAutoWith A.min X86.minss max = zipAutoWith A.max X86.maxss abs = mapAuto A.abs X86.absss -- abs x = max x =<< LLVM.neg x -- abs x = A.abs signum = A.signum instance Real Double where min = zipAutoWith A.min X86.minsd max = zipAutoWith A.max X86.maxsd abs = mapAuto A.abs X86.abssd signum = A.signum instance Real FP128 where min = A.min max = A.max abs = A.abs signum x = do minusOne <- LLVM.inttofp $ LLVM.valueOf (-1 :: Int8) one <- LLVM.inttofp $ LLVM.valueOf ( 1 :: Int8) A.signumGen minusOne one x infixl 1 `mapAuto` {- | There are functions that are intended for processing scalars but have formally vector input and output. This function breaks vector function down to a scalar function by accessing the lowest vector element. -} runScalar :: (Vector.C v, Vector.C w, Size v ~ Size w) => (v -> CodeGenFunction r w) -> (Element v -> CodeGenFunction r (Element w)) runScalar op a = Vector.extract (valueOf 0) =<< op =<< Vector.insert (valueOf 0) a Class.undefTuple mapAuto :: (Vector.C v, Vector.C w, Size v ~ Size w) => (Element v -> CodeGenFunction r (Element w)) -> Ext.T (v -> CodeGenFunction r w) -> (Element v -> CodeGenFunction r (Element w)) mapAuto f g a = Ext.run (f a) $ Ext.with g $ \op -> runScalar op a zipAutoWith :: (Vector.C u, Vector.C v, Vector.C w, Size u ~ Size v, Size v ~ Size w) => (Element u -> Element v -> CodeGenFunction r (Element w)) -> Ext.T (u -> v -> CodeGenFunction r w) -> (Element u -> Element v -> CodeGenFunction r (Element w)) zipAutoWith f g = curry $ mapAuto (uncurry f) (fmap uncurry g) instance Real Int8 where min = A.min; max = A.max; signum = A.signum; abs = A.abs; instance Real Int16 where min = A.min; max = A.max; signum = A.signum; abs = A.abs; instance Real Int32 where min = A.min; max = A.max; signum = A.signum; abs = A.abs; instance Real Int64 where min = A.min; max = A.max; signum = A.signum; abs = A.abs; instance Real Word8 where min = A.min; max = A.max; signum = A.signum; abs = return; instance Real Word16 where min = A.min; max = A.max; signum = A.signum; abs = return; instance Real Word32 where min = A.min; max = A.max; signum = A.signum; abs = return; instance Real Word64 where min = A.min; max = A.max; signum = A.signum; abs = return; instance (TypeNum.Positive n, Vector.Real a) => Real (Vector n a) where min = Vector.min max = Vector.max abs = Vector.abs signum = Vector.signum class (LLVM.IsArithmetic (Scalar v), LLVM.IsArithmetic v) => PseudoModule v where scale :: (a ~ Scalar v) => Value a -> Value v -> CodeGenFunction r (Value v) scaleConst :: (a ~ Scalar v) => ConstValue a -> ConstValue v -> CodeGenFunction r (ConstValue v) instance PseudoModule Word8 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Word16 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Word32 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Word64 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Int8 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Int16 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Int32 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Int64 where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Float where scale = LLVM.mul; scaleConst = LLVM.mul instance PseudoModule Double where scale = LLVM.mul; scaleConst = LLVM.mul instance (LLVM.IsArithmetic a, LLVM.IsPrimitive a, TypeNum.Positive n) => PseudoModule (Vector n a) where scale a v = flip LLVM.mul v . flip asTypeOf v =<< replicate a scaleConst a v = LLVM.mul (replicateConst a `asTypeOf` v) v class (LLVM.IsConst a) => IntegerConstant a where constFromInteger :: Integer -> ConstValue a instance IntegerConstant Word8 where constFromInteger = constOf . fromInteger instance IntegerConstant Word16 where constFromInteger = constOf . fromInteger instance IntegerConstant Word32 where constFromInteger = constOf . fromInteger instance IntegerConstant Word64 where constFromInteger = constOf . fromInteger instance IntegerConstant Int8 where constFromInteger = constOf . fromInteger instance IntegerConstant Int16 where constFromInteger = constOf . fromInteger instance IntegerConstant Int32 where constFromInteger = constOf . fromInteger instance IntegerConstant Int64 where constFromInteger = constOf . fromInteger instance IntegerConstant Float where constFromInteger = constOf . fromInteger instance IntegerConstant Double where constFromInteger = constOf . fromInteger instance (IntegerConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) => IntegerConstant (Vector n a) where constFromInteger = replicateConst . constFromInteger class (IntegerConstant a) => RationalConstant a where constFromRational :: Rational -> ConstValue a instance RationalConstant Float where constFromRational = constOf . fromRational instance RationalConstant Double where constFromRational = constOf . fromRational instance (RationalConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) => RationalConstant (Vector n a) where constFromRational = replicateConst . constFromRational class (RationalConstant a) => TranscendentalConstant a where constPi :: ConstValue a instance TranscendentalConstant Float where constPi = constOf pi instance TranscendentalConstant Double where constPi = constOf pi instance (TranscendentalConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) => TranscendentalConstant (Vector n a) where constPi = replicateConst constPi