{-# LANGUAGE ExistentialQuantification #-}

module Stochastic.Distributions.Discrete(
  mkBinomial
  ,mkBernoulli
  ,mkPoisson
  ,mkZipF
  ,mkGeometric
  ,Stochastic.Distributions.Discrete.Dist(..)
  ,module Stochastic.Distribution.Discrete
  
  ) where

import Stochastic.Uniform
import Data.Maybe
import Control.Monad.State.Lazy
import Stochastic.Tools
import Stochastic.Generators.Discrete
import Stochastic.Distributions ()
import Stochastic.Distribution.Discrete
import Stochastic.Generator
import System.Random
import qualified Stochastic.Distributions.Continuous as C

data Dist = Uniform Int Int 
          | Poisson Double
          | Geometric Double 
          | Bernoulli Double 
          | Binomial Int Double DiscreteCache
          | ZipF Int Double DiscreteCache

data Sample = Sample Dist UniformRandom

instance RandomGen Sample where
  next g = rand g

mkBinomial :: Double -> Int -> UniformRandom -> Sample
mkBinomial p n = Sample (Binomial n p cache)
  where
    nd = toDbl n
    cache :: DiscreteCache
    cache = foldr (\(w,x,y,z) l -> (w,y,z):l) [] (create n)
    create :: Int -> [(Int, Double, Double, Double)]
    create 0 = let pmfk = (nd * (log (1 - p)) )
               in [(0, pmfk, exp pmfk, exp pmfk)]
    create k = (k, lpmfk, pmfk, cdfk) : sub
      where
        sub = create (k-1)
        (j, lpmfj, pmfj, cdfj) = head sub
        pmfk = exp lpmfk
        kd = toDbl k
        lpmfk = lpmfj +
                (log p) - (log (1- p)) -
                (log (kd)) + (log (nd-kd+1))
        cdfk = pmfk + cdfj
        
-- k, pmf k, cdf k
-- do in log space
type DiscreteCache = [(Int, Double, Double)]

mkUniform :: Int -> Int -> UniformRandom -> Sample
mkUniform a b = Sample (Uniform a b)

mkBernoulli :: Double -> UniformRandom -> Sample
mkBernoulli p = Sample (Bernoulli p) 

mkPoisson :: Double -> UniformRandom -> Sample
mkPoisson y = Sample (Poisson y)

mkGeometric :: Double -> UniformRandom -> Sample
mkGeometric p = Sample (Geometric p) 

mkZipF :: Int -> Double -> UniformRandom -> Sample
mkZipF n slope = Sample (ZipF n slope cache)
  where
    hns = sum $ take n $ harmonics slope
    cache :: DiscreteCache
    cache = create n
    f k = (toDbl (k))**slope
    create :: Int -> DiscreteCache
    create 1 = let pmfk = 1 / hns in [(1,pmfk,pmfk)]
    create k =
      let pmfk = (1/((f k) * hns)) in (k, pmfk, cdfj + pmfk) : sub
      where
        sub             = create (k-1)
        (j, pmfj, cdfj) = head sub


instance DiscreteSample Sample where
  entropy (Sample _ u) = entropy u

  rand (Sample (Binomial n p cache) g0) =
    mapTuple
    (\u -> 
      length $ filter (pred u) cache  )
    (Sample $ Binomial n p cache)
    (C.rand g0)
    where pred u (k, pmf, cdf) = cdf < u
  rand (Sample (Geometric p) g0) =
    mapTuple
    ((\u -> ceiling $ (log u) / (log (1-p))))
    (Sample $ Geometric p)
    (C.rand g0)
  rand (Sample (Poisson y) g0) =
    let f (x, g1) = (C.expTransform y x, g1) in
    mapTuple
    (\x -> length x)
    (Sample $ Poisson y)
    ((foldGenWhile (f . C.rand) (+) (0.0) (<1.0)) g0)
  rand (Sample (Bernoulli p) g0) =
    mapTuple
    (\x -> if (x >= p) then 1 else 0) 
    (Sample $ Bernoulli p)
    (C.rand g0)
  rand (Sample (ZipF n slope cache) u0) = 
    mapTuple
    (\u ->
      1 + (length $ filter (pred u) cache)) 
    (Sample $ ZipF n slope cache)
    (C.rand u0)
    where pred u (k, pmf, cdf) = cdf < u
  rand (Sample (Uniform a b) g0) =
    mapTuple
    (\x -> truncate (toDbl (b - a) * x + toDbl a))
    (Sample $ Uniform a b)
    (C.rand g0)

instance DiscreteDistribution Sample where
  cdf  (Sample d _) = cdf  d
  cdf' (Sample d _) = cdf' d
  pmf  (Sample d _) = pmf  d

instance DiscreteDistribution Dist where
  cdf (Poisson y) x =
    (1/(exp y)) * (sum [ (y ** (toDbl i)) / (fromInteger $ fac i) | i <- [0..x]])
  cdf (Geometric p) x = 1 - (1-p)^x
  cdf (Bernoulli p) x
    | x < 0     = 0
    | x >= 1    = 1
    | otherwise = p
  cdf (Binomial n p cache) x = r
    where
      (_, _, r) = fromMaybe (0, 0, 1) $ maybeHead $ filter (\(w,_,_) -> (w==x)) cache
  cdf (ZipF n s cache) x
    | x <= 0 = 0
    | x >= n = 1
    | otherwise = r
    where
      (_, _, r) =
        fromMaybe (0, 0, 1) $ maybeHead $ filter (\(k, _, _) -> x == k) cache 
  cdf (Uniform a b) x = toDbl (x-a) / toDbl (b-a)
  cdf' g@(Poisson y) x =
    (sum . fst) (fold [1..])
    where
      reduce :: Double -> Int -> Double
      reduce = (\p y -> (pmf g y) + p)
      fold :: [Int] -> ([Int], [Int])
      fold = foldGenWhile (myUncons) (reduce) 0 (<x)
  cdf' (Geometric p) x =
    ceiling $ (log (1-x)) / (log (1-p))
  cdf' (Bernoulli p) x
    | x > p     = 1
    | otherwise = 0
  cdf' (Binomial n p cache) x = r
    where
      (r, _, _) =
        head $ reverse $ filter (\(_, _, z) -> z > x) cache
  cdf' (ZipF n s cache) x = r
    where
      (r, _, _) =
        head $ reverse $ filter (\(_, _, z) -> z > x) cache
  cdf' (Uniform a b) x = truncate $ toDbl (b-a) * x + toDbl a
  pmf (Poisson y) x =
    (y^x) / ( exp y * (fromInteger $ fac x) )
  pmf (Geometric p) x = 1 - (1-p)^(x-1)
  pmf (Bernoulli p) 0 = 1-p
  pmf (Bernoulli p) 1 = p
  pmf (Binomial n p cache) x = r
    where
      (_, r, _) = head $ filter (\(w,_,_) -> (w==x)) cache
  pmf (ZipF n s cache) k = r
    where
      (_, r, _) = head $ filter (\(w,_,_) -> (w==k)) cache
  pmf (Uniform a b) x = toDbl x/toDbl (b-a)


myUncons :: [a] -> (a, [a])
myUncons (x:xs) = (x, xs)

toDbl = fromInteger . toInteger