{-# LANGUAGE TypeFamilies #-}
module Stochastic.Distributions(
  UniformBase(rDouble)
  ,stdBase
  ,seededBase
  ,Empirical(..)
  ,mkEmpirical
  ) where

import Stochastic.Uniform
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as LBS

import System.Random
import Control.Monad.State.Lazy
import Stochastic.Tools
import System.IO
import Data.Word
import Data.Binary.Get

data UniformBase = UniformBase {
  rDouble :: (Double, UniformBase)
}


readWord64 :: Handle -> IO Word64
readWord64 h = do
  w1 <- hGetChar h
  w2 <- hGetChar h
  w3 <- hGetChar h
  w4 <- hGetChar h
  w5 <- hGetChar h
  w6 <- hGetChar h
  w7 <- hGetChar h
  w8 <- hGetChar h
  let words = [w1,w2,w3,w4,w5,w6,w7,w8] :: String
  return $ runGet getWord64host $ LBS.fromStrict $ B.pack words

stdBase :: Integer -> UniformRandom
stdBase s = xorshift128plus s

seededBase :: IO UniformRandom
seededBase = do
  word <- withBinaryFile "/dev/random/" ReadMode (readWord64)
  let seed = toInteger word
  return $ stdBase seed

data Empirical = Empirical {
  degreesOfFreedom :: Int,
  cdf  :: Double -> Double,
  cdf' :: Double -> Double
  }

empiricalCDF' hist x = u - ((c - x)/s)
  where
    interval = head $ filter (\y -> cum_frequence y > x) $ hist
    u = upper_bound interval
    s = slope interval
    c = cum_frequence interval
                 
empiricalCDF hist x
  | x <  (lower_bound $ head hist) = 0
  | x >= (upper_bound $ head $ reverse hist) = 1
  | otherwise =
    f $ filter (\y -> lower_bound y <= x  && upper_bound y >= x) hist
  where
    f [] =
      error $ "x "++(show x)++" not in histogram range\n" ++
      (foldr (\h str -> (show (lower_bound h, upper_bound h, frequency h)) ++ "\n" ++ str) "" hist)
    f (y:ys) =
      let part = (x - (lower_bound y)) in
      let step = part * slope y        in
      (cum_frequence y) + (step * (rel_frequence y))
    
mkEmpirical :: [Double] -> Empirical
mkEmpirical samples = Empirical {
  degreesOfFreedom = length h,
  cdf  = empiricalCDF h,
  cdf' = empiricalCDF' h
  }
  where
    h = fIHistogram samples
    count :: Double
    count = fromInteger . toInteger $ length samples