{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskell, GADTs #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Random.Source.PureMT
    ( PureMT, newPureMT, pureMT
    , getRandomPrimFromMTRef
    ) where
import Control.Monad.State
import Control.Monad.RWS
import qualified Control.Monad.State.Strict as S
import qualified Control.Monad.RWS.Strict as S
import Data.Random.Internal.Source
import Data.Random.Source.Internal.TH
import Data.StateRef
import System.Random.Mersenne.Pure64
{-# INLINE withMTRef #-}
withMTRef :: (Monad m, ModifyRef sr m PureMT) => (PureMT -> (t, PureMT)) -> sr -> m t
withMTRef thing ref = atomicModifyReference ref $ \(!oldMT) ->
    case thing oldMT of (!w, !newMT) -> (newMT, w)
{-# INLINE withMTState #-}
withMTState :: MonadState PureMT m => (PureMT -> (t, PureMT)) -> m t
withMTState thing = do
    !mt <- get
    let (!ws, !newMt) = thing mt
    put newMt
    return ws
#ifndef MTL2
$(monadRandom
    [d| instance MonadRandom (State PureMT) where
            getRandomWord64 = withMTState randomWord64
            getRandomDouble = withMTState randomDouble
     |])
$(monadRandom
    [d| instance MonadRandom (S.State PureMT) where
            getRandomWord64 = withMTState randomWord64
            getRandomDouble = withMTState randomDouble
     |])
$(monadRandom
    [d| instance Monoid w => MonadRandom (RWS r w PureMT) where
            getRandomWord64 = withMTState randomWord64
            getRandomDouble = withMTState randomDouble
     |])
$(monadRandom
    [d| instance Monoid w => MonadRandom (S.RWS r w PureMT) where
            getRandomWord64 = withMTState randomWord64
            getRandomDouble = withMTState randomDouble
     |])
#endif
$(randomSource
    [d| instance (Monad m1, ModifyRef (Ref m2 PureMT) m1 PureMT) => RandomSource m1 (Ref m2 PureMT) where
            getRandomWord64From = withMTRef randomWord64
            getRandomDoubleFrom = withMTRef randomDouble
    |])
$(monadRandom
    [d| instance Monad m => MonadRandom (StateT PureMT m) where
            getRandomWord64 = withMTState randomWord64
            getRandomDouble = withMTState randomDouble
     |])
$(monadRandom
    [d| instance Monad m => MonadRandom (S.StateT PureMT m) where
            getRandomWord64 = withMTState randomWord64
            getRandomDouble = withMTState randomDouble
     |])
$(monadRandom
    [d| instance (Monad m, Monoid w) => MonadRandom (RWST r w PureMT m) where
            getRandomWord64 = withMTState randomWord64
            getRandomDouble = withMTState randomDouble
     |])
$(monadRandom
    [d| instance (Monad m, Monoid w) => MonadRandom (S.RWST r w PureMT m) where
            getRandomWord64 = withMTState randomWord64
            getRandomDouble = withMTState randomDouble
     |])
$(randomSource
    [d| instance (MonadIO m) => RandomSource m (IORef PureMT) where
            getRandomWord64From = withMTRef randomWord64
            getRandomDoubleFrom = withMTRef randomDouble
     |])
$(randomSource
    [d| instance (Monad m, ModifyRef (STRef s PureMT) m PureMT) => RandomSource m (STRef s PureMT) where
            getRandomWord64From = withMTRef randomWord64
            getRandomDoubleFrom = withMTRef randomDouble
     |])
getRandomPrimFromMTRef :: ModifyRef sr m PureMT => sr -> Prim a -> m a
getRandomPrimFromMTRef ref
    = atomicModifyReference' ref
    . runState
    . getRandomPrim
atomicModifyReference' :: ModifyRef sr m a => sr -> (a -> (b, a)) -> m b
atomicModifyReference' ref getR =
    atomicModifyReference ref (swap' . getR)
        where swap' (!a,!b) = (b,a)