module Data.Random.Shuffle.Weighted
(
weightedShuffleCDF
, weightedShuffle
, weightedSampleCDF
, weightedSample
, weightedChoiceExtractCDF
, cdfMapFromList
)
where
import Control.Applicative ((<$>))
import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Data.Random.Distribution.Uniform.Exclusive
import qualified Data.Map as M
moduleError :: String -> String -> a
moduleError n s = error $ "Data.Random.Shuffle.Weighted." ++ n ++ ": " ++ s
weightedShuffleCDF :: (Num w, Ord w, Distribution Uniform w, Excludable w) => M.Map w a -> RVar [a]
weightedShuffleCDF m | M.null m = return []
| otherwise = weightedChoiceExtractCDF m >>= \(m', a) -> (a:) <$> weightedShuffleCDF m'
weightedShuffle :: (Num w, Ord w, Distribution Uniform w, Excludable w) => [(w, a)] -> RVar [a]
weightedShuffle = weightedShuffleCDF . cdfMapFromList
weightedSampleCDF :: (Num w, Ord w, Distribution Uniform w, Excludable w) => Int -> M.Map w a -> RVar [a]
weightedSampleCDF n m | M.null m || n <= 0 = return []
| otherwise = weightedChoiceExtractCDF m >>= \(m', a) -> (a:) <$> weightedSampleCDF (n 1) m'
weightedSample :: (Num w, Ord w, Distribution Uniform w, Excludable w) => Int -> [(w, a)] -> RVar [a]
weightedSample n = weightedSampleCDF n . cdfMapFromList
weightedChoiceExtractCDF :: (Num w, Ord w, Distribution Uniform w, Excludable w) => M.Map w a -> RVar (M.Map w a, a)
weightedChoiceExtractCDF m | M.null m = moduleError "weightedChoiceExtractCDF" "empty map"
| M.null exceptMax = return (exceptMax, maxE)
| otherwise = extract <$> uniformExclusive 0 wmax
where Just ((wmax, maxE), exceptMax) = M.maxViewWithKey m
extract w = (a `M.union` M.mapKeysMonotonic (subtract gap) c, b)
where (a, e, r') = M.splitLookup w m
r = case e of
Nothing -> r'
Just ex -> M.insert w ex r'
Just ((k, b), c) = M.minViewWithKey r
gap = case M.minViewWithKey c of
Nothing -> 0
Just ((k2, _), _) -> k2 k
cdfMapFromList :: (Num w, Eq w) => [(w, a)] -> M.Map w a
cdfMapFromList = M.fromAscListWith (const id)
. scanl1 (\(w1, _) (w2, x) -> (w1 + w2, x))
. dropWhile ((==0) . fst)