-- |
-- Module      :  Mcmc.Chain.Trace
-- Description :  History of a Markov chain
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Wed May 20 09:11:25 2020.
module Mcmc.Chain.Trace
  ( Trace,
    replicateT,
    fromVectorT,
    pushT,
    headT,
    takeT,
    freezeT,
    thawT,
  )
where

import Control.Monad.Primitive
import qualified Data.Stack.Circular as C
import qualified Data.Vector as VB
import Mcmc.Chain.Link

-- NOTE: We directly refer to the 'PrimSate' 'RealWorld' because, otherwise, we
-- have two type variables in 'Chain m a'.

-- | A 'Trace' is a mutable circular stack that passes through a list of states
-- with associated priors and likelihoods called 'Link's.
newtype Trace a = Trace {forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace :: C.MStack VB.Vector RealWorld (Link a)}

-- | Initialize a trace of given length by replicating the same value.
--
-- Be careful not to compute summary statistics before pushing enough values.
--
-- Call 'error' if the maximum size is zero or negative.
replicateT :: Int -> Link a -> IO (Trace a)
replicateT :: forall a. Int -> Link a -> IO (Trace a)
replicateT Int
n Link a
l = MStack Vector RealWorld (Link a) -> Trace a
forall a. MStack Vector RealWorld (Link a) -> Trace a
Trace (MStack Vector RealWorld (Link a) -> Trace a)
-> IO (MStack Vector RealWorld (Link a)) -> IO (Trace a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Link a -> IO (MStack Vector (PrimState IO) (Link a))
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
Int -> a -> m (MStack v (PrimState m) a)
C.replicate Int
n Link a
l

-- | Create a trace from a vector. The length is determined by the vector.
--
-- Call 'error' if the vector is empty.
fromVectorT :: VB.Vector (Link a) -> IO (Trace a)
fromVectorT :: forall a. Vector (Link a) -> IO (Trace a)
fromVectorT Vector (Link a)
xs = MStack Vector RealWorld (Link a) -> Trace a
forall a. MStack Vector RealWorld (Link a) -> Trace a
Trace (MStack Vector RealWorld (Link a) -> Trace a)
-> IO (MStack Vector RealWorld (Link a)) -> IO (Trace a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector (Link a) -> IO (MStack Vector (PrimState IO) (Link a))
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
v a -> m (MStack v (PrimState m) a)
C.fromVector Vector (Link a)
xs

-- | Push a 'Link' on the 'Trace'.
pushT :: Link a -> Trace a -> IO (Trace a)
pushT :: forall a. Link a -> Trace a -> IO (Trace a)
pushT Link a
x Trace a
t = do
  MStack Vector RealWorld (Link a)
s' <- Link a
-> MStack Vector (PrimState IO) (Link a)
-> IO (MStack Vector (PrimState IO) (Link a))
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
a -> MStack v (PrimState m) a -> m (MStack v (PrimState m) a)
C.push Link a
x (Trace a -> MStack Vector RealWorld (Link a)
forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace Trace a
t)
  Trace a -> IO (Trace a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Trace a -> IO (Trace a)) -> Trace a -> IO (Trace a)
forall a b. (a -> b) -> a -> b
$ MStack Vector RealWorld (Link a) -> Trace a
forall a. MStack Vector RealWorld (Link a) -> Trace a
Trace MStack Vector RealWorld (Link a)
s'
{-# INLINEABLE pushT #-}

-- | Get the most recent link of the trace (see 'C.get').
headT :: Trace a -> IO (Link a)
headT :: forall a. Trace a -> IO (Link a)
headT = MStack Vector RealWorld (Link a) -> IO (Link a)
MStack Vector (PrimState IO) (Link a) -> IO (Link a)
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
MStack v (PrimState m) a -> m a
C.get (MStack Vector RealWorld (Link a) -> IO (Link a))
-> (Trace a -> MStack Vector RealWorld (Link a))
-> Trace a
-> IO (Link a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Trace a -> MStack Vector RealWorld (Link a)
forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace
{-# INLINEABLE headT #-}

-- | Get the k most recent links of the trace (see 'C.take').
takeT :: Int -> Trace a -> IO (VB.Vector (Link a))
takeT :: forall a. Int -> Trace a -> IO (Vector (Link a))
takeT Int
k = Int
-> MStack Vector (PrimState IO) (Link a) -> IO (Vector (Link a))
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
Int -> MStack v (PrimState m) a -> m (v a)
C.take Int
k (MStack Vector RealWorld (Link a) -> IO (Vector (Link a)))
-> (Trace a -> MStack Vector RealWorld (Link a))
-> Trace a
-> IO (Vector (Link a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Trace a -> MStack Vector RealWorld (Link a)
forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace

-- | Freeze the mutable trace for storage (see 'C.freeze').
freezeT :: Trace a -> IO (C.Stack VB.Vector (Link a))
freezeT :: forall a. Trace a -> IO (Stack Vector (Link a))
freezeT = MStack Vector RealWorld (Link a) -> IO (Stack Vector (Link a))
MStack Vector (PrimState IO) (Link a) -> IO (Stack Vector (Link a))
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
MStack v (PrimState m) a -> m (Stack v a)
C.freeze (MStack Vector RealWorld (Link a) -> IO (Stack Vector (Link a)))
-> (Trace a -> MStack Vector RealWorld (Link a))
-> Trace a
-> IO (Stack Vector (Link a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Trace a -> MStack Vector RealWorld (Link a)
forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace

-- | Thaw a circular stack (see 'C.thaw').
thawT :: C.Stack VB.Vector (Link a) -> IO (Trace a)
thawT :: forall a. Stack Vector (Link a) -> IO (Trace a)
thawT Stack Vector (Link a)
t = MStack Vector RealWorld (Link a) -> Trace a
forall a. MStack Vector RealWorld (Link a) -> Trace a
Trace (MStack Vector RealWorld (Link a) -> Trace a)
-> IO (MStack Vector RealWorld (Link a)) -> IO (Trace a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stack Vector (Link a) -> IO (MStack Vector (PrimState IO) (Link a))
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
Stack v a -> m (MStack v (PrimState m) a)
C.thaw Stack Vector (Link a)
t