module Control.Monad.MC.Sample (
sample,
sampleWithWeights,
sampleSubset,
sampleSubset',
sampleSubsetWithWeights,
sampleSubsetWithWeights',
sampleInt,
sampleIntWithWeights,
sampleIntSubset,
sampleIntSubset',
sampleIntSubsetWithWeights,
sampleIntSubsetWithWeights',
shuffle,
shuffleInt,
shuffleInt',
) where
import Control.Monad
import Control.Monad.ST
import Control.Monad.MC.Base
import Control.Monad.MC.Repeat
import Control.Monad.MC.Walker
import Data.List( foldl', sort )
import Data.Vector.Unboxed( MVector, Unbox )
import qualified Data.Vector as BV
import qualified Data.Vector.Mutable as BMV
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Generic.Mutable as MV
sample :: (MonadMC m) => [a] -> m a
sample xs = let
n = length xs
in sampleHelp n xs $ sampleInt n
sampleWithWeights :: (MonadMC m) => [(Double, a)] -> m a
sampleWithWeights wxs = let
(ws,xs) = unzip wxs
n = length xs
in sampleHelp n xs $ sampleIntWithWeights ws n
sampleSubset :: (MonadMC m) => [a] -> Int -> m [a]
sampleSubset xs k = let
n = length xs
in sampleListHelp n xs $ sampleIntSubset n k
sampleSubset' :: (MonadMC m) => [a] -> Int -> m [a]
sampleSubset' xs k = do
s <- sampleSubset xs k
length s `seq` return s
sampleSubsetWithWeights :: (MonadMC m) => [(Double,a)] -> Int -> m [a]
sampleSubsetWithWeights wxs k = let
(ws,xs) = unzip wxs
n = length ws
in sampleListHelp n xs $ sampleIntSubsetWithWeights ws n k
sampleSubsetWithWeights' :: (MonadMC m) => [(Double,a)] -> Int -> m [a]
sampleSubsetWithWeights' wxs k = do
s <- sampleSubsetWithWeights wxs k
length s `seq` return s
sampleHelp :: (Monad m) => Int -> [a] -> m Int -> m a
sampleHelp _n xs f = let
arr = BV.fromList xs
in liftM (BV.unsafeIndex arr) f
sampleHelpU :: (Unbox a, Monad m) => Int -> [a] -> m Int -> m a
sampleHelpU _n xs f = let
arr = V.fromList xs
in liftM (V.unsafeIndex arr) f
sampleListHelp :: (Monad m) => Int -> [a] -> m [Int] -> m [a]
sampleListHelp _n xs f = let
arr = BV.fromList xs
in liftM (map $ BV.unsafeIndex arr) f
sampleListHelpU :: (Unbox a, Monad m) => Int -> [a] -> m [Int] -> m [a]
sampleListHelpU _n xs f = let
arr = V.fromList xs
in liftM (map $ V.unsafeIndex arr) f
sampleInt :: (MonadMC m) => Int -> m Int
sampleInt n | n < 1 = fail "invalid argument"
| otherwise = uniformInt n
sampleIntWithWeights :: (MonadMC m) => [Double] -> Int -> m Int
sampleIntWithWeights ws n =
let qjs = computeTable n ws
in liftM (indexTable qjs) (uniform 0 1)
sampleIntSubset :: (MonadMC m) => Int -> Int -> m [Int]
sampleIntSubset n k | k < 0 = fail "negative subset size"
| k > n = fail "subset size is too big"
| otherwise = do
us <- randomIndices n k
return $ runST $ do
ints <- MV.new n :: ST s (MVector s Int)
sequence_ [ MV.unsafeWrite ints i i | i <- [0 .. n1] ]
sampleIntSubsetHelp ints us (n1)
where
randomIndices n' k' | k' == 0 = return []
| otherwise = unsafeInterleaveMC $ do
u <- uniformInt n'
us <- randomIndices (n'1) (k'1)
return (u:us)
sampleIntSubsetHelp _ [] _ = return []
sampleIntSubsetHelp ints (u:us) n' = unsafeInterleaveST $ do
i <- MV.unsafeRead ints u
MV.unsafeWrite ints u =<< MV.unsafeRead ints n'
is <- sampleIntSubsetHelp ints us (n'1)
return (i:is)
sampleIntSubset' :: (MonadMC m) => Int -> Int -> m [Int]
sampleIntSubset' n k = do
s <- sampleIntSubset n k
length s `seq` return s
sampleIntSubsetWithWeights :: (MonadMC m) => [Double] -> Int -> Int -> m [Int]
sampleIntSubsetWithWeights ws n k = let
w_sum0 = foldl' (+) 0 $ take n ws
wjs = [ (w / w_sum0, j) | (w,j) <- reverse $ sort $ zip ws [ 0..n1 ] ]
in do
us <- replicateMC k $ uniform 0 1
return $ runST $ do
ints <- MV.new n :: ST s (MVector s (Double,Int))
sequence_ [ MV.unsafeWrite ints i wj | (i,wj) <- zip [ 0.. ] wjs ]
go ints n 1 us
where
go ints n' w_sum us | null us = return []
| otherwise = let
target = head us * w_sum
in unsafeInterleaveST $ do
(i,(w,j)) <- findTarget ints n' target 0 0
shiftDown ints (i+1) (n'1)
let w_sum' = w_sum w
n'' = n' 1
us' = tail us
js <- go ints n'' w_sum' us'
return $ j:js
findTarget ints n' target i acc
| i == n' 1 = do
wj <- MV.unsafeRead ints i
return (i,wj)
| otherwise = do
(w,j) <- MV.unsafeRead ints i
let acc' = acc + w
if target <= acc'
then return (i,(w,j))
else findTarget ints n' target (i+1) acc'
shiftDown ints from to =
forM_ [ from..to ] $ \i -> do
wj <- MV.unsafeRead ints i
MV.unsafeWrite ints (i1) wj
sampleIntSubsetWithWeights' :: (MonadMC m) => [Double] -> Int -> Int -> m [Int]
sampleIntSubsetWithWeights' ws n k = do
s <- sampleIntSubsetWithWeights ws n k
length s `seq` return s
shuffle :: (MonadMC m) => [a] -> m [a]
shuffle xs = let
n = length xs
in shuffleInt n >>= \swaps -> (return . BV.toList) $ BV.create $ do
marr <- MV.new n
zipWithM_ (MV.unsafeWrite marr) [0 .. n1] xs
mapM_ (swap marr) swaps
return marr
where
swap :: BMV.MVector s a -> (Int,Int) -> ST s ()
swap marr (i,j) | i == j = return ()
| otherwise = do
x <- MV.unsafeRead marr i
y <- MV.unsafeRead marr j
MV.unsafeWrite marr i y
MV.unsafeWrite marr j x
shuffleU :: (Unbox a, MonadMC m) => [a] -> m [a]
shuffleU xs = let
n = length xs
in shuffleInt n >>= \swaps -> (return . V.toList) $ V.create $ do
marr <- MV.new n
zipWithM_ (MV.unsafeWrite marr) [0 .. n1] xs
mapM_ (swap marr) swaps
return marr
where
swap :: (Unbox a) => MVector s a -> (Int,Int) -> ST s ()
swap marr (i,j) | i == j = return ()
| otherwise = do
x <- MV.unsafeRead marr i
y <- MV.unsafeRead marr j
MV.unsafeWrite marr i y
MV.unsafeWrite marr j x
shuffleInt :: (MonadMC m) => Int -> m [(Int,Int)]
shuffleInt n =
let shuffleIntHelp i | i <= 1 = return []
| otherwise = unsafeInterleaveMC $ do
j <- uniformInt i
ijs <- shuffleIntHelp (i1)
return $ (i1,j):ijs in
shuffleIntHelp n
shuffleInt' :: (MonadMC m) => Int -> m [(Int,Int)]
shuffleInt' n = do
ss <- shuffleInt n
length ss `seq` return ss