{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}

-- |
-- Module      : Control.Monad.Bayes.Enumerator
-- Description : Exhaustive enumeration of discrete random variables
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
module Control.Monad.Bayes.Enumerator
  ( Enumerator,
    logExplicit,
    explicit,
    evidence,
    mass,
    compact,
    enumerator,
    enumerate,
    expectation,
    normalForm,
    toEmpirical,
    toEmpiricalWeighted,
    normalizeWeights,
    enumerateToDistribution,
    removeZeros,
    fromList,
  )
where

import Control.Applicative (Alternative)
import Control.Arrow (second)
import Control.Monad.Bayes.Class
  ( MonadDistribution (bernoulli, categorical, logCategorical, random),
    MonadFactor (..),
    MonadMeasure,
  )
import Control.Monad.Writer
import Data.AEq (AEq, (===), (~==))
import Data.List (sortOn)
import Data.Map qualified as Map
import Data.Maybe (fromMaybe)
import Data.Ord (Down (Down))
import Data.Vector qualified as VV
import Data.Vector.Generic qualified as V
import Numeric.Log as Log (Log (..), sum)

-- | An exact inference transformer that integrates
-- discrete random variables by enumerating all execution paths.
newtype Enumerator a = Enumerator (WriterT (Product (Log Double)) [] a)
  deriving newtype (forall a b. a -> Enumerator b -> Enumerator a
forall a b. (a -> b) -> Enumerator a -> Enumerator b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> Enumerator b -> Enumerator a
$c<$ :: forall a b. a -> Enumerator b -> Enumerator a
fmap :: forall a b. (a -> b) -> Enumerator a -> Enumerator b
$cfmap :: forall a b. (a -> b) -> Enumerator a -> Enumerator b
Functor, Functor Enumerator
forall a. a -> Enumerator a
forall a b. Enumerator a -> Enumerator b -> Enumerator a
forall a b. Enumerator a -> Enumerator b -> Enumerator b
forall a b. Enumerator (a -> b) -> Enumerator a -> Enumerator b
forall a b c.
(a -> b -> c) -> Enumerator a -> Enumerator b -> Enumerator c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. Enumerator a -> Enumerator b -> Enumerator a
$c<* :: forall a b. Enumerator a -> Enumerator b -> Enumerator a
*> :: forall a b. Enumerator a -> Enumerator b -> Enumerator b
$c*> :: forall a b. Enumerator a -> Enumerator b -> Enumerator b
liftA2 :: forall a b c.
(a -> b -> c) -> Enumerator a -> Enumerator b -> Enumerator c
$cliftA2 :: forall a b c.
(a -> b -> c) -> Enumerator a -> Enumerator b -> Enumerator c
<*> :: forall a b. Enumerator (a -> b) -> Enumerator a -> Enumerator b
$c<*> :: forall a b. Enumerator (a -> b) -> Enumerator a -> Enumerator b
pure :: forall a. a -> Enumerator a
$cpure :: forall a. a -> Enumerator a
Applicative, Applicative Enumerator
forall a. a -> Enumerator a
forall a b. Enumerator a -> Enumerator b -> Enumerator b
forall a b. Enumerator a -> (a -> Enumerator b) -> Enumerator b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> Enumerator a
$creturn :: forall a. a -> Enumerator a
>> :: forall a b. Enumerator a -> Enumerator b -> Enumerator b
$c>> :: forall a b. Enumerator a -> Enumerator b -> Enumerator b
>>= :: forall a b. Enumerator a -> (a -> Enumerator b) -> Enumerator b
$c>>= :: forall a b. Enumerator a -> (a -> Enumerator b) -> Enumerator b
Monad, Applicative Enumerator
forall a. Enumerator a
forall a. Enumerator a -> Enumerator [a]
forall a. Enumerator a -> Enumerator a -> Enumerator a
forall (f :: * -> *).
Applicative f
-> (forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
many :: forall a. Enumerator a -> Enumerator [a]
$cmany :: forall a. Enumerator a -> Enumerator [a]
some :: forall a. Enumerator a -> Enumerator [a]
$csome :: forall a. Enumerator a -> Enumerator [a]
<|> :: forall a. Enumerator a -> Enumerator a -> Enumerator a
$c<|> :: forall a. Enumerator a -> Enumerator a -> Enumerator a
empty :: forall a. Enumerator a
$cempty :: forall a. Enumerator a
Alternative, Monad Enumerator
Alternative Enumerator
forall a. Enumerator a
forall a. Enumerator a -> Enumerator a -> Enumerator a
forall (m :: * -> *).
Alternative m
-> Monad m
-> (forall a. m a)
-> (forall a. m a -> m a -> m a)
-> MonadPlus m
mplus :: forall a. Enumerator a -> Enumerator a -> Enumerator a
$cmplus :: forall a. Enumerator a -> Enumerator a -> Enumerator a
mzero :: forall a. Enumerator a
$cmzero :: forall a. Enumerator a
MonadPlus)

instance MonadDistribution Enumerator where
  random :: Enumerator Double
random = forall a. HasCallStack => [Char] -> a
error [Char]
"Infinitely supported random variables not supported in Enumerator"
  bernoulli :: Double -> Enumerator Bool
bernoulli Double
p = forall a. [(a, Log Double)] -> Enumerator a
fromList [(Bool
True, (forall a. a -> Log a
Exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
log) Double
p), (Bool
False, (forall a. a -> Log a
Exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
log) (Double
1 forall a. Num a => a -> a -> a
- Double
p))]
  categorical :: forall (v :: * -> *). Vector v Double => v Double -> Enumerator Int
categorical v Double
v = forall a. [(a, Log Double)] -> Enumerator a
fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. a -> Log a
Exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
log) (forall (v :: * -> *) a. Vector v a => v a -> [a]
V.toList v Double
v)

instance MonadFactor Enumerator where
  score :: Log Double -> Enumerator ()
score Log Double
w = forall a. [(a, Log Double)] -> Enumerator a
fromList [((), Log Double
w)]

instance MonadMeasure Enumerator

-- | Construct Enumerator from a list of values and associated weights.
fromList :: [(a, Log Double)] -> Enumerator a
fromList :: forall a. [(a, Log Double)] -> Enumerator a
fromList = forall a. WriterT (Product (Log Double)) [] a -> Enumerator a
Enumerator forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second forall a. a -> Product a
Product)

-- | Returns the posterior as a list of weight-value pairs without any post-processing,
-- such as normalization or aggregation
logExplicit :: Enumerator a -> [(a, Log Double)]
logExplicit :: forall a. Enumerator a -> [(a, Log Double)]
logExplicit (Enumerator WriterT (Product (Log Double)) [] a
m) = forall a b. (a -> b) -> [a] -> [b]
map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second forall a. Product a -> a
getProduct) forall a b. (a -> b) -> a -> b
$ forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT WriterT (Product (Log Double)) [] a
m

-- | Same as `toList`, only weights are converted from log-domain.
explicit :: Enumerator a -> [(a, Double)]
explicit :: forall a. Enumerator a -> [(a, Double)]
explicit = forall a b. (a -> b) -> [a] -> [b]
map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (forall a. Floating a => a -> a
exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Log a -> a
ln)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enumerator a -> [(a, Log Double)]
logExplicit

-- | Returns the model evidence, that is sum of all weights.
evidence :: Enumerator a -> Log Double
evidence :: forall a. Enumerator a -> Log Double
evidence = forall a (f :: * -> *).
(RealFloat a, Foldable f) =>
f (Log a) -> Log a
Log.sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enumerator a -> [(a, Log Double)]
logExplicit

-- | Normalized probability mass of a specific value.
mass :: Ord a => Enumerator a -> a -> Double
mass :: forall a. Ord a => Enumerator a -> a -> Double
mass Enumerator a
d = a -> Double
f
  where
    f :: a -> Double
f a
a = forall a. a -> Maybe a -> a
fromMaybe Double
0 forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup a
a [(a, Double)]
m
    m :: [(a, Double)]
m = forall a. Ord a => Enumerator a -> [(a, Double)]
enumerator Enumerator a
d

-- | Aggregate weights of equal values.
-- The resulting list is sorted ascendingly according to values.
compact :: (Num r, Ord a, Ord r) => [(a, r)] -> [(a, r)]
compact :: forall r a. (Num r, Ord a, Ord r) => [(a, r)] -> [(a, r)]
compact = forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (forall a. a -> Down a
Down forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
Map.toAscList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
Map.fromListWith forall a. Num a => a -> a -> a
(+)

-- | Aggregate and normalize of weights.
-- The resulting list is sorted ascendingly according to values.
--
-- > enumerator = compact . explicit
enumerator, enumerate :: Ord a => Enumerator a -> [(a, Double)]
enumerator :: forall a. Ord a => Enumerator a -> [(a, Double)]
enumerator Enumerator a
d = forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Eq a => a -> a -> Bool
/= Double
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$ forall r a. (Num r, Ord a, Ord r) => [(a, r)] -> [(a, r)]
compact (forall a b. [a] -> [b] -> [(a, b)]
zip [a]
xs [Double]
ws)
  where
    ([a]
xs, [Double]
ws) = forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (forall a b. (a -> b) -> [a] -> [b]
map (forall a. Floating a => a -> a
exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Log a -> a
ln) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b. Fractional b => [b] -> [b]
normalize) forall a b. (a -> b) -> a -> b
$ forall a b. [(a, b)] -> ([a], [b])
unzip (forall a. Enumerator a -> [(a, Log Double)]
logExplicit Enumerator a
d)

-- | deprecated synonym
enumerate :: forall a. Ord a => Enumerator a -> [(a, Double)]
enumerate = forall a. Ord a => Enumerator a -> [(a, Double)]
enumerator

-- | Expectation of a given function computed using normalized weights.
expectation :: (a -> Double) -> Enumerator a -> Double
expectation :: forall a. (a -> Double) -> Enumerator a -> Double
expectation a -> Double
f = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Prelude.sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (\(a
x, Log Double
w) -> a -> Double
f a
x forall a. Num a => a -> a -> a
* (forall a. Floating a => a -> a
exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Log a -> a
ln) Log Double
w) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. Fractional b => [(a, b)] -> [(a, b)]
normalizeWeights forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enumerator a -> [(a, Log Double)]
logExplicit

normalize :: Fractional b => [b] -> [b]
normalize :: forall b. Fractional b => [b] -> [b]
normalize [b]
xs = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Fractional a => a -> a -> a
/ b
z) [b]
xs
  where
    z :: b
z = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Prelude.sum [b]
xs

-- | Divide all weights by their sum.
normalizeWeights :: Fractional b => [(a, b)] -> [(a, b)]
normalizeWeights :: forall b a. Fractional b => [(a, b)] -> [(a, b)]
normalizeWeights [(a, b)]
ls = forall a b. [a] -> [b] -> [(a, b)]
zip [a]
xs [b]
ps
  where
    ([a]
xs, [b]
ws) = forall a b. [(a, b)] -> ([a], [b])
unzip [(a, b)]
ls
    ps :: [b]
ps = forall b. Fractional b => [b] -> [b]
normalize [b]
ws

-- | 'compact' followed by removing values with zero weight.
normalForm :: Ord a => Enumerator a -> [(a, Double)]
normalForm :: forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm = forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Eq a => a -> a -> Bool
/= Double
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r a. (Num r, Ord a, Ord r) => [(a, r)] -> [(a, r)]
compact forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enumerator a -> [(a, Double)]
explicit

toEmpirical :: (Fractional b, Ord a, Ord b) => [a] -> [(a, b)]
toEmpirical :: forall b a. (Fractional b, Ord a, Ord b) => [a] -> [(a, b)]
toEmpirical [a]
ls = forall b a. Fractional b => [(a, b)] -> [(a, b)]
normalizeWeights forall a b. (a -> b) -> a -> b
$ forall r a. (Num r, Ord a, Ord r) => [(a, r)] -> [(a, r)]
compact (forall a b. [a] -> [b] -> [(a, b)]
zip [a]
ls (forall a. a -> [a]
repeat b
1))

toEmpiricalWeighted :: (Fractional b, Ord a, Ord b) => [(a, b)] -> [(a, b)]
toEmpiricalWeighted :: forall b a. (Fractional b, Ord a, Ord b) => [(a, b)] -> [(a, b)]
toEmpiricalWeighted = forall b a. Fractional b => [(a, b)] -> [(a, b)]
normalizeWeights forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r a. (Num r, Ord a, Ord r) => [(a, r)] -> [(a, r)]
compact

enumerateToDistribution :: (MonadDistribution n) => Enumerator a -> n a
enumerateToDistribution :: forall (n :: * -> *) a. MonadDistribution n => Enumerator a -> n a
enumerateToDistribution Enumerator a
model = do
  let samples :: [(a, Log Double)]
samples = forall a. Enumerator a -> [(a, Log Double)]
logExplicit Enumerator a
model
  let ([a]
support, [Log Double]
logprobs) = forall a b. [(a, b)] -> ([a], [b])
unzip [(a, Log Double)]
samples
  Int
i <- forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v (Log Double), Vector v Double) =>
v (Log Double) -> m Int
logCategorical forall a b. (a -> b) -> a -> b
$ forall a. [a] -> Vector a
VV.fromList [Log Double]
logprobs
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [a]
support forall a. [a] -> Int -> a
!! Int
i

removeZeros :: Enumerator a -> Enumerator a
removeZeros :: forall a. Enumerator a -> Enumerator a
removeZeros (Enumerator (WriterT [(a, Product (Log Double))]
a)) = forall a. WriterT (Product (Log Double)) [] a -> Enumerator a
Enumerator forall a b. (a -> b) -> a -> b
$ forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((\(Product Log Double
x) -> Log Double
x forall a. Eq a => a -> a -> Bool
/= Log Double
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(a, Product (Log Double))]
a

instance Ord a => Eq (Enumerator a) where
  Enumerator a
p == :: Enumerator a -> Enumerator a -> Bool
== Enumerator a
q = forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
p forall a. Eq a => a -> a -> Bool
== forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
q

instance Ord a => AEq (Enumerator a) where
  Enumerator a
p === :: Enumerator a -> Enumerator a -> Bool
=== Enumerator a
q = [a]
xs forall a. Eq a => a -> a -> Bool
== [a]
ys Bool -> Bool -> Bool
&& [Double]
ps forall a. AEq a => a -> a -> Bool
=== [Double]
qs
    where
      ([a]
xs, [Double]
ps) = forall a b. [(a, b)] -> ([a], [b])
unzip (forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
p)
      ([a]
ys, [Double]
qs) = forall a b. [(a, b)] -> ([a], [b])
unzip (forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
q)
  Enumerator a
p ~== :: Enumerator a -> Enumerator a -> Bool
~== Enumerator a
q = [a]
xs forall a. Eq a => a -> a -> Bool
== [a]
ys Bool -> Bool -> Bool
&& [Double]
ps forall a. AEq a => a -> a -> Bool
~== [Double]
qs
    where
      ([a]
xs, [Double]
ps) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. AEq a => a -> a -> Bool
~== Double
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$ forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
p
      ([a]
ys, [Double]
qs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. AEq a => a -> a -> Bool
~== Double
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$ forall a. Ord a => Enumerator a -> [(a, Double)]
normalForm Enumerator a
q