-- |
-- Module      :  Numeric.Stats
-- Copyright   :  (c) OleksandrZhabenko 2020
-- License     :  MIT
-- Stability   :  Experimental
-- Maintainer  :  olexandr543@yahoo.com
--
-- A very basic descriptive statistics. Functions use a tail recursion approach to compute the values and are strict by an accumulator.

{-# LANGUAGE BangPatterns, MagicHash #-}

module Numeric.Stats where

import GHC.Exts
import GHC.Prim

-- | Inspired by: https://www.mat.univie.ac.at/~neum/scan/01.pdf
sumNeumaierF :: [Float] -> Float# -> Float# -> Float
sumNeumaierF ((F# !x):xs) !c !sum1 = sumNeumaierF xs c1 t
 where { !c1 = if I# (geFloat# (fabsFloat# sum1) (fabsFloat# x)) == 1 then plusFloat# (plusFloat# c (minusFloat# sum1 t)) x else plusFloat# (plusFloat# c (minusFloat# x t)) sum1 ; !t = plusFloat# sum1 x }
sumNeumaierF _ c sum1 = F# (plusFloat# sum1 c)

-- | Inspired by: https://www.mat.univie.ac.at/~neum/scan/01.pdf
-- Looks like that it does not work for Haskell GHC (may be the compiler uses associativity rules to add numbers) as expected, the floating-point arithmetic 
-- cancellation errors grows just like for the usual 'sum'.
sumNeumaier :: RealFrac a => [a] -> a -> a -> a
sumNeumaier (!x:xs) !c !sum1 = sumNeumaier xs c1 t
 where { !c1 = if abs sum1 >= abs x then c + (sum1 - t) + x else c + (x - t) + sum1 ; !t = sum1 + x }
sumNeumaier _ c sum = sum + c

sum_NF :: [Float] -> Float
sum_NF xs = sumNeumaierF xs 0.0# 0.0#
{-# INLINE sum_NF #-}

-- | A tail-recursive realization for the statistic mean. 
mean2 :: RealFrac a => [a] -> a -> a -> a -> a
mean2 (x:xs) !s1 !l1 _ = mean2 xs (s1 + x) (l1 + 1) ((s1 + x) / (l1 + 1))
mean2 _ _ _ m = m

-- | Similar to 'mean2', but uses GHC unlifted types from @ghc-prim@ package. 
mean2F :: [Float] -> Float# -> Float# -> Float# -> Float
mean2F ((F# !x):xs) !s1 !l1 _ = mean2F xs (plusFloat# s1 x) (plusFloat# l1 1.0#) (divideFloat# (plusFloat# s1 x) (plusFloat# l1 1.0#))
mean2F _ _ _ m = F# m

-- | One-pass and tail-recursive realization for the pair of the mean and dispersion. Is vulnerable to the floating-point cancellation errors. 
meanWithDispersion :: (RealFrac a, Floating a) => [a] -> a -> a -> a -> a -> a -> (a,a)
meanWithDispersion (!x:xs) !s1 !s2 !l1 !m1 !d = meanWithDispersion xs (s1 + x) (s2 + x**2) (l1 + 1) (m0 s1 l1 x) (m0 s2 l1 (x**2) - (m0 s1 l1 x)**2)
  where m0 !s3 !l2 !x = (s3 + x) / (l2 + 1)
meanWithDispersion _ _ _ _ !m !d = (m,d)

meanWithDispersionF :: [Float] -> Float# -> Float# -> Float# -> Float# -> Float# -> (Float,Float)
meanWithDispersionF ((F# !x):xs) !s1 !s2 !l1 !m1 !d =
 meanWithDispersionF xs (plusFloat# s1 x) (plusFloat# s2 (powerFloat# x 2.0#)) (plusFloat# l1 1.0#) (m0 s1 l1 x)
   (m0 s2 l1 (minusFloat# (powerFloat# x 2.0#) (powerFloat# (m0 s1 l1 x) 2.0#)))
      where m0 !s3 !l2 !x = (divideFloat# (plusFloat# s3 x) (plusFloat# l2 1.0#))
meanWithDispersionF _ _ _ _ !m !d = (F# m, F# d)

-- | A tail-recursive realization for the statistic mean. A variant of the 'mean2' function. 
mean :: RealFrac a => [a] -> a
mean xs = mean2 xs 0.0 0.0 0.0
{-# INLINE mean #-}

-- | Uses 'mean2F' inside. 
meanF :: [Float] -> Float
meanF xs = mean2F xs 0.0# 0.0# 0.0#
{-# INLINE meanF #-}

meanWithDisp :: (RealFrac a, Floating a) => [a] -> (a,a)
meanWithDisp xs = meanWithDispersion xs 0.0 0.0 0.0 0.0 0.0
{-# INLINE meanWithDisp #-}

meanWithDispF :: [Float] -> (Float, Float)
meanWithDispF xs = meanWithDispersionF xs 0.0# 0.0# 0.0# 0.0# 0.0#
{-# INLINE meanWithDispF #-}