module Control.Monad.MC.GSLBase (
MC(..),
runMC,
evalMC,
execMC,
unsafeInterleaveMC,
MCT(..),
runMCT,
evalMCT,
execMCT,
unsafeInterleaveMCT,
liftMCT,
RNG,
Seed,
mt19937,
mt19937WithState,
rngName,
rngSize,
rngState,
getRNG,
setRNG,
uniform,
uniformInt,
normal,
exponential,
levy,
levySkew,
poisson,
cauchy,
beta,
logistic,
pareto,
weibull,
gamma,
multinomial,
dirichlet,
) where
import Control.Applicative ( Applicative(..), (<$>) )
import Control.Monad ( ap, liftM, MonadPlus(..) )
import Control.Monad.Cont ( MonadCont(..) )
import Control.Monad.Error ( MonadError(..) )
import Control.Monad.Reader ( MonadReader(..) )
import Control.Monad.State ( MonadState(..) )
import Control.Monad.Writer ( MonadWriter(..) )
import Control.Monad.Trans ( MonadTrans(..), MonadIO(..) )
import Data.Word
import System.IO.Unsafe ( unsafePerformIO, unsafeInterleaveIO )
import qualified Data.Vector.Storable as VS
import qualified GSL.Random.Gen as GSL
import GSL.Random.Dist
newtype MC a = MC (GSL.RNG -> IO a)
runMC :: MC a -> RNG -> (a, RNG)
runMC (MC g) (RNG r) = unsafePerformIO $ do
r' <- GSL.cloneRNG r
a <- g r'
return (a,RNG r')
evalMC :: MC a -> RNG -> a
evalMC g r = fst $ runMC g r
execMC :: MC a -> RNG -> RNG
execMC g r = snd $ runMC g r
unsafeInterleaveMC :: MC a -> MC a
unsafeInterleaveMC (MC m) = MC $ \r ->
unsafeInterleaveIO (m r)
instance Functor MC where
fmap f (MC m) = MC $ \r ->
fmap f (m r)
instance Monad MC where
return a = MC $ \_ -> return a
(MC m) >>= k =
MC $ \r -> m r >>= \a ->
let (MC m') = k a
in m' r
fail s = MC $ \_ -> fail s
instance Applicative MC where
pure = return
(<*>) = ap
newtype MCT m a = MCT (GSL.RNG -> IO (m a))
runMCT :: (Monad m) => MCT m a -> RNG -> m (a,RNG)
runMCT (MCT g) (RNG r) = unsafePerformIO $ do
r' <- GSL.cloneRNG r
ma <- g r'
return (ma >>= \a -> return (a, RNG r'))
evalMCT :: (Monad m) => MCT m a -> RNG -> m a
evalMCT g r = do
~(a,_) <- runMCT g r
return a
execMCT :: (Monad m) => MCT m a -> RNG -> m RNG
execMCT g r = do
~(_,r') <- runMCT g r
return r'
liftMCT :: (Monad m) => MC a -> MCT m a
liftMCT (MC g) = MCT $ \r -> do
a <- g r
return (return a)
unsafeInterleaveMCT :: (Monad m) => MCT m a -> MCT m a
unsafeInterleaveMCT (MCT g) = MCT $ \r ->
unsafeInterleaveIO (g r)
instance (Monad m) => Functor (MCT m) where
fmap f (MCT g) = MCT $ \r -> do
ma <- g r
return (ma >>= return . f)
instance (Monad m) => Monad (MCT m) where
return a = MCT $ \_ -> return (return a)
(MCT g) >>= k =
MCT $ \r -> do
ma <- g r
return $ ma >>= \a ->
let (MCT m') = k a
in unsafePerformIO $ m' r
fail str = MCT $ \_ -> fail str
instance (Monad m) => Applicative (MCT m) where
pure = return
(<*>) = ap
instance (MonadPlus m) => MonadPlus (MCT m) where
mzero = MCT $ \_ -> mzero
(MCT m) `mplus` (MCT n) =
MCT $ \r -> do
r' <- GSL.cloneRNG r
mr <- m r
nr <- n r'
return (mr `mplus` nr)
instance MonadTrans MCT where
lift m = MCT $ \_ -> return m
instance (MonadCont m) => MonadCont (MCT m) where
callCC f = MCT $ \r ->
return $ callCC $ \k ->
let (MCT m) = f (\a -> MCT $ \_ -> return (k a))
in unsafePerformIO (m r)
instance (MonadError e m) => MonadError e (MCT m) where
throwError = lift . throwError
(MCT g) `catchError` h = MCT $ \r -> do
ma <- g r
return $ ma `catchError` \e ->
let (MCT m') = h e
in unsafePerformIO (m' r)
instance (MonadIO m) => MonadIO (MCT m) where
liftIO = lift . liftIO
instance (MonadReader r m) => MonadReader r (MCT m) where
ask = lift ask
local f (MCT g) = MCT $ \r -> do
ma <- g r
return $ local f ma
instance (MonadState s m) => MonadState s (MCT m) where
get = lift get
put = lift . put
instance (MonadWriter w m) => MonadWriter w (MCT m) where
tell = lift . tell
listen (MCT g) = MCT $ \r -> do
ma <- g r
return (listen ma)
pass (MCT g) = MCT $ \r -> do
maf <- g r
return (pass maf)
newtype RNG = RNG GSL.RNG
type Seed = Word64
rngName :: RNG -> String
rngName (RNG r) = unsafePerformIO $ GSL.getName r
rngSize :: RNG -> Int
rngSize (RNG r) = fromIntegral $ unsafePerformIO $ GSL.getSize r
rngState :: RNG -> [Word8]
rngState (RNG r) = unsafePerformIO $ GSL.getState r
getRNG :: MC RNG
getRNG = MC (\r -> liftM RNG $ GSL.cloneRNG r)
setRNG :: RNG -> MC ()
setRNG (RNG r') = MC $ \r -> GSL.copyRNG r r'
mt19937 :: Seed -> RNG
mt19937 s = unsafePerformIO $ do
r <- GSL.newRNG GSL.mt19937
GSL.setSeed r s
return (RNG r)
mt19937WithState :: [Word8] -> RNG
mt19937WithState xs = unsafePerformIO $ do
r <- GSL.newRNG GSL.mt19937
GSL.setState r xs
return (RNG r)
uniform :: Double -> Double -> MC Double
uniform 0 1 = liftRan0 GSL.getUniform
uniform a b = liftRan2 getFlat a b
uniformInt :: Int -> MC Int
uniformInt = liftRan1 GSL.getUniformInt
normal :: Double -> Double -> MC Double
normal 0 1 = liftRan0 getUGaussianRatioMethod
normal mu 1 = (mu +) <$> liftRan0 getUGaussianRatioMethod
normal 0 sigma = liftRan1 getGaussianRatioMethod sigma
normal mu sigma = (mu +) <$> liftRan1 getGaussianRatioMethod sigma
exponential :: Double -> MC Double
exponential = liftRan1 getExponential
poisson :: Double -> MC Int
poisson = liftRan1 getPoisson
levy :: Double -> Double -> MC Double
levy = liftRan2 getLevy
levySkew :: Double -> Double -> Double -> MC Double
levySkew = liftRan3 getLevySkew
cauchy :: Double -> MC Double
cauchy = liftRan1 getCauchy
beta :: Double -> Double -> MC Double
beta = liftRan2 getBeta
logistic :: Double -> MC Double
logistic = liftRan1 getLogistic
pareto :: Double -> Double -> MC Double
pareto = liftRan2 getPareto
weibull :: Double -> Double -> MC Double
weibull = liftRan2 getWeibull
gamma :: Double -> Double -> MC Double
gamma = liftRan2 getGamma
multinomial :: Int -> VS.Vector Double -> MC (VS.Vector Int)
multinomial = liftRan2 getMultinomial
dirichlet :: VS.Vector Double -> MC (VS.Vector Double)
dirichlet = liftRan1 getDirichlet
liftRan0 :: (GSL.RNG -> IO a) -> MC a
liftRan0 = MC
liftRan1 :: (GSL.RNG -> a -> IO b) -> a -> MC b
liftRan1 f a = MC $ \r -> f r a
liftRan2 :: (GSL.RNG -> a -> b -> IO c) -> a -> b -> MC c
liftRan2 f a b = MC $ \r -> f r a b
liftRan3 :: (GSL.RNG -> a -> b -> c -> IO d) -> a -> b -> c -> MC d
liftRan3 f a b c = MC $ \r -> f r a b c