{-# OPTIONS_GHC -fno-warn-type-defaults #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Numeric.MCMC.Flat (
mcmc
, flat
, Particle
, Ensemble
, Chain
, module Sampling.Types
, MWC.create
, MWC.createSystemRandom
, MWC.withSystemRandom
, MWC.asGenIO
, VE.ensemble
, VE.particle
) where
import Control.Monad (replicateM)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Par (NFData)
import Control.Monad.Par.Combinator (parMap)
import Control.Monad.Par.Scheds.Sparks hiding (get)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.Trans.State.Strict (get, put, execStateT)
import Data.Sampling.Types as Sampling.Types hiding (Chain(..))
import qualified Data.Text as T
import qualified Data.Text.IO as T (putStrLn)
import Data.Vector (Vector)
import qualified Data.Vector as V
import qualified Data.Vector.Extended as VE (ensemble, particle)
import qualified Data.Vector.Unboxed as U
import Formatting ((%))
import qualified Formatting as F
import Pipes (Producer, lift, yield, runEffect, (>->))
import qualified Pipes.Prelude as Pipes
import System.Random.MWC.Probability as MWC
data Chain = Chain {
chainTarget :: Target Particle
, chainPosition :: !Ensemble
}
render :: Chain -> T.Text
render Chain {..} = renderEnsemble chainPosition
{-# INLINE render #-}
renderParticle :: Particle -> T.Text
renderParticle =
T.drop 1
. U.foldl' glue mempty
where
glue = F.sformat (F.stext % "," % F.float)
{-# INLINE renderParticle #-}
renderEnsemble :: Ensemble -> T.Text
renderEnsemble =
T.drop 1
. V.foldl' glue mempty
where
glue a b = a <> "\n" <> renderParticle b
{-# INLINE renderEnsemble #-}
type Particle = U.Vector Double
type Ensemble = Vector Particle
symmetric :: PrimMonad m => Prob m Double
symmetric = fmap transform uniform where
transform z = 0.5 * (z + 1) ^ (2 :: Int)
{-# INLINE symmetric #-}
stretch :: Particle -> Particle -> Double -> Particle
stretch p0 p1 z = U.zipWith str p0 p1 where
str x y = z * x + (1 - z) * y
{-# INLINE stretch #-}
acceptProb :: Target Particle -> Particle -> Particle -> Double -> Double
acceptProb target particle proposal z =
lTarget target proposal
- lTarget target particle
+ log z * (fromIntegral (U.length particle) - 1)
{-# INLINE acceptProb #-}
move :: Target Particle -> Particle -> Particle -> Double -> Double -> Particle
move target !p0 p1 z zc =
let !proposal = stretch p0 p1 z
pAccept = acceptProb target p0 proposal z
in if zc <= min 1 (exp pAccept)
then proposal
else p0
{-# INLINE move #-}
execute
:: PrimMonad m
=> Target Particle
-> Ensemble
-> Ensemble
-> Int
-> Prob m Ensemble
execute target e0 e1 n = do
zs <- replicateM n symmetric
zcs <- replicateM n uniform
js <- U.replicateM n (uniformR (1, n))
let granularity = n `div` 2
w0 k = e0 `V.unsafeIndex` pred k
w1 k ks = e1 `V.unsafeIndex` pred (ks `U.unsafeIndex` pred k)
worker (k, z, zc) = move target (w0 k) (w1 k js) z zc
!result = runPar $
parMapChunk granularity worker (zip3 [1..n] zs zcs)
return $! V.fromList result
{-# INLINE execute #-}
flat
:: PrimMonad m
=> Transition m Chain
flat = do
Chain {..} <- get
let size = V.length chainPosition
n = truncate (fromIntegral size / 2)
e0 = V.unsafeSlice 0 n chainPosition
e1 = V.unsafeSlice n n chainPosition
result0 <- lift (execute chainTarget e0 e1 n)
result1 <- lift (execute chainTarget e1 result0 n)
let !ensemble = V.concat [result0, result1]
put $! (Chain chainTarget ensemble)
{-# INLINE flat #-}
chain :: PrimMonad m => Chain -> Gen (PrimState m) -> Producer Chain m ()
chain = loop where
loop state prng = do
next <- lift (MWC.sample (execStateT flat state) prng)
yield next
loop next prng
{-# INLINE chain #-}
mcmc
:: (MonadIO m, PrimMonad m)
=> Int
-> Ensemble
-> (Particle -> Double)
-> Gen (PrimState m)
-> m ()
mcmc n chainPosition target gen = runEffect $
chain Chain {..} gen
>-> Pipes.take n
>-> Pipes.mapM_ (liftIO . T.putStrLn . render)
where
chainTarget = Target target Nothing
{-# INLINE mcmc #-}
parMapChunk :: NFData b => Int -> (a -> b) -> [a] -> Par [b]
parMapChunk n f xs = concat <$> parMap (map f) (chunk n xs) where
chunk _ [] = []
chunk m ys =
let (as, bs) = splitAt m ys
in as : chunk m bs
{-# INLINE parMapChunk #-}