{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {- | Some special operations on X86 processors. If you want to use them in algorithms you will always have to prepare an alternative implementation in terms of plain LLVM instructions. You will then run them with 'Ext.run' and this driver function then selects the most advanced of both implementations. Functions that are written this way can be found in "LLVM.Extra.Vector". Availability of extensions is checked with the @CPUID@ instruction. However this does only work if you compile code for the host machine, that is cross compilation will fail! For cross compilation we would need access to the SubTarget detection of LLVM that is only available in the C++ interface in version 2.6. -} module LLVM.Extra.Extension.X86 ( X86.maxss, X86.minss, X86.maxps, X86.minps, X86.maxsd, X86.minsd, X86.maxpd, X86.minpd, cmpss, cmpps, cmpsd, cmppd, cmpps256, cmppd256, pcmpgtb, pcmpgtw, pcmpgtd, pcmpgtq, pcmpugtb, pcmpugtw, pcmpugtd, pcmpugtq, pminsb, pminsw, pminsd, pmaxsb, pmaxsw, pmaxsd, pminub, pminuw, pminud, pmaxub, pmaxuw, pmaxud, pabsb, pabsw, pabsd, pmuludq, pmuldq, pmulld, cvtps2dq, cvtpd2dq, cvtdq2ps, cvtdq2pd, ldmxcsr, stmxcsr, withMXCSR, X86.haddps, X86.haddpd, X86.dpps, X86.dppd, roundss, X86.roundps, roundsd, X86.roundpd, absss, abssd, absps, abspd, ) where import qualified LLVM.Extra.Extension.X86Auto as X86 import qualified LLVM.Extra.Extension as Ext import LLVM.Extra.Extension.X86Auto ( V2Double, V4Float, V2Int64, V2Word64, V4Int32, V4Word32, V8Int16, V8Word16, V16Int8, V16Word8, ) import LLVM.Extra.ExtensionCheck.X86 (sse1, sse2, sse41, sse42, ) import qualified LLVM.Extra.Monad as M import qualified LLVM.Extra.ArithmeticPrivate as A import qualified LLVM.Core as LLVM import LLVM.Core (Value, Vector, valueOf, constOf, vector, CodeGenFunction, FPPredicate, ) import qualified Type.Data.Num.Decimal as TypeNum import qualified Data.NonEmpty.Class as NonEmptyC import qualified Data.Empty as Empty import Data.NonEmpty ((!:), ) import Data.Bits (clearBit, complement, ) import Data.Word (Word8, Word32, Word64, ) import Control.Monad.HT ((<=<), ) import Control.Applicative (pure, ) import Foreign.Ptr (Ptr, ) switchFPPred :: (Num i, LLVM.IsConst i, LLVM.IsInteger i, LLVM.IsPrimitive i, LLVM.IsFirstClass v, TypeNum.Positive n, LLVM.IsSized v, LLVM.IsSized (Vector n i), LLVM.SizeOf v ~ LLVM.SizeOf (Vector n i)) => (Value v -> Value v -> Value Word8 -> CodeGenFunction r (Value v)) -> FPPredicate -> Value v -> Value v -> CodeGenFunction r (Value (Vector n i)) switchFPPred g p x y = let f i x0 y0 = LLVM.bitcast =<< g x0 y0 (valueOf i) in case p of LLVM.FPFalse -> return (LLVM.value LLVM.zero) LLVM.FPOEQ -> f 0 x y LLVM.FPOGT -> f 1 y x LLVM.FPOGE -> f 2 y x LLVM.FPOLT -> f 1 x y LLVM.FPOLE -> f 2 x y LLVM.FPONE -> M.liftR2 A.and (f 7 x y) (f 4 x y) LLVM.FPORD -> f 7 x y LLVM.FPUNO -> f 3 x y LLVM.FPUEQ -> M.liftR2 A.or (f 3 x y) (f 0 x y) LLVM.FPUGT -> f 6 x y LLVM.FPUGE -> f 5 x y LLVM.FPULT -> f 6 y x LLVM.FPULE -> f 5 y x LLVM.FPUNE -> f 4 x y LLVM.FPT -> return (valueOf $ pure (-1)) cmpss :: Ext.T (FPPredicate -> V4Float -> V4Float -> CodeGenFunction r V4Int32) cmpss = fmap switchFPPred X86.cmpss cmpps :: Ext.T (FPPredicate -> V4Float -> V4Float -> CodeGenFunction r V4Int32) cmpps = fmap switchFPPred X86.cmpps cmpsd :: Ext.T (FPPredicate -> V2Double -> V2Double -> CodeGenFunction r V2Int64) cmpsd = fmap switchFPPred X86.cmpsd cmppd :: Ext.T (FPPredicate -> V2Double -> V2Double -> CodeGenFunction r V2Int64) cmppd = fmap switchFPPred X86.cmppd cmpps256 :: Ext.T (FPPredicate -> X86.V8Float -> X86.V8Float -> CodeGenFunction r X86.V8Int32) cmpps256 = fmap switchFPPred X86.cmpps256 cmppd256 :: Ext.T (FPPredicate -> X86.V4Double -> X86.V4Double -> CodeGenFunction r X86.V4Int64) cmppd256 = fmap switchFPPred X86.cmppd256 pcmpgtb :: Ext.T (V16Int8 -> V16Int8 -> CodeGenFunction r V16Int8) pcmpgtb = Ext.intrinsic sse2 "pcmpgt.b" pcmpgtw :: Ext.T (V8Int16 -> V8Int16 -> CodeGenFunction r V8Int16) pcmpgtw = Ext.intrinsic sse2 "pcmpgt.w" pcmpgtd :: Ext.T (V4Int32 -> V4Int32 -> CodeGenFunction r V4Int32) pcmpgtd = Ext.intrinsic sse2 "pcmpgt.d" pcmpgtq :: Ext.T (V2Int64 -> V2Int64 -> CodeGenFunction r V2Int64) pcmpgtq = Ext.intrinsic sse42 "pcmpgtq" pcmpuFromPcmp :: (TypeNum.Positive n, LLVM.IsPrimitive s, LLVM.IsPrimitive u, LLVM.IsArithmetic u, LLVM.IsConst u, Bounded u, Integral u, LLVM.IsSized (Vector n s), LLVM.IsSized (Vector n u), LLVM.SizeOf (Vector n s) ~ LLVM.SizeOf (Vector n u)) => Ext.T (Value (Vector n s) -> Value (Vector n s) -> CodeGenFunction r (Value (Vector n s))) -> Ext.T (Value (Vector n u) -> Value (Vector n u) -> CodeGenFunction r (Value (Vector n u))) pcmpuFromPcmp pcmp = Ext.with pcmp $ \cmp x y -> do let offset = valueOf $ pure (1 + div maxBound 2) xa <- LLVM.bitcast =<< A.sub x offset ya <- LLVM.bitcast =<< A.sub y offset LLVM.bitcast =<< cmp xa ya pcmpugtb :: Ext.T (V16Word8 -> V16Word8 -> CodeGenFunction r V16Word8) pcmpugtb = pcmpuFromPcmp pcmpgtb pcmpugtw :: Ext.T (V8Word16 -> V8Word16 -> CodeGenFunction r V8Word16) pcmpugtw = pcmpuFromPcmp pcmpgtw pcmpugtd :: Ext.T (V4Word32 -> V4Word32 -> CodeGenFunction r V4Word32) pcmpugtd = pcmpuFromPcmp pcmpgtd pcmpugtq :: Ext.T (V2Word64 -> V2Word64 -> CodeGenFunction r V2Word64) pcmpugtq = pcmpuFromPcmp pcmpgtq pminsb, pmaxsb :: Ext.T (V16Int8 -> V16Int8 -> CodeGenFunction r V16Int8) pminsb = X86.pminsb128 pmaxsb = X86.pmaxsb128 pminsw, pmaxsw :: Ext.T (V8Int16 -> V8Int16 -> CodeGenFunction r V8Int16) pminsw = X86.pminsw128 pmaxsw = X86.pmaxsw128 pminsd, pmaxsd :: Ext.T (V4Int32 -> V4Int32 -> CodeGenFunction r V4Int32) pminsd = X86.pminsd128 pmaxsd = X86.pmaxsd128 pminub, pmaxub :: Ext.T (V16Word8 -> V16Word8 -> CodeGenFunction r V16Word8) pminub = X86.pminub128 pmaxub = X86.pmaxub128 pminuw, pmaxuw :: Ext.T (V8Word16 -> V8Word16 -> CodeGenFunction r V8Word16) pminuw = X86.pminuw128 pmaxuw = X86.pmaxuw128 pminud, pmaxud :: Ext.T (V4Word32 -> V4Word32 -> CodeGenFunction r V4Word32) pminud = X86.pminud128 pmaxud = X86.pmaxud128 pabsb :: Ext.T (V16Int8 -> CodeGenFunction r V16Int8) pabsb = X86.pabsb128 pabsw :: Ext.T (V8Int16 -> CodeGenFunction r V8Int16) pabsw = X86.pabsw128 pabsd :: Ext.T (V4Int32 -> CodeGenFunction r V4Int32) pabsd = X86.pabsd128 pmuludq :: Ext.T (V4Word32 -> V4Word32 -> CodeGenFunction r V2Word64) pmuludq = X86.pmuludq128 pmuldq :: Ext.T (V4Int32 -> V4Int32 -> CodeGenFunction r V2Int64) pmuldq = X86.pmuldq128 pmulld :: Ext.T (V4Word32 -> V4Word32 -> CodeGenFunction r V4Word32) pmulld = Ext.wrap sse41 LLVM.mul -- pmulld = Ext.intrinsic sse41 "pmulld" cvtps2dq :: Ext.T (V4Float -> CodeGenFunction r V4Int32) cvtps2dq = X86.cvtps2dq -- | the upper two integers are set to zero, there is no instruction that converts to Int64 cvtpd2dq :: Ext.T (V2Double -> CodeGenFunction r V4Int32) cvtpd2dq = X86.cvtpd2dq cvtdq2ps :: Ext.T (V4Int32 -> CodeGenFunction r V4Float) cvtdq2ps = X86.cvtdq2ps -- | the upper two integers are ignored, there is no instruction that converts from Int64 cvtdq2pd :: Ext.T (V4Int32 -> CodeGenFunction r V2Double) cvtdq2pd = X86.cvtdq2pd valueUnit :: Value () -> () valueUnit _ = () {- | MXCSR is not really supported by LLVM-2.6. LLVM does not know about the dependency of all floating point operations on this status register. -} ldmxcsr :: Ext.T (Value (Ptr Word32) -> CodeGenFunction r ()) ldmxcsr = fmap (fmap valueUnit .) $ Ext.intrinsicAttr [] sse1 "ldmxcsr" stmxcsr :: Ext.T (Value (Ptr Word32) -> CodeGenFunction r ()) stmxcsr = fmap (fmap valueUnit .) $ Ext.intrinsicAttr [] sse1 "stmxcsr" withMXCSR :: Word32 -> Ext.T (CodeGenFunction r a -> CodeGenFunction r a) withMXCSR mxcsr = Ext.with2 ldmxcsr stmxcsr $ \ ld st f -> do mxcsrOld <- LLVM.alloca st mxcsrOld mxcsrFloor <- LLVM.alloca LLVM.store (valueOf $ mxcsr) mxcsrFloor {- unfortunately, createGlobal is a function CodeGenModule monad mxcsrFloor <- LLVM.createGlobal True LLVM.InternalLinkage mxcsr -} ld mxcsrFloor r <- f ld mxcsrOld return r {- [maxsd, minsd, maxpd, minpd] = map (Ext.intrinsic sse2) ["max.ss", "min.ss", "max.ps", "min.ps"] -} roundss :: Ext.T (V4Float -> Value Word32 -> CodeGenFunction r V4Float) roundss = fmap (\f -> f (LLVM.value LLVM.undef)) X86.roundss roundsd :: Ext.T (V2Double -> Value Word32 -> CodeGenFunction r V2Double) roundsd = fmap (\f -> f (LLVM.value LLVM.undef)) X86.roundsd {- Not an LLVM intrinsic but implementation specific: We expect that floating point values are in IEEE format and thus the most significant bit is the sign. The absolute value can be computed very efficiently by clearing the sign bit. Actually, LLVM's codegen implements neg by an XOR on the sign bit. -} absss :: Ext.T (V4Float -> CodeGenFunction r V4Float) absss = Ext.wrap sse1 $ LLVM.bitcast <=< A.and (LLVM.valueOf $ vector $ (flip clearBit 31 $ complement 0) !: NonEmptyC.repeat (complement 0) :: V4Word32) <=< LLVM.bitcast {- This function works on a single Float, but I like to do the masking in an XMM register because usually the value is there anyway. absss = flip LLVM.extractelement (valueOf 0) . flip asTypeOf (undefined :: V4Float) <=< LLVM.bitcast -- <=< A.and (LLVM.value $ constVector [constOf 0x7FFFFFFF] :: V4Word32) -- <=< A.and (LLVM.value $ constVector [constOf 0x7FFFFFFF, LLVM.undef, LLVM.undef, LLVM.undef] :: V4Word32) <=< A.and (LLVM.value $ constVector [constOf 0x7FFFFFFF, LLVM.zero, LLVM.zero, LLVM.zero] :: V4Word32) <=< LLVM.bitcast . flip asTypeOf (undefined :: V4Float) <=< flip (LLVM.insertelement (LLVM.value LLVM.undef)) (valueOf 0) -} {- This moves the value to a general purpose register and performs the bit masking there absss = LLVM.bitcast <=< A.and (valueOf 0x7FFFFFFF :: Value Word32) <=< LLVM.bitcast -} abssd :: Ext.T (V2Double -> CodeGenFunction r V2Double) abssd = Ext.wrap sse2 $ LLVM.bitcast <=< A.and (LLVM.valueOf $ vector $ (flip clearBit 63 $ complement 0) !: complement 0 !: Empty.Cons :: V2Word64) <=< LLVM.bitcast mask :: (TypeNum.Positive n, LLVM.IsConst w, LLVM.IsPrimitive w, LLVM.IsInteger w) => w -> Value (Vector n w) -> CodeGenFunction r (Value (Vector n w)) mask x = A.and (LLVM.valueOf $ pure x) absps :: (TypeNum.Positive n) => Ext.T (Value (Vector n Float) -> CodeGenFunction r (Value (Vector n Float))) absps = Ext.wrap sse1 $ LLVM.bitcastElements <=< mask (flip clearBit 31 $ complement 0 :: Word32) <=< LLVM.bitcastElements abspd :: (TypeNum.Positive n) => Ext.T (Value (Vector n Double) -> CodeGenFunction r (Value (Vector n Double))) abspd = Ext.wrap sse2 $ LLVM.bitcastElements <=< mask (flip clearBit 63 $ complement 0 :: Word64) <=< LLVM.bitcastElements {- | cumulative sum: @(a,b,c,d) -> (a,a+b,a+b+c,a+b+c+d)@ I try to cleverly use horizontal add, but the generic version in the Vector module is better. -} _cumulate1s :: Ext.T (V4Float -> CodeGenFunction r V4Float) _cumulate1s = Ext.with X86.haddps $ \haddp x -> do y <- haddp x (LLVM.value LLVM.undef) z <- LLVM.shufflevector x y $ constOf $ vector $ 0!:4!:2!:5!:Empty.Cons offset <- LLVM.shufflevector y (LLVM.value LLVM.zero) $ constOf $ vector $ 4!:5!:0!:0!:Empty.Cons A.add z offset