{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

-- |
-- Module      :  Mcmc.Algorithm.MC3
-- Description :  Metropolis-coupled Markov chain Monte Carlo algorithm
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Mon Nov 23 15:20:33 2020.
--
-- The Metropolis-coupled Markov chain Monte Carlo ('MC3') algorithm.
--
-- Also known as parallel tempering.
--
-- Like any other parallel MCMC algorithm, the 'MC3' algorithm is essentially an
-- 'Mcmc.Algorithm.MHG.MHG' algorithm on the product space of all parallel
-- chains.
--
-- For example, see
--
-- - Geyer, C. J., Markov chain monte carlo maximum likelihood, Computing
--   Science and Statistics, Proceedings of the 23rd Symposium on the Interface,
--   (1991).
--
-- - Altekar, G., Dwarkadas, S., Huelsenbeck, J. P., & Ronquist, F., Parallel
--   metropolis coupled markov chain monte carlo for bayesian phylogenetic
--   inference, Bioinformatics, 20(3), 407–415 (2004).
module Mcmc.Algorithm.MC3
  ( -- * Definitions
    NChains (..),
    SwapPeriod (..),
    NSwaps (..),
    MC3Settings (..),
    MHGChains,
    ReciprocalTemperatures,

    -- * Metropolis-coupled Markov chain Monte Carlo algorithm
    MC3 (..),
    mc3,
    mc3Save,
    mc3Load,
  )
where

import Codec.Compression.GZip
import Control.Concurrent.Async hiding (link)
import Control.Monad
import Data.Aeson
import Data.Aeson.TH
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.List
import qualified Data.Map.Strict as M
import Data.Time
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import Data.Word
import Mcmc.Acceptance
import Mcmc.Algorithm
import Mcmc.Algorithm.MHG
import Mcmc.Chain.Chain
import Mcmc.Chain.Link
import Mcmc.Chain.Save
import Mcmc.Chain.Trace
import Mcmc.Cycle
import Mcmc.Internal.Random
import Mcmc.Internal.Shuffle
import Mcmc.Likelihood
import Mcmc.Monitor
import Mcmc.Posterior
import Mcmc.Prior
import Mcmc.Proposal
import Mcmc.Settings
import Numeric.Log hiding (sum)
import System.Random.Stateful
import Text.Printf

-- | Total number of parallel chains.
--
-- Must be two or larger.
newtype NChains = NChains {NChains -> Int
fromNChains :: Int}
  deriving (NChains -> NChains -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: NChains -> NChains -> Bool
$c/= :: NChains -> NChains -> Bool
== :: NChains -> NChains -> Bool
$c== :: NChains -> NChains -> Bool
Eq, ReadPrec [NChains]
ReadPrec NChains
Int -> ReadS NChains
ReadS [NChains]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [NChains]
$creadListPrec :: ReadPrec [NChains]
readPrec :: ReadPrec NChains
$creadPrec :: ReadPrec NChains
readList :: ReadS [NChains]
$creadList :: ReadS [NChains]
readsPrec :: Int -> ReadS NChains
$creadsPrec :: Int -> ReadS NChains
Read, Int -> NChains -> ShowS
[NChains] -> ShowS
NChains -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [NChains] -> ShowS
$cshowList :: [NChains] -> ShowS
show :: NChains -> [Char]
$cshow :: NChains -> [Char]
showsPrec :: Int -> NChains -> ShowS
$cshowsPrec :: Int -> NChains -> ShowS
Show)

$(deriveJSON defaultOptions ''NChains)

-- | The period of proposing state swaps between chains.
--
-- Must be one or larger.
newtype SwapPeriod = SwapPeriod {SwapPeriod -> Int
fromSwapPeriod :: Int}
  deriving (SwapPeriod -> SwapPeriod -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SwapPeriod -> SwapPeriod -> Bool
$c/= :: SwapPeriod -> SwapPeriod -> Bool
== :: SwapPeriod -> SwapPeriod -> Bool
$c== :: SwapPeriod -> SwapPeriod -> Bool
Eq, ReadPrec [SwapPeriod]
ReadPrec SwapPeriod
Int -> ReadS SwapPeriod
ReadS [SwapPeriod]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [SwapPeriod]
$creadListPrec :: ReadPrec [SwapPeriod]
readPrec :: ReadPrec SwapPeriod
$creadPrec :: ReadPrec SwapPeriod
readList :: ReadS [SwapPeriod]
$creadList :: ReadS [SwapPeriod]
readsPrec :: Int -> ReadS SwapPeriod
$creadsPrec :: Int -> ReadS SwapPeriod
Read, Int -> SwapPeriod -> ShowS
[SwapPeriod] -> ShowS
SwapPeriod -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [SwapPeriod] -> ShowS
$cshowList :: [SwapPeriod] -> ShowS
show :: SwapPeriod -> [Char]
$cshow :: SwapPeriod -> [Char]
showsPrec :: Int -> SwapPeriod -> ShowS
$cshowsPrec :: Int -> SwapPeriod -> ShowS
Show)

$(deriveJSON defaultOptions ''SwapPeriod)

-- | The number of proposed swaps at each swapping event.
--
-- Must be in @[1, NChains - 1]@.
newtype NSwaps = NSwaps {NSwaps -> Int
fromNSwaps :: Int}
  deriving (NSwaps -> NSwaps -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: NSwaps -> NSwaps -> Bool
$c/= :: NSwaps -> NSwaps -> Bool
== :: NSwaps -> NSwaps -> Bool
$c== :: NSwaps -> NSwaps -> Bool
Eq, ReadPrec [NSwaps]
ReadPrec NSwaps
Int -> ReadS NSwaps
ReadS [NSwaps]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [NSwaps]
$creadListPrec :: ReadPrec [NSwaps]
readPrec :: ReadPrec NSwaps
$creadPrec :: ReadPrec NSwaps
readList :: ReadS [NSwaps]
$creadList :: ReadS [NSwaps]
readsPrec :: Int -> ReadS NSwaps
$creadsPrec :: Int -> ReadS NSwaps
Read, Int -> NSwaps -> ShowS
[NSwaps] -> ShowS
NSwaps -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [NSwaps] -> ShowS
$cshowList :: [NSwaps] -> ShowS
show :: NSwaps -> [Char]
$cshow :: NSwaps -> [Char]
showsPrec :: Int -> NSwaps -> ShowS
$cshowsPrec :: Int -> NSwaps -> ShowS
Show)

$(deriveJSON defaultOptions ''NSwaps)

-- | MC3 settings.
data MC3Settings = MC3Settings
  { -- | The number of chains has to be larger equal two.
    MC3Settings -> NChains
mc3NChains :: NChains,
    MC3Settings -> SwapPeriod
mc3SwapPeriod :: SwapPeriod,
    MC3Settings -> NSwaps
mc3NSwaps :: NSwaps
  }
  deriving (MC3Settings -> MC3Settings -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MC3Settings -> MC3Settings -> Bool
$c/= :: MC3Settings -> MC3Settings -> Bool
== :: MC3Settings -> MC3Settings -> Bool
$c== :: MC3Settings -> MC3Settings -> Bool
Eq, ReadPrec [MC3Settings]
ReadPrec MC3Settings
Int -> ReadS MC3Settings
ReadS [MC3Settings]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [MC3Settings]
$creadListPrec :: ReadPrec [MC3Settings]
readPrec :: ReadPrec MC3Settings
$creadPrec :: ReadPrec MC3Settings
readList :: ReadS [MC3Settings]
$creadList :: ReadS [MC3Settings]
readsPrec :: Int -> ReadS MC3Settings
$creadsPrec :: Int -> ReadS MC3Settings
Read, Int -> MC3Settings -> ShowS
[MC3Settings] -> ShowS
MC3Settings -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MC3Settings] -> ShowS
$cshowList :: [MC3Settings] -> ShowS
show :: MC3Settings -> [Char]
$cshow :: MC3Settings -> [Char]
showsPrec :: Int -> MC3Settings -> ShowS
$cshowsPrec :: Int -> MC3Settings -> ShowS
Show)

$(deriveJSON defaultOptions ''MC3Settings)

-- | Vector of MHG chains.
type MHGChains a = V.Vector (MHG a)

-- | Vector of reciprocal temperatures.
type ReciprocalTemperatures = U.Vector Double

data SavedMC3 a = SavedMC3
  { forall a. SavedMC3 a -> MC3Settings
savedMC3Settings :: MC3Settings,
    forall a. SavedMC3 a -> Vector (SavedChain a)
savedMC3Chains :: V.Vector (SavedChain a),
    forall a. SavedMC3 a -> ReciprocalTemperatures
savedMC3ReciprocalTemperatures :: ReciprocalTemperatures,
    forall a. SavedMC3 a -> Int
savedMC3Iteration :: Int,
    forall a. SavedMC3 a -> Acceptance Int
savedMC3SwapAcceptance :: Acceptance Int,
    forall a. SavedMC3 a -> (Word64, Word64)
savedMC3Generator :: (Word64, Word64)
  }
  deriving (SavedMC3 a -> SavedMC3 a -> Bool
forall a. Eq a => SavedMC3 a -> SavedMC3 a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SavedMC3 a -> SavedMC3 a -> Bool
$c/= :: forall a. Eq a => SavedMC3 a -> SavedMC3 a -> Bool
== :: SavedMC3 a -> SavedMC3 a -> Bool
$c== :: forall a. Eq a => SavedMC3 a -> SavedMC3 a -> Bool
Eq, Int -> SavedMC3 a -> ShowS
forall a. Show a => Int -> SavedMC3 a -> ShowS
forall a. Show a => [SavedMC3 a] -> ShowS
forall a. Show a => SavedMC3 a -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [SavedMC3 a] -> ShowS
$cshowList :: forall a. Show a => [SavedMC3 a] -> ShowS
show :: SavedMC3 a -> [Char]
$cshow :: forall a. Show a => SavedMC3 a -> [Char]
showsPrec :: Int -> SavedMC3 a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> SavedMC3 a -> ShowS
Show)

$(deriveJSON defaultOptions ''SavedMC3)

toSavedMC3 ::
  MC3 a ->
  IO (SavedMC3 a)
toSavedMC3 :: forall a. MC3 a -> IO (SavedMC3 a)
toSavedMC3 (MC3 MC3Settings
s MHGChains a
mhgs ReciprocalTemperatures
bs Int
i Acceptance Int
ac IOGenM StdGen
g) = do
  Vector (SavedChain a)
scs <- forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM (forall a. Chain a -> IO (SavedChain a)
toSavedChain forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. MHG a -> Chain a
fromMHG) MHGChains a
mhgs
  (Word64, Word64)
g' <- IOGenM StdGen -> IO (Word64, Word64)
saveGen IOGenM StdGen
g
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a.
MC3Settings
-> Vector (SavedChain a)
-> ReciprocalTemperatures
-> Int
-> Acceptance Int
-> (Word64, Word64)
-> SavedMC3 a
SavedMC3 MC3Settings
s Vector (SavedChain a)
scs ReciprocalTemperatures
bs Int
i Acceptance Int
ac (Word64, Word64)
g'

fromSavedMC3 ::
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  SavedMC3 a ->
  IO (MC3 a)
fromSavedMC3 :: forall a.
PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedMC3 a
-> IO (MC3 a)
fromSavedMC3 PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn (SavedMC3 MC3Settings
s Vector (SavedChain a)
scs ReciprocalTemperatures
bs Int
i Acceptance Int
ac (Word64, Word64)
g') = do
  Vector (MHG a)
mhgs <-
    forall a. [a] -> Vector a
V.fromList
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
        [ forall a. Chain a -> MHG a
MHG forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a.
PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
fromSavedChain PriorFunction a
pf PriorFunction a
lf Cycle a
cc Monitor a
mn SavedChain a
sc
          | (SavedChain a
sc, PriorFunction a
pf, PriorFunction a
lf) <- forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (forall a. Vector a -> [a]
V.toList Vector (SavedChain a)
scs) [PriorFunction a]
prs [PriorFunction a]
lhs
        ]
  IOGenM StdGen
g <- (Word64, Word64) -> IO (IOGenM StdGen)
loadGen (Word64, Word64)
g'
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a.
MC3Settings
-> MHGChains a
-> ReciprocalTemperatures
-> Int
-> Acceptance Int
-> IOGenM StdGen
-> MC3 a
MC3 MC3Settings
s Vector (MHG a)
mhgs ReciprocalTemperatures
bs Int
i Acceptance Int
ac IOGenM StdGen
g
  where
    prs :: [PriorFunction a]
prs = forall a b. (a -> b) -> [a] -> [b]
map (forall a. (a -> Log Double) -> Double -> a -> Log Double
heatFunction PriorFunction a
pr) forall a b. (a -> b) -> a -> b
$ forall a. Unbox a => Vector a -> [a]
U.toList ReciprocalTemperatures
bs
    lhs :: [PriorFunction a]
lhs = forall a b. (a -> b) -> [a] -> [b]
map (forall a. (a -> Log Double) -> Double -> a -> Log Double
heatFunction PriorFunction a
lh) forall a b. (a -> b) -> a -> b
$ forall a. Unbox a => Vector a -> [a]
U.toList ReciprocalTemperatures
bs

-- | The MC3 algorithm.
data MC3 a = MC3
  { forall a. MC3 a -> MC3Settings
mc3Settings :: MC3Settings,
    -- | The first chain is the cold chain with temperature 1.0.
    forall a. MC3 a -> MHGChains a
mc3MHGChains :: MHGChains a,
    -- | Vector of reciprocal temperatures.
    forall a. MC3 a -> ReciprocalTemperatures
mc3ReciprocalTemperatures :: ReciprocalTemperatures,
    -- | Current iteration.
    forall a. MC3 a -> Int
mc3Iteration :: Int,
    -- | Number of accepted and rejected swaps.
    forall a. MC3 a -> Acceptance Int
mc3SwapAcceptance :: Acceptance Int,
    forall a. MC3 a -> IOGenM StdGen
mc3Generator :: IOGenM StdGen
  }

instance ToJSON a => Algorithm (MC3 a) where
  aName :: MC3 a -> [Char]
aName = forall a b. a -> b -> a
const [Char]
"Metropolis-coupled Markov chain Monte Carlo (MC3)"
  aIteration :: MC3 a -> Int
aIteration = forall a. MC3 a -> Int
mc3Iteration
  aIsInvalidState :: MC3 a -> Bool
aIsInvalidState = forall a. ToJSON a => MC3 a -> Bool
mc3IsInvalidState
  aIterate :: IterationMode -> ParallelizationMode -> MC3 a -> IO (MC3 a)
aIterate = forall a.
ToJSON a =>
IterationMode -> ParallelizationMode -> MC3 a -> IO (MC3 a)
mc3Iterate
  aAutoTune :: TuningType -> Int -> MC3 a -> IO (MC3 a)
aAutoTune = forall a. ToJSON a => TuningType -> Int -> MC3 a -> IO (MC3 a)
mc3AutoTune
  aResetAcceptance :: MC3 a -> MC3 a
aResetAcceptance = forall a. ToJSON a => MC3 a -> MC3 a
mc3ResetAcceptance
  aCleanAfterBurnIn :: TraceLength -> MC3 a -> IO (MC3 a)
aCleanAfterBurnIn = forall a. ToJSON a => TraceLength -> MC3 a -> IO (MC3 a)
mc3CleanAfterBurnIn
  aSummarizeCycle :: IterationMode -> MC3 a -> ByteString
aSummarizeCycle = forall a. ToJSON a => IterationMode -> MC3 a -> ByteString
mc3SummarizeCycle
  aOpenMonitors :: AnalysisName -> ExecutionMode -> MC3 a -> IO (MC3 a)
aOpenMonitors = forall a.
ToJSON a =>
AnalysisName -> ExecutionMode -> MC3 a -> IO (MC3 a)
mc3OpenMonitors
  aExecuteMonitors :: Verbosity -> UTCTime -> Int -> MC3 a -> IO (Maybe ByteString)
aExecuteMonitors = forall a.
ToJSON a =>
Verbosity -> UTCTime -> Int -> MC3 a -> IO (Maybe ByteString)
mc3ExecuteMonitors
  aStdMonitorHeader :: MC3 a -> ByteString
aStdMonitorHeader = forall a. ToJSON a => MC3 a -> ByteString
mc3StdMonitorHeader
  aCloseMonitors :: MC3 a -> IO (MC3 a)
aCloseMonitors = forall a. ToJSON a => MC3 a -> IO (MC3 a)
mc3CloseMonitors
  aSave :: AnalysisName -> MC3 a -> IO ()
aSave = forall a. ToJSON a => AnalysisName -> MC3 a -> IO ()
mc3Save

heatFunction ::
  -- Cold Function.
  (a -> Log Double) ->
  -- Reciprocal temperature.
  Double ->
  -- The heated prior or likelihood function
  (a -> Log Double)
heatFunction :: forall a. (a -> Log Double) -> Double -> a -> Log Double
heatFunction a -> Log Double
f Double
b
  | Double
b forall a. Ord a => a -> a -> Bool
<= Double
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"heatFunction: Reciprocal temperature is zero or negative."
  | Double
b forall a. Eq a => a -> a -> Bool
== Double
1.0 = a -> Log Double
f
  | Bool
otherwise = (forall a. Floating a => a -> a -> a
** Log Double
b') forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Log Double
f
  where
    b' :: Log Double
b' = forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
log Double
b

--  The prior and likelihood values of the current link are updated.
--
-- NOTE: The trace is not changed! In particular, the prior and likelihood
-- values are not updated for any link of the trace, and no new link is added to
-- the trace.
setReciprocalTemperature ::
  -- Cold prior function.
  PriorFunction a ->
  -- Cold likelihood function.
  LikelihoodFunction a ->
  -- New reciprocal temperature.
  Double ->
  MHG a ->
  MHG a
setReciprocalTemperature :: forall a.
PriorFunction a -> PriorFunction a -> Double -> MHG a -> MHG a
setReciprocalTemperature PriorFunction a
coldPrf PriorFunction a
coldLhf Double
b MHG a
a =
  forall a. Chain a -> MHG a
MHG forall a b. (a -> b) -> a -> b
$
    Chain a
c
      { priorFunction :: PriorFunction a
priorFunction = PriorFunction a
prf',
        likelihoodFunction :: PriorFunction a
likelihoodFunction = PriorFunction a
lhf',
        link :: Link a
link = forall a. a -> Log Double -> Log Double -> Link a
Link a
x (PriorFunction a
prf' a
x) (PriorFunction a
lhf' a
x)
      }
  where
    c :: Chain a
c = forall a. MHG a -> Chain a
fromMHG MHG a
a
    -- We need twice the amount of computations compared to taking the power
    -- after calculating the posterior (pr x * lh x) ** b'. However, I don't
    -- think this is a serious problem.
    --
    -- To minimize computations, it is key to avoid modification of the
    -- reciprocal temperature for the cold chain.
    prf' :: PriorFunction a
prf' = forall a. (a -> Log Double) -> Double -> a -> Log Double
heatFunction PriorFunction a
coldPrf Double
b
    lhf' :: PriorFunction a
lhf' = forall a. (a -> Log Double) -> Double -> a -> Log Double
heatFunction PriorFunction a
coldLhf Double
b
    x :: a
x = forall a. Link a -> a
state forall a b. (a -> b) -> a -> b
$ forall a. Chain a -> Link a
link Chain a
c

initMHG ::
  -- Cold prior function.
  PriorFunction a ->
  -- Cold likelihood function.
  LikelihoodFunction a ->
  -- Index of MHG chain.
  Int ->
  -- Reciprocal temperature.
  Double ->
  MHG a ->
  IO (MHG a)
initMHG :: forall a.
PriorFunction a
-> PriorFunction a -> Int -> Double -> MHG a -> IO (MHG a)
initMHG PriorFunction a
prf PriorFunction a
lhf Int
i Double
beta MHG a
a
  | Int
i forall a. Ord a => a -> a -> Bool
< Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"initMHG: Chain index negative."
  -- Do nothing for the cold chain.
  | Int
i forall a. Eq a => a -> a -> Bool
== Int
0 = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Chain a -> MHG a
MHG forall a b. (a -> b) -> a -> b
$ Chain a
c
  | Bool
otherwise = do
      -- We have to push the current link in the trace, since it is not set by
      -- 'setReciprocalTemperature'. The other links in the trace are still
      -- pointing to the link of the cold chain, but this has no effect.
      Trace a
t' <- forall a. Link a -> Trace a -> IO (Trace a)
pushT Link a
l Trace a
t
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Chain a -> MHG a
MHG forall a b. (a -> b) -> a -> b
$ Chain a
c {trace :: Trace a
trace = Trace a
t'}
  where
    a' :: MHG a
a' = forall a.
PriorFunction a -> PriorFunction a -> Double -> MHG a -> MHG a
setReciprocalTemperature PriorFunction a
prf PriorFunction a
lhf Double
beta MHG a
a
    c :: Chain a
c = forall a. MHG a -> Chain a
fromMHG MHG a
a'
    l :: Link a
l = forall a. Chain a -> Link a
link Chain a
c
    t :: Trace a
t = forall a. Chain a -> Trace a
trace Chain a
c

-- | Initialize an MC3 algorithm with a given number of chains.
--
-- Call 'error' if:
--
-- - The number of chains is one or lower.
--
-- - The swap period is zero or negative.
mc3 ::
  MC3Settings ->
  Settings ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  InitialState a ->
  StdGen ->
  IO (MC3 a)
mc3 :: forall a.
MC3Settings
-> Settings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> IO (MC3 a)
mc3 MC3Settings
sMc3 Settings
s PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn a
i0 StdGen
g
  | Int
n forall a. Ord a => a -> a -> Bool
< Int
2 = forall a. HasCallStack => [Char] -> a
error [Char]
"mc3: The number of chains must be two or larger."
  | Int
sp forall a. Ord a => a -> a -> Bool
< Int
1 = forall a. HasCallStack => [Char] -> a
error [Char]
"mc3: The swap period must be strictly positive."
  | Int
sn forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
sn forall a. Ord a => a -> a -> Bool
> Int
n forall a. Num a => a -> a -> a
- Int
1 = forall a. HasCallStack => [Char] -> a
error [Char]
"mc3: The number of swaps must be in [1, NChains - 1]."
  | Bool
otherwise = do
      -- Split random number generator.
      let gs :: [StdGen]
gs = forall a. Int -> [a] -> [a]
take (Int
n forall a. Num a => a -> a -> a
+ Int
1) forall a b. (a -> b) -> a -> b
$ forall b a. (b -> Maybe (a, b)) -> b -> [a]
unfoldr (forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall g. RandomGen g => g -> (g, g)
split) StdGen
g
      -- Prepare MHG chains.
      Vector (MHG a)
cs <- forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM (forall a.
Settings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> IO (MHG a)
mhg Settings
s PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn a
i0) (forall a. [a] -> Vector a
V.fromList forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail [StdGen]
gs)
      Vector (MHG a)
hcs <- forall (m :: * -> *) a b c.
Monad m =>
(Int -> a -> b -> m c) -> Vector a -> Vector b -> m (Vector c)
V.izipWithM (forall a.
PriorFunction a
-> PriorFunction a -> Int -> Double -> MHG a -> IO (MHG a)
initMHG PriorFunction a
pr PriorFunction a
lh) (forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
V.convert ReciprocalTemperatures
bs) Vector (MHG a)
cs
      -- Do not reuse the initial generator.
      IOGenM StdGen
gm <- forall (m :: * -> *) g. MonadIO m => g -> m (IOGenM g)
newIOGenM forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [StdGen]
gs
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a.
MC3Settings
-> MHGChains a
-> ReciprocalTemperatures
-> Int
-> Acceptance Int
-> IOGenM StdGen
-> MC3 a
MC3 MC3Settings
sMc3 Vector (MHG a)
hcs ReciprocalTemperatures
bs Int
0 (forall k. Ord k => [k] -> Acceptance k
emptyA [Int
0 .. Int
n forall a. Num a => a -> a -> a
- Int
2]) IOGenM StdGen
gm
  where
    n :: Int
n = NChains -> Int
fromNChains forall a b. (a -> b) -> a -> b
$ MC3Settings -> NChains
mc3NChains MC3Settings
sMc3
    sp :: Int
sp = SwapPeriod -> Int
fromSwapPeriod forall a b. (a -> b) -> a -> b
$ MC3Settings -> SwapPeriod
mc3SwapPeriod MC3Settings
sMc3
    sn :: Int
sn = NSwaps -> Int
fromNSwaps forall a b. (a -> b) -> a -> b
$ MC3Settings -> NSwaps
mc3NSwaps MC3Settings
sMc3
    -- NOTE: The initial choice of reciprocal temperatures is based on a few
    -- tests but otherwise pretty arbitrary.
    --
    -- NOTE: Have to 'take n' elements, because vectors are not as lazy as
    -- lists.
    bs :: ReciprocalTemperatures
bs = forall a. Unbox a => [a] -> Vector a
U.fromList forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
n forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (forall a. Num a => a -> a -> a
* Double
0.97) Double
1.0

mc3Fn :: AnalysisName -> FilePath
mc3Fn :: AnalysisName -> [Char]
mc3Fn (AnalysisName [Char]
nm) = [Char]
nm forall a. [a] -> [a] -> [a]
++ [Char]
".mcmc.mc3"

-- | Save an MC3 algorithm.
mc3Save ::
  ToJSON a =>
  AnalysisName ->
  MC3 a ->
  IO ()
mc3Save :: forall a. ToJSON a => AnalysisName -> MC3 a -> IO ()
mc3Save AnalysisName
nm MC3 a
a = do
  SavedMC3 a
savedMC3 <- forall a. MC3 a -> IO (SavedMC3 a)
toSavedMC3 MC3 a
a
  [Char] -> ByteString -> IO ()
BL.writeFile (AnalysisName -> [Char]
mc3Fn AnalysisName
nm) forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
compress forall a b. (a -> b) -> a -> b
$ forall a. ToJSON a => a -> ByteString
encode SavedMC3 a
savedMC3

-- | Load an MC3 algorithm.
--
-- Also create a backup of the save.
--
-- See 'Mcmc.Mcmc.mcmcContinue'.
mc3Load ::
  FromJSON a =>
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  AnalysisName ->
  IO (MC3 a)
mc3Load :: forall a.
FromJSON a =>
PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> AnalysisName
-> IO (MC3 a)
mc3Load PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn AnalysisName
nm = do
  Either [Char] (SavedMC3 a)
savedMC3 <- forall a. FromJSON a => ByteString -> Either [Char] a
eitherDecode forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
decompress forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> IO ByteString
BL.readFile [Char]
fn
  forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => [Char] -> a
error (forall a.
PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedMC3 a
-> IO (MC3 a)
fromSavedMC3 PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn) Either [Char] (SavedMC3 a)
savedMC3
  where
    -- fnBak = mc3Fn $ AnalysisName $ (fromAnalysisName nm ++ ".bak")
    fn :: [Char]
fn = AnalysisName -> [Char]
mc3Fn AnalysisName
nm

-- I call the chains left and right, because it is easy to think about them as
-- being left and right. Of course, the left chain may also have a larger index
-- than the right chain.
swapWith ::
  -- Index i>=0 of left chain.
  Int ->
  -- Index j>=0, j/=i of right chain.
  Int ->
  MHGChains a ->
  (MHGChains a, Posterior)
swapWith :: forall a. Int -> Int -> MHGChains a -> (MHGChains a, Log Double)
swapWith Int
i Int
j MHGChains a
xs
  | Int
i forall a. Ord a => a -> a -> Bool
< Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"swapWith: Left index is negative."
  | Int
j forall a. Ord a => a -> a -> Bool
< Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"swapWith: Right index is negative."
  | Int
i forall a. Eq a => a -> a -> Bool
== Int
j = forall a. HasCallStack => [Char] -> a
error [Char]
"swapWith: Indices are equal."
  | Bool
otherwise = (MHGChains a
xs', Log Double
q)
  where
    -- Gather information from current chains.
    cl :: Chain a
cl = forall a. MHG a -> Chain a
fromMHG forall a b. (a -> b) -> a -> b
$ MHGChains a
xs forall a. Vector a -> Int -> a
V.! Int
i
    cr :: Chain a
cr = forall a. MHG a -> Chain a
fromMHG forall a b. (a -> b) -> a -> b
$ MHGChains a
xs forall a. Vector a -> Int -> a
V.! Int
j
    ll :: Link a
ll = forall a. Chain a -> Link a
link Chain a
cl
    lr :: Link a
lr = forall a. Chain a -> Link a
link Chain a
cr
    prl :: Log Double
prl = forall a. Link a -> Log Double
prior Link a
ll
    prr :: Log Double
prr = forall a. Link a -> Log Double
prior Link a
lr
    lhl :: Log Double
lhl = forall a. Link a -> Log Double
likelihood Link a
ll
    lhr :: Log Double
lhr = forall a. Link a -> Log Double
likelihood Link a
lr
    -- Swap the states.
    xl' :: a
xl' = forall a. Link a -> a
state Link a
lr
    xr' :: a
xr' = forall a. Link a -> a
state Link a
ll
    -- Compute new priors and likelihoods.
    prl' :: Log Double
prl' = forall a. Chain a -> PriorFunction a
priorFunction Chain a
cl a
xl'
    prr' :: Log Double
prr' = forall a. Chain a -> PriorFunction a
priorFunction Chain a
cr a
xr'
    lhl' :: Log Double
lhl' = forall a. Chain a -> PriorFunction a
likelihoodFunction Chain a
cl a
xl'
    lhr' :: Log Double
lhr' = forall a. Chain a -> PriorFunction a
likelihoodFunction Chain a
cr a
xr'
    -- Set the new links and the proposed state.
    ll' :: Link a
ll' = forall a. a -> Log Double -> Log Double -> Link a
Link a
xl' Log Double
prl' Log Double
lhl'
    lr' :: Link a
lr' = forall a. a -> Log Double -> Log Double -> Link a
Link a
xr' Log Double
prr' Log Double
lhr'
    cl' :: Chain a
cl' = Chain a
cl {link :: Link a
link = Link a
ll'}
    cr' :: Chain a
cr' = Chain a
cr {link :: Link a
link = Link a
lr'}
    xs' :: MHGChains a
xs' = MHGChains a
xs forall a. Vector a -> [(Int, a)] -> Vector a
V.// [(Int
i, forall a. Chain a -> MHG a
MHG Chain a
cl'), (Int
j, forall a. Chain a -> MHG a
MHG Chain a
cr')]
    -- Compute the Metropolis ratio.
    nominator :: Log Double
nominator = Log Double
prl' forall a. Num a => a -> a -> a
* Log Double
prr' forall a. Num a => a -> a -> a
* Log Double
lhl' forall a. Num a => a -> a -> a
* Log Double
lhr'
    denominator :: Log Double
denominator = Log Double
prl forall a. Num a => a -> a -> a
* Log Double
prr forall a. Num a => a -> a -> a
* Log Double
lhl forall a. Num a => a -> a -> a
* Log Double
lhr
    q :: Log Double
q = Log Double
nominator forall a. Fractional a => a -> a -> a
/ Log Double
denominator

mc3ProposeSwap ::
  MC3 a ->
  -- Index of left chain.
  Int ->
  IO (MC3 a)
mc3ProposeSwap :: forall a. MC3 a -> Int -> IO (MC3 a)
mc3ProposeSwap MC3 a
a Int
i = do
  let cs :: MHGChains a
cs = forall a. MC3 a -> MHGChains a
mc3MHGChains MC3 a
a
  -- -- Debug.
  -- prL = prior $ link $ fromMHG $ cs V.! i
  -- prR = prior $ link $ fromMHG $ cs V.! (i+1)
  -- lhL = likelihood $ link $ fromMHG $ cs V.! i
  -- lhR = likelihood $ link $ fromMHG $ cs V.! (i+1)
  -- 1. Sample new state and get the Metropolis ratio.
  let (!MHGChains a
y, !Log Double
r) = forall a. Int -> Int -> MHGChains a -> (MHGChains a, Log Double)
swapWith Int
i (Int
i forall a. Num a => a -> a -> a
+ Int
1) MHGChains a
cs
  -- 2. Accept or reject.
  Bool
accept <- Log Double -> IOGenM StdGen -> IO Bool
mhgAccept Log Double
r IOGenM StdGen
g
  if Bool
accept
    then do
      -- -- Debug.
      -- traceIO $ "Swap accepted: " <> show i <> " <-> " <> show (i+1)
      -- let prL' = prior $ link $ fromMHG $ y V.! i
      --     prR' = prior $ link $ fromMHG $ y V.! (i+1)
      --     lhL' = likelihood $ link $ fromMHG $ y V.! i
      --     lhR' = likelihood $ link $ fromMHG $ y V.! (i+1)
      -- traceIO $ "Log priors (left, right, before swap): " <> show (ln prL) <> " " <> show (ln prR)
      -- traceIO $ "Log priors (left, right, after swap): " <> show (ln prL') <> " " <> show (ln prR')
      -- traceIO $ "Log likelihoods (left, right, before swap): " <> show (ln lhL) <> " " <> show (ln lhR)
      -- traceIO $ "Log likelihood (left, right, after swap): " <> show (ln lhL') <> " " <> show (ln lhR')
      let !ac' :: Acceptance Int
ac' = forall k. Ord k => k -> Acceptance k -> Acceptance k
pushAccept Int
i (forall a. MC3 a -> Acceptance Int
mc3SwapAcceptance MC3 a
a)
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ MC3 a
a {mc3MHGChains :: MHGChains a
mc3MHGChains = MHGChains a
y, mc3SwapAcceptance :: Acceptance Int
mc3SwapAcceptance = Acceptance Int
ac'}
    else do
      let !ac' :: Acceptance Int
ac' = forall k. Ord k => k -> Acceptance k -> Acceptance k
pushReject Int
i (forall a. MC3 a -> Acceptance Int
mc3SwapAcceptance MC3 a
a)
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ MC3 a
a {mc3SwapAcceptance :: Acceptance Int
mc3SwapAcceptance = Acceptance Int
ac'}
  where
    g :: IOGenM StdGen
g = forall a. MC3 a -> IOGenM StdGen
mc3Generator MC3 a
a

mc3IsInvalidState :: ToJSON a => MC3 a -> Bool
mc3IsInvalidState :: forall a. ToJSON a => MC3 a -> Bool
mc3IsInvalidState MC3 a
a = forall a. (a -> Bool) -> Vector a -> Bool
V.any forall a. Algorithm a => a -> Bool
aIsInvalidState MHGChains a
mhgs
  where
    mhgs :: MHGChains a
mhgs = forall a. MC3 a -> MHGChains a
mc3MHGChains MC3 a
a

-- NOTE: 'mc3Iterate' is actually not parallel, but concurrent because of the IO
-- constraint of the mutable trace.
mc3Iterate ::
  ToJSON a =>
  IterationMode ->
  ParallelizationMode ->
  MC3 a ->
  IO (MC3 a)
mc3Iterate :: forall a.
ToJSON a =>
IterationMode -> ParallelizationMode -> MC3 a -> IO (MC3 a)
mc3Iterate IterationMode
m ParallelizationMode
pm MC3 a
a = do
  -- 1. Maybe propose swaps.
  --
  -- NOTE: Swaps have to be proposed first, because the traces are automatically
  -- updated at step 2.
  let s :: MC3Settings
s = forall a. MC3 a -> MC3Settings
mc3Settings MC3 a
a
  MC3 a
a' <-
    if forall a. MC3 a -> Int
mc3Iteration MC3 a
a forall a. Integral a => a -> a -> a
`mod` SwapPeriod -> Int
fromSwapPeriod (MC3Settings -> SwapPeriod
mc3SwapPeriod MC3Settings
s) forall a. Eq a => a -> a -> Bool
== Int
0
      then do
        let n :: Int
n = forall a. Vector a -> Int
V.length forall a b. (a -> b) -> a -> b
$ forall a. MC3 a -> MHGChains a
mc3MHGChains MC3 a
a
            is :: [Int]
is = [Int
0 .. Int
n forall a. Num a => a -> a -> a
- Int
2]
            ns :: Int
ns = NSwaps -> Int
fromNSwaps forall a b. (a -> b) -> a -> b
$ MC3Settings -> NSwaps
mc3NSwaps MC3Settings
s
        [Int]
is' <- forall g (m :: * -> *) a. StatefulGen g m => [a] -> g -> m [a]
shuffle [Int]
is forall a b. (a -> b) -> a -> b
$ forall a. MC3 a -> IOGenM StdGen
mc3Generator MC3 a
a
        forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall a. MC3 a -> Int -> IO (MC3 a)
mc3ProposeSwap MC3 a
a (forall a. Int -> [a] -> [a]
take Int
ns [Int]
is')
      else forall (m :: * -> *) a. Monad m => a -> m a
return MC3 a
a
  -- 2. Iterate all chains and increment iteration.
  Vector (MHG a)
mhgs <- case ParallelizationMode
pm of
    ParallelizationMode
Sequential -> forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM (forall a.
Algorithm a =>
IterationMode -> ParallelizationMode -> a -> IO a
aIterate IterationMode
m ParallelizationMode
pm) (forall a. MC3 a -> MHGChains a
mc3MHGChains MC3 a
a')
    ParallelizationMode
Parallel ->
      -- Go via a list, and use 'forkIO' ("Control.Concurrent.Async").
      forall a. [a] -> Vector a
V.fromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) a b.
Traversable t =>
(a -> IO b) -> t a -> IO (t b)
mapConcurrently (forall a.
Algorithm a =>
IterationMode -> ParallelizationMode -> a -> IO a
aIterate IterationMode
m ParallelizationMode
pm) (forall a. Vector a -> [a]
V.toList forall a b. (a -> b) -> a -> b
$ forall a. MC3 a -> MHGChains a
mc3MHGChains MC3 a
a')
  let i :: Int
i = forall a. MC3 a -> Int
mc3Iteration MC3 a
a'
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ MC3 a
a' {mc3MHGChains :: Vector (MHG a)
mc3MHGChains = Vector (MHG a)
mhgs, mc3Iteration :: Int
mc3Iteration = forall a. Enum a => a -> a
succ Int
i}

tuneBeta ::
  -- The old reciprocal temperatures are needed to retrieve the old ratios.
  ReciprocalTemperatures ->
  -- Index i of left chain. Change the reciprocal temperature of chain (i+1).
  Int ->
  -- Exponent xi of the reciprocal temperature ratio.
  Double ->
  -- The new reciprocal temperatures are updated incrementally using the
  -- reciprocal temperature ratios during the fold (see 'mc3AutoTune' below).
  ReciprocalTemperatures ->
  ReciprocalTemperatures
tuneBeta :: ReciprocalTemperatures
-> Int
-> Double
-> ReciprocalTemperatures
-> ReciprocalTemperatures
tuneBeta ReciprocalTemperatures
bsOld Int
i Double
xi ReciprocalTemperatures
bsNew = ReciprocalTemperatures
bsNew forall a. Unbox a => Vector a -> [(Int, a)] -> Vector a
U.// [(Int
j, Double
brNew)]
  where
    j :: Int
j = Int
i forall a. Num a => a -> a -> a
+ Int
1
    blOld :: Double
blOld = ReciprocalTemperatures
bsOld forall a. Unbox a => Vector a -> Int -> a
U.! Int
i
    brOld :: Double
brOld = ReciprocalTemperatures
bsOld forall a. Unbox a => Vector a -> Int -> a
U.! Int
j
    blNew :: Double
blNew = ReciprocalTemperatures
bsNew forall a. Unbox a => Vector a -> Int -> a
U.! Int
i
    -- The new ratio is in (0,1).
    rNew :: Double
rNew = (Double
brOld forall a. Fractional a => a -> a -> a
/ Double
blOld) forall a. Floating a => a -> a -> a
** Double
xi
    brNew :: Double
brNew = Double
blNew forall a. Num a => a -> a -> a
* Double
rNew

mc3AutoTune :: ToJSON a => TuningType -> Int -> MC3 a -> IO (MC3 a)
mc3AutoTune :: forall a. ToJSON a => TuningType -> Int -> MC3 a -> IO (MC3 a)
mc3AutoTune TuningType
b Int
l MC3 a
a = do
  -- 1. Auto tune all chains.
  Vector (MHG a)
mhgs' <- forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM (forall a. Algorithm a => TuningType -> Int -> a -> IO a
aAutoTune TuningType
b Int
l) forall a b. (a -> b) -> a -> b
$ forall a. MC3 a -> MHGChains a
mc3MHGChains MC3 a
a
  -- 2. Auto tune temperatures.
  let optimalRate :: Double
optimalRate = PDimension -> Double
getOptimalRate PDimension
PDimensionUnknown
      mCurrentRates :: Map Int (Maybe Double)
mCurrentRates = forall k. Acceptance k -> Map k (Maybe Double)
acceptanceRates forall a b. (a -> b) -> a -> b
$ forall a. MC3 a -> Acceptance Int
mc3SwapAcceptance MC3 a
a
      -- We assume that the acceptance rate of state swaps between two chains is
      -- roughly proportional to the ratio of the temperatures of the chains.
      -- Hence, we focus on temperature ratios, actually reciprocal temperature
      -- ratios, which is the same. Also, by working with ratios in (0,1) of
      -- neighboring chains, we ensure the monotonicity of the reciprocal
      -- temperatures.
      --
      -- The factor (1/2) was determined by a few tests and is otherwise
      -- absolutely arbitrary.
      xi :: Int -> Double
xi Int
i = case Map Int (Maybe Double)
mCurrentRates forall k a. Ord k => Map k a -> k -> a
M.! Int
i of
        Maybe Double
Nothing -> Double
1.0
        Just Double
currentRate -> forall a. Floating a => a -> a
exp forall a b. (a -> b) -> a -> b
$ (forall a. Fractional a => a -> a -> a
/ Double
2) forall a b. (a -> b) -> a -> b
$ Double
currentRate forall a. Num a => a -> a -> a
- Double
optimalRate
      bs :: ReciprocalTemperatures
bs = forall a. MC3 a -> ReciprocalTemperatures
mc3ReciprocalTemperatures MC3 a
a
      n :: Int
n = NChains -> Int
fromNChains forall a b. (a -> b) -> a -> b
$ MC3Settings -> NChains
mc3NChains forall a b. (a -> b) -> a -> b
$ forall a. MC3 a -> MC3Settings
mc3Settings MC3 a
a
      -- Do not change the temperature, and the prior and likelihood functions of
      -- the cold chain.
      bs' :: ReciprocalTemperatures
bs' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\ReciprocalTemperatures
xs Int
j -> ReciprocalTemperatures
-> Int
-> Double
-> ReciprocalTemperatures
-> ReciprocalTemperatures
tuneBeta ReciprocalTemperatures
bs Int
j (Int -> Double
xi Int
j) ReciprocalTemperatures
xs) ReciprocalTemperatures
bs [Int
0 .. Int
n forall a. Num a => a -> a -> a
- Int
2]
      coldChain :: Chain a
coldChain = forall a. MHG a -> Chain a
fromMHG forall a b. (a -> b) -> a -> b
$ forall a. Vector a -> a
V.head Vector (MHG a)
mhgs'
      coldPrF :: PriorFunction a
coldPrF = forall a. Chain a -> PriorFunction a
priorFunction Chain a
coldChain
      coldLhF :: PriorFunction a
coldLhF = forall a. Chain a -> PriorFunction a
likelihoodFunction Chain a
coldChain
      mhgs'' :: Vector (MHG a)
mhgs'' =
        forall a. Vector a -> a
V.head Vector (MHG a)
mhgs'
          forall a. a -> Vector a -> Vector a
`V.cons` forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith
            (forall a.
PriorFunction a -> PriorFunction a -> Double -> MHG a -> MHG a
setReciprocalTemperature PriorFunction a
coldPrF PriorFunction a
coldLhF)
            (forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
V.convert forall a b. (a -> b) -> a -> b
$ forall a. Unbox a => Vector a -> Vector a
U.tail ReciprocalTemperatures
bs')
            (forall a. Vector a -> Vector a
V.tail Vector (MHG a)
mhgs')
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ MC3 a
a {mc3MHGChains :: Vector (MHG a)
mc3MHGChains = Vector (MHG a)
mhgs'', mc3ReciprocalTemperatures :: ReciprocalTemperatures
mc3ReciprocalTemperatures = ReciprocalTemperatures
bs'}

mc3ResetAcceptance :: ToJSON a => MC3 a -> MC3 a
mc3ResetAcceptance :: forall a. ToJSON a => MC3 a -> MC3 a
mc3ResetAcceptance MC3 a
a = MC3 a
a'
  where
    -- 1. Reset acceptance of all chains.
    mhgs' :: Vector (MHG a)
mhgs' = forall a b. (a -> b) -> Vector a -> Vector b
V.map forall a. Algorithm a => a -> a
aResetAcceptance (forall a. MC3 a -> MHGChains a
mc3MHGChains MC3 a
a)
    -- 2. Reset acceptance of swaps.
    ac' :: Acceptance Int
ac' = forall k. Ord k => Acceptance k -> Acceptance k
resetA forall a b. (a -> b) -> a -> b
$ forall a. MC3 a -> Acceptance Int
mc3SwapAcceptance MC3 a
a
    --
    a' :: MC3 a
a' = MC3 a
a {mc3MHGChains :: Vector (MHG a)
mc3MHGChains = Vector (MHG a)
mhgs', mc3SwapAcceptance :: Acceptance Int
mc3SwapAcceptance = Acceptance Int
ac'}

mc3CleanAfterBurnIn :: ToJSON a => TraceLength -> MC3 a -> IO (MC3 a)
mc3CleanAfterBurnIn :: forall a. ToJSON a => TraceLength -> MC3 a -> IO (MC3 a)
mc3CleanAfterBurnIn TraceLength
tl MC3 a
a = do
  Vector (MHG a)
cs' <- forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM (forall a. Algorithm a => TraceLength -> a -> IO a
aCleanAfterBurnIn TraceLength
tl) Vector (MHG a)
cs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ MC3 a
a {mc3MHGChains :: Vector (MHG a)
mc3MHGChains = Vector (MHG a)
cs'}
  where
    cs :: Vector (MHG a)
cs = forall a. MC3 a -> MHGChains a
mc3MHGChains MC3 a
a

-- Information in cycle summary:
--
-- - The complete summary of the cycle of the cold chain.
--
-- - The combined acceptance rate of proposals within the hot chains.
--
-- - The temperatures of the chains and the acceptance rates of the state swaps.
mc3SummarizeCycle :: ToJSON a => IterationMode -> MC3 a -> BL.ByteString
mc3SummarizeCycle :: forall a. ToJSON a => IterationMode -> MC3 a -> ByteString
mc3SummarizeCycle IterationMode
m MC3 a
a =
  ByteString -> [ByteString] -> ByteString
BL.intercalate ByteString
"\n" forall a b. (a -> b) -> a -> b
$
    [ ByteString
"MC3: Cycle of cold chain.",
      ByteString
coldMHGCycleSummary
    ]
      forall a. [a] -> [a] -> [a]
++ case Maybe Double
mAr of
        Maybe Double
Nothing -> []
        Just Double
ar ->
          [ ByteString
"MC3: Average acceptance rate across all chains: "
              forall a. Semigroup a => a -> a -> a
<> Builder -> ByteString
BB.toLazyByteString (FloatFormat -> Double -> Builder
BB.formatDouble (Int -> FloatFormat
BB.standard Int
2) Double
ar)
              forall a. Semigroup a => a -> a -> a
<> ByteString
"."
          ]
      forall a. [a] -> [a] -> [a]
++ [ ByteString
"MC3: Reciprocal temperatures of the chains: " forall a. Semigroup a => a -> a -> a
<> ByteString -> [ByteString] -> ByteString
BL.intercalate ByteString
", " [ByteString]
bsB forall a. Semigroup a => a -> a -> a
<> ByteString
".",
           ByteString
"MC3: Summary of state swaps.",
           ByteString
"MC3: The swap period is " forall a. Semigroup a => a -> a -> a
<> ByteString
swapPeriodB forall a. Semigroup a => a -> a -> a
<> ByteString
".",
           ByteString
"MC3: The state swaps are executed in random order.",
           ByteString
proposalHeader,
           ByteString
proposalHLine
         ]
      forall a. [a] -> [a] -> [a]
++ [ PName
-> PDescription
-> PWeight
-> Maybe Double
-> PDimension
-> Maybe (Int, Int, Double)
-> ByteString
summarizeProposal
             ([Char] -> PName
PName forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show Int
i forall a. [a] -> [a] -> [a]
++ [Char]
" <-> " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (Int
i forall a. Num a => a -> a -> a
+ Int
1))
             ([Char] -> PDescription
PDescription [Char]
"Swap states between chains")
             (Int -> PWeight
pWeight Int
1)
             (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ ReciprocalTemperatures
bs forall a. Unbox a => Vector a -> Int -> a
U.! (Int
i forall a. Num a => a -> a -> a
+ Int
1))
             PDimension
PDimensionUnknown
             (forall k. Ord k => k -> Acceptance k -> Maybe (Int, Int, Double)
acceptanceRate Int
i Acceptance Int
swapAcceptance)
           | Int
i <- [Int
0 .. Int
n forall a. Num a => a -> a -> a
- Int
2]
         ]
      forall a. [a] -> [a] -> [a]
++ [ByteString
proposalHLine]
  where
    mhgs :: MHGChains a
mhgs = forall a. MC3 a -> MHGChains a
mc3MHGChains MC3 a
a
    coldMHGCycleSummary :: ByteString
coldMHGCycleSummary = forall a. Algorithm a => IterationMode -> a -> ByteString
aSummarizeCycle IterationMode
m forall a b. (a -> b) -> a -> b
$ forall a. Vector a -> a
V.head MHGChains a
mhgs
    cs :: Vector (Chain a)
cs = forall a b. (a -> b) -> Vector a -> Vector b
V.map forall a. MHG a -> Chain a
fromMHG MHGChains a
mhgs
    -- Acceptance rates may be 'Nothing' when no proposals have been undertaken.
    -- The 'sequence' operations pull the 'Nothing's out of the inner
    -- structures.
    as :: Maybe (Vector (Map (Proposal a) Double))
as = forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> Vector a -> Vector b
V.map (forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k. Acceptance k -> Map k (Maybe Double)
acceptanceRates forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Chain a -> Acceptance (Proposal a)
acceptance) Vector (Chain a)
cs
    mVecAr :: Maybe (Vector Double)
mVecAr = forall a b. (a -> b) -> Vector a -> Vector b
V.map (\Map (Proposal a) Double
mp -> forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Map (Proposal a) Double
mp forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (t :: * -> *) a. Foldable t => t a -> Int
length Map (Proposal a) Double
mp)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Vector (Map (Proposal a) Double))
as
    mAr :: Maybe Double
mAr = (\Vector Double
vec -> forall a. Num a => Vector a -> a
V.sum Vector Double
vec forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Vector a -> Int
V.length Vector Double
vec)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Vector Double)
mVecAr
    bs :: ReciprocalTemperatures
bs = forall a. MC3 a -> ReciprocalTemperatures
mc3ReciprocalTemperatures MC3 a
a
    bsB :: [ByteString]
bsB = forall a b. (a -> b) -> [a] -> [b]
map (Builder -> ByteString
BB.toLazyByteString forall b c a. (b -> c) -> (a -> b) -> a -> c
. FloatFormat -> Double -> Builder
BB.formatDouble (Int -> FloatFormat
BB.standard Int
2)) forall a b. (a -> b) -> a -> b
$ forall a. Unbox a => Vector a -> [a]
U.toList ReciprocalTemperatures
bs
    swapPeriod :: Int
swapPeriod = SwapPeriod -> Int
fromSwapPeriod forall a b. (a -> b) -> a -> b
$ MC3Settings -> SwapPeriod
mc3SwapPeriod forall a b. (a -> b) -> a -> b
$ forall a. MC3 a -> MC3Settings
mc3Settings MC3 a
a
    swapPeriodB :: ByteString
swapPeriodB = Builder -> ByteString
BB.toLazyByteString forall a b. (a -> b) -> a -> b
$ Int -> Builder
BB.intDec Int
swapPeriod
    swapAcceptance :: Acceptance Int
swapAcceptance = forall a. MC3 a -> Acceptance Int
mc3SwapAcceptance MC3 a
a
    n :: Int
n = NChains -> Int
fromNChains forall a b. (a -> b) -> a -> b
$ MC3Settings -> NChains
mc3NChains forall a b. (a -> b) -> a -> b
$ forall a. MC3 a -> MC3Settings
mc3Settings MC3 a
a
    proposalHLine :: ByteString
proposalHLine = Int64 -> Char -> ByteString
BL.replicate (ByteString -> Int64
BL.length ByteString
proposalHeader) Char
'-'

-- No extra monitors are opened.
mc3OpenMonitors :: ToJSON a => AnalysisName -> ExecutionMode -> MC3 a -> IO (MC3 a)
mc3OpenMonitors :: forall a.
ToJSON a =>
AnalysisName -> ExecutionMode -> MC3 a -> IO (MC3 a)
mc3OpenMonitors AnalysisName
nm ExecutionMode
em MC3 a
a = do
  Vector (MHG a)
mhgs' <- forall (m :: * -> *) a b.
Monad m =>
(Int -> a -> m b) -> Vector a -> m (Vector b)
V.imapM forall {p} {a}. PrintfArg p => p -> MHG a -> IO (MHG a)
mhgOpenMonitors (forall a. MC3 a -> MHGChains a
mc3MHGChains MC3 a
a)
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ MC3 a
a {mc3MHGChains :: Vector (MHG a)
mc3MHGChains = Vector (MHG a)
mhgs'}
  where
    mhgOpenMonitors :: p -> MHG a -> IO (MHG a)
mhgOpenMonitors p
i (MHG Chain a
c) = do
      Monitor a
m' <- forall a.
[Char] -> [Char] -> ExecutionMode -> Monitor a -> IO (Monitor a)
mOpen [Char]
pre [Char]
suf ExecutionMode
em forall a b. (a -> b) -> a -> b
$ forall a. Chain a -> Monitor a
monitor Chain a
c
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Chain a -> MHG a
MHG Chain a
c {monitor :: Monitor a
monitor = Monitor a
m'}
      where
        pre :: [Char]
pre = AnalysisName -> [Char]
fromAnalysisName AnalysisName
nm
        suf :: [Char]
suf = forall r. PrintfType r => [Char] -> r
printf [Char]
"%02d" p
i

mc3ExecuteMonitors ::
  ToJSON a =>
  Verbosity ->
  -- Starting time.
  UTCTime ->
  -- Total number of iterations.
  Int ->
  MC3 a ->
  IO (Maybe BL.ByteString)
mc3ExecuteMonitors :: forall a.
ToJSON a =>
Verbosity -> UTCTime -> Int -> MC3 a -> IO (Maybe ByteString)
mc3ExecuteMonitors Verbosity
vb UTCTime
t0 Int
iTotal MC3 a
a = forall a. Vector a -> a
V.head forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b.
Monad m =>
(Int -> a -> m b) -> Vector a -> m (Vector b)
V.imapM forall {a} {a}.
(Eq a, Num a, Algorithm a) =>
a -> a -> IO (Maybe ByteString)
f (forall a. MC3 a -> MHGChains a
mc3MHGChains MC3 a
a)
  where
    -- The first chain honors verbosity.
    f :: a -> a -> IO (Maybe ByteString)
f a
0 = forall a.
Algorithm a =>
Verbosity -> UTCTime -> Int -> a -> IO (Maybe ByteString)
aExecuteMonitors Verbosity
vb UTCTime
t0 Int
iTotal
    -- All other chains are to be quiet.
    f a
_ = forall a.
Algorithm a =>
Verbosity -> UTCTime -> Int -> a -> IO (Maybe ByteString)
aExecuteMonitors Verbosity
Quiet UTCTime
t0 Int
iTotal

mc3StdMonitorHeader :: ToJSON a => MC3 a -> BL.ByteString
mc3StdMonitorHeader :: forall a. ToJSON a => MC3 a -> ByteString
mc3StdMonitorHeader = forall a. Algorithm a => a -> ByteString
aStdMonitorHeader forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Vector a -> a
V.head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. MC3 a -> MHGChains a
mc3MHGChains

mc3CloseMonitors :: ToJSON a => MC3 a -> IO (MC3 a)
mc3CloseMonitors :: forall a. ToJSON a => MC3 a -> IO (MC3 a)
mc3CloseMonitors MC3 a
a = do
  Vector (MHG a)
mhgs' <- forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM forall a. Algorithm a => a -> IO a
aCloseMonitors forall a b. (a -> b) -> a -> b
$ forall a. MC3 a -> MHGChains a
mc3MHGChains MC3 a
a
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ MC3 a
a {mc3MHGChains :: Vector (MHG a)
mc3MHGChains = Vector (MHG a)
mhgs'}