{-
 -      ``Data/Random/Distribution/Discrete''
 -}
{-# LANGUAGE
    MultiParamTypeClasses,
    FlexibleInstances, FlexibleContexts
  #-}

module Data.Random.Distribution.Discrete where

import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Data.Random.List (randomElement)

import Control.Arrow
import Control.Monad
import Control.Applicative
import Data.Foldable (Foldable(foldMap))
import Data.Traversable (Traversable(traverse, sequenceA))

import Data.List
import Data.Function

discrete :: Distribution (Discrete p) a => [(p,a)] -> RVar a
discrete ps = rvar (Discrete ps)

empirical :: (Num p, Ord a) => [a] -> Discrete p a
empirical xs = Discrete bins
    where bins = [ (genericLength bin, x)
                 | bin@(x:_) <- group (sort xs)
                 ]

newtype Discrete p a = Discrete [(p, a)]
    deriving (Eq, Show)

instance (Num p, Ord p, Distribution Uniform p) => Distribution (Discrete p) a where
    rvar (Discrete []) = fail "discrete distribution over empty set cannot be sampled"
    rvar (Discrete ds) = do
        let (ps, xs) = unzip ds
            cs = scanl1 (+) ps
        
        when (any (<0) ps) $ fail "negative probability in discrete distribution"
        
        let totalWeight = last cs
        if  totalWeight <= 0
            -- if all events are "equally impossible", just pick an arbitrary one.
            then randomElement xs
            else do
                let getU = do
                        u <- uniform 0 totalWeight
                        -- reject 0; this causes integral weights to behave sensibly (although it 
                        -- potentially wastes up to 50% of all sampled integral values in the
                        -- case where totalWeight = 1) and is still valid for fractional weights
                        -- (it only prevents zero-probability events from ever occurring, which 
                        -- is reasonable).
                        if u == 0 then getU
                                  else return u
                u <- getU
                return $ head
                    [ x
                    | (c,x) <- zip cs xs
                    , c >= u
                    ]

instance Functor (Discrete p) where
    fmap f (Discrete ds) = Discrete [(p, f x) | (p, x) <- ds]

instance Foldable (Discrete p) where
    foldMap f (Discrete ds) = foldMap (f . snd) ds

instance Traversable (Discrete p) where
    traverse f (Discrete ds) = Discrete <$> traverse (\(p,e) -> (\e -> (p,e)) <$> f e) ds
    sequenceA  (Discrete ds) = Discrete <$> traverse (\(p,e) -> (\e -> (p,e)) <$>   e) ds

-- We want each subset of cases in fx derived from a given case 
-- in x to have the same relative weight as the set in x from whence they came.
instance Num p => Monad (Discrete p) where
    return x = Discrete [(1, x)]
    (Discrete x) >>= f = Discrete $ do
        (p, x) <- x
        
        let Discrete fx = f x
        (q, x) <- fx
        
        return (p * q, x)

instance Num p => Applicative (Discrete p) where
    pure = return
    (<*>) = ap

-- |Like 'fmap', but for the weights of a discrete distribution.
mapDiscreteWeights :: (p -> q) -> Discrete p e -> Discrete q e
mapDiscreteWeights f (Discrete ds) = Discrete [(f p, x) | (p, x) <- ds]

-- |Adjust all the weights of a discrete distribution so that they 
-- sum to unity.  If not possible, returns the original distribution 
-- unchanged.
normalizeDiscreteWeights :: (Fractional p) => Discrete p e -> Discrete p e
normalizeDiscreteWeights orig@(Discrete ds) = 
    -- For practical purposes the scale factor is strict anyway,
    -- so check if it's 0 or 1 and, if so, skip the actual scaling part.
    if ws `elem` [0,1]
        then orig
        else Discrete
                [ (w * scale, e)
                | (w, e) <- ds
                ] 
    where
        ws = sum (map fst ds)
        scale = recip ws

-- |Simplify a discrete distribution by combining equivalent events (the new
-- event will have a weight equal to the sum of all the originals).
collectDiscreteEvents :: (Ord e, Num p, Ord p) => Discrete p e -> Discrete p e
collectDiscreteEvents = collectDiscreteEventsBy compare sum head
        
-- |Simplify a discrete distribution by combining equivalent events (the new
-- event will have a weight equal to the sum of all the originals).
-- The comparator function is used to identify events to combine.  Once chosen,
-- the events and their weights are combined (independently) by the provided
-- weight and event aggregation functions.
collectDiscreteEventsBy :: (e -> e -> Ordering) -> ([p] -> p) -> ([e] -> e)-> Discrete p e -> Discrete p e
collectDiscreteEventsBy compareE sumWeights mergeEvents (Discrete ds) = 
    Discrete . map ((sumWeights *** mergeEvents) . unzip) . groupEvents . sortEvents $ ds
    
    where
        groupEvents = groupBy (\x y -> snd x `compareE` snd y == EQ)
        sortEvents  = sortBy (compareE `on` snd)
        
        weight (p,x)
            | p < 0     = error "negative probability in discrete distribution"
            | otherwise = p
        event ((p,x):_) = x
        
        combine (ps, xs) = (sumWeights ps, mergeEvents xs)