{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}

-- |
-- Module: Numeric.MCMC
-- Copyright: (c) 2015 Jared Tobin
-- License: MIT
--
-- Maintainer: Jared Tobin <jared@jtobin.ca>
-- Stability: unstable
-- Portability: ghc
--
-- This module presents a simple combinator language for Markov transition
-- operators that are useful in MCMC.
--
-- Any transition operators sharing the same stationary distribution and
-- obeying the Markov and reversibility properties can be combined in a couple
-- of ways, such that the resulting operator preserves the stationary
-- distribution and desirable properties amenable for MCMC.
--
-- We can deterministically concatenate operators end-to-end, or sample from
-- a collection of them according to some probability distribution.  See
-- <www.stat.umn.edu/geyer/f05/8931/n1998.pdf Geyer, 2005> for details.
--
-- The result is a simple grammar for building composite, property-preserving
-- transition operators from existing ones:
--
-- @
-- transition ::= primitive <transition>
--              | concatT transition transition
--              | sampleT transition transition
-- @
--
-- In addition to the above, this module provides a number of combinators for
-- building composite transition operators.  It re-exports a number of
-- production-quality transition operators from the /mighty-metropolis/,
-- /speedy-slice/, and /hasty-hamiltonian/ libraries.
--
-- Markov chains can then be run over arbitrary 'Target's using whatever
-- transition operator is desired.
--
-- > import Numeric.MCMC
-- > import Data.Sampling.Types
-- >
-- > target :: [Double] -> Double
-- > target [x0, x1] = negate (5  *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
-- >
-- > rosenbrock :: Target [Double]
-- > rosenbrock = Target target Nothing
-- >
-- > transition :: Transition IO (Chain [Double] b)
-- > transition =
-- >   concatT
-- >     (sampleT (metropolis 0.5) (metropolis 1.0))
-- >     (sampleT (slice 2.0) (slice 3.0))
-- >
-- > main :: IO ()
-- > main = withSystemRandom . asGenIO $ mcmc 10000 [0, 0] transition rosenbrock
--
-- See the attached test suite for other examples.

module Numeric.MCMC (
    concatT
  , concatAllT
  , sampleT
  , sampleAllT
  , bernoulliT
  , frequency
  , anneal
  , mcmc
  , chain

  -- * Re-exported
  , module Data.Sampling.Types

  , metropolis
  , hamiltonian
  , slice

  , MWC.create
  , MWC.createSystemRandom
  , MWC.withSystemRandom
  , MWC.asGenIO

  , PrimMonad
  , PrimState
  , RealWorld
  ) where

import Control.Monad (replicateM)
import Control.Monad.Codensity (lowerCodensity)
import Control.Monad.Primitive (PrimMonad, PrimState, RealWorld)
import Control.Monad.Trans.State.Strict (execStateT)
import Data.Sampling.Types
import Numeric.MCMC.Anneal
import qualified Numeric.MCMC.Metropolis as M (metropolis)
import Numeric.MCMC.Hamiltonian (hamiltonian)
import Numeric.MCMC.Slice (slice)
import Pipes hiding (next)
import qualified Pipes.Prelude as Pipes
import System.Random.MWC.Probability (Gen)
import qualified System.Random.MWC.Probability as MWC

-- | Deterministically concat transition operators together.
concatT :: Monad m => Transition m a -> Transition m a -> Transition m a
concatT :: Transition m a -> Transition m a -> Transition m a
concatT = Transition m a -> Transition m a -> Transition m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
(>>)

-- | Deterministically concat a list of transition operators together.
concatAllT :: Monad m => [Transition m a] -> Transition m a
concatAllT :: [Transition m a] -> Transition m a
concatAllT = (Transition m a -> Transition m a -> Transition m a)
-> [Transition m a] -> Transition m a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 Transition m a -> Transition m a -> Transition m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
(>>)

-- | Probabilistically concat transition operators together.
sampleT :: PrimMonad m => Transition m a -> Transition m a -> Transition m a
sampleT :: Transition m a -> Transition m a -> Transition m a
sampleT = Double -> Transition m a -> Transition m a -> Transition m a
forall (m :: * -> *) a.
PrimMonad m =>
Double -> Transition m a -> Transition m a -> Transition m a
bernoulliT Double
0.5

-- | Probabilistically concat transition operators together using a Bernoulli
--   distribution with the supplied success probability.
--
--   This is just a generalization of sampleT.
bernoulliT
  :: PrimMonad m
  => Double
  -> Transition m a
  -> Transition m a
  -> Transition m a
bernoulliT :: Double -> Transition m a -> Transition m a -> Transition m a
bernoulliT Double
p Transition m a
t0 Transition m a
t1 = do
  Bool
heads <- Prob m Bool -> StateT a (Prob m) Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Double -> Prob m Bool
forall (m :: * -> *). PrimMonad m => Double -> Prob m Bool
MWC.bernoulli Double
p)
  if Bool
heads then Transition m a
t0 else Transition m a
t1

-- | Probabilistically concat transition operators together via a uniform
--   distribution.
sampleAllT :: PrimMonad m => [Transition m a] -> Transition m a
sampleAllT :: [Transition m a] -> Transition m a
sampleAllT [Transition m a]
ts = do
  Int
j <- Prob m Int -> StateT a (Prob m) Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ((Int, Int) -> Prob m Int
forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
MWC.uniformR (Int
0, [Transition m a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Transition m a]
ts Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
  [Transition m a]
ts [Transition m a] -> Int -> Transition m a
forall a. [a] -> Int -> a
!! Int
j

-- | Probabilistically concat transition operators together using the supplied
--   frequency distribution.
--
--   This function is more-or-less an exact copy of 'QuickCheck.frequency',
--   except here applied to transition operators.
frequency :: PrimMonad m => [(Int, Transition m a)] -> Transition m a
frequency :: [(Int, Transition m a)] -> Transition m a
frequency [(Int, Transition m a)]
xs = Prob m Int -> StateT a (Prob m) Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ((Int, Int) -> Prob m Int
forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
MWC.uniformR (Int
1, Int
tot)) StateT a (Prob m) Int -> (Int -> Transition m a) -> Transition m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Int -> [(Int, Transition m a)] -> Transition m a
forall t p. (Ord t, Num t) => t -> [(t, p)] -> p
`pick` [(Int, Transition m a)]
xs) where
  tot :: Int
tot = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([(Int, Transition m a)] -> [Int])
-> [(Int, Transition m a)]
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Transition m a) -> Int) -> [(Int, Transition m a)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Transition m a) -> Int
forall a b. (a, b) -> a
fst ([(Int, Transition m a)] -> Int) -> [(Int, Transition m a)] -> Int
forall a b. (a -> b) -> a -> b
$ [(Int, Transition m a)]
xs
  pick :: t -> [(t, p)] -> p
pick t
n ((t
k, p
v):[(t, p)]
vs)
    | t
n t -> t -> Bool
forall a. Ord a => a -> a -> Bool
<= t
k = p
v
    | Bool
otherwise = t -> [(t, p)] -> p
pick (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
k) [(t, p)]
vs
  pick t
_ [(t, p)]
_ = [Char] -> p
forall a. HasCallStack => [Char] -> a
error [Char]
"frequency: no distribution specified"

-- | Trace 'n' iterations of a Markov chain and stream them to stdout.
--
-- >>> withSystemRandom . asGenIO $ mcmc 3 [0, 0] (metropolis 0.5) rosenbrock
-- -0.48939312153007863,0.13290702689491818
-- 1.4541485365128892e-2,-0.4859905564050404
-- 0.22487398491619448,-0.29769783186855125
mcmc
  :: (MonadIO m, PrimMonad m, Show (t a))
  => Int
  -> t a
  -> Transition m (Chain (t a) b)
  -> Target (t a)
  -> Gen (PrimState m)
  -> m ()
mcmc :: Int
-> t a
-> Transition m (Chain (t a) b)
-> Target (t a)
-> Gen (PrimState m)
-> m ()
mcmc Int
n t a
chainPosition Transition m (Chain (t a) b)
transition Target (t a)
chainTarget Gen (PrimState m)
gen = Effect m () -> m ()
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m () -> m ()) -> Effect m () -> m ()
forall a b. (a -> b) -> a -> b
$
        Transition m (Chain (t a) b)
-> Chain (t a) b
-> Gen (PrimState m)
-> Producer (Chain (t a) b) m ()
forall (m :: * -> *) b a.
PrimMonad m =>
Transition m b -> b -> Gen (PrimState m) -> Producer b m a
drive Transition m (Chain (t a) b)
transition Chain :: forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain {t a
Double
Maybe b
Target (t a)
forall a. Maybe a
chainTarget :: Target (t a)
chainScore :: Double
chainPosition :: t a
chainTunables :: Maybe b
chainTunables :: forall a. Maybe a
chainScore :: Double
chainTarget :: Target (t a)
chainPosition :: t a
..} Gen (PrimState m)
gen
    Producer (Chain (t a) b) m ()
-> Proxy () (Chain (t a) b) () (Chain (t a) b) m ()
-> Producer (Chain (t a) b) m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int -> Proxy () (Chain (t a) b) () (Chain (t a) b) m ()
forall (m :: * -> *) a. Functor m => Int -> Pipe a a m ()
Pipes.take Int
n
    Producer (Chain (t a) b) m ()
-> Proxy () (Chain (t a) b) () X m () -> Effect m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> (Chain (t a) b -> m ()) -> Consumer' (Chain (t a) b) m ()
forall (m :: * -> *) a r. Monad m => (a -> m ()) -> Consumer' a m r
Pipes.mapM_ (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ())
-> (Chain (t a) b -> IO ()) -> Chain (t a) b -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Chain (t a) b -> IO ()
forall a. Show a => a -> IO ()
print)
  where
    chainScore :: Double
chainScore    = Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget t a
chainPosition
    chainTunables :: Maybe a
chainTunables = Maybe a
forall a. Maybe a
Nothing

-- | Trace 'n' iterations of a Markov chain and collect them in a list.
--
-- >>> results <- withSystemRandom . asGenIO $ chain 3 [0, 0] (metropolis 0.5) rosenbrock
chain
  :: (MonadIO m, PrimMonad m)
  => Int
  -> t a
  -> Transition m (Chain (t a) b)
  -> Target (t a)
  -> Gen (PrimState m)
  -> m [Chain (t a) b]
chain :: Int
-> t a
-> Transition m (Chain (t a) b)
-> Target (t a)
-> Gen (PrimState m)
-> m [Chain (t a) b]
chain Int
n t a
chainPosition Transition m (Chain (t a) b)
transition Target (t a)
chainTarget Gen (PrimState m)
gen = Effect m [Chain (t a) b] -> m [Chain (t a) b]
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m [Chain (t a) b] -> m [Chain (t a) b])
-> Effect m [Chain (t a) b] -> m [Chain (t a) b]
forall a b. (a -> b) -> a -> b
$
        Transition m (Chain (t a) b)
-> Chain (t a) b
-> Gen (PrimState m)
-> Producer (Chain (t a) b) m [Chain (t a) b]
forall (m :: * -> *) b a.
PrimMonad m =>
Transition m b -> b -> Gen (PrimState m) -> Producer b m a
drive Transition m (Chain (t a) b)
transition Chain :: forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain {t a
Double
Maybe b
Target (t a)
forall a. Maybe a
chainTunables :: forall a. Maybe a
chainScore :: Double
chainTarget :: Target (t a)
chainPosition :: t a
chainTarget :: Target (t a)
chainScore :: Double
chainPosition :: t a
chainTunables :: Maybe b
..} Gen (PrimState m)
gen
    Producer (Chain (t a) b) m [Chain (t a) b]
-> Proxy () (Chain (t a) b) () X m [Chain (t a) b]
-> Effect m [Chain (t a) b]
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int -> Proxy () (Chain (t a) b) () X m [Chain (t a) b]
forall (m :: * -> *) a. Monad m => Int -> Consumer a m [a]
collect Int
n
  where
    chainScore :: Double
chainScore    = Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget t a
chainPosition
    chainTunables :: Maybe a
chainTunables = Maybe a
forall a. Maybe a
Nothing

    collect :: Monad m => Int -> Consumer a m [a]
    collect :: Int -> Consumer a m [a]
collect Int
size = Codensity (Proxy () a () X m) [a] -> Consumer a m [a]
forall (f :: * -> *) a. Applicative f => Codensity f a -> f a
lowerCodensity (Codensity (Proxy () a () X m) [a] -> Consumer a m [a])
-> Codensity (Proxy () a () X m) [a] -> Consumer a m [a]
forall a b. (a -> b) -> a -> b
$
      Int
-> Codensity (Proxy () a () X m) a
-> Codensity (Proxy () a () X m) [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
size (Proxy () a () X m a -> Codensity (Proxy () a () X m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Proxy () a () X m a
forall (m :: * -> *) a. Functor m => Consumer' a m a
Pipes.await)

-- A Markov chain driven by an arbitrary transition operator.
drive
  :: PrimMonad m
  => Transition m b
  -> b
  -> Gen (PrimState m)
  -> Producer b m a
drive :: Transition m b -> b -> Gen (PrimState m) -> Producer b m a
drive Transition m b
transition = b -> Gen (PrimState m) -> Producer b m a
forall x' x b. b -> Gen (PrimState m) -> Proxy x' x () b m b
loop where
  loop :: b -> Gen (PrimState m) -> Proxy x' x () b m b
loop b
state Gen (PrimState m)
prng = do
    b
next <- m b -> Proxy x' x () b m b
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Prob m b -> Gen (PrimState m) -> m b
forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
MWC.sample (Transition m b -> b -> Prob m b
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT Transition m b
transition b
state) Gen (PrimState m)
prng)
    b -> Proxy x' x () b m ()
forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield b
next
    b -> Gen (PrimState m) -> Proxy x' x () b m b
loop b
next Gen (PrimState m)
prng

-- | A generic Metropolis transition operator.
metropolis
  :: (Traversable f, PrimMonad m)
  => Double
  -> Transition m (Chain (f Double) b)
metropolis :: Double -> Transition m (Chain (f Double) b)
metropolis Double
radial = Double
-> Maybe (f Double -> b) -> Transition m (Chain (f Double) b)
forall (f :: * -> *) (m :: * -> *) b.
(Traversable f, PrimMonad m) =>
Double
-> Maybe (f Double -> b) -> Transition m (Chain (f Double) b)
M.metropolis Double
radial Maybe (f Double -> b)
forall a. Maybe a
Nothing