{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
module Numeric.MCMC (
concatT
, concatAllT
, sampleT
, sampleAllT
, bernoulliT
, frequency
, anneal
, mcmc
, chain
, module Data.Sampling.Types
, metropolis
, hamiltonian
, slice
, MWC.create
, MWC.createSystemRandom
, MWC.withSystemRandom
, MWC.asGenIO
, PrimMonad
, PrimState
, RealWorld
) where
import Control.Monad (replicateM)
import Control.Monad.Codensity (lowerCodensity)
import Control.Monad.Primitive (PrimMonad, PrimState, RealWorld)
import Control.Monad.Trans.State.Strict (execStateT)
import Data.Sampling.Types
import Numeric.MCMC.Anneal
import qualified Numeric.MCMC.Metropolis as M (metropolis)
import Numeric.MCMC.Hamiltonian (hamiltonian)
import Numeric.MCMC.Slice (slice)
import Pipes hiding (next)
import qualified Pipes.Prelude as Pipes
import System.Random.MWC.Probability (Gen)
import qualified System.Random.MWC.Probability as MWC
concatT :: Monad m => Transition m a -> Transition m a -> Transition m a
concatT :: Transition m a -> Transition m a -> Transition m a
concatT = Transition m a -> Transition m a -> Transition m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
(>>)
concatAllT :: Monad m => [Transition m a] -> Transition m a
concatAllT :: [Transition m a] -> Transition m a
concatAllT = (Transition m a -> Transition m a -> Transition m a)
-> [Transition m a] -> Transition m a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 Transition m a -> Transition m a -> Transition m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
(>>)
sampleT :: PrimMonad m => Transition m a -> Transition m a -> Transition m a
sampleT :: Transition m a -> Transition m a -> Transition m a
sampleT = Double -> Transition m a -> Transition m a -> Transition m a
forall (m :: * -> *) a.
PrimMonad m =>
Double -> Transition m a -> Transition m a -> Transition m a
bernoulliT Double
0.5
bernoulliT
:: PrimMonad m
=> Double
-> Transition m a
-> Transition m a
-> Transition m a
bernoulliT :: Double -> Transition m a -> Transition m a -> Transition m a
bernoulliT Double
p Transition m a
t0 Transition m a
t1 = do
Bool
heads <- Prob m Bool -> StateT a (Prob m) Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Double -> Prob m Bool
forall (m :: * -> *). PrimMonad m => Double -> Prob m Bool
MWC.bernoulli Double
p)
if Bool
heads then Transition m a
t0 else Transition m a
t1
sampleAllT :: PrimMonad m => [Transition m a] -> Transition m a
sampleAllT :: [Transition m a] -> Transition m a
sampleAllT [Transition m a]
ts = do
Int
j <- Prob m Int -> StateT a (Prob m) Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ((Int, Int) -> Prob m Int
forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
MWC.uniformR (Int
0, [Transition m a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Transition m a]
ts Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
[Transition m a]
ts [Transition m a] -> Int -> Transition m a
forall a. [a] -> Int -> a
!! Int
j
frequency :: PrimMonad m => [(Int, Transition m a)] -> Transition m a
frequency :: [(Int, Transition m a)] -> Transition m a
frequency [(Int, Transition m a)]
xs = Prob m Int -> StateT a (Prob m) Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ((Int, Int) -> Prob m Int
forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
MWC.uniformR (Int
1, Int
tot)) StateT a (Prob m) Int -> (Int -> Transition m a) -> Transition m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Int -> [(Int, Transition m a)] -> Transition m a
forall t p. (Ord t, Num t) => t -> [(t, p)] -> p
`pick` [(Int, Transition m a)]
xs) where
tot :: Int
tot = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([(Int, Transition m a)] -> [Int])
-> [(Int, Transition m a)]
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Transition m a) -> Int) -> [(Int, Transition m a)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Transition m a) -> Int
forall a b. (a, b) -> a
fst ([(Int, Transition m a)] -> Int) -> [(Int, Transition m a)] -> Int
forall a b. (a -> b) -> a -> b
$ [(Int, Transition m a)]
xs
pick :: t -> [(t, p)] -> p
pick t
n ((t
k, p
v):[(t, p)]
vs)
| t
n t -> t -> Bool
forall a. Ord a => a -> a -> Bool
<= t
k = p
v
| Bool
otherwise = t -> [(t, p)] -> p
pick (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
k) [(t, p)]
vs
pick t
_ [(t, p)]
_ = [Char] -> p
forall a. HasCallStack => [Char] -> a
error [Char]
"frequency: no distribution specified"
mcmc
:: (MonadIO m, PrimMonad m, Show (t a))
=> Int
-> t a
-> Transition m (Chain (t a) b)
-> Target (t a)
-> Gen (PrimState m)
-> m ()
mcmc :: Int
-> t a
-> Transition m (Chain (t a) b)
-> Target (t a)
-> Gen (PrimState m)
-> m ()
mcmc Int
n t a
chainPosition Transition m (Chain (t a) b)
transition Target (t a)
chainTarget Gen (PrimState m)
gen = Effect m () -> m ()
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m () -> m ()) -> Effect m () -> m ()
forall a b. (a -> b) -> a -> b
$
Transition m (Chain (t a) b)
-> Chain (t a) b
-> Gen (PrimState m)
-> Producer (Chain (t a) b) m ()
forall (m :: * -> *) b a.
PrimMonad m =>
Transition m b -> b -> Gen (PrimState m) -> Producer b m a
drive Transition m (Chain (t a) b)
transition Chain :: forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain {t a
Double
Maybe b
Target (t a)
forall a. Maybe a
chainTarget :: Target (t a)
chainScore :: Double
chainPosition :: t a
chainTunables :: Maybe b
chainTunables :: forall a. Maybe a
chainScore :: Double
chainTarget :: Target (t a)
chainPosition :: t a
..} Gen (PrimState m)
gen
Producer (Chain (t a) b) m ()
-> Proxy () (Chain (t a) b) () (Chain (t a) b) m ()
-> Producer (Chain (t a) b) m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int -> Proxy () (Chain (t a) b) () (Chain (t a) b) m ()
forall (m :: * -> *) a. Functor m => Int -> Pipe a a m ()
Pipes.take Int
n
Producer (Chain (t a) b) m ()
-> Proxy () (Chain (t a) b) () X m () -> Effect m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> (Chain (t a) b -> m ()) -> Consumer' (Chain (t a) b) m ()
forall (m :: * -> *) a r. Monad m => (a -> m ()) -> Consumer' a m r
Pipes.mapM_ (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ())
-> (Chain (t a) b -> IO ()) -> Chain (t a) b -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Chain (t a) b -> IO ()
forall a. Show a => a -> IO ()
print)
where
chainScore :: Double
chainScore = Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget t a
chainPosition
chainTunables :: Maybe a
chainTunables = Maybe a
forall a. Maybe a
Nothing
chain
:: (MonadIO m, PrimMonad m)
=> Int
-> t a
-> Transition m (Chain (t a) b)
-> Target (t a)
-> Gen (PrimState m)
-> m [Chain (t a) b]
chain :: Int
-> t a
-> Transition m (Chain (t a) b)
-> Target (t a)
-> Gen (PrimState m)
-> m [Chain (t a) b]
chain Int
n t a
chainPosition Transition m (Chain (t a) b)
transition Target (t a)
chainTarget Gen (PrimState m)
gen = Effect m [Chain (t a) b] -> m [Chain (t a) b]
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m [Chain (t a) b] -> m [Chain (t a) b])
-> Effect m [Chain (t a) b] -> m [Chain (t a) b]
forall a b. (a -> b) -> a -> b
$
Transition m (Chain (t a) b)
-> Chain (t a) b
-> Gen (PrimState m)
-> Producer (Chain (t a) b) m [Chain (t a) b]
forall (m :: * -> *) b a.
PrimMonad m =>
Transition m b -> b -> Gen (PrimState m) -> Producer b m a
drive Transition m (Chain (t a) b)
transition Chain :: forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain {t a
Double
Maybe b
Target (t a)
forall a. Maybe a
chainTunables :: forall a. Maybe a
chainScore :: Double
chainTarget :: Target (t a)
chainPosition :: t a
chainTarget :: Target (t a)
chainScore :: Double
chainPosition :: t a
chainTunables :: Maybe b
..} Gen (PrimState m)
gen
Producer (Chain (t a) b) m [Chain (t a) b]
-> Proxy () (Chain (t a) b) () X m [Chain (t a) b]
-> Effect m [Chain (t a) b]
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int -> Proxy () (Chain (t a) b) () X m [Chain (t a) b]
forall (m :: * -> *) a. Monad m => Int -> Consumer a m [a]
collect Int
n
where
chainScore :: Double
chainScore = Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget t a
chainPosition
chainTunables :: Maybe a
chainTunables = Maybe a
forall a. Maybe a
Nothing
collect :: Monad m => Int -> Consumer a m [a]
collect :: Int -> Consumer a m [a]
collect Int
size = Codensity (Proxy () a () X m) [a] -> Consumer a m [a]
forall (f :: * -> *) a. Applicative f => Codensity f a -> f a
lowerCodensity (Codensity (Proxy () a () X m) [a] -> Consumer a m [a])
-> Codensity (Proxy () a () X m) [a] -> Consumer a m [a]
forall a b. (a -> b) -> a -> b
$
Int
-> Codensity (Proxy () a () X m) a
-> Codensity (Proxy () a () X m) [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
size (Proxy () a () X m a -> Codensity (Proxy () a () X m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Proxy () a () X m a
forall (m :: * -> *) a. Functor m => Consumer' a m a
Pipes.await)
drive
:: PrimMonad m
=> Transition m b
-> b
-> Gen (PrimState m)
-> Producer b m a
drive :: Transition m b -> b -> Gen (PrimState m) -> Producer b m a
drive Transition m b
transition = b -> Gen (PrimState m) -> Producer b m a
forall x' x b. b -> Gen (PrimState m) -> Proxy x' x () b m b
loop where
loop :: b -> Gen (PrimState m) -> Proxy x' x () b m b
loop b
state Gen (PrimState m)
prng = do
b
next <- m b -> Proxy x' x () b m b
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Prob m b -> Gen (PrimState m) -> m b
forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
MWC.sample (Transition m b -> b -> Prob m b
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT Transition m b
transition b
state) Gen (PrimState m)
prng)
b -> Proxy x' x () b m ()
forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield b
next
b -> Gen (PrimState m) -> Proxy x' x () b m b
loop b
next Gen (PrimState m)
prng
metropolis
:: (Traversable f, PrimMonad m)
=> Double
-> Transition m (Chain (f Double) b)
metropolis :: Double -> Transition m (Chain (f Double) b)
metropolis Double
radial = Double
-> Maybe (f Double -> b) -> Transition m (Chain (f Double) b)
forall (f :: * -> *) (m :: * -> *) b.
(Traversable f, PrimMonad m) =>
Double
-> Maybe (f Double -> b) -> Transition m (Chain (f Double) b)
M.metropolis Double
radial Maybe (f Double -> b)
forall a. Maybe a
Nothing