{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
module Mcmc.Chain.Save
( SavedChain (..),
toSavedChain,
fromSavedChain,
fromSavedChainUnsafe,
)
where
import Control.Monad
import Data.Aeson
import Data.Aeson.TH
import Data.List hiding (cycle)
import Data.Maybe
import qualified Data.Stack.Circular as C
import qualified Data.Vector as VB
import Data.Word
import Mcmc.Acceptance
import Mcmc.Chain.Chain
import Mcmc.Chain.Link
import Mcmc.Chain.Trace
import Mcmc.Cycle
import Mcmc.Internal.Random
import Mcmc.Likelihood
import Mcmc.Monitor
import Mcmc.Prior
import Mcmc.Proposal
import Prelude hiding (cycle)
data SavedChain a = SavedChain
{ forall a. SavedChain a -> Link a
savedLink :: Link a,
forall a. SavedChain a -> Int
savedIteration :: Int,
forall a. SavedChain a -> Stack Vector (Link a)
savedTrace :: C.Stack VB.Vector (Link a),
forall a. SavedChain a -> Acceptances Int
savedAcceptances :: Acceptances Int,
forall a. SavedChain a -> (Word64, Word64)
savedSeed :: (Word64, Word64),
forall a.
SavedChain a
-> [Maybe (TuningParameter, AuxiliaryTuningParameters)]
savedTuningParameters :: [Maybe (TuningParameter, AuxiliaryTuningParameters)]
}
deriving (SavedChain a -> SavedChain a -> Bool
(SavedChain a -> SavedChain a -> Bool)
-> (SavedChain a -> SavedChain a -> Bool) -> Eq (SavedChain a)
forall a. Eq a => SavedChain a -> SavedChain a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => SavedChain a -> SavedChain a -> Bool
== :: SavedChain a -> SavedChain a -> Bool
$c/= :: forall a. Eq a => SavedChain a -> SavedChain a -> Bool
/= :: SavedChain a -> SavedChain a -> Bool
Eq, Int -> SavedChain a -> ShowS
[SavedChain a] -> ShowS
SavedChain a -> String
(Int -> SavedChain a -> ShowS)
-> (SavedChain a -> String)
-> ([SavedChain a] -> ShowS)
-> Show (SavedChain a)
forall a. Show a => Int -> SavedChain a -> ShowS
forall a. Show a => [SavedChain a] -> ShowS
forall a. Show a => SavedChain a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> SavedChain a -> ShowS
showsPrec :: Int -> SavedChain a -> ShowS
$cshow :: forall a. Show a => SavedChain a -> String
show :: SavedChain a -> String
$cshowList :: forall a. Show a => [SavedChain a] -> ShowS
showList :: [SavedChain a] -> ShowS
Show)
$(deriveJSON defaultOptions ''SavedChain)
toSavedChain ::
Chain a ->
IO (SavedChain a)
toSavedChain :: forall a. Chain a -> IO (SavedChain a)
toSavedChain (Chain Link a
it Int
i Trace a
tr Acceptances (Proposal a)
ac IOGenM StdGen
g Int
_ PriorFunction a
_ PriorFunction a
_ Cycle a
cc Monitor a
_) = do
(Word64, Word64)
g' <- IOGenM StdGen -> IO (Word64, Word64)
saveGen IOGenM StdGen
g
Stack Vector (Link a)
tr' <- Trace a -> IO (Stack Vector (Link a))
forall a. Trace a -> IO (Stack Vector (Link a))
freezeT Trace a
tr
SavedChain a -> IO (SavedChain a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SavedChain a -> IO (SavedChain a))
-> SavedChain a -> IO (SavedChain a)
forall a b. (a -> b) -> a -> b
$ Link a
-> Int
-> Stack Vector (Link a)
-> Acceptances Int
-> (Word64, Word64)
-> [Maybe (TuningParameter, AuxiliaryTuningParameters)]
-> SavedChain a
forall a.
Link a
-> Int
-> Stack Vector (Link a)
-> Acceptances Int
-> (Word64, Word64)
-> [Maybe (TuningParameter, AuxiliaryTuningParameters)]
-> SavedChain a
SavedChain Link a
it Int
i Stack Vector (Link a)
tr' Acceptances Int
ac' (Word64, Word64)
g' [Maybe (TuningParameter, AuxiliaryTuningParameters)]
ts
where
ps :: [Proposal a]
ps = Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc
ac' :: Acceptances Int
ac' = [(Proposal a, Int)] -> Acceptances (Proposal a) -> Acceptances Int
forall k1 k2.
(Ord k1, Ord k2) =>
[(k1, k2)] -> Acceptances k1 -> Acceptances k2
transformKeysA ([Proposal a] -> [Int] -> [(Proposal a, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Proposal a]
ps [Int
0 ..]) Acceptances (Proposal a)
ac
ts :: [Maybe (TuningParameter, AuxiliaryTuningParameters)]
ts =
[ (\Tuner a
t -> (Tuner a -> TuningParameter
forall a. Tuner a -> TuningParameter
tTuningParameter Tuner a
t, Tuner a -> AuxiliaryTuningParameters
forall a. Tuner a -> AuxiliaryTuningParameters
tAuxiliaryTuningParameters Tuner a
t)) (Tuner a -> (TuningParameter, AuxiliaryTuningParameters))
-> Maybe (Tuner a)
-> Maybe (TuningParameter, AuxiliaryTuningParameters)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Tuner a)
mt
| Maybe (Tuner a)
mt <- (Proposal a -> Maybe (Tuner a))
-> [Proposal a] -> [Maybe (Tuner a)]
forall a b. (a -> b) -> [a] -> [b]
map Proposal a -> Maybe (Tuner a)
forall a. Proposal a -> Maybe (Tuner a)
prTuner [Proposal a]
ps
]
fromSavedChain ::
PriorFunction a ->
LikelihoodFunction a ->
Cycle a ->
Monitor a ->
SavedChain a ->
IO (Chain a)
fromSavedChain :: forall a.
PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
fromSavedChain PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn SavedChain a
sv
| PriorFunction a
pr (Link a -> a
forall a. Link a -> a
state Link a
it) Log TuningParameter -> Log TuningParameter -> Bool
forall a. Eq a => a -> a -> Bool
/= Link a -> Log TuningParameter
forall a. Link a -> Log TuningParameter
prior Link a
it =
let msg :: String
msg =
[String] -> String
unlines
[ String
"fromSave: Provided prior function does not match the saved prior.",
String
"fromSave: Current prior:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Log TuningParameter -> String
forall a. Show a => a -> String
show (Link a -> Log TuningParameter
forall a. Link a -> Log TuningParameter
prior Link a
it) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
".",
String
"fromSave: Given prior:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Log TuningParameter -> String
forall a. Show a => a -> String
show (PriorFunction a
pr PriorFunction a -> PriorFunction a
forall a b. (a -> b) -> a -> b
$ Link a -> a
forall a. Link a -> a
state Link a
it) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"."
]
in String -> IO (Chain a)
forall a. HasCallStack => String -> a
error String
msg
| PriorFunction a
lh (Link a -> a
forall a. Link a -> a
state Link a
it) Log TuningParameter -> Log TuningParameter -> Bool
forall a. Eq a => a -> a -> Bool
/= Link a -> Log TuningParameter
forall a. Link a -> Log TuningParameter
likelihood Link a
it =
let msg :: String
msg =
[String] -> String
unlines
[ String
"fromSave: Provided likelihood function does not match the saved likelihood function.",
String
"fromSave: Current likelihood:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Log TuningParameter -> String
forall a. Show a => a -> String
show (Link a -> Log TuningParameter
forall a. Link a -> Log TuningParameter
likelihood Link a
it) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
".",
String
"fromSave: Given likelihood:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Log TuningParameter -> String
forall a. Show a => a -> String
show (PriorFunction a
lh PriorFunction a -> PriorFunction a
forall a b. (a -> b) -> a -> b
$ Link a -> a
forall a. Link a -> a
state Link a
it) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"."
]
in String -> IO (Chain a)
forall a. HasCallStack => String -> a
error String
msg
| Map Int Acceptance -> Int
forall a. Map Int a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Acceptances Int -> Map Int Acceptance
forall k. Acceptances k -> Map k Acceptance
fromAcceptances Acceptances Int
ac) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Proposal a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc) =
let msg :: String
msg =
[String] -> String
unlines
[ String
"fromSave: The number of proposals does not match.",
String
"fromSave: Number of saved proposals:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show (Map Int Acceptance -> Int
forall a. Map Int a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Map Int Acceptance -> Int) -> Map Int Acceptance -> Int
forall a b. (a -> b) -> a -> b
$ Acceptances Int -> Map Int Acceptance
forall k. Acceptances k -> Map k Acceptance
fromAcceptances Acceptances Int
ac) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
".",
String
"fromSave: Number of given proposals:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([Proposal a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Proposal a] -> Int) -> [Proposal a] -> Int
forall a b. (a -> b) -> a -> b
$ Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"."
]
in String -> IO (Chain a)
forall a. HasCallStack => String -> a
error String
msg
| Bool
otherwise = PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
forall a.
PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
fromSavedChainUnsafe PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn SavedChain a
sv
where
it :: Link a
it = SavedChain a -> Link a
forall a. SavedChain a -> Link a
savedLink SavedChain a
sv
ac :: Acceptances Int
ac = SavedChain a -> Acceptances Int
forall a. SavedChain a -> Acceptances Int
savedAcceptances SavedChain a
sv
fromSavedChainUnsafe ::
PriorFunction a ->
LikelihoodFunction a ->
Cycle a ->
Monitor a ->
SavedChain a ->
IO (Chain a)
fromSavedChainUnsafe :: forall a.
PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> SavedChain a
-> IO (Chain a)
fromSavedChainUnsafe PriorFunction a
pr PriorFunction a
lh Cycle a
cc Monitor a
mn (SavedChain Link a
it Int
i Stack Vector (Link a)
tr Acceptances Int
ac' (Word64, Word64)
g' [Maybe (TuningParameter, AuxiliaryTuningParameters)]
ts) = do
IOGenM StdGen
g <- (Word64, Word64) -> IO (IOGenM StdGen)
loadGen (Word64, Word64)
g'
Trace a
tr' <- Stack Vector (Link a) -> IO (Trace a)
forall a. Stack Vector (Link a) -> IO (Trace a)
thawT Stack Vector (Link a)
tr
Chain a -> IO (Chain a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Chain a -> IO (Chain a)) -> Chain a -> IO (Chain a)
forall a b. (a -> b) -> a -> b
$ Link a
-> Int
-> Trace a
-> Acceptances (Proposal a)
-> IOGenM StdGen
-> Int
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> Chain a
forall a.
Link a
-> Int
-> Trace a
-> Acceptances (Proposal a)
-> IOGenM StdGen
-> Int
-> PriorFunction a
-> PriorFunction a
-> Cycle a
-> Monitor a
-> Chain a
Chain Link a
it Int
i Trace a
tr' Acceptances (Proposal a)
ac IOGenM StdGen
g Int
i PriorFunction a
pr PriorFunction a
lh Cycle a
cc' Monitor a
mn
where
ac :: Acceptances (Proposal a)
ac = [(Int, Proposal a)] -> Acceptances Int -> Acceptances (Proposal a)
forall k1 k2.
(Ord k1, Ord k2) =>
[(k1, k2)] -> Acceptances k1 -> Acceptances k2
transformKeysA ([Int] -> [Proposal a] -> [(Int, Proposal a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] (Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc)) Acceptances Int
ac'
tunePs :: Maybe (TuningParameter, AuxiliaryTuningParameters)
-> Proposal a -> Proposal a
tunePs Maybe (TuningParameter, AuxiliaryTuningParameters)
mt Proposal a
p = case Maybe (TuningParameter, AuxiliaryTuningParameters)
mt of
Maybe (TuningParameter, AuxiliaryTuningParameters)
Nothing -> Proposal a
p
Just (TuningParameter
x, AuxiliaryTuningParameters
xs) -> (String -> Proposal a)
-> (Proposal a -> Proposal a)
-> Either String (Proposal a)
-> Proposal a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> Proposal a
forall a. HasCallStack => String -> a
error (String -> Proposal a) -> ShowS -> String -> Proposal a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
forall {a}. a
err)) Proposal a -> Proposal a
forall a. a -> a
id (Either String (Proposal a) -> Proposal a)
-> Either String (Proposal a) -> Proposal a
forall a b. (a -> b) -> a -> b
$ TuningParameter
-> AuxiliaryTuningParameters
-> Proposal a
-> Either String (Proposal a)
forall a.
TuningParameter
-> AuxiliaryTuningParameters
-> Proposal a
-> Either String (Proposal a)
tuneWithTuningParameters TuningParameter
x AuxiliaryTuningParameters
xs Proposal a
p
err :: a
err = String -> a
forall a. HasCallStack => String -> a
error String
"\nfromSavedChain: Proposal with stored tuning parameters is not tunable."
ps :: [Proposal a]
ps = Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
cc
cc' :: Cycle a
cc' = Cycle a
cc {ccProposals = zipWith tunePs ts ps}