{-# LANGUAGE BangPatterns #-}
module Mcmc.Metropolis
( mh,
mhContinue,
)
where
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.State.Strict
import Data.Aeson
import Data.Maybe
import Mcmc.Item
import Mcmc.Mcmc
import Mcmc.Move
import Mcmc.Status
import Mcmc.Tools.Shuffle
import Mcmc.Trace
import Numeric.Log
import System.Random.MWC
import Prelude hiding (cycle)
mhRatio :: Log Double -> Log Double -> Log Double -> Log Double -> Log Double
mhRatio lX lY qXY qYX = lY * qYX / lX / qXY
{-# INLINE mhRatio #-}
mhRatioSymmetric :: Log Double -> Log Double -> Log Double
mhRatioSymmetric lX lY = lY / lX
{-# INLINE mhRatioSymmetric #-}
mhMove :: Move a -> Mcmc a ()
mhMove m = do
let p = mvSample $ mvSimple m
mq = mvDensity $ mvSimple m
s <- get
let (Item x pX lX) = item s
pF = priorF s
lF = likelihoodF s
a = acceptance s
g = generator s
!y <- liftIO $ p x g
let !pY = pF y
!lY = lF y
!r = case mq of
Nothing -> mhRatioSymmetric (pX * lX) (pY * lY)
Just q -> mhRatio (pX * lX) (pY * lY) (q x y) (q y x)
if ln r >= 0.0
then put $ s {item = Item y pY lY, acceptance = pushA m True a}
else do
b <- uniform g
if b < exp (ln r)
then put $ s {item = Item y pY lY, acceptance = pushA m True a}
else put $ s {acceptance = pushA m False a}
getNCycles :: Cycle a -> Int -> GenIO -> IO [[Move a]]
getNCycles c = shuffleN mvs
where
!mvs = concat [replicate (mvWeight m) m | m <- fromCycle c]
mhIter :: ToJSON a => [Move a] -> Mcmc a ()
mhIter mvs = do
mapM_ mhMove mvs
s <- get
let i = item s
t = trace s
n = iteration s
put $ s {trace = pushT i t, iteration = succ n}
mcmcMonitorExec
mhNIter :: ToJSON a => Int -> Mcmc a ()
mhNIter n = do
c <- gets cycle
g <- gets generator
cycles <- liftIO $ getNCycles c n g
forM_ cycles mhIter
mhBurnInN :: ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnInN b (Just t)
| t <= 0 = error "mhBurnInN: Auto tuning period smaller equal 0."
| b > t = mhNIter t >> mcmcAutotune t >> mhBurnInN (b - t) (Just t)
| otherwise = mhNIter b >> mcmcAutotune b
mhBurnInN b Nothing = mhNIter b
mhBurnIn :: ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnIn b t
| b < 0 = error "mhBurnIn: Negative number of burn in iterations."
| b == 0 = return ()
| otherwise = do
liftIO $ putStrLn $ "-- Burn in for " <> show b <> " cycles."
mcmcMonitorHeader
mhBurnInN b t
liftIO $ putStrLn "-- Burn in finished."
case t of
Nothing -> return ()
Just _ -> mcmcSummarizeCycle t
s <- get
let a = acceptance s
put $ s {acceptance = resetA a}
mhRun :: ToJSON a => Int -> Mcmc a ()
mhRun n = do
liftIO $ putStrLn $ "-- Run chain for " <> show n <> " iterations."
mcmcMonitorHeader
mhNIter n
mhT :: ToJSON a => Mcmc a ()
mhT = do
s <- get
liftIO $ putStrLn "-- Start of Metropolis-Hastings sampler."
mcmcInit
mcmcReport
mcmcSummarizeCycle Nothing
let b = fromMaybe 0 (burnInIterations s)
mhBurnIn b (autoTuningPeriod s)
let n = iterations s
mhRun n
mcmcClose
mhContinueT :: ToJSON a => Int -> Mcmc a ()
mhContinueT dn = do
liftIO $ putStrLn "-- Continue Metropolis-Hastings sampler."
liftIO $ putStrLn $ "-- Run chain for " <> show dn <> " additional iterations."
s <- get
let n = iterations s
put s {iterations = n + dn}
mcmcInit
mcmcSummarizeCycle Nothing
mhRun dn
mcmcClose
mhContinue ::
ToJSON a =>
Int ->
Status a ->
IO (Status a)
mhContinue dn
| dn <= 0 = error "mhContinue: The number of iterations is zero or negative."
| otherwise = execStateT $ mhContinueT dn
mh ::
ToJSON a =>
Status a ->
IO (Status a)
mh s = do
let m = iteration s
if m == 0
then execStateT mhT s
else do
putStrLn "To continue a Markov chain run, please use 'mhContinue'."
error $ "mh: Current iteration " ++ show m ++ " is non-zero."