{-# LANGUAGE BangPatterns #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.Probability.Distribution -- Copyright : (c) William Yager 2015 -- License : MIT -- Maintainer : will (dot) yager (at) gmail (dot) com -- Stability : provisional -- Portability : portable -- -- This module provides a data structure and associated functions for -- representing discrete probability distributions. -- -- All time and space complexity metrics are given in terms of @n@. In this -- case, @n@ refers to the number of unique outcomes inserted into the tree. -- If one were to construct a tree by inserting a billion of the same -- outcome, @n@ would still be 1. -- -- The data structure is optimized for fast sampling from the distribution. -- Sampling ranges from @O(1)@ to @O(log(n))@ depending on the distribution. -- -- Under the hood, the distribution is represented by a perfectly balanced -- binary tree. The tree enforces a heap property, where more likely outcomes -- are closer to the top than less likely outcomes. Because we're more -- likely to sample from those outcomes, we minimize the amount of time -- spent traversing the tree. -- -- When a duplicate outcome is inserted into the tree, the tree's "dups" -- counter is incremented. When more than half the tree is duplicate entries, -- the entire tree is rebuilt from scratch. Using amortized complexity -- analysis, we can show that insertion is, at worst, @log(n)@ amortized -- complexity. This prevents the size of tree from increasing to more than -- @O(n)@, even with many duplicate outcomes inserted. ----------------------------------------------------------------------------- module Numeric.Probability.Distribution ( -- * The distribution type Distribution, -- * Probability operations sample, cumulate, normalize, lookup, -- * Building empty, insert, fromList, -- * Modifying reweight, -- * Reducing toList, foldrWithP, -- * Combining joint, sum, -- * Debugging invariants ) where import Prelude hiding (product, sum, lookup) import Control.Monad.Random (MonadRandom, Random, getRandomRs) import Data.Word (Word) import qualified Data.Map.Strict as Map import Data.List (foldl') -- | A probability distribution with probabilities of type @p@ and -- outcomes/events of type @o@. data Distribution p o = Distribution !(DTree p o) -- Used for sampling !(Map.Map o p) -- Used for looking up probability by outcome and for deduping !Word -- Duplicate entry count data DTree p o = Leaf | DTree !o -- Some outcome !p -- Probability weight of this outcome !p -- Total probability weight contained in this subtree !Word -- Number of outcomes (not necessarily distinct) in this subtree !(DTree p o) !(DTree p o) probOf :: Num p => DTree p o -> p probOf Leaf = 0 probOf (DTree _ p _ _ _ _) = p sumOf :: Num p => DTree p o -> p sumOf Leaf = 0 sumOf (DTree _ _ s _ _ _) = s countOf :: DTree p o -> Word countOf Leaf = 0 countOf (DTree _ _ _ c _ _) = c instance (Num p, Show p, Ord o, Show o) => Show (Distribution p o) where show dist = "fromList " ++ show (toList dist) -- | Reweights the probabilities in the distribution based on -- the given function. @n*log(n)@ reweight :: (Ord o, Num p, Ord p) => (o -> p -> p) -> Distribution p o -> Distribution p o reweight f = fromUniqList . map update . toList where update (o,p) = (o, f o p) -- | The sum of all probabilities in the distribution. @O(1)@ cumulate :: (Num p) => Distribution p o -> p cumulate (Distribution tree _ _) = sumOf tree -- | Normalizes the distribution. -- After normalizing, @'cumulate' distribution@ is 1. @O(n)@ -- Returns Nothing if distribution is empty. normalize :: (Fractional p) => Distribution p o -> Maybe (Distribution p o) normalize (Distribution Leaf _ _) = Nothing normalize (Distribution tree@(DTree _ _ sum' _ _ _) members dups) = Just $ Distribution (normalize' sum' tree) (Map.map (/ sum') members) dups where normalize' _ Leaf = Leaf normalize' sum'' (DTree e p s c l r) = DTree e (p/sum'') (s/sum'') c l' r' where l' = normalize' sum'' l r' = normalize' sum'' r -- | Insert an outcome into the distribution. -- Inserting @(o,p1)@ and @(o,p2)@ results in the same sampled distribution as -- inserting @(o,p1+p2)@. @O(log(n))@ amortized. insert :: (Ord o, Num p, Ord p) => (o,p) -> Distribution p o -> Distribution p o insert (_, 0) dist = dist insert (o',p') (Distribution tree outcomes dups) = if dups' * 2 <= countOf tree then distribution' -- Not too many repeated elements else fromUniqList . toList $ distribution' where dups' = if Map.member o' outcomes then dups + 1 else dups outcomes' = Map.insertWith (+) o' p' outcomes tree' = insertTree (o',p') tree distribution' = Distribution tree' outcomes' dups' -- | The empty distribution. @O(1)@ empty :: (Num p) => Distribution p o empty = Distribution Leaf Map.empty 0 reduce :: (Ord o, Num p) => [(o,p)] -> Map.Map o p reduce = foldl' (\theMap (o,p) -> Map.insertWith (+) o p theMap) Map.empty -- | @O(n*log(n))@ fromList :: (Ord o, Num p, Ord p) => [(o,p)] -> Distribution p o fromList = fromUniqList . Map.toList . reduce -- | @O(n*log(n))@ toList :: (Ord o, Num p) => Distribution p o -> [(o,p)] toList = Map.toList . reduce . toRepeatList -- | Doesn't bother to remove duplicates. @O(n*log(n))@ amortized. fromUniqList :: (Ord o, Num p, Ord p) => [(o,p)] -> Distribution p o fromUniqList = foldl' (\dist pair -> insert pair dist) empty -- | Doesn't bother to eliminate repeats. @O(n)@ toRepeatList :: Distribution p o -> [(o,p)] toRepeatList = foldrWithP (:) [] -- | A right-associative fold on the tree structure, including the -- probabilities. Note that outcomes may be repeated within the data structure. -- If you want identical outcomes to be lumped together, fold on the list -- produced by @'toList'@. @O(n)@. foldrWithP :: ((o,p) -> b -> b) -> b -> Distribution p o -> b foldrWithP f b (Distribution tree _ _) = foldrTreeWithP f b tree foldrTreeWithP :: ((o,p) -> b -> b) -> b -> DTree p o -> b foldrTreeWithP _ b Leaf = b foldrTreeWithP f b (DTree o p _ _ l r) = foldrTreeWithP f (f (o,p) (foldrTreeWithP f b r)) l insertTree :: (Num p, Ord p) => (o,p) -> DTree p o -> DTree p o insertTree (o',p') Leaf = DTree o' p' p' 1 Leaf Leaf insertTree (o',p') (DTree o p s c l r) | p' <= p = if countOf l < countOf r then DTree o p s' c' (insertTree (o',p') l) r else DTree o p s' c' l (insertTree (o',p') r) | otherwise = if countOf l < countOf r then DTree o' p' s' c' (insertTree (o,p) l) r else DTree o' p' s' c' l (insertTree (o,p) r) where s' = s + p' c' = c + 1 -- | Creates a new distribution that's the joint distribution of the two provided. -- @O(nm*log(nm))@ amortized. joint :: (Ord o1, Ord o2, Num p, Ord p) => Distribution p o1 -> Distribution p o2 -> Distribution p (o1, o2) joint da db = fromList $ [((a,b), pa * pb) | (a,pa) <- toList da, (b,pb) <- toList db] -- | Creates a new distribution by summing the probabilities of the outcomes -- in the two provided. @O((n+m)log(n+m))@ amortized. sum :: (Ord o, Num p, Ord p) => Distribution p o -> Distribution p o -> Distribution p o sum da db = fromList $ toRepeatList da ++ toRepeatList db -- Returns random value in range (0,n] randomPositiveUpto :: (Eq n, Num n, Random n, MonadRandom m) => n -> m n randomPositiveUpto n = do randoms <- getRandomRs (0,n) return . head . dropWhile (==0) $ randoms -- | Given an outcome, returns the probability. Note that the probability -- is not always normalized. If you want the probability to be in the 0-1 -- range, you should divide it by @cumulate dist@ (the sum of the probability -- of all outcomes) lookup :: (Ord o, Num p) => Distribution p o -> o -> p lookup (Distribution _ members _) outcome = case Map.lookup outcome members of Nothing -> 0 Just probability -> probability -- | Take a sample from the distribution. Can be used with e.g. @evalRand@ -- or @evalRandIO@ from @Control.Monad.Random@. @O(log(n))@ for a uniform -- distribution (worst case), but approaches @O(1)@ with less balanced -- distributions. -- Returns Nothing on an empty distribution. sample :: (Ord p, Num p, Random p, MonadRandom m) => Distribution p o -> m (Maybe o) sample (Distribution tree _ _) = sampleTree tree sampleTree :: (Ord p, Num p, Random p, MonadRandom m) => DTree p o -> m (Maybe o) sampleTree Leaf = return Nothing sampleTree (DTree event prob sum' _ l r) = do index <- randomPositiveUpto sum' let result | index > sumOf l + prob = sampleTree r | index > sumOf l = return (Just event) | otherwise = sampleTree l result sizeInvariant :: (Num p, Eq p) => DTree p o -> Either String () sizeInvariant Leaf = Right () sizeInvariant (DTree _ _ _ c l r) | (c /= countOf l + countOf r + 1) = Left $ "Count mismatch" | (countOf l > countOf r + 1) = Left $ "Left is too heavy" | (countOf r > countOf l + 1) = Left $ "Right is too heavy" | otherwise = sizeInvariant l >> sizeInvariant r sumInvariant :: (Show p, Num p, Eq p) => DTree p e -> Either String () sumInvariant Leaf = Right () sumInvariant _ = Right () -- Fails with floating point numbers, due to very small errors --sumInvariant (DTree e p s c l r) -- | (s /= p + sumOf l + sumOf r) = Left $ "Sum mismatch:" ++ show [s,p,sumOf l,sumOf r] -- | otherwise = (sumInvariant l) >> (sumInvariant r) heapInvariant :: (Ord p, Num p) => DTree p e -> Either String () heapInvariant Leaf = Right () heapInvariant (DTree _ p _ _ l r) | (p < probOf l) = Left $ "Heap violation on left" | (p < probOf r) = Left $ "Heap violation on right" | otherwise = heapInvariant l >> heapInvariant r zeroInvariant :: (Ord p, Num p) => DTree p e -> Either String () zeroInvariant Leaf = Right () zeroInvariant (DTree _ p _ _ l r) | (p == 0) = Left $ "Zero value in tree" | otherwise = zeroInvariant l >> zeroInvariant r memberInvariant :: (Eq p, Num p, Ord o) => Distribution p o -> Either String () memberInvariant _ = Right () -- Fails with floating numbers due to rounding error --memberInvariant dist@(Distribution _ members _) -- | reduce (toList dist) == members = Right () -- | otherwise = Left $ "Reduction doesn't match member map" -- | A series of tests on the internal structure of the distribution. -- For debugging purposes. invariants :: (Num p, Ord p, Show p, Ord e, Show e) => Distribution p e -> Either String () invariants dist@(Distribution tree _ _) = do sizeInvariant tree sumInvariant tree heapInvariant tree zeroInvariant tree memberInvariant dist