{-# LANGUAGE BangPatterns #-}
module Numeric.Probability.Distribution (
Distribution,
sample,
cumulate,
normalize,
lookup,
empty,
insert,
fromList,
reweight,
toList,
foldrWithP,
joint,
sum,
invariants
) where
import Prelude hiding (product, sum, lookup)
import Control.Monad.Random (MonadRandom, Random, getRandomRs)
import Data.Word (Word)
import qualified Data.Map.Strict as Map
import Data.List (foldl')
data Distribution p o = Distribution
!(DTree p o)
!(Map.Map o p)
!Word
data DTree p o = Leaf
| DTree
!o
!p
!p
!Word
!(DTree p o)
!(DTree p o)
probOf :: Num p => DTree p o -> p
probOf Leaf = 0
probOf (DTree _ p _ _ _ _) = p
sumOf :: Num p => DTree p o -> p
sumOf Leaf = 0
sumOf (DTree _ _ s _ _ _) = s
countOf :: DTree p o -> Word
countOf Leaf = 0
countOf (DTree _ _ _ c _ _) = c
instance (Num p, Show p, Ord o, Show o) => Show (Distribution p o) where
show dist = "fromList " ++ show (toList dist)
reweight :: (Ord o, Num p, Ord p) => (o -> p -> p) -> Distribution p o -> Distribution p o
reweight f = fromUniqList . map update . toList
where update (o,p) = (o, f o p)
cumulate :: (Num p) => Distribution p o -> p
cumulate (Distribution tree _ _) = sumOf tree
normalize :: (Fractional p) => Distribution p o -> Maybe (Distribution p o)
normalize (Distribution Leaf _ _) = Nothing
normalize (Distribution tree@(DTree _ _ sum' _ _ _) members dups) =
Just $ Distribution (normalize' sum' tree) (Map.map (/ sum') members) dups
where
normalize' _ Leaf = Leaf
normalize' sum'' (DTree e p s c l r) = DTree e (p/sum'') (s/sum'') c l' r'
where
l' = normalize' sum'' l
r' = normalize' sum'' r
insert :: (Ord o, Num p, Ord p) => (o,p) -> Distribution p o -> Distribution p o
insert (_, 0) dist = dist
insert (o',p') (Distribution tree outcomes dups) = if dups' * 2 <= countOf tree
then distribution'
else fromUniqList . toList $ distribution'
where
dups' = if Map.member o' outcomes then dups + 1 else dups
outcomes' = Map.insertWith (+) o' p' outcomes
tree' = insertTree (o',p') tree
distribution' = Distribution tree' outcomes' dups'
empty :: (Num p) => Distribution p o
empty = Distribution Leaf Map.empty 0
reduce :: (Ord o, Num p) => [(o,p)] -> Map.Map o p
reduce = foldl' (\theMap (o,p) -> Map.insertWith (+) o p theMap) Map.empty
fromList :: (Ord o, Num p, Ord p) => [(o,p)] -> Distribution p o
fromList = fromUniqList . Map.toList . reduce
toList :: (Ord o, Num p) => Distribution p o -> [(o,p)]
toList = Map.toList . reduce . toRepeatList
fromUniqList :: (Ord o, Num p, Ord p) => [(o,p)] -> Distribution p o
fromUniqList = foldl' (\dist pair -> insert pair dist) empty
toRepeatList :: Distribution p o -> [(o,p)]
toRepeatList = foldrWithP (:) []
foldrWithP :: ((o,p) -> b -> b) -> b -> Distribution p o -> b
foldrWithP f b (Distribution tree _ _) = foldrTreeWithP f b tree
foldrTreeWithP :: ((o,p) -> b -> b) -> b -> DTree p o -> b
foldrTreeWithP _ b Leaf = b
foldrTreeWithP f b (DTree o p _ _ l r) = foldrTreeWithP f (f (o,p) (foldrTreeWithP f b r)) l
insertTree :: (Num p, Ord p) => (o,p) -> DTree p o -> DTree p o
insertTree (o',p') Leaf = DTree o' p' p' 1 Leaf Leaf
insertTree (o',p') (DTree o p s c l r)
| p' <= p = if countOf l < countOf r
then DTree o p s' c' (insertTree (o',p') l) r
else DTree o p s' c' l (insertTree (o',p') r)
| otherwise = if countOf l < countOf r
then DTree o' p' s' c' (insertTree (o,p) l) r
else DTree o' p' s' c' l (insertTree (o,p) r)
where
s' = s + p'
c' = c + 1
joint :: (Ord o1, Ord o2, Num p, Ord p) => Distribution p o1 -> Distribution p o2 -> Distribution p (o1, o2)
joint da db = fromList $ [((a,b), pa * pb) |
(a,pa) <- toList da,
(b,pb) <- toList db]
sum :: (Ord o, Num p, Ord p) => Distribution p o -> Distribution p o -> Distribution p o
sum da db = fromList $ toRepeatList da ++ toRepeatList db
randomPositiveUpto :: (Eq n, Num n, Random n, MonadRandom m) => n -> m n
randomPositiveUpto n = do
randoms <- getRandomRs (0,n)
return . head . dropWhile (==0) $ randoms
lookup :: (Ord o, Num p) => Distribution p o -> o -> p
lookup (Distribution _ members _) outcome = case Map.lookup outcome members of
Nothing -> 0
Just probability -> probability
sample :: (Ord p, Num p, Random p, MonadRandom m) => Distribution p o -> m (Maybe o)
sample (Distribution tree _ _) = sampleTree tree
sampleTree :: (Ord p, Num p, Random p, MonadRandom m) => DTree p o -> m (Maybe o)
sampleTree Leaf = return Nothing
sampleTree (DTree event prob sum' _ l r) = do
index <- randomPositiveUpto sum'
let result | index > sumOf l + prob = sampleTree r
| index > sumOf l = return (Just event)
| otherwise = sampleTree l
result
sizeInvariant :: (Num p, Eq p) => DTree p o -> Either String ()
sizeInvariant Leaf = Right ()
sizeInvariant (DTree _ _ _ c l r)
| (c /= countOf l + countOf r + 1) = Left $ "Count mismatch"
| (countOf l > countOf r + 1) = Left $ "Left is too heavy"
| (countOf r > countOf l + 1) = Left $ "Right is too heavy"
| otherwise = sizeInvariant l >> sizeInvariant r
sumInvariant :: (Show p, Num p, Eq p) => DTree p e -> Either String ()
sumInvariant Leaf = Right ()
sumInvariant _ = Right ()
heapInvariant :: (Ord p, Num p) => DTree p e -> Either String ()
heapInvariant Leaf = Right ()
heapInvariant (DTree _ p _ _ l r)
| (p < probOf l) = Left $ "Heap violation on left"
| (p < probOf r) = Left $ "Heap violation on right"
| otherwise = heapInvariant l >> heapInvariant r
zeroInvariant :: (Ord p, Num p) => DTree p e -> Either String ()
zeroInvariant Leaf = Right ()
zeroInvariant (DTree _ p _ _ l r)
| (p == 0) = Left $ "Zero value in tree"
| otherwise = zeroInvariant l >> zeroInvariant r
memberInvariant :: (Eq p, Num p, Ord o) => Distribution p o -> Either String ()
memberInvariant _ = Right ()
invariants :: (Num p, Ord p, Show p, Ord e, Show e) => Distribution p e -> Either String ()
invariants dist@(Distribution tree _ _) = do
sizeInvariant tree
sumInvariant tree
heapInvariant tree
zeroInvariant tree
memberInvariant dist