{-#LANGUAGE GADTs#-}
{-#LANGUAGE RankNTypes#-}
{-#LANGUAGE FlexibleInstances#-}

{-|
 Module         : Data.Stochastic.Types
 Description    : Types used for the stochastic package.
 License        : GPL-3
 Maintainer     : hackage@mail.kevinl.io
 Stability      : experimental

 This module contains the types used
 for the stochastic package.

 WARNING: In its current state, care should be
 taken when using discrete distributions
 as it is never checked that the probabilities
 sum to 1. As is, execution of sampling may fail at run-time
 if probabilities aren't normalized.
-}

module Data.Stochastic.Types (
  Distribution (..)
, Sampleable (..)
, Sample (..)
, StochProcess (..)
, Sampler (..)
, Mean (..)
, StDev (..)
, marsagliaTsang
) where

import Control.Monad
import Control.Monad.State
import Control.Monad.Writer

import Data.Stochastic.Internal

import qualified Data.Sequence as S

import System.Random

-- | Datatype representing parameterized probability distributions
-- over values of type a. GADTs are used to restrict types
-- of certain distributions (e.g. normal distributions can
-- only be defined over floating point numbers)
data Distribution a where
    Normal :: Mean -> StDev -> Distribution Double
    Bernoulli :: Double -> Distribution Bool
    Discrete :: [(a, Double)] -> Distribution a
    DiscreteUniform :: [a] -> Distribution a
    Uniform :: Distribution Double
    Certain :: a -> Distribution a
    -- | Gamma distribution, where the first parameter is the shape parameter alpha, and the second parameter is the scale parameter beta.
    Gamma :: Double -> Double -> Distribution Double
    Beta :: Double -> Double -> Distribution Double

-- | Class of types from which samples can be obtained.
class Sampleable d where
    -- | Constructor for a datatype from which we always
    -- sample the same value.
    certainDist :: a -> d a
    -- | Sample from the sampleable datatype using a 'RandomGen'
    -- returning a new 'RandomGen'.
    sampleFrom :: (RandomGen g) => d a -> g -> (a, g)

-- | 'Sampleable' instance for 'Distribution'. We ensure
-- that we always pass the *next* 'RandomGen' provided
-- to sampleFrom. This lets us obey the monad laws.
instance Sampleable Distribution where
    sampleFrom da g
        = case da of
            Normal mean stdev 
                -> let (a, g')   = closedRnd g 
                       (a', g'') = closedRnd g'
                       s = (stdev * (boxMuller a a')) + mean
                   in (s, g')
            Bernoulli prob    
                -> let (a, g') = closedRnd g
                   in (a <= prob, g')
            Discrete []
                -> error "cannot sample from empty discrete distribution"
            Discrete l        
                -> let (a, g') = closedRnd g
                   in (scan a l, g')
                   where scan lim [] = 
                             if lim <= 0 then error $ "not normalized discrete dist"
                             else error "empty discrete dist"
                         scan lim (x:xs) = 
                             if lim <= snd x then fst x 
                             else scan (lim - snd x) xs
            DiscreteUniform []
                -> error "cannot sample from empty discrete distribution"
            DiscreteUniform l
                -> let (a, g') = closedRnd g
                       prob = 1 / (fromIntegral $ length l)
                   in (l !! (floor $ a / prob), g')
            Uniform
                -> closedRnd g
            Gamma alpha beta
                -> if alpha <= 0 || beta <= 0 then error "alpha and beta parameter cannot be less than or equal to zero in beta distribution"
                   else if alpha > 0 && alpha < 1 then
                        let (a, g') = sampleFrom (Gamma (alpha + 1) beta) g
                            (uni, g'') = openRnd g'
                        in (a * (uni ** (1/alpha)), g'')
                   else let d = alpha - (1/3)
                            c = 1 / sqrt (9 * d)
                            (m, g') =  marsagliaTsang d c g
                        in (m * beta, g')
            Beta alpha beta
                -> let (x, g') = sampleFrom (Gamma alpha 1) g
                       (y, g'') = sampleFrom (Gamma beta 1) g'
                   in (x / (x + y), g'')
            Certain val       
                -> (val, snd $ openRnd g) 
                -- Seemingly unnecessary, but important to obey the monad laws to always produce the same RandomGen each time we sample.
    certainDist = Certain

-- | Marsaglia and Tsang's rejection method
-- for generating Gamma variates with parameters
-- alpha and 1, where 1 is the scale parameter,
-- given d and c.
marsagliaTsang :: (RandomGen g) => Double -> Double -> g -> (Double, g)
marsagliaTsang d c g =
    let (norm, g') = sampleFrom (Normal 0 1) g
        (uni, g'') = sampleFrom Uniform g'
        v = (1 + (c * norm)) ** 3
    in if norm > ((-1)/c) && 
          log uni < ((norm ** 2)/2 + d - (d * v) + (d * log v))
       then (d * v, g'')
       else marsagliaTsang d c g''

-- | Show instance for 'Distribution's.
instance (Show a) => Show (Distribution a) where
    show da = case da of
        Normal mean stdev -> "Normal " ++ show mean ++ " " ++ show stdev
        Bernoulli prob -> "Bernoulli " ++ show prob
        Discrete l -> "Discrete " ++ show l
        DiscreteUniform l -> "DiscreteUniform " ++ show l       
        Certain val -> "Certain " ++ show val

-- | 'Sample' monad containing a random number generator plus a type from which
-- we can sample values of type a
newtype Sample g d a
    = Sample { runSample :: (RandomGen g, Sampleable d) => State g (d a) }

-- | Monad that represents a stochastic process.
-- It allows us to record numeric values as we sample.
type StochProcess
    = WriterT (S.Seq Double) (Sample StdGen Distribution) Double

-- | Monad instance for Sample.
instance (RandomGen g, Sampleable d) => Monad (Sample g d) where
    return x = Sample $ do
                modify (snd . next)
                return $ certainDist x

    (>>=) ma f = Sample $ do
                     modify (snd . next)
                     dist <- runSample ma
                     g <- get
                     let a = fst $ sampleFrom dist g
                     runSample (f a)

-- | Trivial 'Functor' instance for 'Sample' 'StdGen' 'Distribution'.
instance (RandomGen g, Sampleable s) => Functor (Sample g s) where
    fmap = liftM

-- | Trivial 'Applicative' instance for 'Sample' 'StdGen' 'Distribution'.
instance (RandomGen g, Sampleable s) => Applicative (Sample g s) where
    pure = return
    (<*>) = ap

-- | Type synonym for shorter type annotations for 'Sample'.
type Sampler a = Sample StdGen Distribution a

-- | Type synonym for 'Double' so that the 
-- type annotation for the 'Normal' constructor is more informative.
type Mean = Double

-- | Type synonyms for 'Double' so that the 
-- type annotation for the 'Normal' constructor is more informative.
type StDev = Double