{-# LANGUAGE BangPatterns #-}
-- | Scoring functions commonly used for evaluation of NLP
-- systems. Most functions in this module work on lists, but some take
-- a precomputed table of 'Counts'. This will give a speedup if you
-- want to compute multiple scores on the same data. For example to
-- compute the Mutual Information, Variation of Information and the
-- Adujusted Rand Index on the same pair of clusterings:
--
-- >>> let cs = counts $ zip "abcabc" "abaaba"
-- >>> mapM_ (print . ($ cs)) [mi, ari, vi]
--
module NLP.Scores 
    ( 
      -- * Scores for classification and ranking
      accuracy
    , recipRank
    , avgPrecision
      -- * Scores for clustering
    , ari
    , mi
    , vi
      -- * Auxiliary types and functions
    , Count
    , Counts
    , counts
    , sum
    , mean
    , jaccard
    , entropy
    )
where
import Data.List hiding (sum)
import qualified Data.Set as Set
import qualified Data.Map as Map
import Prelude hiding (sum)

-- | Accuracy: the proportion of elements in the first list equal to 
-- elements at corresponding positions in second list. Lists should be
-- of equal lengths.
accuracy :: (Eq a, Fractional n) => [a] -> [a] -> n
accuracy xs = mean . map fromEnum . zipWith (==) xs
{-# SPECIALIZE accuracy :: [Double] -> [Double] -> Double #-}

-- | Reciprocal rank: the reciprocal of the rank at which the first arguments
-- occurs in the list given as the second argument.
recipRank :: (Eq a, Fractional n) => a -> [a] -> n
recipRank y ys = 
    case [ r | (r,y') <- zip [1::Int ..] ys , y' == y ] of
      []  -> 0
      r:_ -> 1/fromIntegral r
{-# SPECIALIZE recipRank :: Double -> [Double] -> Double #-}

-- | Average precision. 
-- <http://en.wikipedia.org/wiki/Information_retrieval#Average_precision>
avgPrecision :: (Fractional n, Ord a) => Set.Set a -> [a] -> n
avgPrecision gold _ | Set.size gold == 0 = 0
avgPrecision gold xs =
      (/fromIntegral (Set.size gold))
    . sum 
    . map (\(r,rel,cum) -> if rel == 0 
                          then 0 
                          else fromIntegral cum / fromIntegral r)
    . takeWhile (\(_,_,cum) -> cum <= Set.size gold) 
    . snd 
    . mapAccumL (\z (r,rel) -> (z+rel,(r,rel,z+rel))) 0
    $ [ (r,fromEnum $ x `Set.member` gold) | (x,r) <- zip xs [1::Int ..]]
{-# SPECIALIZE avgPrecision :: (Ord a) => Set.Set a -> [a] -> Double #-}

-- | Mutual information: MI(X,Y) = H(X) - H(X|Y) = H(Y) - H(Y|X). Also
-- known as information gain.
mi :: (Ord a, Ord b) => Counts a b -> Double
mi (Counts cxy cx cy) =
  let n = Map.foldl' (+) 0 cxy
      cell (P x y) nxy = 
        let nx = cx Map.! x
            ny = cy Map.! y
        in  nxy / n * logBase 2 (nxy * n / nx / ny)
  in sum [ cell (P x y) nxy | (P x y, nxy) <- Map.toList cxy ]

-- | Variation of information: VI(X,Y) = H(X) + H(Y) - 2 MI(X,Y)
vi :: (Ord a, Ord b) => Counts a b -> Double
vi cs@(Counts cxy cx cy) = entropy (elems cx) + entropy (elems cy) - 2 * mi cs
  where elems = Map.elems

-- | Adjusted Rand Index: <http://en.wikipedia.org/wiki/Rand_index>
ari :: (Ord a, Ord b) => Counts a b -> Double
ari (Counts cxy cx cy) =  (sum1 - sum2*sum3/choicen2) 
                        / (1/2 * (sum2+sum3) - (sum2*sum3) / choicen2)
  where choicen2 = choice (sum . Map.elems $ cx) 2
        sum1 = sum [ choice nij 2 | nij <- Map.elems cxy ]
        sum2 = sum [ choice ni 2 | ni <- Map.elems cx ]
        sum3 = sum [ choice nj 2 | nj <- Map.elems cy ]

-- | A count
type Count = Double
-- | Count table
data Counts a b = 
  Counts 
  { joint :: !(Map.Map (P a b) Count)   -- ^ Counts of both components
  , marginalFst :: !(Map.Map a Count) -- ^ Counts of the first component
  , marginalSnd :: !(Map.Map b Count) -- ^ Counts of the second component
  }
data P a b = P !a !b deriving (Eq, Ord)

-- | The empty count table
empty :: (Ord a, Ord b) => Counts a b
empty = Counts Map.empty Map.empty Map.empty

-- | The sum of a list of numbers (without overflowing stack, 
-- unlike 'Prelude.sum').
sum :: (Num a) => [a] -> a
sum = foldl' (+) 0
{-# SPECIALIZE sum :: [Double] -> Double #-}
{-# SPECIALIZE sum :: [Int] -> Int #-}
{-# INLINE sum #-}

-- | The mean of a list of numbers.
mean :: (Fractional n, Real a) => [a] -> n
mean xs = 
    let (P tot len) = foldl' (\(P s l) x -> (P (s+x) (l+1))) (P 0 0) xs
    in realToFrac tot/len
{-# SPECIALIZE mean :: [Double] -> Double #-}

-- | The binomial coefficient: C^n_k = PROD^k_i=1 (n-k-i)/i
choice :: (Enum b, Fractional b) => b -> b -> b
choice n k = foldl' (*) 1 [n-k+1 .. n] / foldl' (*) 1 [1 .. k]
{-# SPECIALIZE choice :: Double -> Double -> Double #-}

-- | Jaccard coefficient
-- J(A,B) = |AB| / |A union B|
jaccard :: (Fractional n, Ord a) => Set.Set a -> Set.Set a -> n
jaccard a b = 
  fromIntegral (Set.size (Set.intersection a b))
  / 
  fromIntegral (Set.size (Set.union a b))
{-# SPECIALIZE jaccard :: (Ord a) => Set.Set a -> Set.Set a -> Double #-}  

-- | Entropy: H(X) = -SUM_i P(X=i) log_2(P(X=i))
entropy :: [Count] -> Double
entropy cx = negate $ sum [ f nx | nx <- cx ]
    where n    = sum cx
          logn = logBase 2 n
          f nx = nx / n * (logBase 2 nx - logn)

-- | Creates count table 'Counts'
counts :: (Ord a, Ord b) => [(a,b)] -> Counts a b
counts xys = foldl' f empty xys
    where f cs@(Counts cxy cx cy) (!x,!y) = 
            cs { joint       = Map.insertWith' (+) (P x y) 1 cxy
               , marginalFst = Map.insertWith' (+) x 1 cx
               , marginalSnd = Map.insertWith' (+) y 1 cy }