{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Mcmc.Proposal.Hamiltonian.Hamiltonian
(
HParams (..),
defaultHParams,
hamiltonian,
)
where
import Data.Bifunctor
import Mcmc.Acceptance
import Mcmc.Algorithm.MHG
import Mcmc.Proposal
import Mcmc.Proposal.Hamiltonian.Common
import Mcmc.Proposal.Hamiltonian.Internal
import Mcmc.Proposal.Hamiltonian.Masses
import Numeric.AD.Double
import qualified Numeric.LinearAlgebra as L
import Numeric.Log
import System.Random.Stateful
data HParams = HParams
{ HParams -> Maybe LeapfrogScalingFactor
hLeapfrogScalingFactor :: Maybe LeapfrogScalingFactor,
HParams -> Maybe LeapfrogScalingFactor
hLeapfrogSimulationLength :: Maybe LeapfrogSimulationLength,
HParams -> Maybe Masses
hMasses :: Maybe Masses
}
deriving (Int -> HParams -> ShowS
[HParams] -> ShowS
HParams -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HParams] -> ShowS
$cshowList :: [HParams] -> ShowS
show :: HParams -> String
$cshow :: HParams -> String
showsPrec :: Int -> HParams -> ShowS
$cshowsPrec :: Int -> HParams -> ShowS
Show)
defaultHParams :: HParams
defaultHParams :: HParams
defaultHParams = Maybe LeapfrogScalingFactor
-> Maybe LeapfrogScalingFactor -> Maybe Masses -> HParams
HParams forall a. Maybe a
Nothing forall a. Maybe a
Nothing forall a. Maybe a
Nothing
hamiltonianPFunctionWithTuningParameters ::
Traversable s =>
Dimension ->
HStructure s ->
(s Double -> Target) ->
TuningParameter ->
AuxiliaryTuningParameters ->
Either String (PFunction (s Double))
hamiltonianPFunctionWithTuningParameters :: forall (s :: * -> *).
Traversable s =>
Int
-> HStructure s
-> (s LeapfrogScalingFactor -> Target)
-> LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> Either String (PFunction (s LeapfrogScalingFactor))
hamiltonianPFunctionWithTuningParameters Int
d HStructure s
hstruct s LeapfrogScalingFactor -> Target
targetWith LeapfrogScalingFactor
_ AuxiliaryTuningParameters
ts = do
HParamsI
hParamsI <- Int -> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters Int
d AuxiliaryTuningParameters
ts
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (s :: * -> *).
HParamsI
-> HStructure s
-> (s LeapfrogScalingFactor -> Target)
-> PFunction (s LeapfrogScalingFactor)
hamiltonianPFunction HParamsI
hParamsI HStructure s
hstruct s LeapfrogScalingFactor -> Target
targetWith
hamiltonianPFunction ::
HParamsI ->
HStructure s ->
(s Double -> Target) ->
PFunction (s Double)
hamiltonianPFunction :: forall (s :: * -> *).
HParamsI
-> HStructure s
-> (s LeapfrogScalingFactor -> Target)
-> PFunction (s LeapfrogScalingFactor)
hamiltonianPFunction HParamsI
hparamsi HStructure s
hstruct s LeapfrogScalingFactor -> Target
targetWith s LeapfrogScalingFactor
x IOGenM StdGen
g = do
Momenta
p <- forall g (m :: * -> *).
StatefulGen g m =>
Momenta -> Masses -> g -> m Momenta
generateMomenta Momenta
mus Masses
ms IOGenM StdGen
g
LeapfrogScalingFactor
eRan <- forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (LeapfrogScalingFactor
eL, LeapfrogScalingFactor
eR) IOGenM StdGen
g
let lM :: LeapfrogScalingFactor
lM = LeapfrogScalingFactor
la forall a. Fractional a => a -> a -> a
/ LeapfrogScalingFactor
eRan
lL :: Int
lL = forall a. Ord a => a -> a -> a
max (Int
1 :: Int) (forall a b. (RealFrac a, Integral b) => a -> b
floor forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
0.9 forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
lM)
lR :: Int
lR = forall a. Ord a => a -> a -> a
max Int
lL (forall a b. (RealFrac a, Integral b) => a -> b
ceiling forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
1.1 forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
lM)
Int
lRan <- forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Int
lL, Int
lR) IOGenM StdGen
g
case Target
-> MassesI
-> Int
-> LeapfrogScalingFactor
-> Momenta
-> Momenta
-> Maybe
(Momenta, Momenta, Log LeapfrogScalingFactor,
Log LeapfrogScalingFactor)
leapfrog (s LeapfrogScalingFactor -> Target
targetWith s LeapfrogScalingFactor
x) MassesI
msI Int
lRan LeapfrogScalingFactor
eRan Momenta
q Momenta
p of
Maybe
(Momenta, Momenta, Log LeapfrogScalingFactor,
Log LeapfrogScalingFactor)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. PResult a
ForceReject, forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor -> Int -> AcceptanceRates
AcceptanceRates LeapfrogScalingFactor
0 Int
1)
Just (Momenta
q', Momenta
p', Log LeapfrogScalingFactor
prQ, Log LeapfrogScalingFactor
prQ') -> do
let
prP :: Log LeapfrogScalingFactor
prP = MassesI -> Momenta -> Log LeapfrogScalingFactor
exponentialKineticEnergy MassesI
msI Momenta
p
prP' :: Log LeapfrogScalingFactor
prP' = MassesI -> Momenta -> Log LeapfrogScalingFactor
exponentialKineticEnergy MassesI
msI Momenta
p'
r :: Log LeapfrogScalingFactor
r = Log LeapfrogScalingFactor
prQ' forall a. Num a => a -> a -> a
* Log LeapfrogScalingFactor
prP' forall a. Fractional a => a -> a -> a
/ (Log LeapfrogScalingFactor
prQ forall a. Num a => a -> a -> a
* Log LeapfrogScalingFactor
prP)
Bool
accept <- Log LeapfrogScalingFactor -> IOGenM StdGen -> IO Bool
mhgAccept Log LeapfrogScalingFactor
r IOGenM StdGen
g
let pr :: PResult (s LeapfrogScalingFactor)
pr = if Bool
accept then forall a. a -> PResult a
ForceAccept (s LeapfrogScalingFactor -> Momenta -> s LeapfrogScalingFactor
fromVec s LeapfrogScalingFactor
x Momenta
q') else forall a. PResult a
ForceReject
ar :: LeapfrogScalingFactor
ar = forall a. Ord a => a -> a -> a
max LeapfrogScalingFactor
0 forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a
min LeapfrogScalingFactor
1 (forall a. Floating a => a -> a
exp forall a b. (a -> b) -> a -> b
$ forall a. Log a -> a
ln Log LeapfrogScalingFactor
r)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PResult (s LeapfrogScalingFactor)
pr, forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor -> Int -> AcceptanceRates
AcceptanceRates LeapfrogScalingFactor
ar Int
1)
where
(HParamsI LeapfrogScalingFactor
e LeapfrogScalingFactor
la Masses
ms TParamsVar
_ TParamsFixed
_ MassesI
msI Momenta
mus) = HParamsI
hparamsi
(HStructure s LeapfrogScalingFactor
_ s LeapfrogScalingFactor -> Momenta
toVec s LeapfrogScalingFactor -> Momenta -> s LeapfrogScalingFactor
fromVec) = HStructure s
hstruct
q :: Momenta
q = s LeapfrogScalingFactor -> Momenta
toVec s LeapfrogScalingFactor
x
eL :: LeapfrogScalingFactor
eL = LeapfrogScalingFactor
0.9 forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
e
eR :: LeapfrogScalingFactor
eR = LeapfrogScalingFactor
1.1 forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
e
hamiltonian ::
Traversable s =>
HParams ->
HTuningConf ->
HStructure s ->
HTarget s ->
PName ->
PWeight ->
Proposal (s Double)
hamiltonian :: forall (s :: * -> *).
Traversable s =>
HParams
-> HTuningConf
-> HStructure s
-> HTarget s
-> PName
-> PWeight
-> Proposal (s LeapfrogScalingFactor)
hamiltonian HParams
hparams HTuningConf
htconf HStructure s
hstruct HTarget s
htarget PName
n PWeight
w =
let
desc :: PDescription
desc = String -> PDescription
PDescription String
"Hamiltonian Monte Carlo (HMC)"
(HStructure s LeapfrogScalingFactor
sample s LeapfrogScalingFactor -> Momenta
toVec s LeapfrogScalingFactor -> Momenta -> s LeapfrogScalingFactor
fromVec) = HStructure s
hstruct
dim :: IndexOf Vector
dim = forall (c :: * -> *) t. Container c t => c t -> IndexOf c
L.size forall a b. (a -> b) -> a -> b
$ s LeapfrogScalingFactor -> Momenta
toVec s LeapfrogScalingFactor
sample
pDim :: PDimension
pDim = Int -> LeapfrogScalingFactor -> PDimension
PSpecial Int
dim LeapfrogScalingFactor
0.65
(HTarget forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mPrF forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mJcF) = HTarget s
htarget
tF :: s a -> Log a
tF s a
y = case (forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mPrF, forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mJcF) of
(Maybe (s a -> Log a)
Nothing, Maybe (s a -> Log a)
Nothing) -> forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y
(Just s a -> Log a
prF, Maybe (s a -> Log a)
Nothing) -> s a -> Log a
prF s a
y forall a. Num a => a -> a -> a
* forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y
(Maybe (s a -> Log a)
Nothing, Just s a -> Log a
jcF) -> forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y forall a. Num a => a -> a -> a
* s a -> Log a
jcF s a
y
(Just s a -> Log a
prF, Just s a -> Log a
jcF) -> s a -> Log a
prF s a
y forall a. Num a => a -> a -> a
* forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y forall a. Num a => a -> a -> a
* s a -> Log a
jcF s a
y
tFnG :: s LeapfrogScalingFactor
-> (LeapfrogScalingFactor, s LeapfrogScalingFactor)
tFnG = forall (f :: * -> *).
Traversable f =>
(forall s.
(Reifies s Tape, Typeable s) =>
f (ReverseDouble s) -> ReverseDouble s)
-> f LeapfrogScalingFactor
-> (LeapfrogScalingFactor, f LeapfrogScalingFactor)
grad' (forall a. Log a -> a
ln forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
tF)
targetWith :: s LeapfrogScalingFactor -> Target
targetWith s LeapfrogScalingFactor
x = forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall a. a -> Log a
Exp s LeapfrogScalingFactor -> Momenta
toVec forall b c a. (b -> c) -> (a -> b) -> a -> c
. s LeapfrogScalingFactor
-> (LeapfrogScalingFactor, s LeapfrogScalingFactor)
tFnG forall b c a. (b -> c) -> (a -> b) -> a -> c
. s LeapfrogScalingFactor -> Momenta -> s LeapfrogScalingFactor
fromVec s LeapfrogScalingFactor
x
(HParams Maybe LeapfrogScalingFactor
mEps Maybe LeapfrogScalingFactor
mLa Maybe Masses
mMs) = HParams
hparams
hParamsI :: HParamsI
hParamsI =
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => String -> a
error forall a. a -> a
id forall a b. (a -> b) -> a -> b
$
Target
-> Momenta
-> Maybe LeapfrogScalingFactor
-> Maybe LeapfrogScalingFactor
-> Maybe Masses
-> Either String HParamsI
hParamsIWith (s LeapfrogScalingFactor -> Target
targetWith s LeapfrogScalingFactor
sample) (s LeapfrogScalingFactor -> Momenta
toVec s LeapfrogScalingFactor
sample) Maybe LeapfrogScalingFactor
mEps Maybe LeapfrogScalingFactor
mLa Maybe Masses
mMs
ps :: PFunction (s LeapfrogScalingFactor)
ps = forall (s :: * -> *).
HParamsI
-> HStructure s
-> (s LeapfrogScalingFactor -> Target)
-> PFunction (s LeapfrogScalingFactor)
hamiltonianPFunction HParamsI
hParamsI HStructure s
hstruct s LeapfrogScalingFactor -> Target
targetWith
hamiltonianWith :: Maybe (Tuner (s LeapfrogScalingFactor))
-> Proposal (s LeapfrogScalingFactor)
hamiltonianWith = forall a.
PName
-> PDescription
-> PSpeed
-> PDimension
-> PWeight
-> PFunction a
-> Maybe (Tuner a)
-> Proposal a
Proposal PName
n PDescription
desc PSpeed
PSlow PDimension
pDim PWeight
w PFunction (s LeapfrogScalingFactor)
ps
ts :: AuxiliaryTuningParameters
ts = HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters HParamsI
hParamsI
tuner :: Maybe (Tuner (s LeapfrogScalingFactor))
tuner = do
TuningFunction (s LeapfrogScalingFactor)
tfun <- forall a.
Int -> (a -> Momenta) -> HTuningConf -> Maybe (TuningFunction a)
hTuningFunctionWith Int
dim s LeapfrogScalingFactor -> Momenta
toVec HTuningConf
htconf
let pfun :: LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> Either String (PFunction (s LeapfrogScalingFactor))
pfun = forall (s :: * -> *).
Traversable s =>
Int
-> HStructure s
-> (s LeapfrogScalingFactor -> Target)
-> LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> Either String (PFunction (s LeapfrogScalingFactor))
hamiltonianPFunctionWithTuningParameters Int
dim HStructure s
hstruct s LeapfrogScalingFactor -> Target
targetWith
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a.
LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> Bool
-> Bool
-> TuningFunction a
-> (LeapfrogScalingFactor
-> AuxiliaryTuningParameters -> Either String (PFunction a))
-> Tuner a
Tuner LeapfrogScalingFactor
1.0 AuxiliaryTuningParameters
ts Bool
True Bool
True TuningFunction (s LeapfrogScalingFactor)
tfun LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> Either String (PFunction (s LeapfrogScalingFactor))
pfun
in case forall (s :: * -> *).
Foldable s =>
Masses -> HStructure s -> Maybe String
checkHStructureWith (HParamsI -> Masses
hpsMasses HParamsI
hParamsI) HStructure s
hstruct of
Just String
err -> forall a. HasCallStack => String -> a
error String
err
Maybe String
Nothing -> Maybe (Tuner (s LeapfrogScalingFactor))
-> Proposal (s LeapfrogScalingFactor)
hamiltonianWith Maybe (Tuner (s LeapfrogScalingFactor))
tuner