{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# OPTIONS_GHC -Wno-type-defaults #-}
{-# OPTIONS_GHC -Wno-unused-top-binds #-}

-- |
-- This is adapted from https://jtobin.io/giry-monad-implementation
-- but brought into the monad-bayes framework (i.e. Integrator is an instance of MonadMeasure)
-- It's largely for debugging other inference methods and didactic use,
-- because brute force integration of measures is
-- only practical for small programs
module Control.Monad.Bayes.Integrator
  ( probability,
    variance,
    expectation,
    cdf,
    empirical,
    enumeratorWith,
    histogram,
    plotCdf,
    volume,
    normalize,
    Integrator,
    momentGeneratingFunction,
    cumulantGeneratingFunction,
    integrator,
    runIntegrator,
  )
where

import Control.Applicative (Applicative (..))
import Control.Foldl (Fold)
import Control.Foldl qualified as Foldl
import Control.Monad.Bayes.Class (MonadDistribution (bernoulli, random, uniformD))
import Control.Monad.Bayes.Weighted (Weighted, weighted)
import Control.Monad.Cont
  ( Cont,
    ContT (ContT),
    cont,
    runCont,
  )
import Data.Foldable (Foldable (foldl'))
import Data.Set (Set, elems)
import Numeric.Integration.TanhSinh (Result (result), trap)
import Numeric.Log (Log (ln))
import Statistics.Distribution qualified as Statistics
import Statistics.Distribution.Uniform qualified as Statistics

newtype Integrator a = Integrator {forall a. Integrator a -> Cont Double a
getCont :: Cont Double a}
  deriving newtype (forall a b. a -> Integrator b -> Integrator a
forall a b. (a -> b) -> Integrator a -> Integrator b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> Integrator b -> Integrator a
$c<$ :: forall a b. a -> Integrator b -> Integrator a
fmap :: forall a b. (a -> b) -> Integrator a -> Integrator b
$cfmap :: forall a b. (a -> b) -> Integrator a -> Integrator b
Functor, Functor Integrator
forall a. a -> Integrator a
forall a b. Integrator a -> Integrator b -> Integrator a
forall a b. Integrator a -> Integrator b -> Integrator b
forall a b. Integrator (a -> b) -> Integrator a -> Integrator b
forall a b c.
(a -> b -> c) -> Integrator a -> Integrator b -> Integrator c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. Integrator a -> Integrator b -> Integrator a
$c<* :: forall a b. Integrator a -> Integrator b -> Integrator a
*> :: forall a b. Integrator a -> Integrator b -> Integrator b
$c*> :: forall a b. Integrator a -> Integrator b -> Integrator b
liftA2 :: forall a b c.
(a -> b -> c) -> Integrator a -> Integrator b -> Integrator c
$cliftA2 :: forall a b c.
(a -> b -> c) -> Integrator a -> Integrator b -> Integrator c
<*> :: forall a b. Integrator (a -> b) -> Integrator a -> Integrator b
$c<*> :: forall a b. Integrator (a -> b) -> Integrator a -> Integrator b
pure :: forall a. a -> Integrator a
$cpure :: forall a. a -> Integrator a
Applicative, Applicative Integrator
forall a. a -> Integrator a
forall a b. Integrator a -> Integrator b -> Integrator b
forall a b. Integrator a -> (a -> Integrator b) -> Integrator b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> Integrator a
$creturn :: forall a. a -> Integrator a
>> :: forall a b. Integrator a -> Integrator b -> Integrator b
$c>> :: forall a b. Integrator a -> Integrator b -> Integrator b
>>= :: forall a b. Integrator a -> (a -> Integrator b) -> Integrator b
$c>>= :: forall a b. Integrator a -> (a -> Integrator b) -> Integrator b
Monad)

integrator, runIntegrator :: (a -> Double) -> Integrator a -> Double
integrator :: forall a. (a -> Double) -> Integrator a -> Double
integrator a -> Double
f (Integrator Cont Double a
a) = forall r a. Cont r a -> (a -> r) -> r
runCont Cont Double a
a a -> Double
f
runIntegrator :: forall a. (a -> Double) -> Integrator a -> Double
runIntegrator = forall a. (a -> Double) -> Integrator a -> Double
integrator

instance MonadDistribution Integrator where
  random :: Integrator Double
random = (Double -> Double) -> Integrator Double
fromDensityFunction forall a b. (a -> b) -> a -> b
$ forall d. ContDistr d => d -> Double -> Double
Statistics.density forall a b. (a -> b) -> a -> b
$ Double -> Double -> UniformDistribution
Statistics.uniformDistr Double
0 Double
1
  bernoulli :: Double -> Integrator Bool
bernoulli Double
p = forall a. Cont Double a -> Integrator a
Integrator forall a b. (a -> b) -> a -> b
$ forall a r. ((a -> r) -> r) -> Cont r a
cont (\Bool -> Double
f -> Double
p forall a. Num a => a -> a -> a
* Bool -> Double
f Bool
True forall a. Num a => a -> a -> a
+ (Double
1 forall a. Num a => a -> a -> a
- Double
p) forall a. Num a => a -> a -> a
* Bool -> Double
f Bool
False)
  uniformD :: forall a. [a] -> Integrator a
uniformD [a]
ls = forall (f :: * -> *) a.
Foldable f =>
(a -> Double) -> f a -> Integrator a
fromMassFunction (forall a b. a -> b -> a
const (Double
1 forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ls))) [a]
ls

fromDensityFunction :: (Double -> Double) -> Integrator Double
fromDensityFunction :: (Double -> Double) -> Integrator Double
fromDensityFunction Double -> Double
d = forall a. Cont Double a -> Integrator a
Integrator forall a b. (a -> b) -> a -> b
$
  forall a r. ((a -> r) -> r) -> Cont r a
cont forall a b. (a -> b) -> a -> b
$ \Double -> Double
f ->
    (Double -> Double) -> Double
integralWithQuadrature (\Double
x -> Double -> Double
f Double
x forall a. Num a => a -> a -> a
* Double -> Double
d Double
x)
  where
    integralWithQuadrature :: (Double -> Double) -> Double
integralWithQuadrature = Result -> Double
result forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> a
last forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\Double -> Double
z -> (Double -> Double) -> Double -> Double -> [Result]
trap Double -> Double
z Double
0 Double
1)

fromMassFunction :: Foldable f => (a -> Double) -> f a -> Integrator a
fromMassFunction :: forall (f :: * -> *) a.
Foldable f =>
(a -> Double) -> f a -> Integrator a
fromMassFunction a -> Double
f f a
support = forall a. Cont Double a -> Integrator a
Integrator forall a b. (a -> b) -> a -> b
$ forall a r. ((a -> r) -> r) -> Cont r a
cont \a -> Double
g ->
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Double
acc a
x -> Double
acc forall a. Num a => a -> a -> a
+ a -> Double
f a
x forall a. Num a => a -> a -> a
* a -> Double
g a
x) Double
0 f a
support

empirical :: Foldable f => f a -> Integrator a
empirical :: forall (f :: * -> *) a. Foldable f => f a -> Integrator a
empirical = forall a. Cont Double a -> Integrator a
Integrator forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a r. ((a -> r) -> r) -> Cont r a
cont forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (f :: * -> *) r a.
(Foldable f, Fractional r) =>
(a -> r) -> f a -> r
weightedAverage
  where
    weightedAverage :: (Foldable f, Fractional r) => (a -> r) -> f a -> r
    weightedAverage :: forall (f :: * -> *) r a.
(Foldable f, Fractional r) =>
(a -> r) -> f a -> r
weightedAverage a -> r
f = forall (f :: * -> *) a b. Foldable f => Fold a b -> f a -> b
Foldl.fold (forall r a. Fractional r => (a -> r) -> Fold a r
weightedAverageFold a -> r
f)

    weightedAverageFold :: Fractional r => (a -> r) -> Fold a r
    weightedAverageFold :: forall r a. Fractional r => (a -> r) -> Fold a r
weightedAverageFold a -> r
f = forall a b r. (a -> b) -> Fold b r -> Fold a r
Foldl.premap a -> r
f forall a. Fractional a => Fold a a
averageFold

    averageFold :: Fractional a => Fold a a
    averageFold :: forall a. Fractional a => Fold a a
averageFold = forall a. Fractional a => a -> a -> a
(/) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Num a => Fold a a
Foldl.sum forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall b a. Num b => Fold a b
Foldl.genericLength

expectation :: Integrator Double -> Double
expectation :: Integrator Double -> Double
expectation = forall a. (a -> Double) -> Integrator a -> Double
integrator forall a. a -> a
id

variance :: Integrator Double -> Double
variance :: Integrator Double -> Double
variance Integrator Double
nu = forall a. (a -> Double) -> Integrator a -> Double
integrator (forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2) Integrator Double
nu forall a. Num a => a -> a -> a
- Integrator Double -> Double
expectation Integrator Double
nu forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2

momentGeneratingFunction :: Integrator Double -> Double -> Double
momentGeneratingFunction :: Integrator Double -> Double -> Double
momentGeneratingFunction Integrator Double
nu Double
t = forall a. (a -> Double) -> Integrator a -> Double
integrator (\Double
x -> forall a. Floating a => a -> a
exp (Double
t forall a. Num a => a -> a -> a
* Double
x)) Integrator Double
nu

cumulantGeneratingFunction :: Integrator Double -> Double -> Double
cumulantGeneratingFunction :: Integrator Double -> Double -> Double
cumulantGeneratingFunction Integrator Double
nu = forall a. Floating a => a -> a
log forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integrator Double -> Double -> Double
momentGeneratingFunction Integrator Double
nu

normalize :: Weighted Integrator a -> Integrator a
normalize :: forall a. Weighted Integrator a -> Integrator a
normalize Weighted Integrator a
m =
  let m' :: Integrator (a, Log Double)
m' = forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted Weighted Integrator a
m
      z :: Double
z = forall a. (a -> Double) -> Integrator a -> Double
integrator (forall a. Log a -> a
ln forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) Integrator (a, Log Double)
m'
   in do
        (a
x, Log Double
d) <- forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted Weighted Integrator a
m
        forall a. Cont Double a -> Integrator a
Integrator forall a b. (a -> b) -> a -> b
$ forall a r. ((a -> r) -> r) -> Cont r a
cont forall a b. (a -> b) -> a -> b
$ \() -> Double
f -> (() -> Double
f () forall a. Num a => a -> a -> a
* (forall a. Log a -> a
ln forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
exp Log Double
d)) forall a. Fractional a => a -> a -> a
/ Double
z
        return a
x

cdf :: Integrator Double -> Double -> Double
cdf :: Integrator Double -> Double -> Double
cdf Integrator Double
nu Double
x = forall a. (a -> Double) -> Integrator a -> Double
integrator (Double
negativeInfinity forall a. (Num a, Ord a) => a -> a -> a -> a
`to` Double
x) Integrator Double
nu
  where
    negativeInfinity :: Double
    negativeInfinity :: Double
negativeInfinity = forall a. Num a => a -> a
negate (Double
1 forall a. Fractional a => a -> a -> a
/ Double
0)

    to :: (Num a, Ord a) => a -> a -> a -> a
    to :: forall a. (Num a, Ord a) => a -> a -> a -> a
to a
a a
b a
k
      | a
k forall a. Ord a => a -> a -> Bool
>= a
a Bool -> Bool -> Bool
&& a
k forall a. Ord a => a -> a -> Bool
<= a
b = a
1
      | Bool
otherwise = a
0

volume :: Integrator Double -> Double
volume :: Integrator Double -> Double
volume = forall a. (a -> Double) -> Integrator a -> Double
integrator (forall a b. a -> b -> a
const Double
1)

containing :: (Num a, Eq b) => [b] -> b -> a
containing :: forall a b. (Num a, Eq b) => [b] -> b -> a
containing [b]
xs b
x
  | b
x forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [b]
xs = a
1
  | Bool
otherwise = a
0

instance Num a => Num (Integrator a) where
  + :: Integrator a -> Integrator a -> Integrator a
(+) = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Num a => a -> a -> a
(+)
  (-) = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (-)
  * :: Integrator a -> Integrator a -> Integrator a
(*) = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Num a => a -> a -> a
(*)
  abs :: Integrator a -> Integrator a
abs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Num a => a -> a
abs
  signum :: Integrator a -> Integrator a
signum = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Num a => a -> a
signum
  fromInteger :: Integer -> Integrator a
fromInteger = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Integer -> a
fromInteger

probability :: Ord a => (a, a) -> Integrator a -> Double
probability :: forall a. Ord a => (a, a) -> Integrator a -> Double
probability (a
lower, a
upper) = forall a. (a -> Double) -> Integrator a -> Double
integrator (\a
x -> if a
x forall a. Ord a => a -> a -> Bool
< a
upper Bool -> Bool -> Bool
&& a
x forall a. Ord a => a -> a -> Bool
>= a
lower then Double
1 else Double
0)

enumeratorWith :: Ord a => Set a -> Integrator a -> [(a, Double)]
enumeratorWith :: forall a. Ord a => Set a -> Integrator a -> [(a, Double)]
enumeratorWith Set a
ls Integrator a
meas =
  [ ( a
val,
      forall a. (a -> Double) -> Integrator a -> Double
integrator
        (\a
x -> if a
x forall a. Eq a => a -> a -> Bool
== a
val then Double
1 else Double
0)
        Integrator a
meas
    )
    | a
val <- forall a. Set a -> [a]
elems Set a
ls
  ]

histogram ::
  (Enum a, Ord a, Fractional a) =>
  Int ->
  a ->
  Weighted Integrator a ->
  [(a, Double)]
histogram :: forall a.
(Enum a, Ord a, Fractional a) =>
Int -> a -> Weighted Integrator a -> [(a, Double)]
histogram Int
nBins a
binSize Weighted Integrator a
model = do
  a
x <- forall a. Int -> [a] -> [a]
take Int
nBins [a
1 ..]
  let transform :: a -> a
transform a
k = (a
k forall a. Num a => a -> a -> a
- (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nBins forall a. Fractional a => a -> a -> a
/ a
2)) forall a. Num a => a -> a -> a
* a
binSize
  forall (m :: * -> *) a. Monad m => a -> m a
return
    ( (forall a b. (a, b) -> a
fst)
        (a -> a
transform a
x, a -> a
transform (a
x forall a. Num a => a -> a -> a
+ a
1)),
      forall a. Ord a => (a, a) -> Integrator a -> Double
probability (a -> a
transform a
x, a -> a
transform (a
x forall a. Num a => a -> a -> a
+ a
1)) forall a b. (a -> b) -> a -> b
$ forall a. Weighted Integrator a -> Integrator a
normalize Weighted Integrator a
model
    )

plotCdf :: Int -> Double -> Double -> Integrator Double -> [(Double, Double)]
plotCdf :: Int -> Double -> Double -> Integrator Double -> [(Double, Double)]
plotCdf Int
nBins Double
binSize Double
middlePoint Integrator Double
model = do
  Double
x <- forall a. Int -> [a] -> [a]
take Int
nBins [Double
1 ..]
  let transform :: Double -> Double
transform Double
k = (Double
k forall a. Num a => a -> a -> a
- (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nBins forall a. Fractional a => a -> a -> a
/ Double
2)) forall a. Num a => a -> a -> a
* Double
binSize forall a. Num a => a -> a -> a
+ Double
middlePoint
  forall (m :: * -> *) a. Monad m => a -> m a
return (Double -> Double
transform Double
x, Integrator Double -> Double -> Double
cdf Integrator Double
model (Double -> Double
transform Double
x))