{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

-- |
-- Module      :  Mcmc.Save
-- Description :  Save the state of a Markov chain
-- Copyright   :  (c) Dominik Schrempf, 2020
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Tue Jun 16 10:18:54 2020.
--
-- Save and load an MCMC run. It is easy to save and restore the current state and
-- likelihood (or the trace), but it is not feasible to store all the proposals and so
-- on, so they have to be provided again when continuing a run.
module Mcmc.Save
  ( saveStatus,
    loadStatus,
  )
where

import Codec.Compression.GZip
import Control.Monad
import Data.Aeson
import Data.Aeson.TH
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.List hiding (cycle)
import qualified Data.Map as M
import Data.Maybe
import Data.Vector.Unboxed (Vector)
import Data.Word
-- TODO: Splitmix. Reproposal as soon as split mix is used and is available with the
-- statistics package.
import Mcmc.Item
import Mcmc.Monitor
import Mcmc.Proposal
import Mcmc.Status hiding (save)
import Mcmc.Trace
import Mcmc.Verbosity
import Numeric.Log
import System.Directory
import System.IO.Unsafe (unsafePerformIO)
import System.Random.MWC
import Prelude hiding (cycle)

data Save a
  = Save
      -- Variables related to the chain.
      String -- Name.
      (Item a)
      Int -- Iteration.
      (Trace a)
      (Acceptance Int)
      (Maybe Int) -- Burn in.
      (Maybe Int) -- Auto tune.
      Int -- Iterations.
      Bool -- Force.
      (Maybe Int) -- Save.
      Verbosity
      (Vector Word32) -- Current seed.

      -- Variables related to the algorithm.
      [Maybe Double] -- Tuning parameters.

$(deriveJSON defaultOptions ''Save)

toSave :: Status a -> Save a
toSave :: Status a -> Save a
toSave (Status String
nm Item a
it Int
i Trace a
tr Acceptance (Proposal a)
ac Maybe Int
br Maybe Int
at Int
is Bool
f Maybe Int
sv Verbosity
vb GenIO
g Maybe (Int, UTCTime)
_ Maybe Handle
_ a -> Log Double
_ a -> Log Double
_ Maybe (Cleaner a)
_ Cycle a
c Monitor a
_) =
  String
-> Item a
-> Int
-> Trace a
-> Acceptance Int
-> Maybe Int
-> Maybe Int
-> Int
-> Bool
-> Maybe Int
-> Verbosity
-> Vector Word32
-> [Maybe Double]
-> Save a
forall a.
String
-> Item a
-> Int
-> Trace a
-> Acceptance Int
-> Maybe Int
-> Maybe Int
-> Int
-> Bool
-> Maybe Int
-> Verbosity
-> Vector Word32
-> [Maybe Double]
-> Save a
Save
    String
nm
    Item a
it
    Int
i
    Trace a
tr'
    Acceptance Int
ac'
    Maybe Int
br
    Maybe Int
at
    Int
is
    Bool
f
    Maybe Int
sv
    Verbosity
vb
    Vector Word32
g'
    [Maybe Double]
ts
  where
    tr' :: Trace a
tr' = Int -> Trace a -> Trace a
forall a. Int -> Trace a -> Trace a
takeT (Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 Maybe Int
sv) Trace a
tr
    ac' :: Acceptance Int
ac' = [Proposal a] -> [Int] -> Acceptance (Proposal a) -> Acceptance Int
forall k1 k2.
(Ord k1, Ord k2) =>
[k1] -> [k2] -> Acceptance k1 -> Acceptance k2
transformKeysA (Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
c) [Int
0 ..] Acceptance (Proposal a)
ac
    -- TODO: Splitmix. Remove as soon as split mix is used and is available with
    -- the statistics package.
    g' :: Vector Word32
g' = Seed -> Vector Word32
fromSeed (Seed -> Vector Word32) -> Seed -> Vector Word32
forall a b. (a -> b) -> a -> b
$ IO Seed -> Seed
forall a. IO a -> a
unsafePerformIO (IO Seed -> Seed) -> IO Seed -> Seed
forall a b. (a -> b) -> a -> b
$ GenIO -> IO Seed
forall (m :: * -> *). PrimMonad m => Gen (PrimState m) -> m Seed
save GenIO
g
    ts :: [Maybe Double]
ts = [(Tuner a -> Double) -> Maybe (Tuner a) -> Maybe Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Tuner a -> Double
forall a. Tuner a -> Double
tParam Maybe (Tuner a)
mt | Maybe (Tuner a)
mt <- (Proposal a -> Maybe (Tuner a))
-> [Proposal a] -> [Maybe (Tuner a)]
forall a b. (a -> b) -> [a] -> [b]
map Proposal a -> Maybe (Tuner a)
forall a. Proposal a -> Maybe (Tuner a)
pTuner ([Proposal a] -> [Maybe (Tuner a)])
-> [Proposal a] -> [Maybe (Tuner a)]
forall a b. (a -> b) -> a -> b
$ Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
c]

-- | Save a 'Status' to file.
--
-- Some important values have to be provided upon restoring the status. See
-- 'loadStatus'.
saveStatus :: ToJSON a => FilePath -> Status a -> IO ()
saveStatus :: String -> Status a -> IO ()
saveStatus String
fn Status a
s = String -> ByteString -> IO ()
BL.writeFile String
fn (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
compress (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Save a -> ByteString
forall a. ToJSON a => a -> ByteString
encode (Status a -> Save a
forall a. Status a -> Save a
toSave Status a
s)

-- fromSav prior lh cycle monitor save
fromSave ::
  (a -> Log Double) ->
  (a -> Log Double) ->
  Cycle a ->
  Monitor a ->
  Maybe (Cleaner a) ->
  Save a ->
  Status a
fromSave :: (a -> Log Double)
-> (a -> Log Double)
-> Cycle a
-> Monitor a
-> Maybe (Cleaner a)
-> Save a
-> Status a
fromSave a -> Log Double
pr a -> Log Double
lh Cycle a
cc Monitor a
m Maybe (Cleaner a)
cl (Save String
nm Item a
it Int
i Trace a
tr Acceptance Int
ac' Maybe Int
br Maybe Int
at Int
is Bool
f Maybe Int
sv Verbosity
vb Vector Word32
g' [Maybe Double]
ts) =
  String
-> Item a
-> Int
-> Trace a
-> Acceptance (Proposal a)
-> Maybe Int
-> Maybe Int
-> Int
-> Bool
-> Maybe Int
-> Verbosity
-> GenIO
-> Maybe (Int, UTCTime)
-> Maybe Handle
-> (a -> Log Double)
-> (a -> Log Double)
-> Maybe (Cleaner a)
-> Cycle a
-> Monitor a
-> Status a
forall a.
String
-> Item a
-> Int
-> Trace a
-> Acceptance (Proposal a)
-> Maybe Int
-> Maybe Int
-> Int
-> Bool
-> Maybe Int
-> Verbosity
-> GenIO
-> Maybe (Int, UTCTime)
-> Maybe Handle
-> (a -> Log Double)
-> (a -> Log Double)
-> Maybe (Cleaner a)
-> Cycle a
-> Monitor a
-> Status a
Status
    String
nm
    Item a
it
    Int
i
    Trace a
tr
    Acceptance (Proposal a)
ac
    Maybe Int
br
    Maybe Int
at
    Int
is
    Bool
f
    Maybe Int
sv
    Verbosity
vb
    Gen RealWorld
GenIO
g
    Maybe (Int, UTCTime)
forall a. Maybe a
Nothing
    Maybe Handle
forall a. Maybe a
Nothing
    a -> Log Double
pr
    a -> Log Double
lh
    Maybe (Cleaner a)
cl
    Cycle a
cc'
    Monitor a
m
  where
    ac :: Acceptance (Proposal a)
ac = [Int] -> [Proposal a] -> Acceptance Int -> Acceptance (Proposal a)
forall k1 k2.
(Ord k1, Ord k2) =>
[k1] -> [k2] -> Acceptance k1 -> Acceptance k2
transformKeysA [Int
0 ..] (Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc) Acceptance Int
ac'
    -- TODO: Splitmix. Remove as soon as split mix is used and is available with
    -- the statistics package.
    g :: Gen RealWorld
g = IO (Gen RealWorld) -> Gen RealWorld
forall a. IO a -> a
unsafePerformIO (IO (Gen RealWorld) -> Gen RealWorld)
-> IO (Gen RealWorld) -> Gen RealWorld
forall a b. (a -> b) -> a -> b
$ Seed -> IO GenIO
forall (m :: * -> *). PrimMonad m => Seed -> m (Gen (PrimState m))
restore (Seed -> IO GenIO) -> Seed -> IO GenIO
forall a b. (a -> b) -> a -> b
$ Vector Word32 -> Seed
forall (v :: * -> *). Vector v Word32 => v Word32 -> Seed
toSeed Vector Word32
g'
    cc' :: Cycle a
cc' = Map (Proposal a) Double -> Cycle a -> Cycle a
forall a. Map (Proposal a) Double -> Cycle a -> Cycle a
tuneCycle ((Maybe Double -> Maybe Double)
-> Map (Proposal a) (Maybe Double) -> Map (Proposal a) Double
forall a b k. (a -> Maybe b) -> Map k a -> Map k b
M.mapMaybe Maybe Double -> Maybe Double
forall a. a -> a
id (Map (Proposal a) (Maybe Double) -> Map (Proposal a) Double)
-> Map (Proposal a) (Maybe Double) -> Map (Proposal a) Double
forall a b. (a -> b) -> a -> b
$ [(Proposal a, Maybe Double)] -> Map (Proposal a) (Maybe Double)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Proposal a, Maybe Double)] -> Map (Proposal a) (Maybe Double))
-> [(Proposal a, Maybe Double)] -> Map (Proposal a) (Maybe Double)
forall a b. (a -> b) -> a -> b
$ [Proposal a] -> [Maybe Double] -> [(Proposal a, Maybe Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc) [Maybe Double]
ts) Cycle a
cc

-- | Load a 'Status' from file.
--
-- Important information that cannot be saved and has to be provided again when
-- a chain is restored:
-- - prior function
-- - likelihood function
-- - cleaning function
-- - cycle
-- - monitor
--
-- To avoid incomplete continued runs, the @.mcmc@ file is removed after load.
loadStatus ::
  FromJSON a =>
  -- | Prior function.
  (a -> Log Double) ->
  -- | Likelihood function.
  (a -> Log Double) ->
  Cycle a ->
  Monitor a ->
  -- | Cleaner, if needed.
  Maybe (Cleaner a) ->
  -- | Path of status to load.
  FilePath ->
  IO (Status a)
loadStatus :: (a -> Log Double)
-> (a -> Log Double)
-> Cycle a
-> Monitor a
-> Maybe (Cleaner a)
-> String
-> IO (Status a)
loadStatus a -> Log Double
pr a -> Log Double
lh Cycle a
cc Monitor a
mn Maybe (Cleaner a)
cl String
fn = do
  Either String (Save a)
res <- ByteString -> Either String (Save a)
forall a. FromJSON a => ByteString -> Either String a
eitherDecode (ByteString -> Either String (Save a))
-> (ByteString -> ByteString)
-> ByteString
-> Either String (Save a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
decompress (ByteString -> Either String (Save a))
-> IO ByteString -> IO (Either String (Save a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO ByteString
BL.readFile String
fn
  let s :: Status a
s = case Either String (Save a)
res of
        Left String
err -> String -> Status a
forall a. HasCallStack => String -> a
error String
err
        Right Save a
sv -> (a -> Log Double)
-> (a -> Log Double)
-> Cycle a
-> Monitor a
-> Maybe (Cleaner a)
-> Save a
-> Status a
forall a.
(a -> Log Double)
-> (a -> Log Double)
-> Cycle a
-> Monitor a
-> Maybe (Cleaner a)
-> Save a
-> Status a
fromSave a -> Log Double
pr a -> Log Double
lh Cycle a
cc Monitor a
mn Maybe (Cleaner a)
cl Save a
sv
  -- Check if prior and likelihood matches.
  let Item a
x Log Double
svp Log Double
svl = Status a -> Item a
forall a. Status a -> Item a
item Status a
s
  -- Recompute and check the prior and likelihood for the last state because the
  -- functions may have changed. Of course, we cannot test for the same
  -- function, but having the same prior and likelihood at the last state is
  -- already a good indicator.
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
    (a -> Log Double
pr a
x Log Double -> Log Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Log Double
svp)
    (String -> IO ()
forall a. HasCallStack => String -> a
error String
"loadStatus: Provided prior function does not match the saved prior.")
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
    (a -> Log Double
lh a
x Log Double -> Log Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Log Double
svl)
    (String -> IO ()
forall a. HasCallStack => String -> a
error String
"loadStatus: Provided likelihood function does not match the saved likelihood.")
  String -> IO ()
removeFile String
fn
  Status a -> IO (Status a)
forall (m :: * -> *) a. Monad m => a -> m a
return Status a
s