module System.Random.MWC.Probability (
module MWC
, Prob(..)
, samples
, uniform
, uniformR
, discreteUniform
, categorical
, standard
, normal
, logNormal
, exponential
, gamma
, inverseGamma
, chiSquare
, beta
, dirichlet
, symmetricDirichlet
, bernoulli
, binomial
, multinomial
, student
, isoGauss
, poisson
) where
import Control.Applicative
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import qualified Data.Foldable as F
import Data.List (findIndex)
import System.Random.MWC as MWC hiding (uniform, uniformR)
import qualified System.Random.MWC as QMWC
import qualified System.Random.MWC.Distributions as MWC.Dist
import System.Random.MWC.CondensedTable
newtype Prob m a = Prob { sample :: Gen (PrimState m) -> m a }
samples :: PrimMonad m => Int -> Prob m a -> Gen (PrimState m) -> m [a]
samples n model gen = replicateM n (sample model gen)
instance Monad m => Functor (Prob m) where
fmap h (Prob f) = Prob $ liftM h . f
instance Monad m => Applicative (Prob m) where
pure = return
(<*>) = ap
instance (Applicative m, Monad m, Num a) => Num (Prob m a) where
(+) = liftA2 (+)
() = liftA2 ()
(*) = liftA2 (*)
abs = fmap abs
signum = fmap signum
fromInteger = pure . fromInteger
instance Monad m => Monad (Prob m) where
return = Prob . const . return
m >>= h = Prob $ \g -> do
z <- sample m g
sample (h z) g
instance MonadTrans Prob where
lift m = Prob $ const m
instance MonadIO m => MonadIO (Prob m) where
liftIO m = Prob $ const (liftIO m)
instance PrimMonad m => PrimMonad (Prob m) where
type PrimState (Prob m) = PrimState m
primitive = lift . primitive
uniform :: (PrimMonad m, Variate a) => Prob m a
uniform = Prob QMWC.uniform
uniformR :: (PrimMonad m, Variate a) => (a, a) -> Prob m a
uniformR r = Prob $ QMWC.uniformR r
discreteUniform :: (PrimMonad m, Foldable f) => f a -> Prob m a
discreteUniform cs = do
j <- uniformR (0, length cs 1)
return $ F.toList cs !! j
standard :: PrimMonad m => Prob m Double
standard = Prob MWC.Dist.standard
normal :: PrimMonad m => Double -> Double -> Prob m Double
normal m sd = Prob $ MWC.Dist.normal m sd
logNormal :: PrimMonad m => Double -> Double -> Prob m Double
logNormal m sd = exp <$> normal m sd
exponential :: PrimMonad m => Double -> Prob m Double
exponential r = Prob $ MWC.Dist.exponential r
gamma :: PrimMonad m => Double -> Double -> Prob m Double
gamma a b = Prob $ MWC.Dist.gamma a b
inverseGamma :: PrimMonad m => Double -> Double -> Prob m Double
inverseGamma a b = recip <$> gamma a b
chiSquare :: PrimMonad m => Int -> Prob m Double
chiSquare k = Prob $ MWC.Dist.chiSquare k
beta :: PrimMonad m => Double -> Double -> Prob m Double
beta a b = do
u <- gamma a 1
w <- gamma b 1
return $ u / (u + w)
dirichlet
:: (Foldable f, PrimMonad m) => f Double -> Prob m [Double]
dirichlet as = do
zs <- mapM (`gamma` 1) (F.toList as)
return $ fmap (/ sum zs) zs
symmetricDirichlet :: PrimMonad m => Int -> Double -> Prob m [Double]
symmetricDirichlet n a = dirichlet (replicate n a)
bernoulli :: PrimMonad m => Double -> Prob m Bool
bernoulli p = (< p) <$> uniform
binomial :: PrimMonad m => Int -> Double -> Prob m Int
binomial n p = liftM (length . filter id) $ replicateM n (bernoulli p)
multinomial :: (Foldable f, PrimMonad m) => Int -> f Double -> Prob m [Int]
multinomial n ps = do
let cumulative = scanl1 (+) (F.toList ps)
replicateM n $ do
z <- uniform
let Just g = findIndex (> z) cumulative
return g
student :: PrimMonad m => Double -> Double -> Double -> Prob m Double
student m s k = do
sd <- sqrt <$> inverseGamma (k / 2) (s * 2 / k)
normal m sd
isoGauss :: (Foldable f, PrimMonad m) => f Double -> Double -> Prob m [Double]
isoGauss ms sd = mapM (`normal` sd) (F.toList ms)
poisson :: PrimMonad m => Double -> Prob m Int
poisson l = Prob $ genFromTable table where
table = tablePoisson l
categorical :: (Foldable f, PrimMonad m) => f Double -> Prob m Int
categorical ps = do
xs <- multinomial 1 ps
case xs of
[x] -> return x
_ -> error "categorical: invalid return value"