module Data.Random.Distribution.Discrete where
import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Data.Random.List (randomElement)
import Control.Monad
import Control.Applicative
import Data.List
import Data.Function
discrete :: Distribution (Discrete p) a => [(p,a)] -> RVar a
discrete ps = rvar (Discrete ps)
newtype Discrete p a = Discrete [(p, a)]
deriving (Eq, Show)
instance (Num p, Ord p, Distribution Uniform p) => Distribution (Discrete p) a where
rvar (Discrete []) = fail "discrete distribution over empty set cannot be sampled"
rvar (Discrete ds) = do
let (ps, xs) = unzip ds
cs = scanl1 (+) ps
when (any (<0) ps) $ fail "negative probability in discrete distribution"
let totalProb = last cs
if totalProb <= 0
then randomElement xs
else do
u <- uniform 0 totalProb
return $ head
[ x
| (c,x) <- zip cs xs
, c >= u
]
instance Functor (Discrete p) where
fmap f (Discrete ds) = Discrete [(p, f x) | (p, x) <- ds]
instance (Fractional p, Ord p) => Monad (Discrete p) where
return x = Discrete [(1, x)]
(Discrete x) >>= f = Discrete $ do
(p, x) <- x
let Discrete fx = f x
let qx = fx
qs = sum (map fst qx)
scale
| qs > 0
= recip qs
| otherwise
= recip (fromIntegral (length qx))
(q, x) <- qx
return (p * q * scale, x)
instance (Fractional p, Ord p) => Applicative (Discrete p) where
pure = return
(<*>) = ap
collectDiscreteEvents :: (Ord e, Num p, Ord p) => Discrete p e -> Discrete p e
collectDiscreteEvents (Discrete ds) =
Discrete . concatMap (uncurry combine . unzip) . groupEvents . sortEvents $ ds
where
groupEvents = groupBy ((==) `on` snd)
sortEvents = sortBy (compare `on` snd)
combine ps (x:_) = case partition (> 0) (filter (/= 0) ps) of
([], []) -> []
([], ns) -> (sum ns, x) : []
(ps, []) -> (sum ps, x) : []
(ps, ns) -> (sum ps, x) : (sum ns, x) : []
weight (p,x)
| p < 0 = error "negative probability in discrete distribution"
| otherwise = p
event ((p,x):_) = x