{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeFamilies #-}
module Stochastic.Distributions.Continuous(
  mkUniform
  ,mkExp
  ,mkNormal
  ,mkEmpirical
  ,Dist(..)
  ,ContinuousDistribution(..)
  ) where

import Data.Maybe
import Control.Monad.State.Lazy
--import Stochastic.Analysis
import Stochastic.Generator

import Stochastic.Distributions(UniformBase, rDouble)
import qualified Stochastic.Distributions as B(cdf, mkEmpirical, Empirical)
import Stochastic.Distribution.Continuous
import Stochastic.Tools
import Data.Number.Erf

instance ContinuousDistribution UniformBase where
  rand uni = rDouble uni
  cdf  _ x = x
  cdf' _ p = p
  degreesOfFreedom _ = 0


instance Generator Dist where
  type (From Dist) = Double
  nextG = state $ \ g0 -> rand g0

instance Generator UniformBase where
  type (From UniformBase) = Double
  nextG = state $ \ g0 -> rDouble g0


data Dist =
  Uniform UniformBase
  | Exponential Double UniformBase
  | Normal Double Double (Maybe Double) UniformBase
  | ChiSquared Int UniformBase
  | Empirical B.Empirical UniformBase
    -- empirical points, lo, [(point, mass)]

mkEmpirical :: UniformBase -> [Double] -> Dist
mkEmpirical base samples = Empirical (B.mkEmpirical samples) base

mkExp :: UniformBase -> Double -> Dist
mkExp base y = Exponential y base
mkNormal :: UniformBase -> Double -> Double -> Dist
mkNormal uni mean dev = Normal mean dev Nothing uni
mkUniform :: UniformBase -> Dist
mkUniform uni = Uniform uni

instance ContinuousDistribution Dist where
  rand (Uniform uni) = mapTuple (id) (Uniform) (rand uni)
  rand (Exponential y u) =
    mapTuple (\x -> (-1.0/y) * (log $ x)) (Exponential y) (rand u)
  rand (Normal mean dev m uni) = f m
    where
      f (Just x) = (x, (Normal mean dev Nothing uni'))
      f Nothing  = (y, (Normal mean dev (Just z) uni'))
      ([u1, u2], uni') = rands 2 uni
      from_u g = mean + dev * (sqrt (-2 * (log u1))) * ( g (2 * pi * u2) )
      y = from_u (sin)
      z = from_u (cos)

  cdf  (Uniform _) x = x
  cdf  (Exponential y _) x = 1 - (1 / (exp (y*x)))
  cdf  (Normal u s _ _) x =
    0.5 * (1 + (erf ((x-u)/(s * (sqrt 2))) ))
  cdf (ChiSquared k _) x = (1/(gamma (kd/2))) * lig
    where
      kd = fromInteger $ toInteger k
      lig = lower_incomplete_gamma (kd /2) (x/2)
  cdf (Empirical b _) x = B.cdf b x
         

  cdf' (Uniform _) p = p
  cdf' (Exponential y _) p = -(log (1-p)) / y
  cdf' (Normal u s _ _) p =
    u + (s * (sqrt 2) * (inverf(2*p-1)))

  degreesOfFreedom (Uniform _) = 0
  degreesOfFreedom (Exponential _ _) = 1
  degreesOfFreedom (Normal _ _ _ _) = 2