{-# OPTIONS_GHC -Wall #-} -- | -- Module: System.Random.MWC.Probability -- Copyright: (c) 2015 Jared Tobin -- License: MIT -- -- Maintainer: Jared Tobin -- Stability: unstable -- Portability: ghc -- -- A probability monad based on sampling functions. -- -- Probability distributions are abstract constructs that can be represented in -- a variety of ways. The sampling function representation is particularly -- useful - it's computationally efficient, and collections of samples are -- amenable to much practical work. -- -- Probability monads propagate uncertainty under the hood. An expression like -- @'beta' 1 8 >>= 'binomial' 10@ corresponds to a -- -- distribution in which the uncertainty captured by @'beta' 1 8@ has been -- marginalized out. -- -- The distribution resulting from a series of effects is called the -- /predictive distribution/ of the model described by the corresponding -- expression. The monadic structure lets one piece together a hierarchical -- structure from simpler, local conditionals: -- -- > hierarchicalModel = do -- > [c, d, e, f] <- replicateM 4 $ uniformR (1, 10) -- > a <- gamma c d -- > b <- gamma e f -- > p <- beta a b -- > n <- uniformR (5, 10) -- > binomial n p -- -- The functor instance for a probability monad transforms the support of the -- distribution while leaving its density structure invariant in some sense. -- For example, @'uniform'@ is a distribution over the 0-1 interval, but @fmap -- (+ 1) uniform@ is the translated distribution over the 1-2 interval. -- -- >>> sample (fmap (+ 1) uniform) gen -- 1.5480073474340754 module System.Random.MWC.Probability ( module MWC , Prob(..) , samples , uniform , uniformR , discreteUniform , categorical , standard , normal , logNormal , exponential , gamma , inverseGamma , chiSquare , beta , dirichlet , symmetricDirichlet , bernoulli , binomial , multinomial , student , isoGauss , poisson ) where import Control.Applicative import Control.Monad import Control.Monad.Primitive import Control.Monad.Trans.Class import Data.List (findIndex) import System.Random.MWC as MWC hiding (uniform, uniformR) import qualified System.Random.MWC as QMWC import qualified System.Random.MWC.Distributions as MWC.Dist import System.Random.MWC.CondensedTable -- | A probability distribution characterized by a sampling function. -- -- >>> gen <- create -- >>> sample uniform gen -- 0.4208881170464097 newtype Prob m a = Prob { sample :: Gen (PrimState m) -> m a } -- | Sample from a model 'n' times. -- -- >>> samples 2 uniform gen -- [0.6738707766845254,0.9730405951541817] samples :: PrimMonad m => Int -> Prob m a -> Gen (PrimState m) -> m [a] samples n model gen = replicateM n (sample model gen) instance Monad m => Functor (Prob m) where fmap h (Prob f) = Prob $ liftM h . f instance Monad m => Applicative (Prob m) where pure = return (<*>) = ap instance (Applicative m, Monad m, Num a) => Num (Prob m a) where (+) = liftA2 (+) (-) = liftA2 (-) (*) = liftA2 (*) abs = fmap abs signum = fmap signum fromInteger = pure . fromInteger instance Monad m => Monad (Prob m) where return = Prob . const . return m >>= h = Prob $ \g -> do z <- sample m g sample (h z) g instance MonadTrans Prob where lift m = Prob $ const m -- | The uniform distribution. uniform :: (PrimMonad m, Variate a) => Prob m a uniform = Prob QMWC.uniform -- | The uniform distribution over the provided interval. uniformR :: (PrimMonad m, Variate a) => (a, a) -> Prob m a uniformR r = Prob $ QMWC.uniformR r -- | The discrete uniform distribution. discreteUniform :: PrimMonad m => [a] -> Prob m a discreteUniform cs = do j <- uniformR (0, length cs - 1) return $ cs !! j -- | The standard normal distribution (a Gaussian with mean 0 and variance 1). standard :: PrimMonad m => Prob m Double standard = Prob MWC.Dist.standard -- | The normal or Gaussian distribution. normal :: PrimMonad m => Double -> Double -> Prob m Double normal m sd = Prob $ MWC.Dist.normal m sd -- | The log-normal distribution. logNormal :: PrimMonad m => Double -> Double -> Prob m Double logNormal m sd = exp <$> normal m sd -- | The exponential distribution. exponential :: PrimMonad m => Double -> Prob m Double exponential r = Prob $ MWC.Dist.exponential r -- | The gamma distribution. gamma :: PrimMonad m => Double -> Double -> Prob m Double gamma a b = Prob $ MWC.Dist.gamma a b -- | The inverse-gamma distribution. inverseGamma :: PrimMonad m => Double -> Double -> Prob m Double inverseGamma a b = recip <$> gamma a b -- | The chi-square distribution. chiSquare :: PrimMonad m => Int -> Prob m Double chiSquare k = Prob $ MWC.Dist.chiSquare k -- | The beta distribution. beta :: PrimMonad m => Double -> Double -> Prob m Double beta a b = do u <- gamma a 1 w <- gamma b 1 return $ u / (u + w) -- | The Dirichlet distribution. dirichlet :: PrimMonad m => [Double] -> Prob m [Double] dirichlet as = do zs <- mapM (`gamma` 1) as return $ map (/ sum zs) zs -- | The symmetric Dirichlet distribution (with equal concentration -- parameters). symmetricDirichlet :: PrimMonad m => Int -> Double -> Prob m [Double] symmetricDirichlet n a = dirichlet (replicate n a) -- | The Bernoulli distribution. bernoulli :: PrimMonad m => Double -> Prob m Bool bernoulli p = (< p) <$> uniform -- | The binomial distribution. binomial :: PrimMonad m => Int -> Double -> Prob m Int binomial n p = liftM (length . filter id) $ replicateM n (bernoulli p) -- | The multinomial distribution. multinomial :: PrimMonad m => Int -> [Double] -> Prob m [Int] multinomial n ps = do let cumulative = scanl1 (+) ps replicateM n $ do z <- uniform let Just g = findIndex (> z) cumulative return g -- | Student's t distribution. student :: PrimMonad m => Double -> Double -> Double -> Prob m Double student m s k = do sd <- sqrt <$> inverseGamma (k / 2) (s * 2 / k) normal m sd -- | An isotropic or spherical Gaussian distribution. isoGauss :: PrimMonad m => [Double] -> Double -> Prob m [Double] isoGauss ms sd = mapM (`normal` sd) ms -- | The Poisson distribution. poisson :: PrimMonad m => Double -> Prob m Int poisson l = Prob $ genFromTable table where table = tablePoisson l -- | A categorical distribution defined by the supplied list of probabilities. categorical :: PrimMonad m => [Double] -> Prob m Int categorical ps = do xs <- multinomial 1 ps case xs of [x] -> return x _ -> error "categorical: invalid return value"