{-# LANGUAGE BangPatterns #-}
module Counts ( Count
              , counts
              , mi
              , entropy
              , vi
              , ari
              , rankNormalize
              )
where
import qualified Data.Map as Map
import Data.Map ((!))
import Data.List (foldl', sortBy)
import Prelude hiding (sum)
import Data.Ord

type Count = Double

counts :: (Ord a,Ord b) => [(a,b)] -> (Map.Map (a,b) Count
                                      ,Map.Map a Count
                                      ,Map.Map b Count)
counts xys = foldl' f (Map.empty,Map.empty,Map.empty) xys
    where f (!cxy,!cx,!cy) (!x,!y) = ( Map.insertWith' (+) (x,y) 1 cxy
                                     , Map.insertWith' (+) x 1 cx
                                     , Map.insertWith' (+) y 1 cy )
mi :: (Ord a,Ord b) =>
      (Map.Map (a,b) Count
      ,Map.Map a Count
      ,Map.Map b Count) -> Double
mi (cxy,cx,cy) = 
    let n = Map.fold (+) 0 cxy
        cell (x,y) nxy = 
            let nx = cx ! x
                ny = cy ! y
            in  nxy / n * logBase 2 (nxy * n / nx / ny)
    in sum [ cell (x,y) nxy | ((x,y),nxy) <- Map.toList cxy ]


entropy :: (Ord a) => Map.Map a Count -> Double
entropy cx = negate $ sum [ f nx | nx <- Map.elems cx ]
    where n    = sum . Map.elems $ cx
          logn = logBase 2 n
          f nx = nx / n * (logBase 2 nx - logn)

vi :: (Ord a,Ord b) =>
      (Map.Map (a,b) Count
      ,Map.Map a Count
      ,Map.Map b Count) -> Double
vi (cxy,cx,cy) = entropy cx + entropy cy - 2 * mi (cxy,cx,cy)


ari:: (Ord a,Ord b) =>
      (Map.Map (a,b) Count
      ,Map.Map a Count
      ,Map.Map b Count) -> Double
ari (cxy,cx,cy) =   (sum1 - sum2*sum3/choicen2) 
                  / (1/2 * (sum2+sum3) - (sum2*sum3) / choicen2)
    where choicen2 = choice (sum . Map.elems $ cx) 2
          sum1 = sum [ choice nij 2 | nij <- Map.elems cxy ]
          sum2 = sum [ choice ni 2 | ni <- Map.elems cx ]
          sum3 = sum [ choice nj 2 | nj <- Map.elems cy ]

rankNormalize :: (Ord b) => [(a, b)] -> [(a, Double)]
rankNormalize xs = let xs' = map (\(r,x) -> (x,1/r)) 
                             . zip [1..]
                             . map fst
                             . sortBy (flip $ comparing snd)
                             $ xs
                       s = sum . map snd $ xs'
                   in map (\(x,rr) -> (x,rr/s)) xs'
                      
--example = [(1,1),(1,2),(2,1),(2,2),(2,2),(2,3),(3,3),(3,3),(3,3),(3,3)]

--
choice n k = foldl' (*) 1 [n-k+1 .. n] / foldl' (*) 1 [1 .. k]
sum = foldl' (+) 0