{-# LANGUAGE MultiWayIf #-}

{- Copyright 2014 Romain Edelmann

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. -}


-- | This modules defines types and functions for manipulating
--   finite discrete probability distributions.
module Data.Distribution.Core
    ( -- * Probability
      Probability
      -- * Distribution
    , Distribution
    , toMap
    , toList
      -- ** Properties
    , size
    , support
      -- ** Creation
    , fromList
    , always
    , uniform
    , withProbability
      -- ** Transformation
    , select
    , assuming
      -- ** Combination
    , combine
      -- ** Sequences
      -- *** Independant experiments
    , trials
    , times
      -- *** Dependant experiments
    , andThen
    , on
    ) where

import Control.Arrow (second)
import qualified Data.Function as F
import Data.List (tails, groupBy, sortBy, find)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Monoid
import Data.Ord (comparing)
import Data.Set (Set)


-- | Probability. Should be between 0 and 1.
type Probability = Rational

-- | Distribution over values of type @a@.
--
--   Due to their internal representations, @Distribution@ can not have
--   @Functor@ or @Monad@ instances.
--   However, 'select' is the equivalent of @fmap@ for distributions
--   and 'always' and 'andThen' are respectively the equivalent of @return@
--   and @>>=@.
newtype Distribution a = Distribution
    { toMap :: Map a Probability
      -- ^ Converts the distribution to a mapping from values to their
      --   probability. Values with probability @0@ are not included
      --   in the resulting mapping.
    } deriving Eq

instance Show a => Show (Distribution a) where
    show d = "fromList " ++ show (toList d)

-- | A distribution @d1@ is less than some other distribution @d2@
--   if the smallest value that has a different probability
--   in @d1@ and @d2@ is more probable in @d1@.
--
--   By convention, empty distributions are less than
--   everything except themselves.
instance Ord a => Ord (Distribution a) where
    compare d1 d2 = case (toList d1, toList d2) of
        ([], []) -> EQ
        ([], _)  -> LT
        (_, [])  -> GT
        (xs, ys) -> case find (uncurry (/=)) $ zip xs ys of
            Nothing -> EQ
            Just ((x, p), (y, q)) -> case compare x y of
                EQ -> compare q p
                c  -> c

-- | Lifts the bounds to the distributions that return them
--   with probability one.
--
--   Note that the degenerate distributions of size @0@ will
--   be less than the distribution @minBound@.
--
--   Appart from that, all other distributions d have
--   the property that @minBound <= d <= maxBound@ if
--   this property holds on the values of the distribution.
instance Bounded a => Bounded (Distribution a) where
    minBound = always minBound
    maxBound = always maxBound

-- | Literals are interpreted as distributions that always
--   return the given value.
--
--   >>> 42 == always 42
--   True
--
--   Binary operations on distributions are defined to
--   be the binary operation on each pair of elements.
--
--   For this reason, @(+)@ and @(*)@ are not related in the same way
--   as they are on natural numbers.
--
--   For instance, it is not always the case that:
--   @3 * d == d + d + d@
--
--   >>> let d = uniform [0, 1]
--   >>> 3 * d
--   fromList [(0,1 % 2),(3,1 % 2)]
--   >>> d + d + d
--   fromList [(0,1 % 8),(1,3 % 8),(2,3 % 8),(3,1 % 8)]
--
--   For this particular behavior, see the `times` function.
instance (Ord a, Num a) => Num (Distribution a) where
    fromInteger = always . fromInteger
    abs = select abs
    signum = select signum
    negate = select negate
    d1 + d2 = d1 `andThen` (+) `on` d2
    d1 - d2 = d1 `andThen` (-) `on` d2
    d1 * d2 = d1 `andThen` (*) `on` d2

-- Binary operations on distributions are defined to
-- be the binary operation on each pair of elements.
instance (Ord a, Fractional a) => Fractional (Distribution a) where
    fromRational = always . fromRational
    d1 / d2 = d1 `andThen` (/) `on` d2
    recip = select recip

-- Binary operations on distributions are defined to
-- be the binary operation on each pair of element.
instance (Ord a, Floating a) => Floating (Distribution a) where
    pi = always pi
    exp = select exp
    sqrt = select sqrt
    log = select log
    d1 ** d2 = d1 `andThen` (**) `on` d2
    d1 `logBase` d2 = d1 `andThen` logBase `on` d2
    sin = select sin
    tan = select tan
    cos = select cos
    asin = select asin
    atan = select atan
    acos = select acos
    sinh = select sinh
    tanh = select tanh
    cosh = select cosh
    asinh = select asinh
    atanh = select atanh
    acosh = select acosh

instance (Ord a, Monoid a) => Monoid (Distribution a) where
    mempty = always mempty
    mappend d1 d2 = d1 `andThen` mappend `on` d2

-- | Converts the distribution to a list of increasing values whose probability
--   is greater than @0@. To each value is associated its probability.
toList :: Distribution a -> [(a, Probability)]
toList (Distribution xs) = Map.toAscList xs


-- Properties


-- | Returns the number of elements with non-zero probability
--   in the distribution.
size :: Distribution a -> Int
size = Map.size . toMap

-- | Values in the distribution with non-zero probability.
support :: Distribution a -> Set a
support = Map.keysSet . toMap


-- Creation


-- | Distribution that assigns to each @value@ from the given @(value, weight)@
--   pairs a probability proportional to @weight@.
--
--   >>> fromList [('A', 1), ('B', 2), ('C', 1)]
--   fromList [('A',1 % 4),('B',1 % 2),('C',1 % 4)]
--
--   Values may appear multiple times in the list. In this case, their total
--   weight is the sum of the different associated weights.
--   Values whose total weight is zero or negative are ignored.
fromList :: (Ord a, Real p) => [(a, p)] -> Distribution a
fromList xs = Distribution $ Map.fromDistinctAscList $ zip vs scaledPs
  where
    as = map aggregate $ groupBy ((==) `F.on` fst) $ sortBy (comparing fst) xs
      where
        aggregate ys = let (v : _, qs) = unzip ys in
            (v, fromRational $ toRational $ sum qs)
    (vs, ps) = unzip $ filter ((> 0) . snd) as
    t = sum ps
    scaledPs = if t /= 1 then map (/ t) ps else ps

-- | Distribution that assigns to @x@ the probability of @1@.
--
-- >>> always 0
-- fromList [(0,1 % 1)]
--
-- >>> always 42
-- fromList [(42,1 % 1)]
always :: a -> Distribution a
always x = Distribution $ Map.singleton x 1

-- | Uniform distribution over the values.
--   The probability of each element is proportional
--   to its number of appearance in the list.
--
--   >>> uniform [1 .. 6]
--   fromList [(1,1 % 6),(2,1 % 6),(3,1 % 6),(4,1 % 6),(5,1 % 6),(6,1 % 6)]
uniform :: (Ord a) => [a] -> Distribution a
uniform xs = fromList $ fmap (\ x -> (x, p)) xs
  where
    p = 1 / toRational (length xs)

-- | @True@ with given probability and @False@ with complementary probability.
withProbability :: Real p => p -> Distribution Bool
withProbability p = fromList [(False, 1 - p'), (True, p')]
  where
    p' = fromRational $ max 0 $ min 1 $ toRational p


-- Transformation


-- | Applies a function to the values in the distribution.
--
--   >>> select abs $ uniform [-1, 0, 1]
--   fromList [(0,1 % 3),(1,2 % 3)]
select :: Ord b => (a -> b) -> Distribution a -> Distribution b
select f (Distribution xs) = Distribution $ Map.mapKeysWith (+) f xs

-- | Returns a new distribution conditioning on the predicate holding
--   on the value.
--
--   >>> assuming (> 2) $ uniform [1 .. 6]
--   fromList [(3,1 % 4),(4,1 % 4),(5,1 % 4),(6,1 % 4)]
--
--   Note that the resulting distribution will be empty
--   if the predicate does not hold on any of the values.
--
--   >>> assuming (> 7) $ uniform [1 .. 6]
--   fromList []
assuming :: (a -> Bool) -> Distribution a -> Distribution a
assuming f (Distribution xs) = Distribution $ fmap adjust filtered
  where
    filtered = Map.filterWithKey (const . f) xs
    adjust x = x * (1 / total)
    total = sum $ Map.elems filtered


-- Combination


-- | Combines multiple weighted distributions into a single distribution.
--
--   The probability of each element is the weighted sum of the element's
--   probability in every distribution.
--
--   >>> combine [(always 2, 1 / 3), (uniform [1..6], 2 / 3)]
--   fromList [(1,1 % 9),(2,4 % 9),(3,1 % 9),(4,1 % 9),(5,1 % 9),(6,1 % 9)]
--
--   Note that the weights do not have to sum up to @1@. Distributions with
--   negative or null weight will be ignored.
combine :: (Ord a, Real p) => [(Distribution a, p)] -> Distribution a
combine dws = Distribution $ Map.unionsWith (+) $ zipWith go ds ps
  where
    (ds, ws) = unzip $ filter ((> 0) . snd) $ map (second toRational) dws
    w = sum ws
    ps = map (/ w) ws
    go (Distribution xs) p = fmap (* p) xs


-- Sequences


-- | Binomial distribution.
--   Assigns to each number of successes its probability.
--
--   >>> trials 2 $ uniform [True, False]
--   fromList [(0,1 % 4),(1,1 % 2),(2,1 % 4)]
trials :: Int -> Distribution Bool -> Distribution Int
trials n d = Distribution $ Map.fromDistinctAscList $ if
    | p == 1    -> [(n, 1)]
    | p == 0    -> [(0, 1)]
    | otherwise -> zip outcomes probs
  where
    p = fromMaybe 0 $ Map.lookup True $ toMap d
    q = 1 - p

    ps = take (n + 1) $ iterate (* p) 1
    qs = reverse $ take (n + 1) $ iterate (* q) 1

    probs = zipWith (*) pascalRow $ zipWith (*) ps qs

    outcomes = [0 .. n]

    pascalRow = fmap (fromRational . toRational) $
        scanl ( \ c k -> c * (n' + 1 - k) `div` k) 1 [1 .. n']
      where
        n' = toInteger n

-- | Takes `n` samples from the distribution and returns the distribution
--   of their sum.
--
--   >>> times 2 $ uniform [1 .. 3]
--   fromList [(2,1 % 9),(3,2 % 9),(4,1 % 3),(5,2 % 9),(6,1 % 9)]
--
--   This function makes use of the more efficient @trials@ functions
--   for input distributions of size @2@.
--
--   >>> size $ times 10000 $ uniform [1, 10]
--   10001
times :: (Num a, Ord a) => Int -> Distribution a -> Distribution a
n `times` d
    | s == 0 = d
    | n <= 0 = always 0
    | s == 1 = select (* n') d
    | s == 2 = case toList d of  -- Performs Bernoulli trials. (efficiency)
        [(a, p), (b, q)] -> select (go a b) $ trials n $ withProbability p
        _ -> error "times: size seems not to be properly defined."
    |otherwise = mult n
  where
    s = Map.size $ toMap d
    n' = fromInteger $ toInteger n
    go a b k = k' * a + (n' - k') * b
      where
        k' = fromInteger $ toInteger k

    -- Computes @k `times` d@ using a divide and conquer approach.
    mult 1 = d
    mult k = if r == 0 then twice d' else twice d' + d
      where
        d' = mult k'
        (k', r) = k `quotRem` 2

    -- Computes @d + d@ more efficiently.
    twice (Distribution xs) = Distribution $ Map.unionsWith (+) $ do
        (x, p) : ys <- init $ tails $ Map.toAscList xs
        return $ Map.fromDistinctAscList $ (:) (x + x, p * p) $ do
            (y, q) <- ys
            let p' = 2 * p * q
            return (y + x, p')

-- | Computes for each value in the distribution a new distribution, and then
--   combines those distributions, giving each the weight of the original value.
--
--   >>> uniform [1 .. 3] `andThen` (\ n -> uniform [1 .. n])
--   fromList [(1,11 % 18),(2,5 % 18),(3,1 % 9)]
--
--   See the 'on' function for a convenient way to chain distributions.
infixl 7 `andThen`
andThen :: Ord b => Distribution a -> (a -> Distribution b) -> Distribution b
andThen (Distribution xs) f = Distribution $
    Map.unionsWith (+) $ fmap go $ Map.toList xs
  where
    go (x, p) = fmap (* p) $ toMap $ f x

-- | Utility to partially apply a function on a distribution.
--   A use case for 'on' is to use it in conjunction with 'andThen'
--   to combine distributions.
--
--   >>> uniform [1 .. 3] `andThen` (+) `on` uniform [1 .. 2]
--   fromList [(2,1 % 6),(3,1 % 3),(4,1 % 3),(5,1 % 6)]
infixl 8 `on`
on :: Ord c => (a -> b -> c) -> Distribution b -> a -> Distribution c
on f d x = select (f x) d