{-# LANGUAGE FlexibleContexts #-}
{- |
Some special operations on X86 processors.
If you want to use them in algorithm
you will always have to prepare an alternative implementation
in terms of plain LLVM instructions.
You will then run them with 'Ext.run'
and this driver function then selects the most advanced of both implementations.
Functions that are written this way can be found in "LLVM.Extra.Vector".
Availability of extensions is checked with the @CPUID@ instruction.
However this does only work if you compile code for the host machine,
that is cross compilation will fail!
For cross compilation we would need access to the SubTarget detection of LLVM
that is only available in the C++ interface in version 2.6.
-}
module LLVM.Extra.Extension.X86 (
   maxss, minss, maxps, minps,
   maxsd, minsd, maxpd, minpd,
   cmpss, cmpps, cmpsd, cmppd,
   pcmpgtb,  pcmpgtw,  pcmpgtd,  pcmpgtq,
   pcmpugtb, pcmpugtw, pcmpugtd, pcmpugtq,
   pminsb, pminsw, pminsd,
   pmaxsb, pmaxsw, pmaxsd,
   pminub, pminuw, pminud,
   pmaxub, pmaxuw, pmaxud,
   pabsb, pabsw, pabsd,
   pmuludq, pmulld,
   cvtps2dq, cvtpd2dq,
   ldmxcsr, stmxcsr, withMXCSR,
   haddps, haddpd, dpps, dppd,
   roundss, roundps, roundsd, roundpd,
   absss, abssd, absps, abspd,
   ) where

import qualified LLVM.Extra.Extension as Ext
import LLVM.Extra.ExtensionCheck.X86
          (sse1, sse2, sse3, ssse3, sse41, sse42, )

import qualified LLVM.Extra.Monad as M
import qualified LLVM.Extra.ArithmeticPrivate as A
import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, Vector, value, valueOf, constOf, constVector,
    CodeGenFunction, FPPredicate, )

import qualified Data.TypeLevel.Num as TypeNum
import Data.TypeLevel.Num (D2, D4, D8, D16, )

import Data.Bits (clearBit, complement, )
import Data.Int  (Int8, Int16, Int32, Int64, )
import Data.Word (Word8, Word16, Word32, Word64, )

import Control.Monad.HT ((<=<), )

import Foreign.Ptr (Ptr, )


-- * target dependent functions

type VFloat  = Value (Vector D4 Float)
type VDouble = Value (Vector D2 Double)


maxss, minss, maxps, minps ::
   Ext.T (VFloat -> VFloat -> CodeGenFunction r VFloat)
maxss = Ext.intrinsic sse1 "max.ss"
minss = Ext.intrinsic sse1 "min.ss"
maxps = Ext.intrinsic sse1 "max.ps"
minps = Ext.intrinsic sse1 "min.ps"

{- here r would be unified
[maxss, minss, maxps, minps] =
   map (Ext.intrinsic sse1)
     ["max.ss", "min.ss", "max.ps", "min.ps"]
-}

maxsd, minsd, maxpd, minpd ::
   Ext.T (VDouble -> VDouble -> CodeGenFunction r VDouble)
maxsd = Ext.intrinsic sse1 "max.sd"
minsd = Ext.intrinsic sse1 "min.sd"
maxpd = Ext.intrinsic sse1 "max.pd"
minpd = Ext.intrinsic sse1 "min.pd"

switchFPPred ::
   (Num i, LLVM.IsConst i, LLVM.IsInteger i, LLVM.IsPrimitive i,
    LLVM.IsFirstClass v,
    TypeNum.Pos n,
    LLVM.IsSized v s, LLVM.IsSized (Vector n i) s) =>
   (Value v -> Value v -> Value Word8 -> CodeGenFunction r (Value v)) ->
   FPPredicate -> Value v -> Value v -> CodeGenFunction r (Value (Vector n i))
switchFPPred g p x y =
   let f i x0 y0 = LLVM.bitcastUnify =<< g x0 y0 (valueOf i)
   in  case p of
          LLVM.FPFalse -> return (LLVM.value LLVM.zero)
          LLVM.FPOEQ   -> f 0 x y
          LLVM.FPOGT   -> f 1 y x
          LLVM.FPOGE   -> f 2 y x
          LLVM.FPOLT   -> f 1 x y
          LLVM.FPOLE   -> f 2 x y
          LLVM.FPONE   -> M.liftR2 A.and (f 7 x y) (f 4 x y)
          LLVM.FPORD   -> f 7 x y
          LLVM.FPUNO   -> f 3 x y
          LLVM.FPUEQ   -> M.liftR2 A.or (f 3 x y) (f 0 x y)
          LLVM.FPUGT   -> f 6 x y
          LLVM.FPUGE   -> f 5 x y
          LLVM.FPULT   -> f 6 y x
          LLVM.FPULE   -> f 5 y x
          LLVM.FPUNE   -> f 4 x y
          LLVM.FPT     -> return (LLVM.value (LLVM.constVector [LLVM.constOf (-1)]))

cmpss :: Ext.T (FPPredicate -> VFloat -> VFloat -> CodeGenFunction r (Value (Vector D4 Int32)))
cmpss = fmap switchFPPred (Ext.intrinsic sse1 "cmp.ss")

cmpps :: Ext.T (FPPredicate -> VFloat -> VFloat -> CodeGenFunction r (Value (Vector D4 Int32)))
cmpps = fmap switchFPPred (Ext.intrinsic sse1 "cmp.ps")

cmpsd :: Ext.T (FPPredicate -> VDouble -> VDouble -> CodeGenFunction r (Value (Vector D2 Int64)))
cmpsd = fmap switchFPPred (Ext.intrinsic sse2 "cmp.sd")

cmppd :: Ext.T (FPPredicate -> VDouble -> VDouble -> CodeGenFunction r (Value (Vector D2 Int64)))
cmppd = fmap switchFPPred (Ext.intrinsic sse2 "cmp.pd")


pcmpgtb :: Ext.T (Value (Vector D16 Int8) -> Value (Vector D16 Int8) -> CodeGenFunction r (Value (Vector D16 Int8)))
pcmpgtb = Ext.intrinsic sse2 "pcmpgt.b"

pcmpgtw :: Ext.T (Value (Vector D8 Int16) -> Value (Vector D8 Int16) -> CodeGenFunction r (Value (Vector D8 Int16)))
pcmpgtw = Ext.intrinsic sse2 "pcmpgt.w"

pcmpgtd :: Ext.T (Value (Vector D4 Int32) -> Value (Vector D4 Int32) -> CodeGenFunction r (Value (Vector D4 Int32)))
pcmpgtd = Ext.intrinsic sse2 "pcmpgt.d"

pcmpgtq :: Ext.T (Value (Vector D2 Int64) -> Value (Vector D2 Int64) -> CodeGenFunction r (Value (Vector D2 Int64)))
pcmpgtq = Ext.intrinsic sse42 "pcmpgtq"


pcmpuFromPcmp ::
   (TypeNum.Pos n,
    LLVM.IsPrimitive s,
    LLVM.IsPrimitive u, LLVM.IsArithmetic u, LLVM.IsConst u,
    Bounded u, Integral u,
    LLVM.IsSized (Vector n s) size,
    LLVM.IsSized (Vector n u) size) =>
   Ext.T (Value (Vector n s) -> Value (Vector n s) -> CodeGenFunction r (Value (Vector n s))) ->
   Ext.T (Value (Vector n u) -> Value (Vector n u) -> CodeGenFunction r (Value (Vector n u)))
pcmpuFromPcmp pcmp =
   Ext.with pcmp $ \cmp x y -> do
      let offset = value (constVector [constOf (1 + div maxBound 2)])
      xa <- LLVM.bitcastUnify =<< A.sub x offset
      ya <- LLVM.bitcastUnify =<< A.sub y offset
      LLVM.bitcastUnify =<< cmp xa ya

pcmpugtb :: Ext.T (Value (Vector D16 Word8) -> Value (Vector D16 Word8) -> CodeGenFunction r (Value (Vector D16 Word8)))
pcmpugtb = pcmpuFromPcmp pcmpgtb

pcmpugtw :: Ext.T (Value (Vector D8 Word16) -> Value (Vector D8 Word16) -> CodeGenFunction r (Value (Vector D8 Word16)))
pcmpugtw = pcmpuFromPcmp pcmpgtw

pcmpugtd :: Ext.T (Value (Vector D4 Word32) -> Value (Vector D4 Word32) -> CodeGenFunction r (Value (Vector D4 Word32)))
pcmpugtd = pcmpuFromPcmp pcmpgtd

pcmpugtq :: Ext.T (Value (Vector D2 Word64) -> Value (Vector D2 Word64) -> CodeGenFunction r (Value (Vector D2 Word64)))
pcmpugtq = pcmpuFromPcmp pcmpgtq


pminsb :: Ext.T (Value (Vector D16 Int8) -> Value (Vector D16 Int8) -> CodeGenFunction r (Value (Vector D16 Int8)))
pminsb = Ext.intrinsic sse41 "pminsb"

pminsw :: Ext.T (Value (Vector D8 Int16) -> Value (Vector D8 Int16) -> CodeGenFunction r (Value (Vector D8 Int16)))
pminsw = Ext.intrinsic sse2 "pmins.w"

pminsd :: Ext.T (Value (Vector D4 Int32) -> Value (Vector D4 Int32) -> CodeGenFunction r (Value (Vector D4 Int32)))
pminsd = Ext.intrinsic sse41 "pminsd"


pmaxsb :: Ext.T (Value (Vector D16 Int8) -> Value (Vector D16 Int8) -> CodeGenFunction r (Value (Vector D16 Int8)))
pmaxsb = Ext.intrinsic sse41 "pmaxsb"

pmaxsw :: Ext.T (Value (Vector D8 Int16) -> Value (Vector D8 Int16) -> CodeGenFunction r (Value (Vector D8 Int16)))
pmaxsw = Ext.intrinsic sse2 "pmaxs.w"

pmaxsd :: Ext.T (Value (Vector D4 Int32) -> Value (Vector D4 Int32) -> CodeGenFunction r (Value (Vector D4 Int32)))
pmaxsd = Ext.intrinsic sse41 "pmaxsd"


pminub :: Ext.T (Value (Vector D16 Word8) -> Value (Vector D16 Word8) -> CodeGenFunction r (Value (Vector D16 Word8)))
pminub = Ext.intrinsic sse2 "pminu.b"

pminuw :: Ext.T (Value (Vector D8 Word16) -> Value (Vector D8 Word16) -> CodeGenFunction r (Value (Vector D8 Word16)))
pminuw = Ext.intrinsic sse41 "pminuw"

pminud :: Ext.T (Value (Vector D4 Word32) -> Value (Vector D4 Word32) -> CodeGenFunction r (Value (Vector D4 Word32)))
pminud = Ext.intrinsic sse41 "pminud"


pmaxub :: Ext.T (Value (Vector D16 Word8) -> Value (Vector D16 Word8) -> CodeGenFunction r (Value (Vector D16 Word8)))
pmaxub = Ext.intrinsic sse2 "pmaxu.b"

pmaxuw :: Ext.T (Value (Vector D8 Word16) -> Value (Vector D8 Word16) -> CodeGenFunction r (Value (Vector D8 Word16)))
pmaxuw = Ext.intrinsic sse41 "pmaxuw"

pmaxud :: Ext.T (Value (Vector D4 Word32) -> Value (Vector D4 Word32) -> CodeGenFunction r (Value (Vector D4 Word32)))
pmaxud = Ext.intrinsic sse41 "pmaxud"


pabsb :: Ext.T (Value (Vector D16 Int8) -> CodeGenFunction r (Value (Vector D16 Int8)))
pabsb = Ext.intrinsic ssse3 "pabs.b"

pabsw :: Ext.T (Value (Vector D8 Int16) -> CodeGenFunction r (Value (Vector D8 Int16)))
pabsw = Ext.intrinsic ssse3 "pabs.w"

pabsd :: Ext.T (Value (Vector D4 Int32) -> CodeGenFunction r (Value (Vector D4 Int32)))
pabsd = Ext.intrinsic ssse3 "pabs.d"


pmuludq :: Ext.T (Value (Vector D4 Word32) -> Value (Vector D4 Word32) -> CodeGenFunction r (Value (Vector D2 Word64)))
pmuludq = Ext.intrinsic sse2 "pmulu.dq"

pmulld :: Ext.T (Value (Vector D4 Word32) -> Value (Vector D4 Word32) -> CodeGenFunction r (Value (Vector D4 Word32)))
pmulld = Ext.intrinsic sse41 "pmulld"


cvtps2dq :: Ext.T (VFloat -> CodeGenFunction r (Value (Vector D4 Int32)))
cvtps2dq = Ext.intrinsic sse2 "cvtps2dq"

-- | the upper two integers are set to zero, there is no instruction that converts to Int64
cvtpd2dq :: Ext.T (VDouble -> CodeGenFunction r (Value (Vector D4 Int32)))
cvtpd2dq = Ext.intrinsic sse2 "cvtpd2dq"


valueUnit :: Value () -> ()
valueUnit _ = ()

{- |
MXCSR is not really supported by LLVM-2.6.
LLVM does not know about the dependency of all floating point operations
on this status register.
-}
ldmxcsr :: Ext.T (Value (Ptr Word32) -> CodeGenFunction r ())
ldmxcsr =
   fmap (fmap valueUnit .) $ Ext.intrinsicAttr [] sse1 "ldmxcsr"

stmxcsr :: Ext.T (Value (Ptr Word32) -> CodeGenFunction r ())
stmxcsr =
   fmap (fmap valueUnit .) $ Ext.intrinsicAttr [] sse1 "stmxcsr"

withMXCSR :: Word32 -> Ext.T (CodeGenFunction r a -> CodeGenFunction r a)
withMXCSR mxcsr =
   Ext.with2 ldmxcsr stmxcsr $ \ ld st f -> do
      mxcsrOld <- LLVM.alloca
      st mxcsrOld
      mxcsrFloor <- LLVM.alloca
      LLVM.store (valueOf $ mxcsr) mxcsrFloor
{- unfortunately, createGlobal is a function CodeGenModule monad
      mxcsrFloor <-
         LLVM.createGlobal True LLVM.InternalLinkage mxcsr
-}
      ld mxcsrFloor
      r <- f
      ld mxcsrOld
      return r

{-
[maxsd, minsd, maxpd, minpd] =
   map (Ext.intrinsic sse2)
     ["max.ss", "min.ss", "max.ps", "min.ps"]
-}

haddps :: Ext.T (VFloat -> VFloat -> CodeGenFunction r VFloat)
haddps = Ext.intrinsic sse3 "hadd.ps"

haddpd :: Ext.T (VDouble -> VDouble -> CodeGenFunction r VDouble)
haddpd = Ext.intrinsic sse3 "hadd.pd"

dpps :: Ext.T (VFloat -> VFloat -> Value Word32 -> CodeGenFunction r VFloat)
dpps = Ext.intrinsic sse41 "dpps"

dppd :: Ext.T (VDouble -> VDouble -> Value Word32 -> CodeGenFunction r VDouble)
dppd = Ext.intrinsic sse41 "dppd"

roundss, roundps :: Ext.T (VFloat -> Value Word32 -> CodeGenFunction r VFloat)
roundss = Ext.intrinsic sse41 "round.ss"
roundps = Ext.intrinsic sse41 "round.ps"

roundsd, roundpd :: Ext.T (VDouble -> Value Word32 -> CodeGenFunction r VDouble)
roundsd = Ext.intrinsic sse41 "round.sd"
roundpd = Ext.intrinsic sse41 "round.pd"



{-
Not an LLVM intrinsic but implementation specific:
We expect that floating point values are in IEEE format
and thus the most significant bit is the sign.
The absolute value can be computed very efficiently by clearing the sign bit.
Actually, LLVM's codegen implements neg by an XOR on the sign bit.
-}
absss :: Ext.T (VFloat -> CodeGenFunction r VFloat)
absss =
   Ext.wrap sse1 $
   LLVM.bitcastUnify
     <=< A.and (LLVM.value $ constVector $ map constOf $ (flip clearBit 31 $ complement 0) : repeat (complement 0)
            :: Value (Vector D4 Word32))
     <=< LLVM.bitcastUnify

{-
This function works on a single Float,
but I like to do the masking in an XMM register
because usually the value is there anyway.

absss =
   flip LLVM.extractelement (valueOf 0)
     . flip asTypeOf (undefined :: VFloat)
     <=< LLVM.bitcastUnify
--        <=< A.and (LLVM.value $ constVector [constOf 0x7FFFFFFF] :: Value (Vector D4 Word32))
--        <=< A.and (LLVM.value $ constVector [constOf 0x7FFFFFFF, LLVM.undef, LLVM.undef, LLVM.undef] :: Value (Vector D4 Word32))
     <=< A.and (LLVM.value $ constVector [constOf 0x7FFFFFFF, LLVM.zero, LLVM.zero, LLVM.zero] :: Value (Vector D4 Word32))
     <=< LLVM.bitcastUnify
     . flip asTypeOf (undefined :: VFloat)
     <=< flip (LLVM.insertelement (LLVM.value LLVM.undef)) (valueOf 0)
-}
{- This moves the value to a general purpose register and performs the bit masking there
absss =
   LLVM.bitcastUnify
     <=< A.and (valueOf 0x7FFFFFFF :: Value Word32)
     <=< LLVM.bitcastUnify
-}

abssd :: Ext.T (VDouble -> CodeGenFunction r VDouble)
abssd =
   Ext.wrap sse2 $
   LLVM.bitcastUnify
     <=< A.and (LLVM.value $ constVector $ map constOf $ (flip clearBit 63 $ complement 0) : repeat (complement 0)
            :: Value (Vector D2 Word64))
     <=< LLVM.bitcastUnify

absps :: Ext.T (VFloat -> CodeGenFunction r VFloat)
absps =
   Ext.wrap sse1 $
   LLVM.bitcastUnify
     <=< A.and (LLVM.value $ constVector [constOf $ flip clearBit 31 $ complement 0]
            :: Value (Vector D4 Word32))
     <=< LLVM.bitcastUnify

abspd :: Ext.T (VDouble -> CodeGenFunction r VDouble)
abspd =
   Ext.wrap sse2 $
   LLVM.bitcastUnify
     <=< A.and (LLVM.value $ constVector [constOf $ flip clearBit 63 $ complement 0]
            :: Value (Vector D2 Word64))
     <=< LLVM.bitcastUnify


{- |
cumulative sum:
@(a,b,c,d) -> (a,a+b,a+b+c,a+b+c+d)@

I try to cleverly use horizontal add,
but the generic version in the Vector module is better.
-}
_cumulate1s :: Ext.T (VFloat -> CodeGenFunction r VFloat)
_cumulate1s = Ext.with haddps $ \haddp x -> do
   y <- haddp x (LLVM.value LLVM.undef)
   z <- LLVM.shufflevector x y $
      constVector $ map constOf [0,4,2,5]
   offset <- LLVM.shufflevector y (LLVM.value LLVM.zero) $
      constVector $ map constOf [4,5,0,0]
   A.add z offset