module LLVM.Extra.ScalarOrVector (
Fraction (truncate, fraction),
signedFraction,
addToPhase,
incPhase,
Replicate (replicate, replicateConst),
replicateOf,
Real (min, max, abs),
) where
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Extension.X86 as X86
import qualified LLVM.Extra.Extension as Ext
import qualified LLVM.Extra.Arithmetic as A
import Data.TypeLevel.Num (D1, )
import qualified LLVM.Core as LLVM
import LLVM.Core
(Value, ConstValue, valueOf,
Vector, insertelement, constOf, constVector,
IsConst, IsFloating, IsPrimitive, IsPowerOf2,
CodeGenFunction,
FP128, )
import Control.Monad.HT ((<=<), )
import Data.Word (Word8, Word16, Word32, Word64, )
import Data.Int (Int8, Int16, Int32, Int64, )
import Prelude hiding (Real, replicate, min, max, abs, truncate, floor, round, )
class (Real a, IsFloating a) => Fraction a where
truncate :: Value a -> CodeGenFunction r (Value a)
fraction :: Value a -> CodeGenFunction r (Value a)
instance Fraction Float where
truncate =
mapAuto
(LLVM.sitofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptosi)
(Ext.with X86.roundss $ \round x -> round x (valueOf 3))
fraction =
(\x ->
fractionGen x
`Ext.run`
(Ext.with X86.cmpss $ \cmp ->
fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x))
`mapAuto`
(Ext.with X86.roundss $ \round x ->
A.sub x =<< round x (valueOf 1))
instance Fraction Double where
truncate =
mapAuto
(LLVM.sitofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptosi)
(Ext.with X86.roundsd $ \round x -> round x (valueOf 3))
fraction =
(\x ->
fractionGen x
`Ext.run`
(Ext.with X86.cmpsd $ \cmp ->
fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x))
`mapAuto`
(Ext.with X86.roundsd $ \round x ->
A.sub x =<< round x (valueOf 1))
instance (LLVM.IsPowerOf2 n, Vector.Real a, IsFloating a, IsConst a) =>
Fraction (Vector n a) where
truncate = Vector.truncate
fraction = Vector.fraction
signedFraction ::
(Fraction a) =>
Value a -> CodeGenFunction r (Value a)
signedFraction x =
A.sub x =<< truncate x
fractionGen ::
(Num a, Fraction v, Replicate a v, IsConst a, LLVM.CmpRet v b) =>
Value v -> CodeGenFunction r (Value v)
fractionGen x =
do xf <- signedFraction x
b <- A.fcmp LLVM.FPOGE xf (LLVM.value LLVM.zero)
LLVM.select b xf =<< A.add xf (replicateOf 1)
fractionLogical ::
(Fraction a, LLVM.NumberOfElements D1 a,
LLVM.IsInteger b, LLVM.NumberOfElements D1 b) =>
(LLVM.FPPredicate ->
Value a -> Value a -> CodeGenFunction r (Value b)) ->
Value a -> CodeGenFunction r (Value a)
fractionLogical cmp x =
do xf <- signedFraction x
b <- cmp LLVM.FPOLT xf (LLVM.value LLVM.zero)
A.sub xf =<< LLVM.sitofp b
addToPhase ::
(Fraction a) =>
Value a -> Value a -> CodeGenFunction r (Value a)
addToPhase d p =
fraction =<< A.add d p
incPhase ::
(Fraction a) =>
Value a -> Value a -> CodeGenFunction r (Value a)
incPhase d p =
signedFraction =<< A.add d p
class Replicate scalar vector | vector -> scalar where
replicate :: Value scalar -> CodeGenFunction r (Value vector)
replicateConst :: ConstValue scalar -> ConstValue vector
instance Replicate Float Float where replicate = return; replicateConst = id;
instance Replicate Double Double where replicate = return; replicateConst = id;
instance Replicate FP128 FP128 where replicate = return; replicateConst = id;
instance Replicate Bool Bool where replicate = return; replicateConst = id;
instance Replicate Int8 Int8 where replicate = return; replicateConst = id;
instance Replicate Int16 Int16 where replicate = return; replicateConst = id;
instance Replicate Int32 Int32 where replicate = return; replicateConst = id;
instance Replicate Int64 Int64 where replicate = return; replicateConst = id;
instance Replicate Word8 Word8 where replicate = return; replicateConst = id;
instance Replicate Word16 Word16 where replicate = return; replicateConst = id;
instance Replicate Word32 Word32 where replicate = return; replicateConst = id;
instance Replicate Word64 Word64 where replicate = return; replicateConst = id;
instance (LLVM.IsPowerOf2 n, LLVM.IsPrimitive a) => Replicate a (Vector n a) where
replicate x = do
v <- LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 0)
LLVM.shufflevector v (LLVM.value LLVM.undef) LLVM.zero
replicateConst x = LLVM.constVector [x];
replicateOf ::
(IsConst a, Replicate a v) =>
a -> Value v
replicateOf a =
LLVM.value (replicateConst (LLVM.constOf a))
class (LLVM.IsArithmetic a) => Real a where
min :: Value a -> Value a -> CodeGenFunction r (Value a)
max :: Value a -> Value a -> CodeGenFunction r (Value a)
abs :: Value a -> CodeGenFunction r (Value a)
instance Real Float where
min = zipAutoWith A.fmin X86.minss
max = zipAutoWith A.fmax X86.maxss
abs = mapAuto A.fabs X86.absss
instance Real Double where
min = zipAutoWith A.fmin X86.minsd
max = zipAutoWith A.fmax X86.maxsd
abs = mapAuto A.fabs X86.abssd
infixl 1 `mapAuto`
runScalar ::
(Vector.Access n a va, Vector.Access n b vb) =>
(va -> CodeGenFunction r vb) ->
(a -> CodeGenFunction r b)
runScalar op a =
Vector.extract (valueOf 0)
=<< op
=<< Vector.insert (valueOf 0) a LLVM.undefTuple
mapAuto ::
(Vector.Access n a va, Vector.Access n b vb) =>
(a -> CodeGenFunction r b) ->
Ext.T (va -> CodeGenFunction r vb) ->
(a -> CodeGenFunction r b)
mapAuto f g a =
Ext.run (f a) $
Ext.with g $ \op -> runScalar op a
zipAutoWith ::
(Vector.Access n a va, Vector.Access n b vb, Vector.Access n c vc) =>
(a -> b -> CodeGenFunction r c) ->
Ext.T (va -> vb -> CodeGenFunction r vc) ->
(a -> b -> CodeGenFunction r c)
zipAutoWith f g =
curry $ mapAuto (uncurry f) (fmap uncurry g)
instance Real FP128 where min = A.fmin; max = A.fmax; abs = A.fabs;
instance Real Int8 where min = A.smin; max = A.smax; abs = A.sabs;
instance Real Int16 where min = A.smin; max = A.smax; abs = A.sabs;
instance Real Int32 where min = A.smin; max = A.smax; abs = A.sabs;
instance Real Int64 where min = A.smin; max = A.smax; abs = A.sabs;
instance Real Word8 where min = A.umin; max = A.umax; abs = return;
instance Real Word16 where min = A.umin; max = A.umax; abs = return;
instance Real Word32 where min = A.umin; max = A.umax; abs = return;
instance Real Word64 where min = A.umin; max = A.umax; abs = return;
instance (LLVM.IsPowerOf2 n, Vector.Real a) =>
Real (Vector n a) where
min = Vector.min
max = Vector.max
abs = Vector.abs