{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE RecordWildCards #-}
module Numeric.MCMC.Metropolis (
mcmc
, chain
, chain'
, metropolis
, module Data.Sampling.Types
, MWC.create
, MWC.createSystemRandom
, MWC.withSystemRandom
, MWC.asGenIO
) where
import Control.Monad (when, replicateM)
import Control.Monad.Codensity (lowerCodensity)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict (execStateT, get, put)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Sampling.Types (Target(..), Chain(..), Transition)
#if __GLASGOW_HASKELL__ < 710
import Data.Traversable (Traversable, traverse)
#endif
import Pipes (Producer, Consumer, yield, (>->), runEffect, await)
import qualified Pipes.Prelude as Pipes (mapM_, take)
import System.Random.MWC.Probability (Gen, Prob)
import qualified System.Random.MWC.Probability as MWC
propose
:: (PrimMonad m, Traversable f)
=> Double
-> f Double
-> Prob m (f Double)
propose radial = traverse perturb where
perturb m = MWC.normal m radial
metropolis
:: (Traversable f, PrimMonad m)
=> Double
-> Maybe (f Double -> b)
-> Transition m (Chain (f Double) b)
metropolis radial tunable = do
Chain {..} <- get
proposal <- lift (propose radial chainPosition)
let proposalScore = lTarget chainTarget proposal
acceptProb = whenNaN 0 (exp (min 0 (proposalScore - chainScore)))
accept <- lift (MWC.bernoulli acceptProb)
when accept $ do
let tuned = tunable <*> Just proposal
put (Chain chainTarget proposalScore proposal tuned)
drive
:: (Traversable f, PrimMonad m)
=> Double
-> Maybe (f Double -> b)
-> Chain (f Double) b
-> Gen (PrimState m)
-> Producer (Chain (f Double) b) m c
drive radial tunable = loop where
loop state prng = do
let rvar = execStateT (metropolis radial tunable) state
next <- lift (MWC.sample rvar prng)
yield next
loop next prng
chain' ::
(PrimMonad m, Traversable f)
=> Int
-> Double
-> f Double
-> (f Double -> Double)
-> Maybe (f Double -> b)
-> Gen (PrimState m)
-> m [Chain (f Double) b]
chain' n radial position target tunable gen =
runEffect $ drive radial tunable origin gen >-> collect n
where
ctarget = Target target Nothing
origin = Chain
{ chainScore = lTarget ctarget position
, chainTunables = tunable <*> Just position
, chainTarget = ctarget
, chainPosition = position
}
collect :: Monad m => Int -> Consumer a m [a]
collect size = lowerCodensity $ replicateM size (lift Pipes.await)
chain
:: (PrimMonad m, Traversable f)
=> Int
-> Double
-> f Double
-> (f Double -> Double)
-> Gen (PrimState m)
-> m [Chain (f Double) b]
chain n radial position target = chain' n radial position target Nothing
mcmc
:: (MonadIO m, PrimMonad m, Traversable f, Show (f Double))
=> Int
-> Double
-> f Double
-> (f Double -> Double)
-> Gen (PrimState m)
-> m ()
mcmc n radial chainPosition target gen = runEffect $
drive radial Nothing Chain {..} gen
>-> Pipes.take n
>-> Pipes.mapM_ (liftIO . print)
where
chainScore = lTarget chainTarget chainPosition
chainTunables = Nothing
chainTarget = Target target Nothing
whenNaN :: RealFloat a => a -> a -> a
whenNaN val x
| isNaN x = val
| otherwise = x