{-# LANGUAGE TypeFamilies #-}
{- |
Very simple random number generator according to Knuth
which should be fast and should suffice for generating just noise.
<http://www.softpanorama.org/Algorithms/random_generators.shtml>
-}
module Synthesizer.LLVM.Random where

import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector

import qualified LLVM.Extra.Extension.X86 as X86
import qualified LLVM.Extra.Extension as Ext

import qualified LLVM.Extra.Arithmetic as A

import LLVM.Core
          (CodeGenFunction, Value, Vector,
           zext, trunc, lshr, value, valueOf,
           undef, constOf, constVector, bitcast, )
import qualified LLVM.Core as LLVM
import qualified Type.Data.Num.Decimal as TypeNum

import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.Empty as Empty
import Data.NonEmpty ((!:), )
import Data.Function.HT (nest, )

import Data.Int (Int32, )
import Data.Word (Word32, Word64, )


factor :: Integral a => a
factor = 40692

modulus :: Integral a => a
modulus = 2147483399 -- 2^31-249

{-
We have to split the 32 bit integer in order to avoid overflow on multiplication.
'split' must be chosen, such that 'splitRem' is below 2^16.
-}
split :: Word32
split = succ $ div modulus factor

splitRem :: Word32
splitRem = split * factor - modulus


{- |
efficient computation of @mod (s*factor) modulus@
without Integer or Word64, as in 'next64'.
-}
next :: Word32 -> Word32
next s =
   let (sHigh, sLow) = divMod s split
   in  flip mod modulus $
       splitRem*sHigh + factor*sLow

next64 :: Word32 -> Word32
next64 s =
   fromIntegral $
   flip mod modulus $
   factor * (fromIntegral s :: Word64)

nextCG32 :: Value Word32 -> CodeGenFunction r (Value Word32)
nextCG32 s = do
   sHigh <- A.mul (valueOf splitRem) =<< LLVM.idiv s (valueOf split)
   sLow  <- A.mul (valueOf factor)   =<< LLVM.irem s (valueOf split)
   flip A.irem (valueOf modulus) =<< A.add sHigh sLow

nextCG64 :: Value Word32 -> CodeGenFunction r (Value Word32)
nextCG64 s =
   trunc =<<
   {-
   This is slow on x86 since the native @div@ is not used
   since LLVM wants to prevent overflow.
   We know that there cannot be an overflow,
   but I do not know how to tell LLVM.
   -}
   flip A.irem (valueOf (modulus :: Word64)) =<<
   A.mul (valueOf factor) =<<
   zext s

nextCG :: Value Word32 -> CodeGenFunction r (Value Word32)
nextCG s = do
   x <- A.mul (valueOf $ factor :: Value Word64) =<< zext s
   {-
   split 64 result between bit 30 and bit 31
   we cannot split above bit 31,
   since then 'low' can be up to 2^32-1
   and then later addition overflows.
   -}
   let p2e31 = 2^(31::Int)
   low <- A.and (valueOf $ p2e31-1) =<< trunc x
   high <- trunc =<< flip lshr (valueOf (31 :: Word64)) x
   -- fac = mod (2^31) modulus
   let fac = p2e31 - modulus
   {-
   fac < 250
   high < factor
   fac*high < factor*250
   low < 2^31
   low + fac*high
      < 2^31 + factor*250
      < 2*modulus
   Thus modulo by modulus needs at most one subtraction.
   -}
   subtractIfPossible (valueOf modulus)
      =<< A.add low
      =<< A.mul (valueOf fac) high


{-
How to vectorise?
E.g. by repeated distribution of modulus and split at bit 31.
Can we replace div by modulus by mul with (2^31+249) ?
-}
vectorParameter ::
   Integral a =>
   Int -> a
vectorParameter n =
   fromIntegral $ nest n next 1

vectorSeed ::
   (TypeNum.Positive n) =>
   Word32 -> Vector n Word32
vectorSeed seed =
   LLVM.cyclicVector $ NonEmptyC.iterate next seed
-- vector $ NonEmptyC.iterate next seed

vector64 :: Value (Vector n Word64) -> Value (Vector n Word64)
vector64 = id

nextVector ::
   (TypeNum.Positive n) =>
   Value (Vector n Word32) ->
   CodeGenFunction r (Value (Vector n Word32))
nextVector s =
   Ext.run (nextVectorGeneric s) $
   Ext.with nextVector4X86 $ \nextChunk ->
   Vector.mapChunks (nextChunk (Vector.size s)) s

{- |
This needs only a third of the code of nextVectorGeneric for Vector D4
(37 instructions vs. 110 instructions)
because it arranges data more sensibly:
It de-interleaves the vector and truncates from 64 bit to 32 bit in-place.
-}
nextVector4X86 ::
   Ext.T
      (Int ->
       Value (Vector TypeNum.D4 Word32) ->
       CodeGenFunction r (Value (Vector TypeNum.D4 Word32)))
nextVector4X86 =
   Ext.with X86.pmuludq $ \muludq n s -> do
   let fac = 2^(31::Int) - modulus

       mulAndReduce x = do
          (low0, high0) <-
             splitVector31to64 =<<
             muludq (prepConstFactor (vectorParameter n)) x

          splitVector31to64 =<<
             A.add low0 =<<
             muludq (prepConstFactor fac) =<<
             bitcast high0

   (lowEven, highEven) <- mulAndReduce =<< shuffleHoles s 0 2

   (lowOdd, highOdd) <- mulAndReduce =<< shuffleHoles s 1 3

   low  <- truncAndInterleave2x64to4x32 lowEven  lowOdd
   high <- truncAndInterleave2x64to4x32 highEven highOdd

   prodMod <-
      A.add low =<<
      -- more efficient for Word32 on x86 than LLVM-2.6's mul
      Vector.mul (SoV.replicateOf fac) high
   prodModS <- A.sub prodMod (SoV.replicateOf modulus)

   {-
   An element should become smaller by subtraction.
   If it becomes greater, then there was an overflow
   and 'min' chooses the value before subtraction.
   -}
   Vector.min prodModS prodMod

truncAndInterleave2x64to4x32 ::
   Value (Vector TypeNum.D2 Word64) ->
   Value (Vector TypeNum.D2 Word64) ->
   CodeGenFunction r (Value (Vector TypeNum.D4 Word32))
truncAndInterleave2x64to4x32 even2x64 odd2x64 = do
   even4x32 <- bitcast even2x64
   odd4x32  <- bitcast odd2x64
   Vector.shuffleMatchPlain2 even4x32 odd4x32
      (constVector $ fmap constOf $ 0 !: 4 !: 2 !: 6 !: Empty.Cons)


{-
This will access MMX registers.
-}
nextVector2X86 ::
   Ext.T
      (Int ->
       Value (Vector TypeNum.D2 Word32) ->
       CodeGenFunction r (Value (Vector TypeNum.D2 Word32)))
nextVector2X86 =
   Ext.with X86.pmuludq $ \muludq n s -> do
   (low0, high0) <-
      splitVector31to64 =<<
      muludq (prepConstFactor (vectorParameter n)) =<<
      shuffleHoles s 0 1
   -- fac = mod (2^31) modulus
   let fac = 2^(31::Int) - modulus
   (low1, high1) <-
      splitVector31to64 =<<
      A.add low0 =<<
      muludq (prepConstFactor fac) =<<
      bitcast high0

   prodMod64 <-
      A.add low1 =<<
      muludq (prepConstFactor fac) =<<
      bitcast high1

--   prodMod <- Vector.map trunc prodMod64
--   prodModS <- A.sub prodMod (SoV.replicateOf modulus)
--   Vector.min prodModS prodMod

{-
   prodMod64as32 <- bitcast prodMod64
   prodMod <- Vector.shuffle
      (prodMod64as32 :: Value (Vector TypeNum.D4 Word32))
      (constVector $ fmap constOf $ 0!:2!:Empty.Cons)

   prodModS <- A.sub prodMod (SoV.replicateOf modulus)
-}

   prodMod <- bitcast prodMod64
   prodModS <- A.sub prodMod (prepConstFactor modulus)

   {-
   An element should become smaller by subtraction.
   If it becomes greater, then there was an overflow
   and 'min' chooses the value before subtraction.
   -}
   result <- Vector.min prodModS prodMod
   LLVM.shufflevector
      (result :: Value (Vector TypeNum.D4 Word32))
      (LLVM.value LLVM.undef)
      (constVector $ fmap constOf $ 0!:2!:Empty.Cons)

prepConstFactor :: Word32 -> Value (Vector TypeNum.D4 Word32)
prepConstFactor x =
   value $ constVector $
   constOf x !: undef !: constOf x !: undef !: Empty.Cons

shuffleHoles ::
   (TypeNum.Positive n) =>
   Value (Vector n Word32) ->
   Word32 -> Word32 ->
   CodeGenFunction r (Value (Vector TypeNum.D4 Word32))
shuffleHoles s j k =
   LLVM.shufflevector s (value undef)
      (constVector $ constOf j !: undef !: constOf k !: undef !: Empty.Cons)

splitVector31to64 ::
   (TypeNum.Positive n) =>
   Value (Vector n Word64) ->
   CodeGenFunction r (Value (Vector n Word64), Value (Vector n Word64))
splitVector31to64 x = do
   low  <- A.and (SoV.replicateOf (2^(31::Int)-1)) x
   high <- flip lshr (SoV.replicateOf 31 `asTypeOf` x) x
   return (low, high)

{-
In case of a vector random generator the factor depends on the vector size
and thus we cannot do optimizations on a constant factor as in nextCG.
Thus we just compute the product @factor*seed@ as is
(this is of type @Word32 -> Word32 -> Word64@)
and try to compute @urem@ without using LLVM's @urem@
that calls __umoddi3 on every element.
Instead we optimize on the constant modulus
and utilize that is slightly smaller than 2^31.

We split the product:
  factor*seed = high0*2^31 + low0

Now it is
mod (factor*seed) modulus
  = mod (high0*2^31 + low0) modulus
  = mod (high0 * mod (2^31) modulus + low0) modulus
  = mod (high0 * 249 + low0) modulus

However, high0 * 249 + low0 is still too big,
it can be up to (excluding) 2^31 * 250.
Thus we repeat the split
high0 * 249 + low0 = high1 * 2^31 + low1

It is high1 < 250, and thus high1*249 < 62500,
high1 * 249 + low1 < 2*modulus.
With x = high1 * 249 + low1
we have
mod (factor*seed) modulus
  = if x<modulus
      then x
      else x-modulus


An alternative approach would be to still multiply @let p = factor*seed@ exactly,
then do an approximate division @let q = approxdiv p modulus@,
then compute @p - q*modulus@ and
do a final adjustment in order to fix rounding errors.
The approximate division could be done by a floating point multiplication
or an integer multiplication with some shifting.
But in the end we will need at least the same number of multiplications
as in the approach that is implemented here.
-}
nextVectorGeneric ::
   (TypeNum.Positive n) =>
   Value (Vector n Word32) ->
   CodeGenFunction r (Value (Vector n Word32))
nextVectorGeneric s = do
   {-
   It seems that LLVM-2.6 on x86 does not make use of the fact,
   that the upper doublewords are zero.
   It seems to implement a full 64x64 multiplication in terms of pmuludq.
   -}
   (low0, high0) <-
      splitVector31 =<<
      Vector.umul32to64 (SoV.replicateOf (vectorParameter (Vector.size s))) s
   -- fac = mod (2^31) modulus
   let fac :: Integral a => a
       fac = 2^(31::Int) - modulus
   (low1, high1) <-
      splitVector31 =<<
      (\x -> A.add x =<< Vector.map zext low0) =<<
      Vector.umul32to64 (SoV.replicateOf fac) high0

   subtractIfPossible (SoV.replicateOf modulus)
      =<< A.add low1
      =<< Vector.mul (SoV.replicateOf fac) high1

{- |
@subtractIfPossible d x@ returns @A.sub x d@
if this is possible without underflow.
Otherwise it returns @x@.

Only works for unsigned types.
-}
subtractIfPossible ::
   (SoV.Real a) =>
   Value a -> Value a -> CodeGenFunction r (Value a)
subtractIfPossible d x = do
   {-
   An element should become smaller by subtraction.
   If it becomes greater, then there was an overflow
   and 'min' chooses the value before subtraction.
   -}
   SoV.min x =<< A.sub x d
   -- alternatively (slower):
   --   flip selectNonNegativeGeneric x =<< A.sub x d

{- |
Select non-negative elements from the first vector,
otherwise select corresponding elements from the second vector.
-}
selectNonNegativeGeneric ::
   (TypeNum.Positive n) =>
   Value (Vector n Int32) ->
   Value (Vector n Int32) ->
   CodeGenFunction r (Value (Vector n Int32))
selectNonNegativeGeneric x y = do
   b <- A.cmp LLVM.CmpGE x A.zero
   Vector.select b x y


splitVector31 ::
   (TypeNum.Positive n) =>
   Value (Vector n Word64) ->
   CodeGenFunction r (Value (Vector n Word32), Value (Vector n Word32))
splitVector31 x = do
   low  <- A.and (SoV.replicateOf (2^(31::Int)-1)) =<< Vector.map trunc x
   high <- Vector.map trunc =<< flip lshr (SoV.replicateOf (31 :: Word64) `asTypeOf` x) x
   return (low, high)

{- |
This is the most obvious implementation
but unfortunately calls the expensive __umoddi3.
-}
nextVector64 ::
   (TypeNum.Positive n) =>
   Value (Vector n Word32) ->
   CodeGenFunction r (Value (Vector n Word32))
nextVector64 s =
   Vector.map trunc =<<
   flip A.irem (SoV.replicateOf modulus) =<<
   Vector.umul32to64 (SoV.replicateOf (vectorParameter (Vector.size s))) s