{-# LANGUAGE TemplateHaskell #-}

-- |
-- Module      :  Mcmc.Acceptance
-- Description :  Handle acceptance rates
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Creation date: Thu Jul  8 18:12:07 2021.
module Mcmc.Acceptance
  ( -- * Acceptance rates
    AcceptanceRate,
    AcceptanceCounts (..),
    AcceptanceRates (..),
    Acceptance,
    Acceptances (fromAcceptances),
    emptyA,
    pushAccept,
    pushReject,
    ResetAcceptance (..),
    resetA,
    transformKeysA,
    acceptanceRate,
    acceptanceRates,
  )
where

import Data.Aeson
import Data.Aeson.TH
import Data.Foldable
import qualified Data.Map.Strict as M

-- | Acceptance rate.
type AcceptanceRate = Double

-- | Number of accepted and rejected proposals.
data AcceptanceCounts = AcceptanceCounts
  { AcceptanceCounts -> Int
nAccepted :: !Int,
    AcceptanceCounts -> Int
nRejected :: !Int
  }
  deriving (Int -> AcceptanceCounts -> ShowS
[AcceptanceCounts] -> ShowS
AcceptanceCounts -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AcceptanceCounts] -> ShowS
$cshowList :: [AcceptanceCounts] -> ShowS
show :: AcceptanceCounts -> String
$cshow :: AcceptanceCounts -> String
showsPrec :: Int -> AcceptanceCounts -> ShowS
$cshowsPrec :: Int -> AcceptanceCounts -> ShowS
Show, AcceptanceCounts -> AcceptanceCounts -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AcceptanceCounts -> AcceptanceCounts -> Bool
$c/= :: AcceptanceCounts -> AcceptanceCounts -> Bool
== :: AcceptanceCounts -> AcceptanceCounts -> Bool
$c== :: AcceptanceCounts -> AcceptanceCounts -> Bool
Eq, Eq AcceptanceCounts
AcceptanceCounts -> AcceptanceCounts -> Bool
AcceptanceCounts -> AcceptanceCounts -> Ordering
AcceptanceCounts -> AcceptanceCounts -> AcceptanceCounts
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: AcceptanceCounts -> AcceptanceCounts -> AcceptanceCounts
$cmin :: AcceptanceCounts -> AcceptanceCounts -> AcceptanceCounts
max :: AcceptanceCounts -> AcceptanceCounts -> AcceptanceCounts
$cmax :: AcceptanceCounts -> AcceptanceCounts -> AcceptanceCounts
>= :: AcceptanceCounts -> AcceptanceCounts -> Bool
$c>= :: AcceptanceCounts -> AcceptanceCounts -> Bool
> :: AcceptanceCounts -> AcceptanceCounts -> Bool
$c> :: AcceptanceCounts -> AcceptanceCounts -> Bool
<= :: AcceptanceCounts -> AcceptanceCounts -> Bool
$c<= :: AcceptanceCounts -> AcceptanceCounts -> Bool
< :: AcceptanceCounts -> AcceptanceCounts -> Bool
$c< :: AcceptanceCounts -> AcceptanceCounts -> Bool
compare :: AcceptanceCounts -> AcceptanceCounts -> Ordering
$ccompare :: AcceptanceCounts -> AcceptanceCounts -> Ordering
Ord)

$(deriveJSON defaultOptions ''AcceptanceCounts)

-- | Proposals based on Hamiltonian dynamics use expected acceptance rates, not counts.
data AcceptanceRates = AcceptanceRates
  { AcceptanceRates -> Double
totalAcceptanceRate :: !Double,
    AcceptanceRates -> Int
nAcceptanceRates :: !Int
  }
  deriving (Int -> AcceptanceRates -> ShowS
[AcceptanceRates] -> ShowS
AcceptanceRates -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AcceptanceRates] -> ShowS
$cshowList :: [AcceptanceRates] -> ShowS
show :: AcceptanceRates -> String
$cshow :: AcceptanceRates -> String
showsPrec :: Int -> AcceptanceRates -> ShowS
$cshowsPrec :: Int -> AcceptanceRates -> ShowS
Show, AcceptanceRates -> AcceptanceRates -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AcceptanceRates -> AcceptanceRates -> Bool
$c/= :: AcceptanceRates -> AcceptanceRates -> Bool
== :: AcceptanceRates -> AcceptanceRates -> Bool
$c== :: AcceptanceRates -> AcceptanceRates -> Bool
Eq)

$(deriveJSON defaultOptions ''AcceptanceRates)

-- | Stored actual acceptance counts and maybe expected acceptance rates.
data Acceptance = A AcceptanceCounts (Maybe AcceptanceRates)
  deriving (Int -> Acceptance -> ShowS
[Acceptance] -> ShowS
Acceptance -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Acceptance] -> ShowS
$cshowList :: [Acceptance] -> ShowS
show :: Acceptance -> String
$cshow :: Acceptance -> String
showsPrec :: Int -> Acceptance -> ShowS
$cshowsPrec :: Int -> Acceptance -> ShowS
Show, Acceptance -> Acceptance -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Acceptance -> Acceptance -> Bool
$c/= :: Acceptance -> Acceptance -> Bool
== :: Acceptance -> Acceptance -> Bool
$c== :: Acceptance -> Acceptance -> Bool
Eq)

$(deriveJSON defaultOptions ''Acceptance)

addAccept :: Maybe AcceptanceRates -> Acceptance -> Acceptance
addAccept :: Maybe AcceptanceRates -> Acceptance -> Acceptance
addAccept Maybe AcceptanceRates
mr' (A (AcceptanceCounts Int
a Int
r) Maybe AcceptanceRates
mr) = AcceptanceCounts -> Maybe AcceptanceRates -> Acceptance
A (Int -> Int -> AcceptanceCounts
AcceptanceCounts (Int
a forall a. Num a => a -> a -> a
+ Int
1) Int
r) (Maybe AcceptanceRates
-> Maybe AcceptanceRates -> Maybe AcceptanceRates
addAcceptanceRates Maybe AcceptanceRates
mr' Maybe AcceptanceRates
mr)

addReject :: Maybe AcceptanceRates -> Acceptance -> Acceptance
addReject :: Maybe AcceptanceRates -> Acceptance -> Acceptance
addReject Maybe AcceptanceRates
mr' (A (AcceptanceCounts Int
a Int
r) Maybe AcceptanceRates
mr) = AcceptanceCounts -> Maybe AcceptanceRates -> Acceptance
A (Int -> Int -> AcceptanceCounts
AcceptanceCounts Int
a (Int
r forall a. Num a => a -> a -> a
+ Int
1)) (Maybe AcceptanceRates
-> Maybe AcceptanceRates -> Maybe AcceptanceRates
addAcceptanceRates Maybe AcceptanceRates
mr' Maybe AcceptanceRates
mr)

addAcceptanceRates :: Maybe AcceptanceRates -> Maybe AcceptanceRates -> Maybe AcceptanceRates
addAcceptanceRates :: Maybe AcceptanceRates
-> Maybe AcceptanceRates -> Maybe AcceptanceRates
addAcceptanceRates Maybe AcceptanceRates
Nothing Maybe AcceptanceRates
Nothing = forall a. Maybe a
Nothing
addAcceptanceRates (Just AcceptanceRates
r) Maybe AcceptanceRates
Nothing = forall a. a -> Maybe a
Just AcceptanceRates
r
addAcceptanceRates Maybe AcceptanceRates
Nothing (Just AcceptanceRates
r) = forall a. a -> Maybe a
Just AcceptanceRates
r
addAcceptanceRates (Just (AcceptanceRates Double
al Int
rl)) (Just (AcceptanceRates Double
ar Int
rr)) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Double -> Int -> AcceptanceRates
AcceptanceRates (Double
al forall a. Num a => a -> a -> a
+ Double
ar) (Int
rl forall a. Num a => a -> a -> a
+ Int
rr)

-- | For each key @k@, store the number of accepted and rejected proposals.
newtype Acceptances k = Acceptances {forall k. Acceptances k -> Map k Acceptance
fromAcceptances :: M.Map k Acceptance}
  deriving (Acceptances k -> Acceptances k -> Bool
forall k. Eq k => Acceptances k -> Acceptances k -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Acceptances k -> Acceptances k -> Bool
$c/= :: forall k. Eq k => Acceptances k -> Acceptances k -> Bool
== :: Acceptances k -> Acceptances k -> Bool
$c== :: forall k. Eq k => Acceptances k -> Acceptances k -> Bool
Eq, Int -> Acceptances k -> ShowS
forall k. Show k => Int -> Acceptances k -> ShowS
forall k. Show k => [Acceptances k] -> ShowS
forall k. Show k => Acceptances k -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Acceptances k] -> ShowS
$cshowList :: forall k. Show k => [Acceptances k] -> ShowS
show :: Acceptances k -> String
$cshow :: forall k. Show k => Acceptances k -> String
showsPrec :: Int -> Acceptances k -> ShowS
$cshowsPrec :: forall k. Show k => Int -> Acceptances k -> ShowS
Show)

instance ToJSONKey k => ToJSON (Acceptances k) where
  toJSON :: Acceptances k -> Value
toJSON (Acceptances Map k Acceptance
m) = forall a. ToJSON a => a -> Value
toJSON Map k Acceptance
m
  toEncoding :: Acceptances k -> Encoding
toEncoding (Acceptances Map k Acceptance
m) = forall a. ToJSON a => a -> Encoding
toEncoding Map k Acceptance
m

instance (Ord k, FromJSONKey k) => FromJSON (Acceptances k) where
  parseJSON :: Value -> Parser (Acceptances k)
parseJSON Value
v = forall k. Map k Acceptance -> Acceptances k
Acceptances forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. FromJSON a => Value -> Parser a
parseJSON Value
v

-- | In the beginning there was the Word.
--
-- Initialize an empty storage of accepted/rejected values.
emptyA :: Ord k => [k] -> Acceptances k
emptyA :: forall k. Ord k => [k] -> Acceptances k
emptyA [k]
ks = forall k. Map k Acceptance -> Acceptances k
Acceptances forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(k
k, AcceptanceCounts -> Maybe AcceptanceRates -> Acceptance
A AcceptanceCounts
noCounts forall a. Maybe a
Nothing) | k
k <- [k]
ks]
  where
    noCounts :: AcceptanceCounts
noCounts = Int -> Int -> AcceptanceCounts
AcceptanceCounts Int
0 Int
0

-- | For key @k@, add an accept.
pushAccept :: Ord k => Maybe AcceptanceRates -> k -> Acceptances k -> Acceptances k
pushAccept :: forall k.
Ord k =>
Maybe AcceptanceRates -> k -> Acceptances k -> Acceptances k
pushAccept Maybe AcceptanceRates
mr k
k = forall k. Map k Acceptance -> Acceptances k
Acceptances forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
M.adjust (Maybe AcceptanceRates -> Acceptance -> Acceptance
addAccept Maybe AcceptanceRates
mr) k
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k. Acceptances k -> Map k Acceptance
fromAcceptances

-- | For key @k@, add a reject.
pushReject :: Ord k => Maybe AcceptanceRates -> k -> Acceptances k -> Acceptances k
pushReject :: forall k.
Ord k =>
Maybe AcceptanceRates -> k -> Acceptances k -> Acceptances k
pushReject Maybe AcceptanceRates
mr k
k = forall k. Map k Acceptance -> Acceptances k
Acceptances forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
M.adjust (Maybe AcceptanceRates -> Acceptance -> Acceptance
addReject Maybe AcceptanceRates
mr) k
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k. Acceptances k -> Map k Acceptance
fromAcceptances

-- | Reset acceptance specification.
data ResetAcceptance
  = -- | Reset actual acceptance counts and expected acceptance rates.
    ResetEverything
  | -- | Only reset expected acceptance rates.
    ResetExpectedRatesOnly

-- | Reset acceptance counts.
resetA :: Ord k => ResetAcceptance -> Acceptances k -> Acceptances k
resetA :: forall k.
Ord k =>
ResetAcceptance -> Acceptances k -> Acceptances k
resetA ResetAcceptance
ResetEverything = forall k. Ord k => [k] -> Acceptances k
emptyA forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [k]
M.keys forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k. Acceptances k -> Map k Acceptance
fromAcceptances
resetA ResetAcceptance
ResetExpectedRatesOnly = forall k. Map k Acceptance -> Acceptances k
Acceptances forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b k. (a -> b) -> Map k a -> Map k b
M.map Acceptance -> Acceptance
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k. Acceptances k -> Map k Acceptance
fromAcceptances
  where
    f :: Acceptance -> Acceptance
f (A AcceptanceCounts
cs Maybe AcceptanceRates
_) = AcceptanceCounts -> Maybe AcceptanceRates -> Acceptance
A AcceptanceCounts
cs forall a. Maybe a
Nothing

transformKeys :: (Ord k1, Ord k2) => [(k1, k2)] -> M.Map k1 v -> M.Map k2 v
transformKeys :: forall k1 k2 v.
(Ord k1, Ord k2) =>
[(k1, k2)] -> Map k1 v -> Map k2 v
transformKeys [(k1, k2)]
ks Map k1 v
m = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {k}. Ord k => Map k v -> (k1, k) -> Map k v
insrt forall k a. Map k a
M.empty [(k1, k2)]
ks
  where
    insrt :: Map k v -> (k1, k) -> Map k v
insrt Map k v
m' (k1
k1, k
k2) = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
k2 (Map k1 v
m forall k a. Ord k => Map k a -> k -> a
M.! k1
k1) Map k v
m'

-- | Transform keys using the given lists. Keys not provided will not be present
-- in the new 'Acceptance' variable.
transformKeysA :: (Ord k1, Ord k2) => [(k1, k2)] -> Acceptances k1 -> Acceptances k2
transformKeysA :: forall k1 k2.
(Ord k1, Ord k2) =>
[(k1, k2)] -> Acceptances k1 -> Acceptances k2
transformKeysA [(k1, k2)]
ks = forall k. Map k Acceptance -> Acceptances k
Acceptances forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k1 k2 v.
(Ord k1, Ord k2) =>
[(k1, k2)] -> Map k1 v -> Map k2 v
transformKeys [(k1, k2)]
ks forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k. Acceptances k -> Map k Acceptance
fromAcceptances

-- | Compute acceptance counts, and actual and expected acceptances rates for a
-- specific proposal.
--
-- Return @Just (accepts, rejects, acceptance rate)@.
--
-- Return 'Nothing' if no proposals have been accepted or rejected (division by
-- zero).
acceptanceRate ::
  Ord k =>
  k ->
  Acceptances k ->
  -- | (nAccepts, nRejects, actualRate, expectedRate)
  (Int, Int, Maybe AcceptanceRate, Maybe AcceptanceRate)
acceptanceRate :: forall k.
Ord k =>
k -> Acceptances k -> (Int, Int, Maybe Double, Maybe Double)
acceptanceRate k
k Acceptances k
a = case forall k. Acceptances k -> Map k Acceptance
fromAcceptances Acceptances k
a forall k a. Ord k => Map k a -> k -> Maybe a
M.!? k
k of
  Just (A (AcceptanceCounts Int
as Int
rs) Maybe AcceptanceRates
mrs) -> (Int
as, Int
rs, Maybe Double
mar, Maybe Double
mtr)
    where
      s :: Int
s = Int
as forall a. Num a => a -> a -> a
+ Int
rs
      mar :: Maybe Double
mar = if Int
s forall a. Ord a => a -> a -> Bool
<= Int
0 then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
as forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
s
      mtr :: Maybe Double
mtr = case Maybe AcceptanceRates
mrs of
        Maybe AcceptanceRates
Nothing -> forall a. Maybe a
Nothing
        Just (AcceptanceRates Double
xs Int
n) -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Double
xs forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
  Maybe Acceptance
Nothing -> forall a. HasCallStack => String -> a
error String
"acceptanceRate: Key not found in map."

-- | Compute actual acceptance rates for all proposals.
--
-- Set rate to 'Nothing' if no proposals have been accepted or rejected
-- (division by zero).
acceptanceRates :: Acceptances k -> M.Map k (Maybe AcceptanceRate)
acceptanceRates :: forall k. Acceptances k -> Map k (Maybe Double)
acceptanceRates = forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall {a}. Fractional a => Acceptance -> Maybe a
getRate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k. Acceptances k -> Map k Acceptance
fromAcceptances
  where
    getRate :: Acceptance -> Maybe a
getRate (A (AcceptanceCounts Int
as Int
rs) Maybe AcceptanceRates
_) =
      let s :: Int
s = Int
as forall a. Num a => a -> a -> a
+ Int
rs
       in if Int
s forall a. Ord a => a -> a -> Bool
<= Int
0
            then forall a. Maybe a
Nothing
            else forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
as forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
s