module Numeric.MCMC.Slice (
mcmc
, slice
, MWC.create
, MWC.createSystemRandom
, MWC.withSystemRandom
, MWC.asGenIO
) where
import Control.Monad.Trans.State.Strict (put, get, execStateT)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Lens hiding (index)
import GHC.Prim (RealWorld)
import Data.Maybe (fromMaybe)
import Data.Sampling.Types
import Pipes hiding (next)
import qualified Pipes.Prelude as Pipes
import System.Random.MWC.Probability (Prob, Gen, Variate)
import qualified System.Random.MWC.Probability as MWC
mcmc
:: (Show (t a), FoldableWithIndex (Index (t a)) t, Ixed (t a),
Num (IxValue (t a)), Variate (IxValue (t a)))
=> Int
-> IxValue (t a)
-> t a
-> (t a -> Double)
-> Gen RealWorld
-> IO ()
mcmc n radial chainPosition target gen = runEffect $
chain radial Chain {..} gen
>-> Pipes.take n
>-> Pipes.mapM_ print
where
chainScore = lTarget chainTarget chainPosition
chainTunables = Nothing
chainTarget = Target target Nothing
chain
:: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
Num (IxValue (t a)), Variate (IxValue (t a)))
=> IxValue (t a)
-> Chain (t a) b
-> Gen (PrimState m)
-> Producer (Chain (t a) b) m ()
chain radial = loop where
loop state prng = do
next <- lift (MWC.sample (execStateT (slice radial) state) prng)
yield next
loop next prng
slice
:: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
Num (IxValue (t a)), Variate (IxValue (t a)))
=> IxValue (t a)
-> Transition m (Chain (t a) b)
slice step = do
Chain _ _ position _ <- get
ifor_ position $ \index _ -> do
Chain {..} <- get
let bounds = (0, exp (lTarget chainTarget chainPosition))
height <- lift (fmap log (MWC.uniformR bounds))
let bracket =
findBracket (lTarget chainTarget) index step height chainPosition
perturbed <- lift $
rejection (lTarget chainTarget) index bracket height chainPosition
let perturbedScore = lTarget chainTarget perturbed
put (Chain chainTarget perturbedScore perturbed chainTunables)
findBracket
:: (Ord a, Ixed s, Num (IxValue s))
=> (s -> a)
-> Index s
-> IxValue s
-> a
-> s
-> (IxValue s, IxValue s)
findBracket target index step height xs = go step xs xs where
err = error "findBracket: invalid index -- please report this as a bug!"
go !e !bl !br
| target bl < height && target br < height =
let l = fromMaybe err (bl ^? ix index)
r = fromMaybe err (br ^? ix index)
in (l, r)
| target bl < height && target br >= height =
let br0 = expandBracketRight index step br
in go (2 * e) bl br0
| target bl >= height && target br < height =
let bl0 = expandBracketLeft index step bl
in go (2 * e) bl0 br
| otherwise =
let bl0 = expandBracketLeft index step bl
br0 = expandBracketRight index step br
in go (2 * e) bl0 br0
expandBracketLeft
:: (Ixed s, Num (IxValue s))
=> Index s
-> IxValue s
-> s
-> s
expandBracketLeft = expandBracketBy ()
expandBracketRight
:: (Ixed s, Num (IxValue s))
=> Index s
-> IxValue s
-> s
-> s
expandBracketRight = expandBracketBy (+)
expandBracketBy
:: Ixed s
=> (IxValue s -> t -> IxValue s)
-> Index s
-> t
-> s
-> s
expandBracketBy f index step xs = xs & ix index %~ (`f` step )
rejection
:: (Ord a, PrimMonad m, Ixed b, Variate (IxValue b))
=> (b -> a)
-> Index b
-> (IxValue b, IxValue b)
-> a
-> b
-> Prob m b
rejection target dimension bracket height = go where
go zs = do
u <- MWC.uniformR bracket
let updated = zs & ix dimension .~ u
if target updated < height
then go updated
else return updated