{-# LANGUAGE BangPatterns #-}

-- |
-- Module      :  Prior
-- Description :  Types and convenience functions for computing priors
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Thu Jul 23 13:26:14 2020.
module Mcmc.Prior
  ( Prior,
    PriorFunction,
    PriorFunctionG,

    -- * Improper priors
    noPrior,
    greaterThan,
    positive,
    lessThan,
    negative,

    -- * Continuous priors
    exponential,
    gamma,
    gammaMeanVariance,
    gammaMeanOne,
    gammaShapeScaleToMeanVariance,
    gammaMeanVarianceToShapeScale,
    logNormal,
    normal,
    uniform,

    -- * Discrete priors
    poisson,

    -- * Auxiliary functions
    product',
  )
where

import Control.Monad
import Data.Maybe (fromMaybe)
import Data.Typeable
import Mcmc.Internal.SpecFunctions
import Mcmc.Statistics.Types
import Numeric.Log

-- | Prior values are stored in log domain.
type Prior = Log Double

-- | Prior function.
type PriorFunction a = a -> Log Double

-- | Generalized prior function.
type PriorFunctionG a b = a -> Log b

-- | Flat prior function. Useful for testing and debugging.
noPrior :: RealFloat b => PriorFunctionG a b
noPrior :: forall b a. RealFloat b => PriorFunctionG a b
noPrior = forall a b. a -> b -> a
const Log b
1.0
{-# SPECIALIZE noPrior :: PriorFunction Double #-}

-- | Improper uniform prior; strictly greater than a given value.
greaterThan :: RealFloat a => LowerBoundary a -> PriorFunctionG a a
greaterThan :: forall a. RealFloat a => a -> PriorFunctionG a a
greaterThan a
a a
x
  | a
x forall a. Ord a => a -> a -> Bool
> a
a = Log a
1.0
  | Bool
otherwise = Log a
0.0
{-# SPECIALIZE greaterThan :: Double -> PriorFunction Double #-}

-- | Improper uniform prior; strictly greater than zero.
positive :: RealFloat a => PriorFunctionG a a
positive :: forall a. RealFloat a => PriorFunctionG a a
positive = forall a. RealFloat a => a -> PriorFunctionG a a
greaterThan a
0
{-# SPECIALIZE positive :: PriorFunction Double #-}

-- | Improper uniform prior; strictly less than a given value.
lessThan :: RealFloat a => UpperBoundary a -> PriorFunctionG a a
lessThan :: forall a. RealFloat a => a -> PriorFunctionG a a
lessThan a
a a
x
  | a
x forall a. Ord a => a -> a -> Bool
< a
a = Log a
1.0
  | Bool
otherwise = Log a
0.0
{-# SPECIALIZE lessThan :: Double -> PriorFunction Double #-}

-- | Improper uniform prior; strictly less than zero.
negative :: RealFloat a => PriorFunctionG a a
negative :: forall a. RealFloat a => PriorFunctionG a a
negative = forall a. RealFloat a => a -> PriorFunctionG a a
lessThan a
0.0
{-# SPECIALIZE negative :: PriorFunction Double #-}

-- | Exponential distributed prior.
--
-- Call 'error' if the rate is zero or negative.
exponential :: RealFloat a => Rate a -> PriorFunctionG a a
exponential :: forall a. RealFloat a => a -> PriorFunctionG a a
exponential a
l a
x
  | a
l forall a. Ord a => a -> a -> Bool
<= a
0.0 = forall a. HasCallStack => [Char] -> a
error [Char]
"exponential: Rate is zero or negative."
  | a
x forall a. Ord a => a -> a -> Bool
<= a
0.0 = Log a
0.0
  | Bool
otherwise = Log a
ll forall a. Num a => a -> a -> a
* forall a. a -> Log a
Exp (forall a. Num a => a -> a
negate a
l forall a. Num a => a -> a -> a
* a
x)
  where
    ll :: Log a
ll = forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
log a
l
{-# SPECIALIZE exponential :: Double -> PriorFunction Double #-}

-- | Gamma distributed prior.
--
-- Call 'error' if the shape or scale are zero or negative.
gamma :: (Typeable a, RealFloat a) => Shape a -> Scale a -> PriorFunctionG a a
gamma :: forall a. (Typeable a, RealFloat a) => a -> a -> PriorFunctionG a a
gamma a
k a
t a
x
  | a
k forall a. Ord a => a -> a -> Bool
<= a
0.0 = forall a. HasCallStack => [Char] -> a
error [Char]
"gamma: Shape is zero or negative."
  | a
t forall a. Ord a => a -> a -> Bool
<= a
0.0 = forall a. HasCallStack => [Char] -> a
error [Char]
"gamma: Scale is zero or negative."
  | a
x forall a. Ord a => a -> a -> Bool
<= a
0.0 = Log a
0.0
  | Bool
otherwise = forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
log a
x forall a. Num a => a -> a -> a
* (a
k forall a. Num a => a -> a -> a
- a
1.0) forall a. Num a => a -> a -> a
- (a
x forall a. Fractional a => a -> a -> a
/ a
t) forall a. Num a => a -> a -> a
- forall a. (Typeable a, RealFloat a) => a -> a
logGammaG a
k forall a. Num a => a -> a -> a
- forall a. Floating a => a -> a
log a
t forall a. Num a => a -> a -> a
* a
k
{-# SPECIALIZE gamma :: Double -> Double -> PriorFunction Double #-}

-- | See 'gamma' but parametrized using mean and variance.
gammaMeanVariance :: (Typeable a, RealFloat a) => Mean a -> Variance a -> PriorFunctionG a a
gammaMeanVariance :: forall a. (Typeable a, RealFloat a) => a -> a -> PriorFunctionG a a
gammaMeanVariance a
m a
v = forall a. (Typeable a, RealFloat a) => a -> a -> PriorFunctionG a a
gamma a
k a
t
  where
    (a
k, a
t) = forall a. Fractional a => a -> a -> (a, a)
gammaMeanVarianceToShapeScale a
m a
v
{-# SPECIALIZE gammaMeanVariance :: Double -> Double -> PriorFunction Double #-}

-- | Gamma disstributed prior with given shape and mean 1.0.
gammaMeanOne :: (Typeable a, RealFloat a) => Shape a -> PriorFunctionG a a
gammaMeanOne :: forall a. (Typeable a, RealFloat a) => a -> PriorFunctionG a a
gammaMeanOne a
k = forall a. (Typeable a, RealFloat a) => a -> a -> PriorFunctionG a a
gamma a
k (forall a. Fractional a => a -> a
recip a
k)
{-# SPECIALIZE gammaMeanOne :: Double -> PriorFunction Double #-}

-- The mean and variance of the gamma distribution are
--
-- m = k*t
--
-- v = k*t*t
--
-- Hence, the shape and scale are
--
-- k = m^2/v
--
-- t = v/m

-- | Calculate mean and variance of the gamma distribution given the shape and
-- the scale.
gammaShapeScaleToMeanVariance :: Num a => Shape a -> Scale a -> (Mean a, Variance a)
gammaShapeScaleToMeanVariance :: forall a. Num a => a -> a -> (a, a)
gammaShapeScaleToMeanVariance a
k a
t = let m :: a
m = a
k forall a. Num a => a -> a -> a
* a
t in (a
m, a
m forall a. Num a => a -> a -> a
* a
t)
{-# SPECIALIZE gammaShapeScaleToMeanVariance :: Double -> Double -> (Double, Double) #-}

-- | Calculate shape and scale of the gamma distribution given the mean and
-- the variance.
gammaMeanVarianceToShapeScale :: Fractional a => Mean a -> Variance a -> (Shape a, Scale a)
gammaMeanVarianceToShapeScale :: forall a. Fractional a => a -> a -> (a, a)
gammaMeanVarianceToShapeScale a
m a
v = (a
m forall a. Num a => a -> a -> a
* a
m forall a. Fractional a => a -> a -> a
/ a
v, a
v forall a. Fractional a => a -> a -> a
/ a
m)
{-# SPECIALIZE gammaMeanVarianceToShapeScale :: Double -> Double -> (Double, Double) #-}

mLnSqrt2Pi :: RealFloat a => a
mLnSqrt2Pi :: forall a. RealFloat a => a
mLnSqrt2Pi = a
0.9189385332046727417803297364056176398613974736377834128171
{-# INLINE mLnSqrt2Pi #-}

-- | Log normal distributed prior.
--
-- NOTE: The log normal distribution is parametrized with the mean \(\mu\) and
-- the standard deviation \(\sigma\) of the underlying normal distribution. The
-- mean and variance of the log normal distribution itself are functions of
-- \(\mu\) and \(\sigma\), but are not the same as \(\mu\) and \(\sigma\)!
--
-- Call 'error' if the standard deviation is zero or negative.
logNormal :: RealFloat a => Mean a -> StandardDeviation a -> PriorFunctionG a a
logNormal :: forall a. RealFloat a => a -> a -> PriorFunctionG a a
logNormal a
m a
s a
x
  | a
s forall a. Ord a => a -> a -> Bool
<= a
0.0 = forall a. HasCallStack => [Char] -> a
error [Char]
"logNormal: Standard deviation is zero or negative."
  | a
x forall a. Ord a => a -> a -> Bool
<= a
0.0 = Log a
0.0
  | Bool
otherwise = forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ a
t forall a. Num a => a -> a -> a
+ a
e
  where
    t :: a
t = forall a. Num a => a -> a
negate forall a b. (a -> b) -> a -> b
$ forall a. RealFloat a => a
mLnSqrt2Pi forall a. Num a => a -> a -> a
+ forall a. Floating a => a -> a
log (a
x forall a. Num a => a -> a -> a
* a
s)
    a :: a
a = forall a. Fractional a => a -> a
recip forall a b. (a -> b) -> a -> b
$ a
2.0 forall a. Num a => a -> a -> a
* a
s forall a. Num a => a -> a -> a
* a
s
    b :: a
b = forall a. Floating a => a -> a
log a
x forall a. Num a => a -> a -> a
- a
m
    e :: a
e = forall a. Num a => a -> a
negate forall a b. (a -> b) -> a -> b
$ a
a forall a. Num a => a -> a -> a
* a
b forall a. Num a => a -> a -> a
* a
b

-- | Normal distributed prior.
--
-- Call 'error' if the standard deviation is zero or negative.
normal :: RealFloat a => Mean a -> StandardDeviation a -> PriorFunctionG a a
normal :: forall a. RealFloat a => a -> a -> PriorFunctionG a a
normal a
m a
s a
x
  | a
s forall a. Ord a => a -> a -> Bool
<= a
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"normal: Standard deviation is zero or negative."
  | Bool
otherwise = forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ (-a
xm forall a. Num a => a -> a -> a
* a
xm forall a. Fractional a => a -> a -> a
/ (a
2 forall a. Num a => a -> a -> a
* a
s forall a. Num a => a -> a -> a
* a
s)) forall a. Num a => a -> a -> a
- a
denom
  where
    xm :: a
xm = a
x forall a. Num a => a -> a -> a
- a
m
    denom :: a
denom = forall a. RealFloat a => a
mLnSqrt2Pi forall a. Num a => a -> a -> a
+ forall a. Floating a => a -> a
log a
s
{-# SPECIALIZE normal :: Double -> Double -> PriorFunction Double #-}

-- | Uniform prior on [a, b].
--
-- Call 'error' if the lower boundary is greather than the upper boundary.
uniform :: RealFloat a => LowerBoundary a -> UpperBoundary a -> PriorFunctionG a a
uniform :: forall a. RealFloat a => a -> a -> PriorFunctionG a a
uniform a
a a
b a
x
  | a
a forall a. Ord a => a -> a -> Bool
> a
b = forall a. HasCallStack => [Char] -> a
error [Char]
"uniform: Lower boundary is greater than upper boundary."
  | a
x forall a. Ord a => a -> a -> Bool
< a
a = Log a
0.0
  | a
x forall a. Ord a => a -> a -> Bool
> a
b = Log a
0.0
  | Bool
otherwise = Log a
1.0
{-# SPECIALIZE uniform :: Double -> Double -> PriorFunction Double #-}

-- | Poisson distributed prior.
--
-- Call 'error' if the rate is zero or negative.
poisson :: (RealFloat a, Typeable a) => Rate a -> PriorFunctionG Int a
poisson :: forall a. (RealFloat a, Typeable a) => a -> PriorFunctionG Int a
poisson a
l Int
n
  | a
l forall a. Ord a => a -> a -> Bool
< a
0.0 = forall a. HasCallStack => [Char] -> a
error [Char]
"poisson: Rate is zero or negative."
  | Int
n forall a. Ord a => a -> a -> Bool
< Int
0 = Log a
0.0
  | Bool
otherwise = forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
log a
l forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n forall a. Num a => a -> a -> a
- forall a b. (Integral a, RealFloat b, Typeable b) => a -> b
logFactorialG Int
n forall a. Num a => a -> a -> a
- a
l

-- | Intelligent product that stops when encountering a zero.
--
-- Use with care because the elements are checked for positiveness, and this can
-- take some time if the list is long and does not contain any zeroes.
product' :: RealFloat a => [Log a] -> Log a
product' :: forall a. RealFloat a => [Log a] -> Log a
product' = forall a. a -> Maybe a -> a
fromMaybe Log a
0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. RealFloat a => [Log a] -> Maybe (Log a)
prodM
{-# SPECIALIZE product' :: [Log Double] -> Log Double #-}

-- The type could be generalized to any MonadPlus Integer
prodM :: RealFloat a => [Log a] -> Maybe (Log a)
prodM :: forall a. RealFloat a => [Log a] -> Maybe (Log a)
prodM = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\ !Log a
acc Log a
x -> (Log a
acc forall a. Num a => a -> a -> a
* Log a
x) forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Log a
acc forall a. Eq a => a -> a -> Bool
/= Log a
0.0)) Log a
1.0
{-# SPECIALIZE prodM :: [Log Double] -> Maybe (Log Double) #-}