{-# LANGUAGE BangPatterns #-}

-- |
-- Module      :  Prior
-- Description :  Types and convenience functions for computing priors
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Thu Jul 23 13:26:14 2020.
module Mcmc.Prior
  ( Prior,
    PriorG,
    PriorFunction,
    PriorFunctionG,

    -- * Improper priors
    noPrior,
    greaterThan,
    positive,
    lessThan,
    negative,

    -- * Continuous priors
    exponential,
    gamma,
    gammaMeanVariance,
    gammaMeanOne,
    gammaShapeScaleToMeanVariance,
    gammaMeanVarianceToShapeScale,
    logNormal,
    normal,
    uniform,

    -- * Discrete priors
    poisson,

    -- * Auxiliary functions
    product',
  )
where

import Control.Monad
import Data.Maybe (fromMaybe)
import Data.Typeable
import Mcmc.Internal.SpecFunctions
import Mcmc.Statistics.Types
import Numeric.Log

-- TODO (high): Think about using a "structure" variable.
--
-- For example,
--
-- type PriorFunctionG s a = s a -> PriorG a
--
-- Many of the prior functions would need to use the Identity Functor. This may
-- be slow.

-- | Prior values are stored in log domain.
type Prior = PriorG Double

-- | Generalized prior.
type PriorG a = Log a

-- | Prior function.
type PriorFunction a = PriorFunctionG a Double

-- | Generalized prior function.
type PriorFunctionG a b = a -> PriorG b

-- | Flat prior function. Useful for testing and debugging.
noPrior :: RealFloat b => PriorFunctionG a b
noPrior :: PriorFunctionG a b
noPrior = PriorG b -> PriorFunctionG a b
forall a b. a -> b -> a
const PriorG b
1.0
{-# SPECIALIZE noPrior :: PriorFunction Double #-}

-- | Improper uniform prior; strictly greater than a given value.
greaterThan :: RealFloat a => LowerBoundary a -> PriorFunctionG a a
greaterThan :: LowerBoundary a
-> PriorFunctionG (LowerBoundary a) (LowerBoundary a)
greaterThan LowerBoundary a
a LowerBoundary a
x
  | LowerBoundary a
x LowerBoundary a -> LowerBoundary a -> Bool
forall a. Ord a => a -> a -> Bool
> LowerBoundary a
a = PriorG (LowerBoundary a)
1.0
  | Bool
otherwise = PriorG (LowerBoundary a)
0.0
{-# SPECIALIZE greaterThan :: Double -> PriorFunction Double #-}

-- | Improper uniform prior; strictly greater than zero.
positive :: RealFloat a => PriorFunctionG a a
positive :: PriorFunctionG a a
positive = a -> PriorFunctionG a a
forall a. RealFloat a => a -> PriorFunctionG a a
greaterThan a
0
{-# SPECIALIZE positive :: PriorFunction Double #-}

-- | Improper uniform prior; strictly less than a given value.
lessThan :: RealFloat a => UpperBoundary a -> PriorFunctionG a a
lessThan :: UpperBoundary a
-> PriorFunctionG (UpperBoundary a) (UpperBoundary a)
lessThan UpperBoundary a
a UpperBoundary a
x
  | UpperBoundary a
x UpperBoundary a -> UpperBoundary a -> Bool
forall a. Ord a => a -> a -> Bool
< UpperBoundary a
a = PriorG (UpperBoundary a)
1.0
  | Bool
otherwise = PriorG (UpperBoundary a)
0.0
{-# SPECIALIZE lessThan :: Double -> PriorFunction Double #-}

-- | Improper uniform prior; strictly less than zero.
negative :: RealFloat a => PriorFunctionG a a
negative :: PriorFunctionG a a
negative = a -> PriorFunctionG a a
forall a. RealFloat a => a -> PriorFunctionG a a
lessThan a
0.0
{-# SPECIALIZE negative :: PriorFunction Double #-}

-- | Exponential distributed prior.
--
-- Call 'error' if the rate is zero or negative.
exponential :: RealFloat a => Rate a -> PriorFunctionG a a
exponential :: Rate a -> PriorFunctionG (Rate a) (Rate a)
exponential Rate a
l Rate a
x
  | Rate a
l Rate a -> Rate a -> Bool
forall a. Ord a => a -> a -> Bool
<= Rate a
0.0 = [Char] -> PriorG (Rate a)
forall a. HasCallStack => [Char] -> a
error [Char]
"exponential: Rate is zero or negative."
  | Rate a
x Rate a -> Rate a -> Bool
forall a. Ord a => a -> a -> Bool
<= Rate a
0.0 = PriorG (Rate a)
0.0
  | Bool
otherwise = PriorG (Rate a)
ll PriorG (Rate a) -> PriorG (Rate a) -> PriorG (Rate a)
forall a. Num a => a -> a -> a
* PriorFunctionG (Rate a) (Rate a)
forall a. a -> Log a
Exp (Rate a -> Rate a
forall a. Num a => a -> a
negate Rate a
l Rate a -> Rate a -> Rate a
forall a. Num a => a -> a -> a
* Rate a
x)
  where
    ll :: PriorG (Rate a)
ll = PriorFunctionG (Rate a) (Rate a)
forall a. a -> Log a
Exp PriorFunctionG (Rate a) (Rate a)
-> PriorFunctionG (Rate a) (Rate a)
forall a b. (a -> b) -> a -> b
$ Rate a -> Rate a
forall a. Floating a => a -> a
log Rate a
l
{-# SPECIALIZE exponential :: Double -> PriorFunction Double #-}

-- | Gamma distributed prior.
--
-- Call 'error' if the shape or scale are zero or negative.
gamma :: (Typeable a, RealFloat a) => Shape a -> Scale a -> PriorFunctionG a a
gamma :: Shape a -> Shape a -> PriorFunctionG (Shape a) (Shape a)
gamma Shape a
k Shape a
t Shape a
x
  | Shape a
k Shape a -> Shape a -> Bool
forall a. Ord a => a -> a -> Bool
<= Shape a
0.0 = [Char] -> PriorG (Shape a)
forall a. HasCallStack => [Char] -> a
error [Char]
"gamma: Shape is zero or negative."
  | Shape a
t Shape a -> Shape a -> Bool
forall a. Ord a => a -> a -> Bool
<= Shape a
0.0 = [Char] -> PriorG (Shape a)
forall a. HasCallStack => [Char] -> a
error [Char]
"gamma: Scale is zero or negative."
  | Shape a
x Shape a -> Shape a -> Bool
forall a. Ord a => a -> a -> Bool
<= Shape a
0.0 = PriorG (Shape a)
0.0
  | Bool
otherwise = PriorFunctionG (Shape a) (Shape a)
forall a. a -> Log a
Exp PriorFunctionG (Shape a) (Shape a)
-> PriorFunctionG (Shape a) (Shape a)
forall a b. (a -> b) -> a -> b
$ Shape a -> Shape a
forall a. Floating a => a -> a
log Shape a
x Shape a -> Shape a -> Shape a
forall a. Num a => a -> a -> a
* (Shape a
k Shape a -> Shape a -> Shape a
forall a. Num a => a -> a -> a
- Shape a
1.0) Shape a -> Shape a -> Shape a
forall a. Num a => a -> a -> a
- (Shape a
x Shape a -> Shape a -> Shape a
forall a. Fractional a => a -> a -> a
/ Shape a
t) Shape a -> Shape a -> Shape a
forall a. Num a => a -> a -> a
- Shape a -> Shape a
forall a. (Typeable a, RealFloat a) => a -> a
logGammaG Shape a
k Shape a -> Shape a -> Shape a
forall a. Num a => a -> a -> a
- Shape a -> Shape a
forall a. Floating a => a -> a
log Shape a
t Shape a -> Shape a -> Shape a
forall a. Num a => a -> a -> a
* Shape a
k
{-# SPECIALIZE gamma :: Double -> Double -> PriorFunction Double #-}

-- | See 'gamma' but parametrized using mean and variance.
gammaMeanVariance :: (Typeable a, RealFloat a) => Mean a -> Variance a -> PriorFunctionG a a
gammaMeanVariance :: Mean a -> Mean a -> PriorFunctionG (Mean a) (Mean a)
gammaMeanVariance Mean a
m Mean a
v = Mean a -> Mean a -> PriorFunctionG (Mean a) (Mean a)
forall a. (Typeable a, RealFloat a) => a -> a -> PriorFunctionG a a
gamma Mean a
k Mean a
t
  where
    (Mean a
k, Mean a
t) = Mean a -> Mean a -> (Mean a, Mean a)
forall a. Fractional a => a -> a -> (a, a)
gammaMeanVarianceToShapeScale Mean a
m Mean a
v
{-# SPECIALIZE gammaMeanVariance :: Double -> Double -> PriorFunction Double #-}

-- | Gamma disstributed prior with given shape and mean 1.0.
gammaMeanOne :: (Typeable a, RealFloat a) => Shape a -> PriorFunctionG a a
gammaMeanOne :: Shape a -> PriorFunctionG (Shape a) (Shape a)
gammaMeanOne Shape a
k = Shape a -> Shape a -> PriorFunctionG (Shape a) (Shape a)
forall a. (Typeable a, RealFloat a) => a -> a -> PriorFunctionG a a
gamma Shape a
k (Shape a -> Shape a
forall a. Fractional a => a -> a
recip Shape a
k)
{-# SPECIALIZE gammaMeanOne :: Double -> PriorFunction Double #-}

-- The mean and variance of the gamma distribution are
--
-- m = k*t
--
-- v = k*t*t
--
-- Hence, the shape and scale are
--
-- k = m^2/v
--
-- t = v/m

-- | Calculate mean and variance of the gamma distribution given the shape and
-- the scale.
gammaShapeScaleToMeanVariance :: Num a => Shape a -> Scale a -> (Mean a, Variance a)
gammaShapeScaleToMeanVariance :: Shape a -> Shape a -> (Shape a, Shape a)
gammaShapeScaleToMeanVariance Shape a
k Shape a
t = let m :: Shape a
m = Shape a
k Shape a -> Shape a -> Shape a
forall a. Num a => a -> a -> a
* Shape a
t in (Shape a
m, Shape a
m Shape a -> Shape a -> Shape a
forall a. Num a => a -> a -> a
* Shape a
t)
{-# SPECIALIZE gammaShapeScaleToMeanVariance :: Double -> Double -> (Double, Double) #-}

-- | Calculate shape and scale of the gamma distribution given the mean and
-- the variance.
gammaMeanVarianceToShapeScale :: Fractional a => Mean a -> Variance a -> (Shape a, Scale a)
gammaMeanVarianceToShapeScale :: Mean a -> Mean a -> (Mean a, Mean a)
gammaMeanVarianceToShapeScale Mean a
m Mean a
v = (Mean a
m Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
* Mean a
m Mean a -> Mean a -> Mean a
forall a. Fractional a => a -> a -> a
/ Mean a
v, Mean a
v Mean a -> Mean a -> Mean a
forall a. Fractional a => a -> a -> a
/ Mean a
m)
{-# SPECIALIZE gammaMeanVarianceToShapeScale :: Double -> Double -> (Double, Double) #-}

mLnSqrt2Pi :: RealFloat a => a
mLnSqrt2Pi :: a
mLnSqrt2Pi = a
0.9189385332046727417803297364056176398613974736377834128171
{-# INLINE mLnSqrt2Pi #-}

-- | Log normal distributed prior.
--
-- NOTE: The log normal distribution is parametrized with the mean \(\mu\) and
-- the standard deviation \(\sigma\) of the underlying normal distribution. The
-- mean and variance of the log normal distribution itself are functions of
-- \(\mu\) and \(\sigma\), but are not the same as \(\mu\) and \(\sigma\)!
--
-- Call 'error' if the standard deviation is zero or negative.
logNormal :: RealFloat a => Mean a -> StandardDeviation a -> PriorFunctionG a a
logNormal :: Mean a -> Mean a -> PriorFunctionG (Mean a) (Mean a)
logNormal Mean a
m Mean a
s Mean a
x
  | Mean a
s Mean a -> Mean a -> Bool
forall a. Ord a => a -> a -> Bool
<= Mean a
0.0 = [Char] -> PriorG (Mean a)
forall a. HasCallStack => [Char] -> a
error [Char]
"logNormal: Standard deviation is zero or negative."
  | Mean a
x Mean a -> Mean a -> Bool
forall a. Ord a => a -> a -> Bool
<= Mean a
0.0 = PriorG (Mean a)
0.0
  | Bool
otherwise = PriorFunctionG (Mean a) (Mean a)
forall a. a -> Log a
Exp PriorFunctionG (Mean a) (Mean a)
-> PriorFunctionG (Mean a) (Mean a)
forall a b. (a -> b) -> a -> b
$ Mean a
t Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
+ Mean a
e
  where
    t :: Mean a
t = Mean a -> Mean a
forall a. Num a => a -> a
negate (Mean a -> Mean a) -> Mean a -> Mean a
forall a b. (a -> b) -> a -> b
$ Mean a
forall a. RealFloat a => a
mLnSqrt2Pi Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
+ Mean a -> Mean a
forall a. Floating a => a -> a
log (Mean a
x Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
* Mean a
s)
    a :: Mean a
a = Mean a -> Mean a
forall a. Fractional a => a -> a
recip (Mean a -> Mean a) -> Mean a -> Mean a
forall a b. (a -> b) -> a -> b
$ Mean a
2.0 Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
* Mean a
s Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
* Mean a
s
    b :: Mean a
b = Mean a -> Mean a
forall a. Floating a => a -> a
log Mean a
x Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
- Mean a
m
    e :: Mean a
e = Mean a -> Mean a
forall a. Num a => a -> a
negate (Mean a -> Mean a) -> Mean a -> Mean a
forall a b. (a -> b) -> a -> b
$ Mean a
a Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
* Mean a
b Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
* Mean a
b

-- | Normal distributed prior.
--
-- Call 'error' if the standard deviation is zero or negative.
normal :: RealFloat a => Mean a -> StandardDeviation a -> PriorFunctionG a a
normal :: Mean a -> Mean a -> PriorFunctionG (Mean a) (Mean a)
normal Mean a
m Mean a
s Mean a
x
  | Mean a
s Mean a -> Mean a -> Bool
forall a. Ord a => a -> a -> Bool
<= Mean a
0 = [Char] -> PriorG (Mean a)
forall a. HasCallStack => [Char] -> a
error [Char]
"normal: Standard deviation is zero or negative."
  | Bool
otherwise = PriorFunctionG (Mean a) (Mean a)
forall a. a -> Log a
Exp PriorFunctionG (Mean a) (Mean a)
-> PriorFunctionG (Mean a) (Mean a)
forall a b. (a -> b) -> a -> b
$ (-Mean a
xm Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
* Mean a
xm Mean a -> Mean a -> Mean a
forall a. Fractional a => a -> a -> a
/ (Mean a
2 Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
* Mean a
s Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
* Mean a
s)) Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
- Mean a
denom
  where
    xm :: Mean a
xm = Mean a
x Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
- Mean a
m
    denom :: Mean a
denom = Mean a
forall a. RealFloat a => a
mLnSqrt2Pi Mean a -> Mean a -> Mean a
forall a. Num a => a -> a -> a
+ Mean a -> Mean a
forall a. Floating a => a -> a
log Mean a
s
{-# SPECIALIZE normal :: Double -> Double -> PriorFunction Double #-}

-- | Uniform prior on [a, b].
--
-- Call 'error' if the lower boundary is greather than the upper boundary.
uniform :: RealFloat a => LowerBoundary a -> UpperBoundary a -> PriorFunctionG a a
uniform :: LowerBoundary a
-> LowerBoundary a
-> PriorFunctionG (LowerBoundary a) (LowerBoundary a)
uniform LowerBoundary a
a LowerBoundary a
b LowerBoundary a
x
  | LowerBoundary a
a LowerBoundary a -> LowerBoundary a -> Bool
forall a. Ord a => a -> a -> Bool
> LowerBoundary a
b = [Char] -> PriorG (LowerBoundary a)
forall a. HasCallStack => [Char] -> a
error [Char]
"uniform: Lower boundary is greater than upper boundary."
  | LowerBoundary a
x LowerBoundary a -> LowerBoundary a -> Bool
forall a. Ord a => a -> a -> Bool
< LowerBoundary a
a = PriorG (LowerBoundary a)
0.0
  | LowerBoundary a
x LowerBoundary a -> LowerBoundary a -> Bool
forall a. Ord a => a -> a -> Bool
> LowerBoundary a
b = PriorG (LowerBoundary a)
0.0
  | Bool
otherwise = PriorG (LowerBoundary a)
1.0
{-# SPECIALIZE uniform :: Double -> Double -> PriorFunction Double #-}

-- | Poisson distributed prior.
--
-- Call 'error' if the rate is zero or negative.
poisson :: (RealFloat a, Typeable a) => Rate a -> PriorFunctionG Int a
poisson :: Rate a -> PriorFunctionG Int (Rate a)
poisson Rate a
l Int
n
  | Rate a
l Rate a -> Rate a -> Bool
forall a. Ord a => a -> a -> Bool
< Rate a
0.0 = [Char] -> PriorG (Rate a)
forall a. HasCallStack => [Char] -> a
error [Char]
"poisson: Rate is zero or negative."
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = PriorG (Rate a)
0.0
  | Bool
otherwise = Rate a -> PriorG (Rate a)
forall a. a -> Log a
Exp (Rate a -> PriorG (Rate a)) -> Rate a -> PriorG (Rate a)
forall a b. (a -> b) -> a -> b
$ Rate a -> Rate a
forall a. Floating a => a -> a
log Rate a
l Rate a -> Rate a -> Rate a
forall a. Num a => a -> a -> a
* Int -> Rate a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n Rate a -> Rate a -> Rate a
forall a. Num a => a -> a -> a
- Int -> Rate a
forall a b. (Integral a, RealFloat b, Typeable b) => a -> b
logFactorialG Int
n Rate a -> Rate a -> Rate a
forall a. Num a => a -> a -> a
- Rate a
l

-- | Intelligent product that stops when encountering a zero.
--
-- Use with care because the elements are checked for positiveness, and this can
-- take some time if the list is long and does not contain any zeroes.
product' :: RealFloat a => [Log a] -> Log a
product' :: [Log a] -> Log a
product' = Log a -> Maybe (Log a) -> Log a
forall a. a -> Maybe a -> a
fromMaybe Log a
0 (Maybe (Log a) -> Log a)
-> ([Log a] -> Maybe (Log a)) -> [Log a] -> Log a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Log a] -> Maybe (Log a)
forall a. RealFloat a => [Log a] -> Maybe (Log a)
prodM
{-# SPECIALIZE product' :: [Log Double] -> Log Double #-}

-- The type could be generalized to any MonadPlus Integer
prodM :: RealFloat a => [Log a] -> Maybe (Log a)
prodM :: [Log a] -> Maybe (Log a)
prodM = (Log a -> Log a -> Maybe (Log a))
-> Log a -> [Log a] -> Maybe (Log a)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\ !Log a
acc Log a
x -> (Log a
acc Log a -> Log a -> Log a
forall a. Num a => a -> a -> a
* Log a
x) Log a -> Maybe () -> Maybe (Log a)
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Log a
acc Log a -> Log a -> Bool
forall a. Eq a => a -> a -> Bool
/= Log a
0.0)) Log a
1.0
{-# SPECIALIZE prodM :: [Log Double] -> Maybe (Log Double) #-}