{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- |
-- Module      :  Mcmc.Proposal.Hamiltonian.Hamiltonian
-- Description :  Hamiltonian Monte Carlo proposal
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Creation date: Mon Jul  5 12:59:42 2021.
--
-- The Hamiltonian Monte Carlo (HMC) proposal.
--
-- The HMC proposal acts on 'Positions', a vector of floating point values. The
-- manipulated values can represent the complete state, or a subset of the
-- complete state. Functions converting the state to and from this vector have
-- to be provided; see 'HStructure'.
--
-- Even though the proposal may only act on a subset of the complete state, the
-- prior, likelihood, and Jacobian functions of the complete state have to be
-- provided; see 'HTarget'. This is because parameters not manipulated by the
-- HMC proposal still influence the prior, likelihood and Jacobian functions.
--
-- The points given above have implications on how the HMC proposal is handled:
-- Do not use 'liftProposalWith', 'liftProposal', or '(@~)' with the HMC
-- proposal; instead provide proper conversion functions with 'HStructure'.
--
-- The gradient of the log target function is calculated using automatic
-- differentiation; see the excellent
-- [ad](https://hackage.haskell.org/package/ad) package.
--
-- The desired acceptance rate is 0.65, although the dimension of the proposal
-- is high.
--
-- The speed of this proposal changes drastically with the leapfrog trajectory
-- length and the leapfrog scaling factor. Hence, the speed will change during
-- burn in.
--
-- References:
--
-- - [1] Chapter 5 of Handbook of Monte Carlo: Neal, R. M., MCMC Using
--   Hamiltonian Dynamics, In S. Brooks, A. Gelman, G. Jones, & X. Meng (Eds.),
--   Handbook of Markov Chain Monte Carlo (2011), CRC press.
--
-- - [2] Gelman, A., Carlin, J. B., Stern, H. S., & Rubin, D. B., Bayesian data
--   analysis (2014), CRC Press.
--
-- - [3] Review by Betancourt and notes: Betancourt, M., A conceptual
--   introduction to Hamiltonian Monte Carlo, arXiv, 1701–02434 (2017).
--
-- - [4] Matthew D. Hoffman, Andrew Gelman (2014) The No-U-Turn Sampler:
--   Adaptively Setting Path Lengths in Hamiltonian Monte Carlo, Journal of
--   Machine Learning Research.
module Mcmc.Proposal.Hamiltonian.Hamiltonian
  ( -- * Hamiltonian Monte Carlo proposal
    HParams (..),
    defaultHParams,
    hamiltonian,
  )
where

import Data.Bifunctor
import Mcmc.Acceptance
import Mcmc.Algorithm.MHG
import Mcmc.Proposal
import Mcmc.Proposal.Hamiltonian.Common
import Mcmc.Proposal.Hamiltonian.Internal
import Mcmc.Proposal.Hamiltonian.Masses
import Numeric.AD.Double
import qualified Numeric.LinearAlgebra as L
import Numeric.Log
import System.Random.Stateful

-- | Parameters of the Hamilton Monte Carlo proposal.
--
-- If a parameter is 'Nothing', a default value is used (see 'defaultHParams').
data HParams = HParams
  { HParams -> Maybe LeapfrogScalingFactor
hLeapfrogScalingFactor :: Maybe LeapfrogScalingFactor,
    HParams -> Maybe LeapfrogScalingFactor
hLeapfrogSimulationLength :: Maybe LeapfrogSimulationLength,
    HParams -> Maybe Masses
hMasses :: Maybe Masses
  }
  deriving (Int -> HParams -> ShowS
[HParams] -> ShowS
HParams -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HParams] -> ShowS
$cshowList :: [HParams] -> ShowS
show :: HParams -> String
$cshow :: HParams -> String
showsPrec :: Int -> HParams -> ShowS
$cshowsPrec :: Int -> HParams -> ShowS
Show)

-- | Default parameters.
--
-- - Estimate a reasonable leapfrog scaling factor using Algorithm 4 [4]. If all
--   fails, use 0.1.
--
-- - Leapfrog simulation length is set to 0.5.
--
-- - The mass matrix is set to the identity matrix.
defaultHParams :: HParams
defaultHParams :: HParams
defaultHParams = Maybe LeapfrogScalingFactor
-> Maybe LeapfrogScalingFactor -> Maybe Masses -> HParams
HParams forall a. Maybe a
Nothing forall a. Maybe a
Nothing forall a. Maybe a
Nothing

hamiltonianPFunctionWithTuningParameters ::
  Traversable s =>
  Dimension ->
  HStructure s ->
  (s Double -> Target) ->
  TuningParameter ->
  AuxiliaryTuningParameters ->
  Either String (PFunction (s Double))
hamiltonianPFunctionWithTuningParameters :: forall (s :: * -> *).
Traversable s =>
Int
-> HStructure s
-> (s LeapfrogScalingFactor -> Target)
-> LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> Either String (PFunction (s LeapfrogScalingFactor))
hamiltonianPFunctionWithTuningParameters Int
d HStructure s
hstruct s LeapfrogScalingFactor -> Target
targetWith LeapfrogScalingFactor
_ AuxiliaryTuningParameters
ts = do
  HParamsI
hParamsI <- Int -> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters Int
d AuxiliaryTuningParameters
ts
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (s :: * -> *).
HParamsI
-> HStructure s
-> (s LeapfrogScalingFactor -> Target)
-> PFunction (s LeapfrogScalingFactor)
hamiltonianPFunction HParamsI
hParamsI HStructure s
hstruct s LeapfrogScalingFactor -> Target
targetWith

-- The inverted covariance matrix and the log determinant of the covariance
-- matrix are calculated by 'hamiltonianPFunction'.
hamiltonianPFunction ::
  HParamsI ->
  HStructure s ->
  (s Double -> Target) ->
  PFunction (s Double)
hamiltonianPFunction :: forall (s :: * -> *).
HParamsI
-> HStructure s
-> (s LeapfrogScalingFactor -> Target)
-> PFunction (s LeapfrogScalingFactor)
hamiltonianPFunction HParamsI
hparamsi HStructure s
hstruct s LeapfrogScalingFactor -> Target
targetWith s LeapfrogScalingFactor
x IOGenM StdGen
g = do
  Momenta
p <- forall g (m :: * -> *).
StatefulGen g m =>
Momenta -> Masses -> g -> m Momenta
generateMomenta Momenta
mus Masses
ms IOGenM StdGen
g
  LeapfrogScalingFactor
eRan <- forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (LeapfrogScalingFactor
eL, LeapfrogScalingFactor
eR) IOGenM StdGen
g
  -- NOTE: The NUTS paper does not sample l since l varies naturally because
  -- of epsilon. I still think it should vary because otherwise, there may be
  -- dragons due to periodicity.
  let lM :: LeapfrogScalingFactor
lM = LeapfrogScalingFactor
la forall a. Fractional a => a -> a -> a
/ LeapfrogScalingFactor
eRan
      lL :: Int
lL = forall a. Ord a => a -> a -> a
max (Int
1 :: Int) (forall a b. (RealFrac a, Integral b) => a -> b
floor forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
0.9 forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
lM)
      lR :: Int
lR = forall a. Ord a => a -> a -> a
max Int
lL (forall a b. (RealFrac a, Integral b) => a -> b
ceiling forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
1.1 forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
lM)
  Int
lRan <- forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Int
lL, Int
lR) IOGenM StdGen
g
  case Target
-> MassesI
-> Int
-> LeapfrogScalingFactor
-> Momenta
-> Momenta
-> Maybe
     (Momenta, Momenta, Log LeapfrogScalingFactor,
      Log LeapfrogScalingFactor)
leapfrog (s LeapfrogScalingFactor -> Target
targetWith s LeapfrogScalingFactor
x) MassesI
msI Int
lRan LeapfrogScalingFactor
eRan Momenta
q Momenta
p of
    -- NOTE: I am not sure if it is correct to set the expected acceptance rate
    -- to 0 when the leapfrog integrator fails.
    Maybe
  (Momenta, Momenta, Log LeapfrogScalingFactor,
   Log LeapfrogScalingFactor)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. PResult a
ForceReject, forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor -> Int -> AcceptanceRates
AcceptanceRates LeapfrogScalingFactor
0 Int
1)
    -- Check if next state is accepted here, because the Jacobian is included in
    -- the target function. If not: pure (x, 0.0, 1.0).
    Just (Momenta
q', Momenta
p', Log LeapfrogScalingFactor
prQ, Log LeapfrogScalingFactor
prQ') -> do
      let -- Prior of momenta.
          prP :: Log LeapfrogScalingFactor
prP = MassesI -> Momenta -> Log LeapfrogScalingFactor
exponentialKineticEnergy MassesI
msI Momenta
p
          prP' :: Log LeapfrogScalingFactor
prP' = MassesI -> Momenta -> Log LeapfrogScalingFactor
exponentialKineticEnergy MassesI
msI Momenta
p'
          r :: Log LeapfrogScalingFactor
r = Log LeapfrogScalingFactor
prQ' forall a. Num a => a -> a -> a
* Log LeapfrogScalingFactor
prP' forall a. Fractional a => a -> a -> a
/ (Log LeapfrogScalingFactor
prQ forall a. Num a => a -> a -> a
* Log LeapfrogScalingFactor
prP)
      Bool
accept <- Log LeapfrogScalingFactor -> IOGenM StdGen -> IO Bool
mhgAccept Log LeapfrogScalingFactor
r IOGenM StdGen
g
      -- NOTE: For example, Neal page 12: In order for the Hamiltonian proposal
      -- to be in detailed balance, the momenta have to be negated before
      -- proposing the new value. That is, the negated momenta would guide the
      -- chain back to the previous state. However, we are only interested in
      -- the positions, and are not even storing the momenta.
      let pr :: PResult (s LeapfrogScalingFactor)
pr = if Bool
accept then forall a. a -> PResult a
ForceAccept (s LeapfrogScalingFactor -> Momenta -> s LeapfrogScalingFactor
fromVec s LeapfrogScalingFactor
x Momenta
q') else forall a. PResult a
ForceReject
          -- Limit expected acceptance rate between 0 and 1.
          ar :: LeapfrogScalingFactor
ar = forall a. Ord a => a -> a -> a
max LeapfrogScalingFactor
0 forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a
min LeapfrogScalingFactor
1 (forall a. Floating a => a -> a
exp forall a b. (a -> b) -> a -> b
$ forall a. Log a -> a
ln Log LeapfrogScalingFactor
r)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (PResult (s LeapfrogScalingFactor)
pr, forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor -> Int -> AcceptanceRates
AcceptanceRates LeapfrogScalingFactor
ar Int
1)
  where
    (HParamsI LeapfrogScalingFactor
e LeapfrogScalingFactor
la Masses
ms TParamsVar
_ TParamsFixed
_ MassesI
msI Momenta
mus) = HParamsI
hparamsi
    (HStructure s LeapfrogScalingFactor
_ s LeapfrogScalingFactor -> Momenta
toVec s LeapfrogScalingFactor -> Momenta -> s LeapfrogScalingFactor
fromVec) = HStructure s
hstruct
    q :: Momenta
q = s LeapfrogScalingFactor -> Momenta
toVec s LeapfrogScalingFactor
x
    eL :: LeapfrogScalingFactor
eL = LeapfrogScalingFactor
0.9 forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
e
    eR :: LeapfrogScalingFactor
eR = LeapfrogScalingFactor
1.1 forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
e

-- | Hamiltonian Monte Carlo proposal.
--
-- The structure of the state is denoted as @s@.
--
-- May call 'error' during initialization.
hamiltonian ::
  Traversable s =>
  HParams ->
  HTuningConf ->
  HStructure s ->
  HTarget s ->
  PName ->
  PWeight ->
  Proposal (s Double)
hamiltonian :: forall (s :: * -> *).
Traversable s =>
HParams
-> HTuningConf
-> HStructure s
-> HTarget s
-> PName
-> PWeight
-> Proposal (s LeapfrogScalingFactor)
hamiltonian HParams
hparams HTuningConf
htconf HStructure s
hstruct HTarget s
htarget PName
n PWeight
w =
  let -- Misc.
      desc :: PDescription
desc = String -> PDescription
PDescription String
"Hamiltonian Monte Carlo (HMC)"
      (HStructure s LeapfrogScalingFactor
sample s LeapfrogScalingFactor -> Momenta
toVec s LeapfrogScalingFactor -> Momenta -> s LeapfrogScalingFactor
fromVec) = HStructure s
hstruct
      dim :: IndexOf Vector
dim = forall (c :: * -> *) t. Container c t => c t -> IndexOf c
L.size forall a b. (a -> b) -> a -> b
$ s LeapfrogScalingFactor -> Momenta
toVec s LeapfrogScalingFactor
sample
      -- See bottom of page 1615 in [4].
      pDim :: PDimension
pDim = Int -> LeapfrogScalingFactor -> PDimension
PSpecial Int
dim LeapfrogScalingFactor
0.65
      -- Vectorize and derive the target function.
      (HTarget forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mPrF forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mJcF) = HTarget s
htarget
      tF :: s a -> Log a
tF s a
y = case (forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mPrF, forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mJcF) of
        (Maybe (s a -> Log a)
Nothing, Maybe (s a -> Log a)
Nothing) -> forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y
        (Just s a -> Log a
prF, Maybe (s a -> Log a)
Nothing) -> s a -> Log a
prF s a
y forall a. Num a => a -> a -> a
* forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y
        (Maybe (s a -> Log a)
Nothing, Just s a -> Log a
jcF) -> forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y forall a. Num a => a -> a -> a
* s a -> Log a
jcF s a
y
        (Just s a -> Log a
prF, Just s a -> Log a
jcF) -> s a -> Log a
prF s a
y forall a. Num a => a -> a -> a
* forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y forall a. Num a => a -> a -> a
* s a -> Log a
jcF s a
y
      tFnG :: s LeapfrogScalingFactor
-> (LeapfrogScalingFactor, s LeapfrogScalingFactor)
tFnG = forall (f :: * -> *).
Traversable f =>
(forall s.
 (Reifies s Tape, Typeable s) =>
 f (ReverseDouble s) -> ReverseDouble s)
-> f LeapfrogScalingFactor
-> (LeapfrogScalingFactor, f LeapfrogScalingFactor)
grad' (forall a. Log a -> a
ln forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
tF)
      targetWith :: s LeapfrogScalingFactor -> Target
targetWith s LeapfrogScalingFactor
x = forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall a. a -> Log a
Exp s LeapfrogScalingFactor -> Momenta
toVec forall b c a. (b -> c) -> (a -> b) -> a -> c
. s LeapfrogScalingFactor
-> (LeapfrogScalingFactor, s LeapfrogScalingFactor)
tFnG forall b c a. (b -> c) -> (a -> b) -> a -> c
. s LeapfrogScalingFactor -> Momenta -> s LeapfrogScalingFactor
fromVec s LeapfrogScalingFactor
x
      (HParams Maybe LeapfrogScalingFactor
mEps Maybe LeapfrogScalingFactor
mLa Maybe Masses
mMs) = HParams
hparams
      hParamsI :: HParamsI
hParamsI =
        forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => String -> a
error forall a. a -> a
id forall a b. (a -> b) -> a -> b
$
          Target
-> Momenta
-> Maybe LeapfrogScalingFactor
-> Maybe LeapfrogScalingFactor
-> Maybe Masses
-> Either String HParamsI
hParamsIWith (s LeapfrogScalingFactor -> Target
targetWith s LeapfrogScalingFactor
sample) (s LeapfrogScalingFactor -> Momenta
toVec s LeapfrogScalingFactor
sample) Maybe LeapfrogScalingFactor
mEps Maybe LeapfrogScalingFactor
mLa Maybe Masses
mMs
      ps :: PFunction (s LeapfrogScalingFactor)
ps = forall (s :: * -> *).
HParamsI
-> HStructure s
-> (s LeapfrogScalingFactor -> Target)
-> PFunction (s LeapfrogScalingFactor)
hamiltonianPFunction HParamsI
hParamsI HStructure s
hstruct s LeapfrogScalingFactor -> Target
targetWith
      hamiltonianWith :: Maybe (Tuner (s LeapfrogScalingFactor))
-> Proposal (s LeapfrogScalingFactor)
hamiltonianWith = forall a.
PName
-> PDescription
-> PSpeed
-> PDimension
-> PWeight
-> PFunction a
-> Maybe (Tuner a)
-> Proposal a
Proposal PName
n PDescription
desc PSpeed
PSlow PDimension
pDim PWeight
w PFunction (s LeapfrogScalingFactor)
ps
      -- Tuning.
      ts :: AuxiliaryTuningParameters
ts = HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters HParamsI
hParamsI
      tuner :: Maybe (Tuner (s LeapfrogScalingFactor))
tuner = do
        TuningFunction (s LeapfrogScalingFactor)
tfun <- forall a.
Int -> (a -> Momenta) -> HTuningConf -> Maybe (TuningFunction a)
hTuningFunctionWith Int
dim s LeapfrogScalingFactor -> Momenta
toVec HTuningConf
htconf
        let pfun :: LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> Either String (PFunction (s LeapfrogScalingFactor))
pfun = forall (s :: * -> *).
Traversable s =>
Int
-> HStructure s
-> (s LeapfrogScalingFactor -> Target)
-> LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> Either String (PFunction (s LeapfrogScalingFactor))
hamiltonianPFunctionWithTuningParameters Int
dim HStructure s
hstruct s LeapfrogScalingFactor -> Target
targetWith
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a.
LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> Bool
-> Bool
-> TuningFunction a
-> (LeapfrogScalingFactor
    -> AuxiliaryTuningParameters -> Either String (PFunction a))
-> Tuner a
Tuner LeapfrogScalingFactor
1.0 AuxiliaryTuningParameters
ts Bool
True Bool
True TuningFunction (s LeapfrogScalingFactor)
tfun LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> Either String (PFunction (s LeapfrogScalingFactor))
pfun
   in case forall (s :: * -> *).
Foldable s =>
Masses -> HStructure s -> Maybe String
checkHStructureWith (HParamsI -> Masses
hpsMasses HParamsI
hParamsI) HStructure s
hstruct of
        Just String
err -> forall a. HasCallStack => String -> a
error String
err
        Maybe String
Nothing -> Maybe (Tuner (s LeapfrogScalingFactor))
-> Proposal (s LeapfrogScalingFactor)
hamiltonianWith Maybe (Tuner (s LeapfrogScalingFactor))
tuner