module Synthesizer.LLVM.Random where
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Extension.X86 as X86
import qualified LLVM.Extra.Extension as Ext
import qualified LLVM.Extra.Arithmetic as A
import LLVM.Core
(CodeGenFunction, Value, Vector,
zext, trunc, lshr, value, valueOf,
undef, constOf, constVector, bitcast, )
import qualified LLVM.Core as LLVM
import qualified Type.Data.Num.Decimal as TypeNum
import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.Empty as Empty
import Data.NonEmpty ((!:), )
import Data.Function.HT (nest, )
import Data.Int (Int32, )
import Data.Word (Word32, Word64, )
factor :: Integral a => a
factor = 40692
modulus :: Integral a => a
modulus = 2147483399
split :: Word32
split = succ $ div modulus factor
splitRem :: Word32
splitRem = split * factor modulus
next :: Word32 -> Word32
next s =
let (sHigh, sLow) = divMod s split
in flip mod modulus $
splitRem*sHigh + factor*sLow
next64 :: Word32 -> Word32
next64 s =
fromIntegral $
flip mod modulus $
factor * (fromIntegral s :: Word64)
nextCG32 :: Value Word32 -> CodeGenFunction r (Value Word32)
nextCG32 s = do
sHigh <- A.mul (valueOf splitRem) =<< LLVM.idiv s (valueOf split)
sLow <- A.mul (valueOf factor) =<< LLVM.irem s (valueOf split)
flip A.irem (valueOf modulus) =<< A.add sHigh sLow
nextCG64 :: Value Word32 -> CodeGenFunction r (Value Word32)
nextCG64 s =
trunc =<<
flip A.irem (valueOf (modulus :: Word64)) =<<
A.mul (valueOf factor) =<<
zext s
nextCG :: Value Word32 -> CodeGenFunction r (Value Word32)
nextCG s = do
x <- A.mul (valueOf $ factor :: Value Word64) =<< zext s
let p2e31 = 2^(31::Int)
low <- A.and (valueOf $ p2e311) =<< trunc x
high <- trunc =<< flip lshr (valueOf (31 :: Word64)) x
let fac = p2e31 modulus
subtractIfPossible (valueOf modulus)
=<< A.add low
=<< A.mul (valueOf fac) high
vectorParameter ::
Integral a =>
Int -> a
vectorParameter n =
fromIntegral $ nest n next 1
vectorSeed ::
(TypeNum.Positive n) =>
Word32 -> Vector n Word32
vectorSeed seed =
LLVM.cyclicVector $ NonEmptyC.iterate next seed
vector64 :: Value (Vector n Word64) -> Value (Vector n Word64)
vector64 = id
nextVector ::
(TypeNum.Positive n) =>
Value (Vector n Word32) ->
CodeGenFunction r (Value (Vector n Word32))
nextVector s =
Ext.run (nextVectorGeneric s) $
Ext.with nextVector4X86 $ \nextChunk ->
Vector.mapChunks (nextChunk (Vector.size s)) s
nextVector4X86 ::
Ext.T
(Int ->
Value (Vector TypeNum.D4 Word32) ->
CodeGenFunction r (Value (Vector TypeNum.D4 Word32)))
nextVector4X86 =
Ext.with X86.pmuludq $ \muludq n s -> do
let fac = 2^(31::Int) modulus
mulAndReduce x = do
(low0, high0) <-
splitVector31to64 =<<
muludq (prepConstFactor (vectorParameter n)) x
splitVector31to64 =<<
A.add low0 =<<
muludq (prepConstFactor fac) =<<
bitcast high0
(lowEven, highEven) <- mulAndReduce =<< shuffleHoles s 0 2
(lowOdd, highOdd) <- mulAndReduce =<< shuffleHoles s 1 3
low <- truncAndInterleave2x64to4x32 lowEven lowOdd
high <- truncAndInterleave2x64to4x32 highEven highOdd
prodMod <-
A.add low =<<
Vector.mul (SoV.replicateOf fac) high
prodModS <- A.sub prodMod (SoV.replicateOf modulus)
Vector.min prodModS prodMod
truncAndInterleave2x64to4x32 ::
Value (Vector TypeNum.D2 Word64) ->
Value (Vector TypeNum.D2 Word64) ->
CodeGenFunction r (Value (Vector TypeNum.D4 Word32))
truncAndInterleave2x64to4x32 even2x64 odd2x64 = do
even4x32 <- bitcast even2x64
odd4x32 <- bitcast odd2x64
Vector.shuffleMatchPlain2 even4x32 odd4x32
(constVector $ fmap constOf $ 0 !: 4 !: 2 !: 6 !: Empty.Cons)
nextVector2X86 ::
Ext.T
(Int ->
Value (Vector TypeNum.D2 Word32) ->
CodeGenFunction r (Value (Vector TypeNum.D2 Word32)))
nextVector2X86 =
Ext.with X86.pmuludq $ \muludq n s -> do
(low0, high0) <-
splitVector31to64 =<<
muludq (prepConstFactor (vectorParameter n)) =<<
shuffleHoles s 0 1
let fac = 2^(31::Int) modulus
(low1, high1) <-
splitVector31to64 =<<
A.add low0 =<<
muludq (prepConstFactor fac) =<<
bitcast high0
prodMod64 <-
A.add low1 =<<
muludq (prepConstFactor fac) =<<
bitcast high1
prodMod <- bitcast prodMod64
prodModS <- A.sub prodMod (prepConstFactor modulus)
result <- Vector.min prodModS prodMod
LLVM.shufflevector
(result :: Value (Vector TypeNum.D4 Word32))
(LLVM.value LLVM.undef)
(constVector $ fmap constOf $ 0!:2!:Empty.Cons)
prepConstFactor :: Word32 -> Value (Vector TypeNum.D4 Word32)
prepConstFactor x =
value $ constVector $
constOf x !: undef !: constOf x !: undef !: Empty.Cons
shuffleHoles ::
(TypeNum.Positive n) =>
Value (Vector n Word32) ->
Word32 -> Word32 ->
CodeGenFunction r (Value (Vector TypeNum.D4 Word32))
shuffleHoles s j k =
LLVM.shufflevector s (value undef)
(constVector $ constOf j !: undef !: constOf k !: undef !: Empty.Cons)
splitVector31to64 ::
(TypeNum.Positive n) =>
Value (Vector n Word64) ->
CodeGenFunction r (Value (Vector n Word64), Value (Vector n Word64))
splitVector31to64 x = do
low <- A.and (SoV.replicateOf (2^(31::Int)1)) x
high <- flip lshr (SoV.replicateOf 31 `asTypeOf` x) x
return (low, high)
nextVectorGeneric ::
(TypeNum.Positive n) =>
Value (Vector n Word32) ->
CodeGenFunction r (Value (Vector n Word32))
nextVectorGeneric s = do
(low0, high0) <-
splitVector31 =<<
Vector.umul32to64 (SoV.replicateOf (vectorParameter (Vector.size s))) s
let fac :: Integral a => a
fac = 2^(31::Int) modulus
(low1, high1) <-
splitVector31 =<<
(\x -> A.add x =<< Vector.map zext low0) =<<
Vector.umul32to64 (SoV.replicateOf fac) high0
subtractIfPossible (SoV.replicateOf modulus)
=<< A.add low1
=<< Vector.mul (SoV.replicateOf fac) high1
subtractIfPossible ::
(SoV.Real a) =>
Value a -> Value a -> CodeGenFunction r (Value a)
subtractIfPossible d x = do
SoV.min x =<< A.sub x d
selectNonNegativeGeneric ::
(TypeNum.Positive n) =>
Value (Vector n Int32) ->
Value (Vector n Int32) ->
CodeGenFunction r (Value (Vector n Int32))
selectNonNegativeGeneric x y = do
b <- A.cmp LLVM.CmpGE x A.zero
Vector.select b x y
splitVector31 ::
(TypeNum.Positive n) =>
Value (Vector n Word64) ->
CodeGenFunction r (Value (Vector n Word32), Value (Vector n Word32))
splitVector31 x = do
low <- A.and (SoV.replicateOf (2^(31::Int)1)) =<< Vector.map trunc x
high <- Vector.map trunc =<< flip lshr (SoV.replicateOf (31 :: Word64) `asTypeOf` x) x
return (low, high)
nextVector64 ::
(TypeNum.Positive n) =>
Value (Vector n Word32) ->
CodeGenFunction r (Value (Vector n Word32))
nextVector64 s =
Vector.map trunc =<<
flip A.irem (SoV.replicateOf modulus) =<<
Vector.umul32to64 (SoV.replicateOf (vectorParameter (Vector.size s))) s