{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE TypeFamilies #-}
module Numeric.MCMC.Slice (
mcmc
, chain
, slice
, MWC.create
, MWC.createSystemRandom
, MWC.withSystemRandom
, MWC.asGenIO
) where
import Control.Monad (replicateM)
import Control.Monad.Codensity (lowerCodensity)
import Control.Monad.Trans.State.Strict (put, get, execStateT)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Lens hiding (index)
import Data.Maybe (fromMaybe)
import Data.Sampling.Types
import Pipes hiding (next)
import qualified Pipes.Prelude as Pipes
import System.Random.MWC.Probability (Prob, Gen, Variate)
import qualified System.Random.MWC.Probability as MWC
mcmc
:: (MonadIO m, PrimMonad m,
Show (t a), FoldableWithIndex (Index (t a)) t, Ixed (t a),
Num (IxValue (t a)), Variate (IxValue (t a)))
=> Int
-> IxValue (t a)
-> t a
-> (t a -> Double)
-> Gen (PrimState m)
-> m ()
mcmc :: Int
-> IxValue (t a)
-> t a
-> (t a -> Double)
-> Gen (PrimState m)
-> m ()
mcmc Int
n IxValue (t a)
radial t a
chainPosition t a -> Double
target Gen (PrimState m)
gen = Effect m () -> m ()
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m () -> m ()) -> Effect m () -> m ()
forall a b. (a -> b) -> a -> b
$
IxValue (t a)
-> Chain (t a) Any
-> Gen (PrimState m)
-> Producer (Chain (t a) Any) m ()
forall (m :: * -> *) (t :: * -> *) a b c.
(PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
Num (IxValue (t a)), Variate (IxValue (t a))) =>
IxValue (t a)
-> Chain (t a) b
-> Gen (PrimState m)
-> Producer (Chain (t a) b) m c
drive IxValue (t a)
radial Chain :: forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain {t a
Double
Maybe Any
Target (t a)
forall a. Maybe a
chainTarget :: Target (t a)
chainScore :: Double
chainPosition :: t a
chainTunables :: Maybe Any
chainTarget :: Target (t a)
chainTunables :: forall a. Maybe a
chainScore :: Double
chainPosition :: t a
..} Gen (PrimState m)
gen
Producer (Chain (t a) Any) m ()
-> Proxy () (Chain (t a) Any) () (Chain (t a) Any) m ()
-> Producer (Chain (t a) Any) m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int -> Proxy () (Chain (t a) Any) () (Chain (t a) Any) m ()
forall (m :: * -> *) a. Functor m => Int -> Pipe a a m ()
Pipes.take Int
n
Producer (Chain (t a) Any) m ()
-> Proxy () (Chain (t a) Any) () X m () -> Effect m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> (Chain (t a) Any -> m ()) -> Consumer' (Chain (t a) Any) m ()
forall (m :: * -> *) a r. Monad m => (a -> m ()) -> Consumer' a m r
Pipes.mapM_ (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ())
-> (Chain (t a) Any -> IO ()) -> Chain (t a) Any -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Chain (t a) Any -> IO ()
forall a. Show a => a -> IO ()
print)
where
chainScore :: Double
chainScore = Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget t a
chainPosition
chainTunables :: Maybe a
chainTunables = Maybe a
forall a. Maybe a
Nothing
chainTarget :: Target (t a)
chainTarget = (t a -> Double) -> Maybe (t a -> t a) -> Target (t a)
forall a. (a -> Double) -> Maybe (a -> a) -> Target a
Target t a -> Double
target Maybe (t a -> t a)
forall a. Maybe a
Nothing
chain
:: (PrimMonad m, FoldableWithIndex (Index (f a)) f, Ixed (f a)
, Variate (IxValue (f a)), Num (IxValue (f a)))
=> Int
-> IxValue (f a)
-> f a
-> (f a -> Double)
-> Gen (PrimState m)
-> m [Chain (f a) b]
chain :: Int
-> IxValue (f a)
-> f a
-> (f a -> Double)
-> Gen (PrimState m)
-> m [Chain (f a) b]
chain Int
n IxValue (f a)
radial f a
position f a -> Double
target Gen (PrimState m)
gen = Effect m [Chain (f a) b] -> m [Chain (f a) b]
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m [Chain (f a) b] -> m [Chain (f a) b])
-> Effect m [Chain (f a) b] -> m [Chain (f a) b]
forall a b. (a -> b) -> a -> b
$
IxValue (f a)
-> Chain (f a) b
-> Gen (PrimState m)
-> Producer (Chain (f a) b) m [Chain (f a) b]
forall (m :: * -> *) (t :: * -> *) a b c.
(PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
Num (IxValue (t a)), Variate (IxValue (t a))) =>
IxValue (t a)
-> Chain (t a) b
-> Gen (PrimState m)
-> Producer (Chain (t a) b) m c
drive IxValue (f a)
radial Chain (f a) b
origin Gen (PrimState m)
gen
Producer (Chain (f a) b) m [Chain (f a) b]
-> Proxy () (Chain (f a) b) () X m [Chain (f a) b]
-> Effect m [Chain (f a) b]
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int -> Proxy () (Chain (f a) b) () X m [Chain (f a) b]
forall (m :: * -> *) a. Monad m => Int -> Consumer a m [a]
collect Int
n
where
ctarget :: Target (f a)
ctarget = (f a -> Double) -> Maybe (f a -> f a) -> Target (f a)
forall a. (a -> Double) -> Maybe (a -> a) -> Target a
Target f a -> Double
target Maybe (f a -> f a)
forall a. Maybe a
Nothing
origin :: Chain (f a) b
origin = Chain :: forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain {
chainScore :: Double
chainScore = Target (f a) -> f a -> Double
forall a. Target a -> a -> Double
lTarget Target (f a)
ctarget f a
position
, chainTunables :: Maybe b
chainTunables = Maybe b
forall a. Maybe a
Nothing
, chainTarget :: Target (f a)
chainTarget = Target (f a)
ctarget
, chainPosition :: f a
chainPosition = f a
position
}
collect :: Monad m => Int -> Consumer a m [a]
collect :: Int -> Consumer a m [a]
collect Int
size = Codensity (Proxy () a () X m) [a] -> Consumer a m [a]
forall (f :: * -> *) a. Applicative f => Codensity f a -> f a
lowerCodensity (Codensity (Proxy () a () X m) [a] -> Consumer a m [a])
-> Codensity (Proxy () a () X m) [a] -> Consumer a m [a]
forall a b. (a -> b) -> a -> b
$
Int
-> Codensity (Proxy () a () X m) a
-> Codensity (Proxy () a () X m) [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
size (Proxy () a () X m a -> Codensity (Proxy () a () X m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Proxy () a () X m a
forall (m :: * -> *) a. Functor m => Consumer' a m a
Pipes.await)
drive
:: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
Num (IxValue (t a)), Variate (IxValue (t a)))
=> IxValue (t a)
-> Chain (t a) b
-> Gen (PrimState m)
-> Producer (Chain (t a) b) m c
drive :: IxValue (t a)
-> Chain (t a) b
-> Gen (PrimState m)
-> Producer (Chain (t a) b) m c
drive IxValue (t a)
radial = Chain (t a) b -> Gen (PrimState m) -> Producer (Chain (t a) b) m c
loop where
loop :: Chain (t a) b -> Gen (PrimState m) -> Producer (Chain (t a) b) m c
loop Chain (t a) b
state Gen (PrimState m)
prng = do
Chain (t a) b
next <- m (Chain (t a) b)
-> Proxy X () () (Chain (t a) b) m (Chain (t a) b)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Prob m (Chain (t a) b) -> Gen (PrimState m) -> m (Chain (t a) b)
forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
MWC.sample (StateT (Chain (t a) b) (Prob m) ()
-> Chain (t a) b -> Prob m (Chain (t a) b)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT (IxValue (t a) -> StateT (Chain (t a) b) (Prob m) ()
forall (m :: * -> *) (t :: * -> *) a b.
(PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
Num (IxValue (t a)), Variate (IxValue (t a))) =>
IxValue (t a) -> Transition m (Chain (t a) b)
slice IxValue (t a)
radial) Chain (t a) b
state) Gen (PrimState m)
prng)
Chain (t a) b -> Proxy X () () (Chain (t a) b) m ()
forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield Chain (t a) b
next
Chain (t a) b -> Gen (PrimState m) -> Producer (Chain (t a) b) m c
loop Chain (t a) b
next Gen (PrimState m)
prng
slice
:: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
Num (IxValue (t a)), Variate (IxValue (t a)))
=> IxValue (t a)
-> Transition m (Chain (t a) b)
slice :: IxValue (t a) -> Transition m (Chain (t a) b)
slice IxValue (t a)
step = do
Chain Target (t a)
_ Double
_ t a
position Maybe b
_ <- StateT (Chain (t a) b) (Prob m) (Chain (t a) b)
forall (m :: * -> *) s. Monad m => StateT s m s
get
t a
-> (Index (t a) -> a -> Transition m (Chain (t a) b))
-> Transition m (Chain (t a) b)
forall i (t :: * -> *) (f :: * -> *) a b.
(FoldableWithIndex i t, Applicative f) =>
t a -> (i -> a -> f b) -> f ()
ifor_ t a
position ((Index (t a) -> a -> Transition m (Chain (t a) b))
-> Transition m (Chain (t a) b))
-> (Index (t a) -> a -> Transition m (Chain (t a) b))
-> Transition m (Chain (t a) b)
forall a b. (a -> b) -> a -> b
$ \Index (t a)
index a
_ -> do
Chain {t a
Double
Maybe b
Target (t a)
chainTunables :: Maybe b
chainPosition :: t a
chainScore :: Double
chainTarget :: Target (t a)
chainTarget :: forall a b. Chain a b -> Target a
chainScore :: forall a b. Chain a b -> Double
chainPosition :: forall a b. Chain a b -> a
chainTunables :: forall a b. Chain a b -> Maybe b
..} <- StateT (Chain (t a) b) (Prob m) (Chain (t a) b)
forall (m :: * -> *) s. Monad m => StateT s m s
get
let bounds :: (Double, Double)
bounds = (Double
0, Double -> Double
forall a. Floating a => a -> a
exp (Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget t a
chainPosition))
Double
height <- Prob m Double -> StateT (Chain (t a) b) (Prob m) Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ((Double -> Double) -> Prob m Double -> Prob m Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Double -> Double
forall a. Floating a => a -> a
log ((Double, Double) -> Prob m Double
forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
MWC.uniformR (Double, Double)
bounds))
let bracket :: (IxValue (t a), IxValue (t a))
bracket =
(t a -> Double)
-> Index (t a)
-> IxValue (t a)
-> Double
-> t a
-> (IxValue (t a), IxValue (t a))
forall a s.
(Ord a, Ixed s, Num (IxValue s)) =>
(s -> a)
-> Index s -> IxValue s -> a -> s -> (IxValue s, IxValue s)
findBracket (Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget) Index (t a)
index IxValue (t a)
step Double
height t a
chainPosition
t a
perturbed <- Prob m (t a) -> StateT (Chain (t a) b) (Prob m) (t a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Prob m (t a) -> StateT (Chain (t a) b) (Prob m) (t a))
-> Prob m (t a) -> StateT (Chain (t a) b) (Prob m) (t a)
forall a b. (a -> b) -> a -> b
$
(t a -> Double)
-> Index (t a)
-> (IxValue (t a), IxValue (t a))
-> Double
-> t a
-> Prob m (t a)
forall a (m :: * -> *) b.
(Ord a, PrimMonad m, Ixed b, Variate (IxValue b)) =>
(b -> a) -> Index b -> (IxValue b, IxValue b) -> a -> b -> Prob m b
rejection (Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget) Index (t a)
index (IxValue (t a), IxValue (t a))
bracket Double
height t a
chainPosition
let perturbedScore :: Double
perturbedScore = Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget t a
perturbed
Chain (t a) b -> Transition m (Chain (t a) b)
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Target (t a) -> Double -> t a -> Maybe b -> Chain (t a) b
forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain Target (t a)
chainTarget Double
perturbedScore t a
perturbed Maybe b
chainTunables)
findBracket
:: (Ord a, Ixed s, Num (IxValue s))
=> (s -> a)
-> Index s
-> IxValue s
-> a
-> s
-> (IxValue s, IxValue s)
findBracket :: (s -> a)
-> Index s -> IxValue s -> a -> s -> (IxValue s, IxValue s)
findBracket s -> a
target Index s
index IxValue s
step a
height s
xs = IxValue s -> s -> s -> (IxValue s, IxValue s)
go IxValue s
step s
xs s
xs where
err :: a
err = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"findBracket: invalid index -- please report this as a bug!"
go :: IxValue s -> s -> s -> (IxValue s, IxValue s)
go !IxValue s
e !s
bl !s
br
| s -> a
target s
bl a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
height Bool -> Bool -> Bool
&& s -> a
target s
br a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
height =
let l :: IxValue s
l = IxValue s -> Maybe (IxValue s) -> IxValue s
forall a. a -> Maybe a -> a
fromMaybe IxValue s
forall a. a
err (s
bl s -> Getting (First (IxValue s)) s (IxValue s) -> Maybe (IxValue s)
forall s a. s -> Getting (First a) s a -> Maybe a
^? Index s -> Traversal' s (IxValue s)
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Index s
index)
r :: IxValue s
r = IxValue s -> Maybe (IxValue s) -> IxValue s
forall a. a -> Maybe a -> a
fromMaybe IxValue s
forall a. a
err (s
br s -> Getting (First (IxValue s)) s (IxValue s) -> Maybe (IxValue s)
forall s a. s -> Getting (First a) s a -> Maybe a
^? Index s -> Traversal' s (IxValue s)
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Index s
index)
in (IxValue s
l, IxValue s
r)
| s -> a
target s
bl a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
height Bool -> Bool -> Bool
&& s -> a
target s
br a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
height =
let br0 :: s
br0 = Index s -> IxValue s -> s -> s
forall s.
(Ixed s, Num (IxValue s)) =>
Index s -> IxValue s -> s -> s
expandBracketRight Index s
index IxValue s
step s
br
in IxValue s -> s -> s -> (IxValue s, IxValue s)
go (IxValue s
2 IxValue s -> IxValue s -> IxValue s
forall a. Num a => a -> a -> a
* IxValue s
e) s
bl s
br0
| s -> a
target s
bl a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
height Bool -> Bool -> Bool
&& s -> a
target s
br a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
height =
let bl0 :: s
bl0 = Index s -> IxValue s -> s -> s
forall s.
(Ixed s, Num (IxValue s)) =>
Index s -> IxValue s -> s -> s
expandBracketLeft Index s
index IxValue s
step s
bl
in IxValue s -> s -> s -> (IxValue s, IxValue s)
go (IxValue s
2 IxValue s -> IxValue s -> IxValue s
forall a. Num a => a -> a -> a
* IxValue s
e) s
bl0 s
br
| Bool
otherwise =
let bl0 :: s
bl0 = Index s -> IxValue s -> s -> s
forall s.
(Ixed s, Num (IxValue s)) =>
Index s -> IxValue s -> s -> s
expandBracketLeft Index s
index IxValue s
step s
bl
br0 :: s
br0 = Index s -> IxValue s -> s -> s
forall s.
(Ixed s, Num (IxValue s)) =>
Index s -> IxValue s -> s -> s
expandBracketRight Index s
index IxValue s
step s
br
in IxValue s -> s -> s -> (IxValue s, IxValue s)
go (IxValue s
2 IxValue s -> IxValue s -> IxValue s
forall a. Num a => a -> a -> a
* IxValue s
e) s
bl0 s
br0
expandBracketLeft
:: (Ixed s, Num (IxValue s))
=> Index s
-> IxValue s
-> s
-> s
expandBracketLeft :: Index s -> IxValue s -> s -> s
expandBracketLeft = (IxValue s -> IxValue s -> IxValue s)
-> Index s -> IxValue s -> s -> s
forall s t.
Ixed s =>
(IxValue s -> t -> IxValue s) -> Index s -> t -> s -> s
expandBracketBy (-)
expandBracketRight
:: (Ixed s, Num (IxValue s))
=> Index s
-> IxValue s
-> s
-> s
expandBracketRight :: Index s -> IxValue s -> s -> s
expandBracketRight = (IxValue s -> IxValue s -> IxValue s)
-> Index s -> IxValue s -> s -> s
forall s t.
Ixed s =>
(IxValue s -> t -> IxValue s) -> Index s -> t -> s -> s
expandBracketBy IxValue s -> IxValue s -> IxValue s
forall a. Num a => a -> a -> a
(+)
expandBracketBy
:: Ixed s
=> (IxValue s -> t -> IxValue s)
-> Index s
-> t
-> s
-> s
expandBracketBy :: (IxValue s -> t -> IxValue s) -> Index s -> t -> s -> s
expandBracketBy IxValue s -> t -> IxValue s
f Index s
index t
step s
xs = s
xs s -> (s -> s) -> s
forall a b. a -> (a -> b) -> b
& Index s -> Traversal' s (IxValue s)
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Index s
index ((IxValue s -> Identity (IxValue s)) -> s -> Identity s)
-> (IxValue s -> IxValue s) -> s -> s
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (IxValue s -> t -> IxValue s
`f` t
step )
rejection
:: (Ord a, PrimMonad m, Ixed b, Variate (IxValue b))
=> (b -> a)
-> Index b
-> (IxValue b, IxValue b)
-> a
-> b
-> Prob m b
rejection :: (b -> a) -> Index b -> (IxValue b, IxValue b) -> a -> b -> Prob m b
rejection b -> a
target Index b
dimension (IxValue b, IxValue b)
bracket a
height = b -> Prob m b
go where
go :: b -> Prob m b
go b
zs = do
IxValue b
u <- (IxValue b, IxValue b) -> Prob m (IxValue b)
forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
MWC.uniformR (IxValue b, IxValue b)
bracket
let updated :: b
updated = b
zs b -> (b -> b) -> b
forall a b. a -> (a -> b) -> b
& Index b -> Traversal' b (IxValue b)
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Index b
dimension ((IxValue b -> Identity (IxValue b)) -> b -> Identity b)
-> IxValue b -> b -> b
forall s t a b. ASetter s t a b -> b -> s -> t
.~ IxValue b
u
if b -> a
target b
updated a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
height
then b -> Prob m b
go b
updated
else b -> Prob m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
updated