{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
module Control.Monad.Bayes.Inference.SMC
  ( smc,
    smcPush,
    SMCConfig (..),
  )
where
import Control.Monad.Bayes.Class (MonadDistribution, MonadMeasure)
import Control.Monad.Bayes.Population
  ( Population,
    pushEvidence,
    withParticles,
  )
import Control.Monad.Bayes.Sequential.Coroutine as Coroutine
data SMCConfig m = SMCConfig
  { forall (m :: * -> *).
SMCConfig m -> forall x. Population m x -> Population m x
resampler :: forall x. Population m x -> Population m x,
    forall (m :: * -> *). SMCConfig m -> Int
numSteps :: Int,
    forall (m :: * -> *). SMCConfig m -> Int
numParticles :: Int
  }
smc ::
  MonadDistribution m =>
  SMCConfig m ->
  Coroutine.Sequential (Population m) a ->
  Population m a
smc :: forall (m :: * -> *) a.
MonadDistribution m =>
SMCConfig m -> Sequential (Population m) a -> Population m a
smc SMCConfig {Int
forall x. Population m x -> Population m x
resampler :: forall (m :: * -> *).
SMCConfig m -> forall x. Population m x -> Population m x
numSteps :: forall (m :: * -> *). SMCConfig m -> Int
numParticles :: forall (m :: * -> *). SMCConfig m -> Int
resampler :: forall x. Population m x -> Population m x
numSteps :: Int
numParticles :: Int
..} =
  (forall x. Population m x -> Population m x)
-> Int -> Sequential (Population m) a -> Population m a
forall (m :: * -> *) a.
Monad m =>
(forall x. m x -> m x) -> Int -> Sequential m a -> m a
Coroutine.sequentially Population m x -> Population m x
forall x. Population m x -> Population m x
resampler Int
numSteps
    (Sequential (Population m) a -> Population m a)
-> (Sequential (Population m) a -> Sequential (Population m) a)
-> Sequential (Population m) a
-> Population m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. Population m x -> Population m x)
-> Sequential (Population m) a -> Sequential (Population m) a
forall (m :: * -> *) a.
(forall x. m x -> m x) -> Sequential m a -> Sequential m a
Coroutine.hoistFirst (Int -> Population m x -> Population m x
forall (m :: * -> *) a.
Monad m =>
Int -> Population m a -> Population m a
withParticles Int
numParticles)
smcPush ::
  MonadMeasure m => SMCConfig m -> Coroutine.Sequential (Population m) a -> Population m a
smcPush :: forall (m :: * -> *) a.
MonadMeasure m =>
SMCConfig m -> Sequential (Population m) a -> Population m a
smcPush SMCConfig m
config = SMCConfig m -> Sequential (Population m) a -> Population m a
forall (m :: * -> *) a.
MonadDistribution m =>
SMCConfig m -> Sequential (Population m) a -> Population m a
smc SMCConfig m
config {resampler :: forall x. Population m x -> Population m x
resampler = (Population m x -> Population m x
forall (m :: * -> *) a.
MonadFactor m =>
Population m a -> Population m a
pushEvidence (Population m x -> Population m x)
-> (Population m x -> Population m x)
-> Population m x
-> Population m x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SMCConfig m -> forall x. Population m x -> Population m x
forall (m :: * -> *).
SMCConfig m -> forall x. Population m x -> Population m x
resampler SMCConfig m
config)}