{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module      :  Mcmc.MarginalLikelihood
-- Description :  Calculate the marginal likelihood
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Mon Jan 11 16:34:18 2021.
module Mcmc.MarginalLikelihood
  ( MarginalLikelihood,
    NPoints (..),
    MLAlgorithm (..),
    MLSettings (..),
    marginalLikelihood,
  )
where

import Control.Concurrent (getNumCapabilities)
import Control.Concurrent.Async hiding (link)
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Reader
import Data.Aeson
import Data.List hiding (cycle)
import qualified Data.Map.Strict as M
import qualified Data.Vector as VB
import qualified Data.Vector.Unboxed as VU
import Mcmc.Acceptance
import Mcmc.Algorithm.MHG
import Mcmc.Chain.Chain
import Mcmc.Chain.Link
import Mcmc.Chain.Trace
import Mcmc.Cycle
import Mcmc.Environment
import Mcmc.Likelihood
import Mcmc.Logger
import Mcmc.Mcmc
import Mcmc.Monitor
import Mcmc.Prior
import Mcmc.Settings
import Numeric.Log hiding (sum)
import System.Directory
import System.Random.Stateful
import Text.Printf
import Prelude hiding (cycle)

-- | Marginal likelihood values are stored in log domain.
type MarginalLikelihood = Log Double

-- Reciprocal temperature value traversed along the path integral.
type Point = Double

-- | The number of points used to approximate the path integral.
newtype NPoints = NPoints {NPoints -> Int
fromNPoints :: Int}
  deriving (NPoints -> NPoints -> Bool
(NPoints -> NPoints -> Bool)
-> (NPoints -> NPoints -> Bool) -> Eq NPoints
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NPoints -> NPoints -> Bool
== :: NPoints -> NPoints -> Bool
$c/= :: NPoints -> NPoints -> Bool
/= :: NPoints -> NPoints -> Bool
Eq, ReadPrec [NPoints]
ReadPrec NPoints
Int -> ReadS NPoints
ReadS [NPoints]
(Int -> ReadS NPoints)
-> ReadS [NPoints]
-> ReadPrec NPoints
-> ReadPrec [NPoints]
-> Read NPoints
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS NPoints
readsPrec :: Int -> ReadS NPoints
$creadList :: ReadS [NPoints]
readList :: ReadS [NPoints]
$creadPrec :: ReadPrec NPoints
readPrec :: ReadPrec NPoints
$creadListPrec :: ReadPrec [NPoints]
readListPrec :: ReadPrec [NPoints]
Read, Int -> NPoints -> ShowS
[NPoints] -> ShowS
NPoints -> String
(Int -> NPoints -> ShowS)
-> (NPoints -> String) -> ([NPoints] -> ShowS) -> Show NPoints
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NPoints -> ShowS
showsPrec :: Int -> NPoints -> ShowS
$cshow :: NPoints -> String
show :: NPoints -> String
$cshowList :: [NPoints] -> ShowS
showList :: [NPoints] -> ShowS
Show)

-- | Algorithms to calculate the marginal likelihood.
data MLAlgorithm
  = -- | Use a classical path integral. Also known as thermodynamic integration.
    -- In particular, /Annealing-Melting Integration/ is used.
    --
    -- See Lartillot, N., & Philippe, H., Computing Bayes Factors Using
    -- Thermodynamic Integration, Systematic Biology, 55(2), 195–207 (2006).
    -- http://dx.doi.org/10.1080/10635150500433722
    ThermodynamicIntegration
  | -- | Use stepping stone sampling.
    --
    -- See Xie, W., Lewis, P. O., Fan, Y., Kuo, L., & Chen, M., Improving
    -- marginal likelihood estimation for Bayesian phylogenetic model selection,
    -- Systematic Biology, 60(2), 150–160 (2010).
    -- http://dx.doi.org/10.1093/sysbio/syq085
    --
    -- Or Fan, Y., Wu, R., Chen, M., Kuo, L., & Lewis, P. O., Choosing among
    -- partition models in bayesian phylogenetics, Molecular Biology and
    -- Evolution, 28(1), 523–532 (2010). http://dx.doi.org/10.1093/molbev/msq224
    SteppingStoneSampling
  deriving (MLAlgorithm -> MLAlgorithm -> Bool
(MLAlgorithm -> MLAlgorithm -> Bool)
-> (MLAlgorithm -> MLAlgorithm -> Bool) -> Eq MLAlgorithm
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MLAlgorithm -> MLAlgorithm -> Bool
== :: MLAlgorithm -> MLAlgorithm -> Bool
$c/= :: MLAlgorithm -> MLAlgorithm -> Bool
/= :: MLAlgorithm -> MLAlgorithm -> Bool
Eq, ReadPrec [MLAlgorithm]
ReadPrec MLAlgorithm
Int -> ReadS MLAlgorithm
ReadS [MLAlgorithm]
(Int -> ReadS MLAlgorithm)
-> ReadS [MLAlgorithm]
-> ReadPrec MLAlgorithm
-> ReadPrec [MLAlgorithm]
-> Read MLAlgorithm
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS MLAlgorithm
readsPrec :: Int -> ReadS MLAlgorithm
$creadList :: ReadS [MLAlgorithm]
readList :: ReadS [MLAlgorithm]
$creadPrec :: ReadPrec MLAlgorithm
readPrec :: ReadPrec MLAlgorithm
$creadListPrec :: ReadPrec [MLAlgorithm]
readListPrec :: ReadPrec [MLAlgorithm]
Read, Int -> MLAlgorithm -> ShowS
[MLAlgorithm] -> ShowS
MLAlgorithm -> String
(Int -> MLAlgorithm -> ShowS)
-> (MLAlgorithm -> String)
-> ([MLAlgorithm] -> ShowS)
-> Show MLAlgorithm
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MLAlgorithm -> ShowS
showsPrec :: Int -> MLAlgorithm -> ShowS
$cshow :: MLAlgorithm -> String
show :: MLAlgorithm -> String
$cshowList :: [MLAlgorithm] -> ShowS
showList :: [MLAlgorithm] -> ShowS
Show)

-- | Settings of the marginal likelihood estimation.
data MLSettings = MLSettings
  { MLSettings -> AnalysisName
mlAnalysisName :: AnalysisName,
    MLSettings -> MLAlgorithm
mlAlgorithm :: MLAlgorithm,
    MLSettings -> NPoints
mlNPoints :: NPoints,
    -- | Initial burn in at the starting point of the path (or each segment if
    -- running in parallel).
    MLSettings -> BurnInSettings
mlInitialBurnIn :: BurnInSettings,
    -- | Repetitive burn in at each point on the path.
    MLSettings -> BurnInSettings
mlPointBurnIn :: BurnInSettings,
    -- | The number of iterations performed at each point.
    MLSettings -> Iterations
mlIterations :: Iterations,
    MLSettings -> ExecutionMode
mlExecutionMode :: ExecutionMode,
    MLSettings -> ParallelizationMode
mlParallelizationMode :: ParallelizationMode,
    MLSettings -> LogMode
mlLogMode :: LogMode,
    MLSettings -> Verbosity
mlVerbosity :: Verbosity
  }
  deriving (MLSettings -> MLSettings -> Bool
(MLSettings -> MLSettings -> Bool)
-> (MLSettings -> MLSettings -> Bool) -> Eq MLSettings
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MLSettings -> MLSettings -> Bool
== :: MLSettings -> MLSettings -> Bool
$c/= :: MLSettings -> MLSettings -> Bool
/= :: MLSettings -> MLSettings -> Bool
Eq, ReadPrec [MLSettings]
ReadPrec MLSettings
Int -> ReadS MLSettings
ReadS [MLSettings]
(Int -> ReadS MLSettings)
-> ReadS [MLSettings]
-> ReadPrec MLSettings
-> ReadPrec [MLSettings]
-> Read MLSettings
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS MLSettings
readsPrec :: Int -> ReadS MLSettings
$creadList :: ReadS [MLSettings]
readList :: ReadS [MLSettings]
$creadPrec :: ReadPrec MLSettings
readPrec :: ReadPrec MLSettings
$creadListPrec :: ReadPrec [MLSettings]
readListPrec :: ReadPrec [MLSettings]
Read, Int -> MLSettings -> ShowS
[MLSettings] -> ShowS
MLSettings -> String
(Int -> MLSettings -> ShowS)
-> (MLSettings -> String)
-> ([MLSettings] -> ShowS)
-> Show MLSettings
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MLSettings -> ShowS
showsPrec :: Int -> MLSettings -> ShowS
$cshow :: MLSettings -> String
show :: MLSettings -> String
$cshowList :: [MLSettings] -> ShowS
showList :: [MLSettings] -> ShowS
Show)

instance HasAnalysisName MLSettings where
  getAnalysisName :: MLSettings -> AnalysisName
getAnalysisName = MLSettings -> AnalysisName
mlAnalysisName

instance HasExecutionMode MLSettings where
  getExecutionMode :: MLSettings -> ExecutionMode
getExecutionMode = MLSettings -> ExecutionMode
mlExecutionMode

instance HasLogMode MLSettings where
  getLogMode :: MLSettings -> LogMode
getLogMode = MLSettings -> LogMode
mlLogMode

instance HasVerbosity MLSettings where
  getVerbosity :: MLSettings -> Verbosity
getVerbosity = MLSettings -> Verbosity
mlVerbosity

type ML a = ReaderT (Environment MLSettings) IO a

-- See 'getPoints'. Alpha=0.3 is the standard choice.
alpha :: Double
alpha :: Double
alpha = Double
0.3

-- Distribute the points according to a skewed beta distribution with given
-- 'alpha' value. If alpha is below 1.0, more points at lower values, which is
-- desired. It is inconvenient that the reciprocal temperatures are denoted as
-- beta, and we also use the beta distribution :). Don't mix them up!
--
-- See discussion in Xie, W., Lewis, P. O., Fan, Y., Kuo, L., & Chen, M.,
-- Improving marginal likelihood estimation for bayesian phylogenetic model
-- selection, Systematic Biology, 60(2), 150–160 (2010).
-- http://dx.doi.org/10.1093/sysbio/syq085
--
-- Or Figure 1 in Höhna, S., Landis, M. J., & Huelsenbeck, J. P., Parallel power
-- posterior analyses for fast computation of marginal likelihoods in
-- phylogenetics (2017). http://dx.doi.org/10.1101/104422
getPoints :: NPoints -> [Point]
getPoints :: NPoints -> [Double]
getPoints NPoints
x = [Int -> Double
forall {a} {a}. (Fractional a, Integral a) => a -> a
f Int
i Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1.0 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
alpha) | Int
i <- [Int
0 .. Int
k1]]
  where
    k :: Int
k = NPoints -> Int
fromNPoints NPoints
x
    k1 :: Int
k1 = Int -> Int
forall a. Enum a => a -> a
pred Int
k
    f :: a -> a
f a
j = a -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
j a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k1

sampleAtPoint ::
  (ToJSON a) =>
  Bool ->
  Point ->
  Settings ->
  LikelihoodFunction a ->
  MHG a ->
  ML (MHG a)
sampleAtPoint :: forall a.
ToJSON a =>
Bool
-> Double
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML (MHG a)
sampleAtPoint Bool
isInitialBurnIn Double
x Settings
ss LikelihoodFunction a
lhf MHG a
a = do
  MHG a
a'' <- IO (MHG a) -> ML (MHG a)
forall a. IO a -> ReaderT (Environment MLSettings) IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (MHG a) -> ML (MHG a)) -> IO (MHG a) -> ML (MHG a)
forall a b. (a -> b) -> a -> b
$ Settings -> MHG a -> IO (MHG a)
forall a. Algorithm a => Settings -> a -> IO a
mcmc Settings
ss' MHG a
a'
  let ch'' :: Chain a
ch'' = MHG a -> Chain a
forall a. MHG a -> Chain a
fromMHG MHG a
a''
      ac :: Acceptances (Proposal a)
ac = Chain a -> Acceptances (Proposal a)
forall a. Chain a -> Acceptances (Proposal a)
acceptances Chain a
ch''
      mAr :: Maybe (Map (Proposal a) Double)
mAr = Map (Proposal a) (Maybe Double) -> Maybe (Map (Proposal a) Double)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a.
Monad m =>
Map (Proposal a) (m a) -> m (Map (Proposal a) a)
sequence (Map (Proposal a) (Maybe Double)
 -> Maybe (Map (Proposal a) Double))
-> Map (Proposal a) (Maybe Double)
-> Maybe (Map (Proposal a) Double)
forall a b. (a -> b) -> a -> b
$ Acceptances (Proposal a) -> Map (Proposal a) (Maybe Double)
forall k. Acceptances k -> Map k (Maybe Double)
acceptanceRates Acceptances (Proposal a)
ac
  ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"sampleAtPoint: Summarize cycle."
  ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB (ByteString -> Logger (Environment MLSettings) ())
-> ByteString -> Logger (Environment MLSettings) ()
forall a b. (a -> b) -> a -> b
$ IterationMode -> Acceptances (Proposal a) -> Cycle a -> ByteString
forall a.
IterationMode -> Acceptances (Proposal a) -> Cycle a -> ByteString
summarizeCycle IterationMode
AllProposals Acceptances (Proposal a)
ac (Cycle a -> ByteString) -> Cycle a -> ByteString
forall a b. (a -> b) -> a -> b
$ Chain a -> Cycle a
forall a. Chain a -> Cycle a
cycle Chain a
ch''
  Bool
-> Logger (Environment MLSettings) ()
-> Logger (Environment MLSettings) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
    Bool
isInitialBurnIn
    ( case Maybe (Map (Proposal a) Double)
mAr of
        Maybe (Map (Proposal a) Double)
Nothing -> ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logWarnB ByteString
"Some acceptance rates are unavailable."
        Just Map (Proposal a) Double
ar -> do
          Bool
-> Logger (Environment MLSettings) ()
-> Logger (Environment MLSettings) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Map (Proposal a) Double -> Bool
forall k a. Map k a -> Bool
M.null (Map (Proposal a) Double -> Bool)
-> Map (Proposal a) Double -> Bool
forall a b. (a -> b) -> a -> b
$ (Double -> Bool)
-> Map (Proposal a) Double -> Map (Proposal a) Double
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
0.1) Map (Proposal a) Double
ar) (Logger (Environment MLSettings) ()
 -> Logger (Environment MLSettings) ())
-> Logger (Environment MLSettings) ()
-> Logger (Environment MLSettings) ()
forall a b. (a -> b) -> a -> b
$ ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logWarnB ByteString
"Some acceptance rates are below 0.1."
          Bool
-> Logger (Environment MLSettings) ()
-> Logger (Environment MLSettings) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Map (Proposal a) Double -> Bool
forall k a. Map k a -> Bool
M.null (Map (Proposal a) Double -> Bool)
-> Map (Proposal a) Double -> Bool
forall a b. (a -> b) -> a -> b
$ (Double -> Bool)
-> Map (Proposal a) Double -> Map (Proposal a) Double
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
0.9) Map (Proposal a) Double
ar) (Logger (Environment MLSettings) ()
 -> Logger (Environment MLSettings) ())
-> Logger (Environment MLSettings) ()
-> Logger (Environment MLSettings) ()
forall a b. (a -> b) -> a -> b
$ ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logWarnB ByteString
"Some acceptance rates are above 0.9."
    )
  MHG a -> ML (MHG a)
forall a. a -> ReaderT (Environment MLSettings) IO a
forall (m :: * -> *) a. Monad m => a -> m a
return MHG a
a''
  where
    -- For debugging set a proper analysis name.
    nm :: AnalysisName
nm = Settings -> AnalysisName
sAnalysisName Settings
ss
    getName :: Point -> AnalysisName
    getName :: Double -> AnalysisName
getName Double
y = AnalysisName
nm AnalysisName -> AnalysisName -> AnalysisName
forall a. Semigroup a => a -> a -> a
<> String -> AnalysisName
AnalysisName (String
"/" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String -> Double -> String
forall r. PrintfType r => String -> r
printf String
"point%.8f" Double
y)
    ss' :: Settings
ss' = Settings
ss {sAnalysisName = getName x}
    -- Amend the likelihood function. Don't calculate the likelihood when the
    -- point is 0.0.
    lhf' :: LikelihoodFunction a
lhf' = if Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
0.0 then MarginalLikelihood -> LikelihoodFunction a
forall a b. a -> b -> a
const MarginalLikelihood
1.0 else (MarginalLikelihood -> MarginalLikelihood -> MarginalLikelihood
forall a. Floating a => a -> a -> a
** Double -> MarginalLikelihood
forall a. a -> Log a
Exp (Double -> Double
forall a. Floating a => a -> a
log Double
x)) (MarginalLikelihood -> MarginalLikelihood)
-> LikelihoodFunction a -> LikelihoodFunction a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LikelihoodFunction a
lhf
    -- Amend the MHG algorithm.
    ch :: Chain a
ch = MHG a -> Chain a
forall a. MHG a -> Chain a
fromMHG MHG a
a
    l :: Link a
l = Chain a -> Link a
forall a. Chain a -> Link a
link Chain a
ch
    ch' :: Chain a
ch' =
      Chain a
ch
        { -- Important: Update the likelihood using the new likelihood function.
          link = l {likelihood = lhf' $ state l},
          iteration = 0,
          start = 0,
          likelihoodFunction = lhf'
        }
    a' :: MHG a
a' = Chain a -> MHG a
forall a. Chain a -> MHG a
MHG Chain a
ch'

traversePoints ::
  (ToJSON a) =>
  NPoints ->
  [(Int, Point)] ->
  Settings ->
  LikelihoodFunction a ->
  MHG a ->
  -- For each point a vector of obtained likelihoods stored in the log domain.
  ML [VU.Vector Likelihood]
traversePoints :: forall a.
ToJSON a =>
NPoints
-> [(Int, Double)]
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML [Vector MarginalLikelihood]
traversePoints NPoints
_ [] Settings
_ LikelihoodFunction a
_ MHG a
_ = [Vector MarginalLikelihood] -> ML [Vector MarginalLikelihood]
forall a. a -> ReaderT (Environment MLSettings) IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
traversePoints NPoints
k ((Int
idb, Double
b) : [(Int, Double)]
bs) Settings
ss LikelihoodFunction a
lhf MHG a
a = do
  let msg :: String
msg = String -> Int -> Int -> Double -> String
forall r. PrintfType r => String -> r
printf String
"Point %4d of %4d: %.12f." Int
idb Int
k' Double
b
  String -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logInfoS String
msg
  MHG a
a' <- Bool
-> Double
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML (MHG a)
forall a.
ToJSON a =>
Bool
-> Double
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML (MHG a)
sampleAtPoint Bool
False Double
b Settings
ss LikelihoodFunction a
lhf MHG a
a
  -- Get the links samples at this point.
  Vector (Link a)
ls <- IO (Vector (Link a))
-> ReaderT (Environment MLSettings) IO (Vector (Link a))
forall a. IO a -> ReaderT (Environment MLSettings) IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Vector (Link a))
 -> ReaderT (Environment MLSettings) IO (Vector (Link a)))
-> IO (Vector (Link a))
-> ReaderT (Environment MLSettings) IO (Vector (Link a))
forall a b. (a -> b) -> a -> b
$ Int -> Trace a -> IO (Vector (Link a))
forall a. Int -> Trace a -> IO (Vector (Link a))
takeT Int
n (Trace a -> IO (Vector (Link a)))
-> Trace a -> IO (Vector (Link a))
forall a b. (a -> b) -> a -> b
$ Chain a -> Trace a
forall a. Chain a -> Trace a
trace (Chain a -> Trace a) -> Chain a -> Trace a
forall a b. (a -> b) -> a -> b
$ MHG a -> Chain a
forall a. MHG a -> Chain a
fromMHG MHG a
a'
  -- Extract the likelihoods.
  --
  -- NOTE: This could be sped up by mapping (** -b) on the power likelihoods.
  --
  -- NOTE: This bang is an important one, because if the lhs are not strictly
  -- calculated here, the complete MCMC runs are dragged along before doing so
  -- resulting in a severe memory leak.
  let !lhs :: Vector MarginalLikelihood
lhs = Vector MarginalLikelihood -> Vector MarginalLikelihood
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert (Vector MarginalLikelihood -> Vector MarginalLikelihood)
-> Vector MarginalLikelihood -> Vector MarginalLikelihood
forall a b. (a -> b) -> a -> b
$ (Link a -> MarginalLikelihood)
-> Vector (Link a) -> Vector MarginalLikelihood
forall a b. (a -> b) -> Vector a -> Vector b
VB.map (LikelihoodFunction a
lhf LikelihoodFunction a
-> (Link a -> a) -> Link a -> MarginalLikelihood
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Link a -> a
forall a. Link a -> a
state) Vector (Link a)
ls
  -- Sample the other points.
  [Vector MarginalLikelihood]
lhss <- NPoints
-> [(Int, Double)]
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML [Vector MarginalLikelihood]
forall a.
ToJSON a =>
NPoints
-> [(Int, Double)]
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML [Vector MarginalLikelihood]
traversePoints NPoints
k [(Int, Double)]
bs Settings
ss LikelihoodFunction a
lhf MHG a
a'
  [Vector MarginalLikelihood] -> ML [Vector MarginalLikelihood]
forall a. a -> ReaderT (Environment MLSettings) IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Vector MarginalLikelihood] -> ML [Vector MarginalLikelihood])
-> [Vector MarginalLikelihood] -> ML [Vector MarginalLikelihood]
forall a b. (a -> b) -> a -> b
$ Vector MarginalLikelihood
lhs Vector MarginalLikelihood
-> [Vector MarginalLikelihood] -> [Vector MarginalLikelihood]
forall a. a -> [a] -> [a]
: [Vector MarginalLikelihood]
lhss
  where
    n :: Int
n = Iterations -> Int
fromIterations (Iterations -> Int) -> Iterations -> Int
forall a b. (a -> b) -> a -> b
$ Settings -> Iterations
sIterations Settings
ss
    (NPoints Int
k') = NPoints
k

nChunks :: Int -> [a] -> [[a]]
nChunks :: forall a. Int -> [a] -> [[a]]
nChunks Int
k [a]
xs = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chop (Int -> Int -> [Int]
chunks Int
k Int
l) [a]
xs
  where
    l :: Int
l = [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs

chunks :: Int -> Int -> [Int]
chunks :: Int -> Int -> [Int]
chunks Int
c Int
n = (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) [Int]
ns
  where
    n' :: Int
n' = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
c
    r :: Int
r = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
c
    ns :: [Int]
ns = Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate Int
r (Int
n' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r) Int
n'

chop :: [Int] -> [a] -> [[a]]
chop :: forall a. [Int] -> [a] -> [[a]]
chop [] [] = []
chop (Int
n : [Int]
ns) [a]
xs
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
n [a]
xs [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chop [Int]
ns (Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop Int
n [a]
xs)
  | Bool
otherwise = String -> [[a]]
forall a. HasCallStack => String -> a
error String
"chop: n negative or zero"
chop [Int]
_ [a]
_ = String -> [[a]]
forall a. HasCallStack => String -> a
error String
"chop: not all list elements handled"

mlRunPar ::
  (ToJSON a) =>
  ParallelizationMode ->
  NPoints ->
  [(Int, Point)] ->
  ExecutionMode ->
  Verbosity ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  a ->
  StdGen ->
  ML [VU.Vector Likelihood]
mlRunPar :: forall a.
ToJSON a =>
ParallelizationMode
-> NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRunPar ParallelizationMode
pm NPoints
k [(Int, Double)]
xs ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  Int
nThreads <- case ParallelizationMode
pm of
    ParallelizationMode
Sequential -> do
      ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"Sequential execution."
      Int -> ReaderT (Environment MLSettings) IO Int
forall a. a -> ReaderT (Environment MLSettings) IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
1
    ParallelizationMode
Parallel -> do
      Int
n <- IO Int -> ReaderT (Environment MLSettings) IO Int
forall a. IO a -> ReaderT (Environment MLSettings) IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Int
getNumCapabilities
      String -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logInfoS (String -> Logger (Environment MLSettings) ())
-> String -> Logger (Environment MLSettings) ()
forall a b. (a -> b) -> a -> b
$ String
"Parallel execution with " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
n String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" cores."
      Int -> ReaderT (Environment MLSettings) IO Int
forall a. a -> ReaderT (Environment MLSettings) IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
n
  let xsChunks :: [[(Int, Double)]]
xsChunks = Int -> [(Int, Double)] -> [[(Int, Double)]]
forall a. Int -> [a] -> [[a]]
nChunks Int
nThreads [(Int, Double)]
xs
  Environment MLSettings
r <- ReaderT (Environment MLSettings) IO (Environment MLSettings)
forall (m :: * -> *) r. Monad m => ReaderT r m r
ask
  [[Vector MarginalLikelihood]]
xss <-
    IO [[Vector MarginalLikelihood]]
-> ReaderT
     (Environment MLSettings) IO [[Vector MarginalLikelihood]]
forall a. IO a -> ReaderT (Environment MLSettings) IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [[Vector MarginalLikelihood]]
 -> ReaderT
      (Environment MLSettings) IO [[Vector MarginalLikelihood]])
-> IO [[Vector MarginalLikelihood]]
-> ReaderT
     (Environment MLSettings) IO [[Vector MarginalLikelihood]]
forall a b. (a -> b) -> a -> b
$
      ([(Int, Double)] -> IO [Vector MarginalLikelihood])
-> [[(Int, Double)]] -> IO [[Vector MarginalLikelihood]]
forall (t :: * -> *) a b.
Traversable t =>
(a -> IO b) -> t a -> IO (t b)
mapConcurrently
        (\[(Int, Double)]
thesePoints -> ML [Vector MarginalLikelihood]
-> Environment MLSettings -> IO [Vector MarginalLikelihood]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
forall a.
ToJSON a =>
NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRun NPoints
k [(Int, Double)]
thesePoints ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g) Environment MLSettings
r)
        [[(Int, Double)]]
xsChunks
  [Vector MarginalLikelihood] -> ML [Vector MarginalLikelihood]
forall a. a -> ReaderT (Environment MLSettings) IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Vector MarginalLikelihood] -> ML [Vector MarginalLikelihood])
-> [Vector MarginalLikelihood] -> ML [Vector MarginalLikelihood]
forall a b. (a -> b) -> a -> b
$ [[Vector MarginalLikelihood]] -> [Vector MarginalLikelihood]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Vector MarginalLikelihood]]
xss

mlRun ::
  (ToJSON a) =>
  NPoints ->
  [(Int, Point)] ->
  ExecutionMode ->
  Verbosity ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  a ->
  StdGen ->
  -- For each point a vector of likelihoods stored in log domain.
  ML [VU.Vector Likelihood]
mlRun :: forall a.
ToJSON a =>
NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRun NPoints
k [(Int, Double)]
xs ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"mlRun: Begin."
  MLSettings
s <- (Environment MLSettings -> MLSettings)
-> ReaderT (Environment MLSettings) IO MLSettings
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
reader Environment MLSettings -> MLSettings
forall s. Environment s -> s
settings
  let nm :: AnalysisName
nm = MLSettings -> AnalysisName
mlAnalysisName MLSettings
s
      is :: Iterations
is = MLSettings -> Iterations
mlIterations MLSettings
s
      biI :: BurnInSettings
biI = MLSettings -> BurnInSettings
mlInitialBurnIn MLSettings
s
      biP :: BurnInSettings
biP = MLSettings -> BurnInSettings
mlPointBurnIn MLSettings
s
      -- Only log sub MCMC samplers when debugging.
      vb' :: Verbosity
vb' = case Verbosity
vb of
        Verbosity
Debug -> Verbosity
Debug
        Verbosity
_ -> Verbosity
Quiet
      trLen :: TraceLength
trLen = Int -> TraceLength
TraceMinimum (Int -> TraceLength) -> Int -> TraceLength
forall a b. (a -> b) -> a -> b
$ Iterations -> Int
fromIterations Iterations
is
      ssI :: Settings
ssI = AnalysisName
-> BurnInSettings
-> Iterations
-> TraceLength
-> ExecutionMode
-> ParallelizationMode
-> SaveMode
-> LogMode
-> Verbosity
-> Settings
Settings AnalysisName
nm BurnInSettings
biI (Int -> Iterations
Iterations Int
0) TraceLength
trLen ExecutionMode
em ParallelizationMode
Sequential SaveMode
NoSave LogMode
LogFileOnly Verbosity
vb'
      ssP :: Settings
ssP = AnalysisName
-> BurnInSettings
-> Iterations
-> TraceLength
-> ExecutionMode
-> ParallelizationMode
-> SaveMode
-> LogMode
-> Verbosity
-> Settings
Settings AnalysisName
nm BurnInSettings
biP Iterations
is TraceLength
trLen ExecutionMode
em ParallelizationMode
Sequential SaveMode
NoSave LogMode
LogFileOnly Verbosity
vb'
  ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"mlRun: Initialize MHG algorithm."
  MHG a
a0 <- IO (MHG a) -> ReaderT (Environment MLSettings) IO (MHG a)
forall a. IO a -> ReaderT (Environment MLSettings) IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (MHG a) -> ReaderT (Environment MLSettings) IO (MHG a))
-> IO (MHG a) -> ReaderT (Environment MLSettings) IO (MHG a)
forall a b. (a -> b) -> a -> b
$ Settings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> IO (MHG a)
forall a.
Settings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> IO (MHG a)
mhg Settings
ssI PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g
  let msg :: String
msg = String -> Double -> Int -> String
forall r. PrintfType r => String -> r
printf String
"Initial burn in at point %.12f with ID %4d." Double
x0 Int
id0
  String -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logInfoS String
msg
  MHG a
a1 <- Bool
-> Double
-> Settings
-> PriorFunction a
-> MHG a
-> ReaderT (Environment MLSettings) IO (MHG a)
forall a.
ToJSON a =>
Bool
-> Double
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML (MHG a)
sampleAtPoint Bool
True Double
x0 Settings
ssI PriorFunction a
lhf MHG a
a0
  ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"mlRun: Traverse points."
  NPoints
-> [(Int, Double)]
-> Settings
-> PriorFunction a
-> MHG a
-> ML [Vector MarginalLikelihood]
forall a.
ToJSON a =>
NPoints
-> [(Int, Double)]
-> Settings
-> LikelihoodFunction a
-> MHG a
-> ML [Vector MarginalLikelihood]
traversePoints NPoints
k [(Int, Double)]
xs Settings
ssP PriorFunction a
lhf MHG a
a1
  where
    (Int
id0, Double
x0) = [(Int, Double)] -> (Int, Double)
forall a. HasCallStack => [a] -> a
head [(Int, Double)]
xs

-- Use lists since the number of points is expected to be low.
integrateSimpsonTriangle ::
  -- X values.
  [Point] ->
  -- Y values.
  [Double] ->
  -- Integral.
  Double
integrateSimpsonTriangle :: [Double] -> [Double] -> Double
integrateSimpsonTriangle [Double]
xs [Double]
ys = Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* [Double] -> [Double] -> Double
forall {a}. Num a => [a] -> [a] -> a
go [Double]
xs [Double]
ys
  where
    go :: [a] -> [a] -> a
go (a
p0 : a
p1 : [a]
ps) (a
z0 : a
z1 : [a]
zs) = (a
z0 a -> a -> a
forall a. Num a => a -> a -> a
+ a
z1) a -> a -> a
forall a. Num a => a -> a -> a
* (a
p1 a -> a -> a
forall a. Num a => a -> a -> a
- a
p0) a -> a -> a
forall a. Num a => a -> a -> a
+ [a] -> [a] -> a
go (a
p1 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
ps) (a
z1 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
zs)
    go [a]
_ [a]
_ = a
0

tiWrapper ::
  (ToJSON a) =>
  MLSettings ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  a ->
  StdGen ->
  ML MarginalLikelihood
tiWrapper :: forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
tiWrapper MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"Path integral (thermodynamic integration)."
  let (StdGen
g0, StdGen
g1) = StdGen -> (StdGen, StdGen)
forall g. RandomGen g => g -> (g, g)
split StdGen
g

  -- Parallel execution of both path integrals.
  Environment MLSettings
r <- ReaderT (Environment MLSettings) IO (Environment MLSettings)
forall (m :: * -> *) r. Monad m => ReaderT r m r
ask
  ([Vector MarginalLikelihood]
lhssForward, [Vector MarginalLikelihood]
lhssBackward) <-
    IO ([Vector MarginalLikelihood], [Vector MarginalLikelihood])
-> ReaderT
     (Environment MLSettings)
     IO
     ([Vector MarginalLikelihood], [Vector MarginalLikelihood])
forall (m :: * -> *) a.
Monad m =>
m a -> ReaderT (Environment MLSettings) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO ([Vector MarginalLikelihood], [Vector MarginalLikelihood])
 -> ReaderT
      (Environment MLSettings)
      IO
      ([Vector MarginalLikelihood], [Vector MarginalLikelihood]))
-> IO ([Vector MarginalLikelihood], [Vector MarginalLikelihood])
-> ReaderT
     (Environment MLSettings)
     IO
     ([Vector MarginalLikelihood], [Vector MarginalLikelihood])
forall a b. (a -> b) -> a -> b
$
      IO [Vector MarginalLikelihood]
-> IO [Vector MarginalLikelihood]
-> IO ([Vector MarginalLikelihood], [Vector MarginalLikelihood])
forall a b. IO a -> IO b -> IO (a, b)
concurrently
        (ML [Vector MarginalLikelihood]
-> Environment MLSettings -> IO [Vector MarginalLikelihood]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ParallelizationMode
-> NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
forall a.
ToJSON a =>
ParallelizationMode
-> NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRunPar ParallelizationMode
pm NPoints
k ([Int] -> [Double] -> [(Int, Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..] [Double]
bsForward) ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g0) Environment MLSettings
r)
        (ML [Vector MarginalLikelihood]
-> Environment MLSettings -> IO [Vector MarginalLikelihood]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ParallelizationMode
-> NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
forall a.
ToJSON a =>
ParallelizationMode
-> NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRunPar ParallelizationMode
pm NPoints
k ([Int] -> [Double] -> [(Int, Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..] [Double]
bsBackward) ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g1) Environment MLSettings
r)
  Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasStartingTime e, HasVerbosity e) =>
Logger e ()
logInfoEndTime

  ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"tiWrapper: Calculate mean log likelihoods."
  -- It is important to average across the log likelihoods here (and not the
  -- likelihoods). I am not exactly sure why this is.
  let getMeanLogLhs :: [Vector MarginalLikelihood] -> [Double]
getMeanLogLhs = (Vector MarginalLikelihood -> Double)
-> [Vector MarginalLikelihood] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (\Vector MarginalLikelihood
x -> Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
VU.sum ((MarginalLikelihood -> Double)
-> Vector MarginalLikelihood -> Vector Double
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map MarginalLikelihood -> Double
forall a. Log a -> a
ln Vector MarginalLikelihood
x) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector MarginalLikelihood -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector MarginalLikelihood
x))
      mlForward :: Double
mlForward = [Double] -> [Double] -> Double
integrateSimpsonTriangle [Double]
bsForward ([Vector MarginalLikelihood] -> [Double]
getMeanLogLhs [Vector MarginalLikelihood]
lhssForward)
      mlBackward :: Double
mlBackward = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ [Double] -> [Double] -> Double
integrateSimpsonTriangle [Double]
bsBackward ([Vector MarginalLikelihood] -> [Double]
getMeanLogLhs [Vector MarginalLikelihood]
lhssBackward)
  String -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logDebugS (String -> Logger (Environment MLSettings) ())
-> String -> Logger (Environment MLSettings) ()
forall a b. (a -> b) -> a -> b
$ String
"tiWrapper: Marginal log likelihood of forward integral: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
mlForward
  String -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logDebugS (String -> Logger (Environment MLSettings) ())
-> String -> Logger (Environment MLSettings) ()
forall a b. (a -> b) -> a -> b
$ String
"tiWrapper: Marginal log likelihood of backward integral: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
mlBackward
  let mean :: Double
mean = Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
mlForward Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
mlBackward)
  String -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logDebugS (String -> Logger (Environment MLSettings) ())
-> String -> Logger (Environment MLSettings) ()
forall a b. (a -> b) -> a -> b
$ String
"tiWrapper: The mean is: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show Double
mean
  MarginalLikelihood -> ML MarginalLikelihood
forall a. a -> ReaderT (Environment MLSettings) IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (MarginalLikelihood -> ML MarginalLikelihood)
-> MarginalLikelihood -> ML MarginalLikelihood
forall a b. (a -> b) -> a -> b
$ Double -> MarginalLikelihood
forall a. a -> Log a
Exp Double
mean
  where
    k :: NPoints
k = MLSettings -> NPoints
mlNPoints MLSettings
s
    bsForward :: [Double]
bsForward = NPoints -> [Double]
getPoints NPoints
k
    bsBackward :: [Double]
bsBackward = [Double] -> [Double]
forall a. [a] -> [a]
reverse [Double]
bsForward
    em :: ExecutionMode
em = MLSettings -> ExecutionMode
mlExecutionMode MLSettings
s
    pm :: ParallelizationMode
pm = MLSettings -> ParallelizationMode
mlParallelizationMode MLSettings
s
    vb :: Verbosity
vb = MLSettings -> Verbosity
mlVerbosity MLSettings
s

-- Helper function to exponentiate log domain values with a double value.
pow' :: Log Double -> Double -> Log Double
pow' :: MarginalLikelihood -> Double -> MarginalLikelihood
pow' MarginalLikelihood
x Double
p = Double -> MarginalLikelihood
forall a. a -> Log a
Exp (Double -> MarginalLikelihood) -> Double -> MarginalLikelihood
forall a b. (a -> b) -> a -> b
$ MarginalLikelihood -> Double
forall a. Log a -> a
ln MarginalLikelihood
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
p

-- See Xie2010 p. 153, bottom left.
sssCalculateMarginalLikelihood :: [Point] -> [VU.Vector Likelihood] -> MarginalLikelihood
sssCalculateMarginalLikelihood :: [Double] -> [Vector MarginalLikelihood] -> MarginalLikelihood
sssCalculateMarginalLikelihood [Double]
xs [Vector MarginalLikelihood]
lhss = [MarginalLikelihood] -> MarginalLikelihood
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([MarginalLikelihood] -> MarginalLikelihood)
-> [MarginalLikelihood] -> MarginalLikelihood
forall a b. (a -> b) -> a -> b
$ (Double
 -> Double -> Vector MarginalLikelihood -> MarginalLikelihood)
-> [Double]
-> [Double]
-> [Vector MarginalLikelihood]
-> [MarginalLikelihood]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Double -> Double -> Vector MarginalLikelihood -> MarginalLikelihood
f [Double]
xs ([Double] -> [Double]
forall a. HasCallStack => [a] -> [a]
tail [Double]
xs) [Vector MarginalLikelihood]
lhss
  where
    f :: Point -> Point -> VU.Vector Likelihood -> MarginalLikelihood
    -- f beta_{k-1} beta_k lhs_{k-1}
    f :: Double -> Double -> Vector MarginalLikelihood -> MarginalLikelihood
f Double
bkm1 Double
bk Vector MarginalLikelihood
lhs = MarginalLikelihood
n1 MarginalLikelihood -> MarginalLikelihood -> MarginalLikelihood
forall a. Num a => a -> a -> a
* Vector MarginalLikelihood -> MarginalLikelihood
forall a. (Unbox a, Num a) => Vector a -> a
VU.sum Vector MarginalLikelihood
lhsPowered
      where
        n1 :: MarginalLikelihood
n1 = MarginalLikelihood -> MarginalLikelihood
forall a. Fractional a => a -> a
recip (MarginalLikelihood -> MarginalLikelihood)
-> MarginalLikelihood -> MarginalLikelihood
forall a b. (a -> b) -> a -> b
$ Int -> MarginalLikelihood
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> MarginalLikelihood) -> Int -> MarginalLikelihood
forall a b. (a -> b) -> a -> b
$ Vector MarginalLikelihood -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector MarginalLikelihood
lhs
        dbeta :: Double
dbeta = Double
bk Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
bkm1
        lhsPowered :: Vector MarginalLikelihood
lhsPowered = (MarginalLikelihood -> MarginalLikelihood)
-> Vector MarginalLikelihood -> Vector MarginalLikelihood
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map (MarginalLikelihood -> Double -> MarginalLikelihood
`pow'` Double
dbeta) Vector MarginalLikelihood
lhs

-- -- Numerical stability by factoring out lhMax. But no observed
-- -- improvement towards the standard version.
--
-- f bkm1 bk lhs = n1 * pow' lhMax dbeta * VU.sum lhsNormedPowered
--   where n1 = recip $ fromIntegral $ VU.length lhs
--         lhMax = VU.maximum lhs
--         dbeta = bk - bkm1
--         lhsNormed = VU.map (/lhMax) lhs
--         lhsNormedPowered = VU.map (`pow'` dbeta) lhsNormed

-- -- Computation of the log of the marginal likelihood. According to the paper,
-- -- this estimator is biased and I did not observe any improvements compared
-- -- to the direct estimator implemented above.
--
-- -- See Xie2010 p. 153, top right.
-- sssCalculateMarginalLikelihood' :: [Point] -> [VU.Vector Likelihood] -> MarginalLikelihood
-- sssCalculateMarginalLikelihood' xs lhss = Exp $ sum $ zipWith3 f xs (tail xs) lhss
--   where f :: Point -> Point -> VU.Vector Likelihood -> Double
--         -- f beta_{k-1} beta_k lhs_{k-1}
--         f bkm1 bk lhs = dbeta * llhMax + log (n1 * VU.sum lhsNormedPowered)
--           where dbeta = bk - bkm1
--                 llhMax = ln $ VU.maximum lhs
--                 n1 = recip $ fromIntegral $ VU.length lhs
--                 llhs = VU.map ln lhs
--                 llhsNormed = VU.map (\x -> x - llhMax) llhs
--                 lhsNormedPowered = VU.map (\x -> exp $ dbeta * x) llhsNormed
sssWrapper ::
  (ToJSON a) =>
  MLSettings ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  a ->
  StdGen ->
  ML MarginalLikelihood
sssWrapper :: forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
sssWrapper MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"Stepping stone sampling."
  [Vector MarginalLikelihood]
logLhss <- ParallelizationMode
-> NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
forall a.
ToJSON a =>
ParallelizationMode
-> NPoints
-> [(Int, Double)]
-> ExecutionMode
-> Verbosity
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML [Vector MarginalLikelihood]
mlRunPar ParallelizationMode
pm NPoints
k ([Int] -> [Double] -> [(Int, Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..] [Double]
bsForward') ExecutionMode
em Verbosity
vb PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g
  ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"The last point does not need to be sampled with stepping stone sampling."
  ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"sssWrapper: Calculate marginal likelihood."
  MarginalLikelihood -> ML MarginalLikelihood
forall a. a -> ReaderT (Environment MLSettings) IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (MarginalLikelihood -> ML MarginalLikelihood)
-> MarginalLikelihood -> ML MarginalLikelihood
forall a b. (a -> b) -> a -> b
$ [Double] -> [Vector MarginalLikelihood] -> MarginalLikelihood
sssCalculateMarginalLikelihood [Double]
bsForward [Vector MarginalLikelihood]
logLhss
  where
    k :: NPoints
k = MLSettings -> NPoints
mlNPoints MLSettings
s
    bsForward :: [Double]
bsForward = NPoints -> [Double]
getPoints NPoints
k
    bsForward' :: [Double]
bsForward' = [Double] -> [Double]
forall a. HasCallStack => [a] -> [a]
init [Double]
bsForward
    em :: ExecutionMode
em = MLSettings -> ExecutionMode
mlExecutionMode MLSettings
s
    pm :: ParallelizationMode
pm = MLSettings -> ParallelizationMode
mlParallelizationMode MLSettings
s
    vb :: Verbosity
vb = MLSettings -> Verbosity
mlVerbosity MLSettings
s

-- | Estimate the marginal likelihood.
marginalLikelihood ::
  (ToJSON a) =>
  MLSettings ->
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  InitialState a ->
  StdGen ->
  IO MarginalLikelihood
marginalLikelihood :: forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> IO MarginalLikelihood
marginalLikelihood MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g = do
  -- Initialize.
  Environment MLSettings
e <- MLSettings -> IO (Environment MLSettings)
forall s.
(HasAnalysisName s, HasExecutionMode s, HasLogMode s,
 HasVerbosity s) =>
s -> IO (Environment s)
initializeEnvironment MLSettings
s

  -- Create marginal likelihood analysis directory.
  let n :: String
n = AnalysisName -> String
fromAnalysisName (AnalysisName -> String) -> AnalysisName -> String
forall a b. (a -> b) -> a -> b
$ MLSettings -> AnalysisName
mlAnalysisName MLSettings
s
  Bool -> String -> IO ()
createDirectoryIfMissing Bool
True String
n

  -- Run.
  ML MarginalLikelihood
-> Environment MLSettings -> IO MarginalLikelihood
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT
    ( do
        Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasStartingTime e, HasVerbosity e) =>
Logger e ()
logInfoStartingTime
        ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logInfoB ByteString
"Estimate marginal likelihood."
        ByteString -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
ByteString -> Logger e ()
logDebugB ByteString
"marginalLikelihood: The marginal likelihood settings are:"
        String -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logDebugS (String -> Logger (Environment MLSettings) ())
-> String -> Logger (Environment MLSettings) ()
forall a b. (a -> b) -> a -> b
$ MLSettings -> String
forall a. Show a => a -> String
show MLSettings
s
        MarginalLikelihood
val <- case MLSettings -> MLAlgorithm
mlAlgorithm MLSettings
s of
          MLAlgorithm
ThermodynamicIntegration -> MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
tiWrapper MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g
          MLAlgorithm
SteppingStoneSampling -> MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
forall a.
ToJSON a =>
MLSettings
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> a
-> StdGen
-> ML MarginalLikelihood
sssWrapper MLSettings
s PriorFunction a
prf PriorFunction a
lhf Cycle a
cc Monitor a
mn a
i0 StdGen
g
        String -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logInfoS (String -> Logger (Environment MLSettings) ())
-> String -> Logger (Environment MLSettings) ()
forall a b. (a -> b) -> a -> b
$ String
"Marginal log likelihood: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Double -> String
forall a. Show a => a -> String
show (MarginalLikelihood -> Double
forall a. Log a -> a
ln MarginalLikelihood
val)
        -- TODO (low): Simulation variance.
        String -> Logger (Environment MLSettings) ()
forall e.
(HasLock e, HasLogHandles e, HasVerbosity e) =>
String -> Logger e ()
logInfoS String
"The simulation variance is not yet available."
        MarginalLikelihood -> ML MarginalLikelihood
forall a. a -> ReaderT (Environment MLSettings) IO a
forall (m :: * -> *) a. Monad m => a -> m a
return MarginalLikelihood
val
    )
    Environment MLSettings
e