{-# language OverloadedStrings #-}
{-# language DeriveFunctor, GeneralizedNewtypeDeriving #-}
{-# language FlexibleContexts #-}
module System.Random.MWC.Probability.Transition (
Transition
, mkTransition
, runTransition
, evalTransition
, execTransition
, stepConditional
, withSeverity
) where
import Control.Monad
import Control.Monad.Primitive
import qualified Control.Monad.State as S
import Control.Monad.Trans.Class (MonadTrans(..), lift)
import Control.Monad.Trans.State.Strict (StateT(..), evalStateT, execStateT, runStateT)
import qualified Control.Monad.Log as L
import Data.Char
import System.Random.MWC.Probability
newtype Transition message s m a = Transition (
Gen (PrimState m) -> StateT s (L.LoggingT message m) a
) deriving (Functor)
instance Show (Transition msg s m a) where
show _ = "<Transition>"
mkTransition :: Monad m =>
(s -> Prob m t)
-> (s -> t -> (a, s))
-> (a -> s -> message)
-> Transition message s m a
mkTransition fm fs flog = Transition $ \gen -> do
s <- S.get
w <- lift . lift $ sample (fm s) gen
let (a, s') = fs s w
lift $ L.logMessage $ flog a s'
S.put s'
return a
runTransition :: Monad m =>
L.Handler m message
-> Transition message s m a
-> Int
-> s
-> Gen (PrimState m)
-> m ([a], s)
runTransition logf (Transition fm) n s0 g =
L.runLoggingT (runStateT (replicateM n (fm g)) s0) logf
evalTransition :: Monad m =>
L.Handler m message
-> Transition message s m a
-> Int
-> s
-> Gen (PrimState m)
-> m [a]
evalTransition logf (Transition fm) n s0 g =
L.runLoggingT (evalStateT (replicateM n (fm g)) s0) logf
execTransition :: Monad m =>
L.Handler m message
-> Transition message s m a
-> Int
-> s
-> Gen (PrimState m)
-> m s
execTransition logf (Transition fm) n s0 g =
L.runLoggingT (execStateT (replicateM n (fm g)) s0) logf
stepConditional :: Monad m =>
(a -> s -> s -> Bool)
-> (a -> s -> s -> l)
-> (a -> s -> s -> r)
-> L.Handler m message
-> Transition message s m a
-> s
-> Gen (PrimState m)
-> m (Either l r)
stepConditional q fleft fright logf (Transition fm) s g = do
(a, s') <- L.runLoggingT (runStateT (fm g) s) logf
if q a s s' then pure (Left $ fleft a s s') else pure (Right $ fright a s s')
bracketsUpp :: Show a => a -> String
bracketsUpp p = unwords ["[", map toUpper (show p), "]"]
withSeverity :: (t -> String) -> L.WithSeverity t -> String
withSeverity k (L.WithSeverity u a ) = unwords [bracketsUpp u, k a]