{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Tue Jun 16 10:18:54 2020.
--
-- Save and load chains. It is easy to save and restore the current state and
-- likelihood (or the trace), but it is not feasible to store all the proposals
-- and so on, so they have to be provided again when continuing a run.

-- |
-- Module      :  Mcmc.Chain.Save
-- Description :  Save and load a Markov chain
-- Copyright   :  (c) Dominik Schrempf, 2021
-- License     :  GPL-3.0-or-later
module Mcmc.Chain.Save
  ( SavedChain (..),
    toSavedChain,
    fromSavedChain,
  )
where

import Control.Monad
import Data.Aeson
import Data.Aeson.TH
import Data.List hiding (cycle)
import qualified Data.Map as M
import Data.Maybe
import qualified Data.Stack.Circular as C
import qualified Data.Vector as VB
import qualified Data.Vector.Unboxed as VU
import Data.Word
import Mcmc.Chain.Chain
import Mcmc.Chain.Link
import Mcmc.Chain.Trace
import Mcmc.Internal.Random
import Mcmc.Monitor
import Mcmc.Proposal
import Prelude hiding (cycle)

-- | Storable values of a Markov chain.
--
-- See 'toSavedChain'.
data SavedChain a = SavedChain
  { SavedChain a -> Maybe Int
savedId :: Maybe Int,
    SavedChain a -> Link a
savedLink :: Link a,
    SavedChain a -> Int
savedIteration :: Int,
    SavedChain a -> Stack Vector (Link a)
savedTrace :: C.Stack VB.Vector (Link a),
    SavedChain a -> Acceptance Int
savedAcceptance :: Acceptance Int,
    SavedChain a -> Vector Word32
savedSeed :: VU.Vector Word32,
    SavedChain a -> [Maybe TuningParameter]
savedTuningParameters :: [Maybe TuningParameter]
  }
  deriving (SavedChain a -> SavedChain a -> Bool
(SavedChain a -> SavedChain a -> Bool)
-> (SavedChain a -> SavedChain a -> Bool) -> Eq (SavedChain a)
forall a. Eq a => SavedChain a -> SavedChain a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SavedChain a -> SavedChain a -> Bool
$c/= :: forall a. Eq a => SavedChain a -> SavedChain a -> Bool
== :: SavedChain a -> SavedChain a -> Bool
$c== :: forall a. Eq a => SavedChain a -> SavedChain a -> Bool
Eq, ReadPrec [SavedChain a]
ReadPrec (SavedChain a)
Int -> ReadS (SavedChain a)
ReadS [SavedChain a]
(Int -> ReadS (SavedChain a))
-> ReadS [SavedChain a]
-> ReadPrec (SavedChain a)
-> ReadPrec [SavedChain a]
-> Read (SavedChain a)
forall a. Read a => ReadPrec [SavedChain a]
forall a. Read a => ReadPrec (SavedChain a)
forall a. Read a => Int -> ReadS (SavedChain a)
forall a. Read a => ReadS [SavedChain a]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [SavedChain a]
$creadListPrec :: forall a. Read a => ReadPrec [SavedChain a]
readPrec :: ReadPrec (SavedChain a)
$creadPrec :: forall a. Read a => ReadPrec (SavedChain a)
readList :: ReadS [SavedChain a]
$creadList :: forall a. Read a => ReadS [SavedChain a]
readsPrec :: Int -> ReadS (SavedChain a)
$creadsPrec :: forall a. Read a => Int -> ReadS (SavedChain a)
Read, Int -> SavedChain a -> ShowS
[SavedChain a] -> ShowS
SavedChain a -> String
(Int -> SavedChain a -> ShowS)
-> (SavedChain a -> String)
-> ([SavedChain a] -> ShowS)
-> Show (SavedChain a)
forall a. Show a => Int -> SavedChain a -> ShowS
forall a. Show a => [SavedChain a] -> ShowS
forall a. Show a => SavedChain a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SavedChain a] -> ShowS
$cshowList :: forall a. Show a => [SavedChain a] -> ShowS
show :: SavedChain a -> String
$cshow :: forall a. Show a => SavedChain a -> String
showsPrec :: Int -> SavedChain a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> SavedChain a -> ShowS
Show)

$(deriveJSON defaultOptions ''SavedChain)

-- | Save a chain.
toSavedChain ::
  Chain a ->
  IO (SavedChain a)
toSavedChain :: Chain a -> IO (SavedChain a)
toSavedChain (Chain Maybe Int
ci Link a
it Int
i Trace a
tr Acceptance (Proposal a)
ac GenIO
g Int
_ PriorFunction a
_ PriorFunction a
_ Cycle a
cc Monitor a
_) = do
  Vector Word32
g' <- GenIO -> IO (Vector Word32)
saveGen GenIO
g
  Stack Vector (Link a)
tr' <- Trace a -> IO (Stack Vector (Link a))
forall a. Trace a -> IO (Stack Vector (Link a))
freezeT Trace a
tr
  SavedChain a -> IO (SavedChain a)
forall (m :: * -> *) a. Monad m => a -> m a
return (SavedChain a -> IO (SavedChain a))
-> SavedChain a -> IO (SavedChain a)
forall a b. (a -> b) -> a -> b
$ Maybe Int
-> Link a
-> Int
-> Stack Vector (Link a)
-> Acceptance Int
-> Vector Word32
-> [Maybe TuningParameter]
-> SavedChain a
forall a.
Maybe Int
-> Link a
-> Int
-> Stack Vector (Link a)
-> Acceptance Int
-> Vector Word32
-> [Maybe TuningParameter]
-> SavedChain a
SavedChain Maybe Int
ci Link a
it Int
i Stack Vector (Link a)
tr' Acceptance Int
ac' Vector Word32
g' [Maybe TuningParameter]
ts
  where
    ps :: [Proposal a]
ps = Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc
    ac' :: Acceptance Int
ac' = [Proposal a] -> [Int] -> Acceptance (Proposal a) -> Acceptance Int
forall k1 k2.
(Ord k1, Ord k2) =>
[k1] -> [k2] -> Acceptance k1 -> Acceptance k2
transformKeysA [Proposal a]
ps [Int
0 ..] Acceptance (Proposal a)
ac
    ts :: [Maybe TuningParameter]
ts = [(Tuner a -> TuningParameter)
-> Maybe (Tuner a) -> Maybe TuningParameter
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Tuner a -> TuningParameter
forall a. Tuner a -> TuningParameter
tParam Maybe (Tuner a)
mt | Maybe (Tuner a)
mt <- (Proposal a -> Maybe (Tuner a))
-> [Proposal a] -> [Maybe (Tuner a)]
forall a b. (a -> b) -> [a] -> [b]
map Proposal a -> Maybe (Tuner a)
forall a. Proposal a -> Maybe (Tuner a)
prTuner [Proposal a]
ps]

-- | Load a saved chain.
--
-- Recompute and check the prior and likelihood for the last state because the
-- functions may have changed. Of course, we cannot test for the same function,
-- but having the same prior and likelihood at the last state is already a good
-- indicator.
fromSavedChain ::
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  SavedChain a ->
  IO (Chain a)
fromSavedChain :: PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
fromSavedChain PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn (SavedChain Maybe Int
ci Link a
it Int
i Stack Vector (Link a)
tr Acceptance Int
ac' Vector Word32
g' [Maybe TuningParameter]
ts)
  | PriorFunction a
pr (Link a -> a
forall a. Link a -> a
state Link a
it) Prior -> Prior -> Bool
forall a. Eq a => a -> a -> Bool
/= Link a -> Prior
forall a. Link a -> Prior
prior Link a
it =
    let msg :: String
msg =
          [String] -> String
unlines
            [ String
"fromSave: Provided prior function does not match the saved prior.",
              String
"fromSave: Current prior:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Prior -> String
forall a. Show a => a -> String
show (Link a -> Prior
forall a. Link a -> Prior
prior Link a
it) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
".",
              String
"fromSave: Given prior:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Prior -> String
forall a. Show a => a -> String
show (PriorFunction a
pr PriorFunction a -> PriorFunction a
forall a b. (a -> b) -> a -> b
$ Link a -> a
forall a. Link a -> a
state Link a
it) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"."
            ]
     in String -> IO (Chain a)
forall a. HasCallStack => String -> a
error String
msg
  | PriorFunction a
lh (Link a -> a
forall a. Link a -> a
state Link a
it) Prior -> Prior -> Bool
forall a. Eq a => a -> a -> Bool
/= Link a -> Prior
forall a. Link a -> Prior
likelihood Link a
it =
    String -> IO (Chain a)
forall a. HasCallStack => String -> a
error String
"fromSave: Provided likelihood function does not match the saved likelihood."
  | Bool
otherwise = do
    Gen RealWorld
g <- Vector Word32 -> IO GenIO
loadGen Vector Word32
g'
    Trace a
tr' <- Stack Vector (Link a) -> IO (Trace a)
forall a. Stack Vector (Link a) -> IO (Trace a)
thawT Stack Vector (Link a)
tr
    Chain a -> IO (Chain a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Chain a -> IO (Chain a)) -> Chain a -> IO (Chain a)
forall a b. (a -> b) -> a -> b
$ Maybe Int
-> Link a
-> Int
-> Trace a
-> Acceptance (Proposal a)
-> GenIO
-> Int
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> Chain a
forall a.
Maybe Int
-> Link a
-> Int
-> Trace a
-> Acceptance (Proposal a)
-> GenIO
-> Int
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> Chain a
Chain Maybe Int
ci Link a
it Int
i Trace a
tr' Acceptance (Proposal a)
ac Gen RealWorld
GenIO
g Int
i PriorFunction a
pr PriorFunction a
lh Cycle a
cc' Monitor a
mn
  where
    ac :: Acceptance (Proposal a)
ac = [Int] -> [Proposal a] -> Acceptance Int -> Acceptance (Proposal a)
forall k1 k2.
(Ord k1, Ord k2) =>
[k1] -> [k2] -> Acceptance k1 -> Acceptance k2
transformKeysA [Int
0 ..] (Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc) Acceptance Int
ac'
    getTuningF :: Maybe a -> b -> a
getTuningF Maybe a
mt = case Maybe a
mt of
      Maybe a
Nothing -> a -> b -> a
forall a b. a -> b -> a
const a
1.0
      Just a
t -> a -> b -> a
forall a b. a -> b -> a
const a
t
    cc' :: Cycle a
cc' =
      Map (Proposal a) (TuningParameter -> TuningParameter)
-> Cycle a -> Cycle a
forall a.
Map (Proposal a) (TuningParameter -> TuningParameter)
-> Cycle a -> Cycle a
tuneCycle
        ( (Maybe TuningParameter -> TuningParameter -> TuningParameter)
-> Map (Proposal a) (Maybe TuningParameter)
-> Map (Proposal a) (TuningParameter -> TuningParameter)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Maybe TuningParameter -> TuningParameter -> TuningParameter
forall a b. Fractional a => Maybe a -> b -> a
getTuningF (Map (Proposal a) (Maybe TuningParameter)
 -> Map (Proposal a) (TuningParameter -> TuningParameter))
-> Map (Proposal a) (Maybe TuningParameter)
-> Map (Proposal a) (TuningParameter -> TuningParameter)
forall a b. (a -> b) -> a -> b
$
            [(Proposal a, Maybe TuningParameter)]
-> Map (Proposal a) (Maybe TuningParameter)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Proposal a, Maybe TuningParameter)]
 -> Map (Proposal a) (Maybe TuningParameter))
-> [(Proposal a, Maybe TuningParameter)]
-> Map (Proposal a) (Maybe TuningParameter)
forall a b. (a -> b) -> a -> b
$
              [Proposal a]
-> [Maybe TuningParameter] -> [(Proposal a, Maybe TuningParameter)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc) [Maybe TuningParameter]
ts
        )
        Cycle a
cc