{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, FlexibleContexts, CPP #-} {-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-} module Data.Random.Distribution.Categorical ( Categorical , categorical, categoricalT , weightedCategorical, weightedCategoricalT , fromList, toList, totalWeight, numEvents , fromWeightedList, fromObservations , mapCategoricalPs, normalizeCategoricalPs , collectEvents, collectEventsBy ) where import Data.Random.RVar import Data.Random.Distribution import Data.Random.Distribution.Uniform import Control.Arrow import Control.Monad import Control.Monad.ST import Data.Foldable (Foldable(foldMap)) import Data.STRef import Data.Traversable (Traversable(traverse, sequenceA)) import Data.List import Data.Function import qualified Data.Vector as V import qualified Data.Vector.Mutable as MV -- |Construct a 'Categorical' random variable from a list of probabilities -- and categories, where the probabilities all sum to 1. categorical :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVar a categorical = rvar . fromList -- |Construct a 'Categorical' random process from a list of probabilities -- and categories, where the probabilities all sum to 1. categoricalT :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVarT m a categoricalT = rvarT . fromList -- |Construct a 'Categorical' random variable from a list of weights -- and categories. The weights do /not/ have to sum to 1. weightedCategorical :: (Fractional p, Eq p, Distribution (Categorical p) a) => [(p,a)] -> RVar a weightedCategorical = rvar . fromWeightedList -- |Construct a 'Categorical' random process from a list of weights -- and categories. The weights do /not/ have to sum to 1. weightedCategoricalT :: (Fractional p, Eq p, Distribution (Categorical p) a) => [(p,a)] -> RVarT m a weightedCategoricalT = rvarT . fromWeightedList -- | Construct a 'Categorical' distribution from a list of weighted categories. {-# INLINE fromList #-} fromList :: (Num p) => [(p,a)] -> Categorical p a fromList xs = Categorical (V.fromList (scanl1 f xs)) where f (p0, _) (p1, y) = (p0 + p1, y) {-# INLINE toList #-} toList :: (Num p) => Categorical p a -> [(p,a)] toList (Categorical ds) = V.foldr' g [] ds where g x [] = [x] g x@(p0,_) ((p1, y):xs) = x : (p1-p0,y) : xs totalWeight :: Num p => Categorical p a -> p totalWeight (Categorical ds) | V.null ds = 0 | otherwise = fst (V.last ds) numEvents :: Categorical p a -> Int numEvents (Categorical ds) = V.length ds -- |Construct a 'Categorical' distribution from a list of weighted categories, -- where the weights do not necessarily sum to 1. fromWeightedList :: (Fractional p, Eq p) => [(p,a)] -> Categorical p a fromWeightedList = normalizeCategoricalPs . fromList -- |Construct a 'Categorical' distribution from a list of observed outcomes. -- Equivalent events will be grouped and counted, and the probabilities of each -- event in the returned distribution will be proportional to the number of -- occurrences of that event. fromObservations :: (Fractional p, Eq p, Ord a) => [a] -> Categorical p a fromObservations = fromWeightedList . map (genericLength &&& head) . group . sort -- The following description refers to the public interface. For those reading -- the code, in the actual implementation Categorical is stored as a vector of -- (cumulative-probability, value) pairs, so that sampling can take advantage of -- binary search. -- |Categorical distribution; a list of events with corresponding probabilities. -- The sum of the probabilities must be 1, and no event should have a zero -- or negative probability (at least, at time of sampling; very clever users -- can do what they want with the numbers before sampling, just make sure -- that if you're one of those clever ones, you at least eliminate negative -- weights before sampling). newtype Categorical p a = Categorical (V.Vector (p, a)) deriving Eq instance (Num p, Show p, Show a) => Show (Categorical p a) where showsPrec p cat = showParen (p>10) ( showString "fromList " . showsPrec 11 (toList cat) ) instance (Num p, Read p, Read a) => Read (Categorical p a) where readsPrec p = readParen (p > 10) $ \str -> do ("fromList", valStr) <- lex str (vals, rest) <- readsPrec 11 valStr return (fromList vals, rest) instance (Fractional p, Ord p, Distribution Uniform p) => Distribution (Categorical p) a where rvarT (Categorical ds) | V.null ds = error "categorical distribution over empty set cannot be sampled" | n == 1 = return (snd (V.head ds)) | otherwise = do u <- uniformT 0 (fst (V.last ds)) let -- by construction, p is monotone; (i < j) ==> (p i <= p j) p i = fst (ds V.! i) x i = snd (ds V.! i) -- findEvent -- =========== -- invariants: (i <= j), (u <= p j), ((i == 0) || (p i < u)) -- (the last one means 'i' does not increase unless it bounds 'p' below 'u') -- variant: either i increases or j decreases. -- upon termination: ∀ k. if (k < j) then (p k < u) else (u <= p k) -- (that is, the chosen event 'x j' is the first one whose -- associated cumulative probability 'p j' is greater than -- or equal to 'u') findEvent i j | j <= i = x j | u <= p m = findEvent i m | otherwise = findEvent (max m (i+1)) j where -- midpoint rounding down -- (i < j) ==> (m < j) m = (i + j) `div` 2 return $! if u <= 0 then x 0 else findEvent 0 (n-1) where n = V.length ds instance Functor (Categorical p) where fmap f (Categorical ds) = Categorical (V.map (second f) ds) instance Foldable (Categorical p) where foldMap f (Categorical ds) = foldMap (f . snd) (V.toList ds) instance Traversable (Categorical p) where traverse f (Categorical ds) = Categorical . V.fromList <$> traverse (\(p,e) -> (\e' -> (p,e')) <$> f e) (V.toList ds) sequenceA (Categorical ds) = Categorical . V.fromList <$> traverse (\(p,e) -> (\e' -> (p,e')) <$> e) (V.toList ds) instance Fractional p => Monad (Categorical p) where return x = Categorical (V.singleton (1, x)) -- I'm not entirely sure whether this is a valid form of failure; see next -- set of comments. #if __GLASGOW_HASKELL__ < 808 fail _ = Categorical V.empty #endif -- Should the normalize step be included here, or should normalization -- be assumed? It seems like there is (at least) 1 valid situation where -- non-normal results would arise: the distribution being modeled is -- "conditional" and some event arose that contradicted the assumed -- condition and thus was eliminated ('f' returned an empty or -- zero-probability consequent, possibly by 'fail'ing). -- -- It seems reasonable to continue in such circumstances, but should there -- be any renormalization? If so, does it make a difference when that -- renormalization is done? I'm pretty sure it does, actually. So, the -- normalization will be omitted here for now, as it's easier for the -- user (who really better know what they mean if they're returning -- non-normalized probability anyway) to normalize explicitly than to -- undo any normalization that was done automatically. xs >>= f = {- normalizeCategoricalPs . -} fromList $ do (p, x) <- toList xs (q, y) <- toList (f x) return (p * q, y) instance Fractional p => Applicative (Categorical p) where pure = return (<*>) = ap -- |Like 'fmap', but for the probabilities of a categorical distribution. mapCategoricalPs :: (Num p, Num q) => (p -> q) -> Categorical p e -> Categorical q e mapCategoricalPs f = fromList . map (first f) . toList -- |Adjust all the weights of a categorical distribution so that they -- sum to unity and remove all events whose probability is zero. normalizeCategoricalPs :: (Fractional p, Eq p) => Categorical p e -> Categorical p e normalizeCategoricalPs orig@(Categorical ds) | ps == 0 = Categorical V.empty | otherwise = runST $ do lastP <- newSTRef 0 nDups <- newSTRef 0 normalized <- V.thaw ds let n = V.length ds skip = modifySTRef' nDups (1+) save i p x = do d <- readSTRef nDups MV.write normalized (i-d) (p, x) sequence_ [ do let (p,x) = ds V.! i p0 <- readSTRef lastP if p == p0 then skip else do save i (p * scale) x writeSTRef lastP $! p | i <- [0..n-1] ] -- force last element to 1 d <- readSTRef nDups let n' = n-d (_,lastX) <- MV.read normalized (n'-1) MV.write normalized (n'-1) (1,lastX) Categorical <$> V.unsafeFreeze (MV.unsafeSlice 0 n' normalized) where ps = totalWeight orig scale = recip ps #if __GLASGOW_HASKELL__ < 706 -- |strict 'modifySTRef' modifySTRef' :: STRef s a -> (a -> a) -> ST s () modifySTRef' x f = do v <- readSTRef x let fv = f v fv `seq` writeSTRef x fv #endif -- |Simplify a categorical distribution by combining equivalent events (the new -- event will have a probability equal to the sum of all the originals). collectEvents :: (Ord e, Num p, Ord p) => Categorical p e -> Categorical p e collectEvents = collectEventsBy compare ((sum *** head) . unzip) -- |Simplify a categorical distribution by combining equivalent events (the new -- event will have a weight equal to the sum of all the originals). -- The comparator function is used to identify events to combine. Once chosen, -- the events and their weights are combined by the provided probability and -- event aggregation function. collectEventsBy :: Num p => (e -> e -> Ordering) -> ([(p,e)] -> (p,e))-> Categorical p e -> Categorical p e collectEventsBy compareE combine = fromList . map combine . groupEvents . sortEvents . toList where groupEvents = groupBy (\x y -> snd x `compareE` snd y == EQ) sortEvents = sortBy (compareE `on` snd)