{-# LANGUAGE BangPatterns #-}

-- |
-- Module      :  Prior
-- Description :  Types and convenience functions for computing priors
-- Copyright   :  (c) Dominik Schrempf, 2021
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Thu Jul 23 13:26:14 2020.
module Mcmc.Prior
  ( Prior,
    PriorFunction,
    noPrior,

    -- * Improper priors
    largerThan,
    positive,
    lowerThan,
    negative,

    -- * Continuous priors
    exponential,
    gamma,
    gammaMeanVariance,
    gammaMeanOne,
    gammaShapeScaleToMeanVariance,
    gammaMeanVarianceToShapeScale,
    normal,
    uniform,

    -- * Discrete priors
    poisson,

    -- * Auxiliary functions
    product',
  )
where

import Control.Monad
import Data.Maybe (fromMaybe)
import Mcmc.Statistics.Types
import Numeric.Log
import qualified Statistics.Distribution as S
import qualified Statistics.Distribution.Exponential as S
import qualified Statistics.Distribution.Gamma as S
import qualified Statistics.Distribution.Normal as S
import qualified Statistics.Distribution.Poisson as S

-- | Prior values are stored in log domain.
type Prior = Log Double

-- | Prior function.
type PriorFunction a = a -> Prior

-- | Flat prior function. Useful for testing and debugging.
noPrior :: PriorFunction a
noPrior :: PriorFunction a
noPrior = Prior -> PriorFunction a
forall a b. a -> b -> a
const Prior
1.0

-- | Improper uniform prior; strictly larger than a given value.
largerThan :: LowerBoundary -> PriorFunction Double
largerThan :: LowerBoundary -> PriorFunction LowerBoundary
largerThan LowerBoundary
a LowerBoundary
x
  | LowerBoundary
x LowerBoundary -> LowerBoundary -> Bool
forall a. Ord a => a -> a -> Bool
<= LowerBoundary
a = Prior
0
  | Bool
otherwise = Prior
1

-- | Improper uniform prior; strictly larger than zero.
positive :: PriorFunction Double
positive :: PriorFunction LowerBoundary
positive = LowerBoundary -> PriorFunction LowerBoundary
largerThan LowerBoundary
0

-- | Improper uniform prior; strictly lower than a given value.
lowerThan :: UpperBoundary -> PriorFunction Double
lowerThan :: LowerBoundary -> PriorFunction LowerBoundary
lowerThan LowerBoundary
b LowerBoundary
x
  | LowerBoundary
x LowerBoundary -> LowerBoundary -> Bool
forall a. Ord a => a -> a -> Bool
>= LowerBoundary
b = Prior
0
  | Bool
otherwise = Prior
1

-- | Improper uniform prior; strictly lower than zero.
negative :: PriorFunction Double
negative :: PriorFunction LowerBoundary
negative = LowerBoundary -> PriorFunction LowerBoundary
lowerThan LowerBoundary
0

-- | Exponential distributed prior.
exponential :: Rate -> PriorFunction Double
exponential :: LowerBoundary -> PriorFunction LowerBoundary
exponential LowerBoundary
l = PriorFunction LowerBoundary
forall a. a -> Log a
Exp PriorFunction LowerBoundary
-> (LowerBoundary -> LowerBoundary) -> PriorFunction LowerBoundary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExponentialDistribution -> LowerBoundary -> LowerBoundary
forall d. ContDistr d => d -> LowerBoundary -> LowerBoundary
S.logDensity ExponentialDistribution
d
  where
    d :: ExponentialDistribution
d = LowerBoundary -> ExponentialDistribution
S.exponential LowerBoundary
l

-- | Gamma distributed prior.
gamma :: Shape -> Scale -> PriorFunction Double
gamma :: LowerBoundary -> LowerBoundary -> PriorFunction LowerBoundary
gamma LowerBoundary
k LowerBoundary
t = PriorFunction LowerBoundary
forall a. a -> Log a
Exp PriorFunction LowerBoundary
-> (LowerBoundary -> LowerBoundary) -> PriorFunction LowerBoundary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GammaDistribution -> LowerBoundary -> LowerBoundary
forall d. ContDistr d => d -> LowerBoundary -> LowerBoundary
S.logDensity GammaDistribution
d
  where
    d :: GammaDistribution
d = LowerBoundary -> LowerBoundary -> GammaDistribution
S.gammaDistr LowerBoundary
k LowerBoundary
t

-- | See 'gamma' but parametrized using mean and variance.
gammaMeanVariance :: Mean -> Variance -> PriorFunction Double
gammaMeanVariance :: LowerBoundary -> LowerBoundary -> PriorFunction LowerBoundary
gammaMeanVariance LowerBoundary
m LowerBoundary
v = PriorFunction LowerBoundary
forall a. a -> Log a
Exp PriorFunction LowerBoundary
-> (LowerBoundary -> LowerBoundary) -> PriorFunction LowerBoundary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GammaDistribution -> LowerBoundary -> LowerBoundary
forall d. ContDistr d => d -> LowerBoundary -> LowerBoundary
S.logDensity GammaDistribution
d
  where (LowerBoundary
k, LowerBoundary
th) = LowerBoundary -> LowerBoundary -> (LowerBoundary, LowerBoundary)
gammaMeanVarianceToShapeScale LowerBoundary
m LowerBoundary
v
        d :: GammaDistribution
d = LowerBoundary -> LowerBoundary -> GammaDistribution
S.gammaDistr LowerBoundary
k LowerBoundary
th

-- | Gamma disstributed prior with given shape and mean 1.0.
gammaMeanOne :: Shape -> PriorFunction Double
gammaMeanOne :: LowerBoundary -> PriorFunction LowerBoundary
gammaMeanOne LowerBoundary
k = PriorFunction LowerBoundary
forall a. a -> Log a
Exp PriorFunction LowerBoundary
-> (LowerBoundary -> LowerBoundary) -> PriorFunction LowerBoundary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GammaDistribution -> LowerBoundary -> LowerBoundary
forall d. ContDistr d => d -> LowerBoundary -> LowerBoundary
S.logDensity GammaDistribution
d
  where
    d :: GammaDistribution
d = LowerBoundary -> LowerBoundary -> GammaDistribution
S.gammaDistr LowerBoundary
k (LowerBoundary -> LowerBoundary
forall a. Fractional a => a -> a
recip LowerBoundary
k)

-- The mean and variance of the gamma distribution are
--
-- m = k*t
--
-- v = k*t*t
--
-- Hence, the shape and scale are
--
-- k = m^2/v
--
-- t = v/m

-- | Calculate mean and variance of the gamma distribution given the shape and
-- the scale.
gammaShapeScaleToMeanVariance :: Shape -> Scale -> (Mean, Variance)
gammaShapeScaleToMeanVariance :: LowerBoundary -> LowerBoundary -> (LowerBoundary, LowerBoundary)
gammaShapeScaleToMeanVariance LowerBoundary
k LowerBoundary
t = let m :: LowerBoundary
m = LowerBoundary
k LowerBoundary -> LowerBoundary -> LowerBoundary
forall a. Num a => a -> a -> a
* LowerBoundary
t in (LowerBoundary
m, LowerBoundary
m LowerBoundary -> LowerBoundary -> LowerBoundary
forall a. Num a => a -> a -> a
* LowerBoundary
t)

-- | Calculate shape and scale of the gamma distribution given the mean and
-- the variance.
gammaMeanVarianceToShapeScale :: Mean -> Variance -> (Shape, Scale)
gammaMeanVarianceToShapeScale :: LowerBoundary -> LowerBoundary -> (LowerBoundary, LowerBoundary)
gammaMeanVarianceToShapeScale LowerBoundary
m LowerBoundary
v = (LowerBoundary
m LowerBoundary -> LowerBoundary -> LowerBoundary
forall a. Num a => a -> a -> a
* LowerBoundary
m LowerBoundary -> LowerBoundary -> LowerBoundary
forall a. Fractional a => a -> a -> a
/ LowerBoundary
v, LowerBoundary
v LowerBoundary -> LowerBoundary -> LowerBoundary
forall a. Fractional a => a -> a -> a
/ LowerBoundary
m)

-- | Normal distributed prior.
normal :: Mean -> StandardDeviation -> PriorFunction Double
normal :: LowerBoundary -> LowerBoundary -> PriorFunction LowerBoundary
normal LowerBoundary
m LowerBoundary
s = PriorFunction LowerBoundary
forall a. a -> Log a
Exp PriorFunction LowerBoundary
-> (LowerBoundary -> LowerBoundary) -> PriorFunction LowerBoundary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NormalDistribution -> LowerBoundary -> LowerBoundary
forall d. ContDistr d => d -> LowerBoundary -> LowerBoundary
S.logDensity NormalDistribution
d
  where
    d :: NormalDistribution
d = LowerBoundary -> LowerBoundary -> NormalDistribution
S.normalDistr LowerBoundary
m LowerBoundary
s

-- | Uniform prior on [a, b].
uniform :: LowerBoundary -> UpperBoundary -> PriorFunction Double
uniform :: LowerBoundary -> LowerBoundary -> PriorFunction LowerBoundary
uniform LowerBoundary
a LowerBoundary
b LowerBoundary
x
  | LowerBoundary
x LowerBoundary -> LowerBoundary -> Bool
forall a. Ord a => a -> a -> Bool
<= LowerBoundary
a = Prior
0
  | LowerBoundary
x LowerBoundary -> LowerBoundary -> Bool
forall a. Ord a => a -> a -> Bool
>= LowerBoundary
b = Prior
0
  | Bool
otherwise = PriorFunction LowerBoundary
forall a. a -> Log a
Exp LowerBoundary
0

-- | Poisson distributed prior.
poisson :: Rate -> PriorFunction Int
poisson :: LowerBoundary -> PriorFunction Int
poisson LowerBoundary
l = PriorFunction LowerBoundary
forall a. a -> Log a
Exp PriorFunction LowerBoundary
-> (Int -> LowerBoundary) -> PriorFunction Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PoissonDistribution -> Int -> LowerBoundary
forall d. DiscreteDistr d => d -> Int -> LowerBoundary
S.logProbability PoissonDistribution
d
  where
    d :: PoissonDistribution
d = LowerBoundary -> PoissonDistribution
S.poisson LowerBoundary
l

-- | Intelligent product that stops when encountering a zero.
--
-- Use with care because the elements are checked for positiveness, and this can
-- take some time if the list is long and does not contain any zeroes.
product' :: [Log Double] -> Log Double
product' :: [Prior] -> Prior
product' = Prior -> Maybe Prior -> Prior
forall a. a -> Maybe a -> a
fromMaybe Prior
0 (Maybe Prior -> Prior)
-> ([Prior] -> Maybe Prior) -> [Prior] -> Prior
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Prior] -> Maybe Prior
prodM

-- The type could be generalized to any MonadPlus Integer
prodM :: [Log Double] -> Maybe (Log Double)
prodM :: [Prior] -> Maybe Prior
prodM = (Prior -> Prior -> Maybe Prior) -> Prior -> [Prior] -> Maybe Prior
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\ !Prior
acc Prior
x -> (Prior
acc Prior -> Prior -> Prior
forall a. Num a => a -> a -> a
* Prior
x) Prior -> Maybe () -> Maybe Prior
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Prior
acc Prior -> Prior -> Bool
forall a. Eq a => a -> a -> Bool
/= Prior
0)) Prior
1