{-# LANGUAGE RankNTypes #-}

-- |
-- Module      :  Mcmc.Monitor.ParameterBatch
-- Description :  Batch monitor parameters
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Fri May 29 11:11:49 2020.
--
-- A batch monitor prints summary statistics of a parameter collected over a
-- specific number of last iterations. The functions provided in this module
-- calculate the mean of the monitored parameter. However, custom batch monitors
-- can use more complex functions.
module Mcmc.Monitor.ParameterBatch
  ( -- * Batch parameter monitors
    MonitorParameterBatch (..),
    (>$<),
    monitorBatchMean,
    monitorBatchMeanF,
    monitorBatchMeanE,
  )
where

import qualified Data.ByteString.Builder as BB
import qualified Data.Double.Conversion.ByteString as BC
import Data.Functor.Contravariant
import qualified Data.Vector as VB

-- | Instruction about a parameter to monitor via batch means. Usually, the
-- monitored parameter is averaged over the batch size. However, arbitrary
-- functions performing more complicated analyses on the states in the batch can
-- be provided.
--
-- Convert a batch monitor from one data type to another with '(>$<)'.
--
-- For example, batch monitor the mean of the first entry of a tuple:
--
-- @
-- mon = fst >$< monitorBatchMean
-- @
--
-- Batch monitors may be slow because the monitored parameter has to be
-- extracted from the state for each iteration.
data MonitorParameterBatch a = MonitorParameterBatch
  { -- | Name of batch monitored parameter.
    MonitorParameterBatch a -> String
mbpName :: String,
    -- | For a given batch, extract the summary statistics.
    MonitorParameterBatch a -> Vector a -> Builder
mbpFunc :: VB.Vector a -> BB.Builder
  }

instance Contravariant MonitorParameterBatch where
  contramap :: (a -> b) -> MonitorParameterBatch b -> MonitorParameterBatch a
contramap a -> b
f (MonitorParameterBatch String
n Vector b -> Builder
m) = String -> (Vector a -> Builder) -> MonitorParameterBatch a
forall a.
String -> (Vector a -> Builder) -> MonitorParameterBatch a
MonitorParameterBatch String
n (Vector b -> Builder
m (Vector b -> Builder)
-> (Vector a -> Vector b) -> Vector a -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> b) -> Vector a -> Vector b
forall a b. (a -> b) -> Vector a -> Vector b
VB.map a -> b
f)

mean :: Real a => VB.Vector a -> Double
mean :: Vector a -> Double
mean Vector a
xs = a -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac (Vector a -> a
forall a. Num a => Vector a -> a
VB.sum Vector a
xs) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector a -> Int
forall a. Vector a -> Int
VB.length Vector a
xs)
{-# SPECIALIZE mean :: VB.Vector Double -> Double #-}
{-# SPECIALIZE mean :: VB.Vector Int -> Double #-}

-- | Batch mean monitor.
--
-- Print the mean with eight decimal places (half precision).
monitorBatchMean ::
  Real a =>
  -- | Name.
  String ->
  MonitorParameterBatch a
monitorBatchMean :: String -> MonitorParameterBatch a
monitorBatchMean String
n = String -> (Vector a -> Builder) -> MonitorParameterBatch a
forall a.
String -> (Vector a -> Builder) -> MonitorParameterBatch a
MonitorParameterBatch String
n (ByteString -> Builder
BB.byteString (ByteString -> Builder)
-> (Vector a -> ByteString) -> Vector a -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Double -> ByteString
BC.toFixed Int
8 (Double -> ByteString)
-> (Vector a -> Double) -> Vector a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> Double
forall a. Real a => Vector a -> Double
mean)
{-# SPECIALIZE monitorBatchMean :: String -> MonitorParameterBatch Int #-}
{-# SPECIALIZE monitorBatchMean :: String -> MonitorParameterBatch Double #-}

-- | Batch mean monitor.
--
-- Print the mean with full precision computing the shortest string of digits
-- that correctly represent the number.
monitorBatchMeanF ::
  Real a =>
  -- | Name.
  String ->
  MonitorParameterBatch a
monitorBatchMeanF :: String -> MonitorParameterBatch a
monitorBatchMeanF String
n = String -> (Vector a -> Builder) -> MonitorParameterBatch a
forall a.
String -> (Vector a -> Builder) -> MonitorParameterBatch a
MonitorParameterBatch String
n (ByteString -> Builder
BB.byteString (ByteString -> Builder)
-> (Vector a -> ByteString) -> Vector a -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> ByteString
BC.toShortest (Double -> ByteString)
-> (Vector a -> Double) -> Vector a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> Double
forall a. Real a => Vector a -> Double
mean)
{-# SPECIALIZE monitorBatchMeanF :: String -> MonitorParameterBatch Int #-}
{-# SPECIALIZE monitorBatchMeanF :: String -> MonitorParameterBatch Double #-}

-- | Batch mean monitor.
--
-- Print the real float parameters such as 'Double' with scientific notation and
-- eight decimal places.
monitorBatchMeanE ::
  Real a =>
  -- | Name.
  String ->
  MonitorParameterBatch a
monitorBatchMeanE :: String -> MonitorParameterBatch a
monitorBatchMeanE String
n = String -> (Vector a -> Builder) -> MonitorParameterBatch a
forall a.
String -> (Vector a -> Builder) -> MonitorParameterBatch a
MonitorParameterBatch String
n (ByteString -> Builder
BB.byteString (ByteString -> Builder)
-> (Vector a -> ByteString) -> Vector a -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Double -> ByteString
BC.toExponential Int
8 (Double -> ByteString)
-> (Vector a -> Double) -> Vector a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> Double
forall a. Real a => Vector a -> Double
mean)
{-# SPECIALIZE monitorBatchMeanE :: String -> MonitorParameterBatch Int #-}
{-# SPECIALIZE monitorBatchMeanE :: String -> MonitorParameterBatch Double #-}