-- | Compare two NE-annotated datasets.


module NLP.Nerf.Compare
( Stats (..)
, (.+.)
, compare
) where


import           Prelude hiding (span, compare)
import           Control.Applicative ((<$>))
import           Control.Monad (forM)
import qualified Control.Monad.State.Strict as ST
import qualified Control.Monad.Writer.Strict as W
import qualified Data.Traversable as Tr
import qualified Data.Set as S
import qualified Data.Map as M
import qualified Data.Char as C
import qualified Data.Text as T

import qualified Data.Named.Tree as N


-- | Statistics.
data Stats = Stats
    { fp    :: !Int     -- ^ false positive
    , tp    :: !Int     -- ^ true positive
    , fn    :: !Int     -- ^ false negative
    , tn    :: !Int     -- ^ true negative
    } deriving (Show, Eq, Ord)


-- | A NE represented by its label and a character-level span, over which
-- the NE is stretched.  White-space characters do not count when computing
-- the span.
data Node a = Node
    { label :: a
    , _span :: (Int, Int)
    } deriving (Show, Eq, Ord)


-- | A union of two spans.
spanUnion :: (Int, Int) -> (Int, Int) -> (Int, Int)
spanUnion (p0, q0) (p1, q1) = (min p0 p1, max q0 q1)


-- | Add stats.
(.+.) :: Stats -> Stats -> Stats
x .+. y = Stats
    { fp    = fp x + fp y
    , tp    = tp x + tp y
    , fn    = fn x + fn y
    , tn    = tn x + tn y }


-- | Compare two NE-annotated datasets.  The function assumes, that
-- forest pairs correspond to the same sentences.
compare
    :: Ord a
    => [ ( N.NeForest a T.Text
         , N.NeForest a T.Text) ]
    -> M.Map a Stats
compare xs = M.unionsWith (.+.)
    [ cmpNodes (nodesF $ toIDs x) (nodesF $ toIDs y)
    | (x, y) <- xs ]


-- | Compare two sets of `Node`s.  The function is label-sensitive.
cmpNodes :: Ord a => S.Set (Node a) -> S.Set (Node a) -> M.Map a Stats
cmpNodes x y = M.fromList
    [ (key, mkStats (with key x) (with key y))
    | key <- S.toList keys ]
  where
    keys    = S.union (getKeys x) (getKeys y)
    getKeys = S.fromList . map label . S.toList
    with k  = S.filter ((==k).label)


-- | Compare two sets of `Node`s.  The function is label-insensitive.
mkStats :: Ord a => S.Set (Node a) -> S.Set (Node a) -> Stats
mkStats x y = Stats
    { fp    = S.size (S.difference y x)
    , tp    = S.size (S.intersection x y)
    , fn    = S.size (S.difference x y)
    , tn    = 0 }


-- | Replace words with character-level position identifiers.
-- White-spaces are ignored.
toIDs :: N.NeForest a T.Text -> N.NeForest a (Int, Int)
toIDs ts = flip ST.evalState 0 $ forM ts $ Tr.mapM $ \e -> case e of
    Left x  -> return (Left x)
    Right x -> do
        let k = T.length $ T.filter (not.C.isSpace) x
        i <- ST.get
        ST.put $ i + k
        return $ Right (i, i + k)


-- | Extract the set of nodes from the NE forest.  
nodesF :: Ord a => N.NeForest a (Int, Int) -> S.Set (Node a)
nodesF = S.unions . map nodesT


-- | Extract the set of nodes from the NE tree.
nodesT :: Ord a => N.NeTree a (Int, Int) -> S.Set (Node a)
nodesT = W.execWriter . mkNode


-- | Make `Node` from a tree.  Return the span of the tree.
mkNode
    :: Ord a => N.NeTree a (Int, Int)
    -> W.Writer (S.Set (Node a)) (Int, Int)
mkNode (N.Node (Right i) _) = return i
mkNode (N.Node (Left neType) xs) = do
    span <- foldl1 spanUnion <$> mapM mkNode xs
    W.tell $ S.singleton $ Node neType span
    return span