{-# LANGUAGE FlexibleContexts #-}
{- |
Module : Data.Random.Extras
Copyright : 2010 Aristid Breitkreuz
License : BSD3
Stability : experimental
Portability : portable
Functions for shuffling elements according to weights.
Definitions:
[Weight] A number above /0/ denoting how likely it is for an element to
end up in the first position.
[Probability] A weight normalised into the /(0,1]/ range.
[Weighted list] A list of pairs @(w, a)@, where @w@ is the weight of
element @a@.
The probability of an element getting into the first position
is equal by its weight divided by the sum of all weights, and
the probability of getting into a position other than the
first is equal to the probability of getting in the first
position when all elements in prior positions have been
removed from the weighted list.
[CDF Map] A map of /summed weights/ to elements. For example, a weighted
list @[(0.2, 'a'), (0.6, 'b'), (0.2, 'c')]@ corresponds to a
CDF map of @[(0.2, 'a'), (0.8, 'b'), (1.0, 'c')]@
(as a 'Map'). The weights are summed from left to right.
-}
module Data.Random.Shuffle.Weighted
(
-- * Shuffling
weightedShuffleCDF
, weightedShuffle
-- * Sampling
, weightedSampleCDF
, weightedSample
-- * Extraction
, weightedChoiceExtractCDF
-- * Utilities
, cdfMapFromList
)
where
import Control.Applicative ((<$>))
import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Data.Random.Distribution.Uniform.Exclusive
import qualified Data.Map as M
moduleError :: String -> String -> a
moduleError n s = error $ "Data.Random.Shuffle.Weighted." ++ n ++ ": " ++ s
-- | Randomly shuffle a CDF map according to its weights.
weightedShuffleCDF :: (Num w, Ord w, Distribution Uniform w, Excludable w) => M.Map w a -> RVar [a]
weightedShuffleCDF m | M.null m = return []
| otherwise = weightedChoiceExtractCDF m >>= \(m', a) -> (a:) <$> weightedShuffleCDF m'
-- | Randomly shuffle a weighted list according to its weights.
weightedShuffle :: (Num w, Ord w, Distribution Uniform w, Excludable w) => [(w, a)] -> RVar [a]
weightedShuffle = weightedShuffleCDF . cdfMapFromList
-- | Randomly draw /n/ elements from a CDF map according to its weights.
weightedSampleCDF :: (Num w, Ord w, Distribution Uniform w, Excludable w) => Int -> M.Map w a -> RVar [a]
weightedSampleCDF n m | M.null m || n <= 0 = return []
| otherwise = weightedChoiceExtractCDF m >>= \(m', a) -> (a:) <$> weightedSampleCDF (n - 1) m'
-- | Randomly draw /n/ elements from a weighted list according to its weights.
weightedSample :: (Num w, Ord w, Distribution Uniform w, Excludable w) => Int -> [(w, a)] -> RVar [a]
weightedSample n = weightedSampleCDF n . cdfMapFromList
-- | Randomly extract an element from a CDF map according to its weights. The
-- element is removed and the resulting "weight gap" closed.
weightedChoiceExtractCDF :: (Num w, Ord w, Distribution Uniform w, Excludable w) => M.Map w a -> RVar (M.Map w a, a)
weightedChoiceExtractCDF m | M.null m = moduleError "weightedChoiceExtractCDF" "empty map"
| M.null exceptMax = return (exceptMax, maxE)
| otherwise = extract <$> uniformExclusive 0 wmax
where Just ((wmax, maxE), exceptMax) = M.maxViewWithKey m
extract w = (a `M.union` M.mapKeysMonotonic (subtract gap) c, b)
where (a, e, r') = M.splitLookup w m
r = case e of
Nothing -> r'
Just ex -> M.insert w ex r'
Just ((k, b), c) = M.minViewWithKey r
gap = case M.minViewWithKey c of
Nothing -> 0
Just ((k2, _), _) -> k2 - k
-- | Generate a CDF map from a weighted list.
cdfMapFromList :: Num w => [(w, a)] -> M.Map w a
cdfMapFromList = M.fromAscListWith (const id)
. scanl1 (\(w1, _) (w2, x) -> (w1 + w2, x))
. dropWhile ((==0) . fst)