{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module      :  Mcmc.Algorithm.MHG
-- Description :  Metropolis-Hastings-Green algorithm
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Tue May  5 20:11:30 2020.
--
-- The Metropolis-Hastings-Green ('MHG') algorithm.
--
-- For example, see Geyer, C. J., Introduction to Markov chain Monte Carlo, In
-- Handbook of Markov Chain Monte Carlo (pp. 45) (2011). CRC press.
module Mcmc.Algorithm.MHG
  ( MHG (..),
    mhg,
    mhgSave,
    mhgLoad,
    mhgLoadUnsafe,
    MHGRatio,
    mhgAccept,
  )
where

import Codec.Compression.GZip
import Control.Monad
import Control.Monad.IO.Class
import Control.Parallel.Strategies
import Data.Aeson
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.Maybe
import Data.Time
import qualified Data.Vector as VB
import Mcmc.Acceptance
import Mcmc.Algorithm
import Mcmc.Chain.Chain
import Mcmc.Chain.Link
import Mcmc.Chain.Save
import Mcmc.Chain.Trace
import Mcmc.Cycle
import Mcmc.Likelihood
import Mcmc.Monitor
import Mcmc.Posterior
import Mcmc.Prior hiding (uniform)
import Mcmc.Proposal
import Mcmc.Settings
import Numeric.Log
import System.Random.Stateful
import Prelude hiding (cycle)

-- | The MHG algorithm.
newtype MHG a = MHG {forall a. MHG a -> Chain a
fromMHG :: Chain a}

instance (ToJSON a) => Algorithm (MHG a) where
  aName :: MHG a -> [Char]
aName = [Char] -> MHG a -> [Char]
forall a b. a -> b -> a
const [Char]
"Metropolis-Hastings-Green (MHG)"
  aIteration :: MHG a -> Int
aIteration = Chain a -> Int
forall a. Chain a -> Int
iteration (Chain a -> Int) -> (MHG a -> Chain a) -> MHG a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MHG a -> Chain a
forall a. MHG a -> Chain a
fromMHG
  aIsInvalidState :: MHG a -> Bool
aIsInvalidState = MHG a -> Bool
forall a. MHG a -> Bool
mhgIsInvalidState
  aIterate :: IterationMode -> ParallelizationMode -> MHG a -> IO (MHG a)
aIterate = IterationMode -> ParallelizationMode -> MHG a -> IO (MHG a)
forall a.
IterationMode -> ParallelizationMode -> MHG a -> IO (MHG a)
mhgIterate
  aAutoTune :: TuningType -> Int -> MHG a -> IO (MHG a)
aAutoTune = TuningType -> Int -> MHG a -> IO (MHG a)
forall a. TuningType -> Int -> MHG a -> IO (MHG a)
mhgAutoTune
  aResetAcceptance :: ResetAcceptance -> MHG a -> MHG a
aResetAcceptance = ResetAcceptance -> MHG a -> MHG a
forall a. ResetAcceptance -> MHG a -> MHG a
mhgResetAcceptance
  aCleanAfterBurnIn :: TraceLength -> MHG a -> IO (MHG a)
aCleanAfterBurnIn = TraceLength -> MHG a -> IO (MHG a)
forall a. TraceLength -> MHG a -> IO (MHG a)
mhgCleanAfterBurnIn
  aSummarizeCycle :: IterationMode -> MHG a -> ByteString
aSummarizeCycle = IterationMode -> MHG a -> ByteString
forall a. IterationMode -> MHG a -> ByteString
mhgSummarizeCycle
  aOpenMonitors :: AnalysisName -> ExecutionMode -> MHG a -> IO (MHG a)
aOpenMonitors = AnalysisName -> ExecutionMode -> MHG a -> IO (MHG a)
forall a. AnalysisName -> ExecutionMode -> MHG a -> IO (MHG a)
mhgOpenMonitors
  aExecuteMonitors :: Verbosity -> UTCTime -> Int -> MHG a -> IO (Maybe ByteString)
aExecuteMonitors = Verbosity -> UTCTime -> Int -> MHG a -> IO (Maybe ByteString)
forall a.
Verbosity -> UTCTime -> Int -> MHG a -> IO (Maybe ByteString)
mhgExecuteMonitors
  aStdMonitorHeader :: MHG a -> ByteString
aStdMonitorHeader = MHG a -> ByteString
forall a. MHG a -> ByteString
mhgStdMonitorHeader
  aCloseMonitors :: MHG a -> IO (MHG a)
aCloseMonitors = MHG a -> IO (MHG a)
forall a. MHG a -> IO (MHG a)
mhgCloseMonitors
  aSave :: AnalysisName -> MHG a -> IO ()
aSave = AnalysisName -> MHG a -> IO ()
forall a. ToJSON a => AnalysisName -> MHG a -> IO ()
mhgSave

-- Calculate required length of trace. The length may be larger during burn in,
-- because the tuners of some proposals (e.g., HMC, NUTS) require the states of
-- the last tuning interval.
getTraceLength ::
  Maybe BurnInSettings ->
  TraceLength ->
  Monitor a ->
  Cycle a ->
  Int
getTraceLength :: forall a.
Maybe BurnInSettings -> TraceLength -> Monitor a -> Cycle a -> Int
getTraceLength Maybe BurnInSettings
burnIn TraceLength
tl Monitor a
mn Cycle a
cc = [Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ Int
minimumTraceLength Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
bi Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
batchMonitorSizes
  where
    batchMonitorSizes :: [Int]
batchMonitorSizes = (MonitorBatch a -> Int) -> [MonitorBatch a] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map MonitorBatch a -> Int
forall a. MonitorBatch a -> Int
getMonitorBatchSize ([MonitorBatch a] -> [Int]) -> [MonitorBatch a] -> [Int]
forall a b. (a -> b) -> a -> b
$ Monitor a -> [MonitorBatch a]
forall a. Monitor a -> [MonitorBatch a]
mBatches Monitor a
mn
    minimumTraceLength :: Int
minimumTraceLength = case TraceLength
tl of
      TraceLength
TraceAuto -> Int
1
      TraceMinimum Int
n -> Int
n
    bi :: Int
bi = case (Cycle a -> Bool
forall a. Cycle a -> Bool
ccRequireTrace Cycle a
cc, Maybe BurnInSettings
burnIn) of
      (Bool
True, Just (BurnInWithAutoTuning Int
_ Int
n)) -> Int
n
      (Bool
True, Just (BurnInWithCustomAutoTuning [Int]
ns [Int]
ms)) -> Int -> Int -> Int
forall a. Ord a => a -> a -> a
max ([Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
ns) ([Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
ms)
      (Bool, Maybe BurnInSettings)
_ -> Int
0

-- | Initialize an MHG algorithm.
--
-- NOTE: Computation in the 'IO' Monad is necessary because the trace is
-- mutable.
mhg ::
  Settings ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  InitialState a ->
  StdGen ->
  IO (MHG a)
mhg :: forall a.
Settings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> IO (MHG a)
mhg Settings
s PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  -- The trace is a mutable vector and the mutable state needs to be handled by
  -- a monad.
  Trace a
tr <- Int -> Link a -> IO (Trace a)
forall a. Int -> Link a -> IO (Trace a)
replicateT Int
tl Link a
l0
  IOGenM StdGen
gm <- StdGen -> IO (IOGenM StdGen)
forall (m :: * -> *) g. MonadIO m => g -> m (IOGenM g)
newIOGenM StdGen
g
  MHG a -> IO (MHG a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (MHG a -> IO (MHG a)) -> MHG a -> IO (MHG a)
forall a b. (a -> b) -> a -> b
$ Chain a -> MHG a
forall a. Chain a -> MHG a
MHG (Chain a -> MHG a) -> Chain a -> MHG a
forall a b. (a -> b) -> a -> b
$ Link a
-> Int
-> Trace a
-> Acceptances (Proposal a)
-> IOGenM StdGen
-> Int
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> Chain a
forall a.
Link a
-> Int
-> Trace a
-> Acceptances (Proposal a)
-> IOGenM StdGen
-> Int
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> Chain a
Chain Link a
l0 Int
0 Trace a
tr Acceptances (Proposal a)
ac IOGenM StdGen
gm Int
0 PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn
  where
    l0 :: Link a
l0 = a -> KernelRatio -> KernelRatio -> Link a
forall a. a -> KernelRatio -> KernelRatio -> Link a
Link a
i0 (PriorFunction a
pr a
i0) (PriorFunction a
lh a
i0)
    ac :: Acceptances (Proposal a)
ac = [Proposal a] -> Acceptances (Proposal a)
forall k. Ord k => [k] -> Acceptances k
emptyA ([Proposal a] -> Acceptances (Proposal a))
-> [Proposal a] -> Acceptances (Proposal a)
forall a b. (a -> b) -> a -> b
$ Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc
    tl :: Int
tl = Maybe BurnInSettings -> TraceLength -> Monitor a -> Cycle a -> Int
forall a.
Maybe BurnInSettings -> TraceLength -> Monitor a -> Cycle a -> Int
getTraceLength (BurnInSettings -> Maybe BurnInSettings
forall a. a -> Maybe a
Just (BurnInSettings -> Maybe BurnInSettings)
-> BurnInSettings -> Maybe BurnInSettings
forall a b. (a -> b) -> a -> b
$ Settings -> BurnInSettings
sBurnIn Settings
s) (Settings -> TraceLength
sTraceLength Settings
s) Monitor a
mn Cycle a
cc

mhgFn :: AnalysisName -> FilePath
mhgFn :: AnalysisName -> [Char]
mhgFn (AnalysisName [Char]
nm) = [Char]
nm [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
".mcmc.mhg"

-- | Save an MHG algorithm.
mhgSave ::
  (ToJSON a) =>
  AnalysisName ->
  MHG a ->
  IO ()
mhgSave :: forall a. ToJSON a => AnalysisName -> MHG a -> IO ()
mhgSave AnalysisName
nm (MHG Chain a
c) = do
  SavedChain a
savedChain <- Chain a -> IO (SavedChain a)
forall a. Chain a -> IO (SavedChain a)
toSavedChain Chain a
c
  [Char] -> ByteString -> IO ()
BL.writeFile (AnalysisName -> [Char]
mhgFn AnalysisName
nm) (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
compress (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ SavedChain a -> ByteString
forall a. ToJSON a => a -> ByteString
encode SavedChain a
savedChain

-- | Load an MHG algorithm.
--
-- Also create a backup of the save.
--
-- See 'Mcmc.Mcmc.mcmcContinue'.
mhgLoad ::
  (FromJSON a) =>
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  AnalysisName ->
  IO (MHG a)
mhgLoad :: forall a.
FromJSON a =>
PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> AnalysisName
-> IO (MHG a)
mhgLoad = (PriorFunction a
 -> PriorFunction a
 -> Cycle a
 -> Monitor a
 -> SavedChain a
 -> IO (Chain a))
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> AnalysisName
-> IO (MHG a)
forall a.
FromJSON a =>
(PriorFunction a
 -> PriorFunction a
 -> Cycle a
 -> Monitor a
 -> SavedChain a
 -> IO (Chain a))
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> AnalysisName
-> IO (MHG a)
mhgLoadWith PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
forall a.
PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
fromSavedChain

-- | Like 'mhgLoad' but do not perform sanity checks.
--
-- Also create a backup of the save.
--
-- Useful when restarting a run with changed prior function, likelihood function
-- or proposals. Use with care!
mhgLoadUnsafe ::
  (FromJSON a) =>
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  AnalysisName ->
  IO (MHG a)
mhgLoadUnsafe :: forall a.
FromJSON a =>
PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> AnalysisName
-> IO (MHG a)
mhgLoadUnsafe = (PriorFunction a
 -> PriorFunction a
 -> Cycle a
 -> Monitor a
 -> SavedChain a
 -> IO (Chain a))
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> AnalysisName
-> IO (MHG a)
forall a.
FromJSON a =>
(PriorFunction a
 -> PriorFunction a
 -> Cycle a
 -> Monitor a
 -> SavedChain a
 -> IO (Chain a))
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> AnalysisName
-> IO (MHG a)
mhgLoadWith PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
forall a.
PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
fromSavedChainUnsafe

-- Nice type :-).
mhgLoadWith ::
  (FromJSON a) =>
  (PriorFunction a -> LikelihoodFunction a -> Cycle a -> Monitor a -> SavedChain a -> IO (Chain a)) ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  AnalysisName ->
  IO (MHG a)
mhgLoadWith :: forall a.
FromJSON a =>
(PriorFunction a
 -> PriorFunction a
 -> Cycle a
 -> Monitor a
 -> SavedChain a
 -> IO (Chain a))
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> AnalysisName
-> IO (MHG a)
mhgLoadWith PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
f PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn AnalysisName
nm = do
  Either [Char] (SavedChain a)
savedChain <- ByteString -> Either [Char] (SavedChain a)
forall a. FromJSON a => ByteString -> Either [Char] a
eitherDecode (ByteString -> Either [Char] (SavedChain a))
-> (ByteString -> ByteString)
-> ByteString
-> Either [Char] (SavedChain a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
decompress (ByteString -> Either [Char] (SavedChain a))
-> IO ByteString -> IO (Either [Char] (SavedChain a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> IO ByteString
BL.readFile [Char]
fn
  Chain a
chain <- ([Char] -> IO (Chain a))
-> (SavedChain a -> IO (Chain a))
-> Either [Char] (SavedChain a)
-> IO (Chain a)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char] -> IO (Chain a)
forall a. HasCallStack => [Char] -> a
error (PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
f PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn) Either [Char] (SavedChain a)
savedChain
  MHG a -> IO (MHG a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (MHG a -> IO (MHG a)) -> MHG a -> IO (MHG a)
forall a b. (a -> b) -> a -> b
$ Chain a -> MHG a
forall a. Chain a -> MHG a
MHG Chain a
chain
  where
    -- fnBak = mhgFn $ AnalysisName $ (fromAnalysisName nm ++ ".bak")
    fn :: [Char]
fn = AnalysisName -> [Char]
mhgFn AnalysisName
nm

-- | MHG ratios are stored in log domain.
type MHGRatio = Log Double

-- The MHG ratio. This implementation has the following properties:
--
-- - The kernel ratio and the Jacobian are checked carefully and should be
-- - strictly positive, finite numbers.
--
-- - The ratio is 'Infinity' if fX is zero. In this case, the proposal is always
--   accepted.
--
-- - The ratio is 'NaN' if fY and fX are zero. In this case, the proposal is
--   always rejected.
--
-- This means that a chain in a state with posterior probability zero (fX=0) can
-- only move if a state with non-zero posterior probability is proposed.
-- Otherwise it is stuck. Therefore, I print a warning when the posterior
-- probability is zero in the beginning of the MCMC run. This is probably not
-- the best behavior, but see below.
--
-- There is a discrepancy between authors saying that one should (a) always
-- accept the new state when the current posterior is zero (Chapter 4 of [1],
-- [2]), or (b) almost surely reject the proposal when either fY or q are zero
-- (Chapter 1 of [1]).
--
-- Since I trust the author of Chapter 1 (Charles Geyer) I choose to follow
-- option (b). However, Option (a) is more user-friendly.
--
-- [1] Handbook of Markov chain Monte Carlo (2011), CRC press.
--
-- [2] Dellaportas, P., & Roberts, G. O., An introduction to MCMC, Lecture Notes
-- in Statistics, (), 1–41 (2003).
-- http://dx.doi.org/10.1007/978-0-387-21811-3_1.
mhgRatio :: Posterior -> Posterior -> KernelRatio -> Jacobian -> MHGRatio
mhgRatio :: KernelRatio
-> KernelRatio -> KernelRatio -> KernelRatio -> KernelRatio
mhgRatio KernelRatio
fX KernelRatio
fY KernelRatio
q KernelRatio
j
  | KernelRatio
q KernelRatio -> KernelRatio -> Bool
forall a. Eq a => a -> a -> Bool
== KernelRatio
0.0 = [Char] -> KernelRatio
forall a. HasCallStack => [Char] -> a
error [Char]
"mhgRatio: Kernel ratio is negative infinity. Use 'ForceReject'."
  | KernelRatio
q KernelRatio -> KernelRatio -> Bool
forall a. Eq a => a -> a -> Bool
== KernelRatio
1.0 KernelRatio -> KernelRatio -> KernelRatio
forall a. Fractional a => a -> a -> a
/ KernelRatio
0.0 = [Char] -> KernelRatio
forall a. HasCallStack => [Char] -> a
error [Char]
"mhgRatio: Kernel ratio is infinity. Use 'ForceAccept'."
  | KernelRatio
q KernelRatio -> KernelRatio -> Bool
forall a. Eq a => a -> a -> Bool
== KernelRatio
0.0 KernelRatio -> KernelRatio -> KernelRatio
forall a. Fractional a => a -> a -> a
/ KernelRatio
0.0 = [Char] -> KernelRatio
forall a. HasCallStack => [Char] -> a
error [Char]
"mhgRatio: Kernel ratio is NaN."
  | KernelRatio
j KernelRatio -> KernelRatio -> Bool
forall a. Eq a => a -> a -> Bool
== KernelRatio
0.0 = [Char] -> KernelRatio
forall a. HasCallStack => [Char] -> a
error [Char]
"mhgRatio: Jacobian is negative infinity. Use 'ForceReject'."
  | KernelRatio
j KernelRatio -> KernelRatio -> Bool
forall a. Eq a => a -> a -> Bool
== KernelRatio
1.0 KernelRatio -> KernelRatio -> KernelRatio
forall a. Fractional a => a -> a -> a
/ KernelRatio
0.0 = [Char] -> KernelRatio
forall a. HasCallStack => [Char] -> a
error [Char]
"mhgRatio: Jacobian is infinity. Use 'ForceAccept'."
  | KernelRatio
j KernelRatio -> KernelRatio -> Bool
forall a. Eq a => a -> a -> Bool
== KernelRatio
0.0 KernelRatio -> KernelRatio -> KernelRatio
forall a. Fractional a => a -> a -> a
/ KernelRatio
0.0 = [Char] -> KernelRatio
forall a. HasCallStack => [Char] -> a
error [Char]
"mhgRatio: Jacobian is NaN."
  | Bool
otherwise = KernelRatio
fY KernelRatio -> KernelRatio -> KernelRatio
forall a. Fractional a => a -> a -> a
/ KernelRatio
fX KernelRatio -> KernelRatio -> KernelRatio
forall a. Num a => a -> a -> a
* KernelRatio
q KernelRatio -> KernelRatio -> KernelRatio
forall a. Num a => a -> a -> a
* KernelRatio
j
{-# INLINE mhgRatio #-}

-- | Accept or reject a proposal with given MHG ratio?
mhgAccept :: MHGRatio -> IOGenM StdGen -> IO Bool
mhgAccept :: KernelRatio -> IOGenM StdGen -> IO Bool
mhgAccept KernelRatio
r IOGenM StdGen
g
  | KernelRatio -> Double
forall a. Log a -> a
ln KernelRatio
r Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
0.0 = Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
  | Bool
otherwise = do
      Double
b <- (Double, Double) -> IOGenM StdGen -> IO Double
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *).
StatefulGen g m =>
(Double, Double) -> g -> m Double
uniformRM (Double
0, Double
1) IOGenM StdGen
g
      Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ Double
b Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double -> Double
forall a. Floating a => a -> a
exp (KernelRatio -> Double
forall a. Log a -> a
ln KernelRatio
r)

mhgPropose :: MHG a -> Proposal a -> IO (MHG a)
mhgPropose :: forall a. MHG a -> Proposal a -> IO (MHG a)
mhgPropose (MHG Chain a
c) Proposal a
p = do
  -- 1. Sample new state.
  (!PResult a
pres, !Maybe AcceptanceRates
mcs) <- IO (PResult a, Maybe AcceptanceRates)
-> IO (PResult a, Maybe AcceptanceRates)
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PResult a, Maybe AcceptanceRates)
 -> IO (PResult a, Maybe AcceptanceRates))
-> IO (PResult a, Maybe AcceptanceRates)
-> IO (PResult a, Maybe AcceptanceRates)
forall a b. (a -> b) -> a -> b
$ PFunction a
s a
x IOGenM StdGen
g
  -- 2. Define new prior and likelihood calculation functions. Avoid actual
  -- calculation of the values.
  --
  -- Most often, parallelization is not helpful, because the prior and
  -- likelihood functions are too fast; see
  -- https://stackoverflow.com/a/46603680/3536806.
  let calcPrLh :: a -> (KernelRatio, KernelRatio)
calcPrLh a
y = (PriorFunction a
pF a
y, PriorFunction a
lF a
y) (KernelRatio, KernelRatio)
-> Strategy (KernelRatio, KernelRatio)
-> (KernelRatio, KernelRatio)
forall a. a -> Strategy a -> a
`using` Strategy KernelRatio
-> Strategy KernelRatio -> Strategy (KernelRatio, KernelRatio)
forall a b. Strategy a -> Strategy b -> Strategy (a, b)
parTuple2 Strategy KernelRatio
forall a. NFData a => Strategy a
rdeepseq Strategy KernelRatio
forall a. NFData a => Strategy a
rdeepseq
      accept :: a -> KernelRatio -> KernelRatio -> f (MHG a)
accept a
y KernelRatio
pr KernelRatio
lh =
        let !ac' :: Acceptances (Proposal a)
ac' = Maybe AcceptanceRates
-> Proposal a
-> Acceptances (Proposal a)
-> Acceptances (Proposal a)
forall k.
Ord k =>
Maybe AcceptanceRates -> k -> Acceptances k -> Acceptances k
pushAccept Maybe AcceptanceRates
mcs Proposal a
p Acceptances (Proposal a)
ac
         in MHG a -> f (MHG a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MHG a -> f (MHG a)) -> MHG a -> f (MHG a)
forall a b. (a -> b) -> a -> b
$ Chain a -> MHG a
forall a. Chain a -> MHG a
MHG (Chain a -> MHG a) -> Chain a -> MHG a
forall a b. (a -> b) -> a -> b
$ Chain a
c {link = Link y pr lh, acceptances = ac'}
      reject :: IO (MHG a)
reject =
        let !ac' :: Acceptances (Proposal a)
ac' = Maybe AcceptanceRates
-> Proposal a
-> Acceptances (Proposal a)
-> Acceptances (Proposal a)
forall k.
Ord k =>
Maybe AcceptanceRates -> k -> Acceptances k -> Acceptances k
pushReject Maybe AcceptanceRates
mcs Proposal a
p Acceptances (Proposal a)
ac
         in MHG a -> IO (MHG a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MHG a -> IO (MHG a)) -> MHG a -> IO (MHG a)
forall a b. (a -> b) -> a -> b
$ Chain a -> MHG a
forall a. Chain a -> MHG a
MHG (Chain a -> MHG a) -> Chain a -> MHG a
forall a b. (a -> b) -> a -> b
$ Chain a
c {acceptances = ac'}
  -- 3. Accept or reject.
  --
  -- 3a. When rejection is inevitable, avoid calculation of the prior, the
  -- likelihood and the MHG ratio.
  case PResult a
pres of
    PResult a
ForceReject -> IO (MHG a)
reject
    ForceAccept a
y -> let (KernelRatio
pY, KernelRatio
lY) = a -> (KernelRatio, KernelRatio)
calcPrLh a
y in a -> KernelRatio -> KernelRatio -> IO (MHG a)
forall {f :: * -> *}.
Applicative f =>
a -> KernelRatio -> KernelRatio -> f (MHG a)
accept a
y KernelRatio
pY KernelRatio
lY
    (Propose a
y KernelRatio
q KernelRatio
j) ->
      if KernelRatio
q KernelRatio -> KernelRatio -> Bool
forall a. Ord a => a -> a -> Bool
<= KernelRatio
0.0 Bool -> Bool -> Bool
|| KernelRatio
j KernelRatio -> KernelRatio -> Bool
forall a. Ord a => a -> a -> Bool
<= KernelRatio
0.0
        then IO (MHG a)
reject
        else do
          -- 3b. Calculate Metropolis-Hastings-Green ratio.
          let (KernelRatio
pY, KernelRatio
lY) = a -> (KernelRatio, KernelRatio)
calcPrLh a
y
              !r :: KernelRatio
r = KernelRatio
-> KernelRatio -> KernelRatio -> KernelRatio -> KernelRatio
mhgRatio (KernelRatio
pX KernelRatio -> KernelRatio -> KernelRatio
forall a. Num a => a -> a -> a
* KernelRatio
lX) (KernelRatio
pY KernelRatio -> KernelRatio -> KernelRatio
forall a. Num a => a -> a -> a
* KernelRatio
lY) KernelRatio
q KernelRatio
j
          Bool
isAccept <- KernelRatio -> IOGenM StdGen -> IO Bool
mhgAccept KernelRatio
r IOGenM StdGen
g
          if Bool
isAccept
            then a -> KernelRatio -> KernelRatio -> IO (MHG a)
forall {f :: * -> *}.
Applicative f =>
a -> KernelRatio -> KernelRatio -> f (MHG a)
accept a
y KernelRatio
pY KernelRatio
lY
            else IO (MHG a)
reject
  where
    s :: PFunction a
s = Proposal a -> PFunction a
forall a. Proposal a -> PFunction a
prFunction Proposal a
p
    (Link a
x KernelRatio
pX KernelRatio
lX) = Chain a -> Link a
forall a. Chain a -> Link a
link Chain a
c
    pF :: PriorFunction a
pF = Chain a -> PriorFunction a
forall a. Chain a -> PriorFunction a
priorFunction Chain a
c
    lF :: PriorFunction a
lF = Chain a -> PriorFunction a
forall a. Chain a -> PriorFunction a
likelihoodFunction Chain a
c
    ac :: Acceptances (Proposal a)
ac = Chain a -> Acceptances (Proposal a)
forall a. Chain a -> Acceptances (Proposal a)
acceptances Chain a
c
    g :: IOGenM StdGen
g = Chain a -> IOGenM StdGen
forall a. Chain a -> IOGenM StdGen
generator Chain a
c

mhgPush :: MHG a -> IO (MHG a)
mhgPush :: forall a. MHG a -> IO (MHG a)
mhgPush (MHG Chain a
c) = do
  Trace a
t' <- Link a -> Trace a -> IO (Trace a)
forall a. Link a -> Trace a -> IO (Trace a)
pushT Link a
i Trace a
t
  MHG a -> IO (MHG a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (MHG a -> IO (MHG a)) -> MHG a -> IO (MHG a)
forall a b. (a -> b) -> a -> b
$ Chain a -> MHG a
forall a. Chain a -> MHG a
MHG Chain a
c {trace = t', iteration = succ n}
  where
    i :: Link a
i = Chain a -> Link a
forall a. Chain a -> Link a
link Chain a
c
    t :: Trace a
t = Chain a -> Trace a
forall a. Chain a -> Trace a
trace Chain a
c
    n :: Int
n = Chain a -> Int
forall a. Chain a -> Int
iteration Chain a
c

-- Check if the current state is invalid.
--
-- At the moment this just checks whether the prior, likelihood, or posterior
-- are NaN or infinite.
mhgIsInvalidState :: MHG a -> Bool
mhgIsInvalidState :: forall a. MHG a -> Bool
mhgIsInvalidState MHG a
a = KernelRatio -> Bool
forall {a}. RealFloat a => Log a -> Bool
checkSoft KernelRatio
p Bool -> Bool -> Bool
|| KernelRatio -> Bool
forall {a}. RealFloat a => Log a -> Bool
check KernelRatio
l Bool -> Bool -> Bool
|| KernelRatio -> Bool
forall {a}. RealFloat a => Log a -> Bool
check (KernelRatio
p KernelRatio -> KernelRatio -> KernelRatio
forall a. Num a => a -> a -> a
* KernelRatio
l)
  where
    x :: Link a
x = Chain a -> Link a
forall a. Chain a -> Link a
link (Chain a -> Link a) -> Chain a -> Link a
forall a b. (a -> b) -> a -> b
$ MHG a -> Chain a
forall a. MHG a -> Chain a
fromMHG MHG a
a
    p :: KernelRatio
p = Link a -> KernelRatio
forall a. Link a -> KernelRatio
prior Link a
x
    l :: KernelRatio
l = Link a -> KernelRatio
forall a. Link a -> KernelRatio
likelihood Link a
x
    check :: Log a -> Bool
check Log a
v = let v' :: a
v' = Log a -> a
forall a. Log a -> a
ln Log a
v in a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
v' Bool -> Bool -> Bool
|| a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
v' Bool -> Bool -> Bool
|| a
v' a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0
    checkSoft :: Log a -> Bool
checkSoft Log a
v = let v' :: a
v' = Log a -> a
forall a. Log a -> a
ln Log a
v in a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
v' Bool -> Bool -> Bool
|| a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
v'

-- Ignore the number of capabilities. I have tried a lot of stuff, but the MHG
-- algorithm is just inherently sequential. Parallelization can be achieved by
-- having parallel prior and/or likelihood functions, or by using algorithms
-- running parallel chains such as 'MC3'.
mhgIterate :: IterationMode -> ParallelizationMode -> MHG a -> IO (MHG a)
mhgIterate :: forall a.
IterationMode -> ParallelizationMode -> MHG a -> IO (MHG a)
mhgIterate IterationMode
m ParallelizationMode
_ MHG a
a = do
  [Proposal a]
ps <- IterationMode -> Cycle a -> IOGenM StdGen -> IO [Proposal a]
forall g (m :: * -> *) a.
StatefulGen g m =>
IterationMode -> Cycle a -> g -> m [Proposal a]
prepareProposals IterationMode
m Cycle a
cc IOGenM StdGen
g
  MHG a
a' <- (MHG a -> Proposal a -> IO (MHG a))
-> MHG a -> [Proposal a] -> IO (MHG a)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM MHG a -> Proposal a -> IO (MHG a)
forall a. MHG a -> Proposal a -> IO (MHG a)
mhgPropose MHG a
a [Proposal a]
ps
  MHG a -> IO (MHG a)
forall a. MHG a -> IO (MHG a)
mhgPush MHG a
a'
  where
    c :: Chain a
c = MHG a -> Chain a
forall a. MHG a -> Chain a
fromMHG MHG a
a
    cc :: Cycle a
cc = Chain a -> Cycle a
forall a. Chain a -> Cycle a
cycle Chain a
c
    g :: IOGenM StdGen
g = Chain a -> IOGenM StdGen
forall a. Chain a -> IOGenM StdGen
generator Chain a
c

mhgAutoTune :: TuningType -> Int -> MHG a -> IO (MHG a)
mhgAutoTune :: forall a. TuningType -> Int -> MHG a -> IO (MHG a)
mhgAutoTune TuningType
tt Int
n (MHG Chain a
c)
  | Bool
isIntermediate =
      MHG a -> IO (MHG a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MHG a -> IO (MHG a))
-> (Chain a -> MHG a) -> Chain a -> IO (MHG a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Chain a -> MHG a
forall a. Chain a -> MHG a
MHG (Chain a -> IO (MHG a)) -> Chain a -> IO (MHG a)
forall a b. (a -> b) -> a -> b
$
        if Cycle a -> Bool
forall a. Cycle a -> Bool
ccHasIntermediateTuners Cycle a
cc
          then -- Do not provide trace when tuning intermediately.
            Chain a
c {cycle = autoTuneCycle tt ac Nothing cc}
          else -- Skip intermediate tuning completely when unnecessary.
            Chain a
c
  | Bool
otherwise = do
      Maybe (Vector a)
mxs <-
        -- Provide the trace if required.
        if Cycle a -> Bool
forall a. Cycle a -> Bool
ccRequireTrace Cycle a
cc
          then Vector a -> Maybe (Vector a)
forall a. a -> Maybe a
Just (Vector a -> Maybe (Vector a))
-> (Vector (Link a) -> Vector a)
-> Vector (Link a)
-> Maybe (Vector a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Link a -> a) -> Vector (Link a) -> Vector a
forall a b. (a -> b) -> Vector a -> Vector b
VB.map Link a -> a
forall a. Link a -> a
state (Vector (Link a) -> Maybe (Vector a))
-> IO (Vector (Link a)) -> IO (Maybe (Vector a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Trace a -> IO (Vector (Link a))
forall a. Int -> Trace a -> IO (Vector (Link a))
takeT Int
n Trace a
tr
          else Maybe (Vector a) -> IO (Maybe (Vector a))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Vector a)
forall a. Maybe a
Nothing
      MHG a -> IO (MHG a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MHG a -> IO (MHG a)) -> MHG a -> IO (MHG a)
forall a b. (a -> b) -> a -> b
$ Chain a -> MHG a
forall a. Chain a -> MHG a
MHG Chain a
c {cycle = autoTuneCycle tt ac mxs cc}
  where
    isIntermediate :: Bool
isIntermediate = TuningType
tt TuningType -> TuningType -> Bool
forall a. Eq a => a -> a -> Bool
== TuningType
IntermediateTuningFastProposalsOnly Bool -> Bool -> Bool
|| TuningType
tt TuningType -> TuningType -> Bool
forall a. Eq a => a -> a -> Bool
== TuningType
IntermediateTuningAllProposals
    ac :: Acceptances (Proposal a)
ac = Chain a -> Acceptances (Proposal a)
forall a. Chain a -> Acceptances (Proposal a)
acceptances Chain a
c
    cc :: Cycle a
cc = Chain a -> Cycle a
forall a. Chain a -> Cycle a
cycle Chain a
c
    tr :: Trace a
tr = Chain a -> Trace a
forall a. Chain a -> Trace a
trace Chain a
c

mhgResetAcceptance :: ResetAcceptance -> MHG a -> MHG a
mhgResetAcceptance :: forall a. ResetAcceptance -> MHG a -> MHG a
mhgResetAcceptance ResetAcceptance
a (MHG Chain a
c) = Chain a -> MHG a
forall a. Chain a -> MHG a
MHG (Chain a -> MHG a) -> Chain a -> MHG a
forall a b. (a -> b) -> a -> b
$ Chain a
c {acceptances = resetA a ac}
  where
    ac :: Acceptances (Proposal a)
ac = Chain a -> Acceptances (Proposal a)
forall a. Chain a -> Acceptances (Proposal a)
acceptances Chain a
c

mhgCleanAfterBurnIn :: TraceLength -> MHG a -> IO (MHG a)
mhgCleanAfterBurnIn :: forall a. TraceLength -> MHG a -> IO (MHG a)
mhgCleanAfterBurnIn TraceLength
tl (MHG Chain a
c) = do
  Vector (Link a)
xs <- Int -> Trace a -> IO (Vector (Link a))
forall a. Int -> Trace a -> IO (Vector (Link a))
takeT Int
l Trace a
tr
  Trace a
tr' <- Vector (Link a) -> IO (Trace a)
forall a. Vector (Link a) -> IO (Trace a)
fromVectorT Vector (Link a)
xs
  let c' :: Chain a
c' = Chain a
c {trace = tr'}
  MHG a -> IO (MHG a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MHG a -> IO (MHG a)) -> MHG a -> IO (MHG a)
forall a b. (a -> b) -> a -> b
$ Chain a -> MHG a
forall a. Chain a -> MHG a
MHG Chain a
c'
  where
    mn :: Monitor a
mn = Chain a -> Monitor a
forall a. Chain a -> Monitor a
monitor Chain a
c
    cc :: Cycle a
cc = Chain a -> Cycle a
forall a. Chain a -> Cycle a
cycle Chain a
c
    tr :: Trace a
tr = Chain a -> Trace a
forall a. Chain a -> Trace a
trace Chain a
c
    l :: Int
l = Maybe BurnInSettings -> TraceLength -> Monitor a -> Cycle a -> Int
forall a.
Maybe BurnInSettings -> TraceLength -> Monitor a -> Cycle a -> Int
getTraceLength Maybe BurnInSettings
forall a. Maybe a
Nothing TraceLength
tl Monitor a
mn Cycle a
cc

mhgSummarizeCycle :: IterationMode -> MHG a -> BL.ByteString
mhgSummarizeCycle :: forall a. IterationMode -> MHG a -> ByteString
mhgSummarizeCycle IterationMode
m (MHG Chain a
c) = IterationMode -> Acceptances (Proposal a) -> Cycle a -> ByteString
forall a.
IterationMode -> Acceptances (Proposal a) -> Cycle a -> ByteString
summarizeCycle IterationMode
m Acceptances (Proposal a)
ac Cycle a
cc
  where
    cc :: Cycle a
cc = Chain a -> Cycle a
forall a. Chain a -> Cycle a
cycle Chain a
c
    ac :: Acceptances (Proposal a)
ac = Chain a -> Acceptances (Proposal a)
forall a. Chain a -> Acceptances (Proposal a)
acceptances Chain a
c

mhgOpenMonitors ::
  AnalysisName ->
  ExecutionMode ->
  MHG a ->
  IO (MHG a)
mhgOpenMonitors :: forall a. AnalysisName -> ExecutionMode -> MHG a -> IO (MHG a)
mhgOpenMonitors AnalysisName
nm ExecutionMode
em (MHG Chain a
c) = do
  Monitor a
m' <- [Char] -> [Char] -> ExecutionMode -> Monitor a -> IO (Monitor a)
forall a.
[Char] -> [Char] -> ExecutionMode -> Monitor a -> IO (Monitor a)
mOpen [Char]
pre [Char]
"" ExecutionMode
em Monitor a
m
  MHG a -> IO (MHG a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MHG a -> IO (MHG a)) -> MHG a -> IO (MHG a)
forall a b. (a -> b) -> a -> b
$ Chain a -> MHG a
forall a. Chain a -> MHG a
MHG Chain a
c {monitor = m'}
  where
    m :: Monitor a
m = Chain a -> Monitor a
forall a. Chain a -> Monitor a
monitor Chain a
c
    pre :: [Char]
pre = AnalysisName -> [Char]
fromAnalysisName AnalysisName
nm

mhgExecuteMonitors ::
  Verbosity ->
  -- Starting time.
  UTCTime ->
  -- Total number of iterations.
  Int ->
  MHG a ->
  IO (Maybe BL.ByteString)
mhgExecuteMonitors :: forall a.
Verbosity -> UTCTime -> Int -> MHG a -> IO (Maybe ByteString)
mhgExecuteMonitors Verbosity
vb UTCTime
t0 Int
iTotal (MHG Chain a
c) = Verbosity
-> Int
-> Int
-> UTCTime
-> Trace a
-> Int
-> Monitor a
-> IO (Maybe ByteString)
forall a.
Verbosity
-> Int
-> Int
-> UTCTime
-> Trace a
-> Int
-> Monitor a
-> IO (Maybe ByteString)
mExec Verbosity
vb Int
i Int
i0 UTCTime
t0 Trace a
tr Int
iTotal Monitor a
m
  where
    i :: Int
i = Chain a -> Int
forall a. Chain a -> Int
iteration Chain a
c
    i0 :: Int
i0 = Chain a -> Int
forall a. Chain a -> Int
start Chain a
c
    tr :: Trace a
tr = Chain a -> Trace a
forall a. Chain a -> Trace a
trace Chain a
c
    m :: Monitor a
m = Chain a -> Monitor a
forall a. Chain a -> Monitor a
monitor Chain a
c

mhgStdMonitorHeader :: MHG a -> BL.ByteString
mhgStdMonitorHeader :: forall a. MHG a -> ByteString
mhgStdMonitorHeader (MHG Chain a
c) = MonitorStdOut a -> ByteString
forall a. MonitorStdOut a -> ByteString
msHeader (Monitor a -> MonitorStdOut a
forall a. Monitor a -> MonitorStdOut a
mStdOut (Monitor a -> MonitorStdOut a) -> Monitor a -> MonitorStdOut a
forall a b. (a -> b) -> a -> b
$ Chain a -> Monitor a
forall a. Chain a -> Monitor a
monitor Chain a
c)

mhgCloseMonitors :: MHG a -> IO (MHG a)
mhgCloseMonitors :: forall a. MHG a -> IO (MHG a)
mhgCloseMonitors (MHG Chain a
c) = do
  Monitor a
m' <- Monitor a -> IO (Monitor a)
forall a. Monitor a -> IO (Monitor a)
mClose Monitor a
m
  MHG a -> IO (MHG a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (MHG a -> IO (MHG a)) -> MHG a -> IO (MHG a)
forall a b. (a -> b) -> a -> b
$ Chain a -> MHG a
forall a. Chain a -> MHG a
MHG (Chain a -> MHG a) -> Chain a -> MHG a
forall a b. (a -> b) -> a -> b
$ Chain a
c {monitor = m'}
  where
    m :: Monitor a
m = Chain a -> Monitor a
forall a. Chain a -> Monitor a
monitor Chain a
c