{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module      :  Mcmc.Cycle
-- Description :  A cycle is a list of proposals
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Creation date: Thu Jul  8 17:56:03 2021.
module Mcmc.Cycle
  ( -- * Cycles
    Order (..),
    Cycle (ccProposals, ccRequireTrace, ccHasIntermediateTuners),
    cycleFromList,
    setOrder,
    IterationMode (..),
    prepareProposals,
    autoTuneCycle,

    -- * Output
    summarizeCycle,
  )
where

import Control.Applicative
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.List
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Vector as VB
import Mcmc.Acceptance
import Mcmc.Internal.Shuffle
import Mcmc.Proposal
import System.Random.Stateful

-- | Define the order in which 'Proposal's are executed in a 'Cycle'. The total
-- number of 'Proposal's per 'Cycle' may differ between 'Order's (e.g., compare
-- 'RandomO' and 'RandomReversibleO').
data Order
  = -- | Shuffle the 'Proposal's in the 'Cycle'. The 'Proposal's are replicated
    -- according to their weights and executed in random order. If a 'Proposal' has
    -- weight @w@, it is executed exactly @w@ times per iteration.
    RandomO
  | -- | The 'Proposal's are executed sequentially, in the order they appear in the
    -- 'Cycle'. 'Proposal's with weight @w>1@ are repeated immediately @w@ times
    -- (and not appended to the end of the list).
    SequentialO
  | -- | Similar to 'RandomO'. However, a reversed copy of the list of
    --  shuffled 'Proposal's is appended such that the resulting Markov chain is
    --  reversible.
    --  Note: the total number of 'Proposal's executed per cycle is twice the number
    --  of 'RandomO'.
    RandomReversibleO
  | -- | Similar to 'SequentialO'. However, a reversed copy of the list of
    -- sequentially ordered 'Proposal's is appended such that the resulting Markov
    -- chain is reversible.
    SequentialReversibleO
  deriving (Order -> Order -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Order -> Order -> Bool
$c/= :: Order -> Order -> Bool
== :: Order -> Order -> Bool
$c== :: Order -> Order -> Bool
Eq, Int -> Order -> ShowS
[Order] -> ShowS
Order -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Order] -> ShowS
$cshowList :: [Order] -> ShowS
show :: Order -> String
$cshow :: Order -> String
showsPrec :: Int -> Order -> ShowS
$cshowsPrec :: Int -> Order -> ShowS
Show)

-- Describe the order.
describeOrder :: Order -> BL.ByteString
describeOrder :: Order -> ByteString
describeOrder Order
RandomO = ByteString
"The proposals are executed in random order."
describeOrder Order
SequentialO = ByteString
"The proposals are executed sequentially."
describeOrder Order
RandomReversibleO =
  ByteString -> [ByteString] -> ByteString
BL.intercalate
    ByteString
"\n"
    [ Order -> ByteString
describeOrder Order
RandomO,
      ByteString
"A reversed copy of the shuffled proposals is appended to ensure reversibility."
    ]
describeOrder Order
SequentialReversibleO =
  ByteString -> [ByteString] -> ByteString
BL.intercalate
    ByteString
"\n"
    [ Order -> ByteString
describeOrder Order
SequentialO,
      ByteString
"A reversed copy of the sequential proposals is appended to ensure reversibility."
    ]

-- | In brief, a 'Cycle' is a list of proposals.
--
-- The state of the Markov chain will be logged only after all 'Proposal's in
-- the 'Cycle' have been completed, and the iteration counter will be increased
-- by one. The order in which the 'Proposal's are executed is specified by
-- 'Order'. The default is 'RandomO'.
--
-- No proposals with the same name and description are allowed in a 'Cycle', so
-- that they can be uniquely identified.
data Cycle a = Cycle
  { forall a. Cycle a -> [Proposal a]
ccProposals :: [Proposal a],
    forall a. Cycle a -> Order
ccOrder :: Order,
    -- | Does the cycle require the trace when auto tuning? See 'tRequireTrace'.
    forall a. Cycle a -> Bool
ccRequireTrace :: Bool,
    -- | Does the cycle include proposals that can be tuned every iterations?
    -- See 'tSuitableForIntermediateTuning'.
    forall a. Cycle a -> Bool
ccHasIntermediateTuners :: Bool
  }

-- | Create a 'Cycle' from a list of 'Proposal's; use 'RandomO', but see 'setOrder'.
cycleFromList :: [Proposal a] -> Cycle a
cycleFromList :: forall a. [Proposal a] -> Cycle a
cycleFromList [] =
  forall a. HasCallStack => String -> a
error String
"cycleFromList: Received an empty list but cannot create an empty Cycle."
cycleFromList [Proposal a]
xs =
  if forall (t :: * -> *) a. Foldable t => t a -> Int
length [Proposal a]
uniqueXs forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [Proposal a]
xs
    then forall a. [Proposal a] -> Order -> Bool -> Bool -> Cycle a
Cycle [Proposal a]
xs Order
RandomO (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any forall {a}. Proposal a -> Bool
needsTrace [Proposal a]
xs) (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any forall {a}. Proposal a -> Bool
isIntermediate [Proposal a]
xs)
    else forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"\n" forall a. [a] -> [a] -> [a]
++ String
msg forall a. [a] -> [a] -> [a]
++ String
"cycleFromList: Proposals are not unique."
  where
    uniqueXs :: [Proposal a]
uniqueXs = forall a. Eq a => [a] -> [a]
nub [Proposal a]
xs
    removedXs :: [Proposal a]
removedXs = [Proposal a]
xs forall a. Eq a => [a] -> [a] -> [a]
\\ [Proposal a]
uniqueXs
    removedNames :: [String]
removedNames = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Show a => a -> String
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Proposal a -> PName
prName) [Proposal a]
removedXs
    removedDescriptions :: [String]
removedDescriptions = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Show a => a -> String
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Proposal a -> PDescription
prDescription) [Proposal a]
removedXs
    removedMsgs :: [String]
removedMsgs = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\String
n String
d -> String
n forall a. [a] -> [a] -> [a]
++ String
" " forall a. [a] -> [a] -> [a]
++ String
d) [String]
removedNames [String]
removedDescriptions
    msg :: String
msg = [String] -> String
unlines [String]
removedMsgs
    needsTrace :: Proposal a -> Bool
needsTrace Proposal a
p = forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False forall a. Tuner a -> Bool
tRequireTrace (forall a. Proposal a -> Maybe (Tuner a)
prTuner Proposal a
p)
    isIntermediate :: Proposal a -> Bool
isIntermediate Proposal a
p = forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False forall a. Tuner a -> Bool
tSuitableForIntermediateTuning (forall a. Proposal a -> Maybe (Tuner a)
prTuner Proposal a
p)

-- | Set the order of 'Proposal's in a 'Cycle'.
setOrder :: Order -> Cycle a -> Cycle a
setOrder :: forall a. Order -> Cycle a -> Cycle a
setOrder Order
o Cycle a
c = Cycle a
c {ccOrder :: Order
ccOrder = Order
o}

-- | Use all proposals, or use fast proposals only?
data IterationMode = AllProposals | FastProposals
  deriving (IterationMode -> IterationMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IterationMode -> IterationMode -> Bool
$c/= :: IterationMode -> IterationMode -> Bool
== :: IterationMode -> IterationMode -> Bool
$c== :: IterationMode -> IterationMode -> Bool
Eq)

-- | Replicate 'Proposal's according to their weights and possibly shuffle them.
prepareProposals :: StatefulGen g m => IterationMode -> Cycle a -> g -> m [Proposal a]
prepareProposals :: forall g (m :: * -> *) a.
StatefulGen g m =>
IterationMode -> Cycle a -> g -> m [Proposal a]
prepareProposals IterationMode
m (Cycle [Proposal a]
xs Order
o Bool
_ Bool
_) g
g =
  if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Proposal a]
ps
    then
      let msg :: String
msg = case IterationMode
m of
            IterationMode
FastProposals -> String
"no fast proposals found"
            IterationMode
AllProposals -> String
"no proposals found"
       in forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"prepareProposals: " forall a. Semigroup a => a -> a -> a
<> String
msg
    else case Order
o of
      Order
RandomO -> forall g (m :: * -> *) a. StatefulGen g m => [a] -> g -> m [a]
shuffle [Proposal a]
ps g
g
      Order
SequentialO -> forall (m :: * -> *) a. Monad m => a -> m a
return [Proposal a]
ps
      Order
RandomReversibleO -> do
        [Proposal a]
psR <- forall g (m :: * -> *) a. StatefulGen g m => [a] -> g -> m [a]
shuffle [Proposal a]
ps g
g
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Proposal a]
psR forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [a]
reverse [Proposal a]
psR
      Order
SequentialReversibleO -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Proposal a]
ps forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [a]
reverse [Proposal a]
ps
  where
    !ps :: [Proposal a]
ps =
      forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
        [ forall a. Int -> a -> [a]
replicate (PWeight -> Int
fromPWeight forall a b. (a -> b) -> a -> b
$ forall a. Proposal a -> PWeight
prWeight Proposal a
p) Proposal a
p
          | Proposal a
p <- [Proposal a]
xs,
            case IterationMode
m of
              IterationMode
AllProposals -> Bool
True
              -- Only use proposal if it is fast.
              IterationMode
FastProposals -> forall a. Proposal a -> PSpeed
prSpeed Proposal a
p forall a. Eq a => a -> a -> Bool
== PSpeed
PFast
        ]

-- The number of proposals depends on the order.
getNProposalsPerCycle :: IterationMode -> Cycle a -> Int
getNProposalsPerCycle :: forall a. IterationMode -> Cycle a -> Int
getNProposalsPerCycle IterationMode
m (Cycle [Proposal a]
xs Order
o Bool
_ Bool
_) = case Order
o of
  Order
RandomO -> Int
once
  Order
SequentialO -> Int
once
  Order
RandomReversibleO -> Int
2 forall a. Num a => a -> a -> a
* Int
once
  Order
SequentialReversibleO -> Int
2 forall a. Num a => a -> a -> a
* Int
once
  where
    xs' :: [Proposal a]
xs' = case IterationMode
m of
      IterationMode
AllProposals -> [Proposal a]
xs
      IterationMode
FastProposals -> forall a. (a -> Bool) -> [a] -> [a]
filter (\Proposal a
x -> forall a. Proposal a -> PSpeed
prSpeed Proposal a
x forall a. Eq a => a -> a -> Bool
== PSpeed
PFast) [Proposal a]
xs
    once :: Int
once = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (PWeight -> Int
fromPWeight forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Proposal a -> PWeight
prWeight) [Proposal a]
xs'

-- See 'tuneWithTuningParameters' and 'Tuner'.
tuneWithChainParameters ::
  TuningType ->
  Maybe AcceptanceRate ->
  Maybe (VB.Vector a) ->
  Proposal a ->
  Either String (Proposal a)
tuneWithChainParameters :: forall a.
TuningType
-> Maybe AcceptanceRate
-> Maybe (Vector a)
-> Proposal a
-> Either String (Proposal a)
tuneWithChainParameters TuningType
tt Maybe AcceptanceRate
mar Maybe (Vector a)
mxs Proposal a
p = case forall a. Proposal a -> Maybe (Tuner a)
prTuner Proposal a
p of
  Maybe (Tuner a)
Nothing -> forall a b. b -> Either a b
Right Proposal a
p
  Just (Tuner AcceptanceRate
t AuxiliaryTuningParameters
ts Bool
rt Bool
it TuningFunction a
fT AcceptanceRate
-> AuxiliaryTuningParameters -> Either String (PFunction a)
_) -> case (TuningType
tt, Bool
it, forall a. Proposal a -> PSpeed
prSpeed Proposal a
p) of
    (TuningType
IntermediateTuningFastProposalsOnly, Bool
True, PSpeed
PFast) -> Either String (Proposal a)
tuneIntermediate
    (TuningType
IntermediateTuningAllProposals, Bool
True, PSpeed
_) -> Either String (Proposal a)
tuneIntermediate
    (TuningType
NormalTuningFastProposalsOnly, Bool
_, PSpeed
PFast) -> Either String (Proposal a)
tuneNormally
    (TuningType
NormalTuningAllProposals, Bool
_, PSpeed
_) -> Either String (Proposal a)
tuneNormally
    (TuningType
LastTuningFastProposalsOnly, Bool
_, PSpeed
_) -> Either String (Proposal a)
tuneNormally
    (TuningType
LastTuningAllProposals, Bool
_, PSpeed
_) -> Either String (Proposal a)
tuneNormally
    (TuningType, Bool, PSpeed)
_ -> forall a b. b -> Either a b
Right Proposal a
p
    where
      hasTrace :: Bool
hasTrace = forall a. Maybe a -> Bool
isJust Maybe (Vector a)
mxs
      err :: a -> Either a b
err a
m = forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ a
"tuneWithChainParameters: " forall a. Semigroup a => a -> a -> a
<> a
m
      tuneIntermediate :: Either String (Proposal a)
tuneIntermediate =
        if Bool
hasTrace
          then forall {a} {b}. (Semigroup a, IsString a) => a -> Either a b
err String
"intermediate tuning but trace provided"
          else Either String (Proposal a)
tune
      tuneNormally :: Either String (Proposal a)
tuneNormally =
        if Bool
rt Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
hasTrace
          then forall {a} {b}. (Semigroup a, IsString a) => a -> Either a b
err String
"trace required"
          else Either String (Proposal a)
tune
      tune :: Either String (Proposal a)
tune =
        let (AcceptanceRate
t', AuxiliaryTuningParameters
ts') = TuningFunction a
fT TuningType
tt (forall a. Proposal a -> PDimension
prDimension Proposal a
p) Maybe AcceptanceRate
mar Maybe (Vector a)
mxs (AcceptanceRate
t, AuxiliaryTuningParameters
ts)
         in forall a.
AcceptanceRate
-> AuxiliaryTuningParameters
-> Proposal a
-> Either String (Proposal a)
tuneWithTuningParameters AcceptanceRate
t' AuxiliaryTuningParameters
ts' Proposal a
p

-- (_, False, Just _) ->
-- (True, _, Just _) ->
-- (False, True, Nothing) ->
-- _ ->

-- | Calculate acceptance rates and auto tunes the 'Proposal's in the 'Cycle'.
--
-- Do not change 'Proposal's that are not tuneable.
autoTuneCycle :: TuningType -> Acceptances (Proposal a) -> Maybe (VB.Vector a) -> Cycle a -> Cycle a
autoTuneCycle :: forall a.
TuningType
-> Acceptances (Proposal a)
-> Maybe (Vector a)
-> Cycle a
-> Cycle a
autoTuneCycle TuningType
tt Acceptances (Proposal a)
a Maybe (Vector a)
mxs Cycle a
c
  | forall a. Maybe a -> Bool
isJust Maybe (Vector a)
mxs Bool -> Bool -> Bool
&& Bool -> Bool
not (forall a. Cycle a -> Bool
ccRequireTrace Cycle a
c) = forall {a}. String -> a
err String
"trace provided but not required"
  | Bool
otherwise =
      if forall a. Ord a => [a] -> [a]
sort (forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall k. Acceptances k -> Map k Acceptance
fromAcceptances Acceptances (Proposal a)
a) forall a. Eq a => a -> a -> Bool
== forall a. Ord a => [a] -> [a]
sort [Proposal a]
ps
        then Cycle a
c {ccProposals :: [Proposal a]
ccProposals = forall a b. (a -> b) -> [a] -> [b]
map Proposal a -> Proposal a
tuneF [Proposal a]
ps}
        else forall {a}. String -> a
err String
"proposals in map and cycle do not match"
  where
    err :: String -> a
err String
msg = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"autoTuneCycle: " forall a. Semigroup a => a -> a -> a
<> String
msg
    ps :: [Proposal a]
ps = forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
c
    tuneF :: Proposal a -> Proposal a
tuneF Proposal a
p =
      let (Int
_, Int
_, Maybe AcceptanceRate
mar, Maybe AcceptanceRate
mtr) = forall k.
Ord k =>
k
-> Acceptances k
-> (Int, Int, Maybe AcceptanceRate, Maybe AcceptanceRate)
acceptanceRate Proposal a
p Acceptances (Proposal a)
a
          -- Favor the expected rate, if available.
          mr :: Maybe AcceptanceRate
mr = Maybe AcceptanceRate
mtr forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Maybe AcceptanceRate
mar
       in forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => String -> a
error forall a. a -> a
id forall a b. (a -> b) -> a -> b
$ forall a.
TuningType
-> Maybe AcceptanceRate
-> Maybe (Vector a)
-> Proposal a
-> Either String (Proposal a)
tuneWithChainParameters TuningType
tt Maybe AcceptanceRate
mr Maybe (Vector a)
mxs Proposal a
p

-- | Summarize the 'Proposal's in the 'Cycle'. Also report acceptance rates.
summarizeCycle :: IterationMode -> Acceptances (Proposal a) -> Cycle a -> BL.ByteString
summarizeCycle :: forall a.
IterationMode -> Acceptances (Proposal a) -> Cycle a -> ByteString
summarizeCycle IterationMode
m Acceptances (Proposal a)
a Cycle a
c =
  ByteString -> [ByteString] -> ByteString
BL.intercalate ByteString
"\n" forall a b. (a -> b) -> a -> b
$
    [ ByteString
"Summary of proposal(s) in cycle.",
      ByteString
nProposalsFullStr,
      Order -> ByteString
describeOrder (forall a. Cycle a -> Order
ccOrder Cycle a
c),
      ByteString
proposalHeader,
      ByteString
proposalHLine
    ]
      forall a. [a] -> [a] -> [a]
++ [ PName
-> PDescription
-> PWeight
-> Maybe AcceptanceRate
-> PDimension
-> (Int, Int, Maybe AcceptanceRate, Maybe AcceptanceRate)
-> ByteString
summarizeProposal
             (forall a. Proposal a -> PName
prName Proposal a
p)
             (forall a. Proposal a -> PDescription
prDescription Proposal a
p)
             (forall a. Proposal a -> PWeight
prWeight Proposal a
p)
             (forall a. Tuner a -> AcceptanceRate
tTuningParameter forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Proposal a -> Maybe (Tuner a)
prTuner Proposal a
p)
             (forall a. Proposal a -> PDimension
prDimension Proposal a
p)
             (Proposal a
-> (Int, Int, Maybe AcceptanceRate, Maybe AcceptanceRate)
ar Proposal a
p)
           | Proposal a
p <- [Proposal a]
ps
         ]
      forall a. [a] -> [a] -> [a]
++ [ByteString
proposalHLine]
  where
    ps :: [Proposal a]
ps = forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
c
    nProposals :: Int
nProposals = forall a. IterationMode -> Cycle a -> Int
getNProposalsPerCycle IterationMode
m Cycle a
c
    nProposalsStr :: ByteString
nProposalsStr = Builder -> ByteString
BB.toLazyByteString forall a b. (a -> b) -> a -> b
$ Int -> Builder
BB.intDec Int
nProposals
    nProposalsFullStr :: ByteString
nProposalsFullStr = case Int
nProposals of
      Int
1 -> ByteString
nProposalsStr forall a. Semigroup a => a -> a -> a
<> ByteString
" proposal is performed per iteration."
      Int
_ -> ByteString
nProposalsStr forall a. Semigroup a => a -> a -> a
<> ByteString
" proposals are performed per iterations."
    ar :: Proposal a
-> (Int, Int, Maybe AcceptanceRate, Maybe AcceptanceRate)
ar Proposal a
pr = forall k.
Ord k =>
k
-> Acceptances k
-> (Int, Int, Maybe AcceptanceRate, Maybe AcceptanceRate)
acceptanceRate Proposal a
pr Acceptances (Proposal a)
a
    proposalHLine :: ByteString
proposalHLine = Int64 -> Char -> ByteString
BL.replicate (ByteString -> Int64
BL.length ByteString
proposalHeader) Char
'-'