-- |
-- Module      :  Statistics.Distribution.Dirichlet
-- Description :  Multivariate Dirichlet distribution
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Tue Oct 20 10:10:39 2020.
module Statistics.Distribution.Dirichlet
  ( -- * Dirichlet distribution
    DirichletDistribution (ddGetParameters),
    dirichletDistribution,
    dirichletDensity,
    dirichletSample,

    -- * Symmetric Dirichlet distribution
    DirichletDistributionSymmetric (ddSymGetParameter),
    dirichletDistributionSymmetric,
    dirichletDensitySymmetric,
    dirichletSampleSymmetric,
  )
where

import qualified Data.Vector.Unboxed as V
import Numeric.Log
import Numeric.SpecFunctions
import System.Random.MWC.Distributions
import System.Random.Stateful

-- | The [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution).
data DirichletDistribution = DirichletDistribution
  { DirichletDistribution -> Vector Double
ddGetParameters :: V.Vector Double,
    DirichletDistribution -> Int
_getDimension :: Int,
    DirichletDistribution -> Log Double
_getNormConst :: Log Double
  }
  deriving (DirichletDistribution -> DirichletDistribution -> Bool
(DirichletDistribution -> DirichletDistribution -> Bool)
-> (DirichletDistribution -> DirichletDistribution -> Bool)
-> Eq DirichletDistribution
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DirichletDistribution -> DirichletDistribution -> Bool
$c/= :: DirichletDistribution -> DirichletDistribution -> Bool
== :: DirichletDistribution -> DirichletDistribution -> Bool
$c== :: DirichletDistribution -> DirichletDistribution -> Bool
Eq, Int -> DirichletDistribution -> ShowS
[DirichletDistribution] -> ShowS
DirichletDistribution -> String
(Int -> DirichletDistribution -> ShowS)
-> (DirichletDistribution -> String)
-> ([DirichletDistribution] -> ShowS)
-> Show DirichletDistribution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DirichletDistribution] -> ShowS
$cshowList :: [DirichletDistribution] -> ShowS
show :: DirichletDistribution -> String
$cshow :: DirichletDistribution -> String
showsPrec :: Int -> DirichletDistribution -> ShowS
$cshowsPrec :: Int -> DirichletDistribution -> ShowS
Show)

-- Check if vector is strictly positive.
isNegativeOrZero :: V.Vector Double -> Bool
isNegativeOrZero :: Vector Double -> Bool
isNegativeOrZero = (Double -> Bool) -> Vector Double -> Bool
forall a. Unbox a => (a -> Bool) -> Vector a -> Bool
V.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
0)

-- Inverse multivariate beta function. Does not check if parameters are valid!
invBeta :: V.Vector Double -> Log Double
invBeta :: Vector Double -> Log Double
invBeta Vector Double
v = Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double) -> Double -> Log Double
forall a b. (a -> b) -> a -> b
$ Double
logDenominator Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
logNominator
  where
    logNominator :: Double
logNominator = Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum (Vector Double -> Double) -> Vector Double -> Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> Vector Double -> Vector Double
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
V.map Double -> Double
logGamma Vector Double
v
    logDenominator :: Double
logDenominator = Double -> Double
logGamma (Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum Vector Double
v)

-- | Create a Dirichlet distribution from the given parameter vector.
--
-- Return Left if:
--
-- - The parameter vector has less then two elements.
--
-- - One or more parameters are negative or zero.
dirichletDistribution :: V.Vector Double -> Either String DirichletDistribution
dirichletDistribution :: Vector Double -> Either String DirichletDistribution
dirichletDistribution Vector Double
v
  | Vector Double -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector Double
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 =
      String -> Either String DirichletDistribution
forall a b. a -> Either a b
Left String
"dirichletDistribution: Parameter vector is too short."
  | Vector Double -> Bool
isNegativeOrZero Vector Double
v =
      String -> Either String DirichletDistribution
forall a b. a -> Either a b
Left String
"dirichletDistribution: One or more parameters are negative or zero."
  | Bool
otherwise = DirichletDistribution -> Either String DirichletDistribution
forall a b. b -> Either a b
Right (DirichletDistribution -> Either String DirichletDistribution)
-> DirichletDistribution -> Either String DirichletDistribution
forall a b. (a -> b) -> a -> b
$ Vector Double -> Int -> Log Double -> DirichletDistribution
DirichletDistribution Vector Double
v (Vector Double -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector Double
v) (Vector Double -> Log Double
invBeta Vector Double
v)

-- Tolerance.
eps :: Double
eps :: Double
eps = Double
1e-14

-- Check if vector is normalized with tolerance 'eps'.
isNormalized :: V.Vector Double -> Bool
isNormalized :: Vector Double -> Bool
isNormalized Vector Double
v
  | Double -> Double
forall a. Num a => a -> a
abs (Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum Vector Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1.0) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
eps = Bool
False
  | Bool
otherwise = Bool
True

-- | Density of the Dirichlet distribution evaluated at a given value vector.
--
-- Return 0 if:
--
-- - The value vector has a different length than the parameter vector.
--
-- - The value vector has elements being negative or zero.
--
-- - The value vector does not sum to 1.0 (with tolerance @eps = 1e-14@).
dirichletDensity :: DirichletDistribution -> V.Vector Double -> Log Double
dirichletDensity :: DirichletDistribution -> Vector Double -> Log Double
dirichletDensity (DirichletDistribution Vector Double
as Int
k Log Double
c) Vector Double
xs
  | Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Vector Double -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector Double
xs = Log Double
0
  | Vector Double -> Bool
isNegativeOrZero Vector Double
xs = Log Double
0
  | Bool -> Bool
not (Vector Double -> Bool
isNormalized Vector Double
xs) = Log Double
0
  | Bool
otherwise = Log Double
c Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Double -> Log Double
forall a. a -> Log a
Exp Double
logXsPow
  where
    logXsPow :: Double
logXsPow = Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum (Vector Double -> Double) -> Vector Double -> Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Vector Double -> Vector Double -> Vector Double
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (\Double
a Double
x -> Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
x Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1.0)) Vector Double
as Vector Double
xs

-- | Sample a value vector from the Dirichlet distribution.
dirichletSample :: StatefulGen g m => DirichletDistribution -> g -> m (V.Vector Double)
dirichletSample :: DirichletDistribution -> g -> m (Vector Double)
dirichletSample (DirichletDistribution Vector Double
as Int
_ Log Double
_) g
g = do
  Vector Double
ys <- (Double -> m Double) -> Vector Double -> m (Vector Double)
forall (m :: * -> *) a b.
(Monad m, Unbox a, Unbox b) =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM (\Double
a -> Double -> Double -> g -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
gamma Double
a Double
1.0 g
g) Vector Double
as
  let s :: Double
s = Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum Vector Double
ys
  Vector Double -> m (Vector Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (Vector Double -> m (Vector Double))
-> Vector Double -> m (Vector Double)
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> Vector Double -> Vector Double
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
V.map (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
s) Vector Double
ys

-- | See 'DirichletDistribution' but with parameter vector @replicate DIM VAL@.
data DirichletDistributionSymmetric = DirichletDistributionSymmetric
  { DirichletDistributionSymmetric -> Double
ddSymGetParameter :: Double,
    DirichletDistributionSymmetric -> Int
_symGetDimension :: Int,
    DirichletDistributionSymmetric -> Log Double
_symGetNormConst :: Log Double
  }
  deriving (DirichletDistributionSymmetric
-> DirichletDistributionSymmetric -> Bool
(DirichletDistributionSymmetric
 -> DirichletDistributionSymmetric -> Bool)
-> (DirichletDistributionSymmetric
    -> DirichletDistributionSymmetric -> Bool)
-> Eq DirichletDistributionSymmetric
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DirichletDistributionSymmetric
-> DirichletDistributionSymmetric -> Bool
$c/= :: DirichletDistributionSymmetric
-> DirichletDistributionSymmetric -> Bool
== :: DirichletDistributionSymmetric
-> DirichletDistributionSymmetric -> Bool
$c== :: DirichletDistributionSymmetric
-> DirichletDistributionSymmetric -> Bool
Eq, Int -> DirichletDistributionSymmetric -> ShowS
[DirichletDistributionSymmetric] -> ShowS
DirichletDistributionSymmetric -> String
(Int -> DirichletDistributionSymmetric -> ShowS)
-> (DirichletDistributionSymmetric -> String)
-> ([DirichletDistributionSymmetric] -> ShowS)
-> Show DirichletDistributionSymmetric
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DirichletDistributionSymmetric] -> ShowS
$cshowList :: [DirichletDistributionSymmetric] -> ShowS
show :: DirichletDistributionSymmetric -> String
$cshow :: DirichletDistributionSymmetric -> String
showsPrec :: Int -> DirichletDistributionSymmetric -> ShowS
$cshowsPrec :: Int -> DirichletDistributionSymmetric -> ShowS
Show)

-- Inverse multivariate beta function. Does not check if parameters are valid!
invBetaSym :: Int -> Double -> Log Double
invBetaSym :: Int -> Double -> Log Double
invBetaSym Int
k Double
a = Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double) -> Double -> Log Double
forall a b. (a -> b) -> a -> b
$ Double
logDenominator Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
logNominator
  where
    logNominator :: Double
logNominator = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
logGamma Double
a
    logDenominator :: Double
logDenominator = Double -> Double
logGamma (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
a)

-- | Create a symmetric Dirichlet distribution of given dimension and parameter.
--
-- Return Left if:
--
-- - The given dimension is smaller than two.
--
-- - The parameter is negative or zero.
dirichletDistributionSymmetric :: Int -> Double -> Either String DirichletDistributionSymmetric
dirichletDistributionSymmetric :: Int -> Double -> Either String DirichletDistributionSymmetric
dirichletDistributionSymmetric Int
k Double
a
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 =
      String -> Either String DirichletDistributionSymmetric
forall a b. a -> Either a b
Left String
"dirichletDistributionSymmetric: The dimension is smaller than two."
  | Double
a Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
0 =
      String -> Either String DirichletDistributionSymmetric
forall a b. a -> Either a b
Left String
"dirichletDistributionSymmetric: The parameter is negative or zero."
  | Bool
otherwise = DirichletDistributionSymmetric
-> Either String DirichletDistributionSymmetric
forall a b. b -> Either a b
Right (DirichletDistributionSymmetric
 -> Either String DirichletDistributionSymmetric)
-> DirichletDistributionSymmetric
-> Either String DirichletDistributionSymmetric
forall a b. (a -> b) -> a -> b
$ Double -> Int -> Log Double -> DirichletDistributionSymmetric
DirichletDistributionSymmetric Double
a Int
k (Int -> Double -> Log Double
invBetaSym Int
k Double
a)

-- | Density of the symmetric Dirichlet distribution evaluated at a given value
-- vector.
--
-- Return 0 if:
--
-- - The value vector has a different dimension.
--
-- - The value vector has elements being negative or zero.
--
-- - The value vector does not sum to 1.0 (with tolerance @eps = 1e-14@).
dirichletDensitySymmetric :: DirichletDistributionSymmetric -> V.Vector Double -> Log Double
dirichletDensitySymmetric :: DirichletDistributionSymmetric -> Vector Double -> Log Double
dirichletDensitySymmetric (DirichletDistributionSymmetric Double
a Int
k Log Double
c) Vector Double
xs
  | Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Vector Double -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector Double
xs = Log Double
0
  | Vector Double -> Bool
isNegativeOrZero Vector Double
xs = Log Double
0
  | Bool -> Bool
not (Vector Double -> Bool
isNormalized Vector Double
xs) = Log Double
0
  | Bool
otherwise = Log Double
c Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Double -> Log Double
forall a. a -> Log a
Exp Double
logXsPow
  where
    logXsPow :: Double
logXsPow = Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum (Vector Double -> Double) -> Vector Double -> Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> Vector Double -> Vector Double
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
V.map (\Double
x -> Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
x Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1.0)) Vector Double
xs

-- | Sample a value vector from the symmetric Dirichlet distribution.
dirichletSampleSymmetric ::
  StatefulGen g m =>
  DirichletDistributionSymmetric ->
  g ->
  m (V.Vector Double)
dirichletSampleSymmetric :: DirichletDistributionSymmetric -> g -> m (Vector Double)
dirichletSampleSymmetric (DirichletDistributionSymmetric Double
a Int
k Log Double
_) g
g = do
  Vector Double
ys <- Int -> m Double -> m (Vector Double)
forall (m :: * -> *) a.
(Monad m, Unbox a) =>
Int -> m a -> m (Vector a)
V.replicateM Int
k (Double -> Double -> g -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
gamma Double
a Double
1.0 g
g)
  let s :: Double
s = Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum Vector Double
ys
  Vector Double -> m (Vector Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (Vector Double -> m (Vector Double))
-> Vector Double -> m (Vector Double)
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> Vector Double -> Vector Double
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
V.map (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
s) Vector Double
ys