{-# LANGUAGE BangPatterns #-}
module Mcmc.Proposal.Hamiltonian.Internal
(
HParamsI (..),
hParamsIWith,
toAuxiliaryTuningParameters,
fromAuxiliaryTuningParameters,
findReasonableEpsilon,
hTuningFunctionWith,
checkHStructureWith,
generateMomenta,
exponentialKineticEnergy,
Target,
leapfrog,
)
where
import Control.Monad
import Control.Monad.ST
import Data.Foldable
import Data.Maybe
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Unboxed as VU
import Mcmc.Proposal
import Mcmc.Proposal.Hamiltonian.Common
import Mcmc.Proposal.Hamiltonian.Masses
import qualified Numeric.LinearAlgebra as L
import Numeric.Log
import System.Random.Stateful
data TParamsVar = TParamsVar
{
TParamsVar -> Double
tpvLeapfrogScalingFactorMean :: LeapfrogScalingFactor,
TParamsVar -> Double
tpvHStatistics :: Double,
TParamsVar -> Double
tpvCurrentTuningStep :: Double
}
deriving (LeapfrogTrajectoryLength -> TParamsVar -> ShowS
[TParamsVar] -> ShowS
TParamsVar -> String
(LeapfrogTrajectoryLength -> TParamsVar -> ShowS)
-> (TParamsVar -> String)
-> ([TParamsVar] -> ShowS)
-> Show TParamsVar
forall a.
(LeapfrogTrajectoryLength -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: LeapfrogTrajectoryLength -> TParamsVar -> ShowS
showsPrec :: LeapfrogTrajectoryLength -> TParamsVar -> ShowS
$cshow :: TParamsVar -> String
show :: TParamsVar -> String
$cshowList :: [TParamsVar] -> ShowS
showList :: [TParamsVar] -> ShowS
Show)
tParamsVar :: TParamsVar
tParamsVar :: TParamsVar
tParamsVar = Double -> Double -> Double -> TParamsVar
TParamsVar Double
1.0 Double
0.0 Double
1.0
data TParamsFixed = TParamsFixed
{ TParamsFixed -> Double
tpfEps0 :: Double,
TParamsFixed -> Double
tpfMu :: Double,
TParamsFixed -> Double
tpfGa :: Double,
TParamsFixed -> Double
tpfT0 :: Double,
TParamsFixed -> Double
tpfKa :: Double
}
deriving (LeapfrogTrajectoryLength -> TParamsFixed -> ShowS
[TParamsFixed] -> ShowS
TParamsFixed -> String
(LeapfrogTrajectoryLength -> TParamsFixed -> ShowS)
-> (TParamsFixed -> String)
-> ([TParamsFixed] -> ShowS)
-> Show TParamsFixed
forall a.
(LeapfrogTrajectoryLength -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: LeapfrogTrajectoryLength -> TParamsFixed -> ShowS
showsPrec :: LeapfrogTrajectoryLength -> TParamsFixed -> ShowS
$cshow :: TParamsFixed -> String
show :: TParamsFixed -> String
$cshowList :: [TParamsFixed] -> ShowS
showList :: [TParamsFixed] -> ShowS
Show)
tParamsFixedWith :: LeapfrogScalingFactor -> TParamsFixed
tParamsFixedWith :: Double -> TParamsFixed
tParamsFixedWith Double
eps = Double -> Double -> Double -> Double -> Double -> TParamsFixed
TParamsFixed Double
eps Double
mu Double
ga Double
t0 Double
ka
where
mu :: Double
mu = Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
10 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
eps
ga :: Double
ga = Double
0.15
t0 :: Double
t0 = Double
10
ka :: Double
ka = Double
0.75
data HParamsI = HParamsI
{ HParamsI -> Double
hpsLeapfrogScalingFactor :: LeapfrogScalingFactor,
HParamsI -> Double
hpsLeapfrogSimulationLength :: LeapfrogSimulationLength,
HParamsI -> Masses
hpsMasses :: Masses,
HParamsI -> TParamsVar
hpsTParamsVar :: TParamsVar,
HParamsI -> TParamsFixed
hpsTParamsFixed :: TParamsFixed,
HParamsI -> MassesI
hpsMassesI :: MassesI,
HParamsI -> Positions
hpsMu :: Mu
}
deriving (LeapfrogTrajectoryLength -> HParamsI -> ShowS
[HParamsI] -> ShowS
HParamsI -> String
(LeapfrogTrajectoryLength -> HParamsI -> ShowS)
-> (HParamsI -> String) -> ([HParamsI] -> ShowS) -> Show HParamsI
forall a.
(LeapfrogTrajectoryLength -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: LeapfrogTrajectoryLength -> HParamsI -> ShowS
showsPrec :: LeapfrogTrajectoryLength -> HParamsI -> ShowS
$cshow :: HParamsI -> String
show :: HParamsI -> String
$cshowList :: [HParamsI] -> ShowS
showList :: [HParamsI] -> ShowS
Show)
defaultLeapfrogScalingFactor :: LeapfrogScalingFactor
defaultLeapfrogScalingFactor :: Double
defaultLeapfrogScalingFactor = Double
0.1
defaultLeapfrogSimulationLength :: LeapfrogSimulationLength
defaultLeapfrogSimulationLength :: Double
defaultLeapfrogSimulationLength = Double
0.5
defaultMassesWith :: Int -> Masses
defaultMassesWith :: LeapfrogTrajectoryLength -> Masses
defaultMassesWith LeapfrogTrajectoryLength
d = Matrix Double -> Masses
forall t. Matrix t -> Herm t
L.trustSym (Matrix Double -> Masses) -> Matrix Double -> Masses
forall a b. (a -> b) -> a -> b
$ LeapfrogTrajectoryLength -> Matrix Double
forall a.
(Num a, Element a) =>
LeapfrogTrajectoryLength -> Matrix a
L.ident LeapfrogTrajectoryLength
d
hParamsIWith ::
Target ->
Positions ->
Maybe LeapfrogScalingFactor ->
Maybe LeapfrogSimulationLength ->
Maybe Masses ->
Either String HParamsI
hParamsIWith :: Target
-> Positions
-> Maybe Double
-> Maybe Double
-> Maybe Masses
-> Either String HParamsI
hParamsIWith Target
htarget Positions
p Maybe Double
mEps Maybe Double
mLa Maybe Masses
mMs = do
LeapfrogTrajectoryLength
d <- case Positions -> LeapfrogTrajectoryLength
forall a. Storable a => Vector a -> LeapfrogTrajectoryLength
VS.length Positions
p of
LeapfrogTrajectoryLength
0 -> String -> Either String LeapfrogTrajectoryLength
forall {b}. String -> Either String b
eWith String
"Empty position vector."
LeapfrogTrajectoryLength
d -> LeapfrogTrajectoryLength -> Either String LeapfrogTrajectoryLength
forall a b. b -> Either a b
Right LeapfrogTrajectoryLength
d
Masses
ms <- case Maybe Masses
mMs of
Maybe Masses
Nothing -> Masses -> Either String Masses
forall a b. b -> Either a b
Right (Masses -> Either String Masses) -> Masses -> Either String Masses
forall a b. (a -> b) -> a -> b
$ LeapfrogTrajectoryLength -> Masses
defaultMassesWith LeapfrogTrajectoryLength
d
Just Masses
ms -> do
let ms' :: Matrix Double
ms' = Matrix Double -> Matrix Double
cleanMatrix (Matrix Double -> Matrix Double) -> Matrix Double -> Matrix Double
forall a b. (a -> b) -> a -> b
$ Masses -> Matrix Double
forall t. Herm t -> Matrix t
L.unSym Masses
ms
diagonalMs :: [Double]
diagonalMs = Positions -> [Double]
forall a. Storable a => Vector a -> [a]
L.toList (Positions -> [Double]) -> Positions -> [Double]
forall a b. (a -> b) -> a -> b
$ Matrix Double -> Positions
forall t. Element t => Matrix t -> Vector t
L.takeDiag Matrix Double
ms'
Bool -> Either String () -> Either String ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Double -> Bool) -> [Double] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
0) [Double]
diagonalMs) (Either String () -> Either String ())
-> Either String () -> Either String ()
forall a b. (a -> b) -> a -> b
$ String -> Either String ()
forall {b}. String -> Either String b
eWith String
"Some diagonal masses are zero or negative."
let nrows :: LeapfrogTrajectoryLength
nrows = Matrix Double -> LeapfrogTrajectoryLength
forall t. Matrix t -> LeapfrogTrajectoryLength
L.rows Matrix Double
ms'
ncols :: LeapfrogTrajectoryLength
ncols = Matrix Double -> LeapfrogTrajectoryLength
forall t. Matrix t -> LeapfrogTrajectoryLength
L.cols Matrix Double
ms'
Bool -> Either String () -> Either String ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (LeapfrogTrajectoryLength
nrows LeapfrogTrajectoryLength -> LeapfrogTrajectoryLength -> Bool
forall a. Eq a => a -> a -> Bool
/= LeapfrogTrajectoryLength
ncols) (Either String () -> Either String ())
-> Either String () -> Either String ()
forall a b. (a -> b) -> a -> b
$ String -> Either String ()
forall {b}. String -> Either String b
eWith String
"Mass matrix is not square."
Masses -> Either String Masses
forall a b. b -> Either a b
Right Masses
ms
let msI :: MassesI
msI = Masses -> MassesI
getMassesI Masses
ms
mus :: Positions
mus = Masses -> Positions
getMus Masses
ms
Double
la <- case Maybe Double
mLa of
Maybe Double
Nothing -> Double -> Either String Double
forall a b. b -> Either a b
Right Double
defaultLeapfrogSimulationLength
Just Double
l
| Double
l Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
0 -> String -> Either String Double
forall {b}. String -> Either String b
eWith String
"Leapfrog simulation length is zero or negative."
| Bool
otherwise -> Double -> Either String Double
forall a b. b -> Either a b
Right Double
l
Double
eps <- case Maybe Double
mEps of
Maybe Double
Nothing -> Double -> Either String Double
forall a b. b -> Either a b
Right (Double -> Either String Double) -> Double -> Either String Double
forall a b. (a -> b) -> a -> b
$ (forall s. ST s Double) -> Double
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Double) -> Double)
-> (forall s. ST s Double) -> Double
forall a b. (a -> b) -> a -> b
$ do
STGenM StdGen s
g <- StdGen -> ST s (STGenM StdGen s)
forall g s. g -> ST s (STGenM g s)
newSTGenM (StdGen -> ST s (STGenM StdGen s))
-> StdGen -> ST s (STGenM StdGen s)
forall a b. (a -> b) -> a -> b
$ LeapfrogTrajectoryLength -> StdGen
mkStdGen LeapfrogTrajectoryLength
42
Target -> Masses -> Positions -> STGenM StdGen s -> ST s Double
forall g (m :: * -> *).
StatefulGen g m =>
Target -> Masses -> Positions -> g -> m Double
findReasonableEpsilon Target
htarget Masses
ms Positions
p STGenM StdGen s
g
Just Double
e
| Double
e Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
0 -> String -> Either String Double
forall {b}. String -> Either String b
eWith String
"Leapfrog scaling factor is zero or negative."
| Bool
otherwise -> Double -> Either String Double
forall a b. b -> Either a b
Right Double
e
let tParamsFixed :: TParamsFixed
tParamsFixed = Double -> TParamsFixed
tParamsFixedWith Double
eps
HParamsI -> Either String HParamsI
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HParamsI -> Either String HParamsI)
-> HParamsI -> Either String HParamsI
forall a b. (a -> b) -> a -> b
$ Double
-> Double
-> Masses
-> TParamsVar
-> TParamsFixed
-> MassesI
-> Positions
-> HParamsI
HParamsI Double
eps Double
la Masses
ms TParamsVar
tParamsVar TParamsFixed
tParamsFixed MassesI
msI Positions
mus
where
eWith :: String -> Either String b
eWith String
m = String -> Either String b
forall a b. a -> Either a b
Left (String -> Either String b) -> String -> Either String b
forall a b. (a -> b) -> a -> b
$ String
"hParamsIWith: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
m
toAuxiliaryTuningParameters :: HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters :: HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters (HParamsI Double
eps Double
la Masses
ms TParamsVar
tpv TParamsFixed
tpf MassesI
_ Positions
_) =
[Double] -> AuxiliaryTuningParameters
forall a. Unbox a => [a] -> Vector a
VU.fromList ([Double] -> AuxiliaryTuningParameters)
-> [Double] -> AuxiliaryTuningParameters
forall a b. (a -> b) -> a -> b
$ Double
eps Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Double
la Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Double
epsMean Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Double
h Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Double
m Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Double
eps0 Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Double
mu Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Double
ga Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Double
t0 Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Double
ka Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: [Double]
msL
where
(TParamsVar Double
epsMean Double
h Double
m) = TParamsVar
tpv
(TParamsFixed Double
eps0 Double
mu Double
ga Double
t0 Double
ka) = TParamsFixed
tpf
msL :: [Double]
msL = AuxiliaryTuningParameters -> [Double]
forall a. Unbox a => Vector a -> [a]
VU.toList (AuxiliaryTuningParameters -> [Double])
-> AuxiliaryTuningParameters -> [Double]
forall a b. (a -> b) -> a -> b
$ Masses -> AuxiliaryTuningParameters
massesToVector Masses
ms
fromAuxiliaryTuningParameters :: Dimension -> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters :: LeapfrogTrajectoryLength
-> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters LeapfrogTrajectoryLength
d AuxiliaryTuningParameters
xs
| (LeapfrogTrajectoryLength
d LeapfrogTrajectoryLength
-> LeapfrogTrajectoryLength -> LeapfrogTrajectoryLength
forall a. Num a => a -> a -> a
* LeapfrogTrajectoryLength
d) LeapfrogTrajectoryLength
-> LeapfrogTrajectoryLength -> LeapfrogTrajectoryLength
forall a. Num a => a -> a -> a
+ LeapfrogTrajectoryLength
10 LeapfrogTrajectoryLength -> LeapfrogTrajectoryLength -> Bool
forall a. Eq a => a -> a -> Bool
/= LeapfrogTrajectoryLength
len = String -> Either String HParamsI
forall a b. a -> Either a b
Left String
"fromAuxiliaryTuningParameters: Dimension mismatch."
| LeapfrogTrajectoryLength -> LeapfrogTrajectoryLength
forall a b. (Integral a, Num b) => a -> b
fromIntegral (LeapfrogTrajectoryLength
d LeapfrogTrajectoryLength
-> LeapfrogTrajectoryLength -> LeapfrogTrajectoryLength
forall a. Num a => a -> a -> a
* LeapfrogTrajectoryLength
d) LeapfrogTrajectoryLength -> LeapfrogTrajectoryLength -> Bool
forall a. Eq a => a -> a -> Bool
/= LeapfrogTrajectoryLength
lenMs = String -> Either String HParamsI
forall a b. a -> Either a b
Left String
"fromAuxiliaryTuningParameters: Masses dimension mismatch."
| Bool
otherwise = case AuxiliaryTuningParameters -> [Double]
forall a. Unbox a => Vector a -> [a]
VU.toList (AuxiliaryTuningParameters -> [Double])
-> AuxiliaryTuningParameters -> [Double]
forall a b. (a -> b) -> a -> b
$ LeapfrogTrajectoryLength
-> AuxiliaryTuningParameters -> AuxiliaryTuningParameters
forall a.
Unbox a =>
LeapfrogTrajectoryLength -> Vector a -> Vector a
VU.take LeapfrogTrajectoryLength
10 AuxiliaryTuningParameters
xs of
[Double
eps, Double
la, Double
epsMean, Double
h, Double
m, Double
eps0, Double
mu, Double
ga, Double
t0, Double
ka] ->
let tpv :: TParamsVar
tpv = Double -> Double -> Double -> TParamsVar
TParamsVar Double
epsMean Double
h Double
m
tpf :: TParamsFixed
tpf = Double -> Double -> Double -> Double -> Double -> TParamsFixed
TParamsFixed Double
eps0 Double
mu Double
ga Double
t0 Double
ka
in HParamsI -> Either String HParamsI
forall a b. b -> Either a b
Right (HParamsI -> Either String HParamsI)
-> HParamsI -> Either String HParamsI
forall a b. (a -> b) -> a -> b
$ Double
-> Double
-> Masses
-> TParamsVar
-> TParamsFixed
-> MassesI
-> Positions
-> HParamsI
HParamsI Double
eps Double
la Masses
ms TParamsVar
tpv TParamsFixed
tpf MassesI
msI Positions
mus
[Double]
_ -> String -> Either String HParamsI
forall a b. a -> Either a b
Left String
"fromAuxiliaryTuningParameters: Impossible dimension mismatch."
where
len :: LeapfrogTrajectoryLength
len = AuxiliaryTuningParameters -> LeapfrogTrajectoryLength
forall a. Unbox a => Vector a -> LeapfrogTrajectoryLength
VU.length AuxiliaryTuningParameters
xs
msV :: AuxiliaryTuningParameters
msV = LeapfrogTrajectoryLength
-> AuxiliaryTuningParameters -> AuxiliaryTuningParameters
forall a.
Unbox a =>
LeapfrogTrajectoryLength -> Vector a -> Vector a
VU.drop LeapfrogTrajectoryLength
10 AuxiliaryTuningParameters
xs
lenMs :: LeapfrogTrajectoryLength
lenMs = AuxiliaryTuningParameters -> LeapfrogTrajectoryLength
forall a. Unbox a => Vector a -> LeapfrogTrajectoryLength
VU.length AuxiliaryTuningParameters
msV
ms :: Masses
ms = LeapfrogTrajectoryLength -> AuxiliaryTuningParameters -> Masses
vectorToMasses LeapfrogTrajectoryLength
d AuxiliaryTuningParameters
msV
msI :: MassesI
msI = Masses -> MassesI
getMassesI Masses
ms
mus :: Positions
mus = Masses -> Positions
getMus Masses
ms
findReasonableEpsilon ::
(StatefulGen g m) =>
Target ->
Masses ->
Positions ->
g ->
m LeapfrogScalingFactor
findReasonableEpsilon :: forall g (m :: * -> *).
StatefulGen g m =>
Target -> Masses -> Positions -> g -> m Double
findReasonableEpsilon Target
t Masses
ms Positions
q g
g = do
Positions
p <- Positions -> Masses -> g -> m Positions
forall g (m :: * -> *).
StatefulGen g m =>
Positions -> Masses -> g -> m Positions
generateMomenta Positions
mu Masses
ms g
g
case Target
-> MassesI
-> LeapfrogTrajectoryLength
-> Double
-> Positions
-> Positions
-> Maybe (Positions, Positions, Log Double, Log Double)
leapfrog Target
t MassesI
msI LeapfrogTrajectoryLength
1 Double
eI Positions
q Positions
p of
Maybe (Positions, Positions, Log Double, Log Double)
Nothing -> Double -> m Double
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
defaultLeapfrogScalingFactor
Just (Positions
_, Positions
p', Log Double
prQ, Log Double
prQ') -> do
let expEKin :: Log Double
expEKin = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
p
expEKin' :: Log Double
expEKin' = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
p'
rI :: Double
rI :: Double
rI = Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Log Double -> Double
forall a. Log a -> a
ln (Log Double -> Double) -> Log Double -> Double
forall a b. (a -> b) -> a -> b
$ Log Double
prQ' Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
expEKin' Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ (Log Double
prQ Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
expEKin)
a :: Double
a :: Double
a = if Double
rI Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
0.5 then Double
1 else (-Double
1)
go :: Double -> Double -> Double
go Double
e Double
r =
if Double
r Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
a Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
2 Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double -> Double
forall a. Num a => a -> a
negate Double
a
then case Target
-> MassesI
-> LeapfrogTrajectoryLength
-> Double
-> Positions
-> Positions
-> Maybe (Positions, Positions, Log Double, Log Double)
leapfrog Target
t MassesI
msI LeapfrogTrajectoryLength
1 Double
e Positions
q Positions
p of
Maybe (Positions, Positions, Log Double, Log Double)
Nothing -> Double
e
Just (Positions
_, Positions
p'', Log Double
_, Log Double
prQ'') ->
let expEKin'' :: Log Double
expEKin'' = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
p''
r' :: Double
r' :: Double
r' = Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Log Double -> Double
forall a. Log a -> a
ln (Log Double -> Double) -> Log Double -> Double
forall a b. (a -> b) -> a -> b
$ Log Double
prQ'' Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
expEKin'' Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ (Log Double
prQ Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
expEKin)
e' :: Double
e' = (Double
2 Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
a) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
e
in Double -> Double -> Double
go Double
e' Double
r'
else Double
e
Double -> m Double
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> m Double) -> Double -> m Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Double
go Double
eI Double
rI
where
eI :: Double
eI = Double
1.0
msI :: MassesI
msI = Masses -> MassesI
getMassesI Masses
ms
mu :: Positions
mu = Masses -> Positions
getMus Masses
ms
hTuningFunctionWith ::
Dimension ->
(a -> Positions) ->
HTuningConf ->
Maybe (TuningFunction a)
hTuningFunctionWith :: forall a.
LeapfrogTrajectoryLength
-> (a -> Positions) -> HTuningConf -> Maybe (TuningFunction a)
hTuningFunctionWith LeapfrogTrajectoryLength
_ a -> Positions
_ (HTuningConf HTuneLeapfrog
HNoTuneLeapfrog HTuneMasses
HNoTuneMasses) = Maybe (TuningFunction a)
forall a. Maybe a
Nothing
hTuningFunctionWith LeapfrogTrajectoryLength
n a -> Positions
toVec (HTuningConf HTuneLeapfrog
lc HTuneMasses
mc) = TuningFunction a -> Maybe (TuningFunction a)
forall a. a -> Maybe a
Just (TuningFunction a -> Maybe (TuningFunction a))
-> TuningFunction a -> Maybe (TuningFunction a)
forall a b. (a -> b) -> a -> b
$ \TuningType
tt PDimension
pdim Maybe Double
mar Maybe (Vector a)
mxs (Double
_, !AuxiliaryTuningParameters
ts) ->
case TuningType
tt of
TuningType
IntermediateTuningFastProposalsOnly -> String -> (Double, AuxiliaryTuningParameters)
forall {a}. String -> a
err String
"fast intermediate tuning step but slow proposal"
TuningType
NormalTuningFastProposalsOnly -> String -> (Double, AuxiliaryTuningParameters)
forall {a}. String -> a
err String
"fast normal tuning step but slow proposal"
TuningType
_ ->
let (HParamsI Double
eps Double
la Masses
ms TParamsVar
tpv TParamsFixed
tpf MassesI
msI Positions
mus) =
(String -> HParamsI)
-> (HParamsI -> HParamsI) -> Either String HParamsI -> HParamsI
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> HParamsI
forall a. HasCallStack => String -> a
error HParamsI -> HParamsI
forall a. a -> a
id (Either String HParamsI -> HParamsI)
-> Either String HParamsI -> HParamsI
forall a b. (a -> b) -> a -> b
$ LeapfrogTrajectoryLength
-> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters LeapfrogTrajectoryLength
n AuxiliaryTuningParameters
ts
(TParamsVar Double
epsMean Double
h Double
m) = TParamsVar
tpv
(TParamsFixed Double
eps0 Double
mu Double
ga Double
t0 Double
ka) = TParamsFixed
tpf
m' :: SmoothingParameter
m' = Natural -> SmoothingParameter
SmoothingParameter (Natural -> SmoothingParameter) -> Natural -> SmoothingParameter
forall a b. (a -> b) -> a -> b
$ Double -> Natural
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
round Double
m
(Masses
ms', MassesI
msI') = case TuningType
tt of
TuningType
IntermediateTuningAllProposals -> (Masses
ms, MassesI
msI)
TuningType
_ ->
let xs :: Vector a
xs = Vector a -> Maybe (Vector a) -> Vector a
forall a. a -> Maybe a -> a
fromMaybe (String -> Vector a
forall {a}. String -> a
err String
"empty trace") Maybe (Vector a)
mxs
in case HTuneMasses
mc of
HTuneMasses
HNoTuneMasses -> (Masses
ms, MassesI
msI)
HTuneMasses
HTuneDiagonalMassesOnly -> SmoothingParameter
-> (a -> Positions)
-> Vector a
-> (Masses, MassesI)
-> (Masses, MassesI)
forall a.
SmoothingParameter
-> (a -> Positions)
-> Vector a
-> (Masses, MassesI)
-> (Masses, MassesI)
tuneDiagonalMassesOnly SmoothingParameter
m' a -> Positions
toVec Vector a
xs (Masses
ms, MassesI
msI)
HTuneMasses
HTuneAllMasses -> SmoothingParameter
-> (a -> Positions)
-> Vector a
-> (Masses, MassesI)
-> (Masses, MassesI)
forall a.
SmoothingParameter
-> (a -> Positions)
-> Vector a
-> (Masses, MassesI)
-> (Masses, MassesI)
tuneAllMasses SmoothingParameter
m' a -> Positions
toVec Vector a
xs (Masses
ms, MassesI
msI)
(Double
eps'', Double
epsMean'', Double
h'') = case TuningType
tt of
TuningType
LastTuningFastProposalsOnly -> (Double
eps, Double
epsMean, Double
h)
TuningType
_ -> case HTuneLeapfrog
lc of
HTuneLeapfrog
HNoTuneLeapfrog -> (Double
eps, Double
epsMean, Double
h)
HTuneLeapfrog
HTuneLeapfrog ->
let ar :: Double
ar = Double -> Maybe Double -> Double
forall a. a -> Maybe a -> a
fromMaybe (String -> Double
forall {a}. String -> a
err String
"no acceptance rate") Maybe Double
mar
delta :: Double
delta = PDimension -> Double
getOptimalRate PDimension
pdim
c :: Double
c = Double -> Double
forall a. Fractional a => a -> a
recip (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
t0
h' :: Double
h' = (Double
1.0 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
c) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
h Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
c Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
delta Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
ar)
eps' :: Double
eps' = Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
mu Double -> Double -> Double
forall a. Num a => a -> a -> a
- (Double -> Double
forall a. Floating a => a -> a
sqrt Double
m Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
ga) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
h'
mMKa :: Double
mMKa = Double
m Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double -> Double
forall a. Num a => a -> a
negate Double
ka
epsMean' :: Double
epsMean' = (Double
eps' Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
mMKa) Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
epsMean Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
mMKa))
epsF :: Double
epsF = if TuningType
tt TuningType -> TuningType -> Bool
forall a. Eq a => a -> a -> Bool
== TuningType
LastTuningAllProposals then Double
epsMean' else Double
eps'
in (Double
epsF, Double
epsMean', Double
h')
tpv' :: TParamsVar
tpv' = Double -> Double -> Double -> TParamsVar
TParamsVar Double
epsMean'' Double
h'' (Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1.0)
in (Double
eps'' Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
eps0, HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters (HParamsI -> AuxiliaryTuningParameters)
-> HParamsI -> AuxiliaryTuningParameters
forall a b. (a -> b) -> a -> b
$ Double
-> Double
-> Masses
-> TParamsVar
-> TParamsFixed
-> MassesI
-> Positions
-> HParamsI
HParamsI Double
eps'' Double
la Masses
ms' TParamsVar
tpv' TParamsFixed
tpf MassesI
msI' Positions
mus)
where
err :: String -> a
err String
msg = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"hTuningFunctionWith: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
msg
checkHStructureWith :: (Foldable s) => Masses -> HStructure s -> Maybe String
checkHStructureWith :: forall (s :: * -> *).
Foldable s =>
Masses -> HStructure s -> Maybe String
checkHStructureWith Masses
ms (HStructure s Double
x s Double -> Positions
toVec s Double -> Positions -> s Double
fromVec)
| s Double -> [Double]
forall a. s a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (s Double -> Positions -> s Double
fromVec s Double
x Positions
xVec) [Double] -> [Double] -> Bool
forall a. Eq a => a -> a -> Bool
/= s Double -> [Double]
forall a. s a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList s Double
x = String -> Maybe String
eWith String
"'fromVectorWith x (toVector x) /= x' for sample state."
| Positions -> IndexOf Vector
forall (c :: * -> *) t. Container c t => c t -> IndexOf c
L.size Positions
xVec LeapfrogTrajectoryLength -> LeapfrogTrajectoryLength -> Bool
forall a. Eq a => a -> a -> Bool
/= LeapfrogTrajectoryLength
nrows = String -> Maybe String
eWith String
"Mass matrix and 'toVector x' have different sizes for sample state."
| Bool
otherwise = Maybe String
forall a. Maybe a
Nothing
where
eWith :: String -> Maybe String
eWith String
m = String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$ String
"checkHStructureWith: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
m
nrows :: LeapfrogTrajectoryLength
nrows = Matrix Double -> LeapfrogTrajectoryLength
forall t. Matrix t -> LeapfrogTrajectoryLength
L.rows (Matrix Double -> LeapfrogTrajectoryLength)
-> Matrix Double -> LeapfrogTrajectoryLength
forall a b. (a -> b) -> a -> b
$ Masses -> Matrix Double
forall t. Herm t -> Matrix t
L.unSym Masses
ms
xVec :: Positions
xVec = s Double -> Positions
toVec s Double
x
generateMomenta ::
(StatefulGen g m) =>
Mu ->
Masses ->
g ->
m Momenta
generateMomenta :: forall g (m :: * -> *).
StatefulGen g m =>
Positions -> Masses -> g -> m Positions
generateMomenta Positions
mu Masses
masses g
gen = do
LeapfrogTrajectoryLength
seed <- g -> m LeapfrogTrajectoryLength
forall a g (m :: * -> *). (Uniform a, StatefulGen g m) => g -> m a
forall g (m :: * -> *).
StatefulGen g m =>
g -> m LeapfrogTrajectoryLength
uniformM g
gen
let momenta :: Matrix Double
momenta = LeapfrogTrajectoryLength
-> LeapfrogTrajectoryLength -> Positions -> Masses -> Matrix Double
L.gaussianSample LeapfrogTrajectoryLength
seed LeapfrogTrajectoryLength
1 Positions
mu Masses
masses
Positions -> m Positions
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Positions -> m Positions) -> Positions -> m Positions
forall a b. (a -> b) -> a -> b
$ Matrix Double -> Positions
forall t. Element t => Matrix t -> Vector t
L.flatten Matrix Double
momenta
exponentialKineticEnergy ::
MassesI ->
Momenta ->
Log Double
exponentialKineticEnergy :: MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
xs =
Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double) -> Double -> Log Double
forall a b. (a -> b) -> a -> b
$ (-Double
0.5) Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Positions
xs Positions -> Positions -> Double
forall t. Numeric t => Vector t -> Vector t -> t
L.<.> (MassesI
msI MassesI -> Positions -> Positions
L.!#> Positions
xs))
type Target = Positions -> (Log Double, Positions)
leapfrog ::
Target ->
MassesI ->
LeapfrogTrajectoryLength ->
LeapfrogScalingFactor ->
Positions ->
Momenta ->
Maybe (Positions, Momenta, Log Double, Log Double)
leapfrog :: Target
-> MassesI
-> LeapfrogTrajectoryLength
-> Double
-> Positions
-> Positions
-> Maybe (Positions, Positions, Log Double, Log Double)
leapfrog Target
tF MassesI
msI LeapfrogTrajectoryLength
l Double
eps Positions
q Positions
p = do
(Log Double
x, Positions
pHalf) <-
let (Log Double
x, Positions
pHalf) = Double -> Target -> Positions -> Target
leapfrogStepMomenta (Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
eps) Target
tF Positions
q Positions
p
in if Log Double
x Log Double -> Log Double -> Bool
forall a. Ord a => a -> a -> Bool
> Log Double
0.0
then (Log Double, Positions) -> Maybe (Log Double, Positions)
forall a. a -> Maybe a
Just (Log Double
x, Positions
pHalf)
else Maybe (Log Double, Positions)
forall a. Maybe a
Nothing
(Positions
qLM1, Positions
pLM1Half) <- LeapfrogTrajectoryLength
-> Maybe (Positions, Positions) -> Maybe (Positions, Positions)
forall {t}.
(Ord t, Num t) =>
t -> Maybe (Positions, Positions) -> Maybe (Positions, Positions)
go (LeapfrogTrajectoryLength
l LeapfrogTrajectoryLength
-> LeapfrogTrajectoryLength -> LeapfrogTrajectoryLength
forall a. Num a => a -> a -> a
- LeapfrogTrajectoryLength
1) (Maybe (Positions, Positions) -> Maybe (Positions, Positions))
-> Maybe (Positions, Positions) -> Maybe (Positions, Positions)
forall a b. (a -> b) -> a -> b
$ (Positions, Positions) -> Maybe (Positions, Positions)
forall a. a -> Maybe a
Just (Positions
q, Positions
pHalf)
let qL :: Positions
qL = MassesI -> Double -> Positions -> Positions -> Positions
leapfrogStepPositions MassesI
msI Double
eps Positions
qLM1 Positions
pLM1Half
(Log Double
x', Positions
pL) <-
let (Log Double
x', Positions
pL) = Double -> Target -> Positions -> Target
leapfrogStepMomenta (Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
eps) Target
tF Positions
qL Positions
pLM1Half
in if Log Double
x' Log Double -> Log Double -> Bool
forall a. Ord a => a -> a -> Bool
> Log Double
0.0
then (Log Double, Positions) -> Maybe (Log Double, Positions)
forall a. a -> Maybe a
Just (Log Double
x', Positions
pL)
else Maybe (Log Double, Positions)
forall a. Maybe a
Nothing
(Positions, Positions, Log Double, Log Double)
-> Maybe (Positions, Positions, Log Double, Log Double)
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (Positions
qL, Positions
pL, Log Double
x, Log Double
x')
where
go :: t -> Maybe (Positions, Positions) -> Maybe (Positions, Positions)
go t
_ Maybe (Positions, Positions)
Nothing = Maybe (Positions, Positions)
forall a. Maybe a
Nothing
go t
n (Just (Positions
qs, Positions
ps))
| t
n t -> t -> Bool
forall a. Ord a => a -> a -> Bool
<= t
0 = (Positions, Positions) -> Maybe (Positions, Positions)
forall a. a -> Maybe a
Just (Positions
qs, Positions
ps)
| Bool
otherwise =
let qs' :: Positions
qs' = MassesI -> Double -> Positions -> Positions -> Positions
leapfrogStepPositions MassesI
msI Double
eps Positions
qs Positions
ps
(Log Double
x, Positions
ps') = Double -> Target -> Positions -> Target
leapfrogStepMomenta Double
eps Target
tF Positions
qs' Positions
p
in if Log Double
x Log Double -> Log Double -> Bool
forall a. Ord a => a -> a -> Bool
> Log Double
0.0
then t -> Maybe (Positions, Positions) -> Maybe (Positions, Positions)
go (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
1) (Maybe (Positions, Positions) -> Maybe (Positions, Positions))
-> Maybe (Positions, Positions) -> Maybe (Positions, Positions)
forall a b. (a -> b) -> a -> b
$ (Positions, Positions) -> Maybe (Positions, Positions)
forall a. a -> Maybe a
Just (Positions
qs', Positions
ps')
else Maybe (Positions, Positions)
forall a. Maybe a
Nothing
leapfrogStepMomenta ::
LeapfrogScalingFactor ->
Target ->
Positions ->
Momenta ->
(Log Double, Momenta)
leapfrogStepMomenta :: Double -> Target -> Positions -> Target
leapfrogStepMomenta Double
eps Target
tf Positions
q Positions
p = (Log Double
x, Positions
p Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
+ Double -> Positions -> Positions
forall t (c :: * -> *). Linear t c => t -> c t -> c t
L.scale Double
eps Positions
g)
where
(Log Double
x, Positions
g) = Target
tf Positions
q
leapfrogStepPositions ::
MassesI ->
LeapfrogScalingFactor ->
Positions ->
Momenta ->
Positions
leapfrogStepPositions :: MassesI -> Double -> Positions -> Positions -> Positions
leapfrogStepPositions MassesI
msI Double
eps Positions
q Positions
p = Positions
q Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
+ (MassesI
msI MassesI -> Positions -> Positions
L.!#> Double -> Positions -> Positions
forall t (c :: * -> *). Linear t c => t -> c t -> c t
L.scale Double
eps Positions
p)