module Numeric.MCMC.Metropolis (
mcmc
, metropolis
, module Data.Sampling.Types
, MWC.create
, MWC.createSystemRandom
, MWC.withSystemRandom
, MWC.asGenIO
) where
import Control.Monad (when)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict (execStateT, get, put)
import Data.Sampling.Types (Target(..), Chain(..), Transition)
#if __GLASGOW_HASKELL__ < 710
import Data.Traversable (Traversable, traverse)
#endif
import GHC.Prim (RealWorld)
import Pipes (Producer, yield, (>->), runEffect)
import qualified Pipes.Prelude as Pipes (mapM_, take)
import System.Random.MWC.Probability (Gen, Prob)
import qualified System.Random.MWC.Probability as MWC
propose
:: (PrimMonad m, Traversable f)
=> Double
-> f Double
-> Prob m (f Double)
propose radial = traverse perturb where
perturb m = MWC.normal m radial
metropolis
:: (Traversable f, PrimMonad m)
=> Double
-> Transition m (Chain (f Double) b)
metropolis radial = do
Chain {..} <- get
proposal <- lift (propose radial chainPosition)
let proposalScore = lTarget chainTarget proposal
acceptProb = whenNaN 0 (exp (min 0 (proposalScore chainScore)))
accept <- lift (MWC.bernoulli acceptProb)
when accept (put (Chain chainTarget proposalScore proposal chainTunables))
chain
:: (Traversable f, PrimMonad m)
=> Double
-> Chain (f Double) b
-> Gen (PrimState m)
-> Producer (Chain (f Double) b) m ()
chain radial = loop where
loop state prng = do
next <- lift (MWC.sample (execStateT (metropolis radial) state) prng)
yield next
loop next prng
mcmc
:: (Traversable f, Show (f Double))
=> Int
-> Double
-> f Double
-> (f Double -> Double)
-> Gen RealWorld
-> IO ()
mcmc n radial chainPosition target gen = runEffect $
chain radial Chain {..} gen
>-> Pipes.take n
>-> Pipes.mapM_ print
where
chainScore = lTarget chainTarget chainPosition
chainTunables = Nothing
chainTarget = Target target Nothing
whenNaN :: RealFloat a => a -> a -> a
whenNaN val x
| isNaN x = val
| otherwise = x