{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
module Mcmc.Metropolis
( mh,
mhContinue,
)
where
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.State
import Data.Aeson
import Data.Maybe
import Mcmc.Item
import Mcmc.Mcmc
import Mcmc.Proposal
import Mcmc.Status
import Mcmc.Trace
import Numeric.Log
import System.Random.MWC
import Prelude hiding (cycle)
mhRatio :: Log Double -> Log Double -> Log Double -> Log Double -> Log Double
mhRatio :: Log Double -> Log Double -> Log Double -> Log Double -> Log Double
mhRatio Log Double
fX Log Double
fY Log Double
q Log Double
j = Log Double
fY Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Log Double
fX Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
q Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
j
{-# INLINE mhRatio #-}
mhPropose :: Proposal a -> Mcmc a ()
mhPropose :: Proposal a -> Mcmc a ()
mhPropose Proposal a
m = do
let p :: ProposalSimple a
p = Proposal a -> ProposalSimple a
forall a. Proposal a -> ProposalSimple a
pSimple Proposal a
m
Status a
s <- StateT (Status a) IO (Status a)
forall (m :: * -> *) s. Monad m => StateT s m s
get
let (Item a
x Log Double
pX Log Double
lX) = Status a -> Item a
forall a. Status a -> Item a
item Status a
s
pF :: a -> Log Double
pF = Status a -> a -> Log Double
forall a. Status a -> a -> Log Double
priorF Status a
s
lF :: a -> Log Double
lF = Status a -> a -> Log Double
forall a. Status a -> a -> Log Double
likelihoodF Status a
s
a :: Acceptance (Proposal a)
a = Status a -> Acceptance (Proposal a)
forall a. Status a -> Acceptance (Proposal a)
acceptance Status a
s
g :: GenIO
g = Status a -> GenIO
forall a. Status a -> GenIO
generator Status a
s
(!a
y, !Log Double
q, !Log Double
j) <- IO (a, Log Double, Log Double)
-> StateT (Status a) IO (a, Log Double, Log Double)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (a, Log Double, Log Double)
-> StateT (Status a) IO (a, Log Double, Log Double))
-> IO (a, Log Double, Log Double)
-> StateT (Status a) IO (a, Log Double, Log Double)
forall a b. (a -> b) -> a -> b
$ a -> Gen RealWorld -> IO (a, Log Double, Log Double)
p a
x Gen RealWorld
g
let !pY :: Log Double
pY = a -> Log Double
pF a
y
!lY :: Log Double
lY = a -> Log Double
lF a
y
!r :: Log Double
r = Log Double -> Log Double -> Log Double -> Log Double -> Log Double
mhRatio (Log Double
pX Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
lX) (Log Double
pY Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
lY) Log Double
q Log Double
j
if Log Double -> Double
forall a. Log a -> a
ln Log Double
r Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
0.0
then Status a -> Mcmc a ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Status a -> Mcmc a ()) -> Status a -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ Status a
s {item :: Item a
item = a -> Log Double -> Log Double -> Item a
forall a. a -> Log Double -> Log Double -> Item a
Item a
y Log Double
pY Log Double
lY, acceptance :: Acceptance (Proposal a)
acceptance = Proposal a
-> Bool -> Acceptance (Proposal a) -> Acceptance (Proposal a)
forall k.
(Ord k, Show k) =>
k -> Bool -> Acceptance k -> Acceptance k
pushA Proposal a
m Bool
True Acceptance (Proposal a)
a}
else do
Double
b <- Gen (PrimState (StateT (Status a) IO))
-> StateT (Status a) IO Double
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
Gen (PrimState m) -> m a
uniform Gen RealWorld
Gen (PrimState (StateT (Status a) IO))
g
if Double
b Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double -> Double
forall a. Floating a => a -> a
exp (Log Double -> Double
forall a. Log a -> a
ln Log Double
r)
then Status a -> Mcmc a ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Status a -> Mcmc a ()) -> Status a -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ Status a
s {item :: Item a
item = a -> Log Double -> Log Double -> Item a
forall a. a -> Log Double -> Log Double -> Item a
Item a
y Log Double
pY Log Double
lY, acceptance :: Acceptance (Proposal a)
acceptance = Proposal a
-> Bool -> Acceptance (Proposal a) -> Acceptance (Proposal a)
forall k.
(Ord k, Show k) =>
k -> Bool -> Acceptance k -> Acceptance k
pushA Proposal a
m Bool
True Acceptance (Proposal a)
a}
else Status a -> Mcmc a ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Status a -> Mcmc a ()) -> Status a -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ Status a
s {acceptance :: Acceptance (Proposal a)
acceptance = Proposal a
-> Bool -> Acceptance (Proposal a) -> Acceptance (Proposal a)
forall k.
(Ord k, Show k) =>
k -> Bool -> Acceptance k -> Acceptance k
pushA Proposal a
m Bool
False Acceptance (Proposal a)
a}
mhIter :: ToJSON a => [Proposal a] -> Mcmc a ()
mhIter :: [Proposal a] -> Mcmc a ()
mhIter [Proposal a]
ps = do
(Proposal a -> Mcmc a ()) -> [Proposal a] -> Mcmc a ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Proposal a -> Mcmc a ()
forall a. Proposal a -> Mcmc a ()
mhPropose [Proposal a]
ps
Status a
s <- StateT (Status a) IO (Status a)
forall (m :: * -> *) s. Monad m => StateT s m s
get
let i :: Item a
i = Status a -> Item a
forall a. Status a -> Item a
item Status a
s
t :: Trace a
t = Status a -> Trace a
forall a. Status a -> Trace a
trace Status a
s
n :: Int
n = Status a -> Int
forall a. Status a -> Int
iteration Status a
s
Status a -> Mcmc a ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Status a -> Mcmc a ()) -> Status a -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ Status a
s {trace :: Trace a
trace = Item a -> Trace a -> Trace a
forall a. Item a -> Trace a -> Trace a
pushT Item a
i Trace a
t, iteration :: Int
iteration = Int -> Int
forall a. Enum a => a -> a
succ Int
n}
Mcmc a ()
forall a. Mcmc a ()
mcmcClean
Mcmc a ()
forall a. ToJSON a => Mcmc a ()
mcmcMonitorExec
mhNIter :: ToJSON a => Int -> Mcmc a ()
mhNIter :: Int -> Mcmc a ()
mhNIter Int
n = do
String -> Mcmc a ()
forall a. String -> Mcmc a ()
mcmcDebugS (String -> Mcmc a ()) -> String -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ String
"Run " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" iterations."
Cycle a
c <- (Status a -> Cycle a) -> StateT (Status a) IO (Cycle a)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets Status a -> Cycle a
forall a. Status a -> Cycle a
cycle
Gen RealWorld
g <- (Status a -> Gen RealWorld) -> StateT (Status a) IO (Gen RealWorld)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets Status a -> Gen RealWorld
forall a. Status a -> GenIO
generator
[[Proposal a]]
cycles <- IO [[Proposal a]] -> StateT (Status a) IO [[Proposal a]]
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [[Proposal a]] -> StateT (Status a) IO [[Proposal a]])
-> IO [[Proposal a]] -> StateT (Status a) IO [[Proposal a]]
forall a b. (a -> b) -> a -> b
$ Cycle a -> Int -> GenIO -> IO [[Proposal a]]
forall a. Cycle a -> Int -> GenIO -> IO [[Proposal a]]
getNIterations Cycle a
c Int
n Gen RealWorld
GenIO
g
[[Proposal a]] -> ([Proposal a] -> Mcmc a ()) -> Mcmc a ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [[Proposal a]]
cycles [Proposal a] -> Mcmc a ()
forall a. ToJSON a => [Proposal a] -> Mcmc a ()
mhIter
mhBurnInN :: ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnInN :: Int -> Maybe Int -> Mcmc a ()
mhBurnInN Int
b (Just Int
t)
| Int
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = String -> Mcmc a ()
forall a. HasCallStack => String -> a
error String
"mhBurnInN: Auto tuning period smaller equal 0."
| Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
t = do
Mcmc a ()
forall a. Mcmc a ()
mcmcResetA
Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhNIter Int
t
Mcmc a ByteString
forall a. Mcmc a ByteString
mcmcSummarizeCycle Mcmc a ByteString -> (ByteString -> Mcmc a ()) -> Mcmc a ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcDebugB
Mcmc a ()
forall a. Mcmc a ()
mcmcAutotune
Int -> Maybe Int -> Mcmc a ()
forall a. ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnInN (Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
t) (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
t)
| Bool
otherwise = do
Mcmc a ()
forall a. Mcmc a ()
mcmcResetA
Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhNIter Int
b
Mcmc a ByteString
forall a. Mcmc a ByteString
mcmcSummarizeCycle Mcmc a ByteString -> (ByteString -> Mcmc a ()) -> Mcmc a ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcInfoB
String -> Mcmc a ()
forall a. String -> Mcmc a ()
mcmcInfoS (String -> Mcmc a ()) -> String -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ String
"Acceptance ratios calculated over the last " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
b String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" iterations."
mhBurnInN Int
b Maybe Int
Nothing = Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhNIter Int
b
mhBurnIn :: ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnIn :: Int -> Maybe Int -> Mcmc a ()
mhBurnIn Int
b Maybe Int
t
| Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> Mcmc a ()
forall a. HasCallStack => String -> a
error String
"mhBurnIn: Negative number of burn in iterations."
| Int
b Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = () -> Mcmc a ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise = do
String -> Mcmc a ()
forall a. String -> Mcmc a ()
mcmcInfoS (String -> Mcmc a ()) -> String -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ String
"Burn in for " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
b String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" cycles."
String -> Mcmc a ()
forall a. String -> Mcmc a ()
mcmcDebugS (String -> Mcmc a ()) -> String -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ String
"Auto tuning period is " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Maybe Int -> String
forall a. Show a => a -> String
show Maybe Int
t String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"."
Int -> Maybe Int -> Mcmc a ()
forall a. ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnInN Int
b Maybe Int
t
ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcInfoB ByteString
"Burn in finished."
mhRun :: ToJSON a => Int -> Mcmc a ()
mhRun :: Int -> Mcmc a ()
mhRun Int
n = do
Mcmc a ()
forall a. Mcmc a ()
mcmcResetA
String -> Mcmc a ()
forall a. String -> Mcmc a ()
mcmcInfoS (String -> Mcmc a ()) -> String -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ String
"Run chain for " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" iterations."
Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhNIter Int
n
mhT :: ToJSON a => Mcmc a ()
mhT :: Mcmc a ()
mhT = do
ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcInfoB ByteString
"Metropolis-Hastings sampler."
Mcmc a ByteString
forall a. Mcmc a ByteString
mcmcSummarizeCycle Mcmc a ByteString -> (ByteString -> Mcmc a ()) -> Mcmc a ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcInfoB
Mcmc a ()
forall a. ToJSON a => Mcmc a ()
mcmcReport
Status a
s <- StateT (Status a) IO (Status a)
forall (m :: * -> *) s. Monad m => StateT s m s
get
let b :: Int
b = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 (Status a -> Maybe Int
forall a. Status a -> Maybe Int
burnInIterations Status a
s)
Int -> Maybe Int -> Mcmc a ()
forall a. ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnIn Int
b (Status a -> Maybe Int
forall a. Status a -> Maybe Int
autoTuningPeriod Status a
s)
Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhRun (Int -> Mcmc a ()) -> Int -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ Status a -> Int
forall a. Status a -> Int
iterations Status a
s
mhContinueT :: ToJSON a => Int -> Mcmc a ()
mhContinueT :: Int -> Mcmc a ()
mhContinueT Int
dn = do
ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcInfoB ByteString
"Continuation of Metropolis-Hastings sampler."
String -> Mcmc a ()
forall a. String -> Mcmc a ()
mcmcInfoS (String -> Mcmc a ()) -> String -> Mcmc a ()
forall a b. (a -> b) -> a -> b
$ String
"Run chain for " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
dn String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" additional iterations."
Mcmc a ByteString
forall a. Mcmc a ByteString
mcmcSummarizeCycle Mcmc a ByteString -> (ByteString -> Mcmc a ()) -> Mcmc a ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Mcmc a ()
forall a. ByteString -> Mcmc a ()
mcmcInfoB
Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhRun Int
dn
mhContinue ::
ToJSON a =>
Int ->
Status a ->
IO (Status a)
mhContinue :: Int -> Status a -> IO (Status a)
mhContinue Int
dn Status a
s
| Int
dn Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = String -> IO (Status a)
forall a. HasCallStack => String -> a
error String
"mhContinue: The number of iterations is zero or negative."
| Bool
otherwise = Mcmc a () -> Status a -> IO (Status a)
forall a. ToJSON a => Mcmc a () -> Status a -> IO (Status a)
mcmcRun (Int -> Mcmc a ()
forall a. ToJSON a => Int -> Mcmc a ()
mhContinueT Int
dn) Status a
s'
where
n' :: Int
n' = Status a -> Int
forall a. Status a -> Int
iterations Status a
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
dn
s' :: Status a
s' = Status a
s {iterations :: Int
iterations = Int
n'}
mh ::
ToJSON a =>
Status a ->
IO (Status a)
mh :: Status a -> IO (Status a)
mh Status a
s =
if Status a -> Int
forall a. Status a -> Int
iteration Status a
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
then Mcmc a () -> Status a -> IO (Status a)
forall a. ToJSON a => Mcmc a () -> Status a -> IO (Status a)
mcmcRun Mcmc a ()
forall a. ToJSON a => Mcmc a ()
mhT Status a
s
else do
String -> IO ()
putStrLn String
"To continue a Markov chain run, please use 'mhContinue'."
String -> IO (Status a)
forall a. HasCallStack => String -> a
error (String -> IO (Status a)) -> String -> IO (Status a)
forall a b. (a -> b) -> a -> b
$ String
"mh: Current iteration " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Status a -> Int
forall a. Status a -> Int
iteration Status a
s) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" is non-zero."