{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | This is a port of the implementation of LazyPPL: https://lazyppl.bitbucket.io/
module Control.Monad.Bayes.Sampler.Lazy where

import Control.Monad (ap)
import Control.Monad.Bayes.Class (MonadDistribution (random))
import Control.Monad.Bayes.Weighted (WeightedT, runWeightedT)
import Control.Monad.IO.Class
import Control.Monad.Identity (Identity (runIdentity))
import Control.Monad.Trans
import Numeric.Log (Log (..))
import System.Random
  ( RandomGen (split),
    getStdGen,
    newStdGen,
  )
import System.Random qualified as R

-- | A 'Tree' is a lazy, infinitely wide and infinitely deep tree, labelled by Doubles.
--
--   Our source of randomness will be a Tree, populated by uniform [0,1] choices for each label.
--   Often people just use a list or stream instead of a tree.
--   But a tree allows us to be lazy about how far we are going all the time.
data Tree = Tree
  { Tree -> Double
currentUniform :: Double,
    Tree -> Trees
lazyUniforms :: Trees
  }

-- | An infinite stream of 'Tree's.
data Trees = Trees
  { Trees -> Tree
headTree :: Tree,
    Trees -> Trees
tailTrees :: Trees
  }

-- | A probability distribution over @a@ is a function 'Tree -> a'.
--   The idea is that it uses up bits of the tree as it runs.
type Sampler = SamplerT Identity

runSampler :: Sampler a -> Tree -> a
runSampler :: forall a. Sampler a -> Tree -> a
runSampler = (Identity a -> a
forall a. Identity a -> a
runIdentity (Identity a -> a) -> (Tree -> Identity a) -> Tree -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((Tree -> Identity a) -> Tree -> a)
-> (Sampler a -> Tree -> Identity a) -> Sampler a -> Tree -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sampler a -> Tree -> Identity a
forall (m :: * -> *) a. SamplerT m a -> Tree -> m a
runSamplerT

newtype SamplerT m a = SamplerT {forall (m :: * -> *) a. SamplerT m a -> Tree -> m a
runSamplerT :: Tree -> m a}
  deriving ((forall a b. (a -> b) -> SamplerT m a -> SamplerT m b)
-> (forall a b. a -> SamplerT m b -> SamplerT m a)
-> Functor (SamplerT m)
forall a b. a -> SamplerT m b -> SamplerT m a
forall a b. (a -> b) -> SamplerT m a -> SamplerT m b
forall (m :: * -> *) a b.
Functor m =>
a -> SamplerT m b -> SamplerT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SamplerT m a -> SamplerT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SamplerT m a -> SamplerT m b
fmap :: forall a b. (a -> b) -> SamplerT m a -> SamplerT m b
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> SamplerT m b -> SamplerT m a
<$ :: forall a b. a -> SamplerT m b -> SamplerT m a
Functor)

-- | Split a tree in two (bijectively).
splitTree :: Tree -> (Tree, Tree)
splitTree :: Tree -> (Tree, Tree)
splitTree (Tree Double
r (Trees Tree
t Trees
ts)) = (Tree
t, Double -> Trees -> Tree
Tree Double
r Trees
ts)

-- | Generate a tree with uniform random labels.
--
-- Preliminary for the simulation methods. This uses 'split' to split a random seed.
randomTree :: (RandomGen g) => g -> Tree
randomTree :: forall g. RandomGen g => g -> Tree
randomTree g
g = let (Double
a, g
g') = g -> (Double, g)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
R.random g
g in Double -> Trees -> Tree
Tree Double
a (g -> Trees
forall g. RandomGen g => g -> Trees
randomTrees g
g')

randomTrees :: (RandomGen g) => g -> Trees
randomTrees :: forall g. RandomGen g => g -> Trees
randomTrees g
g = let (g
g1, g
g2) = g -> (g, g)
forall g. RandomGen g => g -> (g, g)
split g
g in Tree -> Trees -> Trees
Trees (g -> Tree
forall g. RandomGen g => g -> Tree
randomTree g
g1) (g -> Trees
forall g. RandomGen g => g -> Trees
randomTrees g
g2)

instance (Monad m) => Applicative (SamplerT m) where
  pure :: forall a. a -> SamplerT m a
pure = m a -> SamplerT m a
forall (m :: * -> *) a. Monad m => m a -> SamplerT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> SamplerT m a) -> (a -> m a) -> a -> SamplerT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  <*> :: forall a b. SamplerT m (a -> b) -> SamplerT m a -> SamplerT m b
(<*>) = SamplerT m (a -> b) -> SamplerT m a -> SamplerT m b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

-- | Sequencing is done by splitting the tree
--   and using different bits for different computations.
instance (Monad m) => Monad (SamplerT m) where
  return :: forall a. a -> SamplerT m a
return = a -> SamplerT m a
forall a. a -> SamplerT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  (SamplerT Tree -> m a
m) >>= :: forall a b. SamplerT m a -> (a -> SamplerT m b) -> SamplerT m b
>>= a -> SamplerT m b
f = (Tree -> m b) -> SamplerT m b
forall (m :: * -> *) a. (Tree -> m a) -> SamplerT m a
SamplerT \Tree
g -> do
    let (Tree
g1, Tree
g2) = Tree -> (Tree, Tree)
splitTree Tree
g
    a
a <- Tree -> m a
m Tree
g1
    let SamplerT Tree -> m b
m' = a -> SamplerT m b
f a
a
    Tree -> m b
m' Tree
g2

instance MonadTrans SamplerT where
  lift :: forall (m :: * -> *) a. Monad m => m a -> SamplerT m a
lift = (Tree -> m a) -> SamplerT m a
forall (m :: * -> *) a. (Tree -> m a) -> SamplerT m a
SamplerT ((Tree -> m a) -> SamplerT m a)
-> (m a -> Tree -> m a) -> m a -> SamplerT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> Tree -> m a
forall a b. a -> b -> a
const

instance (MonadIO m) => MonadIO (SamplerT m) where
  liftIO :: forall a. IO a -> SamplerT m a
liftIO = m a -> SamplerT m a
forall (m :: * -> *) a. Monad m => m a -> SamplerT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> SamplerT m a) -> (IO a -> m a) -> IO a -> SamplerT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

-- | Sampling gets the label at the head of the tree and discards the rest.
instance (Monad m) => MonadDistribution (SamplerT m) where
  random :: SamplerT m Double
random = (Tree -> m Double) -> SamplerT m Double
forall (m :: * -> *) a. (Tree -> m a) -> SamplerT m a
SamplerT \(Tree Double
r Trees
_) -> Double -> m Double
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
r

-- | Runs a 'SamplerT' by creating a new 'StdGen'.
runSamplerTIO :: (MonadIO m) => SamplerT m a -> m a
runSamplerTIO :: forall (m :: * -> *) a. MonadIO m => SamplerT m a -> m a
runSamplerTIO SamplerT m a
m = IO StdGen -> m StdGen
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
newStdGen m StdGen -> m a -> m a
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (SamplerT m a -> Tree -> m a
forall (m :: * -> *) a. SamplerT m a -> Tree -> m a
runSamplerT SamplerT m a
m (Tree -> m a) -> m Tree -> m a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< StdGen -> Tree
forall g. RandomGen g => g -> Tree
randomTree (StdGen -> Tree) -> m StdGen -> m Tree
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO StdGen -> m StdGen
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
getStdGen)

-- | Draw a stream of independent samples.
independent :: (Monad m) => m a -> m [a]
independent :: forall (m :: * -> *) a. Monad m => m a -> m [a]
independent = [m a] -> m [a]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence ([m a] -> m [a]) -> (m a -> [m a]) -> m a -> m [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> [m a]
forall a. a -> [a]
repeat

-- | Runs a probability measure and gets out a stream of @(result,weight)@ pairs
weightedSamples :: (MonadIO m) => WeightedT (SamplerT m) a -> m [(a, Log Double)]
weightedSamples :: forall (m :: * -> *) a.
MonadIO m =>
WeightedT (SamplerT m) a -> m [(a, Log Double)]
weightedSamples = SamplerT m [(a, Log Double)] -> m [(a, Log Double)]
forall (m :: * -> *) a. MonadIO m => SamplerT m a -> m a
runSamplerTIO (SamplerT m [(a, Log Double)] -> m [(a, Log Double)])
-> (WeightedT (SamplerT m) a -> SamplerT m [(a, Log Double)])
-> WeightedT (SamplerT m) a
-> m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SamplerT m (a, Log Double)] -> SamplerT m [(a, Log Double)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence ([SamplerT m (a, Log Double)] -> SamplerT m [(a, Log Double)])
-> (WeightedT (SamplerT m) a -> [SamplerT m (a, Log Double)])
-> WeightedT (SamplerT m) a
-> SamplerT m [(a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SamplerT m (a, Log Double) -> [SamplerT m (a, Log Double)]
forall a. a -> [a]
repeat (SamplerT m (a, Log Double) -> [SamplerT m (a, Log Double)])
-> (WeightedT (SamplerT m) a -> SamplerT m (a, Log Double))
-> WeightedT (SamplerT m) a
-> [SamplerT m (a, Log Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WeightedT (SamplerT m) a -> SamplerT m (a, Log Double)
forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT