-- |
-- Module      :  Mcmc.Chain.Trace
-- Description :  History of a Markov chain
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Wed May 20 09:11:25 2020.
module Mcmc.Chain.Trace
  ( Trace,
    replicateT,
    fromVectorT,
    pushT,
    headT,
    takeT,
    freezeT,
    thawT,
  )
where

import Control.Monad.Primitive
import qualified Data.Stack.Circular as C
import qualified Data.Vector as VB
import Mcmc.Chain.Link

-- NOTE: We directly refer to the 'PrimSate' 'RealWorld' because, otherwise, we
-- have two type variables in 'Chain m a'.

-- | A 'Trace' is a mutable circular stack that passes through a list of states
-- with associated priors and likelihoods called 'Link's.
newtype Trace a = Trace {forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace :: C.MStack VB.Vector RealWorld (Link a)}

-- | Initialize a trace of given length by replicating the same value.
--
-- Be careful not to compute summary statistics before pushing enough values.
--
-- Call 'error' if the maximum size is zero or negative.
replicateT :: Int -> Link a -> IO (Trace a)
replicateT :: forall a. Int -> Link a -> IO (Trace a)
replicateT Int
n Link a
l = forall a. MStack Vector RealWorld (Link a) -> Trace a
Trace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
Int -> a -> m (MStack v (PrimState m) a)
C.replicate Int
n Link a
l

-- | Create a trace from a vector. The length is determined by the vector.
--
-- Call 'error' if the vector is empty.
fromVectorT :: VB.Vector (Link a) -> IO (Trace a)
fromVectorT :: forall a. Vector (Link a) -> IO (Trace a)
fromVectorT Vector (Link a)
xs = forall a. MStack Vector RealWorld (Link a) -> Trace a
Trace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
v a -> m (MStack v (PrimState m) a)
C.fromVector Vector (Link a)
xs

-- | Push a 'Link' on the 'Trace'.
pushT :: Link a -> Trace a -> IO (Trace a)
pushT :: forall a. Link a -> Trace a -> IO (Trace a)
pushT Link a
x Trace a
t = do
  MStack Vector RealWorld (Link a)
s' <- forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
a -> MStack v (PrimState m) a -> m (MStack v (PrimState m) a)
C.push Link a
x (forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace Trace a
t)
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. MStack Vector RealWorld (Link a) -> Trace a
Trace MStack Vector RealWorld (Link a)
s'
{-# INLINEABLE pushT #-}

-- | Get the most recent link of the trace (see 'C.get').
headT :: Trace a -> IO (Link a)
headT :: forall a. Trace a -> IO (Link a)
headT = forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
MStack v (PrimState m) a -> m a
C.get forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace
{-# INLINEABLE headT #-}

-- | Get the k most recent links of the trace (see 'C.take').
takeT :: Int -> Trace a -> IO (VB.Vector (Link a))
takeT :: forall a. Int -> Trace a -> IO (Vector (Link a))
takeT Int
k = forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
Int -> MStack v (PrimState m) a -> m (v a)
C.take Int
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace

-- | Freeze the mutable trace for storage (see 'C.freeze').
freezeT :: Trace a -> IO (C.Stack VB.Vector (Link a))
freezeT :: forall a. Trace a -> IO (Stack Vector (Link a))
freezeT = forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
MStack v (PrimState m) a -> m (Stack v a)
C.freeze forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace

-- | Thaw a circular stack (see 'C.thaw').
thawT :: C.Stack VB.Vector (Link a) -> IO (Trace a)
thawT :: forall a. Stack Vector (Link a) -> IO (Trace a)
thawT Stack Vector (Link a)
t = forall a. MStack Vector RealWorld (Link a) -> Trace a
Trace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
Stack v a -> m (MStack v (PrimState m) a)
C.thaw Stack Vector (Link a)
t