{-# LANGUAGE GADTs, MultiParamTypeClasses, FlexibleInstances, FlexibleContexts #-}
module Data.Random.Distribution.Multinomial where

import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Binomial

multinomial :: Distribution (Multinomial p) [a] => [p] -> a -> RVar [a]
multinomial :: [p] -> a -> RVar [a]
multinomial [p]
ps a
n = Multinomial p [a] -> RVar [a]
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar ([p] -> a -> Multinomial p [a]
forall p a. [p] -> a -> Multinomial p [a]
Multinomial [p]
ps a
n)

multinomialT :: Distribution (Multinomial p) [a] => [p] -> a -> RVarT m [a]
multinomialT :: [p] -> a -> RVarT m [a]
multinomialT [p]
ps a
n = Multinomial p [a] -> RVarT m [a]
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT ([p] -> a -> Multinomial p [a]
forall p a. [p] -> a -> Multinomial p [a]
Multinomial [p]
ps a
n)

data Multinomial p a where
    Multinomial :: [p] -> a -> Multinomial p [a]

instance (Num a, Eq a, Fractional p, Distribution (Binomial p) a) => Distribution (Multinomial p) [a] where
    -- TODO: implement faster version based on Categorical for small n, large (length ps)
    rvarT :: Multinomial p [a] -> RVarT n [a]
rvarT (Multinomial [p]
ps0 a
t) = a -> [p] -> [p] -> ([a] -> [a]) -> RVarT n [a]
forall t b c (m :: * -> *).
(Eq t, Distribution (Binomial b) t, Fractional b, Num t) =>
t -> [b] -> [b] -> ([t] -> c) -> RVarT m c
go a
t [p]
ps0 ([p] -> [p]
forall a. Num a => [a] -> [a]
tailSums [p]
ps0) [a] -> [a]
forall a. a -> a
id
        where
            go :: t -> [b] -> [b] -> ([t] -> c) -> RVarT m c
go t
_ []     [b]
_            [t] -> c
f = c -> RVarT m c
forall (m :: * -> *) a. Monad m => a -> m a
return ([t] -> c
f [])
            go t
n [b
_]    [b]
_            [t] -> c
f = c -> RVarT m c
forall (m :: * -> *) a. Monad m => a -> m a
return ([t] -> c
f [t
n])
            go t
0 (b
_:[b]
ps) (b
_   :[b]
psums) [t] -> c
f = t -> [b] -> [b] -> ([t] -> c) -> RVarT m c
go t
0 [b]
ps [b]
psums ([t] -> c
f ([t] -> c) -> ([t] -> [t]) -> [t] -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t
0t -> [t] -> [t]
forall a. a -> [a] -> [a]
:))
            go t
n (b
p:[b]
ps) (b
psum:[b]
psums) [t] -> c
f = do
                t
x <- t -> b -> RVarT m t
forall b a (m :: * -> *).
Distribution (Binomial b) a =>
a -> b -> RVarT m a
binomialT t
n (b
p b -> b -> b
forall a. Fractional a => a -> a -> a
/ b
psum)
                t -> [b] -> [b] -> ([t] -> c) -> RVarT m c
go (t
nt -> t -> t
forall a. Num a => a -> a -> a
-t
x) [b]
ps [b]
psums ([t] -> c
f ([t] -> c) -> ([t] -> [t]) -> [t] -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t
xt -> [t] -> [t]
forall a. a -> [a] -> [a]
:))

            go t
_ [b]
_ [b]
_ [t] -> c
_ = [Char] -> RVarT m c
forall a. HasCallStack => [Char] -> a
error [Char]
"rvar/Multinomial: programming error! this case should be impossible!"

            -- less wasteful version of (map sum . tails)
            tailSums :: [a] -> [a]
tailSums [] = [a
0]
            tailSums (a
x:[a]
xs) = case [a] -> [a]
tailSums [a]
xs of
                (a
s:[a]
rest) -> (a
xa -> a -> a
forall a. Num a => a -> a -> a
+a
s)a -> [a] -> [a]
forall a. a -> [a] -> [a]
:a
sa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
rest
                [a]
_ -> [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"rvar/Multinomial/tailSums: programming error! this case should be impossible!"