{-# LANGUAGE BangPatterns               #-}
{-# LANGUAGE CPP                        #-}
{-# LANGUAGE DeriveDataTypeable         #-}
{-# LANGUAGE DeriveGeneric              #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE ForeignFunctionInterface   #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash                  #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE UnboxedTuples              #-}
#if __GLASGOW_HASKELL__ >= 707
{-# LANGUAGE RoleAnnotations            #-}
#endif
module System.Random.PCG.Pure
  ( 
    Gen, GenIO, GenST
  , create, createSystemRandom, initialize, withSystemRandom
    
  , Variate (..)
  , advance, retract
    
  , FrozenGen, save, restore, seed, initFrozen
    
    
  , uniformW8, uniformW16, uniformW32, uniformW64
  , uniformI8, uniformI16, uniformI32, uniformI64
  , uniformF, uniformD, uniformBool
    
  , uniformRW8, uniformRW16, uniformRW32, uniformRW64
  , uniformRI8, uniformRI16, uniformRI32, uniformRI64
  , uniformRF, uniformRD, uniformRBool
    
  , uniformBW8, uniformBW16, uniformBW32, uniformBW64
  , uniformBI8, uniformBI16, uniformBI32, uniformBI64
  , uniformBF, uniformBD, uniformBBool
    
  , SetSeq
  , next'
  , advanceSetSeq
  ) where
import Control.Monad.Primitive
import Data.Bits
import Data.Data
import Data.Primitive.ByteArray
import Foreign
import GHC.Generics
import System.Random.PCG.Class
import System.Random
type GenIO = Gen RealWorld
type GenST = Gen
newtype Gen s = G (MutableByteArray s)
type FrozenGen = SetSeq
data SetSeq = SetSeq
  {-# UNPACK #-} !Word64 
  {-# UNPACK #-} !Word64 
  deriving (Show, Ord, Eq, Data, Typeable, Generic)
instance Storable SetSeq where
  sizeOf _ = 16
  {-# INLINE sizeOf #-}
  alignment _ = 8
  {-# INLINE alignment #-}
  poke ptr (SetSeq x y) = poke ptr' x >> pokeElemOff ptr' 1 y
    where ptr' = castPtr ptr
  {-# INLINE poke #-}
  peek ptr = do
    let ptr' = castPtr ptr
    s <- peek ptr'
    inc <- peekElemOff ptr' 1
    return $ SetSeq s inc
  {-# INLINE peek #-}
seed :: SetSeq
seed = SetSeq 9600629759793949339 15726070495360670683
data Pair = Pair
  {-# UNPACK #-} !Word64 
  {-# UNPACK #-} !Word32 
multiplier :: Word64
multiplier = 6364136223846793005
state :: SetSeq -> Word64
state (SetSeq s inc) = s * multiplier + inc
{-# INLINE state #-}
output :: Word64 -> Word32
output s =
  (shifted `unsafeShiftR` rot) .|. (shifted `unsafeShiftL` (negate rot .&. 31))
  where
    rot     = fromIntegral $ s `shiftR` 59 :: Int
    shifted = fromIntegral $ ((s `shiftR` 18) `xor` s) `shiftR` 27 :: Word32
{-# INLINE output #-}
pair :: SetSeq -> Pair
pair g@(SetSeq s _) = Pair (state g) (output s)
{-# INLINE pair #-}
bounded :: Word32 -> SetSeq -> Pair
bounded b (SetSeq s0 inc) = go s0
  where
    t = negate b `mod` b
    go !s | r >= t    = Pair s' (r `mod` b)
          | otherwise = go s'
      where Pair s' r = pair (SetSeq s inc)
{-# INLINE bounded #-}
advancing
  :: Word64 
  -> Word64 
  -> Word64 
  -> Word64 
  -> Word64 
advancing d0 s m0 p0 = go d0 m0 p0 1 0
  where
    go d cm cp am ap
      | d <= 0    = am * s + ap
      | odd d     = go d' cm' cp' (am * cm) (ap * cm + cp)
      | otherwise = go d' cm' cp' am        ap
      where
        cm' = cm * cm
        cp' = (cm + 1) * cp
        d'  = d `div` 2
advanceSetSeq :: Word64 -> FrozenGen -> FrozenGen
advanceSetSeq d (SetSeq s inc) = SetSeq (advancing d s multiplier inc) inc
advanceSetSeq' :: Word64 -> FrozenGen -> Word64
advanceSetSeq' d (SetSeq s inc) = advancing d s multiplier inc
start :: Word64 -> Word64 -> SetSeq
start a b = SetSeq s i
  where
    s = state (SetSeq (a + i) i)
    i = (b `shiftL` 1) .|. 1
{-# INLINE start #-}
next' :: SetSeq -> (Word32, SetSeq)
next' g@(SetSeq _ inc) = (r, SetSeq s' inc)
  where Pair s' r = pair g
{-# INLINE next' #-}
save :: PrimMonad m => Gen (PrimState m) -> m SetSeq
save (G a) = do
  s   <- readByteArray a 0
  inc <- readByteArray a 1
  return $ SetSeq s inc
{-# INLINE save #-}
restore :: PrimMonad m => FrozenGen -> m (Gen (PrimState m))
restore (SetSeq s inc) = do
  a <- newByteArray 16
  writeByteArray a 0 s
  writeByteArray a 1 inc
  return $! G a
{-# INLINE restore #-}
initFrozen :: Word64 -> Word64 -> SetSeq
initFrozen = start
create :: PrimMonad m => m (Gen (PrimState m))
create = restore seed
initialize :: PrimMonad m => Word64 -> Word64 -> m (Gen (PrimState m))
initialize a b = restore (initFrozen a b)
withSystemRandom :: (GenIO -> IO a) -> IO a
withSystemRandom f = do
  a <- sysRandom
  b <- sysRandom
  initialize a b >>= f
createSystemRandom :: IO GenIO
createSystemRandom = withSystemRandom return
advance :: PrimMonad m => Word64 -> Gen (PrimState m) -> m ()
advance u g@(G a) = do
  ss <- save g
  let s' = advanceSetSeq' u ss
  writeByteArray a 0 s'
{-# INLINE advance #-}
retract :: PrimMonad m => Word64 -> Gen (PrimState m) -> m ()
retract u g = advance (-u) g
{-# INLINE retract #-}
instance (PrimMonad m, s ~ PrimState m) => Generator (Gen s) m where
  uniform1 f (G a) = do
    s   <- readByteArray a 0
    inc <- readByteArray a 1
    writeByteArray a 0 $! s * multiplier + inc
    return $! f (output s)
  {-# INLINE uniform1 #-}
  uniform2 f (G a) = do
    s   <- readByteArray a 0
    inc <- readByteArray a 1
    let !s' = s * multiplier + inc
    writeByteArray a 0 $! s' * multiplier + inc
    return $! f (output s) (output s')
  {-# INLINE uniform2 #-}
  uniform1B f b g@(G a) = do
    ss <- save g
    let Pair s' r = bounded b ss
    writeByteArray a 0 s'
    return $! f r
  {-# INLINE uniform1B #-}
instance RandomGen FrozenGen where
  next (SetSeq s inc) = (wordsTo64Bit w1 w2, SetSeq s'' inc)
    where
      Pair s'  w1 = pair (SetSeq s inc)
      Pair s'' w2 = pair (SetSeq s' inc)
  {-# INLINE next #-}
  split (SetSeq s inc) = (SetSeq s4 inc, mk w1 w2 w3 w4)
    where
      mk a b c d = start (wordsTo64Bit a b) (wordsTo64Bit c d)
      Pair s1 w1 = pair (SetSeq s  inc)
      Pair s2 w2 = pair (SetSeq s1 inc)
      Pair s3 w3 = pair (SetSeq s2 inc)
      Pair s4 w4 = pair (SetSeq s3 inc)
  {-# INLINE split #-}