{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_GHC -Wno-deprecations #-}

-- |
-- Module      : Control.Monad.Bayes.Class
-- Description : Types for probabilistic modelling
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- This module defines 'MonadMeasure', which can be used to represent any probabilistic program,
-- such as the following:
--
-- @
-- import Control.Monad (when)
-- import Control.Monad.Bayes.Class
--
-- model :: MonadMeasure m => m Bool
-- model = do
--   rain <- bernoulli 0.3
--   sprinkler <-
--     bernoulli $
--     if rain
--       then 0.1
--       else 0.4
--   let wetProb =
--     case (rain, sprinkler) of
--       (True,  True)  -> 0.98
--       (True,  False) -> 0.80
--       (False, True)  -> 0.90
--       (False, False) -> 0.00
--   score wetProb
--   return rain
-- @
module Control.Monad.Bayes.Class
  ( MonadDistribution,
    random,
    uniform,
    normal,
    gamma,
    beta,
    bernoulli,
    categorical,
    logCategorical,
    uniformD,
    geometric,
    poisson,
    dirichlet,
    MonadFactor,
    score,
    factor,
    condition,
    MonadMeasure,
    discrete,
    normalPdf,
    Bayesian (..),
    posterior,
    priorPredictive,
    posteriorPredictive,
    independent,
    mvNormal,
    Histogram,
    histogram,
    histogramToList,
    Distribution,
    Measure,
    Kernel,
    Log (ln, Exp),
  )
where

import Control.Arrow (Arrow (second))
import Control.Monad (replicateM, when)
import Control.Monad.Cont (ContT)
import Control.Monad.Except (ExceptT, lift)
import Control.Monad.Identity (IdentityT)
import Control.Monad.List (ListT)
import Control.Monad.Reader (ReaderT)
import Control.Monad.State (StateT)
import Control.Monad.Writer (WriterT)
import Data.Histogram qualified as H
import Data.Histogram.Fill qualified as H
import Data.Matrix
  ( Matrix,
    cholDecomp,
    colVector,
    getCol,
    multStd,
  )
import Data.Vector qualified as V
import Data.Vector.Generic as VG (Vector, map, mapM, null, sum, (!))
import Numeric.Log (Log (..))
import Statistics.Distribution
  ( ContDistr (logDensity, quantile),
    DiscreteDistr (probability),
  )
import Statistics.Distribution.Beta (betaDistr)
import Statistics.Distribution.Gamma (gammaDistr)
import Statistics.Distribution.Geometric (geometric0)
import Statistics.Distribution.Normal (normalDistr)
import Statistics.Distribution.Poisson qualified as Poisson
import Statistics.Distribution.Uniform (uniformDistr)

-- | Monads that can draw random variables.
class Monad m => MonadDistribution m where
  -- | Draw from a uniform distribution.
  random ::
    -- | \(\sim \mathcal{U}(0, 1)\)
    m Double

  -- | Draw from a uniform distribution.
  uniform ::
    -- | lower bound a
    Double ->
    -- | upper bound b
    Double ->
    -- | \(\sim \mathcal{U}(a, b)\).
    m Double
  uniform Double
a Double
b = forall d (m :: * -> *).
(ContDistr d, MonadDistribution m) =>
d -> m Double
draw (Double -> Double -> UniformDistribution
uniformDistr Double
a Double
b)

  -- | Draw from a normal distribution.
  normal ::
    -- | mean μ
    Double ->
    -- | standard deviation σ
    Double ->
    -- | \(\sim \mathcal{N}(\mu, \sigma^2)\)
    m Double
  normal Double
m Double
s = forall d (m :: * -> *).
(ContDistr d, MonadDistribution m) =>
d -> m Double
draw (Double -> Double -> NormalDistribution
normalDistr Double
m Double
s)

  -- | Draw from a gamma distribution.
  gamma ::
    -- | shape k
    Double ->
    -- | scale θ
    Double ->
    -- | \(\sim \Gamma(k, \theta)\)
    m Double
  gamma Double
shape Double
scale = forall d (m :: * -> *).
(ContDistr d, MonadDistribution m) =>
d -> m Double
draw (Double -> Double -> GammaDistribution
gammaDistr Double
shape Double
scale)

  -- | Draw from a beta distribution.
  beta ::
    -- | shape α
    Double ->
    -- | shape β
    Double ->
    -- | \(\sim \mathrm{Beta}(\alpha, \beta)\)
    m Double
  beta Double
a Double
b = forall d (m :: * -> *).
(ContDistr d, MonadDistribution m) =>
d -> m Double
draw (Double -> Double -> BetaDistribution
betaDistr Double
a Double
b)

  -- | Draw from a Bernoulli distribution.
  bernoulli ::
    -- | probability p
    Double ->
    -- | \(\sim \mathrm{B}(1, p)\)
    m Bool
  bernoulli Double
p = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Ord a => a -> a -> Bool
< Double
p) forall (m :: * -> *). MonadDistribution m => m Double
random

  -- | Draw from a categorical distribution.
  categorical ::
    Vector v Double =>
    -- | event probabilities
    v Double ->
    -- | outcome category
    m Int
  categorical v Double
ps = if forall (v :: * -> *) a. Vector v a => v a -> Bool
VG.null v Double
ps then forall a. HasCallStack => [Char] -> a
error [Char]
"empty input list" else forall (m :: * -> *).
MonadDistribution m =>
(Int -> Double) -> m Int
fromPMF (v Double
ps forall (v :: * -> *) a. Vector v a => v a -> Int -> a
!)

  -- | Draw from a categorical distribution in the log domain.
  logCategorical ::
    (Vector v (Log Double), Vector v Double) =>
    -- | event probabilities
    v (Log Double) ->
    -- | outcome category
    m Int
  logCategorical = forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> m Int
categorical forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
VG.map (forall a. Floating a => a -> a
exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Log a -> a
ln)

  -- | Draw from a discrete uniform distribution.
  uniformD ::
    -- | observable outcomes @xs@
    [a] ->
    -- | \(\sim \mathcal{U}\{\mathrm{xs}\}\)
    m a
  uniformD [a]
xs = do
    let n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
Prelude.length [a]
xs
    Int
i <- forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> m Int
categorical forall a b. (a -> b) -> a -> b
$ forall a. Int -> a -> Vector a
V.replicate Int
n (Double
1 forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
    forall (m :: * -> *) a. Monad m => a -> m a
return ([a]
xs forall a. [a] -> Int -> a
!! Int
i)

  -- | Draw from a geometric distribution.
  geometric ::
    -- | success rate p
    Double ->
    -- | \(\sim\) number of failed Bernoulli trials with success probability p before first success
    m Int
  geometric = forall d (m :: * -> *).
(DiscreteDistr d, MonadDistribution m) =>
d -> m Int
discrete forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> GeometricDistribution0
geometric0

  -- | Draw from a Poisson distribution.
  poisson ::
    -- | parameter λ
    Double ->
    -- | \(\sim \mathrm{Pois}(\lambda)\)
    m Int
  poisson = forall d (m :: * -> *).
(DiscreteDistr d, MonadDistribution m) =>
d -> m Int
discrete forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> PoissonDistribution
Poisson.poisson

  -- | Draw from a Dirichlet distribution.
  dirichlet ::
    Vector v Double =>
    -- | concentration parameters @as@
    v Double ->
    -- | \(\sim \mathrm{Dir}(\mathrm{as})\)
    m (v Double)
  dirichlet v Double
as = do
    v Double
xs <- forall (m :: * -> *) (v :: * -> *) a b.
(Monad m, Vector v a, Vector v b) =>
(a -> m b) -> v a -> m (v b)
VG.mapM (forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> m Double
`gamma` Double
1) v Double
as
    let s :: Double
s = forall (v :: * -> *) a. (Vector v a, Num a) => v a -> a
VG.sum v Double
xs
    let ys :: v Double
ys = forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
VG.map (forall a. Fractional a => a -> a -> a
/ Double
s) v Double
xs
    forall (m :: * -> *) a. Monad m => a -> m a
return v Double
ys

-- | Draw from a continuous distribution using the inverse cumulative density
-- function.
draw :: (ContDistr d, MonadDistribution m) => d -> m Double
draw :: forall d (m :: * -> *).
(ContDistr d, MonadDistribution m) =>
d -> m Double
draw d
d = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall d. ContDistr d => d -> Double -> Double
quantile d
d) forall (m :: * -> *). MonadDistribution m => m Double
random

-- | Draw from a discrete distribution using a sequence of draws from
-- Bernoulli.
fromPMF :: MonadDistribution m => (Int -> Double) -> m Int
fromPMF :: forall (m :: * -> *).
MonadDistribution m =>
(Int -> Double) -> m Int
fromPMF Int -> Double
p = forall {m :: * -> *}. MonadDistribution m => Int -> Double -> m Int
f Int
0 Double
1
  where
    f :: Int -> Double -> m Int
f Int
i Double
r = do
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Double
r forall a. Ord a => a -> a -> Bool
< Double
0) forall a b. (a -> b) -> a -> b
$ forall a. HasCallStack => [Char] -> a
error [Char]
"fromPMF: total PMF above 1"
      let q :: Double
q = Int -> Double
p Int
i
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Double
q forall a. Ord a => a -> a -> Bool
< Double
0 Bool -> Bool -> Bool
|| Double
q forall a. Ord a => a -> a -> Bool
> Double
1) forall a b. (a -> b) -> a -> b
$ forall a. HasCallStack => [Char] -> a
error [Char]
"fromPMF: invalid probability value"
      Bool
b <- forall (m :: * -> *). MonadDistribution m => Double -> m Bool
bernoulli (Double
q forall a. Fractional a => a -> a -> a
/ Double
r)
      if Bool
b then forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
i else Int -> Double -> m Int
f (Int
i forall a. Num a => a -> a -> a
+ Int
1) (Double
r forall a. Num a => a -> a -> a
- Double
q)

-- | Draw from a discrete distributions using the probability mass function.
discrete :: (DiscreteDistr d, MonadDistribution m) => d -> m Int
discrete :: forall d (m :: * -> *).
(DiscreteDistr d, MonadDistribution m) =>
d -> m Int
discrete = forall (m :: * -> *).
MonadDistribution m =>
(Int -> Double) -> m Int
fromPMF forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. DiscreteDistr d => d -> Int -> Double
probability

-- | Monads that can score different execution paths.
class Monad m => MonadFactor m where
  -- | Record a likelihood.
  score ::
    -- | likelihood of the execution path
    Log Double ->
    m ()

-- | Synonym for 'score'.
factor ::
  MonadFactor m =>
  -- | likelihood of the execution path
  Log Double ->
  m ()
factor :: forall (m :: * -> *). MonadFactor m => Log Double -> m ()
factor = forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score

-- | synonym for pretty type signatures, but note that (A -> Distribution B) won't work as intended: for that, use Kernel
-- Also note that the use of RankNTypes means performance may take a hit: really the main point of these signatures is didactic
type Distribution a = forall m. MonadDistribution m => m a

type Measure a = forall m. MonadMeasure m => m a

type Kernel a b = forall m. MonadMeasure m => a -> m b

-- | Hard conditioning.
condition :: MonadFactor m => Bool -> m ()
condition :: forall (m :: * -> *). MonadFactor m => Bool -> m ()
condition Bool
b = forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score forall a b. (a -> b) -> a -> b
$ if Bool
b then Log Double
1 else Log Double
0

independent :: Applicative m => Int -> m a -> m [a]
independent :: forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
independent = forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM

-- | Monads that support both sampling and scoring.
class (MonadDistribution m, MonadFactor m) => MonadMeasure m

-- | Probability density function of the normal distribution.
normalPdf ::
  -- | mean μ
  Double ->
  -- | standard deviation σ
  Double ->
  -- | sample x
  Double ->
  -- | relative likelihood of observing sample x in \(\mathcal{N}(\mu, \sigma^2)\)
  Log Double
normalPdf :: Double -> Double -> Double -> Log Double
normalPdf Double
mu Double
sigma Double
x = forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ forall d. ContDistr d => d -> Double -> Double
logDensity (Double -> Double -> NormalDistribution
normalDistr Double
mu Double
sigma) Double
x

-- | multivariate normal
mvNormal :: MonadDistribution m => V.Vector Double -> Matrix Double -> m (V.Vector Double)
mvNormal :: forall (m :: * -> *).
MonadDistribution m =>
Vector Double -> Matrix Double -> m (Vector Double)
mvNormal Vector Double
mu Matrix Double
bigSigma = do
  let n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
length Vector Double
mu
  [Double]
ss <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> m Double
normal Double
0 Double
1)
  let bigL :: Matrix Double
bigL = forall a. Floating a => Matrix a -> Matrix a
cholDecomp Matrix Double
bigSigma
  let ts :: Matrix Double
ts = (forall a. Vector a -> Matrix a
colVector Vector Double
mu) forall a. Num a => a -> a -> a
+ Matrix Double
bigL forall a. Num a => Matrix a -> Matrix a -> Matrix a
`multStd` (forall a. Vector a -> Matrix a
colVector forall a b. (a -> b) -> a -> b
$ forall a. [a] -> Vector a
V.fromList [Double]
ss)
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Int -> Matrix a -> Vector a
getCol Int
1 Matrix Double
ts

-- | a useful datatype for expressing bayesian models
data Bayesian m z o = Bayesian
  { forall (m :: * -> *) z o. Bayesian m z o -> m z
prior :: m z, -- prior over latent variable Z; p(z)
    forall (m :: * -> *) z o. Bayesian m z o -> z -> m o
generative :: z -> m o, -- distribution over observations given Z=z; p(o|Z=z)
    forall (m :: * -> *) z o. Bayesian m z o -> z -> o -> Log Double
likelihood :: z -> o -> Log Double -- p(o|z)
  }

-- | p(z|o)
posterior :: (MonadMeasure m, Foldable f, Functor f) => Bayesian m z o -> f o -> m z
posterior :: forall (m :: * -> *) (f :: * -> *) z o.
(MonadMeasure m, Foldable f, Functor f) =>
Bayesian m z o -> f o -> m z
posterior Bayesian {m z
z -> m o
z -> o -> Log Double
likelihood :: z -> o -> Log Double
generative :: z -> m o
prior :: m z
likelihood :: forall (m :: * -> *) z o. Bayesian m z o -> z -> o -> Log Double
generative :: forall (m :: * -> *) z o. Bayesian m z o -> z -> m o
prior :: forall (m :: * -> *) z o. Bayesian m z o -> m z
..} f o
os = do
  z
z <- m z
prior
  forall (m :: * -> *). MonadFactor m => Log Double -> m ()
factor forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (z -> o -> Log Double
likelihood z
z) f o
os
  forall (m :: * -> *) a. Monad m => a -> m a
return z
z

priorPredictive :: Monad m => Bayesian m a b -> m b
priorPredictive :: forall (m :: * -> *) a b. Monad m => Bayesian m a b -> m b
priorPredictive Bayesian m a b
bm = forall (m :: * -> *) z o. Bayesian m z o -> m z
prior Bayesian m a b
bm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) z o. Bayesian m z o -> z -> m o
generative Bayesian m a b
bm

posteriorPredictive ::
  (MonadMeasure m, Foldable f, Functor f) =>
  Bayesian m a b ->
  f b ->
  m b
posteriorPredictive :: forall (m :: * -> *) (f :: * -> *) a b.
(MonadMeasure m, Foldable f, Functor f) =>
Bayesian m a b -> f b -> m b
posteriorPredictive Bayesian m a b
bm f b
os = forall (m :: * -> *) (f :: * -> *) z o.
(MonadMeasure m, Foldable f, Functor f) =>
Bayesian m z o -> f o -> m z
posterior Bayesian m a b
bm f b
os forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) z o. Bayesian m z o -> z -> m o
generative Bayesian m a b
bm

-- helper funcs
--------------------

type Histogram = H.Histogram H.BinD Double

histogram :: Int -> [(Double, Log Double)] -> Histogram
histogram :: Int -> [(Double, Log Double)] -> Histogram
histogram Int
n [(Double, Log Double)]
v = forall (f :: * -> *) a b. Foldable f => HBuilder a b -> f a -> b
H.fillBuilder HBuilder (Double, Double) Histogram
buildr forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (forall a. Log a -> a
ln forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
exp)) [(Double, Log Double)]
v
  where
    v1 :: [Double]
v1 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst [(Double, Log Double)]
v
    mi :: Double
mi = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
Prelude.minimum [Double]
v1
    ma :: Double
ma = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
Prelude.maximum [Double]
v1
    bins :: BinD
bins = Double -> Int -> Double -> BinD
H.binD Double
mi Int
n Double
ma
    buildr :: HBuilder (BinValue BinD, Double) Histogram
buildr = forall bin val.
(Bin bin, Unbox val, Num val) =>
bin -> HBuilder (BinValue bin, val) (Histogram bin val)
H.mkWeighted BinD
bins

histogramToList :: Histogram -> [(Double, Double)]
histogramToList :: Histogram -> [(Double, Double)]
histogramToList = forall a bin.
(Unbox a, Bin bin) =>
Histogram bin a -> [(BinValue bin, a)]
H.asList

----------------------------------------------------------------------------
-- Instances that lift probabilistic effects to standard tranformers.

instance MonadDistribution m => MonadDistribution (IdentityT m) where
  random :: IdentityT m Double
random = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadDistribution m => m Double
random
  bernoulli :: Double -> IdentityT m Bool
bernoulli = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadDistribution m => Double -> m Bool
bernoulli

instance MonadFactor m => MonadFactor (IdentityT m) where
  score :: Log Double -> IdentityT m ()
score = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score

instance MonadMeasure m => MonadMeasure (IdentityT m)

instance MonadDistribution m => MonadDistribution (ExceptT e m) where
  random :: ExceptT e m Double
random = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadDistribution m => m Double
random
  uniformD :: forall a. [a] -> ExceptT e m a
uniformD = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadDistribution m => [a] -> m a
uniformD

instance MonadFactor m => MonadFactor (ExceptT e m) where
  score :: Log Double -> ExceptT e m ()
score = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score

instance MonadMeasure m => MonadMeasure (ExceptT e m)

instance MonadDistribution m => MonadDistribution (ReaderT r m) where
  random :: ReaderT r m Double
random = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadDistribution m => m Double
random
  bernoulli :: Double -> ReaderT r m Bool
bernoulli = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadDistribution m => Double -> m Bool
bernoulli

instance MonadFactor m => MonadFactor (ReaderT r m) where
  score :: Log Double -> ReaderT r m ()
score = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score

instance MonadMeasure m => MonadMeasure (ReaderT r m)

instance (Monoid w, MonadDistribution m) => MonadDistribution (WriterT w m) where
  random :: WriterT w m Double
random = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadDistribution m => m Double
random
  bernoulli :: Double -> WriterT w m Bool
bernoulli = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadDistribution m => Double -> m Bool
bernoulli
  categorical :: forall (v :: * -> *).
Vector v Double =>
v Double -> WriterT w m Int
categorical = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> m Int
categorical

instance (Monoid w, MonadFactor m) => MonadFactor (WriterT w m) where
  score :: Log Double -> WriterT w m ()
score = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score

instance (Monoid w, MonadMeasure m) => MonadMeasure (WriterT w m)

instance MonadDistribution m => MonadDistribution (StateT s m) where
  random :: StateT s m Double
random = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadDistribution m => m Double
random
  bernoulli :: Double -> StateT s m Bool
bernoulli = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadDistribution m => Double -> m Bool
bernoulli
  categorical :: forall (v :: * -> *). Vector v Double => v Double -> StateT s m Int
categorical = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> m Int
categorical
  uniformD :: forall a. [a] -> StateT s m a
uniformD = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadDistribution m => [a] -> m a
uniformD

instance MonadFactor m => MonadFactor (StateT s m) where
  score :: Log Double -> StateT s m ()
score = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score

instance MonadMeasure m => MonadMeasure (StateT s m)

instance MonadDistribution m => MonadDistribution (ListT m) where
  random :: ListT m Double
random = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadDistribution m => m Double
random
  bernoulli :: Double -> ListT m Bool
bernoulli = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadDistribution m => Double -> m Bool
bernoulli
  categorical :: forall (v :: * -> *). Vector v Double => v Double -> ListT m Int
categorical = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> m Int
categorical

instance MonadFactor m => MonadFactor (ListT m) where
  score :: Log Double -> ListT m ()
score = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score

instance MonadMeasure m => MonadMeasure (ListT m)

instance MonadDistribution m => MonadDistribution (ContT r m) where
  random :: ContT r m Double
random = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadDistribution m => m Double
random

instance MonadFactor m => MonadFactor (ContT r m) where
  score :: Log Double -> ContT r m ()
score = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score

instance MonadMeasure m => MonadMeasure (ContT r m)