module Numeric.Probability.Distribution (
Distribution,
sample,
cumulate,
normalize,
empty,
insert,
fromList,
reweight,
toList,
foldrWithP,
joint,
sum,
invariants
) where
import Prelude hiding (product, sum)
import Control.Monad.Random (MonadRandom, Random, getRandomRs)
import Data.Word (Word)
import Data.Set (Set, member)
import qualified Data.Set as Set
import qualified Data.Map.Strict as Map
import Data.List (foldl')
data Distribution p o = Distribution !(DTree p o) !(Set o) !Word
data DTree p o = Leaf
| DTree !o !p !p !Word !(DTree p o) !(DTree p o)
outcomeOf (DTree o _ _ _ _ _) = o
probOf Leaf = 0
probOf (DTree _ p _ _ _ _) = p
sumOf Leaf = 0
sumOf (DTree _ _ s _ _ _) = s
countOf Leaf = 0
countOf (DTree _ _ _ c _ _) = c
leftOf (DTree _ _ _ _ l _) = l
rightOf (DTree _ _ _ _ _ r) = r
instance (Num p, Show p, Ord o, Show o) => Show (Distribution p o) where
show dist = "fromList " ++ show (toList dist)
reweight :: (Ord o, Num p, Ord p) => ((o,p) -> p) -> Distribution p o -> Distribution p o
reweight f = fromUniqList . map update . toList
where update (o,p) = (o, f (o,p))
cumulate :: (Num p) => Distribution p o -> p
cumulate (Distribution tree _ _) = sumOf tree
normalize :: (Fractional p) => Distribution p o -> Distribution p o
normalize (Distribution Leaf _ _) = error "Can't normalize empty distribution"
normalize (Distribution tree@(DTree _ _ sum _ _ _) members dups) =
Distribution (normalize' sum tree) members dups
normalize' sum Leaf = Leaf
normalize' sum (DTree e p s c l r) = DTree e (p/sum) (s/sum) c l' r'
where
l' = normalize' sum l
r' = normalize' sum r
insert :: (Ord o, Num p, Ord p) => (o,p) -> Distribution p o -> Distribution p o
insert (_, 0) dist = dist
insert (o',p') (Distribution tree outcomes dups) = if dups' * 2 <= countOf tree
then distribution'
else fromUniqList . toList $ distribution'
where
dups' = if o' `member` outcomes then dups + 1 else dups
outcomes' = Set.insert o' outcomes
tree' = insertTree (o',p') tree
distribution' = Distribution tree' outcomes' dups'
empty :: (Num p) => Distribution p o
empty = Distribution Leaf Set.empty 0
reduce :: (Ord o, Num p) => [(o,p)] -> Map.Map o p
reduce = foldl' (\map (o,p) -> Map.insertWith (+) o p map) Map.empty
fromList :: (Ord o, Num p, Ord p) => [(o,p)] -> Distribution p o
fromList = fromUniqList . Map.toList . reduce
toList :: (Ord o, Num p) => Distribution p o -> [(o,p)]
toList = Map.toList . reduce . toRepeatList
fromUniqList :: (Ord o, Num p, Ord p) => [(o,p)] -> Distribution p o
fromUniqList = foldl' (\dist pair -> insert pair dist) empty
toRepeatList :: Distribution p o -> [(o,p)]
toRepeatList = foldrWithP (:) []
foldrWithP :: ((o,p) -> b -> b) -> b -> Distribution p o -> b
foldrWithP f b (Distribution tree _ _) = foldrTreeWithP f b tree
foldrTreeWithP :: ((o,p) -> b -> b) -> b -> DTree p o -> b
foldrTreeWithP f b Leaf = b
foldrTreeWithP f b (DTree o p _ _ l r) = foldrTreeWithP f (f (o,p) (foldrTreeWithP f b r)) l
insertTree :: (Num p, Ord p) => (o,p) -> DTree p o -> DTree p o
insertTree (o',p') Leaf = DTree o' p' p' 1 Leaf Leaf
insertTree (o',p') (DTree o p s c l r)
| p' <= p = if countOf l < countOf r
then DTree o p s' c' (insertTree (o',p') l) r
else DTree o p s' c' l (insertTree (o',p') r)
| p' > p = if countOf l < countOf r
then DTree o' p' s' c' (insertTree (o,p) l) r
else DTree o' p' s' c' l (insertTree (o,p) r)
where
s' = s + p'
c' = c + 1
joint :: (Ord o1, Ord o2, Num p, Ord p) => Distribution p o1 -> Distribution p o2 -> Distribution p (o1, o2)
joint da db = fromList $ [((a,b), pa * pb) |
(a,pa) <- toList da,
(b,pb) <- toList db]
sum :: (Ord o, Num p, Ord p) => Distribution p o -> Distribution p o -> Distribution p o
sum da db = fromList $ toRepeatList da ++ toRepeatList db
randomPositiveUpto :: (Eq n, Num n, Random n, MonadRandom m) => n -> m n
randomPositiveUpto n = do
randoms <- getRandomRs (0,n)
return . head . dropWhile (==0) $ randoms
sample :: (Ord p, Num p, Random p, MonadRandom m) => Distribution p o -> m o
sample (Distribution tree _ _) = sampleTree tree
sampleTree :: (Ord p, Num p, Random p, MonadRandom m) => DTree p o -> m o
sampleTree Leaf = error "Error: Can't sample an empty distribution"
sampleTree (DTree event prob sum count l r) = do
index <- randomPositiveUpto sum
let result | index > sumOf l + prob = sampleTree r
| index > sumOf l = return event
| index > 0 = sampleTree l
result
sizeInvariant :: (Num p, Eq p) => DTree p o -> Either String ()
sizeInvariant Leaf = Right ()
sizeInvariant (DTree e p s c l r)
| (c /= countOf l + countOf r + 1) = Left $ "Count mismatch"
| (countOf l > countOf r + 1) = Left $ "Left is too heavy"
| (countOf r > countOf l + 1) = Left $ "Right is too heavy"
| otherwise = sizeInvariant l >> sizeInvariant r
sumInvariant :: (Show p, Num p, Eq p) => DTree p e -> Either String ()
sumInvariant Leaf = Right ()
sumInvariant _ = Right ()
heapInvariant :: (Ord p, Num p) => DTree p e -> Either String ()
heapInvariant Leaf = Right ()
heapInvariant (DTree e p s c l r)
| (p < probOf l) = Left $ "Heap violation on left"
| (p < probOf r) = Left $ "Heap violation on right"
| otherwise = heapInvariant l >> heapInvariant r
zeroInvariant :: (Ord p, Num p) => DTree p e -> Either String ()
zeroInvariant Leaf = Right ()
zeroInvariant (DTree _ p _ c l r)
| (p == 0) = Left $ "Zero value in tree"
| otherwise = zeroInvariant l >> zeroInvariant r
invariants :: (Num p, Ord p, Show p, Ord e, Show e) => Distribution p e -> Either String ()
invariants (Distribution tree members dups) = do
sizeInvariant tree
sumInvariant tree
heapInvariant tree
zeroInvariant tree