{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} {- | Support for unified handling of scalars and vectors. Attention: The rounding and fraction functions only work for floating point values with maximum magnitude of @maxBound :: Int32@. This way we save expensive handling of possibly seldom cases. -} module LLVM.Extra.ScalarOrVector ( Fraction (truncate, fraction), signedFraction, addToPhase, incPhase, Replicate (replicate, replicateConst), replicateOf, Real (min, max, abs), ) where import qualified LLVM.Extra.Vector as Vector import qualified LLVM.Extra.Extension.X86 as X86 import qualified LLVM.Extra.Extension as Ext import qualified LLVM.Extra.Class as Class import qualified LLVM.Extra.Arithmetic as A import qualified Data.TypeLevel.Num as TypeNum import Data.TypeLevel.Num (D1, ) import qualified LLVM.Core as LLVM import LLVM.Core (Value, ConstValue, valueOf, Vector, FP128, IsConst, IsFloating, CodeGenFunction, ) import Control.Monad.HT ((<=<), ) import Data.Word (Word8, Word16, Word32, Word64, ) import Data.Int (Int8, Int16, Int32, Int64, ) import Prelude hiding (Real, replicate, min, max, abs, truncate, floor, round, ) {- class (IsFloating frac, IsInteger int, LLVM.NumberOfElements n frac, LLVM.NumberOfElements n int) => Fraction n int frac | frac -> int, frac -> n, int -> n where fptosi :: Value frac -> CodeGenFunction r (Value int) fptosi = LLVM.fptosi sitofp :: Value int -> CodeGenFunction r (Value frac) sitofp = LLVM.sitofp -} {- class (IsFloating frac) => Fraction int frac | frac -> int where fptosi :: Value frac -> CodeGenFunction r (Value int) sitofp :: Value int -> CodeGenFunction r (Value frac) instance Fraction Int32 Float where fptosi = LLVM.fptosi sitofp = LLVM.sitofp instance Fraction Int64 Double where fptosi = LLVM.fptosi sitofp = LLVM.sitofp instance (TypeNum.Pos n) => Fraction (Vector n Int32) (Vector n Float) where fptosi = LLVM.fptosi sitofp = LLVM.sitofp instance (TypeNum.Pos n) => Fraction (Vector n Int64) (Vector n Double) where fptosi = LLVM.fptosi sitofp = LLVM.sitofp -} class (Real a, IsFloating a) => Fraction a where truncate :: Value a -> CodeGenFunction r (Value a) fraction :: Value a -> CodeGenFunction r (Value a) instance Fraction Float where truncate = mapAuto (LLVM.sitofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptosi) (Ext.with X86.roundss $ \round x -> round x (valueOf 3)) fraction = (\x -> fractionGen x `Ext.run` (Ext.with X86.cmpss $ \cmp -> fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x)) `mapAuto` (Ext.with X86.roundss $ \round x -> A.sub x =<< round x (valueOf 1)) instance Fraction Double where truncate = mapAuto -- X86 only converts Double to Int32, it cannot target Int64 (LLVM.sitofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptosi) (Ext.with X86.roundsd $ \round x -> round x (valueOf 3)) fraction = (\x -> fractionGen x `Ext.run` (Ext.with X86.cmpsd $ \cmp -> fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x)) {- For Doubles it would be more efficient to convert the lower 32 bit instead of the lower 64 bit, since x86 supports only conversion from 32 bit natively. (Ext.with X86.cmpsd $ \cmp -> fractionLogical (\x y -> cmp x y >>= LLVM.bitcastUnify ) -} `mapAuto` (Ext.with X86.roundsd $ \round x -> A.sub x =<< round x (valueOf 1)) instance (TypeNum.Pos n, Vector.Real a, IsFloating a, IsConst a) => Fraction (Vector n a) where truncate = Vector.truncate fraction = Vector.fraction {- | The fraction has the same sign as the argument. This is not particular useful but fast on IEEE implementations. -} signedFraction :: (Fraction a) => Value a -> CodeGenFunction r (Value a) signedFraction x = A.sub x =<< truncate x fractionGen :: (Num a, Fraction v, Replicate a v, IsConst a, LLVM.CmpRet v b) => Value v -> CodeGenFunction r (Value v) fractionGen x = do xf <- signedFraction x b <- A.fcmp LLVM.FPOGE xf (LLVM.value LLVM.zero) LLVM.select b xf =<< A.add xf (replicateOf 1) fractionLogical :: (Fraction a, LLVM.NumberOfElements D1 a, LLVM.IsInteger b, LLVM.NumberOfElements D1 b) => (LLVM.FPPredicate -> Value a -> Value a -> CodeGenFunction r (Value b)) -> Value a -> CodeGenFunction r (Value a) fractionLogical cmp x = do xf <- signedFraction x b <- cmp LLVM.FPOLT xf (LLVM.value LLVM.zero) A.sub xf =<< LLVM.sitofp b {- | increment (first operand) may be negative, phase must always be non-negative -} addToPhase :: (Fraction a) => Value a -> Value a -> CodeGenFunction r (Value a) addToPhase d p = fraction =<< A.add d p {- | both increment and phase must be non-negative -} incPhase :: (Fraction a) => Value a -> Value a -> CodeGenFunction r (Value a) incPhase d p = signedFraction =<< A.add d p class Replicate scalar vector | vector -> scalar where replicate :: Value scalar -> CodeGenFunction r (Value vector) replicateConst :: ConstValue scalar -> ConstValue vector instance Replicate Float Float where replicate = return; replicateConst = id; instance Replicate Double Double where replicate = return; replicateConst = id; instance Replicate FP128 FP128 where replicate = return; replicateConst = id; instance Replicate Bool Bool where replicate = return; replicateConst = id; instance Replicate Int8 Int8 where replicate = return; replicateConst = id; instance Replicate Int16 Int16 where replicate = return; replicateConst = id; instance Replicate Int32 Int32 where replicate = return; replicateConst = id; instance Replicate Int64 Int64 where replicate = return; replicateConst = id; instance Replicate Word8 Word8 where replicate = return; replicateConst = id; instance Replicate Word16 Word16 where replicate = return; replicateConst = id; instance Replicate Word32 Word32 where replicate = return; replicateConst = id; instance Replicate Word64 Word64 where replicate = return; replicateConst = id; instance (TypeNum.Pos n, LLVM.IsPrimitive a) => Replicate a (Vector n a) where {- crashes LLVM-2.5, seems to be fixed in LLVM-2.6 -} replicate x = do v <- singleton x LLVM.shufflevector v (LLVM.value LLVM.undef) LLVM.zero {- crashes LLVM-2.5 replicate x = do v <- LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 1) LLVM.shufflevector v (LLVM.value LLVM.undef) (constVector $ repeat $ LLVM.constOf 1) -} {- the (repeat zero) is also converted to 'zeroinitializer' and crashes LLVM compiler (constVector $ repeat LLVM.zero) -} {- replicate = Vector.replicate -} replicateConst x = LLVM.constVector [x]; singleton :: (LLVM.IsPrimitive a) => Value a -> CodeGenFunction r (Value (Vector D1 a)) singleton x = LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 0) replicateOf :: (IsConst a, Replicate a v) => a -> Value v replicateOf a = LLVM.value (replicateConst (LLVM.constOf a)) class (LLVM.IsArithmetic a) => Real a where min :: Value a -> Value a -> CodeGenFunction r (Value a) max :: Value a -> Value a -> CodeGenFunction r (Value a) abs :: Value a -> CodeGenFunction r (Value a) instance Real Float where min = zipAutoWith A.min X86.minss max = zipAutoWith A.max X86.maxss abs = mapAuto A.abs X86.absss -- abs x = max x =<< LLVM.neg x -- abs x = A.abs instance Real Double where min = zipAutoWith A.min X86.minsd max = zipAutoWith A.max X86.maxsd abs = mapAuto A.abs X86.abssd infixl 1 `mapAuto` {- | There are functions that are intended for processing scalars but have formally vector input and output. This function breaks vector function down to a scalar function by accessing the lowest vector element. -} runScalar :: (Vector.Access n a va, Vector.Access n b vb) => (va -> CodeGenFunction r vb) -> (a -> CodeGenFunction r b) runScalar op a = Vector.extract (valueOf 0) =<< op =<< Vector.insert (valueOf 0) a Class.undefTuple mapAuto :: (Vector.Access n a va, Vector.Access n b vb) => (a -> CodeGenFunction r b) -> Ext.T (va -> CodeGenFunction r vb) -> (a -> CodeGenFunction r b) mapAuto f g a = Ext.run (f a) $ Ext.with g $ \op -> runScalar op a zipAutoWith :: (Vector.Access n a va, Vector.Access n b vb, Vector.Access n c vc) => (a -> b -> CodeGenFunction r c) -> Ext.T (va -> vb -> CodeGenFunction r vc) -> (a -> b -> CodeGenFunction r c) zipAutoWith f g = curry $ mapAuto (uncurry f) (fmap uncurry g) instance Real FP128 where min = A.min; max = A.max; abs = A.abs; instance Real Int8 where min = A.min; max = A.max; abs = A.abs; instance Real Int16 where min = A.min; max = A.max; abs = A.abs; instance Real Int32 where min = A.min; max = A.max; abs = A.abs; instance Real Int64 where min = A.min; max = A.max; abs = A.abs; instance Real Word8 where min = A.min; max = A.max; abs = return; instance Real Word16 where min = A.min; max = A.max; abs = return; instance Real Word32 where min = A.min; max = A.max; abs = return; instance Real Word64 where min = A.min; max = A.max; abs = return; instance (TypeNum.Pos n, Vector.Real a) => Real (Vector n a) where min = Vector.min max = Vector.max abs = Vector.abs