{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -Wall #-}
module System.Random.MWC.Probability (
module MWC
, Prob(..)
, samples
, uniform
, uniformR
, normal
, standardNormal
, isoNormal
, logNormal
, exponential
, inverseGaussian
, laplace
, gamma
, inverseGamma
, normalGamma
, weibull
, chiSquare
, beta
, gstudent
, student
, pareto
, dirichlet
, symmetricDirichlet
, discreteUniform
, zipf
, categorical
, discrete
, bernoulli
, binomial
, negativeBinomial
, multinomial
, poisson
, crp
) where
import Control.Applicative
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Data.Monoid (Sum(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Foldable (Foldable)
#endif
import qualified Data.Foldable as F
import Data.List (findIndex)
import qualified Data.IntMap as IM
import System.Random.MWC as MWC hiding (uniform, uniformR)
import qualified System.Random.MWC as QMWC
import qualified System.Random.MWC.Distributions as MWC.Dist
import System.Random.MWC.CondensedTable
newtype Prob m a = Prob { sample :: Gen (PrimState m) -> m a }
samples :: PrimMonad m => Int -> Prob m a -> Gen (PrimState m) -> m [a]
samples n model gen = sequenceA (replicate n (sample model gen))
{-# INLINABLE samples #-}
instance Functor m => Functor (Prob m) where
fmap h (Prob f) = Prob (fmap h . f)
instance Monad m => Applicative (Prob m) where
pure = Prob . const . pure
(<*>) = ap
instance Monad m => Monad (Prob m) where
return = pure
m >>= h = Prob $ \g -> do
z <- sample m g
sample (h z) g
{-# INLINABLE (>>=) #-}
instance (Monad m, Num a) => Num (Prob m a) where
(+) = liftA2 (+)
(-) = liftA2 (-)
(*) = liftA2 (*)
abs = fmap abs
signum = fmap signum
fromInteger = pure . fromInteger
instance MonadTrans Prob where
lift m = Prob $ const m
instance MonadIO m => MonadIO (Prob m) where
liftIO m = Prob $ const (liftIO m)
instance PrimMonad m => PrimMonad (Prob m) where
type PrimState (Prob m) = PrimState m
primitive = lift . primitive
{-# INLINE primitive #-}
uniform :: (PrimMonad m, Variate a) => Prob m a
uniform = Prob QMWC.uniform
{-# INLINABLE uniform #-}
uniformR :: (PrimMonad m, Variate a) => (a, a) -> Prob m a
uniformR r = Prob $ QMWC.uniformR r
{-# INLINABLE uniformR #-}
discreteUniform :: (PrimMonad m, Foldable f) => f a -> Prob m a
discreteUniform cs = do
j <- uniformR (0, length cs - 1)
return $ F.toList cs !! j
{-# INLINABLE discreteUniform #-}
standardNormal :: PrimMonad m => Prob m Double
standardNormal = Prob MWC.Dist.standard
{-# INLINABLE standardNormal #-}
normal :: PrimMonad m => Double -> Double -> Prob m Double
normal m sd = Prob $ MWC.Dist.normal m sd
{-# INLINABLE normal #-}
logNormal :: PrimMonad m => Double -> Double -> Prob m Double
logNormal m sd = exp <$> normal m sd
{-# INLINABLE logNormal #-}
exponential :: PrimMonad m => Double -> Prob m Double
exponential r = Prob $ MWC.Dist.exponential r
{-# INLINABLE exponential #-}
laplace :: (Floating a, Variate a, PrimMonad m) => a -> a -> Prob m a
laplace mu sigma = do
u <- uniformR (-0.5, 0.5)
let b = sigma / sqrt 2
return $ mu - b * signum u * log (1 - 2 * abs u)
{-# INLINABLE laplace #-}
weibull :: (Floating a, Variate a, PrimMonad m) => a -> a -> Prob m a
weibull a b = do
x <- uniform
return $ (- 1/a * log (1 - x)) ** 1/b
{-# INLINABLE weibull #-}
gamma :: PrimMonad m => Double -> Double -> Prob m Double
gamma a b = Prob $ MWC.Dist.gamma a b
{-# INLINABLE gamma #-}
inverseGamma :: PrimMonad m => Double -> Double -> Prob m Double
inverseGamma a b = recip <$> gamma a b
{-# INLINABLE inverseGamma #-}
normalGamma :: PrimMonad m => Double -> Double -> Double -> Double -> Prob m Double
normalGamma mu lambda a b = do
tau <- gamma a b
let xsd = sqrt (recip (lambda * tau))
normal mu xsd
{-# INLINABLE normalGamma #-}
chiSquare :: PrimMonad m => Int -> Prob m Double
chiSquare k = Prob $ MWC.Dist.chiSquare k
{-# INLINABLE chiSquare #-}
beta :: PrimMonad m => Double -> Double -> Prob m Double
beta a b = do
u <- gamma a 1
w <- gamma b 1
return $ u / (u + w)
{-# INLINABLE beta #-}
pareto :: PrimMonad m => Double -> Double -> Prob m Double
pareto a xmin = do
y <- exponential a
return $ xmin * exp y
{-# INLINABLE pareto #-}
dirichlet
:: (Traversable f, PrimMonad m) => f Double -> Prob m (f Double)
dirichlet as = do
zs <- traverse (`gamma` 1) as
return $ fmap (/ sum zs) zs
{-# INLINABLE dirichlet #-}
symmetricDirichlet :: PrimMonad m => Int -> Double -> Prob m [Double]
symmetricDirichlet n a = dirichlet (replicate n a)
{-# INLINABLE symmetricDirichlet #-}
bernoulli :: PrimMonad m => Double -> Prob m Bool
bernoulli p = (< p) <$> uniform
{-# INLINABLE bernoulli #-}
binomial :: PrimMonad m => Int -> Double -> Prob m Int
binomial n p = fmap (length . filter id) $ replicateM n (bernoulli p)
{-# INLINABLE binomial #-}
negativeBinomial :: (PrimMonad m, Integral a) => a -> Double -> Prob m Int
negativeBinomial n p = do
y <- gamma (fromIntegral n) ((1 - p) / p)
poisson y
{-# INLINABLE negativeBinomial #-}
multinomial :: (Foldable f, PrimMonad m) => Int -> f Double -> Prob m [Int]
multinomial n ps = do
let (cumulative, total) = runningTotals (F.toList ps)
replicateM n $ do
z <- uniformR (0, total)
case findIndex (> z) cumulative of
Just g -> return g
Nothing -> error "mwc-probability: invalid probability vector"
where
runningTotals :: Num a => [a] -> ([a], a)
runningTotals xs = let adds = scanl1 (+) xs in (adds, sum xs)
{-# INLINABLE multinomial #-}
gstudent :: PrimMonad m => Double -> Double -> Double -> Prob m Double
gstudent m s k = do
sd <- fmap sqrt (inverseGamma (k / 2) (s * 2 / k))
normal m sd
{-# INLINABLE gstudent #-}
student :: PrimMonad m => Double -> Prob m Double
student = gstudent 0 1
{-# INLINABLE student #-}
isoNormal
:: (Traversable f, PrimMonad m) => f Double -> Double -> Prob m (f Double)
isoNormal ms sd = traverse (`normal` sd) ms
{-# INLINABLE isoNormal #-}
inverseGaussian :: PrimMonad m => Double -> Double -> Prob m Double
inverseGaussian lambda mu = do
nu <- standardNormal
let y = nu ** 2
s = sqrt (4 * mu * lambda * y + mu ** 2 * y ** 2)
x = mu * (1 + 1 / (2 * lambda) * (mu * y - s))
thresh = mu / (mu + x)
z <- uniform
if z <= thresh
then return x
else return (mu ** 2 / x)
{-# INLINABLE inverseGaussian #-}
poisson :: PrimMonad m => Double -> Prob m Int
poisson l = Prob $ genFromTable table where
table = tablePoisson l
{-# INLINABLE poisson #-}
categorical :: (Foldable f, PrimMonad m) => f Double -> Prob m Int
categorical ps = do
xs <- multinomial 1 ps
case xs of
[x] -> return x
_ -> error "mwc-probability: invalid probability vector"
{-# INLINABLE categorical #-}
discrete :: (Foldable f, PrimMonad m) => f (Double, a) -> Prob m a
discrete d = do
let (ps, xs) = unzip (F.toList d)
idx <- categorical ps
pure (xs !! idx)
{-# INLINABLE discrete #-}
zipf :: (PrimMonad m, Integral b) => Double -> Prob m b
zipf a = do
let
b = 2 ** (a - 1)
go = do
u <- uniform
v <- uniform
let xInt = floor (u ** (- 1 / (a - 1)))
x = fromIntegral xInt
t = (1 + 1 / x) ** (a - 1)
if v * x * (t - 1) / (b - 1) <= t / b
then return xInt
else go
go
{-# INLINABLE zipf #-}
crp
:: PrimMonad m
=> Double
-> Int
-> Prob m [Integer]
crp a n = do
ts <- go crpInitial 1
pure $ F.toList (fmap getSum ts)
where
go acc i
| i == n = pure acc
| otherwise = do
acc' <- crpSingle i acc a
go acc' (i + 1)
{-# INLINABLE crp #-}
crpSingle :: (PrimMonad m, Integral b) =>
Int
-> CRPTables (Sum b)
-> Double
-> Prob m (CRPTables (Sum b))
crpSingle i zs a = do
zn1 <- categorical probs
pure $ crpInsert zn1 zs
where
probs = pms <> [pm1]
acc m = fromIntegral m / (fromIntegral i - 1 + a)
pms = F.toList $ fmap (acc . getSum) zs
pm1 = a / (fromIntegral i - 1 + a)
newtype CRPTables c = CRP {
getCRPTables :: IM.IntMap c
} deriving (Eq, Show, Functor, Foldable, Semigroup, Monoid)
crpInitial :: CRPTables (Sum Integer)
crpInitial = crpInsert 0 mempty
crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert k (CRP ts) = CRP $ IM.insertWith (<>) k (Sum 1) ts