---------------------------------------------------------------------------
-- | Module      : Math.Statistics.Fusion
-- Copyright   : (c) 2008 Don Stewart
-- License     : BSD3
--
-- Maintainer  : dons@galois.com
-- Stability   : experimental
-- Portability : portable
--
-- Description :
--
-- A collection of commonly used statistical functions designed to
-- fuse under stream fusion, with attention paid to the generated assembly.
--
-- These are high performance replacements for various list functions, 
-- implemented in pure Haskell using stream fusion for sequences.
--
-- To illustrate the performance gap, consider the task of calculating
-- the numerically stable mean of a sequence of 1e9 double values.
--
-- Using the standard list implementation provided by the hstats
-- package, 
--
-- >    $ time ./mean 
-- >    3.141592653589793
-- >    ./mean  26.80s user 0.08s system 99% cpu 26.965 total
--
-- And this package,
--
-- >    $ time ./mean                      
-- >    3.141592653589793
-- >    ./mean  6.25s user 0.00s system 99% cpu   6.261 total
--
--------------------------------------------------------------------------

module Math.Statistics.Fusion (
         mean
       , harmonic
       , geometric
       , var
       , stddev
    ) where

import Data.Array.Vector

-- | A numerically stable mean.
mean :: UArr Double -> Double
mean = fstT . foldlU k (T 0 1)
    where
        k (T b a) n = T b' (a+1)
            where b' = b + (n - b) / fromIntegral a
{-# INLINE mean #-}
-- ^ required.

-- | The harmonic mean
harmonic :: UArr Double -> Double
harmonic xs = fromIntegral a / b
    where
        T b a = foldlU k (T 0 0) xs
        k (T b a) n = T (b + (1/n)) (a+1)
{-# INLINE harmonic #-}

-- | The geometric mean of a non-negative list.
geometric :: UArr Double -> Double
geometric xs = p ** (1 / fromIntegral n)
    where
        T p n = foldlU k (T 1 0) xs
        k (T p n) a = T (p * a) (n + 1)
{-# INLINE geometric #-}

-- | A numerically stable variance.
var :: UArr Double -> Double
var = quotT1 . foldlU k (T1 1 0 0)
   where
       k (T1 n mean m2) x = T1 (n + 1) mean' (m2 + delta * (x - mean'))
           where delta = x - mean
                 mean' = mean + delta / (fromIntegral n)
{-# INLINE var #-}

-- | The standard deviation.
stddev :: UArr Double -> Double
stddev = sqrt . var
{-# INLINE stddev #-}

------------------------------------------------------------------------
-- Helper code. Monomorphic unpacked accumulators.

-- don't support polymorphism, as we can't get unboxed returns if we use it.
data T = T {-# UNPACK #-}!Double {-# UNPACK #-}!Int

data T1 = T1 {-# UNPACK #-}!Int {-# UNPACK #-}!Double {-# UNPACK #-}!Double

fstT :: T -> Double
fstT (T a _) = a

-- this is a terrible name, and probably a bad place to be doing this
quotT1 :: T1 -> Double
quotT1 (T1 n _ m2) = m2 / (fromIntegral $ n - 2)

{-

Consider this core:

with data T a = T !a !Int

$wfold :: Double#
               -> Int#
               -> Int#
               -> (# Double, Int# #)

and without,

$wfold :: Double#
               -> Int#
               -> Int#
               -> (# Double#, Int# #)

yielding to boxed returns and heap checks.

-}