{-# LANGUAGE
    MultiParamTypeClasses,
    FlexibleInstances, FlexibleContexts,
    UndecidableInstances, BangPatterns
  #-}

{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-}

module Data.Random.Distribution.Gamma
    ( Gamma(..)
    , gamma, gammaT

    , Erlang(..)
    , erlang, erlangT

    , mtGamma
    ) where

import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Data.Random.Distribution.Normal

import Data.Ratio

import Numeric.SpecFunctions

-- |derived from  Marsaglia & Tang, "A Simple Method for generating gamma
-- variables", ACM Transactions on Mathematical Software, Vol 26, No 3 (2000), p363-372.
{-# SPECIALIZE mtGamma :: Double -> Double -> RVarT m Double #-}
{-# SPECIALIZE mtGamma :: Float  -> Float  -> RVarT m Float  #-}
mtGamma
    :: (Floating a, Ord a,
        Distribution StdUniform a,
        Distribution Normal a)
    => a -> a -> RVarT m a
mtGamma :: forall a (m :: * -> *).
(Floating a, Ord a, Distribution StdUniform a,
 Distribution Normal a) =>
a -> a -> RVarT m a
mtGamma a
a a
b
    | a
a forall a. Ord a => a -> a -> Bool
< a
1     = do
        a
u <- forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
        forall a (m :: * -> *).
(Floating a, Ord a, Distribution StdUniform a,
 Distribution Normal a) =>
a -> a -> RVarT m a
mtGamma (a
1forall a. Num a => a -> a -> a
+a
a) forall a b. (a -> b) -> a -> b
$! (a
b forall a. Num a => a -> a -> a
* a
u forall a. Floating a => a -> a -> a
** forall a. Fractional a => a -> a
recip a
a)
    | Bool
otherwise = forall {m :: * -> *}. RVarT m a
go
    where
        !d :: a
d = a
a forall a. Num a => a -> a -> a
- forall a. Fractional a => Rational -> a
fromRational (Integer
1forall a. Integral a => a -> a -> Ratio a
%Integer
3)
        !c :: a
c = forall a. Fractional a => a -> a
recip (forall a. Floating a => a -> a
sqrt (a
9forall a. Num a => a -> a -> a
*a
d))

        go :: RVarT m a
go = do
            a
x <- forall a (m :: * -> *). Distribution Normal a => RVarT m a
stdNormalT
            let !v :: a
v   = a
1 forall a. Num a => a -> a -> a
+ a
cforall a. Num a => a -> a -> a
*a
x

            if a
v forall a. Ord a => a -> a -> Bool
<= a
0
                then RVarT m a
go
                else do
                    a
u  <- forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
                    let !x_2 :: a
x_2 = a
xforall a. Num a => a -> a -> a
*a
x; !x_4 :: a
x_4 = a
x_2forall a. Num a => a -> a -> a
*a
x_2
                        v3 :: a
v3 = a
vforall a. Num a => a -> a -> a
*a
vforall a. Num a => a -> a -> a
*a
v
                        dv :: a
dv = a
d forall a. Num a => a -> a -> a
* a
v3
                    if      a
u forall a. Ord a => a -> a -> Bool
< a
1 forall a. Num a => a -> a -> a
- a
0.0331forall a. Num a => a -> a -> a
*a
x_4
                     Bool -> Bool -> Bool
|| forall a. Floating a => a -> a
log a
u forall a. Ord a => a -> a -> Bool
< a
0.5 forall a. Num a => a -> a -> a
* a
x_2 forall a. Num a => a -> a -> a
+ a
d forall a. Num a => a -> a -> a
- a
dv forall a. Num a => a -> a -> a
+ a
dforall a. Num a => a -> a -> a
*forall a. Floating a => a -> a
log a
v3
                        then forall (m :: * -> *) a. Monad m => a -> m a
return (a
bforall a. Num a => a -> a -> a
*a
dv)
                        else RVarT m a
go

{-# SPECIALIZE gamma :: Float  -> Float  -> RVar Float  #-}
{-# SPECIALIZE gamma :: Double -> Double -> RVar Double #-}
gamma :: (Distribution Gamma a) => a -> a -> RVar a
gamma :: forall a. Distribution Gamma a => a -> a -> RVar a
gamma a
a a
b = forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (forall a. a -> a -> Gamma a
Gamma a
a a
b)

gammaT :: (Distribution Gamma a) => a -> a -> RVarT m a
gammaT :: forall a (m :: * -> *). Distribution Gamma a => a -> a -> RVarT m a
gammaT a
a a
b = forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (forall a. a -> a -> Gamma a
Gamma a
a a
b)

erlang :: (Distribution (Erlang a) b) => a -> RVar b
erlang :: forall a b. Distribution (Erlang a) b => a -> RVar b
erlang a
a = forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (forall a b. a -> Erlang a b
Erlang a
a)

erlangT :: (Distribution (Erlang a) b) => a -> RVarT m b
erlangT :: forall a b (m :: * -> *).
Distribution (Erlang a) b =>
a -> RVarT m b
erlangT a
a = forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (forall a b. a -> Erlang a b
Erlang a
a)

data    Gamma a    = Gamma a a
newtype Erlang a b = Erlang a

instance (Floating a, Ord a, Distribution Normal a, Distribution StdUniform a) => Distribution Gamma a where
    {-# SPECIALIZE instance Distribution Gamma Double #-}
    {-# SPECIALIZE instance Distribution Gamma Float #-}
    rvarT :: forall (n :: * -> *). Gamma a -> RVarT n a
rvarT (Gamma a
a a
b) = forall a (m :: * -> *).
(Floating a, Ord a, Distribution StdUniform a,
 Distribution Normal a) =>
a -> a -> RVarT m a
mtGamma a
a a
b

instance (Real a, Distribution Gamma a) => CDF Gamma a where
    cdf :: Gamma a -> a -> Double
cdf (Gamma a
a a
b) a
x = Double -> Double -> Double
incompleteGamma (forall a b. (Real a, Fractional b) => a -> b
realToFrac a
a) (forall a b. (Real a, Fractional b) => a -> b
realToFrac a
x forall a. Fractional a => a -> a -> a
/ forall a b. (Real a, Fractional b) => a -> b
realToFrac a
b)

instance (Integral a, Floating b, Ord b, Distribution Normal b, Distribution StdUniform b) => Distribution (Erlang a) b where
    rvarT :: forall (n :: * -> *). Erlang a b -> RVarT n b
rvarT (Erlang a
a) = forall a (m :: * -> *).
(Floating a, Ord a, Distribution StdUniform a,
 Distribution Normal a) =>
a -> a -> RVarT m a
mtGamma (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
a) b
1

instance (Integral a, Real b, Distribution (Erlang a) b) => CDF (Erlang a) b where
    cdf :: Erlang a b -> b -> Double
cdf (Erlang a
a) b
x = Double -> Double -> Double
incompleteGamma (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
a) (forall a b. (Real a, Fractional b) => a -> b
realToFrac b
x)