{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}

-- |
-- Module      : Control.Monad.Bayes.Weighted
-- Description : Probability monad accumulating the likelihood
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- 'Weighted' is an instance of 'MonadFactor'. Apply a 'MonadDistribution' transformer to
-- obtain a 'MonadMeasure' that can execute probabilistic models.
module Control.Monad.Bayes.Weighted
  ( Weighted,
    withWeight,
    weighted,
    extractWeight,
    unweighted,
    applyWeight,
    hoist,
    runWeighted,
  )
where

import Control.Monad.Bayes.Class
  ( MonadDistribution,
    MonadFactor (..),
    MonadMeasure,
    factor,
  )
import Control.Monad.State (MonadIO, MonadTrans, StateT (..), lift, mapStateT, modify)
import Numeric.Log (Log)

-- | Execute the program using the prior distribution, while accumulating likelihood.
newtype Weighted m a = Weighted (StateT (Log Double) m a)
  -- StateT is more efficient than WriterT
  deriving newtype (forall a b. a -> Weighted m b -> Weighted m a
forall a b. (a -> b) -> Weighted m a -> Weighted m b
forall (m :: * -> *) a b.
Functor m =>
a -> Weighted m b -> Weighted m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> Weighted m a -> Weighted m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> Weighted m b -> Weighted m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> Weighted m b -> Weighted m a
fmap :: forall a b. (a -> b) -> Weighted m a -> Weighted m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> Weighted m a -> Weighted m b
Functor, forall a. a -> Weighted m a
forall a b. Weighted m a -> Weighted m b -> Weighted m a
forall a b. Weighted m a -> Weighted m b -> Weighted m b
forall a b. Weighted m (a -> b) -> Weighted m a -> Weighted m b
forall a b c.
(a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c
forall {m :: * -> *}. Monad m => Functor (Weighted m)
forall (m :: * -> *) a. Monad m => a -> Weighted m a
forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m a
forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m b
forall (m :: * -> *) a b.
Monad m =>
Weighted m (a -> b) -> Weighted m a -> Weighted m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. Weighted m a -> Weighted m b -> Weighted m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m a
*> :: forall a b. Weighted m a -> Weighted m b -> Weighted m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m b
liftA2 :: forall a b c.
(a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c
<*> :: forall a b. Weighted m (a -> b) -> Weighted m a -> Weighted m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
Weighted m (a -> b) -> Weighted m a -> Weighted m b
pure :: forall a. a -> Weighted m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> Weighted m a
Applicative, forall a. a -> Weighted m a
forall a b. Weighted m a -> Weighted m b -> Weighted m b
forall a b. Weighted m a -> (a -> Weighted m b) -> Weighted m b
forall (m :: * -> *). Monad m => Applicative (Weighted m)
forall (m :: * -> *) a. Monad m => a -> Weighted m a
forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m b
forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> (a -> Weighted m b) -> Weighted m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> Weighted m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> Weighted m a
>> :: forall a b. Weighted m a -> Weighted m b -> Weighted m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m b
>>= :: forall a b. Weighted m a -> (a -> Weighted m b) -> Weighted m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> (a -> Weighted m b) -> Weighted m b
Monad, forall a. IO a -> Weighted m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
forall {m :: * -> *}. MonadIO m => Monad (Weighted m)
forall (m :: * -> *) a. MonadIO m => IO a -> Weighted m a
liftIO :: forall a. IO a -> Weighted m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> Weighted m a
MonadIO, forall (m :: * -> *) a. Monad m => m a -> Weighted m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: forall (m :: * -> *) a. Monad m => m a -> Weighted m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> Weighted m a
MonadTrans, Weighted m Double
Double -> Weighted m Bool
Double -> Weighted m Int
Double -> Double -> Weighted m Double
forall a. [a] -> Weighted m a
forall (m :: * -> *).
Monad m
-> m Double
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> m Bool)
-> (forall (v :: * -> *). Vector v Double => v Double -> m Int)
-> (forall (v :: * -> *).
    (Vector v (Log Double), Vector v Double) =>
    v (Log Double) -> m Int)
-> (forall a. [a] -> m a)
-> (Double -> m Int)
-> (Double -> m Int)
-> (forall (v :: * -> *).
    Vector v Double =>
    v Double -> m (v Double))
-> MonadDistribution m
forall (v :: * -> *).
Vector v Double =>
v Double -> Weighted m (v Double)
forall (v :: * -> *). Vector v Double => v Double -> Weighted m Int
forall (v :: * -> *).
(Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Weighted m Int
forall {m :: * -> *}. MonadDistribution m => Monad (Weighted m)
forall (m :: * -> *). MonadDistribution m => Weighted m Double
forall (m :: * -> *).
MonadDistribution m =>
Double -> Weighted m Bool
forall (m :: * -> *).
MonadDistribution m =>
Double -> Weighted m Int
forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> Weighted m Double
forall (m :: * -> *) a. MonadDistribution m => [a] -> Weighted m a
forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> Weighted m (v Double)
forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> Weighted m Int
forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Weighted m Int
dirichlet :: forall (v :: * -> *).
Vector v Double =>
v Double -> Weighted m (v Double)
$cdirichlet :: forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> Weighted m (v Double)
poisson :: Double -> Weighted m Int
$cpoisson :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Weighted m Int
geometric :: Double -> Weighted m Int
$cgeometric :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Weighted m Int
uniformD :: forall a. [a] -> Weighted m a
$cuniformD :: forall (m :: * -> *) a. MonadDistribution m => [a] -> Weighted m a
logCategorical :: forall (v :: * -> *).
(Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Weighted m Int
$clogCategorical :: forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Weighted m Int
categorical :: forall (v :: * -> *). Vector v Double => v Double -> Weighted m Int
$ccategorical :: forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> Weighted m Int
bernoulli :: Double -> Weighted m Bool
$cbernoulli :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Weighted m Bool
beta :: Double -> Double -> Weighted m Double
$cbeta :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> Weighted m Double
gamma :: Double -> Double -> Weighted m Double
$cgamma :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> Weighted m Double
normal :: Double -> Double -> Weighted m Double
$cnormal :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> Weighted m Double
uniform :: Double -> Double -> Weighted m Double
$cuniform :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> Weighted m Double
random :: Weighted m Double
$crandom :: forall (m :: * -> *). MonadDistribution m => Weighted m Double
MonadDistribution)

instance Monad m => MonadFactor (Weighted m) where
  score :: Log Double -> Weighted m ()
score Log Double
w = forall (m :: * -> *) a. StateT (Log Double) m a -> Weighted m a
Weighted (forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Num a => a -> a -> a
* Log Double
w))

instance MonadDistribution m => MonadMeasure (Weighted m)

-- | Obtain an explicit value of the likelihood for a given value.
weighted, runWeighted :: Weighted m a -> m (a, Log Double)
weighted :: forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted (Weighted StateT (Log Double) m a
m) = forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT (Log Double) m a
m Log Double
1
runWeighted :: forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
runWeighted = forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted

-- | Compute the sample and discard the weight.
--
-- This operation introduces bias.
unweighted :: Functor m => Weighted m a -> m a
unweighted :: forall (m :: * -> *) a. Functor m => Weighted m a -> m a
unweighted = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted

-- | Compute the weight and discard the sample.
extractWeight :: Functor m => Weighted m a -> m (Log Double)
extractWeight :: forall (m :: * -> *) a. Functor m => Weighted m a -> m (Log Double)
extractWeight = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted

-- | Embed a random variable with explicitly given likelihood.
--
-- > weighted . withWeight = id
withWeight :: (Monad m) => m (a, Log Double) -> Weighted m a
withWeight :: forall (m :: * -> *) a.
Monad m =>
m (a, Log Double) -> Weighted m a
withWeight m (a, Log Double)
m = forall (m :: * -> *) a. StateT (Log Double) m a -> Weighted m a
Weighted forall a b. (a -> b) -> a -> b
$ do
  (a
x, Log Double
w) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (a, Log Double)
m
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Num a => a -> a -> a
* Log Double
w)
  forall (m :: * -> *) a. Monad m => a -> m a
return a
x

-- | Use the weight as a factor in the transformed monad.
applyWeight :: MonadFactor m => Weighted m a -> m a
applyWeight :: forall (m :: * -> *) a. MonadFactor m => Weighted m a -> m a
applyWeight Weighted m a
m = do
  (a
x, Log Double
w) <- forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted Weighted m a
m
  forall (m :: * -> *). MonadFactor m => Log Double -> m ()
factor Log Double
w
  forall (m :: * -> *) a. Monad m => a -> m a
return a
x

-- | Apply a transformation to the transformed monad.
hoist :: (forall x. m x -> n x) -> Weighted m a -> Weighted n a
hoist :: forall (m :: * -> *) (n :: * -> *) a.
(forall x. m x -> n x) -> Weighted m a -> Weighted n a
hoist forall x. m x -> n x
t (Weighted StateT (Log Double) m a
m) = forall (m :: * -> *) a. StateT (Log Double) m a -> Weighted m a
Weighted forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT forall x. m x -> n x
t StateT (Log Double) m a
m