{-# LANGUAGE CPP, Rank2Types #-}
module RBR.Stats (UniVar(..),uniVar,uv,uv_add,uv_del
              ,Quantiles(..),quantiles) where -- ,histogram,display) where

import Data.List (sort, sortBy, group)
-- import FiniteMap

class Statistic s where
    samples :: s -> Int

data UniVar = UV { uv_samples  :: Int
                 , mean     :: Double
                 , stdev    :: Double
                 , variance :: Double
                 , skewness :: Double
                 , kurtosis :: Double
                 , sumSq    :: Double
                 , coeffVar :: Double
                 , stdErrMn :: Double
                 }

instance Statistic UniVar where samples = uv_samples

instance Show UniVar where
    show u = adjust
             [["Samples", (show $ samples u)]
             , ["Mean", (show $ mean u)]
             , ["Standard dev", (show $ stdev u)]
             , ["Variance", (show $ variance u)]
             , ["Skewness", (show $ skewness u)]
             , ["Kurtosis", (show $ kurtosis u)]
             , ["Sum of squares", (show $ sumSq u)]
             , ["Coeff. of var", (show $ coeffVar u)]
             , ["Std err mean", (show $ stdErrMn u)]
             ]

adjust :: [[String]] -> String
adjust [] = []
adjust ([a,b]:xs) = (a++":"++take (15-length a) (repeat ' ')++b) ++ "\n" ++ adjust xs
adjust _ = error "Pattern matching error in 'adjust'"

type UVTMP = (Int,Double,Double,Double,Double)

-- | more or less the univariate function from SAS
--   calculate by tracking n, sum of x, of x², x^3, x^4
uniVar :: (Int, Double, Double, Double, Double) -> UniVar
uniVar (n',x,x2,x3,x4) =
        let n = fromIntegral n'
            m = x/n
            m2 = m*m
            m3 = m*m2
            var  = (x2-m2*n)/(n-1)
            s = sqrt(var)
            skew = (x3 - 3*m*x2 + 2*m3*n)/(s*s*s*n)
            kurt = (x4 - 4*m*x3 + 6*m2*x2 - 4*m3*x + n*m*m3)/(s*s*s*s*n) - 3
            cv = s/m
        in UV n' m s var skew kurt x2 cv (s/sqrt n)

uv_add, uv_del :: Real a => a -> UVTMP -> UVTMP
uv_add = uv_upd (+)
uv_del = uv_upd (-)

-- requires -fglasgow-exts
uv_upd :: (Real a) => (forall b . Real b => b -> b -> b)
       -> a -> UVTMP -> UVTMP
uv_upd f d' (n,x,x2,x3,x4) = let d = toFloat d' in n `seq` x `seq` x2
        `seq` x3 `seq` x4 `seq` (f n 1,f x d,f x2 (d*d),f x3 (d*d*d),f x4 (d*d*d*d))
    where toFloat = fromRational . toRational


uv :: Real a => [a] -> UVTMP
uv ds = foldr uv_add (0,0,0,0,0) ds

data Quantiles = Qs { wsamples   :: Int
                 , smallest  :: Double
                 , quartile1 :: Double
                 , median    :: Double
                 , mode      :: [Double]
                 , quartile3 :: Double
                 , greatest  :: Double
                 }

instance Statistic Quantiles where samples = wsamples

instance Show Quantiles where
    show w = adjust
             [["Samples", (show $ samples w)]
             , ["Smallest", (show $ smallest w)]
             , ["Q1", (show $ quartile1 w)]
             , ["Median", (show $ median w)]
             , ["Modes", (show $ mode w)]
             , ["Q3", (show $ quartile3 w)]
             , ["Greatest", (show $ greatest w)]]

quantiles :: [Double] -> Quantiles
quantiles ds = let
            n = length ds
            sorted = sort ds
            q1 = case n `quotRem` 4 of (q,0) -> ((sorted!!(q-1))+(sorted!!q))/2.0
                                       (q,_) -> sorted!!q
            q2 = case n `quotRem` 2 of (q,0) -> ((sorted!!(q-1))+(sorted!!q))/2.0
                                       (q,_) -> sorted!!q
            q3 = case (3*n) `quotRem` 4 of (q,0) -> ((sorted!!(q-1))+(sorted!!q))/2.0
                                           (q,_) -> sorted!!q
            modes = let
                    ms = sortOn (negate.fst) $ map (\x->(length x,head x))$ group sorted
                    in (snd $ head ms) : map snd
                       (takeWhile (\x->fst x==fst (head ms)) (tail ms))
            in
            Qs n (head sorted) q1 q2 modes q3 (last sorted)

sortOn :: Ord b => (a->b) -> [a] -> [a]
sortOn f = sortBy (\x y -> compare (f x) (f y))

{-
type Histogram = FiniteMap Double Int

-- | histogram builds a histogram given the list of midpoints
histogram :: [Double] -> [Double] -> Histogram
histogram ms xs = foldl (insert ms) emptyFM' xs
    where
    emptyFM' = foldl (\s v -> addToFM s v 0) emptyFM ms
    insert :: [Double] -> Histogram -> Double -> Histogram
    insert (m1:m2:ms) s x = if abs (m1-x) <= abs (m2-x)
                            then addToFM_C (+) s m1 1 else insert (m2:ms) s x
    insert [m1] s x = addToFM_C (+) s m1 1
    insert [] _ _ = error "Must provide at least one midpoint"

-- todo: speed up with strict foldl'

display :: Histogram -> String
display h = unlines $ map disp1 $ fmToList h
    where disp1 (v,n) = show v ++ (take (7-(length $ show v)) (repeat ' ')) ++
                        ": " ++ (take n $ repeat '*')

-}