{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- |
-- Module: System.Random.Internal
-- Copyright: Copyright © 2019 Lars Kuhtz <lakuhtz@gmail.com>
-- License: BSD3
-- Maintainer: Lars Kuhtz <lakuhtz@gmail.com>
-- Stability: experimental
--
-- Dispatch for different PRNG implementations
--
module System.Random.Internal
( Gen
, Variate
, initialize
, uniform
, uniformR
) where

-- -------------------------------------------------------------------------- --
-- PCG
#ifdef RANDOM_PCG

import Control.Monad.Primitive
import System.Random.PCG hiding (initialize)
import qualified System.Random.PCG as PCG

initialize
    :: PrimMonad m
    => Int
    -> m (Gen (PrimState m))
initialize salt = PCG.initialize 0 (fromIntegral salt)

-- -------------------------------------------------------------------------- --
-- MWC
#elif defined RANDOM_MWC

import Control.Monad.Primitive
import Data.Vector
import System.Random.MWC hiding (initialize)
import qualified System.Random.MWC as MWC

initialize
    :: PrimMonad m
    => Int
    -> m (Gen (PrimState m))
initialize salt = MWC.initialize (singleton $ fromIntegral salt)

-- -------------------------------------------------------------------------- --
-- Random
#else

import Control.Monad.Primitive
import Data.STRef

#if MIN_VERSION_random(1,2,0)
import System.Random hiding (uniform, uniformR)
#else
import System.Random
#endif

type Variate a = (Random a)

type Gen s = STRef s StdGen

initialize
    :: PrimMonad m
    => Int
    -> m (Gen (PrimState m))
initialize :: Int -> m (Gen (PrimState m))
initialize Int
salt = ST (PrimState m) (Gen (PrimState m)) -> m (Gen (PrimState m))
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) (Gen (PrimState m)) -> m (Gen (PrimState m)))
-> ST (PrimState m) (Gen (PrimState m)) -> m (Gen (PrimState m))
forall a b. (a -> b) -> a -> b
$ StdGen -> ST (PrimState m) (Gen (PrimState m))
forall a s. a -> ST s (STRef s a)
newSTRef (StdGen -> ST (PrimState m) (Gen (PrimState m)))
-> StdGen -> ST (PrimState m) (Gen (PrimState m))
forall a b. (a -> b) -> a -> b
$! Int -> StdGen
mkStdGen Int
salt

uniformR
    :: Variate b
    => PrimMonad m
    => (b, b)
    -> Gen (PrimState m)
    -> m b
uniformR :: (b, b) -> Gen (PrimState m) -> m b
uniformR (b, b)
range Gen (PrimState m)
gen = ST (PrimState m) b -> m b
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) b -> m b) -> ST (PrimState m) b -> m b
forall a b. (a -> b) -> a -> b
$ do
    (!b
r, !StdGen
g) <- (b, b) -> StdGen -> (b, StdGen)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (b, b)
range (StdGen -> (b, StdGen))
-> ST (PrimState m) StdGen -> ST (PrimState m) (b, StdGen)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen (PrimState m) -> ST (PrimState m) StdGen
forall s a. STRef s a -> ST s a
readSTRef Gen (PrimState m)
gen
    Gen (PrimState m) -> StdGen -> ST (PrimState m) ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef Gen (PrimState m)
gen StdGen
g
    b -> ST (PrimState m) b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r

uniform
    :: Variate b
    => PrimMonad m
    => Gen (PrimState m)
    -> m b
uniform :: Gen (PrimState m) -> m b
uniform Gen (PrimState m)
gen = ST (PrimState m) b -> m b
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) b -> m b) -> ST (PrimState m) b -> m b
forall a b. (a -> b) -> a -> b
$ do
    (!b
r, !StdGen
g) <- StdGen -> (b, StdGen)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random (StdGen -> (b, StdGen))
-> ST (PrimState m) StdGen -> ST (PrimState m) (b, StdGen)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen (PrimState m) -> ST (PrimState m) StdGen
forall s a. STRef s a -> ST s a
readSTRef Gen (PrimState m)
gen
    Gen (PrimState m) -> StdGen -> ST (PrimState m) ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef Gen (PrimState m)
gen StdGen
g
    b -> ST (PrimState m) b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r

#endif