module Control.Monad.Random (
RandPicker (..),
MonadRand,
Rand (..),
evalRand, execRand,
rand, oneOf, inRange,
fromFreqs, withFreq,
RandT (..),
evalRandT, execRandT
) where
import System.Random
import Data.Array.IArray ((!), listArray, Array)
import Control.Applicative
import Control.Monad
import Control.Arrow (first, second)
import Control.Monad.Trans
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Writer.Class
import qualified Data.IntervalMap as IM
class RandPicker m where
pick :: Rand a -> m a
type MonadRand m = (Monad m, RandPicker m)
instance RandPicker IO where
pick r = do
g <- getStdGen
let (x, g') = r `runRand` g
setStdGen g'
return x
newtype Rand a = Rand { runRand :: RandomGen g => g -> (a, g) }
instance Monad Rand where
return x = Rand (\ g -> (x, g))
r >>= f = Rand (\ g -> let (x, g') = r `runRand` g in f x `runRand` g')
instance Functor Rand where
fmap f r = Rand (\ g -> let (x, g') = r `runRand` g in (f x, g'))
instance Applicative Rand where
pure = return
f <*> x = do
h <- f
a <- x
return (h a)
instance RandPicker Rand where
pick = id
evalRand :: RandomGen g => Rand a -> g -> a
evalRand v g = fst $ v `runRand` g
execRand :: RandomGen g => Rand a -> g -> g
execRand v g = snd $ v `runRand` g
oneOf :: [a] -> Rand a
oneOf xs = fmap (arr !) $ inRange range
where
range = (0, length xs 1)
arr = (listArray range :: [a] -> Array Int a) xs
rand :: Random a => Rand a
rand = Rand random
inRange :: Random a => (a, a) -> Rand a
inRange r = Rand (randomR r)
fromFreqs :: Real b => [(a, b)] -> Rand a
fromFreqs fs = Rand (\ g ->
let (from, to) = genRange g in
let range = toRational (to from) in
let ratio = freqSum / range in
let (i, g') = next g in
let j = (*) ratio $ toRational (i from) in
case IM.containing intervalMap j of
[(_, x)] -> (x, g')
_ -> error "Index not in the map."
)
where
elems = preprocess fs
freqSum = sum $ map snd elems
intervalMap = IM.fromAscList $ computeIntervals 0 elems
preprocess = map (second toRational) . filter ((> 0) . snd)
computeIntervals _ [] = undefined
computeIntervals lower ((v, f) : []) = let upper = (lower + f) in
[(IM.ClosedInterval lower upper, v)]
computeIntervals lower ((v, f) : xs) = let upper = (lower + f) in
(IM.IntervalCO lower upper, v) : computeIntervals upper xs
withFreq :: Real b => a -> b -> (a, b)
withFreq = (,)
newtype RandT m a = RandT { runRandT :: RandomGen g => g -> m (a, g) }
instance Functor m => Functor (RandT m) where
fmap f r = RandT (\ g -> fmap (first f) $ r `runRandT` g)
instance Applicative m => Applicative (RandT m) where
pure x = RandT (\ g -> pure (x, g))
f <*> x = RandT (\ g -> let (g', g'') = split g in
fmap (\ (h, g3') x -> (h x, g3')) (f `runRandT` g') <*>
fmap fst (x `runRandT` g''))
instance Monad m => Monad (RandT m) where
return x = RandT (\ g -> return (x, g))
r >>= f = RandT (runRandT r >=> (\ ( x, g) -> f x `runRandT` g))
fail err = RandT (\ _ -> fail err)
instance Monad m => RandPicker (RandT m) where
pick r = RandT (\ g -> return $ r `runRand` g)
instance MonadTrans RandT where
lift m = RandT (\ g -> m >>= (\ x -> return (x, g)))
instance MonadReader r m => MonadReader r (RandT m) where
ask = lift ask
local f m = RandT (\ g -> do
(x, g') <- m `runRandT` g
y <- local f (return x)
return (y, g'))
instance MonadWriter w m => MonadWriter w (RandT m) where
tell = lift . tell
listen r = RandT (\ g -> do
((x, g'), w) <- listen $ r `runRandT` g
return ((x, w), g'))
pass r = RandT (\ g -> pass $ do
((x, f), g') <- r `runRandT` g
return ((x, g'), f))
instance MonadState s m => MonadState s (RandT m) where
get = lift get
put = lift . put
instance MonadIO m => MonadIO (RandT m) where
liftIO = lift . liftIO
instance MonadPlus m => MonadPlus (RandT m) where
mzero = lift mzero
mplus a b = RandT (\ g -> let (g', g'') = split g in
(a `runRandT` g') `mplus` (b `runRandT` g''))
evalRandT :: (RandomGen g, Monad m) => RandT m a -> g -> m a
evalRandT r g = (r `runRandT` g) >>= (\ (x, _) -> return x)
execRandT :: (RandomGen g, Monad m) => RandT m a -> g -> m g
execRandT r g = (r `runRandT` g) >>= (\ (_, g') -> return g')