{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Tue Jun 16 10:18:54 2020.
--
-- Save and load chains. It is easy to save and restore the current state and
-- likelihood (or the trace), but it is not feasible to store all the proposals
-- and so on, so they have to be provided again when continuing a run.

-- |
-- Module      :  Mcmc.Chain.Save
-- Description :  Save and load a Markov chain
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
module Mcmc.Chain.Save
  ( SavedChain (..),
    toSavedChain,
    fromSavedChain,
    fromSavedChainUnsafe,
  )
where

import Control.Monad
import Data.Aeson
import Data.Aeson.TH
import Data.List hiding (cycle)
import Data.Maybe
import qualified Data.Stack.Circular as C
import qualified Data.Vector as VB
import Data.Word
import Mcmc.Acceptance
import Mcmc.Chain.Chain
import Mcmc.Chain.Link
import Mcmc.Chain.Trace
import Mcmc.Cycle
import Mcmc.Internal.Random
import Mcmc.Likelihood
import Mcmc.Monitor
import Mcmc.Prior
import Mcmc.Proposal
import Prelude hiding (cycle)

-- | Storable values of a Markov chain.
--
-- See 'toSavedChain'.
data SavedChain a = SavedChain
  { SavedChain a -> Maybe Int
savedId :: Maybe Int,
    SavedChain a -> Link a
savedLink :: Link a,
    SavedChain a -> Int
savedIteration :: Int,
    SavedChain a -> Stack Vector (Link a)
savedTrace :: C.Stack VB.Vector (Link a),
    SavedChain a -> Acceptance Int
savedAcceptance :: Acceptance Int,
    SavedChain a -> (Word64, Word64)
savedSeed :: (Word64, Word64),
    SavedChain a
-> [Maybe (TuningParameter, AuxiliaryTuningParameters)]
savedTuningParameters :: [Maybe (TuningParameter, AuxiliaryTuningParameters)]
  }
  deriving (SavedChain a -> SavedChain a -> Bool
(SavedChain a -> SavedChain a -> Bool)
-> (SavedChain a -> SavedChain a -> Bool) -> Eq (SavedChain a)
forall a. Eq a => SavedChain a -> SavedChain a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SavedChain a -> SavedChain a -> Bool
$c/= :: forall a. Eq a => SavedChain a -> SavedChain a -> Bool
== :: SavedChain a -> SavedChain a -> Bool
$c== :: forall a. Eq a => SavedChain a -> SavedChain a -> Bool
Eq, Int -> SavedChain a -> ShowS
[SavedChain a] -> ShowS
SavedChain a -> String
(Int -> SavedChain a -> ShowS)
-> (SavedChain a -> String)
-> ([SavedChain a] -> ShowS)
-> Show (SavedChain a)
forall a. Show a => Int -> SavedChain a -> ShowS
forall a. Show a => [SavedChain a] -> ShowS
forall a. Show a => SavedChain a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SavedChain a] -> ShowS
$cshowList :: forall a. Show a => [SavedChain a] -> ShowS
show :: SavedChain a -> String
$cshow :: forall a. Show a => SavedChain a -> String
showsPrec :: Int -> SavedChain a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> SavedChain a -> ShowS
Show)

$(deriveJSON defaultOptions ''SavedChain)

-- | Save a chain.
toSavedChain ::
  Chain a ->
  IO (SavedChain a)
toSavedChain :: Chain a -> IO (SavedChain a)
toSavedChain (Chain Maybe Int
ci Link a
it Int
i Trace a
tr Acceptance (Proposal a)
ac IOGenM StdGen
g Int
_ PriorFunction a
_ PriorFunction a
_ Cycle a
cc Monitor a
_) = do
  (Word64, Word64)
g' <- IOGenM StdGen -> IO (Word64, Word64)
saveGen IOGenM StdGen
g
  Stack Vector (Link a)
tr' <- Trace a -> IO (Stack Vector (Link a))
forall a. Trace a -> IO (Stack Vector (Link a))
freezeT Trace a
tr
  SavedChain a -> IO (SavedChain a)
forall (m :: * -> *) a. Monad m => a -> m a
return (SavedChain a -> IO (SavedChain a))
-> SavedChain a -> IO (SavedChain a)
forall a b. (a -> b) -> a -> b
$ Maybe Int
-> Link a
-> Int
-> Stack Vector (Link a)
-> Acceptance Int
-> (Word64, Word64)
-> [Maybe (TuningParameter, AuxiliaryTuningParameters)]
-> SavedChain a
forall a.
Maybe Int
-> Link a
-> Int
-> Stack Vector (Link a)
-> Acceptance Int
-> (Word64, Word64)
-> [Maybe (TuningParameter, AuxiliaryTuningParameters)]
-> SavedChain a
SavedChain Maybe Int
ci Link a
it Int
i Stack Vector (Link a)
tr' Acceptance Int
ac' (Word64, Word64)
g' [Maybe (TuningParameter, AuxiliaryTuningParameters)]
ts
  where
    ps :: [Proposal a]
ps = Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc
    ac' :: Acceptance Int
ac' = [Proposal a] -> [Int] -> Acceptance (Proposal a) -> Acceptance Int
forall k1 k2.
(Ord k1, Ord k2) =>
[k1] -> [k2] -> Acceptance k1 -> Acceptance k2
transformKeysA [Proposal a]
ps [Int
0 ..] Acceptance (Proposal a)
ac
    ts :: [Maybe (TuningParameter, AuxiliaryTuningParameters)]
ts =
      [ (\Tuner a
t -> (Tuner a -> TuningParameter
forall a. Tuner a -> TuningParameter
tTuningParameter Tuner a
t, Tuner a -> AuxiliaryTuningParameters
forall a. Tuner a -> AuxiliaryTuningParameters
tAuxiliaryTuningParameters Tuner a
t)) (Tuner a -> (TuningParameter, AuxiliaryTuningParameters))
-> Maybe (Tuner a)
-> Maybe (TuningParameter, AuxiliaryTuningParameters)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Tuner a)
mt
        | Maybe (Tuner a)
mt <- (Proposal a -> Maybe (Tuner a))
-> [Proposal a] -> [Maybe (Tuner a)]
forall a b. (a -> b) -> [a] -> [b]
map Proposal a -> Maybe (Tuner a)
forall a. Proposal a -> Maybe (Tuner a)
prTuner [Proposal a]
ps
      ]

-- | Load a saved chain.
--
-- Perform some safety checks:
--
-- Check that the number of proposals is equal.
--
-- Recompute and check the prior and likelihood for the last state because the
-- functions may have changed. Of course, we cannot test for the same function,
-- but having the same prior and likelihood at the last state is already a good
-- indicator.
fromSavedChain ::
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  SavedChain a ->
  IO (Chain a)
fromSavedChain :: PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
fromSavedChain PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn SavedChain a
sv
  | PriorFunction a
pr (Link a -> a
forall a. Link a -> a
state Link a
it) PriorG TuningParameter -> PriorG TuningParameter -> Bool
forall a. Eq a => a -> a -> Bool
/= Link a -> PriorG TuningParameter
forall a. Link a -> PriorG TuningParameter
prior Link a
it =
      let msg :: String
msg =
            [String] -> String
unlines
              [ String
"fromSave: Provided prior function does not match the saved prior.",
                String
"fromSave: Current prior:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> PriorG TuningParameter -> String
forall a. Show a => a -> String
show (Link a -> PriorG TuningParameter
forall a. Link a -> PriorG TuningParameter
prior Link a
it) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
".",
                String
"fromSave: Given prior:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> PriorG TuningParameter -> String
forall a. Show a => a -> String
show (PriorFunction a
pr PriorFunction a -> PriorFunction a
forall a b. (a -> b) -> a -> b
$ Link a -> a
forall a. Link a -> a
state Link a
it) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"."
              ]
       in String -> IO (Chain a)
forall a. HasCallStack => String -> a
error String
msg
  | PriorFunction a
lh (Link a -> a
forall a. Link a -> a
state Link a
it) PriorG TuningParameter -> PriorG TuningParameter -> Bool
forall a. Eq a => a -> a -> Bool
/= Link a -> PriorG TuningParameter
forall a. Link a -> PriorG TuningParameter
likelihood Link a
it =
      let msg :: String
msg =
            [String] -> String
unlines
              [ String
"fromSave: Provided likelihood function does not match the saved likelihood function.",
                String
"fromSave: Current likelihood:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> PriorG TuningParameter -> String
forall a. Show a => a -> String
show (Link a -> PriorG TuningParameter
forall a. Link a -> PriorG TuningParameter
likelihood Link a
it) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
".",
                String
"fromSave: Given likelihood:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> PriorG TuningParameter -> String
forall a. Show a => a -> String
show (PriorFunction a
lh PriorFunction a -> PriorFunction a
forall a b. (a -> b) -> a -> b
$ Link a -> a
forall a. Link a -> a
state Link a
it) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"."
              ]
       in String -> IO (Chain a)
forall a. HasCallStack => String -> a
error String
msg
  | Map Int AcceptanceCounts -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Acceptance Int -> Map Int AcceptanceCounts
forall k. Acceptance k -> Map k AcceptanceCounts
fromAcceptance Acceptance Int
ac) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Proposal a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc) =
      let msg :: String
msg =
            [String] -> String
unlines
              [ String
"fromSave: The number of proposals does not match.",
                String
"fromSave: Number of saved proposals:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show (Map Int AcceptanceCounts -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Map Int AcceptanceCounts -> Int)
-> Map Int AcceptanceCounts -> Int
forall a b. (a -> b) -> a -> b
$ Acceptance Int -> Map Int AcceptanceCounts
forall k. Acceptance k -> Map k AcceptanceCounts
fromAcceptance Acceptance Int
ac) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
".",
                String
"fromSave: Number of given proposals:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([Proposal a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Proposal a] -> Int) -> [Proposal a] -> Int
forall a b. (a -> b) -> a -> b
$ Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"."
              ]
       in String -> IO (Chain a)
forall a. HasCallStack => String -> a
error String
msg
  | Bool
otherwise = PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
forall a.
PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
fromSavedChainUnsafe PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn SavedChain a
sv
  where
    it :: Link a
it = SavedChain a -> Link a
forall a. SavedChain a -> Link a
savedLink SavedChain a
sv
    ac :: Acceptance Int
ac = SavedChain a -> Acceptance Int
forall a. SavedChain a -> Acceptance Int
savedAcceptance SavedChain a
sv

-- | See 'fromSavedChain' but do not perform sanity checks. Useful when
-- restarting a run with changed prior function, likelihood function or
-- proposals.
fromSavedChainUnsafe ::
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  SavedChain a ->
  IO (Chain a)
fromSavedChainUnsafe :: PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
fromSavedChainUnsafe PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn (SavedChain Maybe Int
ci Link a
it Int
i Stack Vector (Link a)
tr Acceptance Int
ac' (Word64, Word64)
g' [Maybe (TuningParameter, AuxiliaryTuningParameters)]
ts) = do
  IOGenM StdGen
g <- (Word64, Word64) -> IO (IOGenM StdGen)
loadGen (Word64, Word64)
g'
  Trace a
tr' <- Stack Vector (Link a) -> IO (Trace a)
forall a. Stack Vector (Link a) -> IO (Trace a)
thawT Stack Vector (Link a)
tr
  Chain a -> IO (Chain a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Chain a -> IO (Chain a)) -> Chain a -> IO (Chain a)
forall a b. (a -> b) -> a -> b
$ Maybe Int
-> Link a
-> Int
-> Trace a
-> Acceptance (Proposal a)
-> IOGenM StdGen
-> Int
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> Chain a
forall a.
Maybe Int
-> Link a
-> Int
-> Trace a
-> Acceptance (Proposal a)
-> IOGenM StdGen
-> Int
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> Chain a
Chain Maybe Int
ci Link a
it Int
i Trace a
tr' Acceptance (Proposal a)
ac IOGenM StdGen
g Int
i PriorFunction a
pr PriorFunction a
lh Cycle a
cc' Monitor a
mn
  where
    ac :: Acceptance (Proposal a)
ac = [Int] -> [Proposal a] -> Acceptance Int -> Acceptance (Proposal a)
forall k1 k2.
(Ord k1, Ord k2) =>
[k1] -> [k2] -> Acceptance k1 -> Acceptance k2
transformKeysA [Int
0 ..] (Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc) Acceptance Int
ac'
    tunePs :: Maybe (TuningParameter, AuxiliaryTuningParameters)
-> Proposal a -> Proposal a
tunePs Maybe (TuningParameter, AuxiliaryTuningParameters)
mt Proposal a
p = case Maybe (TuningParameter, AuxiliaryTuningParameters)
mt of
      Maybe (TuningParameter, AuxiliaryTuningParameters)
Nothing -> Proposal a
p
      Just (TuningParameter
x, AuxiliaryTuningParameters
xs) -> (String -> Proposal a)
-> (Proposal a -> Proposal a)
-> Either String (Proposal a)
-> Proposal a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> Proposal a
forall a. HasCallStack => String -> a
error (String -> Proposal a) -> ShowS -> String -> Proposal a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
forall a. a
err)) Proposal a -> Proposal a
forall a. a -> a
id (Either String (Proposal a) -> Proposal a)
-> Either String (Proposal a) -> Proposal a
forall a b. (a -> b) -> a -> b
$ TuningParameter
-> AuxiliaryTuningParameters
-> Proposal a
-> Either String (Proposal a)
forall a.
TuningParameter
-> AuxiliaryTuningParameters
-> Proposal a
-> Either String (Proposal a)
tuneWithTuningParameters TuningParameter
x AuxiliaryTuningParameters
xs Proposal a
p
    err :: a
err = String -> a
forall a. HasCallStack => String -> a
error String
"\nfromSavedChain: Proposal with stored tuning parameters is not tunable."
    ps :: [Proposal a]
ps = Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc
    cc' :: Cycle a
cc' = Cycle a
cc {ccProposals :: [Proposal a]
ccProposals = (Maybe (TuningParameter, AuxiliaryTuningParameters)
 -> Proposal a -> Proposal a)
-> [Maybe (TuningParameter, AuxiliaryTuningParameters)]
-> [Proposal a]
-> [Proposal a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Maybe (TuningParameter, AuxiliaryTuningParameters)
-> Proposal a -> Proposal a
forall a.
Maybe (TuningParameter, AuxiliaryTuningParameters)
-> Proposal a -> Proposal a
tunePs [Maybe (TuningParameter, AuxiliaryTuningParameters)]
ts [Proposal a]
ps}