{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module      :  Mcmc.MarginalLikelihood
-- Description :  Calculate the marginal likelihood
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Mon Jan 11 16:34:18 2021.
module Mcmc.MarginalLikelihood
  ( MarginalLikelihood,
    NPoints (..),
    MLAlgorithm (..),
    MLSettings (..),
    marginalLikelihood,
  )
where

import Control.Concurrent.Async hiding (link)
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Reader
import Data.Aeson
import Data.List hiding (cycle)
import qualified Data.Map.Strict as M
import qualified Data.Vector as VB
import qualified Data.Vector.Unboxed as VU
import Mcmc.Acceptance
import Mcmc.Algorithm.MHG
import Mcmc.Chain.Chain
import Mcmc.Chain.Link
import Mcmc.Chain.Trace
import Mcmc.Cycle
import Mcmc.Environment
import Mcmc.Likelihood
import Mcmc.Logger
import Mcmc.Mcmc
import Mcmc.Monitor
import Mcmc.Prior
import Mcmc.Settings
import Numeric.Log hiding (sum)
import System.Directory
import System.Random.Stateful
import Text.Printf
import Prelude hiding (cycle)

-- | Marginal likelihood values are stored in log domain.
type MarginalLikelihood = Log Double

-- Reciprocal temperature value traversed along the path integral.
type Point = Double

-- | The number of points used to approximate the path integral.
newtype NPoints = NPoints {NPoints -> Int
fromNPoints :: Int}
  deriving (NPoints -> NPoints -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: NPoints -> NPoints -> Bool
$c/= :: NPoints -> NPoints -> Bool
== :: NPoints -> NPoints -> Bool
$c== :: NPoints -> NPoints -> Bool
Eq, ReadPrec [NPoints]
ReadPrec NPoints
Int -> ReadS NPoints
ReadS [NPoints]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [NPoints]
$creadListPrec :: ReadPrec [NPoints]
readPrec :: ReadPrec NPoints
$creadPrec :: ReadPrec NPoints
readList :: ReadS [NPoints]
$creadList :: ReadS [NPoints]
readsPrec :: Int -> ReadS NPoints
$creadsPrec :: Int -> ReadS NPoints
Read, Int -> NPoints -> ShowS
[NPoints] -> ShowS
NPoints -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NPoints] -> ShowS
$cshowList :: [NPoints] -> ShowS
show :: NPoints -> String
$cshow :: NPoints -> String
showsPrec :: Int -> NPoints -> ShowS
$cshowsPrec :: Int -> NPoints -> ShowS
Show)

-- | Algorithms to calculate the marginal likelihood.
data MLAlgorithm
  = -- | Use a classical path integral. Also known as thermodynamic integration.
    -- In particular, /Annealing-Melting Integration/ is used.
    --
    -- See Lartillot, N., & Philippe, H., Computing Bayes Factors Using
    -- Thermodynamic Integration, Systematic Biology, 55(2), 195–207 (2006).
    -- http://dx.doi.org/10.1080/10635150500433722
    ThermodynamicIntegration
  | -- | Use stepping stone sampling.
    --
    -- See Xie, W., Lewis, P. O., Fan, Y., Kuo, L., & Chen, M., Improving
    -- marginal likelihood estimation for Bayesian phylogenetic model selection,
    -- Systematic Biology, 60(2), 150–160 (2010).
    -- http://dx.doi.org/10.1093/sysbio/syq085
    --
    -- Or Fan, Y., Wu, R., Chen, M., Kuo, L., & Lewis, P. O., Choosing among
    -- partition models in bayesian phylogenetics, Molecular Biology and
    -- Evolution, 28(1), 523–532 (2010). http://dx.doi.org/10.1093/molbev/msq224
    SteppingStoneSampling
  deriving (MLAlgorithm -> MLAlgorithm -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MLAlgorithm -> MLAlgorithm -> Bool
$c/= :: MLAlgorithm -> MLAlgorithm -> Bool
== :: MLAlgorithm -> MLAlgorithm -> Bool
$c== :: MLAlgorithm -> MLAlgorithm -> Bool
Eq, ReadPrec [MLAlgorithm]
ReadPrec MLAlgorithm
Int -> ReadS MLAlgorithm
ReadS [MLAlgorithm]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [MLAlgorithm]
$creadListPrec :: ReadPrec [MLAlgorithm]
readPrec :: ReadPrec MLAlgorithm
$creadPrec :: ReadPrec MLAlgorithm
readList :: ReadS [MLAlgorithm]
$creadList :: ReadS [MLAlgorithm]
readsPrec :: Int -> ReadS MLAlgorithm
$creadsPrec :: Int -> ReadS MLAlgorithm
Read, Int -> MLAlgorithm -> ShowS
[MLAlgorithm] -> ShowS
MLAlgorithm -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MLAlgorithm] -> ShowS
$cshowList :: [MLAlgorithm] -> ShowS
show :: MLAlgorithm -> String
$cshow :: MLAlgorithm -> String
showsPrec :: Int -> MLAlgorithm -> ShowS
$cshowsPrec :: Int -> MLAlgorithm -> ShowS
Show)

-- | Settings of the marginal likelihood estimation.
data MLSettings = MLSettings
  { MLSettings -> AnalysisName
mlAnalysisName :: AnalysisName,
    MLSettings -> MLAlgorithm
mlAlgorithm :: MLAlgorithm,
    MLSettings -> NPoints
mlNPoints :: NPoints,
    -- | Initial burn in at the starting point of the path.
    MLSettings -> BurnInSettings
mlInitialBurnIn :: BurnInSettings,
    -- | Repetitive burn in at each point on the path.
    MLSettings -> BurnInSettings
mlPointBurnIn :: BurnInSettings,
    -- | The number of iterations performed at each point.
    MLSettings -> Iterations
mlIterations :: Iterations,
    MLSettings -> ExecutionMode
mlExecutionMode :: ExecutionMode,
    MLSettings -> LogMode
mlLogMode :: LogMode,
    MLSettings -> Verbosity
mlVerbosity :: Verbosity
  }
  deriving (MLSettings -> MLSettings -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MLSettings -> MLSettings -> Bool
$c/= :: MLSettings -> MLSettings -> Bool
== :: MLSettings -> MLSettings -> Bool
$c== :: MLSettings -> MLSettings -> Bool
Eq, ReadPrec [MLSettings]
ReadPrec MLSettings
Int -> ReadS MLSettings
ReadS [MLSettings]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [MLSettings]
$creadListPrec :: ReadPrec [MLSettings]
readPrec :: ReadPrec MLSettings
$creadPrec :: ReadPrec MLSettings
readList :: ReadS [MLSettings]
$creadList :: ReadS [MLSettings]
readsPrec :: Int -> ReadS MLSettings
$creadsPrec :: Int -> ReadS MLSettings
Read, Int -> MLSettings -> ShowS
[MLSettings] -> ShowS
MLSettings -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MLSettings] -> ShowS
$cshowList :: [MLSettings] -> ShowS
show :: MLSettings -> String
$cshow :: MLSettings -> String
showsPrec :: Int -> MLSettings -> ShowS
$cshowsPrec :: Int -> MLSettings -> ShowS
Show)

instance HasAnalysisName MLSettings where
  getAnalysisName :: MLSettings -> AnalysisName
getAnalysisName = MLSettings -> AnalysisName
mlAnalysisName

instance HasExecutionMode MLSettings where
  getExecutionMode :: MLSettings -> ExecutionMode
getExecutionMode = MLSettings -> ExecutionMode
mlExecutionMode

instance HasLogMode MLSettings where
  getLogMode :: MLSettings -> LogMode
getLogMode = MLSettings -> LogMode
mlLogMode

instance HasVerbosity MLSettings where
  getVerbosity :: MLSettings -> Verbosity
getVerbosity = MLSettings -> Verbosity
mlVerbosity

type ML a = ReaderT (Environment MLSettings) IO a

-- See 'getPoints'. Alpha=0.3 is the standard choice.
alpha :: Double
alpha :: Double
alpha = Double
0.3

-- Distribute the points according to a skewed beta distribution with given
-- 'alpha' value. If alpha is below 1.0, more points at lower values, which is
-- desired. It is inconvenient that the reciprocal temperatures are denoted as
-- beta, and we also use the beta distribution :). Don't mix them up!
--
-- See discussion in Xie, W., Lewis, P. O., Fan, Y., Kuo, L., & Chen, M.,
-- Improving marginal likelihood estimation for bayesian phylogenetic model
-- selection, Systematic Biology, 60(2), 150–160 (2010).
-- http://dx.doi.org/10.1093/sysbio/syq085
--
-- Or Figure 1 in Höhna, S., Landis, M. J., & Huelsenbeck, J. P., Parallel power
-- posterior analyses for fast computation of marginal likelihoods in
-- phylogenetics (2017). http://dx.doi.org/10.1101/104422
getPoints :: NPoints -> [Point]
getPoints :: NPoints -> [Double]
getPoints NPoints
x = [forall {a} {a}. (Fractional a, Integral a) => a -> a
f Int
i forall a. Floating a => a -> a -> a
** (Double
1.0 forall a. Fractional a => a -> a -> a
/ Double
alpha) | Int
i <- [Int
0 .. Int
k1]]
  where
    k :: Int
k = NPoints -> Int
fromNPoints NPoints
x
    k1 :: Int
k1 = forall a. Enum a => a -> a
pred Int
k
    f :: a -> a
f a
j = forall a b. (Integral a, Num b) => a -> b
fromIntegral a
j forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k1

sampleAtPoint ::
  ToJSON a =>
  Point ->
  Settings ->
  LikelihoodFunction a ->
  MHG a ->
  ML (MHG a)
sampleAtPoint :: forall a.
ToJSON a =>
Double -> Settings -> LikelihoodFunction a -> MHG a -> ML (MHG a)
sampleAtPoint Double
x Settings
ss LikelihoodFunction a
lhf MHG a
a = do
  MHG a
a'' <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Algorithm a => Settings -> a -> IO a
mcmc Settings
ss' MHG a
a'
  let ch'' :: Chain a
ch'' = forall a. MHG a -> Chain a
fromMHG MHG a
a''
      ac :: Acceptances (Proposal a)
ac = forall a. Chain a -> Acceptances (Proposal a)
acceptances Chain a
ch''
      mAr :: Maybe (Map (Proposal a) Double)
mAr = forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ forall k. Acceptances k -> Map k (Maybe Double)
acceptanceRates Acceptances (Proposal a)
ac
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"sampleAtPoint: Summarize cycle."
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB forall a b. (a -> b) -> a -> b
$ forall a.
IterationMode -> Acceptances (Proposal a) -> Cycle a -> ByteString
summarizeCycle IterationMode
AllProposals Acceptances (Proposal a)
ac forall a b. (a -> b) -> a -> b
$ forall a. Chain a -> Cycle a
cycle Chain a
ch''
  case Maybe (Map (Proposal a) Double)
mAr of
    Maybe (Map (Proposal a) Double)
Nothing -> forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logWarnB ByteString
"Some acceptance rates are unavailable."
    Just Map (Proposal a) Double
ar -> do
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall k a. Map k a -> Bool
M.null forall a b. (a -> b) -> a -> b
$ forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (forall a. Ord a => a -> a -> Bool
<= Double
0.1) Map (Proposal a) Double
ar) forall a b. (a -> b) -> a -> b
$ forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logWarnB ByteString
"Some acceptance rates are below 0.1."
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall k a. Map k a -> Bool
M.null forall a b. (a -> b) -> a -> b
$ forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (forall a. Ord a => a -> a -> Bool
>= Double
0.9) Map (Proposal a) Double
ar) forall a b. (a -> b) -> a -> b
$ forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logWarnB ByteString
"Some acceptance rates are above 0.9."
  forall (m :: * -> *) a. Monad m => a -> m a
return MHG a
a''
  where
    -- For debugging set a proper analysis name.
    nm :: AnalysisName
nm = Settings -> AnalysisName
sAnalysisName Settings
ss
    getName :: Point -> AnalysisName
    getName :: Double -> AnalysisName
getName Double
y = AnalysisName
nm forall a. Semigroup a => a -> a -> a
<> String -> AnalysisName
AnalysisName (String
"/" forall a. Semigroup a => a -> a -> a
<> forall r. PrintfType r => String -> r
printf String
"point%.8f" Double
y)
    ss' :: Settings
ss' = Settings
ss {sAnalysisName :: AnalysisName
sAnalysisName = Double -> AnalysisName
getName Double
x}
    -- Amend the likelihood function. Don't calculate the likelihood when the
    -- point is 0.0.
    lhf' :: LikelihoodFunction a
lhf' = if Double
x forall a. Eq a => a -> a -> Bool
== Double
0.0 then forall a b. a -> b -> a
const MarginalLikelihood
1.0 else (forall a. Floating a => a -> a -> a
** forall a. a -> Log a
Exp (forall a. Floating a => a -> a
log Double
x)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. LikelihoodFunction a
lhf
    -- Amend the MHG algorithm.
    ch :: Chain a
ch = forall a. MHG a -> Chain a
fromMHG MHG a
a
    l :: Link a
l = forall a. Chain a -> Link a
link Chain a
ch
    ch' :: Chain a
ch' =
      Chain a
ch
        { -- Important: Update the likelihood using the new likelihood function.
          link :: Link a
link = Link a
l {likelihood :: MarginalLikelihood
likelihood = LikelihoodFunction a
lhf' forall a b. (a -> b) -> a -> b
$ forall a. Link a -> a
state Link a
l},
          iteration :: Int
iteration = Int
0,
          start :: Int
start = Int
0,
          likelihoodFunction :: LikelihoodFunction a
likelihoodFunction = LikelihoodFunction a
lhf'
        }
    a' :: MHG a
a' = forall a. Chain a -> MHG a
MHG Chain a
ch'

traversePoints ::
  ToJSON a =>
  -- Current point.
  Int ->
  NPoints ->
  [Point] ->
  Settings ->
  LikelihoodFunction a ->
  MHG a ->
  -- For each point a vector of obtained likelihoods stored in the log domain.
  ML [VU.Vector Likelihood]
traversePoints :: forall a.
ToJSON a =>
Int
-> NPoints
-> [Double]
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML [Vector MarginalLikelihood]
traversePoints Int
_ NPoints
_ [] Settings
_ LikelihoodFunction a
_ MHG a
_ = forall (m :: * -> *) a. Monad m => a -> m a
return []
traversePoints Int
i NPoints
k (Double
b : [Double]
bs) Settings
ss LikelihoodFunction a
lhf MHG a
a = do
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logInfoS forall a b. (a -> b) -> a -> b
$ String
"Point " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
i forall a. Semigroup a => a -> a -> a
<> String
" of " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k' forall a. Semigroup a => a -> a -> a
<> String
": " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Double
b forall a. Semigroup a => a -> a -> a
<> String
"."
  MHG a
a' <- forall a.
ToJSON a =>
Double -> Settings -> LikelihoodFunction a -> MHG a -> ML (MHG a)
sampleAtPoint Double
b Settings
ss LikelihoodFunction a
lhf MHG a
a
  -- Get the links samples at this point.
  Vector (Link a)
ls <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Int -> Trace a -> IO (Vector (Link a))
takeT Int
n forall a b. (a -> b) -> a -> b
$ forall a. Chain a -> Trace a
trace forall a b. (a -> b) -> a -> b
$ forall a. MHG a -> Chain a
fromMHG MHG a
a'
  -- Extract the likelihoods.
  --
  -- NOTE: This could be sped up by mapping (** -b) on the power likelihoods.
  --
  -- NOTE: This bang is an important one, because if the lhs are not strictly
  -- calculated here, the complete MCMC runs are dragged along before doing so
  -- resulting in a severe memory leak.
  let !lhs :: Vector MarginalLikelihood
lhs = forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> Vector a -> Vector b
VB.map (LikelihoodFunction a
lhf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Link a -> a
state) Vector (Link a)
ls
  -- Sample the other points.
  [Vector MarginalLikelihood]
lhss <- forall a.
ToJSON a =>
Int
-> NPoints
-> [Double]
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML [Vector MarginalLikelihood]
traversePoints (Int
i forall a. Num a => a -> a -> a
+ Int
1) NPoints
k [Double]
bs Settings
ss LikelihoodFunction a
lhf MHG a
a'
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Vector MarginalLikelihood
lhs forall a. a -> [a] -> [a]
: [Vector MarginalLikelihood]
lhss
  where
    n :: Int
n = Iterations -> Int
fromIterations forall a b. (a -> b) -> a -> b
$ Settings -> Iterations
sIterations Settings
ss
    (NPoints Int
k') = NPoints
k

mlRun ::
  ToJSON a =>
  NPoints ->
  [Point] ->
  ExecutionMode ->
  Verbosity ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  a ->
  StdGen ->
  -- For each point a vector of likelihoods stored in log domain.
  ML [VU.Vector Likelihood]
mlRun :: forall a.
ToJSON a =>
NPoints
-> [Double]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRun NPoints
k [Double]
xs ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"mlRun: Begin."
  MLSettings
s <- forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
reader forall s. Environment s -> s
settings
  let nm :: AnalysisName
nm = MLSettings -> AnalysisName
mlAnalysisName MLSettings
s
      is :: Iterations
is = MLSettings -> Iterations
mlIterations MLSettings
s
      biI :: BurnInSettings
biI = MLSettings -> BurnInSettings
mlInitialBurnIn MLSettings
s
      biP :: BurnInSettings
biP = MLSettings -> BurnInSettings
mlPointBurnIn MLSettings
s
      -- Only log sub MCMC samplers when debugging.
      vb' :: Verbosity
vb' = case Verbosity
vb of
        Verbosity
Debug -> Verbosity
Debug
        Verbosity
_ -> Verbosity
Quiet
      trLen :: TraceLength
trLen = Int -> TraceLength
TraceMinimum forall a b. (a -> b) -> a -> b
$ Iterations -> Int
fromIterations Iterations
is
      ssI :: Settings
ssI = AnalysisName
-> BurnInSettings
-> Iterations
-> TraceLength
-> ExecutionMode
-> ParallelizationMode
-> SaveMode
-> LogMode
-> Verbosity
-> Settings
Settings AnalysisName
nm BurnInSettings
biI (Int -> Iterations
Iterations Int
0) TraceLength
trLen ExecutionMode
em ParallelizationMode
Sequential SaveMode
NoSave LogMode
LogFileOnly Verbosity
vb'
      ssP :: Settings
ssP = AnalysisName
-> BurnInSettings
-> Iterations
-> TraceLength
-> ExecutionMode
-> ParallelizationMode
-> SaveMode
-> LogMode
-> Verbosity
-> Settings
Settings AnalysisName
nm BurnInSettings
biP Iterations
is TraceLength
trLen ExecutionMode
em ParallelizationMode
Sequential SaveMode
NoSave LogMode
LogFileOnly Verbosity
vb'
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"mlRun: Initialize MHG algorithm."
  MHG a
a0 <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a.
Settings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> IO (MHG a)
mhg Settings
ssI PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logDebugS forall a b. (a -> b) -> a -> b
$ String
"mlRun: Perform initial burn in at first point " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Double
x0 forall a. Semigroup a => a -> a -> a
<> String
"."
  MHG a
a1 <- forall a.
ToJSON a =>
Double -> Settings -> LikelihoodFunction a -> MHG a -> ML (MHG a)
sampleAtPoint Double
x0 Settings
ssI PriorFunction a
lhf MHG a
a0
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"mlRun: Traverse points."
  forall a.
ToJSON a =>
Int
-> NPoints
-> [Double]
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML [Vector MarginalLikelihood]
traversePoints Int
1 NPoints
k [Double]
xs Settings
ssP PriorFunction a
lhf MHG a
a1
  where
    x0 :: Double
x0 = forall a. [a] -> a
head [Double]
xs

-- Use lists since the number of points is expected to be low.
integrateSimpsonTriangle ::
  -- X values.
  [Point] ->
  -- Y values.
  [Double] ->
  -- Integral.
  Double
integrateSimpsonTriangle :: [Double] -> [Double] -> Double
integrateSimpsonTriangle [Double]
xs [Double]
ys = Double
0.5 forall a. Num a => a -> a -> a
* forall {a}. Num a => [a] -> [a] -> a
go [Double]
xs [Double]
ys
  where
    go :: [a] -> [a] -> a
go (a
p0 : a
p1 : [a]
ps) (a
z0 : a
z1 : [a]
zs) = (a
z0 forall a. Num a => a -> a -> a
+ a
z1) forall a. Num a => a -> a -> a
* (a
p1 forall a. Num a => a -> a -> a
- a
p0) forall a. Num a => a -> a -> a
+ [a] -> [a] -> a
go (a
p1 forall a. a -> [a] -> [a]
: [a]
ps) (a
z1 forall a. a -> [a] -> [a]
: [a]
zs)
    go [a]
_ [a]
_ = a
0

tiWrapper ::
  ToJSON a =>
  MLSettings ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  a ->
  StdGen ->
  ML MarginalLikelihood
tiWrapper :: forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
tiWrapper MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"Path integral (thermodynamic integration)."
  let (StdGen
g0, StdGen
g1) = forall g. RandomGen g => g -> (g, g)
split StdGen
g

  -- Parallel execution of both path integrals.
  Environment MLSettings
r <- forall (m :: * -> *) r. Monad m => ReaderT r m r
ask
  ([Vector MarginalLikelihood]
lhssForward, [Vector MarginalLikelihood]
lhssBackward) <-
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$
      forall a b. IO a -> IO b -> IO (a, b)
concurrently
        (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall a.
ToJSON a =>
NPoints
-> [Double]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRun NPoints
k [Double]
bsForward ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g0) Environment MLSettings
r)
        (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall a.
ToJSON a =>
NPoints
-> [Double]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRun NPoints
k [Double]
bsBackward ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g1) Environment MLSettings
r)
  forall e.
(HasLock e, HasLogHandles e, HasStartingTime e, HasVerbosity e) =>
Logger e ()
logInfoEndTime

  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"tiWrapper: Calculate mean log likelihoods."
  -- It is important to average across the log likelihoods here (and not the
  -- likelihoods). I am not exactly sure why this is.
  let getMeanLogLhs :: [Vector MarginalLikelihood] -> [Double]
getMeanLogLhs = forall a b. (a -> b) -> [a] -> [b]
map (\Vector MarginalLikelihood
x -> forall a. (Unbox a, Num a) => Vector a -> a
VU.sum (forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map forall a. Log a -> a
ln Vector MarginalLikelihood
x) forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Unbox a => Vector a -> Int
VU.length Vector MarginalLikelihood
x))
      mlForward :: Double
mlForward = [Double] -> [Double] -> Double
integrateSimpsonTriangle [Double]
bsForward ([Vector MarginalLikelihood] -> [Double]
getMeanLogLhs [Vector MarginalLikelihood]
lhssForward)
      mlBackward :: Double
mlBackward = forall a. Num a => a -> a
negate forall a b. (a -> b) -> a -> b
$ [Double] -> [Double] -> Double
integrateSimpsonTriangle [Double]
bsBackward ([Vector MarginalLikelihood] -> [Double]
getMeanLogLhs [Vector MarginalLikelihood]
lhssBackward)
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logDebugS forall a b. (a -> b) -> a -> b
$ String
"tiWrapper: Marginal log likelihood of forward integral: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Double
mlForward
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logDebugS forall a b. (a -> b) -> a -> b
$ String
"tiWrapper: Marginal log likelihood of backward integral: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Double
mlBackward
  let mean :: Double
mean = Double
0.5 forall a. Num a => a -> a -> a
* (Double
mlForward forall a. Num a => a -> a -> a
+ Double
mlBackward)
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logDebugS forall a b. (a -> b) -> a -> b
$ String
"tiWrapper: The mean is: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Double
mean
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Log a
Exp Double
mean
  where
    k :: NPoints
k = MLSettings -> NPoints
mlNPoints MLSettings
s
    bsForward :: [Double]
bsForward = NPoints -> [Double]
getPoints NPoints
k
    bsBackward :: [Double]
bsBackward = forall a. [a] -> [a]
reverse [Double]
bsForward
    em :: ExecutionMode
em = MLSettings -> ExecutionMode
mlExecutionMode MLSettings
s
    vb :: Verbosity
vb = MLSettings -> Verbosity
mlVerbosity MLSettings
s

-- Helper function to exponentiate log domain values with a double value.
pow' :: Log Double -> Double -> Log Double
pow' :: MarginalLikelihood -> Double -> MarginalLikelihood
pow' MarginalLikelihood
x Double
p = forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ forall a. Log a -> a
ln MarginalLikelihood
x forall a. Num a => a -> a -> a
* Double
p

-- See Xie2010 p. 153, bottom left.
sssCalculateMarginalLikelihood :: [Point] -> [VU.Vector Likelihood] -> MarginalLikelihood
sssCalculateMarginalLikelihood :: [Double] -> [Vector MarginalLikelihood] -> MarginalLikelihood
sssCalculateMarginalLikelihood [Double]
xs [Vector MarginalLikelihood]
lhss = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Double -> Double -> Vector MarginalLikelihood -> MarginalLikelihood
f [Double]
xs (forall a. [a] -> [a]
tail [Double]
xs) [Vector MarginalLikelihood]
lhss
  where
    f :: Point -> Point -> VU.Vector Likelihood -> MarginalLikelihood
    -- f beta_{k-1} beta_k lhs_{k-1}
    f :: Double -> Double -> Vector MarginalLikelihood -> MarginalLikelihood
f Double
bkm1 Double
bk Vector MarginalLikelihood
lhs = MarginalLikelihood
n1 forall a. Num a => a -> a -> a
* forall a. (Unbox a, Num a) => Vector a -> a
VU.sum Vector MarginalLikelihood
lhsPowered
      where
        n1 :: MarginalLikelihood
n1 = forall a. Fractional a => a -> a
recip forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Unbox a => Vector a -> Int
VU.length Vector MarginalLikelihood
lhs
        dbeta :: Double
dbeta = Double
bk forall a. Num a => a -> a -> a
- Double
bkm1
        lhsPowered :: Vector MarginalLikelihood
lhsPowered = forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map (MarginalLikelihood -> Double -> MarginalLikelihood
`pow'` Double
dbeta) Vector MarginalLikelihood
lhs

-- -- Numerical stability by factoring out lhMax. But no observed
-- -- improvement towards the standard version.
--
-- f bkm1 bk lhs = n1 * pow' lhMax dbeta * VU.sum lhsNormedPowered
--   where n1 = recip $ fromIntegral $ VU.length lhs
--         lhMax = VU.maximum lhs
--         dbeta = bk - bkm1
--         lhsNormed = VU.map (/lhMax) lhs
--         lhsNormedPowered = VU.map (`pow'` dbeta) lhsNormed

-- -- Computation of the log of the marginal likelihood. According to the paper,
-- -- this estimator is biased and I did not observe any improvements compared
-- -- to the direct estimator implemented above.
--
-- -- See Xie2010 p. 153, top right.
-- sssCalculateMarginalLikelihood' :: [Point] -> [VU.Vector Likelihood] -> MarginalLikelihood
-- sssCalculateMarginalLikelihood' xs lhss = Exp $ sum $ zipWith3 f xs (tail xs) lhss
--   where f :: Point -> Point -> VU.Vector Likelihood -> Double
--         -- f beta_{k-1} beta_k lhs_{k-1}
--         f bkm1 bk lhs = dbeta * llhMax + log (n1 * VU.sum lhsNormedPowered)
--           where dbeta = bk - bkm1
--                 llhMax = ln $ VU.maximum lhs
--                 n1 = recip $ fromIntegral $ VU.length lhs
--                 llhs = VU.map ln lhs
--                 llhsNormed = VU.map (\x -> x - llhMax) llhs
--                 lhsNormedPowered = VU.map (\x -> exp $ dbeta * x) llhsNormed
sssWrapper ::
  ToJSON a =>
  MLSettings ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  a ->
  StdGen ->
  ML MarginalLikelihood
sssWrapper :: forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
sssWrapper MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"Stepping stone sampling."
  [Vector MarginalLikelihood]
logLhss <- forall a.
ToJSON a =>
NPoints
-> [Double]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRun NPoints
k [Double]
bsForward' ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"The last point does not need to be sampled with stepping stone sampling."
  forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"sssWrapper: Calculate marginal likelihood."
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Double] -> [Vector MarginalLikelihood] -> MarginalLikelihood
sssCalculateMarginalLikelihood [Double]
bsForward [Vector MarginalLikelihood]
logLhss
  where
    k :: NPoints
k = MLSettings -> NPoints
mlNPoints MLSettings
s
    bsForward :: [Double]
bsForward = NPoints -> [Double]
getPoints NPoints
k
    bsForward' :: [Double]
bsForward' = forall a. [a] -> [a]
init [Double]
bsForward
    em :: ExecutionMode
em = MLSettings -> ExecutionMode
mlExecutionMode MLSettings
s
    vb :: Verbosity
vb = MLSettings -> Verbosity
mlVerbosity MLSettings
s

-- | Estimate the marginal likelihood.
marginalLikelihood ::
  ToJSON a =>
  MLSettings ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  InitialState a ->
  StdGen ->
  IO MarginalLikelihood
marginalLikelihood :: forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> IO MarginalLikelihood
marginalLikelihood MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  -- Initialize.
  Environment MLSettings
e <- forall s.
(HasAnalysisName s, HasExecutionMode s, HasLogMode s,
 HasVerbosity s) =>
s -> IO (Environment s)
initializeEnvironment MLSettings
s

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (MLSettings -> Verbosity
mlVerbosity MLSettings
s forall a. Eq a => a -> a -> Bool
== Verbosity
Debug) forall a b. (a -> b) -> a -> b
$ do
    let n :: String
n = AnalysisName -> String
fromAnalysisName forall a b. (a -> b) -> a -> b
$ MLSettings -> AnalysisName
mlAnalysisName MLSettings
s
    Bool -> String -> IO ()
createDirectoryIfMissing Bool
True String
n

  -- Run.
  forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT
    ( do
        forall e.
(HasLock e, HasLogHandles e, HasStartingTime e, HasVerbosity e) =>
Logger e ()
logInfoStartingTime
        forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"Estimate marginal likelihood."
        forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"marginalLikelihood: The marginal likelihood settings are:"
        forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logDebugS forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show MLSettings
s
        MarginalLikelihood
val <- case MLSettings -> MLAlgorithm
mlAlgorithm MLSettings
s of
          MLAlgorithm
ThermodynamicIntegration -> forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
tiWrapper MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g
          MLAlgorithm
SteppingStoneSampling -> forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
sssWrapper MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g
        forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logInfoS forall a b. (a -> b) -> a -> b
$ String
"Marginal log likelihood: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall a. Log a -> a
ln MarginalLikelihood
val)
        -- TODO (low): Simulation variance.
        forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logInfoS String
"The simulation variance is not yet available."
        forall (m :: * -> *) a. Monad m => a -> m a
return MarginalLikelihood
val
    )
    Environment MLSettings
e