{-# LANGUAGE
    MultiParamTypeClasses,
    FlexibleInstances, FlexibleContexts
  #-}

module Data.Random.Distribution.Categorical
    ( Categorical
    , categorical, categoricalT
    , weightedCategorical, weightedCategoricalT
    , fromList, toList, totalWeight, numEvents
    , fromWeightedList, fromObservations
    , mapCategoricalPs, normalizeCategoricalPs
    , collectEvents, collectEventsBy
    ) where

import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform

import Control.Arrow
import Control.Monad
import Control.Monad.ST
import Control.Applicative
import Data.Foldable (Foldable(foldMap))
import Data.STRef
import Data.Traversable (Traversable(traverse, sequenceA))

import Data.List
import Data.Function
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV

-- |Construct a 'Categorical' random variable from a list of probabilities
-- and categories, where the probabilities all sum to 1.
categorical :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVar a
categorical = rvar . fromList

-- |Construct a 'Categorical' random process from a list of probabilities 
-- and categories, where the probabilities all sum to 1.
categoricalT :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVarT m a
categoricalT = rvarT . fromList

-- |Construct a 'Categorical' random variable from a list of probabilities
-- and categories, where the probabilities all sum to 1.
weightedCategorical :: (Fractional p, Eq p, Distribution (Categorical p) a) => [(p,a)] -> RVar a
weightedCategorical = rvar . fromWeightedList

-- |Construct a 'Categorical' random process from a list of probabilities 
-- and categories, where the probabilities all sum to 1.
weightedCategoricalT :: (Fractional p, Eq p, Distribution (Categorical p) a) => [(p,a)] -> RVarT m a
weightedCategoricalT = rvarT . fromWeightedList

-- | Construct a 'Categorical' distribution from a list of weighted categories.
{-# INLINE fromList #-}
fromList :: (Num p) => [(p,a)] -> Categorical p a
fromList xs = Categorical (V.fromList (scanl1 f xs))
    where f (p0, _) (p1, y) = (p0 + p1, y)

{-# INLINE toList #-}
toList :: (Num p) => Categorical p a -> [(p,a)]
toList (Categorical ds) = V.foldr' g [] ds
    where
        g x [] = [x]
        g x@(p0,_) ((p1, y):xs) = x : (p1-p0,y) : xs

totalWeight :: Num p => Categorical p a -> p
totalWeight (Categorical ds)
    | V.null ds = 0
    | otherwise = fst (V.last ds)

numEvents :: Categorical p a -> Int
numEvents (Categorical ds) = V.length ds

-- |Construct a 'Categorical' distribution from a list of weighted categories, 
-- where the weights do not necessarily sum to 1.
fromWeightedList :: (Fractional p, Eq p) => [(p,a)] -> Categorical p a
fromWeightedList = normalizeCategoricalPs . fromList

-- |Construct a 'Categorical' distribution from a list of observed outcomes.
-- Equivalent events will be grouped and counted, and the probabilities of each
-- event in the returned distribution will be proportional to the number of 
-- occurrences of that event.
fromObservations :: (Fractional p, Eq p, Ord a) => [a] -> Categorical p a
fromObservations = fromWeightedList . map (genericLength &&& head) . group . sort

-- The following description refers to the public interface.  For those reading
-- the code, in the actual implementation Categorical is stored as a vector of
-- (cumulative-probability, value) pairs, so that sampling can take advantage of
-- binary search.

-- |Categorical distribution; a list of events with corresponding probabilities.
-- The sum of the probabilities must be 1, and no event should have a zero 
-- or negative probability (at least, at time of sampling; very clever users
-- can do what they want with the numbers before sampling, just make sure 
-- that if you're one of those clever ones, you at least eliminate negative 
-- weights before sampling).
newtype Categorical p a = Categorical (V.Vector (p, a))
    deriving Eq

instance (Num p, Show p, Show a) => Show (Categorical p a) where
    showsPrec p cat = showParen (p>10)
        ( showString "fromList "
        . showsPrec 11 (toList cat)
        )

instance (Num p, Read p, Read a) => Read (Categorical p a) where
  readsPrec p = readParen (p > 10) $ \str -> do
                  ("fromList", valStr) <- lex str
                  (vals,       rest)   <- readsPrec 11 valStr
                  return (fromList vals, rest)

instance (Fractional p, Ord p, Distribution Uniform p) => Distribution (Categorical p) a where
    rvarT (Categorical ds)
        | V.null ds = fail "categorical distribution over empty set cannot be sampled"
        | n == 1    = return (snd (V.head ds))
        | otherwise = do
            u <- uniformT 0 (fst (V.last ds))
            
            let -- by construction, p is monotone; (i < j) ==> (p i <= p j)
                p i = fst (ds V.! i)
                x i = snd (ds V.! i)
                
                --  findEvent
                -- ===========
                -- invariants: (i <= j), (u <= p j), ((i == 0) || (p i < u))
                --  (the last one means 'i' does not increase unless it bounds 'p' below 'u')
                -- variant: either i increases or j decreases.
                -- upon termination: ∀ k. if (k < j) then (p k < u) else (u <= p k)
                --  (that is, the chosen event 'x j' is the first one whose 
                --   associated cumulative probability 'p j' is greater than 
                --   or equal to 'u')
                findEvent i j
                    | j <= i    = x j
                    | u <= p m  = findEvent i m
                    | otherwise = findEvent (max m (i+1)) j
                    where
                        -- midpoint rounding down
                        -- (i < j) ==> (m < j)
                        m = (i + j) `div` 2
            
            return $! if u <= 0 then x 0 else findEvent 0 (n-1)
        where n = V.length ds


instance Functor (Categorical p) where
    fmap f (Categorical ds) = Categorical (V.map (second f) ds)

instance Foldable (Categorical p) where
    foldMap f (Categorical ds) = foldMap (f . snd) (V.toList ds)

instance Traversable (Categorical p) where
    traverse f (Categorical ds) = Categorical . V.fromList <$> traverse (\(p,e) -> (\e' -> (p,e')) <$> f e) (V.toList ds)
    sequenceA  (Categorical ds) = Categorical . V.fromList <$> traverse (\(p,e) -> (\e' -> (p,e')) <$>   e) (V.toList ds)

instance Num p => Monad (Categorical p) where
    return x = Categorical (V.singleton (1, x))
    
    -- I'm not entirely sure whether this is a valid form of failure; see next
    -- set of comments.
    fail _ = Categorical V.empty
    
    -- Should the normalize step be included here, or should normalization
    -- be assumed?  It seems like there is (at least) 1 valid situation where
    -- non-normal results would arise:  the distribution being modeled is 
    -- "conditional" and some event arose that contradicted the assumed 
    -- condition and thus was eliminated ('f' returned an empty or 
    -- zero-probability consequent, possibly by 'fail'ing).
    -- 
    -- It seems reasonable to continue in such circumstances, but should there
    -- be any renormalization?  If so, does it make a difference when that 
    -- renormalization is done?  I'm pretty sure it does, actually.  So, the
    -- normalization will be omitted here for now, as it's easier for the
    -- user (who really better know what they mean if they're returning
    -- non-normalized probability anyway) to normalize explicitly than to
    -- undo any normalization that was done automatically.
    xs >>= f = {- normalizeCategoricalPs . -} fromList $ do
        (p, x) <- toList xs
        (q, y) <- toList (f x)
        
        return (p * q, y)

instance Fractional p => Applicative (Categorical p) where
    pure = return
    (<*>) = ap

-- |Like 'fmap', but for the probabilities of a categorical distribution.
mapCategoricalPs :: (Num p, Num q) => (p -> q) -> Categorical p e -> Categorical q e
mapCategoricalPs f = fromList . map (first f) . toList

-- |Adjust all the weights of a categorical distribution so that they 
-- sum to unity and remove all events whose probability is zero.
normalizeCategoricalPs :: (Fractional p, Eq p) => Categorical p e -> Categorical p e
normalizeCategoricalPs orig@(Categorical ds)
    | ps == 0   = Categorical V.empty
    | otherwise = runST $ do
        lastP       <- newSTRef 0
        nDups       <- newSTRef 0
        normalized  <- V.thaw ds
        
        let n           = V.length ds
            skip        = modifySTRef' nDups (1+)
            save i p x  = do
                d <- readSTRef nDups
                MV.write normalized (i-d) (p, x)
        
        sequence_
            [ do
                let (p,x) = ds V.! i
                p0 <- readSTRef lastP
                if p == p0
                    then skip
                    else do
                        save i (p * scale) x
                        writeSTRef lastP $! p
            | i <- [0..n-1]
            ]
        
        -- force last element to 1
        d <- readSTRef nDups
        let n' = n-d
        (_,lastX) <- MV.read normalized (n'-1)
        MV.write normalized (n'-1) (1,lastX)
        Categorical <$> V.unsafeFreeze (MV.unsafeSlice 0 n' normalized)
    where
        ps = totalWeight orig
        scale = recip ps

-- |strict 'modifySTRef'
modifySTRef' :: STRef s a -> (a -> a) -> ST s ()
modifySTRef' x f = do
    v <- readSTRef x
    let fv = f v
    fv `seq` writeSTRef x fv

-- |Simplify a categorical distribution by combining equivalent events (the new
-- event will have a probability equal to the sum of all the originals).
collectEvents :: (Ord e, Num p, Ord p) => Categorical p e -> Categorical p e
collectEvents = collectEventsBy compare ((sum *** head) . unzip)
        
-- |Simplify a categorical distribution by combining equivalent events (the new
-- event will have a weight equal to the sum of all the originals).
-- The comparator function is used to identify events to combine.  Once chosen,
-- the events and their weights are combined by the provided probability and
-- event aggregation function.
collectEventsBy :: Num p => (e -> e -> Ordering) -> ([(p,e)] -> (p,e))-> Categorical p e -> Categorical p e
collectEventsBy compareE combine = 
    fromList . map combine . groupEvents . sortEvents . toList
    where
        groupEvents = groupBy (\x y -> snd x `compareE` snd y == EQ)
        sortEvents  = sortBy (compareE `on` snd)