-      ``Data/Random/Distribution/Discrete''
    FlexibleInstances, FlexibleContexts

module Data.Random.Distribution.Discrete where

import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Data.Random.List (randomElement)

import Control.Arrow
import Control.Monad
import Control.Applicative
import Data.Foldable (Foldable(foldMap))
import Data.Traversable (Traversable(traverse, sequenceA))

import Data.List
import Data.Function

discrete :: Distribution (Discrete p) a => [(p,a)] -> RVar a
discrete ps = rvar (Discrete ps)

empirical :: (Num p, Ord a) => [a] -> Discrete p a
empirical xs = Discrete bins
    where bins = [ (genericLength bin, x)
                 | bin@(x:_) <- group (sort xs)

newtype Discrete p a = Discrete [(p, a)]
    deriving (Eq, Show)

instance (Num p, Ord p, Distribution Uniform p) => Distribution (Discrete p) a where
    rvar (Discrete []) = fail "discrete distribution over empty set cannot be sampled"
    rvar (Discrete ds) = do
        let (ps, xs) = unzip ds
            cs = scanl1 (+) ps
        when (any (<0) ps) $ fail "negative probability in discrete distribution"
        let totalWeight = last cs
        if  totalWeight <= 0
            -- if all events are "equally impossible", just pick an arbitrary one.
            then randomElement xs
            else do
                let getU = do
                        u <- uniform 0 totalWeight
                        -- reject 0; this causes integral weights to behave sensibly (although it 
                        -- potentially wastes up to 50% of all sampled integral values in the
                        -- case where totalWeight = 1) and is still valid for fractional weights
                        -- (it only prevents zero-probability events from ever occurring, which 
                        -- is reasonable).
                        if u == 0 then getU
                                  else return u
                u <- getU
                return $ head
                    [ x
                    | (c,x) <- zip cs xs
                    , c >= u

instance Functor (Discrete p) where
    fmap f (Discrete ds) = Discrete [(p, f x) | (p, x) <- ds]

instance Foldable (Discrete p) where
    foldMap f (Discrete ds) = foldMap (f . snd) ds

instance Traversable (Discrete p) where
    traverse f (Discrete ds) = Discrete <$> traverse (\(p,e) -> (\e -> (p,e)) <$> f e) ds
    sequenceA  (Discrete ds) = Discrete <$> traverse (\(p,e) -> (\e -> (p,e)) <$>   e) ds

-- We want each subset of cases in fx derived from a given case 
-- in x to have the same relative weight as the set in x from whence they came.
instance Num p => Monad (Discrete p) where
    return x = Discrete [(1, x)]
    (Discrete x) >>= f = Discrete $ do
        (p, x) <- x
        let Discrete fx = f x
        (q, x) <- fx
        return (p * q, x)

instance Num p => Applicative (Discrete p) where
    pure = return
    (<*>) = ap

-- |Like 'fmap', but for the weights of a discrete distribution.
mapDiscreteWeights :: (p -> q) -> Discrete p e -> Discrete q e
mapDiscreteWeights f (Discrete ds) = Discrete [(f p, x) | (p, x) <- ds]

-- |Adjust all the weights of a discrete distribution so that they 
-- sum to unity.  If not possible, returns the original distribution 
-- unchanged.
normalizeDiscreteWeights :: (Fractional p) => Discrete p e -> Discrete p e
normalizeDiscreteWeights orig@(Discrete ds) = 
    -- For practical purposes the scale factor is strict anyway,
    -- so check if it's 0 or 1 and, if so, skip the actual scaling part.
    if ws `elem` [0,1]
        then orig
        else Discrete
                [ (w * scale, e)
                | (w, e) <- ds
        ws = sum (map fst ds)
        scale = recip ws

-- |Simplify a discrete distribution by combining equivalent events (the new
-- event will have a weight equal to the sum of all the originals).
collectDiscreteEvents :: (Ord e, Num p, Ord p) => Discrete p e -> Discrete p e
collectDiscreteEvents = collectDiscreteEventsBy compare sum head
-- |Simplify a discrete distribution by combining equivalent events (the new
-- event will have a weight equal to the sum of all the originals).
-- The comparator function is used to identify events to combine.  Once chosen,
-- the events and their weights are combined (independently) by the provided
-- weight and event aggregation functions.
collectDiscreteEventsBy :: (e -> e -> Ordering) -> ([p] -> p) -> ([e] -> e)-> Discrete p e -> Discrete p e
collectDiscreteEventsBy compareE sumWeights mergeEvents (Discrete ds) = 
    Discrete . map ((sumWeights *** mergeEvents) . unzip) . groupEvents . sortEvents $ ds
        groupEvents = groupBy (\x y -> snd x `compareE` snd y == EQ)
        sortEvents  = sortBy (compareE `on` snd)
        weight (p,x)
            | p < 0     = error "negative probability in discrete distribution"
            | otherwise = p
        event ((p,x):_) = x
        combine (ps, xs) = (sumWeights ps, mergeEvents xs)