{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module      :  Mcmc.Metropolis
-- Description :  Metropolis-Hastings at its best
-- Copyright   :  (c) Dominik Schrempf 2020
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Tue May  5 20:11:30 2020.
--
-- Metropolis-Hastings algorithm.
module Mcmc.Metropolis
  ( mh,
    mhContinue,
  )
where

import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.State
import Data.Aeson
import Data.Maybe
import Mcmc.Item
import Mcmc.Mcmc
import Mcmc.Proposal
import Mcmc.Status
import Mcmc.Trace
import Numeric.Log
import System.Random.MWC
import Prelude hiding (cycle)

-- The Metropolis-Hastings ratio.
--
-- 'Infinity' if fX is zero. In this case, the proposal is always accepted.
--
-- 'NaN' if (fY or q) and fX are zero. In this case, the proposal is always
-- rejected.

-- There is a discrepancy between authors saying that one should (a) always
-- accept the new state when the current posterior is zero (Chapter 4 of the
-- Handbook of Markov Chain Monte Carlo), or (b) almost surely reject the
-- proposal when either fY or q are zero (Chapter 1). Since I trust the author
-- of Chapter 1 (Charles Geyer) I choose to follow option (b).
mhRatio :: Log Double -> Log Double -> Log Double -> Log Double -> Log Double
-- q = qYX / qXY * jXY; see 'ProposalSimple'.
-- j = Jacobian.
mhRatio :: Log Double -> Log Double -> Log Double -> Log Double -> Log Double
mhRatio Log Double
fX Log Double
fY Log Double
q Log Double
j = Log Double
fY Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Log Double
fX Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
q Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
j
{-# INLINE mhRatio #-}

mhPropose :: Proposal a -> Mcmc a ()
mhPropose :: Proposal a -> Mcmc a ()
mhPropose Proposal a
m = do
  let p :: ProposalSimple a
p = Proposal a -> ProposalSimple a
forall a. Proposal a -> ProposalSimple a
pSimple Proposal a
m
  Status a
s <- StateT (Status a) IO (Status a)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  let (Item a
x Log Double
pX Log Double
lX) = Status a -> Item a
forall a. Status a -> Item a
item Status a
s
      pF :: a -> Log Double
pF = Status a -> a -> Log Double
forall a. Status a -> a -> Log Double
priorF Status a
s
      lF :: a -> Log Double
lF = Status a -> a -> Log Double
forall a. Status a -> a -> Log Double
likelihoodF Status a
s
      a :: Acceptance (Proposal a)
a = Status a -> Acceptance (Proposal a)
forall a. Status a -> Acceptance (Proposal a)
acceptance Status a
s
      g :: GenIO
g = Status a -> GenIO
forall a. Status a -> GenIO
generator Status a
s
  -- 1. Sample new state.
  (!a
y, !Log Double
q, !Log Double
j) <- IO (a, Log Double, Log Double)
-> StateT (Status a) IO (a, Log Double, Log Double)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (a, Log Double, Log Double)
 -> StateT (Status a) IO (a, Log Double, Log Double))
-> IO (a, Log Double, Log Double)
-> StateT (Status a) IO (a, Log Double, Log Double)
forall a b. (a -> b) -> a -> b
$ a -> Gen RealWorld -> IO (a, Log Double, Log Double)
p a
x Gen RealWorld
g
  -- 2. Calculate Metropolis-Hastings ratio.
  let !pY :: Log Double
pY = a -> Log Double
pF a
y
      !lY :: Log Double
lY = a -> Log Double
lF a
y
      !r :: Log Double
r = Log Double -> Log Double -> Log Double -> Log Double -> Log Double
mhRatio (Log Double
pX Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
lX) (Log Double
pY Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
lY) Log Double
q Log Double
j
  -- 3. Accept or reject.
  if Log Double -> Double
forall a. Log a -> a
ln Log Double
r Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
0.0
    then Status a -> Mcmc a ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Status a -> Mcmc a ()) -> Status a -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ Status a
s {item :: Item a
item = a -> Log Double -> Log Double -> Item a
forall a. a -> Log Double -> Log Double -> Item a
Item a
y Log Double
pY Log Double
lY, acceptance :: Acceptance (Proposal a)
acceptance = Proposal a
-> Bool -> Acceptance (Proposal a) -> Acceptance (Proposal a)
forall k.
(Ord k, Show k) =>
k -> Bool -> Acceptance k -> Acceptance k
pushA Proposal a
m Bool
True Acceptance (Proposal a)
a}
    else do
      Double
b <- Gen (PrimState (StateT (Status a) IO))
-> StateT (Status a) IO Double
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
Gen (PrimState m) -> m a
uniform Gen RealWorld
Gen (PrimState (StateT (Status a) IO))
g
      if Double
b Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double -> Double
forall a. Floating a => a -> a
exp (Log Double -> Double
forall a. Log a -> a
ln Log Double
r)
        then Status a -> Mcmc a ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Status a -> Mcmc a ()) -> Status a -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ Status a
s {item :: Item a
item = a -> Log Double -> Log Double -> Item a
forall a. a -> Log Double -> Log Double -> Item a
Item a
y Log Double
pY Log Double
lY, acceptance :: Acceptance (Proposal a)
acceptance = Proposal a
-> Bool -> Acceptance (Proposal a) -> Acceptance (Proposal a)
forall k.
(Ord k, Show k) =>
k -> Bool -> Acceptance k -> Acceptance k
pushA Proposal a
m Bool
True Acceptance (Proposal a)
a}
        else Status a -> Mcmc a ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Status a -> Mcmc a ()) -> Status a -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ Status a
s {acceptance :: Acceptance (Proposal a)
acceptance = Proposal a
-> Bool -> Acceptance (Proposal a) -> Acceptance (Proposal a)
forall k.
(Ord k, Show k) =>
k -> Bool -> Acceptance k -> Acceptance k
pushA Proposal a
m Bool
False Acceptance (Proposal a)
a}

-- TODO: Splitmix. Split the generator here. See SaveSpec -> mhContinue.

-- Run one iterations; perform all proposals in a Cycle.
mhIter :: ToJSON a => [Proposal a] -> Mcmc a ()
mhIter :: [Proposal a] -> Mcmc a ()
mhIter [Proposal a]
ps = do
  (Proposal a -> Mcmc a ()) -> [Proposal a] -> Mcmc a ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Proposal a -> Mcmc a ()
forall a. Proposal a -> Mcmc a ()
mhPropose [Proposal a]
ps
  Status a
s <- StateT (Status a) IO (Status a)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  let i :: Item a
i = Status a -> Item a
forall a. Status a -> Item a
item Status a
s
      t :: Trace a
t = Status a -> Trace a
forall a. Status a -> Trace a
trace Status a
s
      n :: Int
n = Status a -> Int
forall a. Status a -> Int
iteration Status a
s
  Status a -> Mcmc a ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Status a -> Mcmc a ()) -> Status a -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ Status a
s {trace :: Trace a
trace = Item a -> Trace a -> Trace a
forall a. Item a -> Trace a -> Trace a
pushT Item a
i Trace a
t, iteration :: Int
iteration = Int -> Int
forall a. Enum a => a -> a
succ Int
n}
  Mcmc a ()
forall a. Mcmc a ()
mcmcClean
  Mcmc a ()
forall a. ToJSON a => Mcmc a ()
mcmcMonitorExec

-- Run N iterations.
mhNIter :: ToJSON a => Int -> Mcmc a ()
mhNIter :: Int -> Mcmc a ()
mhNIter Int
n = do
  String -> Mcmc a ()
forall a. String -> Mcmc a ()
mcmcDebugS (String -> Mcmc a ()) -> String -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ String
"Run " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" iterations."
  Cycle a
c <- (Status a -> Cycle a) -> StateT (Status a) IO (Cycle a)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets Status a -> Cycle a
forall a. Status a -> Cycle a
cycle
  Gen RealWorld
g <- (Status a -> Gen RealWorld) -> StateT (Status a) IO (Gen RealWorld)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets Status a -> Gen RealWorld
forall a. Status a -> GenIO
generator
  [[Proposal a]]
cycles <- IO [[Proposal a]] -> StateT (Status a) IO [[Proposal a]]
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [[Proposal a]] -> StateT (Status a) IO [[Proposal a]])
-> IO [[Proposal a]] -> StateT (Status a) IO [[Proposal a]]
forall a b. (a -> b) -> a -> b
$ Cycle a -> Int -> GenIO -> IO [[Proposal a]]
forall a. Cycle a -> Int -> GenIO -> IO [[Proposal a]]
getNIterations Cycle a
c Int
n Gen RealWorld
GenIO
g
  [[Proposal a]] -> ([Proposal a] -> Mcmc a ()) -> Mcmc a ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [[Proposal a]]
cycles [Proposal a] -> Mcmc a ()
forall a. ToJSON a => [Proposal a] -> Mcmc a ()
mhIter

-- Burn in and auto tune.
mhBurnInN :: ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnInN :: Int -> Maybe Int -> Mcmc a ()
mhBurnInN Int
b (Just Int
t)
  | Int
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = String -> Mcmc a ()
forall a. HasCallStack => String -> a
error String
"mhBurnInN: Auto tuning period smaller equal 0."
  | Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
t = do
    Mcmc a ()
forall a. Mcmc a ()
mcmcResetA
    Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhNIter Int
t
    Mcmc a ByteString
forall a. Mcmc a ByteString
mcmcSummarizeCycle Mcmc a ByteString -> (ByteString -> Mcmc a ()) -> Mcmc a ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcDebugB
    Mcmc a ()
forall a. Mcmc a ()
mcmcAutotune
    Int -> Maybe Int -> Mcmc a ()
forall a. ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnInN (Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
t) (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
t)
  | Bool
otherwise = do
    Mcmc a ()
forall a. Mcmc a ()
mcmcResetA
    Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhNIter Int
b
    Mcmc a ByteString
forall a. Mcmc a ByteString
mcmcSummarizeCycle Mcmc a ByteString -> (ByteString -> Mcmc a ()) -> Mcmc a ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcInfoB
    String -> Mcmc a ()
forall a. String -> Mcmc a ()
mcmcInfoS (String -> Mcmc a ()) -> String -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ String
"Acceptance ratios calculated over the last " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
b String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" iterations."
mhBurnInN Int
b Maybe Int
Nothing = Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhNIter Int
b

-- Initialize burn in for given number of iterations.
mhBurnIn :: ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnIn :: Int -> Maybe Int -> Mcmc a ()
mhBurnIn Int
b Maybe Int
t
  | Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> Mcmc a ()
forall a. HasCallStack => String -> a
error String
"mhBurnIn: Negative number of burn in iterations."
  | Int
b Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = () -> Mcmc a ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  | Bool
otherwise = do
    String -> Mcmc a ()
forall a. String -> Mcmc a ()
mcmcInfoS (String -> Mcmc a ()) -> String -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ String
"Burn in for " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
b String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" cycles."
    String -> Mcmc a ()
forall a. String -> Mcmc a ()
mcmcDebugS (String -> Mcmc a ()) -> String -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ String
"Auto tuning period is " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Maybe Int -> String
forall a. Show a => a -> String
show Maybe Int
t String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"."
    Int -> Maybe Int -> Mcmc a ()
forall a. ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnInN Int
b Maybe Int
t
    ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcInfoB ByteString
"Burn in finished."

-- Run for given number of iterations.
mhRun :: ToJSON a => Int -> Mcmc a ()
mhRun :: Int -> Mcmc a ()
mhRun Int
n = do
  Mcmc a ()
forall a. Mcmc a ()
mcmcResetA
  String -> Mcmc a ()
forall a. String -> Mcmc a ()
mcmcInfoS (String -> Mcmc a ()) -> String -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ String
"Run chain for " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" iterations."
  -- let (m, r) = n `quotRem` 100
  -- -- Print header to standard output every 100 iterations.
  -- replicateM_ m $ do
  --   mcmcMonitorStdOutHeader
  --   mhNIter 100
  -- when (r > 0) $ do
  --   mcmcMonitorStdOutHeader
  --   mhNIter r
  Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhNIter Int
n

mhT :: ToJSON a => Mcmc a ()
mhT :: Mcmc a ()
mhT = do
  ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcInfoB ByteString
"Metropolis-Hastings sampler."
  Mcmc a ByteString
forall a. Mcmc a ByteString
mcmcSummarizeCycle Mcmc a ByteString -> (ByteString -> Mcmc a ()) -> Mcmc a ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcInfoB
  Mcmc a ()
forall a. ToJSON a => Mcmc a ()
mcmcReport
  Status a
s <- StateT (Status a) IO (Status a)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  let b :: Int
b = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 (Status a -> Maybe Int
forall a. Status a -> Maybe Int
burnInIterations Status a
s)
  Int -> Maybe Int -> Mcmc a ()
forall a. ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnIn Int
b (Status a -> Maybe Int
forall a. Status a -> Maybe Int
autoTuningPeriod Status a
s)
  Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhRun (Int -> Mcmc a ()) -> Int -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ Status a -> Int
forall a. Status a -> Int
iterations Status a
s

mhContinueT :: ToJSON a => Int -> Mcmc a ()
mhContinueT :: Int -> Mcmc a ()
mhContinueT Int
dn = do
  ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcInfoB ByteString
"Continuation of Metropolis-Hastings sampler."
  String -> Mcmc a ()
forall a. String -> Mcmc a ()
mcmcInfoS (String -> Mcmc a ()) -> String -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ String
"Run chain for " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
dn String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" additional iterations."
  Mcmc a ByteString
forall a. Mcmc a ByteString
mcmcSummarizeCycle Mcmc a ByteString -> (ByteString -> Mcmc a ()) -> Mcmc a ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcInfoB
  Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhRun Int
dn

-- | Continue a Markov chain for a given number of Metropolis-Hastings steps.
--
-- At the moment, when an MCMC run is continued, the old @.mcmc@ file is
-- deleted. This behavior may change in the future.
--
-- This means that an interrupted continuation also breaks previous runs. This
-- step is necessary because, otherwise, incomplete monitor files are left on
-- disk, if a continuation is canceled. Subsequent continuations would append to
-- the incomplete monitor files and produce garbage.
mhContinue ::
  ToJSON a =>
  -- | Additional number of Metropolis-Hastings steps.
  Int ->
  -- | Loaded status of the Markov chain.
  Status a ->
  IO (Status a)
mhContinue :: Int -> Status a -> IO (Status a)
mhContinue Int
dn Status a
s
  | Int
dn Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = String -> IO (Status a)
forall a. HasCallStack => String -> a
error String
"mhContinue: The number of iterations is zero or negative."
  | Bool
otherwise = Mcmc a () -> Status a -> IO (Status a)
forall a. ToJSON a => Mcmc a () -> Status a -> IO (Status a)
mcmcRun (Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhContinueT Int
dn) Status a
s'
  where
    n' :: Int
n' = Status a -> Int
forall a. Status a -> Int
iterations Status a
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
dn
    s' :: Status a
s' = Status a
s {iterations :: Int
iterations = Int
n'}

-- | Run a Markov chain for a given number of Metropolis-Hastings steps.
mh ::
  ToJSON a =>
  -- | Initial (or last) status of the Markov chain.
  Status a ->
  IO (Status a)
mh :: Status a -> IO (Status a)
mh Status a
s =
  if Status a -> Int
forall a. Status a -> Int
iteration Status a
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
    then Mcmc a () -> Status a -> IO (Status a)
forall a. ToJSON a => Mcmc a () -> Status a -> IO (Status a)
mcmcRun Mcmc a ()
forall a. ToJSON a => Mcmc a ()
mhT Status a
s
    else do
      String -> IO ()
putStrLn String
"To continue a Markov chain run, please use 'mhContinue'."
      String -> IO (Status a)
forall a. HasCallStack => String -> a
error (String -> IO (Status a)) -> String -> IO (Status a)
forall a b. (a -> b) -> a -> b
$ String
"mh: Current iteration " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Status a -> Int
forall a. Status a -> Int
iteration Status a
s) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" is non-zero."