module Statistics.Test.KruskalWallis
  ( 
    kruskalWallisTest
    
  , kruskalWallisRank
  , kruskalWallis
  , module Statistics.Test.Types
  ) where
import Data.Ord (comparing)
import Data.Foldable (foldMap)
import qualified Data.Vector.Unboxed as U
import Statistics.Function (sort, sortBy, square)
import Statistics.Distribution (complCumulative)
import Statistics.Distribution.ChiSquared (chiSquared)
import Statistics.Types
import Statistics.Test.Types
import Statistics.Test.Internal (rank)
import Statistics.Sample
import qualified Statistics.Sample.Internal as Sample(sum)
kruskalWallisRank :: (U.Unbox a, Ord a) => [U.Vector a] -> [U.Vector Double]
kruskalWallisRank samples = groupByTags
                          . sortBy (comparing fst)
                          . U.zip tags
                          $ rank (==) joinSample
  where
    (tags,joinSample) = U.unzip
                      . sortBy (comparing snd)
                      $ foldMap (uncurry tagSample) $ zip [(1::Int)..] samples
    tagSample t = U.map (\x -> (t,x))
    groupByTags xs
        | U.null xs = []
        | otherwise = sort (U.map snd ys) : groupByTags zs
      where
        (ys,zs) = U.span ((==) (fst $ U.head xs) . fst) xs
kruskalWallis :: (U.Unbox a, Ord a) => [U.Vector a] -> Double
kruskalWallis samples = (nTot - 1) * numerator / denominator
  where
    
    nTot    = fromIntegral $ sumWith rsamples U.length
    
    avgRank = (nTot + 1) / 2
    
    numerator = sumWith rsamples $ \sample ->
        let n = fromIntegral $ U.length sample
        in  n * square (mean sample - avgRank)
    denominator = sumWith rsamples $ \sample ->
        Sample.sum $ U.map (\r -> square (r - avgRank)) sample
    rsamples = kruskalWallisRank samples
kruskalWallisTest :: (Ord a, U.Unbox a) => [U.Vector a] -> Maybe (Test ())
kruskalWallisTest []      = Nothing
kruskalWallisTest samples
  
  | all (>4) ns = Just Test { testSignificance = mkPValue $ complCumulative d k
                            , testStatistics   = k
                            , testDistribution = ()
                            }
  | otherwise   = Nothing
  where
    k  = kruskalWallis samples
    ns = map U.length samples
    d  = chiSquared (length ns - 1)
sumWith :: Num a => [Sample] -> (Sample -> a) -> a
sumWith samples f = Prelude.sum $ fmap f samples
{-# INLINE sumWith #-}