-----------------------------------------------------------------------------
-- |
-- Module     : Control.Monad.MC.Walker
-- Copyright  : Copyright (c) 2010, Patrick Perry <patperry@gmail.com>
-- License    : BSD3
-- Maintainer : Patrick Perry <patperry@gmail.com>
-- Stability  : experimental
--
-- An implementation of Walker's Alias method for sampling from discrete
-- distributions.  See section III.4 of Luc Devroye's book
-- "Non-Uniform Random Variate Generation", which is available on his
-- homepage, for a description of how it works.
module Control.Monad.MC.Walker (
    Table,
    computeTable,
    indexTable,
    tableSize,
    component,
    ) where

import Control.Monad
import Control.Monad.ST
import Data.Vector.Unboxed( Vector, MVector )
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Generic.Mutable as MV

-- | The table, which represents an equiprobable mixture of two-point
-- distributions.  The @l@th entry of the table represents a mixture
-- distribution with weight @q[l]@ on @l@ and weight @(1-q[l])@ on @j[l]@.
-- The @l@th element of the table stores the pair @q[l] :*: j[l]@.
newtype Table = T (Vector (Double, Int))

-- | Get the @i@th mixture component.  That is, return @q[i]@ and @j[i]@,
-- where the @i@th mixture component puts mass @q[i]@ on @i@ and mass
-- @1 - q[i]@ on @j[i]@.
component :: Table -> Int -> (Double,Int)
component (T qjs) i = let
    (q', j) =  V.unsafeIndex qjs i
    q = q' - fromIntegral i
    in (q,j)

-- | Compute the table for use in Walker's aliasing method.
computeTable :: Int -> [Double] -> Table
computeTable n ws = T $ V.create $ do
    (qjs, sets) <- initTable n ws
    breakLarger qjs sets
    scaleTable qjs
    return qjs

-- | Given an alias table and a number in the range [0,1),
-- get the corresponding sample in the table.
indexTable :: Table -> Double -> Int
indexTable (T qjs) u = let
    n  = V.length qjs
    nu = u * fromIntegral n
    l  = floor nu
    (ql,jl) = V.unsafeIndex qjs l
    in if nu < ql then l else jl

-- | Get the size of the table
tableSize :: Table -> Int
tableSize (T qjs) = V.length qjs

-- | An intermediate result for use in computing a Table.
type STTable s = MVector s (Double, Int)

-- | A partition of indices into the sets /Greater/ and /Smaller/.  The
-- indices of the /Smaller/ set are stored in positions @0, ..., numSmall - 1@,
-- and the indices of the /Greater/ set are stored in positions
-- @numSmall, ..., n-1@, where @n@ is the size of the underlying array.
data STPartition s = P !(MVector s Int)
                       !Int

-- | Given a list of weights, @ws@, compute corresponding probabilities, @ps@,
-- and store @map (n*) ps@ in the @qs@ array.  Partition the probabilities
-- into two sets, /Greater/, and /Smaller/ based on whether or not
-- @q >= 1@ or @q < 1@.
initTable :: Int -> [Double] -> ST s (STTable s, STPartition s)
initTable n ws = do
    when (n < 0) $ fail "negative table size"
    sets <- MV.new n :: ST s (MVector s Int)
    qjs  <- MV.new n :: ST s (MVector s (Double, Int))

    -- Store the weights in the table and compute their total.
    total <-
        foldM (\current (i,w) -> do
                  if w >= 0
                      then do
                          MV.unsafeWrite qjs i (w,i)
                          return $! current + w
                      else
                          fail $ "negative probability" )
              0
              (zip [0 .. n-1] ws)

    when (total == 0) $ fail "no positive probabilities given"

    -- scale the weights to get the qs, and partition the probabilites
    -- into the two sets
    let scale = fromIntegral n / total
    nsmall <- liftM fst $
        foldM (\(smaller,greater) i -> do
               p <- liftM fst $ MV.unsafeRead qjs i
               let q = scale*p
               MV.unsafeWrite qjs i (q,i)
               if q < 1
                   then do
                       MV.unsafeWrite sets smaller i
                       return (smaller+1,greater)
                   else do
                       MV.unsafeWrite sets greater i
                       return (smaller,greater-1) )
              (0,n-1)
              [0 .. n-1]

    return $ (qjs, P sets nsmall)


-- Given an initialized table and partition, compute the two-point
-- distributions by splitting the larger probabilites sccross multiple
-- distribions.
breakLarger :: STTable s -> STPartition s -> ST s ()
breakLarger qjs (P sets nsmall) | nsmall == 0 = return ()
                                | otherwise   = let
    n = MV.length qjs
    breakLargerHelp nsmall' i | nsmall' == n = return ()
                              | i == n       = return ()
                              | otherwise    = do
        -- while Greater is not empty
        -- choose k from Greater, l from Smaller
        k  <- MV.unsafeRead sets $ nsmall'
        l  <- MV.unsafeRead sets $ i
        qk <- liftM fst $ MV.unsafeRead qjs k
        ql <- liftM fst $ MV.unsafeRead qjs l

        -- set jl := k, finalize (ql,jl)
        let jl = k
        MV.unsafeWrite qjs l (ql,jl)

        -- set qk := qk - (1-ql)
        let qk' = qk - (1-ql)
        MV.unsafeWrite qjs k (qk',k)

        -- if qk' < 1, move k from Greater to Smaller
        let nsmall'' = if qk' < 1 then nsmall'+1 else nsmall'

        breakLargerHelp nsmall'' (i+1)
    in
        breakLargerHelp nsmall 0

-- Scale the probabilities in the table so that the lth entry
-- stores q[l] + l instead of q[l].  This helps when we are sampling
-- from the table.
scaleTable :: STTable s -> ST s ()
scaleTable qjs = let
    n = MV.length qjs in
    forM_ [ 0..(n-1) ] $ \l -> do
        (ql, jl) <- MV.unsafeRead qjs l
        MV.unsafeWrite qjs l ((ql + fromIntegral l), jl)