-- |
-- Module      :  Mcmc.Proposal.Hamiltonian.Nuts
-- Description :  No-U-Turn sampler (NUTS)
-- Copyright   :  2022 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Creation date: Fri May 27 09:58:23 2022.
--
-- For a general introduction to Hamiltonian proposals, see
-- "Mcmc.Proposal.Hamiltonian.Hamiltonian".
--
-- This module implements the No-U-Turn Sampler (NUTS), as described in [4].
--
-- Work in progress.
--
-- References:
--
-- - [1] Chapter 5 of Handbook of Monte Carlo: Neal, R. M., MCMC Using
--   Hamiltonian Dynamics, In S. Brooks, A. Gelman, G. Jones, & X. Meng (Eds.),
--   Handbook of Markov Chain Monte Carlo (2011), CRC press.
--
-- - [2] Gelman, A., Carlin, J. B., Stern, H. S., & Rubin, D. B., Bayesian data
--   analysis (2014), CRC Press.
--
-- - [3] Review by Betancourt and notes: Betancourt, M., A conceptual
--   introduction to Hamiltonian Monte Carlo, arXiv, 1701–02434 (2017).
--
-- - [4] Matthew D. Hoffman, Andrew Gelman (2014) The No-U-Turn Sampler:
--   Adaptively Setting Path Lengths in Hamiltonian Monte Carlo, Journal of
--   Machine Learning Research.
module Mcmc.Proposal.Hamiltonian.Nuts
  ( NParams (..),
    defaultNParams,
    nuts,
  )
where

import Data.Bifunctor
import Mcmc.Acceptance
import Mcmc.Proposal
import Mcmc.Proposal.Hamiltonian.Common
import Mcmc.Proposal.Hamiltonian.Internal
import Mcmc.Proposal.Hamiltonian.Masses
import Numeric.AD.Double
import qualified Numeric.LinearAlgebra as L
import Numeric.Log
import System.Random.Stateful

-- Internal; Slice variable 'u'.
type SliceVariable = Log Double

-- Internal; Forward is True.
type Direction = Bool

-- Internal; Doubling step number 'j'.
type DoublingStep = Int

-- Internal; Number of leapfrog steps within the slice 'n'.
type NStepsOk = Int

-- Internal; Estimated acceptance rate \(\alpha\)'.
type Alpha = Log Double

-- Internal; Number of accepted steps.
type NAlpha = Int

-- Internal; Well, that's fun, isn't it? Have a look at Algorithm 3 in [4].
type BuildTreeReturnType = (Positions, Momenta, Positions, Momenta, Positions, NStepsOk, Alpha, NAlpha)

-- Constant determining largest allowed leapfrog integration error. See
-- discussion around Equation (3) in [4].
deltaMax :: Log Double
deltaMax :: Log Double
deltaMax = forall a. a -> Log a
Exp Double
1000

-- Second function in Algorithm 3 and Algorithm 6, respectively in [4].
buildTreeWith ::
  -- The exponent of the total energy of the starting state is used to
  -- calcaulate the expected acceptance rate 'Alpha'.
  Log Double ->
  MassesI ->
  Target ->
  IOGenM StdGen ->
  --
  Positions ->
  Momenta ->
  SliceVariable ->
  Direction ->
  DoublingStep ->
  LeapfrogScalingFactor ->
  IO (Maybe BuildTreeReturnType)
buildTreeWith :: Log Double
-> MassesI
-> Target
-> IOGenM StdGen
-> Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTreeWith Log Double
expETot0 MassesI
msI Target
tfun IOGenM StdGen
g Positions
q Positions
p Log Double
u Direction
v DoublingStep
j Double
e
  | DoublingStep
j forall a. Ord a => a -> a -> Direction
<= DoublingStep
0 =
      -- Move backwards or forwards?
      let e' :: Double
e' = if Direction
v then Double
e else forall a. Num a => a -> a
negate Double
e
       in case Target
-> MassesI
-> DoublingStep
-> Double
-> Positions
-> Positions
-> Maybe (Positions, Positions, Log Double, Log Double)
leapfrog Target
tfun MassesI
msI DoublingStep
1 Double
e' Positions
q Positions
p of
            Maybe (Positions, Positions, Log Double, Log Double)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
            Just (Positions
q', Positions
p', Log Double
_, Log Double
expEPot') ->
              if Direction
errorIsSmall
                then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Positions
q', Positions
p', Positions
q', Positions
p', Positions
q', DoublingStep
n', Log Double
alpha, DoublingStep
1)
                else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
              where
                expEKin' :: Log Double
expEKin' = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
p'
                expETot' :: Log Double
expETot' = Log Double
expEPot' forall a. Num a => a -> a -> a
* Log Double
expEKin'
                n' :: DoublingStep
n' = if Log Double
u forall a. Ord a => a -> a -> Direction
<= Log Double
expEPot' forall a. Num a => a -> a -> a
* Log Double
expEKin' then DoublingStep
1 else DoublingStep
0
                errorIsSmall :: Direction
errorIsSmall = Log Double
u forall a. Ord a => a -> a -> Direction
< Log Double
deltaMax forall a. Num a => a -> a -> a
* Log Double
expETot'
                alpha' :: Log Double
alpha' = Log Double
expETot' forall a. Fractional a => a -> a -> a
/ Log Double
expETot0
                alpha :: Log Double
alpha = forall a. Ord a => a -> a -> a
min Log Double
1.0 Log Double
alpha'

  -- Recursive case. This is complicated because the algorithm is written for an
  -- imperative language, and because we have two stacked monads.
  | Direction
otherwise = do
      Maybe BuildTreeReturnType
mr <- Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTree Positions
q Positions
p Log Double
u Direction
v (DoublingStep
j forall a. Num a => a -> a -> a
- DoublingStep
1) Double
e
      case Maybe BuildTreeReturnType
mr of
        Maybe BuildTreeReturnType
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
        -- Here, the suffixes 'm' and 'p' stand for minus and plus, respectively.
        Just (Positions
qm, Positions
pm, Positions
qp, Positions
pp, Positions
q', DoublingStep
n', Log Double
a', DoublingStep
na') -> do
          Maybe BuildTreeReturnType
mr' <-
            if Direction
v
              then -- Forwards.
              do
                Maybe BuildTreeReturnType
mr'' <- Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTree Positions
qp Positions
pp Log Double
u Direction
v (DoublingStep
j forall a. Num a => a -> a -> a
- DoublingStep
1) Double
e
                case Maybe BuildTreeReturnType
mr'' of
                  Maybe BuildTreeReturnType
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
                  Just (Positions
_, Positions
_, Positions
qp', Positions
pp', Positions
q'', DoublingStep
n'', Log Double
a'', DoublingStep
na'') ->
                    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Positions
qm, Positions
pm, Positions
qp', Positions
pp', Positions
q'', DoublingStep
n'', Log Double
a'', DoublingStep
na'')
              else -- Backwards.
              do
                Maybe BuildTreeReturnType
mr'' <- Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTree Positions
qm Positions
pm Log Double
u Direction
v (DoublingStep
j forall a. Num a => a -> a -> a
- DoublingStep
1) Double
e
                case Maybe BuildTreeReturnType
mr'' of
                  Maybe BuildTreeReturnType
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
                  Just (Positions
qm', Positions
pm', Positions
_, Positions
_, Positions
q'', DoublingStep
n'', Log Double
a'', DoublingStep
na'') ->
                    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Positions
qm', Positions
pm', Positions
qp, Positions
pp, Positions
q'', DoublingStep
n'', Log Double
a'', DoublingStep
na'')
          case Maybe BuildTreeReturnType
mr' of
            Maybe BuildTreeReturnType
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
            Just (Positions
qm'', Positions
pm'', Positions
qp'', Positions
pp'', Positions
q''', DoublingStep
n''', Log Double
a''', DoublingStep
na''') -> do
              Double
b <- forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Double
0, Double
1) IOGenM StdGen
g :: IO Double
              let q'''' :: Positions
q'''' = if Double
b forall a. Ord a => a -> a -> Direction
< forall a b. (Integral a, Num b) => a -> b
fromIntegral DoublingStep
n''' forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ DoublingStep
n' forall a. Num a => a -> a -> a
+ DoublingStep
n''') then Positions
q''' else Positions
q'
                  a'''' :: Log Double
a'''' = Log Double
a' forall a. Num a => a -> a -> a
+ Log Double
a'''
                  na'''' :: DoublingStep
na'''' = DoublingStep
na' forall a. Num a => a -> a -> a
+ DoublingStep
na'''
                  n'''' :: DoublingStep
n'''' = DoublingStep
n' forall a. Num a => a -> a -> a
+ DoublingStep
n'''
                  -- Important: Check for U-turn. This formula differs from the
                  -- formula using indicator functions in Algorithm 3. However,
                  -- check Equation (4).
                  isUTurn :: Direction
isUTurn = let dq :: Positions
dq = (Positions
qp'' forall a. Num a => a -> a -> a
- Positions
qm'') in (Positions
dq forall a. Num a => a -> a -> a
* Positions
pm'' forall a. Ord a => a -> a -> Direction
< Positions
0) Direction -> Direction -> Direction
|| (Positions
dq forall a. Num a => a -> a -> a
* Positions
pp'' forall a. Ord a => a -> a -> Direction
< Positions
0)
              if Direction
isUTurn
                then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
                else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Positions
qm'', Positions
pm'', Positions
qp'', Positions
pp'', Positions
q'''', DoublingStep
n'''', Log Double
a'''', DoublingStep
na'''')
  where
    buildTree :: Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTree = Log Double
-> MassesI
-> Target
-> IOGenM StdGen
-> Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTreeWith Log Double
expETot0 MassesI
msI Target
tfun IOGenM StdGen
g

-- | Paramters of the NUTS proposal.
--
-- Includes tuning parameters and tuning configuration.
data NParams = NParams
  { NParams -> Maybe Double
nLeapfrogScalingFactor :: Maybe LeapfrogScalingFactor,
    NParams -> Maybe Masses
nMasses :: Maybe Masses
  }
  deriving (DoublingStep -> NParams -> ShowS
[NParams] -> ShowS
NParams -> String
forall a.
(DoublingStep -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NParams] -> ShowS
$cshowList :: [NParams] -> ShowS
show :: NParams -> String
$cshow :: NParams -> String
showsPrec :: DoublingStep -> NParams -> ShowS
$cshowsPrec :: DoublingStep -> NParams -> ShowS
Show)

-- | Default parameters.
--
-- - Estimate a reasonable leapfrog scaling factor using Algorithm 4 [4]. If all
--   fails, use 0.1.
--
-- - The mass matrix is set to the identity matrix.
defaultNParams :: NParams
defaultNParams :: NParams
defaultNParams = Maybe Double -> Maybe Masses -> NParams
NParams forall a. Maybe a
Nothing forall a. Maybe a
Nothing

nutsPFunctionWithTuningParameters ::
  Traversable s =>
  Dimension ->
  HStructure s ->
  (s Double -> Target) ->
  TuningParameter ->
  AuxiliaryTuningParameters ->
  Either String (PFunction (s Double))
nutsPFunctionWithTuningParameters :: forall (s :: * -> *).
Traversable s =>
DoublingStep
-> HStructure s
-> (s Double -> Target)
-> Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
nutsPFunctionWithTuningParameters DoublingStep
d HStructure s
hstruct s Double -> Target
targetWith Double
_ AuxiliaryTuningParameters
ts = do
  HParamsI
hParamsI <- DoublingStep -> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters DoublingStep
d AuxiliaryTuningParameters
ts
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (s :: * -> *).
HParamsI
-> HStructure s -> (s Double -> Target) -> PFunction (s Double)
nutsPFunction HParamsI
hParamsI HStructure s
hstruct s Double -> Target
targetWith

data IsNew
  = Old
  | OldWith {IsNew -> AcceptanceCounts
_acceptanceCountsOld :: AcceptanceCounts}
  | NewWith {IsNew -> AcceptanceCounts
_acceptanceCountsNew :: AcceptanceCounts}

-- First function in Algorithm 3.
nutsPFunction ::
  HParamsI ->
  HStructure s ->
  (s Double -> Target) ->
  PFunction (s Double)
nutsPFunction :: forall (s :: * -> *).
HParamsI
-> HStructure s -> (s Double -> Target) -> PFunction (s Double)
nutsPFunction HParamsI
hparamsi HStructure s
hstruct s Double -> Target
targetWith s Double
x IOGenM StdGen
g = do
  Positions
p <- forall g (m :: * -> *).
StatefulGen g m =>
Positions -> Masses -> g -> m Positions
generateMomenta Positions
mus Masses
ms IOGenM StdGen
g
  Double
uZeroOne <- forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Double
0, Double
1) IOGenM StdGen
g :: IO Double
  -- NOTE (runtime): Here we need the target function value from the previous
  -- step. For now, I just recalculate the value, but this is, of course, slow!
  -- However, if other proposals have changed the state inbetween, we do need to
  -- recalculate this value.
  let q :: Positions
q = s Double -> Positions
toVec s Double
x
      expEPot :: Log Double
expEPot = forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ Target
target Positions
q
      expEKin :: Log Double
expEKin = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
p
      expETot :: Log Double
expETot = Log Double
expEPot forall a. Num a => a -> a -> a
* Log Double
expEKin
      uZeroOneL :: Log Double
uZeroOneL = forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
log Double
uZeroOne
      u :: Log Double
u = Log Double
expETot forall a. Num a => a -> a -> a
* Log Double
uZeroOneL
  let -- Recursive case. This is complicated because the algorithm is written for an
      -- imperative language, and because we have two stacked monads.
      --
      -- Here, the suffixes 'm' and 'p' stand for minus and plus, respectively.
      go :: Positions
-> Positions
-> Positions
-> Positions
-> DoublingStep
-> Positions
-> DoublingStep
-> IsNew
-> IO (Positions, IsNew)
go Positions
qm Positions
pm Positions
qp Positions
pp DoublingStep
j Positions
y DoublingStep
n IsNew
isNew = do
        Direction
v <- forall a g (m :: * -> *). (Uniform a, StatefulGen g m) => g -> m a
uniformM IOGenM StdGen
g :: IO Direction
        Maybe BuildTreeReturnType
mr' <-
          if Direction
v
            then -- Forwards.
            do
              Maybe BuildTreeReturnType
mr <- Log Double
-> MassesI
-> Target
-> IOGenM StdGen
-> Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTreeWith Log Double
expETot MassesI
msI Target
target IOGenM StdGen
g Positions
qp Positions
pp Log Double
u Direction
v DoublingStep
j Double
e
              case Maybe BuildTreeReturnType
mr of
                Maybe BuildTreeReturnType
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
                Just (Positions
_, Positions
_, Positions
qp', Positions
pp', Positions
y', DoublingStep
n', Log Double
a, DoublingStep
na) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Positions
qm, Positions
pm, Positions
qp', Positions
pp', Positions
y', DoublingStep
n', Log Double
a, DoublingStep
na)
            else -- Backwards.
            do
              Maybe BuildTreeReturnType
mr <- Log Double
-> MassesI
-> Target
-> IOGenM StdGen
-> Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTreeWith Log Double
expETot MassesI
msI Target
target IOGenM StdGen
g Positions
qm Positions
pm Log Double
u Direction
v DoublingStep
j Double
e
              case Maybe BuildTreeReturnType
mr of
                Maybe BuildTreeReturnType
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
                Just (Positions
qm', Positions
pm', Positions
_, Positions
_, Positions
y', DoublingStep
n', Log Double
a, DoublingStep
na) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Positions
qm', Positions
pm', Positions
qp, Positions
pp, Positions
y', DoublingStep
n', Log Double
a, DoublingStep
na)
        case Maybe BuildTreeReturnType
mr' of
          Maybe BuildTreeReturnType
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Positions
y, IsNew
isNew)
          Just (Positions
qm'', Positions
pm'', Positions
qp'', Positions
pp'', Positions
y'', DoublingStep
n'', Log Double
a, DoublingStep
na) -> do
            let r :: Double
r = forall a b. (Integral a, Num b) => a -> b
fromIntegral DoublingStep
n'' forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral DoublingStep
n :: Double
                ar :: Double
ar = (forall a. Floating a => a -> a
exp forall a b. (a -> b) -> a -> b
$ forall a. Log a -> a
ln Log Double
a) forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral DoublingStep
na :: Double
                getCounts :: a -> a
getCounts a
s = forall a. Ord a => a -> a -> a
max a
0 forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a
min a
100 forall a b. (a -> b) -> a -> b
$ forall a b. (RealFrac a, Integral b) => a -> b
round forall a b. (a -> b) -> a -> b
$ a
s forall a. Num a => a -> a -> a
* a
100
                ac :: AcceptanceCounts
ac =
                  if Double
ar forall a. Ord a => a -> a -> Direction
>= Double
0
                    then let cs :: DoublingStep
cs = forall {a} {a}. (RealFrac a, Integral a) => a -> a
getCounts Double
ar in DoublingStep -> DoublingStep -> AcceptanceCounts
AcceptanceCounts DoublingStep
cs (DoublingStep
100 forall a. Num a => a -> a -> a
- DoublingStep
cs)
                    else forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"nutsPFunction: Acceptance rate negative."
            Direction
isAccept <-
              if Double
r forall a. Ord a => a -> a -> Direction
> Double
1.0
                then forall (f :: * -> *) a. Applicative f => a -> f a
pure Direction
True
                else do
                  Double
b <- forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Double
0, Double
1) IOGenM StdGen
g
                  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Double
b forall a. Ord a => a -> a -> Direction
< Double
r
            let (Positions
y''', IsNew
isNew') = if Direction
isAccept then (Positions
y'', AcceptanceCounts -> IsNew
NewWith AcceptanceCounts
ac) else (Positions
y, AcceptanceCounts -> IsNew
OldWith AcceptanceCounts
ac)
                isUTurn :: Direction
isUTurn = let dq :: Positions
dq = (Positions
qp'' forall a. Num a => a -> a -> a
- Positions
qm'') in (Positions
dq forall a. Num a => a -> a -> a
* Positions
pm'' forall a. Ord a => a -> a -> Direction
< Positions
0) Direction -> Direction -> Direction
|| (Positions
dq forall a. Num a => a -> a -> a
* Positions
pp'' forall a. Ord a => a -> a -> Direction
< Positions
0)
            if Direction
isUTurn
              then forall (f :: * -> *) a. Applicative f => a -> f a
pure (Positions
y''', IsNew
isNew')
              else Positions
-> Positions
-> Positions
-> Positions
-> DoublingStep
-> Positions
-> DoublingStep
-> IsNew
-> IO (Positions, IsNew)
go Positions
qm'' Positions
pm'' Positions
qp'' Positions
pp'' (DoublingStep
j forall a. Num a => a -> a -> a
+ DoublingStep
1) Positions
y''' (DoublingStep
n forall a. Num a => a -> a -> a
+ DoublingStep
n'') IsNew
isNew'
  (Positions
x', IsNew
isNew) <- Positions
-> Positions
-> Positions
-> Positions
-> DoublingStep
-> Positions
-> DoublingStep
-> IsNew
-> IO (Positions, IsNew)
go Positions
q Positions
p Positions
q Positions
p DoublingStep
0 Positions
q DoublingStep
1 IsNew
Old
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case IsNew
isNew of
    IsNew
Old -> (forall a. PResult a
ForceReject, forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ DoublingStep -> DoublingStep -> AcceptanceCounts
AcceptanceCounts DoublingStep
0 DoublingStep
100)
    OldWith AcceptanceCounts
ac -> (forall a. PResult a
ForceReject, forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ AcceptanceCounts
ac)
    NewWith AcceptanceCounts
ac -> (forall a. a -> PResult a
ForceAccept forall a b. (a -> b) -> a -> b
$ Positions -> s Double
fromVec Positions
x', forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ AcceptanceCounts
ac)
  where
    (HParamsI Double
e Double
_ Masses
ms TParamsVar
_ TParamsFixed
_ MassesI
msI Positions
mus) = HParamsI
hparamsi
    (HStructure s Double
_ s Double -> Positions
toVec s Double -> Positions -> s Double
fromVecWith) = HStructure s
hstruct
    fromVec :: Positions -> s Double
fromVec = s Double -> Positions -> s Double
fromVecWith s Double
x
    target :: Target
target = s Double -> Target
targetWith s Double
x

-- | No U-turn Hamiltonian Monte Carlo sampler (NUTS).
--
-- The structure of the state is denoted as @s@.
--
-- May call 'error' during initialization.
nuts ::
  Traversable s =>
  NParams ->
  HTuningConf ->
  HStructure s ->
  HTarget s ->
  PName ->
  PWeight ->
  Proposal (s Double)
nuts :: forall (s :: * -> *).
Traversable s =>
NParams
-> HTuningConf
-> HStructure s
-> HTarget s
-> PName
-> PWeight
-> Proposal (s Double)
nuts NParams
nparams HTuningConf
htconf HStructure s
hstruct HTarget s
htarget PName
n PWeight
w =
  let -- Misc.
      desc :: PDescription
desc = String -> PDescription
PDescription String
"No U-turn sampler (NUTS)"
      (HStructure s Double
sample s Double -> Positions
toVec s Double -> Positions -> s Double
fromVec) = HStructure s
hstruct
      dim :: IndexOf Vector
dim = forall (c :: * -> *) t. Container c t => c t -> IndexOf c
L.size forall a b. (a -> b) -> a -> b
$ s Double -> Positions
toVec s Double
sample
      -- See bottom of page 1616 in [4].
      pDim :: PDimension
pDim = DoublingStep -> Double -> PDimension
PSpecial DoublingStep
dim Double
0.6
      -- Vectorize and derive the target function.
      (HTarget forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mPrF forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mJcF) = HTarget s
htarget
      tF :: s a -> Log a
tF s a
y = case (forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mPrF, forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mJcF) of
        (Maybe (s a -> Log a)
Nothing, Maybe (s a -> Log a)
Nothing) -> forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y
        (Just s a -> Log a
prF, Maybe (s a -> Log a)
Nothing) -> s a -> Log a
prF s a
y forall a. Num a => a -> a -> a
* forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y
        (Maybe (s a -> Log a)
Nothing, Just s a -> Log a
jcF) -> forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y forall a. Num a => a -> a -> a
* s a -> Log a
jcF s a
y
        (Just s a -> Log a
prF, Just s a -> Log a
jcF) -> s a -> Log a
prF s a
y forall a. Num a => a -> a -> a
* forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y forall a. Num a => a -> a -> a
* s a -> Log a
jcF s a
y
      tFnG :: s Double -> (Double, s Double)
tFnG = forall (f :: * -> *).
Traversable f =>
(forall s.
 (Reifies s Tape, Typeable s) =>
 f (ReverseDouble s) -> ReverseDouble s)
-> f Double -> (Double, f Double)
grad' (forall a. Log a -> a
ln forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
tF)
      targetWith :: s Double -> Target
targetWith s Double
x = forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall a. a -> Log a
Exp s Double -> Positions
toVec forall b c a. (b -> c) -> (a -> b) -> a -> c
. s Double -> (Double, s Double)
tFnG forall b c a. (b -> c) -> (a -> b) -> a -> c
. s Double -> Positions -> s Double
fromVec s Double
x
      (NParams Maybe Double
mEps Maybe Masses
mMs) = NParams
nparams
      hParamsI :: HParamsI
hParamsI =
        forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => String -> a
error forall a. a -> a
id forall a b. (a -> b) -> a -> b
$
          Target
-> Positions
-> Maybe Double
-> Maybe Double
-> Maybe Masses
-> Either String HParamsI
hParamsIWith (s Double -> Target
targetWith s Double
sample) (s Double -> Positions
toVec s Double
sample) Maybe Double
mEps forall a. Maybe a
Nothing Maybe Masses
mMs
      ps :: PFunction (s Double)
ps = forall (s :: * -> *).
HParamsI
-> HStructure s -> (s Double -> Target) -> PFunction (s Double)
nutsPFunction HParamsI
hParamsI HStructure s
hstruct s Double -> Target
targetWith
      nutsWith :: Maybe (Tuner (s Double)) -> Proposal (s Double)
nutsWith = forall a.
PName
-> PDescription
-> PSpeed
-> PDimension
-> PWeight
-> PFunction a
-> Maybe (Tuner a)
-> Proposal a
Proposal PName
n PDescription
desc PSpeed
PSlow PDimension
pDim PWeight
w PFunction (s Double)
ps
      -- Tuning.
      ts :: AuxiliaryTuningParameters
ts = HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters HParamsI
hParamsI
      tuner :: Maybe (Tuner (s Double))
tuner = do
        TuningFunction (s Double)
tfun <- forall a.
DoublingStep
-> (a -> Positions) -> HTuningConf -> Maybe (TuningFunction a)
hTuningFunctionWith DoublingStep
dim s Double -> Positions
toVec HTuningConf
htconf
        let pfun :: Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
pfun = forall (s :: * -> *).
Traversable s =>
DoublingStep
-> HStructure s
-> (s Double -> Target)
-> Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
nutsPFunctionWithTuningParameters DoublingStep
dim HStructure s
hstruct s Double -> Target
targetWith
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a.
Double
-> AuxiliaryTuningParameters
-> Direction
-> TuningFunction a
-> (Double
    -> AuxiliaryTuningParameters -> Either String (PFunction a))
-> Tuner a
Tuner Double
1.0 AuxiliaryTuningParameters
ts Direction
True TuningFunction (s Double)
tfun Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
pfun
   in case forall (s :: * -> *).
Foldable s =>
Masses -> HStructure s -> Maybe String
checkHStructureWith (HParamsI -> Masses
hpsMasses HParamsI
hParamsI) HStructure s
hstruct of
        Just String
err -> forall a. HasCallStack => String -> a
error String
err
        Maybe String
Nothing -> Maybe (Tuner (s Double)) -> Proposal (s Double)
nutsWith Maybe (Tuner (s Double))
tuner