{-# language OverloadedStrings #-}
{-# language DeriveFunctor, GeneralizedNewtypeDeriving #-}
{-# language FlexibleContexts #-}
module System.Random.MWC.Probability.Transition (
Transition
, mkTransition
, runTransition
, evalTransition
, execTransition
, stepConditional
, withSeverity
, L.WithSeverity(..), L.Severity(..)
, L.Handler, L.withFDHandler
, L.BatchingOptions(..), L.defaultBatchingOptions, L.withBatchedHandler
, stdout, stderr
) where
import Control.Monad
import Control.Monad.Primitive
import GHC.IO.Handle (Handle(..))
import GHC.IO.Handle.FD (stdout, stderr)
import qualified Control.Monad.State as S
import Control.Monad.Trans.Class (MonadTrans(..), lift)
import Control.Monad.Trans.State.Strict (StateT(..), runStateT)
import qualified Control.Monad.Log as L
import Data.Char
import System.Random.MWC.Probability
newtype Transition message s m a = Transition (
Gen (PrimState m) -> StateT s (L.LoggingT message m) a
) deriving (Functor)
instance Show (Transition msg s m a) where
show _ = "<Transition>"
mkTransition :: Monad m =>
(s -> Prob m t)
-> (s -> t -> (a, s))
-> (s -> t -> a -> message)
-> Transition message s m a
mkTransition fm fs flog = Transition $ \gen -> do
s <- S.get
w <- lift . lift $ sample (fm s) gen
let (a, s') = fs s w
lift $ L.logMessage $ flog s' w a
S.put s'
return a
runTransition :: Monad m =>
L.Handler m message
-> Transition message s m a
-> Int
-> s
-> Gen (PrimState m)
-> m ([a], s)
runTransition logf (Transition fm) n s0 g =
L.runLoggingT (runStateT (replicateM n (fm g)) s0) logf
evalTransition :: Monad m =>
L.Handler m message
-> Transition message s m a
-> Int
-> s
-> Gen (PrimState m)
-> m [a]
evalTransition logf tfm n s0 g = fst <$> runTransition logf tfm n s0 g
execTransition :: Monad m =>
L.Handler m message
-> Transition message s m a
-> Int
-> s
-> Gen (PrimState m)
-> m s
execTransition logf tfm n s0 g = snd <$> runTransition logf tfm n s0 g
stepConditional :: Monad m =>
(a -> s -> s -> Bool)
-> (a -> s -> s -> l)
-> (a -> s -> s -> r)
-> L.Handler m message
-> Transition message s m a
-> s
-> Gen (PrimState m)
-> m (Either l r)
stepConditional q fleft fright logf (Transition fm) s g = do
(a, s') <- L.runLoggingT (runStateT (fm g) s) logf
if q a s s' then pure (Left $ fleft a s s') else pure (Right $ fright a s s')
bracketsUpp :: Show a => a -> String
bracketsUpp p = unwords ["[", map toUpper (show p), "]"]
withSeverity :: (t -> String) -> L.WithSeverity t -> String
withSeverity k (L.WithSeverity u a ) = unwords [bracketsUpp u, k a]