module Control.Monad.MC.Walker (
Table,
computeTable,
indexTable,
tableSize,
component,
) where
import Control.Monad
import Control.Monad.ST
import Data.Vector.Unboxed( Vector, MVector )
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Generic.Mutable as MV
newtype Table = T (Vector (Double, Int))
component :: Table -> Int -> (Double,Int)
component (T qjs) i = let
(q', j) = V.unsafeIndex qjs i
q = q' fromIntegral i
in (q,j)
computeTable :: Int -> [Double] -> Table
computeTable n ws = T $ V.create $ do
(qjs, sets) <- initTable n ws
breakLarger qjs sets
scaleTable qjs
return qjs
indexTable :: Table -> Double -> Int
indexTable (T qjs) u = let
n = V.length qjs
nu = u * fromIntegral n
l = floor nu
(ql,jl) = V.unsafeIndex qjs l
in if nu < ql then l else jl
tableSize :: Table -> Int
tableSize (T qjs) = V.length qjs
type STTable s = MVector s (Double, Int)
data STPartition s = P !(MVector s Int)
!Int
initTable :: Int -> [Double] -> ST s (STTable s, STPartition s)
initTable n ws = do
when (n < 0) $ fail "negative table size"
sets <- MV.new n :: ST s (MVector s Int)
qjs <- MV.new n :: ST s (MVector s (Double, Int))
total <-
foldM (\current (i,w) -> do
if w >= 0
then do
MV.unsafeWrite qjs i (w,i)
return $! current + w
else
fail $ "negative probability" )
0
(zip [0 .. n1] ws)
when (total == 0) $ fail "no positive probabilities given"
let scale = fromIntegral n / total
nsmall <- liftM fst $
foldM (\(smaller,greater) i -> do
p <- liftM fst $ MV.unsafeRead qjs i
let q = scale*p
MV.unsafeWrite qjs i (q,i)
if q < 1
then do
MV.unsafeWrite sets smaller i
return (smaller+1,greater)
else do
MV.unsafeWrite sets greater i
return (smaller,greater1) )
(0,n1)
[0 .. n1]
return $ (qjs, P sets nsmall)
breakLarger :: STTable s -> STPartition s -> ST s ()
breakLarger qjs (P sets nsmall) | nsmall == 0 = return ()
| otherwise = let
n = MV.length qjs
breakLargerHelp nsmall' i | nsmall' == n = return ()
| i == n = return ()
| otherwise = do
k <- MV.unsafeRead sets $ nsmall'
l <- MV.unsafeRead sets $ i
qk <- liftM fst $ MV.unsafeRead qjs k
ql <- liftM fst $ MV.unsafeRead qjs l
let jl = k
MV.unsafeWrite qjs l (ql,jl)
let qk' = qk (1ql)
MV.unsafeWrite qjs k (qk',k)
let nsmall'' = if qk' < 1 then nsmall'+1 else nsmall'
breakLargerHelp nsmall'' (i+1)
in
breakLargerHelp nsmall 0
scaleTable :: STTable s -> ST s ()
scaleTable qjs = let
n = MV.length qjs in
forM_ [ 0..(n1) ] $ \l -> do
(ql, jl) <- MV.unsafeRead qjs l
MV.unsafeWrite qjs l ((ql + fromIntegral l), jl)