{-# OPTIONS_GHC -Wall #-}

-- |
-- Module: System.Random.MWC.Probability
-- Copyright: (c) 2015 Jared Tobin
-- License: MIT
--
-- Maintainer: Jared Tobin <jared@jtobin.ca>
-- Stability: unstable
-- Portability: ghc
--
-- A probability monad based on sampling functions.
--
-- Probability distributions are abstract constructs that can be represented in
-- a variety of ways.  The sampling function representation is particularly
-- useful - it's computationally efficient, and collections of samples are
-- amenable to much practical work.
--
-- Probability monads propagate uncertainty under the hood.  An expression like
-- @'beta' 1 8 >>= 'binomial' 10@ corresponds to a
-- <https://en.wikipedia.org/wiki/Beta-binomial_distribution beta-binomial>
-- distribution in which the uncertainty captured by @'beta' 1 8@ has been
-- marginalized out.
--
-- The distribution resulting from a series of effects is called the
-- /predictive distribution/ of the model described by the corresponding
-- expression.  The monadic structure lets one piece together a hierarchical
-- structure from simpler, local conditionals:
--
-- > hierarchicalModel = do
-- >   [c, d, e, f] <- replicateM 4 $ uniformR (1, 10)
-- >   a <- gamma c d
-- >   b <- gamma e f
-- >   p <- beta a b
-- >   n <- uniformR (5, 10)
-- >   binomial n p
--
-- The functor instance for a probability monad transforms the support of the
-- distribution while leaving its density structure invariant in some sense.
-- For example, @'uniform'@ is a distribution over the 0-1 interval, but @fmap
-- (+ 1) uniform@ is the translated distribution over the 1-2 interval.
--
-- >>> sample (fmap (+ 1) uniform) gen
-- 1.5480073474340754

module System.Random.MWC.Probability (
    module MWC
  , Prob(..)
  , samples

  , uniform
  , uniformR
  , discreteUniform
  , categorical
  , standard
  , normal
  , logNormal
  , exponential
  , gamma
  , inverseGamma
  , chiSquare
  , beta
  , dirichlet
  , symmetricDirichlet
  , bernoulli
  , binomial
  , multinomial
  , student
  , isoGauss
  , poisson
  ) where

import Control.Applicative
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.Trans.Class
import Data.List (findIndex)
import System.Random.MWC as MWC hiding (uniform, uniformR)
import qualified System.Random.MWC as QMWC
import qualified System.Random.MWC.Distributions as MWC.Dist
import System.Random.MWC.CondensedTable

-- | A probability distribution characterized by a sampling function.
--
-- >>> gen <- create
-- >>> sample uniform gen
-- 0.4208881170464097
newtype Prob m a = Prob { sample :: Gen (PrimState m) -> m a }

-- | Sample from a model 'n' times.
--
-- >>> samples 2 uniform gen
-- [0.6738707766845254,0.9730405951541817]
samples :: PrimMonad m => Int -> Prob m a -> Gen (PrimState m) -> m [a]
samples n model gen = replicateM n (sample model gen)

instance Monad m => Functor (Prob m) where
  fmap h (Prob f) = Prob $ liftM h . f

instance Monad m => Applicative (Prob m) where
  pure  = return
  (<*>) = ap

instance (Applicative m, Monad m, Num a) => Num (Prob m a) where
  (+)         = liftA2 (+)
  (-)         = liftA2 (-)
  (*)         = liftA2 (*)
  abs         = fmap abs
  signum      = fmap signum
  fromInteger = pure . fromInteger

instance Monad m => Monad (Prob m) where
  return  = Prob . const . return
  m >>= h = Prob $ \g -> do
    z <- sample m g
    sample (h z) g

instance MonadTrans Prob where
  lift m = Prob $ const m

-- | The uniform distribution.
uniform :: (PrimMonad m, Variate a) => Prob m a
uniform = Prob QMWC.uniform

-- | The uniform distribution over the provided interval.
uniformR :: (PrimMonad m, Variate a) => (a, a) -> Prob m a
uniformR r = Prob $ QMWC.uniformR r

-- | The discrete uniform distribution.
discreteUniform :: PrimMonad m => [a] -> Prob m a
discreteUniform cs = do
  j <- uniformR (0, length cs - 1)
  return $ cs !! j

-- | The standard normal distribution (a Gaussian with mean 0 and variance 1).
standard :: PrimMonad m => Prob m Double
standard = Prob MWC.Dist.standard

-- | The normal or Gaussian distribution.
normal :: PrimMonad m => Double -> Double -> Prob m Double
normal m sd = Prob $ MWC.Dist.normal m sd

-- | The log-normal distribution.
logNormal :: PrimMonad m => Double -> Double -> Prob m Double
logNormal m sd = exp <$> normal m sd

-- | The exponential distribution.
exponential :: PrimMonad m => Double -> Prob m Double
exponential r = Prob $ MWC.Dist.exponential r

-- | The gamma distribution.
gamma :: PrimMonad m => Double -> Double -> Prob m Double
gamma a b = Prob $ MWC.Dist.gamma a b

-- | The inverse-gamma distribution.
inverseGamma :: PrimMonad m => Double -> Double -> Prob m Double
inverseGamma a b = recip <$> gamma a b

-- | The chi-square distribution.
chiSquare :: PrimMonad m => Int -> Prob m Double
chiSquare k = Prob $ MWC.Dist.chiSquare k

-- | The beta distribution.
beta :: PrimMonad m => Double -> Double -> Prob m Double
beta a b = do
  u <- gamma a 1
  w <- gamma b 1
  return $ u / (u + w)

-- | The Dirichlet distribution.
dirichlet :: PrimMonad m => [Double] -> Prob m [Double]
dirichlet as = do
  zs <- mapM (`gamma` 1) as
  return $ map (/ sum zs) zs

-- | The symmetric Dirichlet distribution (with equal concentration
--   parameters).
symmetricDirichlet :: PrimMonad m => Int -> Double -> Prob m [Double]
symmetricDirichlet n a = dirichlet (replicate n a)

-- | The Bernoulli distribution.
bernoulli :: PrimMonad m => Double -> Prob m Bool
bernoulli p = (< p) <$> uniform

-- | The binomial distribution.
binomial :: PrimMonad m => Int -> Double -> Prob m Int
binomial n p = liftM (length . filter id) $ replicateM n (bernoulli p)

-- | The multinomial distribution.
multinomial :: PrimMonad m => Int -> [Double] -> Prob m [Int]
multinomial n ps = do
  let cumulative = scanl1 (+) ps
  replicateM n $ do
    z <- uniform
    let Just g = findIndex (> z) cumulative
    return g

-- | Student's t distribution.
student :: PrimMonad m => Double -> Double -> Double -> Prob m Double
student m s k = do
  sd <- sqrt <$> inverseGamma (k / 2) (s * 2 / k)
  normal m sd

-- | An isotropic or spherical Gaussian distribution.
isoGauss :: PrimMonad m => [Double] -> Double -> Prob m [Double]
isoGauss ms sd = mapM (`normal` sd) ms

-- | The Poisson distribution.
poisson :: PrimMonad m => Double -> Prob m Int
poisson l = Prob $ genFromTable table where
  table = tablePoisson l

-- | A categorical distribution defined by the supplied list of probabilities.
categorical :: PrimMonad m => [Double] -> Prob m Int
categorical ps = do
  xs <- multinomial 1 ps
  case xs of
    [x] -> return x
    _   -> error "categorical: invalid return value"