{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

-- |
-- Module      :  Mcmc.Chain.Save
-- Description :  Save and load a Markov chain
-- Copyright   :  (c) Dominik Schrempf, 2020
-- License     :  GPL-3.0-or-later
--

-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Tue Jun 16 10:18:54 2020.
--
-- Save and load chains. It is easy to save and restore the current state and
-- likelihood (or the trace), but it is not feasible to store all the proposals
-- and so on, so they have to be provided again when continuing a run.
module Mcmc.Chain.Save
  ( SavedChain (..),
    toSavedChain,
    fromSavedChain,
  )
where

import Control.Monad
import Data.Aeson
import Data.Aeson.TH
import Data.List hiding (cycle)
import qualified Data.Map as M
import Data.Maybe
import qualified Data.Vector as VB
import qualified Data.Vector.Unboxed as VU
import Data.Word
import Mcmc.Chain.Chain
import Mcmc.Chain.Link
import Mcmc.Chain.Trace
import Mcmc.Internal.Random
import Mcmc.Monitor
import Mcmc.Proposal
import Prelude hiding (cycle)
import qualified Data.Stack.Circular as C

-- | Storable values of a Markov chain.
--
-- See 'toSavedChain'.
data SavedChain a = SavedChain
  {
    SavedChain a -> Int
savedId :: Int,
    SavedChain a -> Link a
savedLink :: Link a,
    SavedChain a -> Int
savedIteration :: Int,
    SavedChain a -> Stack Vector (Link a)
savedTrace :: C.Stack VB.Vector (Link a),
    SavedChain a -> Acceptance Int
savedAcceptance :: Acceptance Int,
    SavedChain a -> Vector Word32
savedSeed :: VU.Vector Word32,
    SavedChain a -> [Maybe Double]
savedTuningParameters :: [Maybe Double]
  }
  deriving (SavedChain a -> SavedChain a -> Bool
(SavedChain a -> SavedChain a -> Bool)
-> (SavedChain a -> SavedChain a -> Bool) -> Eq (SavedChain a)
forall a. Eq a => SavedChain a -> SavedChain a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SavedChain a -> SavedChain a -> Bool
$c/= :: forall a. Eq a => SavedChain a -> SavedChain a -> Bool
== :: SavedChain a -> SavedChain a -> Bool
$c== :: forall a. Eq a => SavedChain a -> SavedChain a -> Bool
Eq, ReadPrec [SavedChain a]
ReadPrec (SavedChain a)
Int -> ReadS (SavedChain a)
ReadS [SavedChain a]
(Int -> ReadS (SavedChain a))
-> ReadS [SavedChain a]
-> ReadPrec (SavedChain a)
-> ReadPrec [SavedChain a]
-> Read (SavedChain a)
forall a. Read a => ReadPrec [SavedChain a]
forall a. Read a => ReadPrec (SavedChain a)
forall a. Read a => Int -> ReadS (SavedChain a)
forall a. Read a => ReadS [SavedChain a]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [SavedChain a]
$creadListPrec :: forall a. Read a => ReadPrec [SavedChain a]
readPrec :: ReadPrec (SavedChain a)
$creadPrec :: forall a. Read a => ReadPrec (SavedChain a)
readList :: ReadS [SavedChain a]
$creadList :: forall a. Read a => ReadS [SavedChain a]
readsPrec :: Int -> ReadS (SavedChain a)
$creadsPrec :: forall a. Read a => Int -> ReadS (SavedChain a)
Read, Int -> SavedChain a -> ShowS
[SavedChain a] -> ShowS
SavedChain a -> String
(Int -> SavedChain a -> ShowS)
-> (SavedChain a -> String)
-> ([SavedChain a] -> ShowS)
-> Show (SavedChain a)
forall a. Show a => Int -> SavedChain a -> ShowS
forall a. Show a => [SavedChain a] -> ShowS
forall a. Show a => SavedChain a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SavedChain a] -> ShowS
$cshowList :: forall a. Show a => [SavedChain a] -> ShowS
show :: SavedChain a -> String
$cshow :: forall a. Show a => SavedChain a -> String
showsPrec :: Int -> SavedChain a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> SavedChain a -> ShowS
Show)

$(deriveJSON defaultOptions ''SavedChain)

-- | Save a chain.
toSavedChain ::
  Chain a ->
  IO (SavedChain a)
toSavedChain :: Chain a -> IO (SavedChain a)
toSavedChain (Chain Int
ci Link a
it Int
i Trace a
tr Acceptance (Proposal a)
ac GenIO
g Int
_ PriorFunction a
_ PriorFunction a
_ Cycle a
cc Monitor a
_) = do
  Vector Word32
g' <- GenIO -> IO (Vector Word32)
saveGen GenIO
g
  Stack Vector (Link a)
tr' <- Trace a -> IO (Stack Vector (Link a))
forall a. Trace a -> IO (Stack Vector (Link a))
freezeT Trace a
tr
  SavedChain a -> IO (SavedChain a)
forall (m :: * -> *) a. Monad m => a -> m a
return (SavedChain a -> IO (SavedChain a))
-> SavedChain a -> IO (SavedChain a)
forall a b. (a -> b) -> a -> b
$ Int
-> Link a
-> Int
-> Stack Vector (Link a)
-> Acceptance Int
-> Vector Word32
-> [Maybe Double]
-> SavedChain a
forall a.
Int
-> Link a
-> Int
-> Stack Vector (Link a)
-> Acceptance Int
-> Vector Word32
-> [Maybe Double]
-> SavedChain a
SavedChain Int
ci Link a
it Int
i Stack Vector (Link a)
tr' Acceptance Int
ac' Vector Word32
g' [Maybe Double]
ts
  where
    ps :: [Proposal a]
ps = Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc
    ac' :: Acceptance Int
ac' = [Proposal a] -> [Int] -> Acceptance (Proposal a) -> Acceptance Int
forall k1 k2.
(Ord k1, Ord k2) =>
[k1] -> [k2] -> Acceptance k1 -> Acceptance k2
transformKeysA [Proposal a]
ps [Int
0 ..] Acceptance (Proposal a)
ac
    ts :: [Maybe Double]
ts = [(Tuner a -> Double) -> Maybe (Tuner a) -> Maybe Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Tuner a -> Double
forall a. Tuner a -> Double
tParam Maybe (Tuner a)
mt | Maybe (Tuner a)
mt <- (Proposal a -> Maybe (Tuner a))
-> [Proposal a] -> [Maybe (Tuner a)]
forall a b. (a -> b) -> [a] -> [b]
map Proposal a -> Maybe (Tuner a)
forall a. Proposal a -> Maybe (Tuner a)
pTuner [Proposal a]
ps]

-- | Load a saved chain.
--
-- Recompute and check the prior and likelihood for the last state because the
-- functions may have changed. Of course, we cannot test for the same function,
-- but having the same prior and likelihood at the last state is already a good
-- indicator.
fromSavedChain ::
  PriorFunction a ->
  LikelihoodFunction a ->
  Cycle a ->
  Monitor a ->
  SavedChain a ->
  IO (Chain a)
fromSavedChain :: PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
fromSavedChain PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn (SavedChain Int
ci Link a
it Int
i Stack Vector (Link a)
tr Acceptance Int
ac' Vector Word32
g' [Maybe Double]
ts)
  | PriorFunction a
pr (Link a -> a
forall a. Link a -> a
state Link a
it) Log Double -> Log Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Link a -> Log Double
forall a. Link a -> Log Double
prior Link a
it =
    String -> IO (Chain a)
forall a. HasCallStack => String -> a
error String
"fromSave: Provided prior function does not match the saved prior."
  | PriorFunction a
lh (Link a -> a
forall a. Link a -> a
state Link a
it) Log Double -> Log Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Link a -> Log Double
forall a. Link a -> Log Double
likelihood Link a
it =
    String -> IO (Chain a)
forall a. HasCallStack => String -> a
error String
"fromSave: Provided likelihood function does not match the saved likelihood."
  | Bool
otherwise = do
      Gen RealWorld
g <- Vector Word32 -> IO GenIO
loadGen Vector Word32
g'
      Trace a
tr' <- Stack Vector (Link a) -> IO (Trace a)
forall a. Stack Vector (Link a) -> IO (Trace a)
thawT Stack Vector (Link a)
tr
      Chain a -> IO (Chain a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Chain a -> IO (Chain a)) -> Chain a -> IO (Chain a)
forall a b. (a -> b) -> a -> b
$ Int
-> Link a
-> Int
-> Trace a
-> Acceptance (Proposal a)
-> GenIO
-> Int
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> Chain a
forall a.
Int
-> Link a
-> Int
-> Trace a
-> Acceptance (Proposal a)
-> GenIO
-> Int
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> Chain a
Chain Int
ci Link a
it Int
i Trace a
tr' Acceptance (Proposal a)
ac Gen RealWorld
GenIO
g Int
i PriorFunction a
pr PriorFunction a
lh Cycle a
cc' Monitor a
mn
  where
    ac :: Acceptance (Proposal a)
ac = [Int] -> [Proposal a] -> Acceptance Int -> Acceptance (Proposal a)
forall k1 k2.
(Ord k1, Ord k2) =>
[k1] -> [k2] -> Acceptance k1 -> Acceptance k2
transformKeysA [Int
0 ..] (Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc) Acceptance Int
ac'
    getTuningF :: Maybe a -> b -> a
getTuningF Maybe a
mt = case Maybe a
mt of
      Maybe a
Nothing -> a -> b -> a
forall a b. a -> b -> a
const a
1.0
      Just a
t -> a -> b -> a
forall a b. a -> b -> a
const a
t
    cc' :: Cycle a
cc' =
      Map (Proposal a) (Double -> Double) -> Cycle a -> Cycle a
forall a. Map (Proposal a) (Double -> Double) -> Cycle a -> Cycle a
tuneCycle
        ( (Maybe Double -> Double -> Double)
-> Map (Proposal a) (Maybe Double)
-> Map (Proposal a) (Double -> Double)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Maybe Double -> Double -> Double
forall a b. Fractional a => Maybe a -> b -> a
getTuningF (Map (Proposal a) (Maybe Double)
 -> Map (Proposal a) (Double -> Double))
-> Map (Proposal a) (Maybe Double)
-> Map (Proposal a) (Double -> Double)
forall a b. (a -> b) -> a -> b
$
            [(Proposal a, Maybe Double)] -> Map (Proposal a) (Maybe Double)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Proposal a, Maybe Double)] -> Map (Proposal a) (Maybe Double))
-> [(Proposal a, Maybe Double)] -> Map (Proposal a) (Maybe Double)
forall a b. (a -> b) -> a -> b
$
              [Proposal a] -> [Maybe Double] -> [(Proposal a, Maybe Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc) [Maybe Double]
ts
        )
        Cycle a
cc