```{-# LANGUAGE DeriveDataTypeable #-}
-- |
-- Module    : Statistics.Distribution.Binomial
-- Copyright : (c) 2009 Bryan O'Sullivan
-- License   : BSD3
--
-- Maintainer  : bos@serpentine.com
-- Stability   : experimental
-- Portability : portable
--
-- The binomial distribution.  This is the discrete probability
-- distribution of the number of successes in a sequence of /n/
-- independent yes\/no experiments, each of which yields success with
-- probability /p/.

module Statistics.Distribution.Binomial
(
BinomialDistribution
-- * Constructors
, binomial
-- * Accessors
, bdTrials
, bdProbability
) where

import Control.Exception (assert)
import qualified Data.Vector.Unboxed as U
import Data.Int (Int64)
import Data.Typeable (Typeable)
import Statistics.Constants (m_epsilon)
import qualified Statistics.Distribution as D
import Statistics.Distribution.Normal (standard)
import Statistics.Math (choose, logFactorial)

-- | The binomial distribution.
data BinomialDistribution = BD {
bdTrials      :: {-# UNPACK #-} !Int
-- ^ Number of trials.
, bdProbability :: {-# UNPACK #-} !Double
-- ^ Probability.
} deriving (Eq, Read, Show, Typeable)

instance D.Distribution BinomialDistribution where
density    = density
cumulative = cumulative
quantile   = quantile

instance D.Variance BinomialDistribution where
variance = variance

instance D.Mean BinomialDistribution where
mean = mean

density :: BinomialDistribution -> Double -> Double
density (BD n p) x
| not (isIntegral x) = integralError "density"
| n == 0             = 1
| x < 0 || x > n'    = 0
| n <= 50 || x < 2   = sign * p'' ** x' * (n `choose` fx) * q'' ** nx'
| otherwise          = sign * exp (x' * log p'' + nx' * log q'' + lf)
where sign = oddX * oddNX
(x',p',q') | x > n' / 2 = (n'-x, q, p)
| otherwise  = (x,    p, q)
oddX | p' < 0 && odd fx     = -1
| otherwise            = 1
oddNX | q' < 0 && odd nx    = -1
| otherwise           = 1
p'' = abs p'
q'' = abs q'
q   = 1 - p
nx  = n - fx
nx' = fromIntegral nx
fx  = floor x'
n'  = fromIntegral n
lf  = logFactorial n - logFactorial nx - logFactorial fx

cumulative :: BinomialDistribution -> Double -> Double
cumulative d x
| isIntegral x = U.sum . U.map (density d . fromIntegral) . U.enumFromTo (0::Int) . floor \$ x
| otherwise    = integralError "cumulative"

isIntegral :: Double -> Bool
isIntegral x = x == floorf x

floorf :: Double -> Double
floorf = fromIntegral . (floor :: Double -> Int64)

quantile :: BinomialDistribution -> Double -> Double
quantile dist@(BD n p) prob
| isNaN prob = prob
| p == 1     = n'
| n' < 1e5   = fst (search 1 y0 z0)
| otherwise  = let dy = floorf (n' / 1000)
in  narrow dy (search dy y0 z0)
where q  = 1 - p
n' = fromIntegral n
y0 = n' `min` floorf (µ + σ * (d + γ * (d * d - 1) / 6) + 0.5)
where µ  = n' * p
σ  = sqrt (n' * p * q)
d = D.quantile standard prob
γ  = (q - p) / σ
z0 = cumulative dist y0
search dy y1 z1 | z0 >= prob' = left y1 z1
| otherwise   = right y1
where
prob' = prob * (1 - 64 * m_epsilon)
left y oldZ | y == 0 || z < prob' = (y, oldZ)
| otherwise           = left (max 0 y') z
where z  = cumulative dist y'
y' = y - dy
right y | y' >= n' || z >= prob' = (y', z)
| otherwise              = right y'
where z  = cumulative dist y'
y' = y + dy
narrow dy (y,z) | dy <= 1 || dy' <= n'/1e15 = y
| otherwise                 = narrow dy' (search dy y z)
where dy' = floorf (dy / 100)

mean :: BinomialDistribution -> Double
mean (BD n p) = fromIntegral n * p
{-# INLINE mean #-}

variance :: BinomialDistribution -> Double
variance (BD n p) = fromIntegral n * p * (1 - p)
{-# INLINE variance #-}

binomial :: Int                 -- ^ Number of trials.
-> Double              -- ^ Probability.
-> BinomialDistribution
binomial n p =
assert (n > 0) .
assert (p > 0 && p < 1) \$
BD n p
{-# INLINE binomial #-}

integralError :: String -> a
integralError f = error ("Statistics.Distribution.Binomial." ++ f ++
": non-integer-valued input")
```