{-# LANGUAGE TypeFamilies #-}
module Synthesizer.LLVM.Random where
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Core.Guided as Guided
import LLVM.Core
(CodeGenFunction, Value, Vector,
zext, trunc, lshr, valueOf)
import qualified LLVM.Core as LLVM
import qualified Type.Data.Num.Decimal as TypeNum
import qualified Data.NonEmpty.Class as NonEmptyC
import Data.Function.HT (nest)
import Data.Int (Int32)
import Data.Word (Word32, Word64)
factor :: Integral a => a
factor = 40692
modulus :: Integral a => a
modulus = 2147483399
split :: Word32
split = succ $ div modulus factor
splitRem :: Word32
splitRem = split * factor - modulus
next :: Word32 -> Word32
next s =
let (sHigh, sLow) = divMod s split
in flip mod modulus $
splitRem*sHigh + factor*sLow
next64 :: Word32 -> Word32
next64 s =
fromIntegral $
flip mod modulus $
factor * (fromIntegral s :: Word64)
nextCG32 :: Value Word32 -> CodeGenFunction r (Value Word32)
nextCG32 s = do
sHigh <- A.mul (valueOf splitRem) =<< LLVM.idiv s (valueOf split)
sLow <- A.mul (valueOf factor) =<< LLVM.irem s (valueOf split)
flip A.irem (valueOf modulus) =<< A.add sHigh sLow
nextCG64 :: Value Word32 -> CodeGenFunction r (Value Word32)
nextCG64 s =
trunc =<<
flip A.irem (valueOf (modulus :: Word64)) =<<
A.mul (valueOf factor) =<<
zext s
nextCG :: Value Word32 -> CodeGenFunction r (Value Word32)
nextCG s = do
x <- A.mul (valueOf $ factor :: Value Word64) =<< zext s
let p2e31 = 2^(31::Int)
low <- A.and (valueOf $ p2e31-1) =<< trunc x
high <- trunc =<< flip lshr (valueOf (31 :: Word64)) x
let fac = p2e31 - modulus
subtractIfPossible (valueOf modulus)
=<< A.add low
=<< A.mul (valueOf fac) high
vectorParameter ::
Integral a =>
Int -> a
vectorParameter n =
fromIntegral $ nest n next 1
vectorSeed ::
(TypeNum.Positive n) =>
Word32 -> Vector n Word32
vectorSeed seed =
LLVM.cyclicVector $ NonEmptyC.iterate next seed
vector64 :: Value (Vector n Word64) -> Value (Vector n Word64)
vector64 = id
nextVector ::
(TypeNum.Positive n) =>
Value (Vector n Word32) ->
CodeGenFunction r (Value (Vector n Word32))
nextVector s = do
(low0, high0) <-
splitVector31 =<<
umul32to64 (SoV.replicateOf (vectorParameter (Vector.size s))) s
let fac :: Integral a => a
fac = 2^(31::Int) - modulus
(low1, high1) <-
splitVector31 =<<
(\x -> A.add x =<< Vector.map zext low0) =<<
umul32to64 (SoV.replicateOf fac) high0
subtractIfPossible (SoV.replicateOf modulus)
=<< A.add low1
=<< Vector.mul (SoV.replicateOf fac) high1
subtractIfPossible ::
(SoV.Real a) =>
Value a -> Value a -> CodeGenFunction r (Value a)
subtractIfPossible d x = do
SoV.min x =<< A.sub x d
selectNonNegativeGeneric ::
(TypeNum.Positive n) =>
Value (Vector n Int32) ->
Value (Vector n Int32) ->
CodeGenFunction r (Value (Vector n Int32))
selectNonNegativeGeneric x y = do
b <- A.cmp LLVM.CmpGE x A.zero
LLVM.select b x y
splitVector31 ::
(TypeNum.Positive n) =>
Value (Vector n Word64) ->
CodeGenFunction r (Value (Vector n Word32), Value (Vector n Word32))
splitVector31 x = do
low <- A.and (SoV.replicateOf (2^(31::Int)-1)) =<< Vector.map trunc x
high <- Vector.map trunc =<< flip lshr (SoV.replicateOf (31 :: Word64) `asTypeOf` x) x
return (low, high)
nextVector64 ::
(TypeNum.Positive n) =>
Value (Vector n Word32) ->
CodeGenFunction r (Value (Vector n Word32))
nextVector64 s =
Vector.map trunc =<<
flip A.irem (SoV.replicateOf modulus) =<<
umul32to64 (SoV.replicateOf (vectorParameter (Vector.size s))) s
umul32to64 ::
(TypeNum.Positive n) =>
Value (Vector n Word32) ->
Value (Vector n Word32) ->
CodeGenFunction r (Value (Vector n Word64))
umul32to64 x y = do
x64 <- Guided.ext Guided.vector x
y64 <- Guided.ext Guided.vector y
A.mul x64 y64