{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module      :  Mcmc.MarginalLikelihood
-- Description :  Calculate the marginal likelihood
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Mon Jan 11 16:34:18 2021.
module Mcmc.MarginalLikelihood
  ( MarginalLikelihood,
    NPoints (..),
    MLAlgorithm (..),
    MLSettings (..),
    marginalLikelihood,
  )
where

import Control.Concurrent (getNumCapabilities)
import Control.Concurrent.Async hiding (link)
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Reader
import Data.Aeson
import Data.List hiding (cycle)
import qualified Data.Map.Strict as M
import qualified Data.Vector as VB
import qualified Data.Vector.Unboxed as VU
import Mcmc.Acceptance
import Mcmc.Algorithm.MHG
import Mcmc.Chain.Chain
import Mcmc.Chain.Link
import Mcmc.Chain.Trace
import Mcmc.Cycle
import Mcmc.Environment
import Mcmc.Likelihood
import Mcmc.Logger
import Mcmc.Mcmc
import Mcmc.Monitor
import Mcmc.Prior
import Mcmc.Settings
import Numeric.Log hiding (sum)
import System.Directory
import System.Random.Stateful
import Text.Printf
import Prelude hiding (cycle)

-- | Marginal likelihood values are stored in log domain.
type MarginalLikelihood = Log Double

-- Reciprocal temperature value traversed along the path integral.
type Point = Double

-- | The number of points used to approximate the path integral.
newtype NPoints = NPoints {NPoints -> Int
fromNPoints :: Int}
  deriving (NPoints -> NPoints -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: NPoints -> NPoints -> Bool
$c/= :: NPoints -> NPoints -> Bool
== :: NPoints -> NPoints -> Bool
$c== :: NPoints -> NPoints -> Bool
Eq, ReadPrec [NPoints]
ReadPrec NPoints
Int -> ReadS NPoints
ReadS [NPoints]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [NPoints]
$creadListPrec :: ReadPrec [NPoints]
readPrec :: ReadPrec NPoints
$creadPrec :: ReadPrec NPoints
readList :: ReadS [NPoints]
$creadList :: ReadS [NPoints]
readsPrec :: Int -> ReadS NPoints
$creadsPrec :: Int -> ReadS NPoints
Read, Int -> NPoints -> ShowS
[NPoints] -> ShowS
NPoints -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NPoints] -> ShowS
$cshowList :: [NPoints] -> ShowS
show :: NPoints -> String
$cshow :: NPoints -> String
showsPrec :: Int -> NPoints -> ShowS
$cshowsPrec :: Int -> NPoints -> ShowS
Show)

-- | Algorithms to calculate the marginal likelihood.
data MLAlgorithm
  = -- | Use a classical path integral. Also known as thermodynamic integration.
    -- In particular, /Annealing-Melting Integration/ is used.
    --
    -- See Lartillot, N., & Philippe, H., Computing Bayes Factors Using
    -- Thermodynamic Integration, Systematic Biology, 55(2), 195–207 (2006).
    -- http://dx.doi.org/10.1080/10635150500433722
    ThermodynamicIntegration
  | -- | Use stepping stone sampling.
    --
    -- See Xie, W., Lewis, P. O., Fan, Y., Kuo, L., & Chen, M., Improving
    -- marginal likelihood estimation for Bayesian phylogenetic model selection,
    -- Systematic Biology, 60(2), 150–160 (2010).
    -- http://dx.doi.org/10.1093/sysbio/syq085
    --
    -- Or Fan, Y., Wu, R., Chen, M., Kuo, L., & Lewis, P. O., Choosing among
    -- partition models in bayesian phylogenetics, Molecular Biology and
    -- Evolution, 28(1), 523–532 (2010). http://dx.doi.org/10.1093/molbev/msq224
    SteppingStoneSampling
  deriving (MLAlgorithm -> MLAlgorithm -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MLAlgorithm -> MLAlgorithm -> Bool
$c/= :: MLAlgorithm -> MLAlgorithm -> Bool
== :: MLAlgorithm -> MLAlgorithm -> Bool
$c== :: MLAlgorithm -> MLAlgorithm -> Bool
Eq, ReadPrec [MLAlgorithm]
ReadPrec MLAlgorithm
Int -> ReadS MLAlgorithm
ReadS [MLAlgorithm]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [MLAlgorithm]
$creadListPrec :: ReadPrec [MLAlgorithm]
readPrec :: ReadPrec MLAlgorithm
$creadPrec :: ReadPrec MLAlgorithm
readList :: ReadS [MLAlgorithm]
$creadList :: ReadS [MLAlgorithm]
readsPrec :: Int -> ReadS MLAlgorithm
$creadsPrec :: Int -> ReadS MLAlgorithm
Read, Int -> MLAlgorithm -> ShowS
[MLAlgorithm] -> ShowS
MLAlgorithm -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MLAlgorithm] -> ShowS
$cshowList :: [MLAlgorithm] -> ShowS
show :: MLAlgorithm -> String
$cshow :: MLAlgorithm -> String
showsPrec :: Int -> MLAlgorithm -> ShowS
$cshowsPrec :: Int -> MLAlgorithm -> ShowS
Show)

-- | Settings of the marginal likelihood estimation.
data MLSettings = MLSettings
  { MLSettings -> AnalysisName
mlAnalysisName :: AnalysisName,
    MLSettings -> MLAlgorithm
mlAlgorithm :: MLAlgorithm,
    MLSettings -> NPoints
mlNPoints :: NPoints,
    -- | Initial burn in at the starting point of the path (or each segment if
    -- running in parallel).
    MLSettings -> BurnInSettings
mlInitialBurnIn :: BurnInSettings,
    -- | Repetitive burn in at each point on the path.
    MLSettings -> BurnInSettings
mlPointBurnIn :: BurnInSettings,
    -- | The number of iterations performed at each point.
    MLSettings -> Iterations
mlIterations :: Iterations,
    MLSettings -> ExecutionMode
mlExecutionMode :: ExecutionMode,
    MLSettings -> ParallelizationMode
mlParallelizationMode :: ParallelizationMode,
    MLSettings -> LogMode
mlLogMode :: LogMode,
    MLSettings -> Verbosity
mlVerbosity :: Verbosity
  }
  deriving (MLSettings -> MLSettings -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MLSettings -> MLSettings -> Bool
$c/= :: MLSettings -> MLSettings -> Bool
== :: MLSettings -> MLSettings -> Bool
$c== :: MLSettings -> MLSettings -> Bool
Eq, ReadPrec [MLSettings]
ReadPrec MLSettings
Int -> ReadS MLSettings
ReadS [MLSettings]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [MLSettings]
$creadListPrec :: ReadPrec [MLSettings]
readPrec :: ReadPrec MLSettings
$creadPrec :: ReadPrec MLSettings
readList :: ReadS [MLSettings]
$creadList :: ReadS [MLSettings]
readsPrec :: Int -> ReadS MLSettings
$creadsPrec :: Int -> ReadS MLSettings
Read, Int -> MLSettings -> ShowS
[MLSettings] -> ShowS
MLSettings -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MLSettings] -> ShowS
$cshowList :: [MLSettings] -> ShowS
show :: MLSettings -> String
$cshow :: MLSettings -> String
showsPrec :: Int -> MLSettings -> ShowS
$cshowsPrec :: Int -> MLSettings -> ShowS
Show)

instance HasAnalysisName MLSettings where
  getAnalysisName :: MLSettings -> AnalysisName
getAnalysisName = MLSettings -> AnalysisName
mlAnalysisName

instance HasExecutionMode MLSettings where
  getExecutionMode :: MLSettings -> ExecutionMode
getExecutionMode = MLSettings -> ExecutionMode
mlExecutionMode

instance HasLogMode MLSettings where
  getLogMode :: MLSettings -> LogMode
getLogMode = MLSettings -> LogMode
mlLogMode

instance HasVerbosity MLSettings where
  getVerbosity :: MLSettings -> Verbosity
getVerbosity = MLSettings -> Verbosity
mlVerbosity

type ML a = ReaderT (Environment MLSettings) IO a

-- See 'getPoints'. Alpha=0.3 is the standard choice.
alpha :: Double
alpha :: Double
alpha = Double
0.3

-- Distribute the points according to a skewed beta distribution with given
-- 'alpha' value. If alpha is below 1.0, more points at lower values, which is
-- desired. It is inconvenient that the reciprocal temperatures are denoted as
-- beta, and we also use the beta distribution :). Don't mix them up!
--
-- See discussion in Xie, W., Lewis, P. O., Fan, Y., Kuo, L., & Chen, M.,
-- Improving marginal likelihood estimation for bayesian phylogenetic model
-- selection, Systematic Biology, 60(2), 150–160 (2010).
-- http://dx.doi.org/10.1093/sysbio/syq085
--
-- Or Figure 1 in Höhna, S., Landis, M. J., & Huelsenbeck, J. P., Parallel power
-- posterior analyses for fast computation of marginal likelihoods in
-- phylogenetics (2017). http://dx.doi.org/10.1101/104422
getPoints :: NPoints -> [Point]
getPoints :: NPoints -> [Double]
getPoints NPoints
x = [forall {a} {a}. (Fractional a, Integral a) => a -> a
f Int
i forall a. Floating a => a -> a -> a
** (Double
1.0 forall a. Fractional a => a -> a -> a
/ Double
alpha) | Int
i <- [Int
0 .. Int
k1]]
  where
    k :: Int
k = NPoints -> Int
fromNPoints NPoints
x
    k1 :: Int
k1 = forall a. Enum a => a -> a
pred Int
k
    f :: a -> a
f a
j = forall a b. (Integral a, Num b) => a -> b
fromIntegral a
j forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k1

sampleAtPoint ::
  ToJSON a =>
  Bool ->
  Point ->
  Settings ->
  LikelihoodFunction a ->
  MHG a ->
  ML (MHG a)
sampleAtPoint :: forall a.
ToJSON a =>
Bool
-> Double
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML (MHG a)
sampleAtPoint Bool
isInitialBurnIn Double
x Settings
ss LikelihoodFunction a
lhf MHG a
a = do
  MHG a
a'' <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Algorithm a => Settings -> a -> IO a
mcmc Settings
ss' MHG a
a'
  let ch'' :: Chain a
ch'' = forall a. MHG a -> Chain a
fromMHG MHG a
a''
      ac :: Acceptances (Proposal a)
ac = forall a. Chain a -> Acceptances (Proposal a)
acceptances Chain a
ch''
      mAr :: Maybe (Map (Proposal a) Double)
mAr = forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ forall k. Acceptances k -> Map k (Maybe Double)
acceptanceRates Acceptances (Proposal a)
ac
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"sampleAtPoint: Summarize cycle."
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB forall a b. (a -> b) -> a -> b
$ forall a.
IterationMode -> Acceptances (Proposal a) -> Cycle a -> ByteString
summarizeCycle IterationMode
AllProposals Acceptances (Proposal a)
ac forall a b. (a -> b) -> a -> b
$ forall a. Chain a -> Cycle a
cycle Chain a
ch''
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
    Bool
isInitialBurnIn
    ( case Maybe (Map (Proposal a) Double)
mAr of
        Maybe (Map (Proposal a) Double)
Nothing -> forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logWarnB ByteString
"Some acceptance rates are unavailable."
        Just Map (Proposal a) Double
ar -> do
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall k a. Map k a -> Bool
M.null forall a b. (a -> b) -> a -> b
$ forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (forall a. Ord a => a -> a -> Bool
<= Double
0.1) Map (Proposal a) Double
ar) forall a b. (a -> b) -> a -> b
$ forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logWarnB ByteString
"Some acceptance rates are below 0.1."
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall k a. Map k a -> Bool
M.null forall a b. (a -> b) -> a -> b
$ forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (forall a. Ord a => a -> a -> Bool
>= Double
0.9) Map (Proposal a) Double
ar) forall a b. (a -> b) -> a -> b
$ forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logWarnB ByteString
"Some acceptance rates are above 0.9."
    )
  forall (m :: * -> *) a. Monad m => a -> m a
return MHG a
a''
  where
    -- For debugging set a proper analysis name.
    nm :: AnalysisName
nm = Settings -> AnalysisName
sAnalysisName Settings
ss
    getName :: Point -> AnalysisName
    getName :: Double -> AnalysisName
getName Double
y = AnalysisName
nm forall a. Semigroup a => a -> a -> a
<> String -> AnalysisName
AnalysisName (String
"/" forall a. Semigroup a => a -> a -> a
<> forall r. PrintfType r => String -> r
printf String
"point%.8f" Double
y)
    ss' :: Settings
ss' = Settings
ss {sAnalysisName :: AnalysisName
sAnalysisName = Double -> AnalysisName
getName Double
x}
    -- Amend the likelihood function. Don't calculate the likelihood when the
    -- point is 0.0.
    lhf' :: LikelihoodFunction a
lhf' = if Double
x forall a. Eq a => a -> a -> Bool
== Double
0.0 then forall a b. a -> b -> a
const MarginalLikelihood
1.0 else (forall a. Floating a => a -> a -> a
** forall a. a -> Log a
Exp (forall a. Floating a => a -> a
log Double
x)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. LikelihoodFunction a
lhf
    -- Amend the MHG algorithm.
    ch :: Chain a
ch = forall a. MHG a -> Chain a
fromMHG MHG a
a
    l :: Link a
l = forall a. Chain a -> Link a
link Chain a
ch
    ch' :: Chain a
ch' =
      Chain a
ch
        { -- Important: Update the likelihood using the new likelihood function.
          link :: Link a
link = Link a
l {likelihood :: MarginalLikelihood
likelihood = LikelihoodFunction a
lhf' forall a b. (a -> b) -> a -> b
$ forall a. Link a -> a
state Link a
l},
          iteration :: Int
iteration = Int
0,
          start :: Int
start = Int
0,
          likelihoodFunction :: LikelihoodFunction a
likelihoodFunction = LikelihoodFunction a
lhf'
        }
    a' :: MHG a
a' = forall a. Chain a -> MHG a
MHG Chain a
ch'

traversePoints ::
  ToJSON a =>
  NPoints ->
  [(Int, Point)] ->
  Settings ->
  LikelihoodFunction a ->
  MHG a ->
  -- For each point a vector of obtained likelihoods stored in the log domain.
  ML [VU.Vector Likelihood]
traversePoints :: forall a.
ToJSON a =>
NPoints
-> [(Int, Double)]
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML [Vector MarginalLikelihood]
traversePoints NPoints
_ [] Settings
_ LikelihoodFunction a
_ MHG a
_ = forall (m :: * -> *) a. Monad m => a -> m a
return []
traversePoints NPoints
k ((Int
idb, Double
b) : [(Int, Double)]
bs) Settings
ss LikelihoodFunction a
lhf MHG a
a = do
  let msg :: String
msg = forall r. PrintfType r => String -> r
printf String
"Point %4d of %4d: %.12f." Int
idb Int
k' Double
b
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logInfoS String
msg
  MHG a
a' <- forall a.
ToJSON a =>
Bool
-> Double
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML (MHG a)
sampleAtPoint Bool
False Double
b Settings
ss LikelihoodFunction a
lhf MHG a
a
  -- Get the links samples at this point.
  Vector (Link a)
ls <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Int -> Trace a -> IO (Vector (Link a))
takeT Int
n forall a b. (a -> b) -> a -> b
$ forall a. Chain a -> Trace a
trace forall a b. (a -> b) -> a -> b
$ forall a. MHG a -> Chain a
fromMHG MHG a
a'
  -- Extract the likelihoods.
  --
  -- NOTE: This could be sped up by mapping (** -b) on the power likelihoods.
  --
  -- NOTE: This bang is an important one, because if the lhs are not strictly
  -- calculated here, the complete MCMC runs are dragged along before doing so
  -- resulting in a severe memory leak.
  let !lhs :: Vector MarginalLikelihood
lhs = forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> Vector a -> Vector b
VB.map (LikelihoodFunction a
lhf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Link a -> a
state) Vector (Link a)
ls
  -- Sample the other points.
  [Vector MarginalLikelihood]
lhss <- forall a.
ToJSON a =>
NPoints
-> [(Int, Double)]
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML [Vector MarginalLikelihood]
traversePoints NPoints
k [(Int, Double)]
bs Settings
ss LikelihoodFunction a
lhf MHG a
a'
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Vector MarginalLikelihood
lhs forall a. a -> [a] -> [a]
: [Vector MarginalLikelihood]
lhss
  where
    n :: Int
n = Iterations -> Int
fromIterations forall a b. (a -> b) -> a -> b
$ Settings -> Iterations
sIterations Settings
ss
    (NPoints Int
k') = NPoints
k

nChunks :: Int -> [a] -> [[a]]
nChunks :: forall a. Int -> [a] -> [[a]]
nChunks Int
k [a]
xs = forall a. [Int] -> [a] -> [[a]]
chop (Int -> Int -> [Int]
chunks Int
k Int
l) [a]
xs
  where
    l :: Int
l = forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs

chunks :: Int -> Int -> [Int]
chunks :: Int -> Int -> [Int]
chunks Int
c Int
n = forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Ord a => a -> a -> Bool
> Int
0) [Int]
ns
  where
    n' :: Int
n' = Int
n forall a. Integral a => a -> a -> a
`div` Int
c
    r :: Int
r = Int
n forall a. Integral a => a -> a -> a
`mod` Int
c
    ns :: [Int]
ns = forall a. Int -> a -> [a]
replicate Int
r (Int
n' forall a. Num a => a -> a -> a
+ Int
1) forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate (Int
c forall a. Num a => a -> a -> a
- Int
r) Int
n'

chop :: [Int] -> [a] -> [[a]]
chop :: forall a. [Int] -> [a] -> [[a]]
chop [] [] = []
chop (Int
n : [Int]
ns) [a]
xs
  | Int
n forall a. Ord a => a -> a -> Bool
> Int
0 = forall a. Int -> [a] -> [a]
take Int
n [a]
xs forall a. a -> [a] -> [a]
: forall a. [Int] -> [a] -> [[a]]
chop [Int]
ns (forall a. Int -> [a] -> [a]
drop Int
n [a]
xs)
  | Bool
otherwise = forall a. HasCallStack => String -> a
error String
"chop: n negative or zero"
chop [Int]
_ [a]
_ = forall a. HasCallStack => String -> a
error String
"chop: not all list elements handled"

mlRunPar ::
  ToJSON a =>
  ParallelizationMode ->
  NPoints ->
  [(Int, Point)] ->
  ExecutionMode ->
  Verbosity ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  a ->
  StdGen ->
  ML [VU.Vector Likelihood]
mlRunPar :: forall a.
ToJSON a =>
ParallelizationMode
-> NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRunPar ParallelizationMode
pm NPoints
k [(Int, Double)]
xs ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  Int
nThreads <- case ParallelizationMode
pm of
    ParallelizationMode
Sequential -> do
      forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"Sequential execution."
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
1
    ParallelizationMode
Parallel -> do
      Int
n <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Int
getNumCapabilities
      forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logInfoS forall a b. (a -> b) -> a -> b
$ String
"Parallel execution with " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n forall a. Semigroup a => a -> a -> a
<> String
" cores."
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
n
  let xsChunks :: [[(Int, Double)]]
xsChunks = forall a. Int -> [a] -> [[a]]
nChunks Int
nThreads [(Int, Double)]
xs
  Environment MLSettings
r <- forall (m :: * -> *) r. Monad m => ReaderT r m r
ask
  [[Vector MarginalLikelihood]]
xss <-
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$
      forall (t :: * -> *) a b.
Traversable t =>
(a -> IO b) -> t a -> IO (t b)
mapConcurrently
        (\[(Int, Double)]
thesePoints -> forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall a.
ToJSON a =>
NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRun NPoints
k [(Int, Double)]
thesePoints ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g) Environment MLSettings
r)
        [[(Int, Double)]]
xsChunks
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Vector MarginalLikelihood]]
xss

mlRun ::
  ToJSON a =>
  NPoints ->
  [(Int, Point)] ->
  ExecutionMode ->
  Verbosity ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  a ->
  StdGen ->
  -- For each point a vector of likelihoods stored in log domain.
  ML [VU.Vector Likelihood]
mlRun :: forall a.
ToJSON a =>
NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRun NPoints
k [(Int, Double)]
xs ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"mlRun: Begin."
  MLSettings
s <- forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
reader forall s. Environment s -> s
settings
  let nm :: AnalysisName
nm = MLSettings -> AnalysisName
mlAnalysisName MLSettings
s
      is :: Iterations
is = MLSettings -> Iterations
mlIterations MLSettings
s
      biI :: BurnInSettings
biI = MLSettings -> BurnInSettings
mlInitialBurnIn MLSettings
s
      biP :: BurnInSettings
biP = MLSettings -> BurnInSettings
mlPointBurnIn MLSettings
s
      -- Only log sub MCMC samplers when debugging.
      vb' :: Verbosity
vb' = case Verbosity
vb of
        Verbosity
Debug -> Verbosity
Debug
        Verbosity
_ -> Verbosity
Quiet
      trLen :: TraceLength
trLen = Int -> TraceLength
TraceMinimum forall a b. (a -> b) -> a -> b
$ Iterations -> Int
fromIterations Iterations
is
      ssI :: Settings
ssI = AnalysisName
-> BurnInSettings
-> Iterations
-> TraceLength
-> ExecutionMode
-> ParallelizationMode
-> SaveMode
-> LogMode
-> Verbosity
-> Settings
Settings AnalysisName
nm BurnInSettings
biI (Int -> Iterations
Iterations Int
0) TraceLength
trLen ExecutionMode
em ParallelizationMode
Sequential SaveMode
NoSave LogMode
LogFileOnly Verbosity
vb'
      ssP :: Settings
ssP = AnalysisName
-> BurnInSettings
-> Iterations
-> TraceLength
-> ExecutionMode
-> ParallelizationMode
-> SaveMode
-> LogMode
-> Verbosity
-> Settings
Settings AnalysisName
nm BurnInSettings
biP Iterations
is TraceLength
trLen ExecutionMode
em ParallelizationMode
Sequential SaveMode
NoSave LogMode
LogFileOnly Verbosity
vb'
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"mlRun: Initialize MHG algorithm."
  MHG a
a0 <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a.
Settings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> IO (MHG a)
mhg Settings
ssI PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g
  let msg :: String
msg = forall r. PrintfType r => String -> r
printf String
"Initial burn in at point %.12f with ID %4d." Double
x0 Int
id0
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logInfoS String
msg
  MHG a
a1 <- forall a.
ToJSON a =>
Bool
-> Double
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML (MHG a)
sampleAtPoint Bool
True Double
x0 Settings
ssI PriorFunction a
lhf MHG a
a0
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"mlRun: Traverse points."
  forall a.
ToJSON a =>
NPoints
-> [(Int, Double)]
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML [Vector MarginalLikelihood]
traversePoints NPoints
k [(Int, Double)]
xs Settings
ssP PriorFunction a
lhf MHG a
a1
  where
    (Int
id0, Double
x0) = forall a. [a] -> a
head [(Int, Double)]
xs

-- Use lists since the number of points is expected to be low.
integrateSimpsonTriangle ::
  -- X values.
  [Point] ->
  -- Y values.
  [Double] ->
  -- Integral.
  Double
integrateSimpsonTriangle :: [Double] -> [Double] -> Double
integrateSimpsonTriangle [Double]
xs [Double]
ys = Double
0.5 forall a. Num a => a -> a -> a
* forall {a}. Num a => [a] -> [a] -> a
go [Double]
xs [Double]
ys
  where
    go :: [a] -> [a] -> a
go (a
p0 : a
p1 : [a]
ps) (a
z0 : a
z1 : [a]
zs) = (a
z0 forall a. Num a => a -> a -> a
+ a
z1) forall a. Num a => a -> a -> a
* (a
p1 forall a. Num a => a -> a -> a
- a
p0) forall a. Num a => a -> a -> a
+ [a] -> [a] -> a
go (a
p1 forall a. a -> [a] -> [a]
: [a]
ps) (a
z1 forall a. a -> [a] -> [a]
: [a]
zs)
    go [a]
_ [a]
_ = a
0

tiWrapper ::
  ToJSON a =>
  MLSettings ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  a ->
  StdGen ->
  ML MarginalLikelihood
tiWrapper :: forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
tiWrapper MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"Path integral (thermodynamic integration)."
  let (StdGen
g0, StdGen
g1) = forall g. RandomGen g => g -> (g, g)
split StdGen
g

  -- Parallel execution of both path integrals.
  Environment MLSettings
r <- forall (m :: * -> *) r. Monad m => ReaderT r m r
ask
  ([Vector MarginalLikelihood]
lhssForward, [Vector MarginalLikelihood]
lhssBackward) <-
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$
      forall a b. IO a -> IO b -> IO (a, b)
concurrently
        (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall a.
ToJSON a =>
ParallelizationMode
-> NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRunPar ParallelizationMode
pm NPoints
k (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..] [Double]
bsForward) ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g0) Environment MLSettings
r)
        (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall a.
ToJSON a =>
ParallelizationMode
-> NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRunPar ParallelizationMode
pm NPoints
k (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..] [Double]
bsBackward) ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g1) Environment MLSettings
r)
  forall e.
(HasLock e, HasLogHandles e, HasStartingTime e, HasVerbosity e) =>
Logger e ()
logInfoEndTime

  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"tiWrapper: Calculate mean log likelihoods."
  -- It is important to average across the log likelihoods here (and not the
  -- likelihoods). I am not exactly sure why this is.
  let getMeanLogLhs :: [Vector MarginalLikelihood] -> [Double]
getMeanLogLhs = forall a b. (a -> b) -> [a] -> [b]
map (\Vector MarginalLikelihood
x -> forall a. (Unbox a, Num a) => Vector a -> a
VU.sum (forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map forall a. Log a -> a
ln Vector MarginalLikelihood
x) forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Unbox a => Vector a -> Int
VU.length Vector MarginalLikelihood
x))
      mlForward :: Double
mlForward = [Double] -> [Double] -> Double
integrateSimpsonTriangle [Double]
bsForward ([Vector MarginalLikelihood] -> [Double]
getMeanLogLhs [Vector MarginalLikelihood]
lhssForward)
      mlBackward :: Double
mlBackward = forall a. Num a => a -> a
negate forall a b. (a -> b) -> a -> b
$ [Double] -> [Double] -> Double
integrateSimpsonTriangle [Double]
bsBackward ([Vector MarginalLikelihood] -> [Double]
getMeanLogLhs [Vector MarginalLikelihood]
lhssBackward)
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logDebugS forall a b. (a -> b) -> a -> b
$ String
"tiWrapper: Marginal log likelihood of forward integral: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Double
mlForward
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logDebugS forall a b. (a -> b) -> a -> b
$ String
"tiWrapper: Marginal log likelihood of backward integral: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Double
mlBackward
  let mean :: Double
mean = Double
0.5 forall a. Num a => a -> a -> a
* (Double
mlForward forall a. Num a => a -> a -> a
+ Double
mlBackward)
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logDebugS forall a b. (a -> b) -> a -> b
$ String
"tiWrapper: The mean is: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Double
mean
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Log a
Exp Double
mean
  where
    k :: NPoints
k = MLSettings -> NPoints
mlNPoints MLSettings
s
    bsForward :: [Double]
bsForward = NPoints -> [Double]
getPoints NPoints
k
    bsBackward :: [Double]
bsBackward = forall a. [a] -> [a]
reverse [Double]
bsForward
    em :: ExecutionMode
em = MLSettings -> ExecutionMode
mlExecutionMode MLSettings
s
    pm :: ParallelizationMode
pm = MLSettings -> ParallelizationMode
mlParallelizationMode MLSettings
s
    vb :: Verbosity
vb = MLSettings -> Verbosity
mlVerbosity MLSettings
s

-- Helper function to exponentiate log domain values with a double value.
pow' :: Log Double -> Double -> Log Double
pow' :: MarginalLikelihood -> Double -> MarginalLikelihood
pow' MarginalLikelihood
x Double
p = forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ forall a. Log a -> a
ln MarginalLikelihood
x forall a. Num a => a -> a -> a
* Double
p

-- See Xie2010 p. 153, bottom left.
sssCalculateMarginalLikelihood :: [Point] -> [VU.Vector Likelihood] -> MarginalLikelihood
sssCalculateMarginalLikelihood :: [Double] -> [Vector MarginalLikelihood] -> MarginalLikelihood
sssCalculateMarginalLikelihood [Double]
xs [Vector MarginalLikelihood]
lhss = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Double -> Double -> Vector MarginalLikelihood -> MarginalLikelihood
f [Double]
xs (forall a. [a] -> [a]
tail [Double]
xs) [Vector MarginalLikelihood]
lhss
  where
    f :: Point -> Point -> VU.Vector Likelihood -> MarginalLikelihood
    -- f beta_{k-1} beta_k lhs_{k-1}
    f :: Double -> Double -> Vector MarginalLikelihood -> MarginalLikelihood
f Double
bkm1 Double
bk Vector MarginalLikelihood
lhs = MarginalLikelihood
n1 forall a. Num a => a -> a -> a
* forall a. (Unbox a, Num a) => Vector a -> a
VU.sum Vector MarginalLikelihood
lhsPowered
      where
        n1 :: MarginalLikelihood
n1 = forall a. Fractional a => a -> a
recip forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Unbox a => Vector a -> Int
VU.length Vector MarginalLikelihood
lhs
        dbeta :: Double
dbeta = Double
bk forall a. Num a => a -> a -> a
- Double
bkm1
        lhsPowered :: Vector MarginalLikelihood
lhsPowered = forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map (MarginalLikelihood -> Double -> MarginalLikelihood
`pow'` Double
dbeta) Vector MarginalLikelihood
lhs

-- -- Numerical stability by factoring out lhMax. But no observed
-- -- improvement towards the standard version.
--
-- f bkm1 bk lhs = n1 * pow' lhMax dbeta * VU.sum lhsNormedPowered
--   where n1 = recip $ fromIntegral $ VU.length lhs
--         lhMax = VU.maximum lhs
--         dbeta = bk - bkm1
--         lhsNormed = VU.map (/lhMax) lhs
--         lhsNormedPowered = VU.map (`pow'` dbeta) lhsNormed

-- -- Computation of the log of the marginal likelihood. According to the paper,
-- -- this estimator is biased and I did not observe any improvements compared
-- -- to the direct estimator implemented above.
--
-- -- See Xie2010 p. 153, top right.
-- sssCalculateMarginalLikelihood' :: [Point] -> [VU.Vector Likelihood] -> MarginalLikelihood
-- sssCalculateMarginalLikelihood' xs lhss = Exp $ sum $ zipWith3 f xs (tail xs) lhss
--   where f :: Point -> Point -> VU.Vector Likelihood -> Double
--         -- f beta_{k-1} beta_k lhs_{k-1}
--         f bkm1 bk lhs = dbeta * llhMax + log (n1 * VU.sum lhsNormedPowered)
--           where dbeta = bk - bkm1
--                 llhMax = ln $ VU.maximum lhs
--                 n1 = recip $ fromIntegral $ VU.length lhs
--                 llhs = VU.map ln lhs
--                 llhsNormed = VU.map (\x -> x - llhMax) llhs
--                 lhsNormedPowered = VU.map (\x -> exp $ dbeta * x) llhsNormed
sssWrapper ::
  ToJSON a =>
  MLSettings ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  a ->
  StdGen ->
  ML MarginalLikelihood
sssWrapper :: forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
sssWrapper MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"Stepping stone sampling."
  [Vector MarginalLikelihood]
logLhss <- forall a.
ToJSON a =>
ParallelizationMode
-> NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRunPar ParallelizationMode
pm NPoints
k (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..] [Double]
bsForward') ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"The last point does not need to be sampled with stepping stone sampling."
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"sssWrapper: Calculate marginal likelihood."
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Double] -> [Vector MarginalLikelihood] -> MarginalLikelihood
sssCalculateMarginalLikelihood [Double]
bsForward [Vector MarginalLikelihood]
logLhss
  where
    k :: NPoints
k = MLSettings -> NPoints
mlNPoints MLSettings
s
    bsForward :: [Double]
bsForward = NPoints -> [Double]
getPoints NPoints
k
    bsForward' :: [Double]
bsForward' = forall a. [a] -> [a]
init [Double]
bsForward
    em :: ExecutionMode
em = MLSettings -> ExecutionMode
mlExecutionMode MLSettings
s
    pm :: ParallelizationMode
pm = MLSettings -> ParallelizationMode
mlParallelizationMode MLSettings
s
    vb :: Verbosity
vb = MLSettings -> Verbosity
mlVerbosity MLSettings
s

-- | Estimate the marginal likelihood.
marginalLikelihood ::
  ToJSON a =>
  MLSettings ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  InitialState a ->
  StdGen ->
  IO MarginalLikelihood
marginalLikelihood :: forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> IO MarginalLikelihood
marginalLikelihood MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  -- Initialize.
  Environment MLSettings
e <- forall s.
(HasAnalysisName s, HasExecutionMode s, HasLogMode s,
 HasVerbosity s) =>
s -> IO (Environment s)
initializeEnvironment MLSettings
s

  -- Create marginal likelihood analysis directory.
  let n :: String
n = AnalysisName -> String
fromAnalysisName forall a b. (a -> b) -> a -> b
$ MLSettings -> AnalysisName
mlAnalysisName MLSettings
s
  Bool -> String -> IO ()
createDirectoryIfMissing Bool
True String
n

  -- Run.
  forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT
    ( do
        forall e.
(HasLock e, HasLogHandles e, HasStartingTime e, HasVerbosity e) =>
Logger e ()
logInfoStartingTime
        forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"Estimate marginal likelihood."
        forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"marginalLikelihood: The marginal likelihood settings are:"
        forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logDebugS forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show MLSettings
s
        MarginalLikelihood
val <- case MLSettings -> MLAlgorithm
mlAlgorithm MLSettings
s of
          MLAlgorithm
ThermodynamicIntegration -> forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
tiWrapper MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g
          MLAlgorithm
SteppingStoneSampling -> forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
sssWrapper MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g
        forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logInfoS forall a b. (a -> b) -> a -> b
$ String
"Marginal log likelihood: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall a. Log a -> a
ln MarginalLikelihood
val)
        -- TODO (low): Simulation variance.
        forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logInfoS String
"The simulation variance is not yet available."
        forall (m :: * -> *) a. Monad m => a -> m a
return MarginalLikelihood
val
    )
    Environment MLSettings
e