module Data.Random.Distribution.Discrete where
import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Data.Random.List (randomElement)
import Control.Arrow
import Control.Monad
import Control.Applicative
import Data.Foldable (Foldable(foldMap))
import Data.Traversable (Traversable(traverse, sequenceA))
import Data.List
import Data.Function
discrete :: Distribution (Discrete p) a => [(p,a)] -> RVar a
discrete ps = rvar (Discrete ps)
empirical :: (Num p, Ord a) => [a] -> Discrete p a
empirical xs = Discrete bins
where bins = [ (genericLength bin, x)
| bin@(x:_) <- group (sort xs)
]
newtype Discrete p a = Discrete [(p, a)]
deriving (Eq, Show)
instance (Num p, Ord p, Distribution Uniform p) => Distribution (Discrete p) a where
rvar (Discrete []) = fail "discrete distribution over empty set cannot be sampled"
rvar (Discrete ds) = do
let (ps, xs) = unzip ds
cs = scanl1 (+) ps
when (any (<0) ps) $ fail "negative probability in discrete distribution"
let totalWeight = last cs
if totalWeight <= 0
then randomElement xs
else do
let getU = do
u <- uniform 0 totalWeight
if u == 0 then getU
else return u
u <- getU
return $ head
[ x
| (c,x) <- zip cs xs
, c >= u
]
instance Functor (Discrete p) where
fmap f (Discrete ds) = Discrete [(p, f x) | (p, x) <- ds]
instance Foldable (Discrete p) where
foldMap f (Discrete ds) = foldMap (f . snd) ds
instance Traversable (Discrete p) where
traverse f (Discrete ds) = Discrete <$> traverse (\(p,e) -> (\e -> (p,e)) <$> f e) ds
sequenceA (Discrete ds) = Discrete <$> traverse (\(p,e) -> (\e -> (p,e)) <$> e) ds
instance Num p => Monad (Discrete p) where
return x = Discrete [(1, x)]
(Discrete x) >>= f = Discrete $ do
(p, x) <- x
let Discrete fx = f x
(q, x) <- fx
return (p * q, x)
instance Num p => Applicative (Discrete p) where
pure = return
(<*>) = ap
mapDiscreteWeights :: (p -> q) -> Discrete p e -> Discrete q e
mapDiscreteWeights f (Discrete ds) = Discrete [(f p, x) | (p, x) <- ds]
normalizeDiscreteWeights :: (Fractional p) => Discrete p e -> Discrete p e
normalizeDiscreteWeights orig@(Discrete ds) =
if ws `elem` [0,1]
then orig
else Discrete
[ (w * scale, e)
| (w, e) <- ds
]
where
ws = sum (map fst ds)
scale = recip ws
collectDiscreteEvents :: (Ord e, Num p, Ord p) => Discrete p e -> Discrete p e
collectDiscreteEvents = collectDiscreteEventsBy compare sum head
collectDiscreteEventsBy :: (e -> e -> Ordering) -> ([p] -> p) -> ([e] -> e)-> Discrete p e -> Discrete p e
collectDiscreteEventsBy compareE sumWeights mergeEvents (Discrete ds) =
Discrete . map ((sumWeights *** mergeEvents) . unzip) . groupEvents . sortEvents $ ds
where
groupEvents = groupBy (\x y -> snd x `compareE` snd y == EQ)
sortEvents = sortBy (compareE `on` snd)
weight (p,x)
| p < 0 = error "negative probability in discrete distribution"
| otherwise = p
event ((p,x):_) = x
combine (ps, xs) = (sumWeights ps, mergeEvents xs)