{-# LANGUAGE TypeOperators #-} ----------------------------------------------------------------------------- -- | -- Module : Control.Monad.MC.Walker -- Copyright : Copyright (c) , Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- -- An implementation of Walker's Alias method for sampling from discrete -- distributions. See section III.4 of Luc Devroye's book -- "Non-Uniform Random Variate Generation", which is available on his -- homepage, for a description of how it works. module Control.Monad.MC.Walker ( Table, computeTable, indexTable, tableSize, component, ) where import Control.Monad import Control.Monad.ST import Data.Array.Vector -- | The table, which represents an equiprobable mixture of two-point -- distributions. The @l@th entry of the table represents a mixture -- distribution with weight @q[l]@ on @l@ and weight @(1-q[l])@ on @j[l]@. -- The @l@th element of the table stores the pair @q[l] :*: j[l]@. newtype Table = T (UArr (Double :*: Int)) -- | Get the @i@th mixture component. That is, return @q[i]@ and @j[i]@, -- where the @i@th mixture component puts mass @q[i]@ on @i@ and mass -- @1 - q[i]@ on @j[i]@. component :: Table -> Int -> (Double,Int) component (T qjs) i = let (q' :*: j) = indexU qjs i q = q' - fromIntegral i in (q,j) -- | Compute the table for use in Walker's aliasing method. computeTable :: Int -> [Double] -> Table computeTable n ws = runST $ do (qjs, sets) <- initTable n ws breakLarger qjs sets scaleTable qjs liftM T $ unsafeFreezeAllMU qjs -- | Given an alias table and a number in the range [0,1), -- get the corresponding sample in the table. indexTable :: Table -> Double -> Int indexTable (T qjs) u = let n = lengthU qjs nu = u * fromIntegral n l = floor nu (ql :*: jl) = indexU qjs l in if nu < ql then l else jl -- | Get the size of the table tableSize :: Table -> Int tableSize (T qjs) = lengthU qjs -- | An intermediate result for use in computing a Table. type STTable s = MUArr (Double :*: Int) s -- | A partition of indices into the sets /Greater/ and /Smaller/. The -- indices of the /Smaller/ set are stored in positions @0, ..., numSmall - 1@, -- and the indices of the /Greater/ set are stored in positions -- @numSmall, ..., n-1@, where @n@ is the size of the underlying array. data STPartition s = P !(MUArr Int s) !Int -- | Given a list of weights, @ws@, compute corresponding probabilities, @ps@, -- and store @map (n*) ps@ in the @qs@ array. Partition the probabilities -- into two sets, /Greater/, and /Smaller/ based on whether or not -- @q >= 1@ or @q < 1@. initTable :: Int -> [Double] -> ST s (STTable s, STPartition s) initTable n ws = do when (n < 0) $ fail "negative table size" sets <- newMU n :: ST s (MUArr Int s) qjs <- newMU n :: ST s (MUArr (Double :*: Int) s) -- Store the weights in the table and compute their total. total <- foldM (\current (i,w) -> do if w >= 0 then do writeMU qjs i (w :*: i) return $! current + w else fail $ "negative probability" ) 0 (zip [0 .. n-1] ws) when (total == 0) $ fail "no positive probabilities given" -- scale the weights to get the qs, and partition the probabilites -- into the two sets let scale = fromIntegral n / total nsmall <- liftM fst $ foldM (\(smaller,greater) i -> do p <- liftM fstS $ readMU qjs i let q = scale*p writeMU qjs i (q :*: i) if q < 1 then do writeMU sets smaller i return (smaller+1,greater) else do writeMU sets greater i return (smaller,greater-1) ) (0,n-1) [0 .. n-1] return $ (qjs, P sets nsmall) -- Given an initialized table and partition, compute the two-point -- distributions by splitting the larger probabilites sccross multiple -- distribions. breakLarger :: STTable s -> STPartition s -> ST s () breakLarger qjs (P sets nsmall) | nsmall == 0 = return () | otherwise = let n = lengthMU qjs breakLargerHelp nsmall' i | nsmall' == n = return () | i == n = return () | otherwise = do -- while Greater is not empty -- choose k from Greater, l from Smaller k <- readMU sets $ nsmall' l <- readMU sets $ i qk <- liftM fstS $ readMU qjs k ql <- liftM fstS $ readMU qjs l -- set jl := k, finalize (ql,jl) let jl = k writeMU qjs l (ql :*: jl) -- set qk := qk - (1-ql) let qk' = qk - (1-ql) writeMU qjs k (qk' :*: k) -- if qk' < 1, move k from Greater to Smaller let nsmall'' = if qk' < 1 then nsmall'+1 else nsmall' breakLargerHelp nsmall'' (i+1) in breakLargerHelp nsmall 0 -- Scale the probabilities in the table so that the lth entry -- stores q[l] + l instead of q[l]. This helps when we are sampling -- from the table. scaleTable :: STTable s -> ST s () scaleTable qjs = let n = lengthMU qjs in forM_ [ 0..(n-1) ] $ \l -> do (ql :*: jl) <- readMU qjs l writeMU qjs l ((ql + fromIntegral l) :*: jl)