{-# LANGUAGE RankNTypes #-}

-- |
-- Module      : Control.Monad.Bayes.Inference.PMMH
-- Description : Particle Marginal Metropolis-Hastings (PMMH)
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- Particle Marginal Metropolis-Hastings (PMMH) sampling.
--
-- Christophe Andrieu, Arnaud Doucet, and Roman Holenstein. 2010. Particle Markov chain Monte Carlo Methods. /Journal of the Royal Statistical Society/ 72 (2010), 269-342. <http://www.stats.ox.ac.uk/~doucet/andrieu_doucet_holenstein_PMCMC.pdf>
module Control.Monad.Bayes.Inference.PMMH
  ( pmmh,
    pmmhBayesianModel,
  )
where

import Control.Monad.Bayes.Class (Bayesian (generative), MonadDistribution, MonadMeasure, prior)
import Control.Monad.Bayes.Inference.MCMC (MCMCConfig, mcmc)
import Control.Monad.Bayes.Inference.SMC (SMCConfig (), smc)
import Control.Monad.Bayes.Population as Pop
  ( PopulationT,
    hoist,
    pushEvidence,
    runPopulationT,
  )
import Control.Monad.Bayes.Sequential.Coroutine (SequentialT)
import Control.Monad.Bayes.Traced.Static (TracedT)
import Control.Monad.Bayes.Weighted
import Control.Monad.Trans (lift)
import Numeric.Log (Log)

-- | Particle Marginal Metropolis-Hastings sampling.
pmmh ::
  (MonadDistribution m) =>
  MCMCConfig ->
  SMCConfig (WeightedT m) ->
  TracedT (WeightedT m) a1 ->
  (a1 -> SequentialT (PopulationT (WeightedT m)) a2) ->
  m [[(a2, Log Double)]]
pmmh :: forall (m :: * -> *) a1 a2.
MonadDistribution m =>
MCMCConfig
-> SMCConfig (WeightedT m)
-> TracedT (WeightedT m) a1
-> (a1 -> SequentialT (PopulationT (WeightedT m)) a2)
-> m [[(a2, Log Double)]]
pmmh MCMCConfig
mcmcConf SMCConfig (WeightedT m)
smcConf TracedT (WeightedT m) a1
param a1 -> SequentialT (PopulationT (WeightedT m)) a2
model =
  MCMCConfig
-> TracedT (WeightedT m) [(a2, Log Double)]
-> m [[(a2, Log Double)]]
forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> TracedT (WeightedT m) a -> m [a]
mcmc
    MCMCConfig
mcmcConf
    ( TracedT (WeightedT m) a1
param
        TracedT (WeightedT m) a1
-> (a1 -> TracedT (WeightedT m) [(a2, Log Double)])
-> TracedT (WeightedT m) [(a2, Log Double)]
forall a b.
TracedT (WeightedT m) a
-> (a -> TracedT (WeightedT m) b) -> TracedT (WeightedT m) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= PopulationT (TracedT (WeightedT m)) a2
-> TracedT (WeightedT m) [(a2, Log Double)]
forall (m :: * -> *) a. PopulationT m a -> m [(a, Log Double)]
runPopulationT
          (PopulationT (TracedT (WeightedT m)) a2
 -> TracedT (WeightedT m) [(a2, Log Double)])
-> (a1 -> PopulationT (TracedT (WeightedT m)) a2)
-> a1
-> TracedT (WeightedT m) [(a2, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PopulationT (TracedT (WeightedT m)) a2
-> PopulationT (TracedT (WeightedT m)) a2
forall (m :: * -> *) a.
MonadFactor m =>
PopulationT m a -> PopulationT m a
pushEvidence
          (PopulationT (TracedT (WeightedT m)) a2
 -> PopulationT (TracedT (WeightedT m)) a2)
-> (a1 -> PopulationT (TracedT (WeightedT m)) a2)
-> a1
-> PopulationT (TracedT (WeightedT m)) a2
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. WeightedT m x -> TracedT (WeightedT m) x)
-> PopulationT (WeightedT m) a2
-> PopulationT (TracedT (WeightedT m)) a2
forall (n :: * -> *) (m :: * -> *) a.
Monad n =>
(forall x. m x -> n x) -> PopulationT m a -> PopulationT n a
Pop.hoist WeightedT m x -> TracedT (WeightedT m) x
forall x. WeightedT m x -> TracedT (WeightedT m) x
forall (m :: * -> *) a. Monad m => m a -> TracedT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
          (PopulationT (WeightedT m) a2
 -> PopulationT (TracedT (WeightedT m)) a2)
-> (a1 -> PopulationT (WeightedT m) a2)
-> a1
-> PopulationT (TracedT (WeightedT m)) a2
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SMCConfig (WeightedT m)
-> SequentialT (PopulationT (WeightedT m)) a2
-> PopulationT (WeightedT m) a2
forall (m :: * -> *) a.
MonadDistribution m =>
SMCConfig m -> SequentialT (PopulationT m) a -> PopulationT m a
smc SMCConfig (WeightedT m)
smcConf
          (SequentialT (PopulationT (WeightedT m)) a2
 -> PopulationT (WeightedT m) a2)
-> (a1 -> SequentialT (PopulationT (WeightedT m)) a2)
-> a1
-> PopulationT (WeightedT m) a2
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a1 -> SequentialT (PopulationT (WeightedT m)) a2
model
    )

-- | Particle Marginal Metropolis-Hastings sampling from a Bayesian model
pmmhBayesianModel ::
  (MonadMeasure m) =>
  MCMCConfig ->
  SMCConfig (WeightedT m) ->
  (forall m'. (MonadMeasure m') => Bayesian m' a1 a2) ->
  m [[(a2, Log Double)]]
pmmhBayesianModel :: forall (m :: * -> *) a1 a2.
MonadMeasure m =>
MCMCConfig
-> SMCConfig (WeightedT m)
-> (forall (m' :: * -> *). MonadMeasure m' => Bayesian m' a1 a2)
-> m [[(a2, Log Double)]]
pmmhBayesianModel MCMCConfig
mcmcConf SMCConfig (WeightedT m)
smcConf forall (m' :: * -> *). MonadMeasure m' => Bayesian m' a1 a2
bm = MCMCConfig
-> SMCConfig (WeightedT m)
-> TracedT (WeightedT m) a1
-> (a1 -> SequentialT (PopulationT (WeightedT m)) a2)
-> m [[(a2, Log Double)]]
forall (m :: * -> *) a1 a2.
MonadDistribution m =>
MCMCConfig
-> SMCConfig (WeightedT m)
-> TracedT (WeightedT m) a1
-> (a1 -> SequentialT (PopulationT (WeightedT m)) a2)
-> m [[(a2, Log Double)]]
pmmh MCMCConfig
mcmcConf SMCConfig (WeightedT m)
smcConf (Bayesian (TracedT (WeightedT m)) a1 a2 -> TracedT (WeightedT m) a1
forall (m :: * -> *) z o. Bayesian m z o -> m z
prior Bayesian (TracedT (WeightedT m)) a1 a2
forall (m' :: * -> *). MonadMeasure m' => Bayesian m' a1 a2
bm) (Bayesian (SequentialT (PopulationT (WeightedT m))) a1 a2
-> a1 -> SequentialT (PopulationT (WeightedT m)) a2
forall (m :: * -> *) z o. Bayesian m z o -> z -> m o
generative Bayesian (SequentialT (PopulationT (WeightedT m))) a1 a2
forall (m' :: * -> *). MonadMeasure m' => Bayesian m' a1 a2
bm)