module Statistics.Distribution.Dirichlet
(
DirichletDistribution (ddGetParameters),
dirichletDistribution,
dirichletDensity,
dirichletSample,
DirichletDistributionSymmetric (ddSymGetParameter),
dirichletDistributionSymmetric,
dirichletDensitySymmetric,
dirichletSampleSymmetric,
)
where
import qualified Data.Vector.Unboxed as V
import Numeric.Log
import Numeric.SpecFunctions
import System.Random.MWC.Distributions
import System.Random.Stateful
data DirichletDistribution = DirichletDistribution
{ DirichletDistribution -> Vector Double
ddGetParameters :: V.Vector Double,
DirichletDistribution -> Int
_getDimension :: Int,
DirichletDistribution -> Log Double
_getNormConst :: Log Double
}
deriving (DirichletDistribution -> DirichletDistribution -> Bool
(DirichletDistribution -> DirichletDistribution -> Bool)
-> (DirichletDistribution -> DirichletDistribution -> Bool)
-> Eq DirichletDistribution
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DirichletDistribution -> DirichletDistribution -> Bool
$c/= :: DirichletDistribution -> DirichletDistribution -> Bool
== :: DirichletDistribution -> DirichletDistribution -> Bool
$c== :: DirichletDistribution -> DirichletDistribution -> Bool
Eq, Int -> DirichletDistribution -> ShowS
[DirichletDistribution] -> ShowS
DirichletDistribution -> String
(Int -> DirichletDistribution -> ShowS)
-> (DirichletDistribution -> String)
-> ([DirichletDistribution] -> ShowS)
-> Show DirichletDistribution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DirichletDistribution] -> ShowS
$cshowList :: [DirichletDistribution] -> ShowS
show :: DirichletDistribution -> String
$cshow :: DirichletDistribution -> String
showsPrec :: Int -> DirichletDistribution -> ShowS
$cshowsPrec :: Int -> DirichletDistribution -> ShowS
Show)
isNegativeOrZero :: V.Vector Double -> Bool
isNegativeOrZero :: Vector Double -> Bool
isNegativeOrZero = (Double -> Bool) -> Vector Double -> Bool
forall a. Unbox a => (a -> Bool) -> Vector a -> Bool
V.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
0)
invBeta :: V.Vector Double -> Log Double
invBeta :: Vector Double -> Log Double
invBeta Vector Double
v = Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double) -> Double -> Log Double
forall a b. (a -> b) -> a -> b
$ Double
logDenominator Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
logNominator
where
logNominator :: Double
logNominator = Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum (Vector Double -> Double) -> Vector Double -> Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> Vector Double -> Vector Double
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
V.map Double -> Double
logGamma Vector Double
v
logDenominator :: Double
logDenominator = Double -> Double
logGamma (Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum Vector Double
v)
dirichletDistribution :: V.Vector Double -> Either String DirichletDistribution
dirichletDistribution :: Vector Double -> Either String DirichletDistribution
dirichletDistribution Vector Double
v
| Vector Double -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector Double
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 =
String -> Either String DirichletDistribution
forall a b. a -> Either a b
Left String
"dirichletDistribution: Parameter vector is too short."
| Vector Double -> Bool
isNegativeOrZero Vector Double
v =
String -> Either String DirichletDistribution
forall a b. a -> Either a b
Left String
"dirichletDistribution: One or more parameters are negative or zero."
| Bool
otherwise = DirichletDistribution -> Either String DirichletDistribution
forall a b. b -> Either a b
Right (DirichletDistribution -> Either String DirichletDistribution)
-> DirichletDistribution -> Either String DirichletDistribution
forall a b. (a -> b) -> a -> b
$ Vector Double -> Int -> Log Double -> DirichletDistribution
DirichletDistribution Vector Double
v (Vector Double -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector Double
v) (Vector Double -> Log Double
invBeta Vector Double
v)
eps :: Double
eps :: Double
eps = Double
1e-14
isNormalized :: V.Vector Double -> Bool
isNormalized :: Vector Double -> Bool
isNormalized Vector Double
v
| Double -> Double
forall a. Num a => a -> a
abs (Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum Vector Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1.0) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
eps = Bool
False
| Bool
otherwise = Bool
True
dirichletDensity :: DirichletDistribution -> V.Vector Double -> Log Double
dirichletDensity :: DirichletDistribution -> Vector Double -> Log Double
dirichletDensity (DirichletDistribution Vector Double
as Int
k Log Double
c) Vector Double
xs
| Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Vector Double -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector Double
xs = Log Double
0
| Vector Double -> Bool
isNegativeOrZero Vector Double
xs = Log Double
0
| Bool -> Bool
not (Vector Double -> Bool
isNormalized Vector Double
xs) = Log Double
0
| Bool
otherwise = Log Double
c Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Double -> Log Double
forall a. a -> Log a
Exp Double
logXsPow
where
logXsPow :: Double
logXsPow = Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum (Vector Double -> Double) -> Vector Double -> Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Vector Double -> Vector Double -> Vector Double
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (\Double
a Double
x -> Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
x Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1.0)) Vector Double
as Vector Double
xs
dirichletSample :: StatefulGen g m => DirichletDistribution -> g -> m (V.Vector Double)
dirichletSample :: DirichletDistribution -> g -> m (Vector Double)
dirichletSample (DirichletDistribution Vector Double
as Int
_ Log Double
_) g
g = do
Vector Double
ys <- (Double -> m Double) -> Vector Double -> m (Vector Double)
forall (m :: * -> *) a b.
(Monad m, Unbox a, Unbox b) =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM (\Double
a -> Double -> Double -> g -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
gamma Double
a Double
1.0 g
g) Vector Double
as
let s :: Double
s = Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum Vector Double
ys
Vector Double -> m (Vector Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (Vector Double -> m (Vector Double))
-> Vector Double -> m (Vector Double)
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> Vector Double -> Vector Double
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
V.map (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
s) Vector Double
ys
data DirichletDistributionSymmetric = DirichletDistributionSymmetric
{ DirichletDistributionSymmetric -> Double
ddSymGetParameter :: Double,
DirichletDistributionSymmetric -> Int
_symGetDimension :: Int,
DirichletDistributionSymmetric -> Log Double
_symGetNormConst :: Log Double
}
deriving (DirichletDistributionSymmetric
-> DirichletDistributionSymmetric -> Bool
(DirichletDistributionSymmetric
-> DirichletDistributionSymmetric -> Bool)
-> (DirichletDistributionSymmetric
-> DirichletDistributionSymmetric -> Bool)
-> Eq DirichletDistributionSymmetric
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DirichletDistributionSymmetric
-> DirichletDistributionSymmetric -> Bool
$c/= :: DirichletDistributionSymmetric
-> DirichletDistributionSymmetric -> Bool
== :: DirichletDistributionSymmetric
-> DirichletDistributionSymmetric -> Bool
$c== :: DirichletDistributionSymmetric
-> DirichletDistributionSymmetric -> Bool
Eq, Int -> DirichletDistributionSymmetric -> ShowS
[DirichletDistributionSymmetric] -> ShowS
DirichletDistributionSymmetric -> String
(Int -> DirichletDistributionSymmetric -> ShowS)
-> (DirichletDistributionSymmetric -> String)
-> ([DirichletDistributionSymmetric] -> ShowS)
-> Show DirichletDistributionSymmetric
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DirichletDistributionSymmetric] -> ShowS
$cshowList :: [DirichletDistributionSymmetric] -> ShowS
show :: DirichletDistributionSymmetric -> String
$cshow :: DirichletDistributionSymmetric -> String
showsPrec :: Int -> DirichletDistributionSymmetric -> ShowS
$cshowsPrec :: Int -> DirichletDistributionSymmetric -> ShowS
Show)
invBetaSym :: Int -> Double -> Log Double
invBetaSym :: Int -> Double -> Log Double
invBetaSym Int
k Double
a = Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double) -> Double -> Log Double
forall a b. (a -> b) -> a -> b
$ Double
logDenominator Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
logNominator
where
logNominator :: Double
logNominator = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
logGamma Double
a
logDenominator :: Double
logDenominator = Double -> Double
logGamma (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
a)
dirichletDistributionSymmetric :: Int -> Double -> Either String DirichletDistributionSymmetric
dirichletDistributionSymmetric :: Int -> Double -> Either String DirichletDistributionSymmetric
dirichletDistributionSymmetric Int
k Double
a
| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 =
String -> Either String DirichletDistributionSymmetric
forall a b. a -> Either a b
Left String
"dirichletDistributionSymmetric: The dimension is smaller than two."
| Double
a Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
0 =
String -> Either String DirichletDistributionSymmetric
forall a b. a -> Either a b
Left String
"dirichletDistributionSymmetric: The parameter is negative or zero."
| Bool
otherwise = DirichletDistributionSymmetric
-> Either String DirichletDistributionSymmetric
forall a b. b -> Either a b
Right (DirichletDistributionSymmetric
-> Either String DirichletDistributionSymmetric)
-> DirichletDistributionSymmetric
-> Either String DirichletDistributionSymmetric
forall a b. (a -> b) -> a -> b
$ Double -> Int -> Log Double -> DirichletDistributionSymmetric
DirichletDistributionSymmetric Double
a Int
k (Int -> Double -> Log Double
invBetaSym Int
k Double
a)
dirichletDensitySymmetric :: DirichletDistributionSymmetric -> V.Vector Double -> Log Double
dirichletDensitySymmetric :: DirichletDistributionSymmetric -> Vector Double -> Log Double
dirichletDensitySymmetric (DirichletDistributionSymmetric Double
a Int
k Log Double
c) Vector Double
xs
| Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Vector Double -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector Double
xs = Log Double
0
| Vector Double -> Bool
isNegativeOrZero Vector Double
xs = Log Double
0
| Bool -> Bool
not (Vector Double -> Bool
isNormalized Vector Double
xs) = Log Double
0
| Bool
otherwise = Log Double
c Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Double -> Log Double
forall a. a -> Log a
Exp Double
logXsPow
where
logXsPow :: Double
logXsPow = Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum (Vector Double -> Double) -> Vector Double -> Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> Vector Double -> Vector Double
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
V.map (\Double
x -> Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
x Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1.0)) Vector Double
xs
dirichletSampleSymmetric ::
StatefulGen g m =>
DirichletDistributionSymmetric ->
g ->
m (V.Vector Double)
dirichletSampleSymmetric :: DirichletDistributionSymmetric -> g -> m (Vector Double)
dirichletSampleSymmetric (DirichletDistributionSymmetric Double
a Int
k Log Double
_) g
g = do
Vector Double
ys <- Int -> m Double -> m (Vector Double)
forall (m :: * -> *) a.
(Monad m, Unbox a) =>
Int -> m a -> m (Vector a)
V.replicateM Int
k (Double -> Double -> g -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
gamma Double
a Double
1.0 g
g)
let s :: Double
s = Vector Double -> Double
forall a. (Unbox a, Num a) => Vector a -> a
V.sum Vector Double
ys
Vector Double -> m (Vector Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (Vector Double -> m (Vector Double))
-> Vector Double -> m (Vector Double)
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> Vector Double -> Vector Double
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
V.map (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
s) Vector Double
ys