{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}

-- |
-- Module      : Control.Monad.Bayes.Density.Free
-- Description : Free monad transformer over random sampling
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- 'Density' is a free monad transformer over random sampling.
module Control.Monad.Bayes.Density.Free
  ( Density,
    hoist,
    interpret,
    withRandomness,
    density,
    traced,
  )
where

import Control.Monad.Bayes.Class (MonadDistribution (random))
import Control.Monad.RWS
import Control.Monad.State (evalStateT)
import Control.Monad.Trans.Free.Church (FT, MonadFree (..), hoistFT, iterT, iterTM, liftF)
import Control.Monad.Writer (WriterT (..))
import Data.Functor.Identity (Identity, runIdentity)

-- | Random sampling functor.
newtype SamF a = Random (Double -> a) deriving (forall a b. a -> SamF b -> SamF a
forall a b. (a -> b) -> SamF a -> SamF b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> SamF b -> SamF a
$c<$ :: forall a b. a -> SamF b -> SamF a
fmap :: forall a b. (a -> b) -> SamF a -> SamF b
$cfmap :: forall a b. (a -> b) -> SamF a -> SamF b
Functor)

-- | Free monad transformer over random sampling.
--
-- Uses the Church-encoded version of the free monad for efficiency.
newtype Density m a = Density {forall (m :: * -> *) a. Density m a -> FT SamF m a
runDensity :: FT SamF m a}
  deriving newtype (forall a b. a -> Density m b -> Density m a
forall a b. (a -> b) -> Density m a -> Density m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (m :: * -> *) a b. a -> Density m b -> Density m a
forall (m :: * -> *) a b. (a -> b) -> Density m a -> Density m b
<$ :: forall a b. a -> Density m b -> Density m a
$c<$ :: forall (m :: * -> *) a b. a -> Density m b -> Density m a
fmap :: forall a b. (a -> b) -> Density m a -> Density m b
$cfmap :: forall (m :: * -> *) a b. (a -> b) -> Density m a -> Density m b
Functor, forall a. a -> Density m a
forall a b. Density m a -> Density m b -> Density m a
forall a b. Density m a -> Density m b -> Density m b
forall a b. Density m (a -> b) -> Density m a -> Density m b
forall a b c.
(a -> b -> c) -> Density m a -> Density m b -> Density m c
forall (m :: * -> *). Functor (Density m)
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall (m :: * -> *) a. a -> Density m a
forall (m :: * -> *) a b. Density m a -> Density m b -> Density m a
forall (m :: * -> *) a b. Density m a -> Density m b -> Density m b
forall (m :: * -> *) a b.
Density m (a -> b) -> Density m a -> Density m b
forall (m :: * -> *) a b c.
(a -> b -> c) -> Density m a -> Density m b -> Density m c
<* :: forall a b. Density m a -> Density m b -> Density m a
$c<* :: forall (m :: * -> *) a b. Density m a -> Density m b -> Density m a
*> :: forall a b. Density m a -> Density m b -> Density m b
$c*> :: forall (m :: * -> *) a b. Density m a -> Density m b -> Density m b
liftA2 :: forall a b c.
(a -> b -> c) -> Density m a -> Density m b -> Density m c
$cliftA2 :: forall (m :: * -> *) a b c.
(a -> b -> c) -> Density m a -> Density m b -> Density m c
<*> :: forall a b. Density m (a -> b) -> Density m a -> Density m b
$c<*> :: forall (m :: * -> *) a b.
Density m (a -> b) -> Density m a -> Density m b
pure :: forall a. a -> Density m a
$cpure :: forall (m :: * -> *) a. a -> Density m a
Applicative, forall a. a -> Density m a
forall a b. Density m a -> Density m b -> Density m b
forall a b. Density m a -> (a -> Density m b) -> Density m b
forall (m :: * -> *). Applicative (Density m)
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
forall (m :: * -> *) a. a -> Density m a
forall (m :: * -> *) a b. Density m a -> Density m b -> Density m b
forall (m :: * -> *) a b.
Density m a -> (a -> Density m b) -> Density m b
return :: forall a. a -> Density m a
$creturn :: forall (m :: * -> *) a. a -> Density m a
>> :: forall a b. Density m a -> Density m b -> Density m b
$c>> :: forall (m :: * -> *) a b. Density m a -> Density m b -> Density m b
>>= :: forall a b. Density m a -> (a -> Density m b) -> Density m b
$c>>= :: forall (m :: * -> *) a b.
Density m a -> (a -> Density m b) -> Density m b
Monad, forall (m :: * -> *) a. Monad m => m a -> Density m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: forall (m :: * -> *) a. Monad m => m a -> Density m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> Density m a
MonadTrans)

instance MonadFree SamF (Density m) where
  wrap :: forall a. SamF (Density m a) -> Density m a
wrap = forall (m :: * -> *) a. FT SamF m a -> Density m a
Density forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *) a. Density m a -> FT SamF m a
runDensity

instance Monad m => MonadDistribution (Density m) where
  random :: Density m Double
random = forall (m :: * -> *) a. FT SamF m a -> Density m a
Density forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF (forall a. (Double -> a) -> SamF a
Random forall a. a -> a
id)

-- | Hoist 'Density' through a monad transform.
hoist :: (Monad m, Monad n) => (forall x. m x -> n x) -> Density m a -> Density n a
hoist :: forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> Density m a -> Density n a
hoist forall x. m x -> n x
f (Density FT SamF m a
m) = forall (m :: * -> *) a. FT SamF m a -> Density m a
Density (forall (m :: * -> *) (n :: * -> *) (f :: * -> *) b.
(Monad m, Monad n) =>
(forall a. m a -> n a) -> FT f m b -> FT f n b
hoistFT forall x. m x -> n x
f FT SamF m a
m)

-- | Execute random sampling in the transformed monad.
interpret :: MonadDistribution m => Density m a -> m a
interpret :: forall (m :: * -> *) a. MonadDistribution m => Density m a -> m a
interpret (Density FT SamF m a
m) = forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(f (m a) -> m a) -> FT f m a -> m a
iterT forall {m :: * -> *} {b}. MonadDistribution m => SamF (m b) -> m b
f FT SamF m a
m
  where
    f :: SamF (m b) -> m b
f (Random Double -> m b
k) = forall (m :: * -> *). MonadDistribution m => m Double
random forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Double -> m b
k

-- | Execute computation with supplied values for random choices.
withRandomness :: Monad m => [Double] -> Density m a -> m a
withRandomness :: forall (m :: * -> *) a. Monad m => [Double] -> Density m a -> m a
withRandomness [Double]
randomness (Density FT SamF m a
m) = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (f :: * -> *) (m :: * -> *) (t :: (* -> *) -> * -> *) a.
(Functor f, Monad m, MonadTrans t, Monad (t m)) =>
(f (t m a) -> t m a) -> FT f m a -> t m a
iterTM forall {m :: * -> *} {b}.
MonadState [Double] m =>
SamF (m b) -> m b
f FT SamF m a
m) [Double]
randomness
  where
    f :: SamF (m b) -> m b
f (Random Double -> m b
k) = do
      [Double]
xs <- forall s (m :: * -> *). MonadState s m => m s
get
      case [Double]
xs of
        [] -> forall a. HasCallStack => [Char] -> a
error [Char]
"Density: the list of randomness was too short"
        Double
y : [Double]
ys -> forall s (m :: * -> *). MonadState s m => s -> m ()
put [Double]
ys forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Double -> m b
k Double
y

-- | Execute computation with supplied values for a subset of random choices.
-- Return the output value and a record of all random choices used, whether
-- taken as input or drawn using the transformed monad.
density :: MonadDistribution m => [Double] -> Density m a -> m (a, [Double])
density :: forall (m :: * -> *) a.
MonadDistribution m =>
[Double] -> Density m a -> m (a, [Double])
density [Double]
randomness (Density FT SamF m a
m) =
  forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (f :: * -> *) (m :: * -> *) (t :: (* -> *) -> * -> *) a.
(Functor f, Monad m, MonadTrans t, Monad (t m)) =>
(f (t m a) -> t m a) -> FT f m a -> t m a
iterTM forall {m :: * -> *} {b}.
(MonadState [Double] m, MonadDistribution m,
 MonadWriter [Double] m) =>
SamF (m b) -> m b
f forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (n :: * -> *) (f :: * -> *) b.
(Monad m, Monad n) =>
(forall a. m a -> n a) -> FT f m b -> FT f n b
hoistFT forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift FT SamF m a
m) [Double]
randomness
  where
    f :: SamF (m b) -> m b
f (Random Double -> m b
k) = do
      -- This block runs in StateT [Double] (WriterT [Double]) m.
      -- StateT propagates consumed randomness while WriterT records
      -- randomness used, whether old or new.
      [Double]
xs <- forall s (m :: * -> *). MonadState s m => m s
get
      Double
x <- case [Double]
xs of
        [] -> forall (m :: * -> *). MonadDistribution m => m Double
random
        Double
y : [Double]
ys -> forall s (m :: * -> *). MonadState s m => s -> m ()
put [Double]
ys forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return Double
y
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Double
x]
      Double -> m b
k Double
x

-- | Like 'density', but use an arbitrary sampling monad.
traced :: MonadDistribution m => [Double] -> Density Identity a -> m (a, [Double])
traced :: forall (m :: * -> *) a.
MonadDistribution m =>
[Double] -> Density Identity a -> m (a, [Double])
traced [Double]
randomness Density Identity a
m = forall (m :: * -> *) a.
MonadDistribution m =>
[Double] -> Density m a -> m (a, [Double])
density [Double]
randomness forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> Density m a -> Density n a
hoist (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Identity a -> a
runIdentity) Density Identity a
m