{-# LANGUAGE BangPatterns #-}
module NLP.Scores 
    ( 
      sum
    , mean
    , accuracy
    , recipRank
    , avgPrecision
    )
where
import Data.List hiding (sum)
import qualified Data.Set as Set
import Prelude hiding (sum)

-- | The sum of a list of numbers (without overflowing stack, 
-- unlike 'Prelude.sum').
sum :: (Num a) => [a] -> a
sum = foldl' (+) 0

-- | The mean of a list of numbers.
mean :: (Fractional n, Real a) => [a] -> n
mean xs = 
    let (sum,len) = foldl' (\(!s,!l) x -> (s+x,l+1)) (0,0) xs
    in realToFrac sum/len

-- | Accuracy: the proportion of elements in the first list equal to 
-- elements at corresponding positions in second list. Lists should be
-- of equal lengths.
accuracy :: (Eq a, Fractional n) => [a] -> [a] -> n
accuracy xs = mean . map fromEnum . zipWith (==) xs
 
-- | Reciprocal rank: the reciprocal of the rank at which the first arguments
-- occurs in the list given as the second argument.
recipRank :: (Eq a, Fractional n) => a -> [a] -> n
recipRank y ys = 
    case [ r | (r,y') <- zip [1..] ys , y' == y ] of
      []  -> 0
      r:_ -> 1/fromIntegral r

-- | Average precision. 
-- <http://en.wikipedia.org/wiki/Information_retrieval#Average_precision>
avgPrecision :: (Fractional n, Ord a) => Set.Set a -> [a] -> n
avgPrecision gold _ | Set.size gold == 0 = 0
avgPrecision gold xs =
      (/fromIntegral (Set.size gold))
    . sum 
    . map (\(r,rel,cum) -> if rel == 0 
                          then 0 
                          else fromIntegral cum / fromIntegral r)
    . takeWhile (\(_,_,cum) -> cum <= Set.size gold) 
    . snd 
    . mapAccumL (\z (r,rel) -> (z+rel,(r,rel,z+rel))) 0
    $ [ (r,fromEnum $ x `Set.member` gold) | (x,r) <- zip xs [1..]]