module Data.Random.Distribution.Categorical
( categorical, categoricalT
, fromList, toList
, fromWeightedList, fromObservations
, mapCategoricalPs, normalizeCategoricalPs
, collectEvents, collectEventsBy
) where
import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Control.Arrow
import Control.Monad
import Control.Monad.ST
import Control.Applicative
import Data.Foldable (Foldable(foldMap))
import Data.STRef
import Data.Traversable (Traversable(traverse, sequenceA))
import Data.List
import Data.Function
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
categorical :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVar a
categorical = rvar . fromList
categoricalT :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVarT m a
categoricalT = rvarT . fromList
fromList :: (Num p) => [(p,a)] -> Categorical p a
fromList xs = Categorical (V.fromList (scanl1 f xs))
where f (p0, _) (p1, y) = (p0 + p1, y)
toList :: (Num p) => Categorical p a -> [(p,a)]
toList (Categorical ds) = V.foldr' g [] ds
where
g x [] = [x]
g x@(p0,_) ((p1, y):xs) = x : (p1p0,y) : xs
fromWeightedList :: (Fractional p, Ord a) => [(p,a)] -> Categorical p a
fromWeightedList = normalizeCategoricalPs . fromList
fromObservations :: (Fractional p, Ord a) => [a] -> Categorical p a
fromObservations = fromWeightedList . map (genericLength &&& head) . group . sort
newtype Categorical p a = Categorical (V.Vector (p, a))
deriving Eq
instance (Num p, Show a) => Show (Categorical p a) where
showsPrec p cat = showParen (p>10)
( showString "fromList "
. showsPrec 11 (toList cat)
)
instance (Fractional p, Ord p, Distribution Uniform p) => Distribution (Categorical p) a where
rvarT (Categorical ds)
| V.null ds = fail "categorical distribution over empty set cannot be sampled"
| n == 1 = return (snd (V.head ds))
| otherwise = do
u <- uniformT 0 (fst (V.last ds))
let p i = fst (ds V.! i)
x i = snd (ds V.! i)
findEvent i j
| i >= j = x j
| p m >= u = findEvent i m
| otherwise = findEvent (max m (i+1)) j
where
m = (i + j) `div` 2
return (findEvent 0 (n1))
where n = V.length ds
instance Functor (Categorical p) where
fmap f (Categorical ds) = Categorical (V.map (second f) ds)
instance Foldable (Categorical p) where
foldMap f (Categorical ds) = foldMap (f . snd) (V.toList ds)
instance Traversable (Categorical p) where
traverse f (Categorical ds) = Categorical . V.fromList <$> traverse (\(p,e) -> (\e' -> (p,e')) <$> f e) (V.toList ds)
sequenceA (Categorical ds) = Categorical . V.fromList <$> traverse (\(p,e) -> (\e' -> (p,e')) <$> e) (V.toList ds)
instance Num p => Monad (Categorical p) where
return x = Categorical (V.singleton (1, x))
fail _ = Categorical V.empty
xs >>= f = fromList $ do
(p, x) <- toList xs
(q, y) <- toList (f x)
return (p * q, y)
instance Fractional p => Applicative (Categorical p) where
pure = return
(<*>) = ap
mapCategoricalPs :: (p -> q) -> Categorical p e -> Categorical q e
mapCategoricalPs f (Categorical ds) = Categorical (V.map (first f) ds)
normalizeCategoricalPs :: (Fractional p) => Categorical p e -> Categorical p e
normalizeCategoricalPs orig@(Categorical ds) =
if V.null ds
then orig
else runST $ do
let n = V.length ds
lastP <- newSTRef 0
dups <- newSTRef 0
normalized <- V.thaw ds
let skip = modifySTRef' dups (1+)
save i p x = do
d <- readSTRef dups
MV.write normalized (id) (p, x)
sequence_
[ do
let (p,x) = ds V.! i
p0 <- readSTRef lastP
if p == p0
then skip
else do
save i (p * scale) x
writeSTRef lastP p
| i <- [0..n1]
]
d <- readSTRef dups
MV.write normalized (nd1) (1,lastX)
Categorical <$> V.unsafeFreeze (MV.unsafeSlice 0 (nd) normalized)
where
(ps, lastX) = V.last ds
scale = recip ps
modifySTRef' :: STRef s a -> (a -> a) -> ST s ()
modifySTRef' x f = do
v <- readSTRef x
let fv = f v
fv `seq` writeSTRef x fv
collectEvents :: (Ord e, Num p, Ord p) => Categorical p e -> Categorical p e
collectEvents = collectEventsBy compare ((sum *** head) . unzip)
collectEventsBy :: Num p => (e -> e -> Ordering) -> ([(p,e)] -> (p,e))-> Categorical p e -> Categorical p e
collectEventsBy compareE combine =
fromList . map combine . groupEvents . sortEvents . toList
where
groupEvents = groupBy (\x y -> snd x `compareE` snd y == EQ)
sortEvents = sortBy (compareE `on` snd)