module Numeric.MCMC.Metropolis (
mcmc
, chain
, metropolis
, module Data.Sampling.Types
, MWC.create
, MWC.createSystemRandom
, MWC.withSystemRandom
, MWC.asGenIO
) where
import Control.Monad (when, replicateM)
import Control.Monad.Codensity (lowerCodensity)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict (execStateT, get, put)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Sampling.Types (Target(..), Chain(..), Transition)
#if __GLASGOW_HASKELL__ < 710
import Data.Traversable (Traversable, traverse)
#endif
import Pipes (Producer, Consumer, yield, (>->), runEffect, await)
import qualified Pipes.Prelude as Pipes (mapM_, take)
import System.Random.MWC.Probability (Gen, Prob)
import qualified System.Random.MWC.Probability as MWC
propose
:: (PrimMonad m, Traversable f)
=> Double
-> f Double
-> Prob m (f Double)
propose radial = traverse perturb where
perturb m = MWC.normal m radial
metropolis
:: (Traversable f, PrimMonad m)
=> Double
-> Transition m (Chain (f Double) b)
metropolis radial = do
Chain {..} <- get
proposal <- lift (propose radial chainPosition)
let proposalScore = lTarget chainTarget proposal
acceptProb = whenNaN 0 (exp (min 0 (proposalScore chainScore)))
accept <- lift (MWC.bernoulli acceptProb)
when accept (put (Chain chainTarget proposalScore proposal chainTunables))
drive
:: (Traversable f, PrimMonad m)
=> Double
-> Chain (f Double) b
-> Gen (PrimState m)
-> Producer (Chain (f Double) b) m c
drive radial = loop where
loop state prng = do
next <- lift (MWC.sample (execStateT (metropolis radial) state) prng)
yield next
loop next prng
chain
:: (PrimMonad m, Traversable f)
=> Int
-> Double
-> f Double
-> (f Double -> Double)
-> Gen (PrimState m)
-> m [Chain (f Double) b]
chain n radial position target gen = runEffect $
drive radial origin gen
>-> collect n
where
ctarget = Target target Nothing
origin = Chain {
chainScore = lTarget ctarget position
, chainTunables = Nothing
, chainTarget = ctarget
, chainPosition = position
}
collect :: Monad m => Int -> Consumer a m [a]
collect size = lowerCodensity $
replicateM size (lift Pipes.await)
mcmc
:: (MonadIO m, PrimMonad m, Traversable f, Show (f Double))
=> Int
-> Double
-> f Double
-> (f Double -> Double)
-> Gen (PrimState m)
-> m ()
mcmc n radial chainPosition target gen = runEffect $
drive radial Chain {..} gen
>-> Pipes.take n
>-> Pipes.mapM_ (liftIO . print)
where
chainScore = lTarget chainTarget chainPosition
chainTunables = Nothing
chainTarget = Target target Nothing
whenNaN :: RealFloat a => a -> a -> a
whenNaN val x
| isNaN x = val
| otherwise = x