{-# LANGUAGE
    MultiParamTypeClasses,
    FlexibleInstances, FlexibleContexts,
    CPP
  #-}

{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-}

module Data.Random.Distribution.Categorical
    ( Categorical
    , categorical, categoricalT
    , weightedCategorical, weightedCategoricalT
    , fromList, toList, totalWeight, numEvents
    , fromWeightedList, fromObservations
    , mapCategoricalPs, normalizeCategoricalPs
    , collectEvents, collectEventsBy
    ) where

import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform

import Control.Arrow
import Control.Monad
import Control.Monad.ST
import Data.STRef

import Data.List
import Data.Function
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV

-- |Construct a 'Categorical' random variable from a list of probabilities
-- and categories, where the probabilities all sum to 1.
categorical :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVar a
categorical :: [(p, a)] -> RVar a
categorical = Categorical p a -> RVar a
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (Categorical p a -> RVar a)
-> ([(p, a)] -> Categorical p a) -> [(p, a)] -> RVar a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(p, a)] -> Categorical p a
forall p a. Num p => [(p, a)] -> Categorical p a
fromList

-- |Construct a 'Categorical' random process from a list of probabilities
-- and categories, where the probabilities all sum to 1.
categoricalT :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVarT m a
categoricalT :: [(p, a)] -> RVarT m a
categoricalT = Categorical p a -> RVarT m a
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (Categorical p a -> RVarT m a)
-> ([(p, a)] -> Categorical p a) -> [(p, a)] -> RVarT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(p, a)] -> Categorical p a
forall p a. Num p => [(p, a)] -> Categorical p a
fromList

-- |Construct a 'Categorical' random variable from a list of weights
-- and categories. The weights do /not/ have to sum to 1.
weightedCategorical :: (Fractional p, Eq p, Distribution (Categorical p) a) => [(p,a)] -> RVar a
weightedCategorical :: [(p, a)] -> RVar a
weightedCategorical = Categorical p a -> RVar a
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (Categorical p a -> RVar a)
-> ([(p, a)] -> Categorical p a) -> [(p, a)] -> RVar a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(p, a)] -> Categorical p a
forall p a. (Fractional p, Eq p) => [(p, a)] -> Categorical p a
fromWeightedList

-- |Construct a 'Categorical' random process from a list of weights
-- and categories. The weights do /not/ have to sum to 1.
weightedCategoricalT :: (Fractional p, Eq p, Distribution (Categorical p) a) => [(p,a)] -> RVarT m a
weightedCategoricalT :: [(p, a)] -> RVarT m a
weightedCategoricalT = Categorical p a -> RVarT m a
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (Categorical p a -> RVarT m a)
-> ([(p, a)] -> Categorical p a) -> [(p, a)] -> RVarT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(p, a)] -> Categorical p a
forall p a. (Fractional p, Eq p) => [(p, a)] -> Categorical p a
fromWeightedList

-- | Construct a 'Categorical' distribution from a list of weighted categories.
{-# INLINE fromList #-}
fromList :: (Num p) => [(p,a)] -> Categorical p a
fromList :: [(p, a)] -> Categorical p a
fromList [(p, a)]
xs = Vector (p, a) -> Categorical p a
forall p a. Vector (p, a) -> Categorical p a
Categorical ([(p, a)] -> Vector (p, a)
forall a. [a] -> Vector a
V.fromList (((p, a) -> (p, a) -> (p, a)) -> [(p, a)] -> [(p, a)]
forall a. (a -> a -> a) -> [a] -> [a]
scanl1 (p, a) -> (p, a) -> (p, a)
forall a b b. Num a => (a, b) -> (a, b) -> (a, b)
f [(p, a)]
xs))
    where f :: (a, b) -> (a, b) -> (a, b)
f (a
p0, b
_) (a
p1, b
y) = (a
p0 a -> a -> a
forall a. Num a => a -> a -> a
+ a
p1, b
y)

{-# INLINE toList #-}
toList :: (Num p) => Categorical p a -> [(p,a)]
toList :: Categorical p a -> [(p, a)]
toList (Categorical Vector (p, a)
ds) = ((p, a) -> [(p, a)] -> [(p, a)])
-> [(p, a)] -> Vector (p, a) -> [(p, a)]
forall a b. (a -> b -> b) -> b -> Vector a -> b
V.foldr' (p, a) -> [(p, a)] -> [(p, a)]
forall a b. Num a => (a, b) -> [(a, b)] -> [(a, b)]
g [] Vector (p, a)
ds
    where
        g :: (a, b) -> [(a, b)] -> [(a, b)]
g (a, b)
x [] = [(a, b)
x]
        g x :: (a, b)
x@(a
p0,b
_) ((a
p1, b
y):[(a, b)]
xs) = (a, b)
x (a, b) -> [(a, b)] -> [(a, b)]
forall a. a -> [a] -> [a]
: (a
p1a -> a -> a
forall a. Num a => a -> a -> a
-a
p0,b
y) (a, b) -> [(a, b)] -> [(a, b)]
forall a. a -> [a] -> [a]
: [(a, b)]
xs

totalWeight :: Num p => Categorical p a -> p
totalWeight :: Categorical p a -> p
totalWeight (Categorical Vector (p, a)
ds)
    | Vector (p, a) -> Bool
forall a. Vector a -> Bool
V.null Vector (p, a)
ds = p
0
    | Bool
otherwise = (p, a) -> p
forall a b. (a, b) -> a
fst (Vector (p, a) -> (p, a)
forall a. Vector a -> a
V.last Vector (p, a)
ds)

numEvents :: Categorical p a -> Int
numEvents :: Categorical p a -> Int
numEvents (Categorical Vector (p, a)
ds) = Vector (p, a) -> Int
forall a. Vector a -> Int
V.length Vector (p, a)
ds

-- |Construct a 'Categorical' distribution from a list of weighted categories,
-- where the weights do not necessarily sum to 1.
fromWeightedList :: (Fractional p, Eq p) => [(p,a)] -> Categorical p a
fromWeightedList :: [(p, a)] -> Categorical p a
fromWeightedList = Categorical p a -> Categorical p a
forall p e.
(Fractional p, Eq p) =>
Categorical p e -> Categorical p e
normalizeCategoricalPs (Categorical p a -> Categorical p a)
-> ([(p, a)] -> Categorical p a) -> [(p, a)] -> Categorical p a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(p, a)] -> Categorical p a
forall p a. Num p => [(p, a)] -> Categorical p a
fromList

-- |Construct a 'Categorical' distribution from a list of observed outcomes.
-- Equivalent events will be grouped and counted, and the probabilities of each
-- event in the returned distribution will be proportional to the number of
-- occurrences of that event.
fromObservations :: (Fractional p, Eq p, Ord a) => [a] -> Categorical p a
fromObservations :: [a] -> Categorical p a
fromObservations = [(p, a)] -> Categorical p a
forall p a. (Fractional p, Eq p) => [(p, a)] -> Categorical p a
fromWeightedList ([(p, a)] -> Categorical p a)
-> ([a] -> [(p, a)]) -> [a] -> Categorical p a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([a] -> (p, a)) -> [[a]] -> [(p, a)]
forall a b. (a -> b) -> [a] -> [b]
map ([a] -> p
forall i a. Num i => [a] -> i
genericLength ([a] -> p) -> ([a] -> a) -> [a] -> (p, a)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& [a] -> a
forall a. [a] -> a
head) ([[a]] -> [(p, a)]) -> ([a] -> [[a]]) -> [a] -> [(p, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> [[a]]
forall a. Eq a => [a] -> [[a]]
group ([a] -> [[a]]) -> ([a] -> [a]) -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> [a]
forall a. Ord a => [a] -> [a]
sort

-- The following description refers to the public interface.  For those reading
-- the code, in the actual implementation Categorical is stored as a vector of
-- (cumulative-probability, value) pairs, so that sampling can take advantage of
-- binary search.

-- |Categorical distribution; a list of events with corresponding probabilities.
-- The sum of the probabilities must be 1, and no event should have a zero
-- or negative probability (at least, at time of sampling; very clever users
-- can do what they want with the numbers before sampling, just make sure
-- that if you're one of those clever ones, you at least eliminate negative
-- weights before sampling).
newtype Categorical p a = Categorical (V.Vector (p, a))
    deriving Categorical p a -> Categorical p a -> Bool
(Categorical p a -> Categorical p a -> Bool)
-> (Categorical p a -> Categorical p a -> Bool)
-> Eq (Categorical p a)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall p a.
(Eq p, Eq a) =>
Categorical p a -> Categorical p a -> Bool
/= :: Categorical p a -> Categorical p a -> Bool
$c/= :: forall p a.
(Eq p, Eq a) =>
Categorical p a -> Categorical p a -> Bool
== :: Categorical p a -> Categorical p a -> Bool
$c== :: forall p a.
(Eq p, Eq a) =>
Categorical p a -> Categorical p a -> Bool
Eq

instance (Num p, Show p, Show a) => Show (Categorical p a) where
    showsPrec :: Int -> Categorical p a -> ShowS
showsPrec Int
p Categorical p a
cat = Bool -> ShowS -> ShowS
showParen (Int
pInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
10)
        ( String -> ShowS
showString String
"fromList "
        ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [(p, a)] -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 (Categorical p a -> [(p, a)]
forall p a. Num p => Categorical p a -> [(p, a)]
toList Categorical p a
cat)
        )

instance (Num p, Read p, Read a) => Read (Categorical p a) where
  readsPrec :: Int -> ReadS (Categorical p a)
readsPrec Int
p = Bool -> ReadS (Categorical p a) -> ReadS (Categorical p a)
forall a. Bool -> ReadS a -> ReadS a
readParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ReadS (Categorical p a) -> ReadS (Categorical p a))
-> ReadS (Categorical p a) -> ReadS (Categorical p a)
forall a b. (a -> b) -> a -> b
$ \String
str -> do
                  (String
"fromList", String
valStr) <- ReadS String
lex String
str
                  ([(p, a)]
vals,       String
rest)   <- Int -> ReadS [(p, a)]
forall a. Read a => Int -> ReadS a
readsPrec Int
11 String
valStr
                  (Categorical p a, String) -> [(Categorical p a, String)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(p, a)] -> Categorical p a
forall p a. Num p => [(p, a)] -> Categorical p a
fromList [(p, a)]
vals, String
rest)

instance (Fractional p, Ord p, Distribution Uniform p) => Distribution (Categorical p) a where
    rvarT :: Categorical p a -> RVarT n a
rvarT (Categorical Vector (p, a)
ds)
        | Vector (p, a) -> Bool
forall a. Vector a -> Bool
V.null Vector (p, a)
ds = String -> RVarT n a
forall a. HasCallStack => String -> a
error String
"categorical distribution over empty set cannot be sampled"
        | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1    = a -> RVarT n a
forall (m :: * -> *) a. Monad m => a -> m a
return ((p, a) -> a
forall a b. (a, b) -> b
snd (Vector (p, a) -> (p, a)
forall a. Vector a -> a
V.head Vector (p, a)
ds))
        | Bool
otherwise = do
            p
u <- p -> p -> RVarT n p
forall a (m :: * -> *).
Distribution Uniform a =>
a -> a -> RVarT m a
uniformT p
0 ((p, a) -> p
forall a b. (a, b) -> a
fst (Vector (p, a) -> (p, a)
forall a. Vector a -> a
V.last Vector (p, a)
ds))

            let -- by construction, p is monotone; (i < j) ==> (p i <= p j)
                p :: Int -> p
p Int
i = (p, a) -> p
forall a b. (a, b) -> a
fst (Vector (p, a)
ds Vector (p, a) -> Int -> (p, a)
forall a. Vector a -> Int -> a
V.! Int
i)
                x :: Int -> a
x Int
i = (p, a) -> a
forall a b. (a, b) -> b
snd (Vector (p, a)
ds Vector (p, a) -> Int -> (p, a)
forall a. Vector a -> Int -> a
V.! Int
i)

                --  findEvent
                -- ===========
                -- invariants: (i <= j), (u <= p j), ((i == 0) || (p i < u))
                --  (the last one means 'i' does not increase unless it bounds 'p' below 'u')
                -- variant: either i increases or j decreases.
                -- upon termination: ∀ k. if (k < j) then (p k < u) else (u <= p k)
                --  (that is, the chosen event 'x j' is the first one whose
                --   associated cumulative probability 'p j' is greater than
                --   or equal to 'u')
                findEvent :: Int -> Int -> a
findEvent Int
i Int
j
                    | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
i    = Int -> a
x Int
j
                    | p
u p -> p -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> p
p Int
m  = Int -> Int -> a
findEvent Int
i Int
m
                    | Bool
otherwise = Int -> Int -> a
findEvent (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
m (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)) Int
j
                    where
                        -- midpoint rounding down
                        -- (i < j) ==> (m < j)
                        m :: Int
m = (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2

            a -> RVarT n a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> RVarT n a) -> a -> RVarT n a
forall a b. (a -> b) -> a -> b
$! if p
u p -> p -> Bool
forall a. Ord a => a -> a -> Bool
<= p
0 then Int -> a
x Int
0 else Int -> Int -> a
findEvent Int
0 (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
        where n :: Int
n = Vector (p, a) -> Int
forall a. Vector a -> Int
V.length Vector (p, a)
ds


instance Functor (Categorical p) where
    fmap :: (a -> b) -> Categorical p a -> Categorical p b
fmap a -> b
f (Categorical Vector (p, a)
ds) = Vector (p, b) -> Categorical p b
forall p a. Vector (p, a) -> Categorical p a
Categorical (((p, a) -> (p, b)) -> Vector (p, a) -> Vector (p, b)
forall a b. (a -> b) -> Vector a -> Vector b
V.map ((a -> b) -> (p, a) -> (p, b)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second a -> b
f) Vector (p, a)
ds)

instance Foldable (Categorical p) where
    foldMap :: (a -> m) -> Categorical p a -> m
foldMap a -> m
f (Categorical Vector (p, a)
ds) = ((p, a) -> m) -> [(p, a)] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (a -> m
f (a -> m) -> ((p, a) -> a) -> (p, a) -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (p, a) -> a
forall a b. (a, b) -> b
snd) (Vector (p, a) -> [(p, a)]
forall a. Vector a -> [a]
V.toList Vector (p, a)
ds)

instance Traversable (Categorical p) where
    traverse :: (a -> f b) -> Categorical p a -> f (Categorical p b)
traverse a -> f b
f (Categorical Vector (p, a)
ds) = Vector (p, b) -> Categorical p b
forall p a. Vector (p, a) -> Categorical p a
Categorical (Vector (p, b) -> Categorical p b)
-> ([(p, b)] -> Vector (p, b)) -> [(p, b)] -> Categorical p b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(p, b)] -> Vector (p, b)
forall a. [a] -> Vector a
V.fromList ([(p, b)] -> Categorical p b) -> f [(p, b)] -> f (Categorical p b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((p, a) -> f (p, b)) -> [(p, a)] -> f [(p, b)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\(p
p,a
e) -> (\b
e' -> (p
p,b
e')) (b -> (p, b)) -> f b -> f (p, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
e) (Vector (p, a) -> [(p, a)]
forall a. Vector a -> [a]
V.toList Vector (p, a)
ds)
    sequenceA :: Categorical p (f a) -> f (Categorical p a)
sequenceA  (Categorical Vector (p, f a)
ds) = Vector (p, a) -> Categorical p a
forall p a. Vector (p, a) -> Categorical p a
Categorical (Vector (p, a) -> Categorical p a)
-> ([(p, a)] -> Vector (p, a)) -> [(p, a)] -> Categorical p a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(p, a)] -> Vector (p, a)
forall a. [a] -> Vector a
V.fromList ([(p, a)] -> Categorical p a) -> f [(p, a)] -> f (Categorical p a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((p, f a) -> f (p, a)) -> [(p, f a)] -> f [(p, a)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\(p
p,f a
e) -> (\a
e' -> (p
p,a
e')) (a -> (p, a)) -> f a -> f (p, a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>   f a
e) (Vector (p, f a) -> [(p, f a)]
forall a. Vector a -> [a]
V.toList Vector (p, f a)
ds)

instance Fractional p => Monad (Categorical p) where
    return :: a -> Categorical p a
return a
x = Vector (p, a) -> Categorical p a
forall p a. Vector (p, a) -> Categorical p a
Categorical ((p, a) -> Vector (p, a)
forall a. a -> Vector a
V.singleton (p
1, a
x))

    -- I'm not entirely sure whether this is a valid form of failure; see next
    -- set of comments.
#if __GLASGOW_HASKELL__ < 808
    fail _ = Categorical V.empty
#endif

    -- Should the normalize step be included here, or should normalization
    -- be assumed?  It seems like there is (at least) 1 valid situation where
    -- non-normal results would arise:  the distribution being modeled is
    -- "conditional" and some event arose that contradicted the assumed
    -- condition and thus was eliminated ('f' returned an empty or
    -- zero-probability consequent, possibly by 'fail'ing).
    --
    -- It seems reasonable to continue in such circumstances, but should there
    -- be any renormalization?  If so, does it make a difference when that
    -- renormalization is done?  I'm pretty sure it does, actually.  So, the
    -- normalization will be omitted here for now, as it's easier for the
    -- user (who really better know what they mean if they're returning
    -- non-normalized probability anyway) to normalize explicitly than to
    -- undo any normalization that was done automatically.
    Categorical p a
xs >>= :: Categorical p a -> (a -> Categorical p b) -> Categorical p b
>>= a -> Categorical p b
f = {- normalizeCategoricalPs . -} [(p, b)] -> Categorical p b
forall p a. Num p => [(p, a)] -> Categorical p a
fromList ([(p, b)] -> Categorical p b) -> [(p, b)] -> Categorical p b
forall a b. (a -> b) -> a -> b
$ do
        (p
p, a
x) <- Categorical p a -> [(p, a)]
forall p a. Num p => Categorical p a -> [(p, a)]
toList Categorical p a
xs
        (p
q, b
y) <- Categorical p b -> [(p, b)]
forall p a. Num p => Categorical p a -> [(p, a)]
toList (a -> Categorical p b
f a
x)

        (p, b) -> [(p, b)]
forall (m :: * -> *) a. Monad m => a -> m a
return (p
p p -> p -> p
forall a. Num a => a -> a -> a
* p
q, b
y)

instance Fractional p => Applicative (Categorical p) where
    pure :: a -> Categorical p a
pure = a -> Categorical p a
forall (m :: * -> *) a. Monad m => a -> m a
return
    <*> :: Categorical p (a -> b) -> Categorical p a -> Categorical p b
(<*>) = Categorical p (a -> b) -> Categorical p a -> Categorical p b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

-- |Like 'fmap', but for the probabilities of a categorical distribution.
mapCategoricalPs :: (Num p, Num q) => (p -> q) -> Categorical p e -> Categorical q e
mapCategoricalPs :: (p -> q) -> Categorical p e -> Categorical q e
mapCategoricalPs p -> q
f = [(q, e)] -> Categorical q e
forall p a. Num p => [(p, a)] -> Categorical p a
fromList ([(q, e)] -> Categorical q e)
-> (Categorical p e -> [(q, e)])
-> Categorical p e
-> Categorical q e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((p, e) -> (q, e)) -> [(p, e)] -> [(q, e)]
forall a b. (a -> b) -> [a] -> [b]
map ((p -> q) -> (p, e) -> (q, e)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first p -> q
f) ([(p, e)] -> [(q, e)])
-> (Categorical p e -> [(p, e)]) -> Categorical p e -> [(q, e)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Categorical p e -> [(p, e)]
forall p a. Num p => Categorical p a -> [(p, a)]
toList

-- |Adjust all the weights of a categorical distribution so that they
-- sum to unity and remove all events whose probability is zero.
normalizeCategoricalPs :: (Fractional p, Eq p) => Categorical p e -> Categorical p e
normalizeCategoricalPs :: Categorical p e -> Categorical p e
normalizeCategoricalPs orig :: Categorical p e
orig@(Categorical Vector (p, e)
ds)
    | p
ps p -> p -> Bool
forall a. Eq a => a -> a -> Bool
== p
0   = Vector (p, e) -> Categorical p e
forall p a. Vector (p, a) -> Categorical p a
Categorical Vector (p, e)
forall a. Vector a
V.empty
    | Bool
otherwise = (forall s. ST s (Categorical p e)) -> Categorical p e
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Categorical p e)) -> Categorical p e)
-> (forall s. ST s (Categorical p e)) -> Categorical p e
forall a b. (a -> b) -> a -> b
$ do
        STRef s p
lastP       <- p -> ST s (STRef s p)
forall a s. a -> ST s (STRef s a)
newSTRef p
0
        STRef s Int
nDups       <- Int -> ST s (STRef s Int)
forall a s. a -> ST s (STRef s a)
newSTRef Int
0
        MVector s (p, e)
normalized  <- Vector (p, e) -> ST s (MVector (PrimState (ST s)) (p, e))
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.thaw Vector (p, e)
ds

        let n :: Int
n           = Vector (p, e) -> Int
forall a. Vector a -> Int
V.length Vector (p, e)
ds
            skip :: ST s ()
skip        = STRef s Int -> (Int -> Int) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s Int
nDups (Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
+)
            save :: Int -> p -> e -> ST s ()
save Int
i p
p e
x  = do
                Int
d <- STRef s Int -> ST s Int
forall s a. STRef s a -> ST s a
readSTRef STRef s Int
nDups
                MVector (PrimState (ST s)) (p, e) -> Int -> (p, e) -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s (p, e)
MVector (PrimState (ST s)) (p, e)
normalized (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
d) (p
p, e
x)

        [ST s ()] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
            [ do
                let (p
p,e
x) = Vector (p, e)
ds Vector (p, e) -> Int -> (p, e)
forall a. Vector a -> Int -> a
V.! Int
i
                p
p0 <- STRef s p -> ST s p
forall s a. STRef s a -> ST s a
readSTRef STRef s p
lastP
                if p
p p -> p -> Bool
forall a. Eq a => a -> a -> Bool
== p
p0
                    then ST s ()
skip
                    else do
                        Int -> p -> e -> ST s ()
save Int
i (p
p p -> p -> p
forall a. Num a => a -> a -> a
* p
scale) e
x
                        STRef s p -> p -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s p
lastP (p -> ST s ()) -> p -> ST s ()
forall a b. (a -> b) -> a -> b
$! p
p
            | Int
i <- [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
            ]

        -- force last element to 1
        Int
d <- STRef s Int -> ST s Int
forall s a. STRef s a -> ST s a
readSTRef STRef s Int
nDups
        let n' :: Int
n' = Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
d
        (p
_,e
lastX) <- MVector (PrimState (ST s)) (p, e) -> Int -> ST s (p, e)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s (p, e)
MVector (PrimState (ST s)) (p, e)
normalized (Int
n'Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
        MVector (PrimState (ST s)) (p, e) -> Int -> (p, e) -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s (p, e)
MVector (PrimState (ST s)) (p, e)
normalized (Int
n'Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (p
1,e
lastX)
        Vector (p, e) -> Categorical p e
forall p a. Vector (p, a) -> Categorical p a
Categorical (Vector (p, e) -> Categorical p e)
-> ST s (Vector (p, e)) -> ST s (Categorical p e)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) (p, e) -> ST s (Vector (p, e))
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (Int -> Int -> MVector s (p, e) -> MVector s (p, e)
forall s a. Int -> Int -> MVector s a -> MVector s a
MV.unsafeSlice Int
0 Int
n' MVector s (p, e)
normalized)
    where
        ps :: p
ps = Categorical p e -> p
forall p a. Num p => Categorical p a -> p
totalWeight Categorical p e
orig
        scale :: p
scale = p -> p
forall a. Fractional a => a -> a
recip p
ps

#if __GLASGOW_HASKELL__ < 706
-- |strict 'modifySTRef'
modifySTRef' :: STRef s a -> (a -> a) -> ST s ()
modifySTRef' x f = do
    v <- readSTRef x
    let fv = f v
    fv `seq` writeSTRef x fv
#endif

-- |Simplify a categorical distribution by combining equivalent events (the new
-- event will have a probability equal to the sum of all the originals).
collectEvents :: (Ord e, Num p, Ord p) => Categorical p e -> Categorical p e
collectEvents :: Categorical p e -> Categorical p e
collectEvents = (e -> e -> Ordering)
-> ([(p, e)] -> (p, e)) -> Categorical p e -> Categorical p e
forall p e.
Num p =>
(e -> e -> Ordering)
-> ([(p, e)] -> (p, e)) -> Categorical p e -> Categorical p e
collectEventsBy e -> e -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (([p] -> p
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([p] -> p) -> ([e] -> e) -> ([p], [e]) -> (p, e)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** [e] -> e
forall a. [a] -> a
head) (([p], [e]) -> (p, e))
-> ([(p, e)] -> ([p], [e])) -> [(p, e)] -> (p, e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(p, e)] -> ([p], [e])
forall a b. [(a, b)] -> ([a], [b])
unzip)

-- |Simplify a categorical distribution by combining equivalent events (the new
-- event will have a weight equal to the sum of all the originals).
-- The comparator function is used to identify events to combine.  Once chosen,
-- the events and their weights are combined by the provided probability and
-- event aggregation function.
collectEventsBy :: Num p => (e -> e -> Ordering) -> ([(p,e)] -> (p,e))-> Categorical p e -> Categorical p e
collectEventsBy :: (e -> e -> Ordering)
-> ([(p, e)] -> (p, e)) -> Categorical p e -> Categorical p e
collectEventsBy e -> e -> Ordering
compareE [(p, e)] -> (p, e)
combine =
    [(p, e)] -> Categorical p e
forall p a. Num p => [(p, a)] -> Categorical p a
fromList ([(p, e)] -> Categorical p e)
-> (Categorical p e -> [(p, e)])
-> Categorical p e
-> Categorical p e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(p, e)] -> (p, e)) -> [[(p, e)]] -> [(p, e)]
forall a b. (a -> b) -> [a] -> [b]
map [(p, e)] -> (p, e)
combine ([[(p, e)]] -> [(p, e)])
-> (Categorical p e -> [[(p, e)]]) -> Categorical p e -> [(p, e)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(p, e)] -> [[(p, e)]]
forall a. [(a, e)] -> [[(a, e)]]
groupEvents ([(p, e)] -> [[(p, e)]])
-> (Categorical p e -> [(p, e)]) -> Categorical p e -> [[(p, e)]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(p, e)] -> [(p, e)]
forall a. [(a, e)] -> [(a, e)]
sortEvents ([(p, e)] -> [(p, e)])
-> (Categorical p e -> [(p, e)]) -> Categorical p e -> [(p, e)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Categorical p e -> [(p, e)]
forall p a. Num p => Categorical p a -> [(p, a)]
toList
    where
        groupEvents :: [(a, e)] -> [[(a, e)]]
groupEvents = ((a, e) -> (a, e) -> Bool) -> [(a, e)] -> [[(a, e)]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (\(a, e)
x (a, e)
y -> (a, e) -> e
forall a b. (a, b) -> b
snd (a, e)
x e -> e -> Ordering
`compareE` (a, e) -> e
forall a b. (a, b) -> b
snd (a, e)
y Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ)
        sortEvents :: [(a, e)] -> [(a, e)]
sortEvents  = ((a, e) -> (a, e) -> Ordering) -> [(a, e)] -> [(a, e)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (e -> e -> Ordering
compareE (e -> e -> Ordering)
-> ((a, e) -> e) -> (a, e) -> (a, e) -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (a, e) -> e
forall a b. (a, b) -> b
snd)