{-# LANGUAGE
    MultiParamTypeClasses,
    FlexibleInstances, FlexibleContexts,
    UndecidableInstances, GADTs
  #-}

module Data.Random.Distribution.Dirichlet where

import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Gamma

import Data.List

fractionalDirichlet :: (Fractional a, Distribution Gamma a) => [a] -> RVarT m [a]
fractionalDirichlet :: [a] -> RVarT m [a]
fractionalDirichlet []  = [a] -> RVarT m [a]
forall (m :: * -> *) a. Monad m => a -> m a
return []
fractionalDirichlet [a
_] = [a] -> RVarT m [a]
forall (m :: * -> *) a. Monad m => a -> m a
return [a
1]
fractionalDirichlet [a]
as = do
    [a]
xs <- [RVarT m a] -> RVarT m [a]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [a -> a -> RVarT m a
forall a (m :: * -> *). Distribution Gamma a => a -> a -> RVarT m a
gammaT a
a a
1 | a
a <- [a]
as]
    let total :: a
total = (a -> a -> a) -> [a] -> a
forall a. (a -> a -> a) -> [a] -> a
foldl1' a -> a -> a
forall a. Num a => a -> a -> a
(+) [a]
xs

    [a] -> RVarT m [a]
forall (m :: * -> *) a. Monad m => a -> m a
return ((a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Fractional a => a -> a
recip a
total) [a]
xs)

dirichlet :: Distribution Dirichlet [a] => [a] -> RVar [a]
dirichlet :: [a] -> RVar [a]
dirichlet [a]
as = Dirichlet [a] -> RVar [a]
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar ([a] -> Dirichlet [a]
forall a. a -> Dirichlet a
Dirichlet [a]
as)

dirichletT :: Distribution Dirichlet [a] => [a] -> RVarT m [a]
dirichletT :: [a] -> RVarT m [a]
dirichletT [a]
as = Dirichlet [a] -> RVarT m [a]
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT ([a] -> Dirichlet [a]
forall a. a -> Dirichlet a
Dirichlet [a]
as)

newtype Dirichlet a = Dirichlet a deriving (Dirichlet a -> Dirichlet a -> Bool
(Dirichlet a -> Dirichlet a -> Bool)
-> (Dirichlet a -> Dirichlet a -> Bool) -> Eq (Dirichlet a)
forall a. Eq a => Dirichlet a -> Dirichlet a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Dirichlet a -> Dirichlet a -> Bool
$c/= :: forall a. Eq a => Dirichlet a -> Dirichlet a -> Bool
== :: Dirichlet a -> Dirichlet a -> Bool
$c== :: forall a. Eq a => Dirichlet a -> Dirichlet a -> Bool
Eq, Int -> Dirichlet a -> ShowS
[Dirichlet a] -> ShowS
Dirichlet a -> String
(Int -> Dirichlet a -> ShowS)
-> (Dirichlet a -> String)
-> ([Dirichlet a] -> ShowS)
-> Show (Dirichlet a)
forall a. Show a => Int -> Dirichlet a -> ShowS
forall a. Show a => [Dirichlet a] -> ShowS
forall a. Show a => Dirichlet a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Dirichlet a] -> ShowS
$cshowList :: forall a. Show a => [Dirichlet a] -> ShowS
show :: Dirichlet a -> String
$cshow :: forall a. Show a => Dirichlet a -> String
showsPrec :: Int -> Dirichlet a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Dirichlet a -> ShowS
Show)

instance (Fractional a, Distribution Gamma a) => Distribution Dirichlet [a] where
    rvarT :: Dirichlet [a] -> RVarT n [a]
rvarT (Dirichlet [a]
as) = [a] -> RVarT n [a]
forall a (m :: * -> *).
(Fractional a, Distribution Gamma a) =>
[a] -> RVarT m [a]
fractionalDirichlet [a]
as