{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}

-- |
-- Module      : Control.Monad.Bayes.Inference.SMC
-- Description : Sequential Monte Carlo (SMC)
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- Sequential Monte Carlo (SMC) sampling.
--
-- Arnaud Doucet and Adam M. Johansen. 2011. A tutorial on particle filtering and smoothing: fifteen years later. In /The Oxford Handbook of Nonlinear Filtering/, Dan Crisan and Boris Rozovskii (Eds.). Oxford University Press, Chapter 8.
module Control.Monad.Bayes.Inference.SMC
  ( smc,
    smcPush,
    SMCConfig (..),
  )
where

import Control.Monad.Bayes.Class (MonadDistribution, MonadMeasure)
import Control.Monad.Bayes.Population
  ( PopulationT,
    pushEvidence,
    withParticles,
  )
import Control.Monad.Bayes.Sequential.Coroutine as Coroutine

data SMCConfig m = SMCConfig
  { forall (m :: * -> *).
SMCConfig m -> forall x. PopulationT m x -> PopulationT m x
resampler :: forall x. PopulationT m x -> PopulationT m x,
    forall (m :: * -> *). SMCConfig m -> Int
numSteps :: Int,
    forall (m :: * -> *). SMCConfig m -> Int
numParticles :: Int
  }

-- | Sequential importance resampling.
-- Basically an SMC template that takes a custom resampler.
smc ::
  (MonadDistribution m) =>
  SMCConfig m ->
  Coroutine.SequentialT (PopulationT m) a ->
  PopulationT m a
smc :: forall (m :: * -> *) a.
MonadDistribution m =>
SMCConfig m -> SequentialT (PopulationT m) a -> PopulationT m a
smc SMCConfig {Int
forall x. PopulationT m x -> PopulationT m x
resampler :: forall (m :: * -> *).
SMCConfig m -> forall x. PopulationT m x -> PopulationT m x
numSteps :: forall (m :: * -> *). SMCConfig m -> Int
numParticles :: forall (m :: * -> *). SMCConfig m -> Int
resampler :: forall x. PopulationT m x -> PopulationT m x
numSteps :: Int
numParticles :: Int
..} =
  (forall x. PopulationT m x -> PopulationT m x)
-> Int -> SequentialT (PopulationT m) a -> PopulationT m a
forall (m :: * -> *) a.
Monad m =>
(forall x. m x -> m x) -> Int -> SequentialT m a -> m a
Coroutine.sequentially PopulationT m x -> PopulationT m x
forall x. PopulationT m x -> PopulationT m x
resampler Int
numSteps
    (SequentialT (PopulationT m) a -> PopulationT m a)
-> (SequentialT (PopulationT m) a -> SequentialT (PopulationT m) a)
-> SequentialT (PopulationT m) a
-> PopulationT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. PopulationT m x -> PopulationT m x)
-> SequentialT (PopulationT m) a -> SequentialT (PopulationT m) a
forall (m :: * -> *) a.
(forall x. m x -> m x) -> SequentialT m a -> SequentialT m a
Coroutine.hoistFirst (Int -> PopulationT m x -> PopulationT m x
forall (m :: * -> *) a.
Monad m =>
Int -> PopulationT m a -> PopulationT m a
withParticles Int
numParticles)

-- | Sequential Monte Carlo with multinomial resampling at each timestep.
-- Weights are normalized at each timestep and the total weight is pushed
-- as a score into the transformed monad.
smcPush ::
  (MonadMeasure m) => SMCConfig m -> Coroutine.SequentialT (PopulationT m) a -> PopulationT m a
smcPush :: forall (m :: * -> *) a.
MonadMeasure m =>
SMCConfig m -> SequentialT (PopulationT m) a -> PopulationT m a
smcPush SMCConfig m
config = SMCConfig m -> SequentialT (PopulationT m) a -> PopulationT m a
forall (m :: * -> *) a.
MonadDistribution m =>
SMCConfig m -> SequentialT (PopulationT m) a -> PopulationT m a
smc SMCConfig m
config {resampler = (pushEvidence . resampler config)}