{-# language OverloadedStrings #-} {-# language DeriveFunctor, GeneralizedNewtypeDeriving #-} module System.Random.MWC.Probability.Transition ( -- * Transition Transition , mkTransition , runTransition -- ** Helper functions , withSeverity -- * Re-exported from `logging-effect` , Handler , WithSeverity(..), Severity(..) -- , withFDHandler, defaultBatchingOptions ) where import Control.Monad import Control.Monad.Primitive import qualified Control.Monad.State as S import Control.Monad.Trans.Class (MonadTrans(..), lift) import Control.Monad.Trans.State.Strict (StateT(..), evalStateT, execStateT, runStateT) import Control.Monad.Log (MonadLog(..), Handler, WithSeverity(..), Severity(..), LoggingT(..), runLoggingT, withFDHandler, defaultBatchingOptions, logMessage) import Data.Char import System.Random.MWC.Probability -- | A Markov transition kernel. newtype Transition message s m a = Transition ( Gen (PrimState m) -> StateT s (LoggingT message m) a ) deriving (Functor) -- | Construct a 'Transition' from sampling, state transformation and logging functions. -- -- NB: The three function arguments are used in the order in which they appear here: -- -- 1. a random sample @w :: t@ is produced, using the current state @x :: s@ as input -- -- 2. output @z :: a@ and next state @x' :: s@ are computed using @w@ and @x@ -- -- 3. a logging message is constructed, using @z@ and @x'@ as arguments. mkTransition :: Monad m => (s -> Prob m t) -- ^ Random generation -> (s -> t -> (a, s)) -- ^ (Output, Next state) -> (a -> s -> message) -- ^ Log message generation -> Transition message s m a mkTransition fm fs flog = Transition $ \gen -> do s <- S.get w <- lift . lift $ sample (fm s) gen let (a, s') = fs s w lift $ logMessage $ flog a s' S.put s' return a -- | Run a 'Transition' for a number of steps, while logging each iteration. runTransition :: Monad m => Handler m message -- ^ Logging handler -> Transition message s m a -> Int -- ^ Number of iterations -> s -- ^ Initial state -> Gen (PrimState m) -- ^ PRNG -> m [(a, s)] runTransition logf (Transition fm) n s0 g = runLoggingT (replicateM n (runStateT (fm g) s0)) logf bracketsUpp :: Show a => a -> String bracketsUpp p = unwords ["[", map toUpper (show p), "]"] -- | Render a logging message along with an annotation of its severity. withSeverity :: (t -> String) -> WithSeverity t -> String withSeverity k (WithSeverity u a ) = unwords [bracketsUpp u, k a]