{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TupleSections #-} -- | Accuracy statistics. module NLP.Concraft.DAG.Morphosyntax.Accuracy ( -- * Stats Stats(..) , AccCfg (..) , collect , precision , recall , accuracy ) where import Prelude hiding (Word) import GHC.Conc (numCapabilities) import Control.Arrow (first) import qualified Control.Parallel.Strategies as Par import Data.List (transpose) import qualified Data.Foldable as F import qualified Data.Set as S import qualified Data.Map.Strict as M import qualified Data.Tagset.Positional as P import qualified Data.DAG as DAG import NLP.Concraft.DAG.Morphosyntax import NLP.Concraft.DAG.Morphosyntax.Ambiguous (identifyAmbiguousSegments) -- import NLP.Concraft.DAG.Morphosyntax.Align -- import qualified Data.Text as T import Debug.Trace (trace) -- | Configuration of accuracy computation. data AccCfg x = AccCfg { onlyOov :: Bool -- ^ Limit calculations to OOV words , onlyAmb :: Bool -- ^ Limit calculations to segmentation-ambiguous words , onlyMarkedWith :: S.Set x -- ^ Limit calculations to segments marked with one of the given labels; -- if empty, the option has no effect , accTagset :: P.Tagset -- ^ The underlying tagset , expandTag :: Bool -- ^ Should the tags be expanded? , ignoreTag :: Bool -- ^ Compute segmentation-level accurracy. The actually chosen tags are -- ignored, only information about the chosen DAG edges is relevant. , weakAcc :: Bool -- ^ If weak, there has to be an overlap in the tags assigned to a given -- segment in both datasets. Otherwise, the two sets of tags have to be -- identical. , discardProb0 :: Bool -- ^ Whether sentences with near 0 probability should be discarded from -- evaluation. , verbose :: Bool -- ^ Print information about compared elements } -- | True positives, false positives, etc. data Stats = Stats { tp :: !Int -- ^ True positive , fp :: !Int -- ^ False positive , tn :: !Int -- ^ True negative , fn :: !Int -- ^ False negative , ce :: !Int -- ^ Consistency error (number of edges for which both `fp` and `fn` hold) } deriving (Show, Eq, Ord) -- | Initial statistics. zeroStats :: Stats zeroStats = Stats 0 0 0 0 0 addStats :: Stats -> Stats -> Stats addStats x y = Stats { tp = tp x + tp y , fp = fp x + fp y , tn = tn x + tn y , fn = fn x + fn y , ce = ce x + ce y } goodAndBad :: (Word w, Ord x, Show x) => AccCfg x -> Sent w (P.Tag, x) -- ^ Gold (reference) DAG -> Sent w (P.Tag, x) -- ^ Tagged (to compare) DAG -> Stats goodAndBad cfg dag1 dag2 | discardProb0 cfg && (dagProb dag1 < eps || dagProb dag2 < eps) = zeroStats | otherwise = -- By using `DAG.zipE'`, we allow the DAGs to be slighly different in terms -- of their edge sets. F.foldl' addStats zeroStats . DAG.mapE gather $ dag where eps = 1e-9 dag = DAG.zipE' dag1 dag2 ambiDag = identifyAmbiguousSegments dag traceThem gold tagg = if verbose cfg then trace ( let info = (,) <$> orth <*> choice cfg in "comparing '" ++ show (info <$> gold) ++ "' with '" ++ show (info <$> tagg) ++ "'" ) else id gather edgeID (gold, tagg) | (onlyOov cfg `implies` isOov) && (onlyAmb cfg `implies` isAmb) && ((not . S.null) (onlyMarkedWith cfg) `implies` isMarked) = traceThem gold tagg $ gather0 (maybe S.empty (choice cfg) gold) (maybe S.empty (choice cfg) tagg) | otherwise = zeroStats where isOov = oov $ case (gold, tagg) of (Just seg, _) -> seg (_, Just seg) -> seg _ -> error "Accuracy.goodAndBad: impossible happened" hasMarker = any (`S.member` onlyMarkedWith cfg) . map (snd . fst) . M.toList isMarked = hasMarker $ case (gold, tagg) of (Just seg1, Just seg2) -> unWMap (tags seg1) `M.union` unWMap (tags seg2) (Just seg, _) -> unWMap $ tags seg (_, Just seg) -> unWMap $ tags seg _ -> error "Accuracy.goodAndBad: impossible2 happened" isAmb = DAG.edgeLabel edgeID ambiDag gather0 gold tagg | S.null gold && S.null tagg = zeroStats {tn = 1} | S.null gold = zeroStats {fp = 1} | S.null tagg = zeroStats {fn = 1} | otherwise = if consistent gold tagg then zeroStats {tp = 1} else zeroStats {fp = 1, fn = 1, ce = 1} consistent xs ys | weakAcc cfg = (not . S.null) (S.intersection xs ys) | otherwise = xs == ys goodAndBad' :: (Word w, Ord x, Show x) => AccCfg x -> [Sent w (P.Tag, x)] -> [Sent w (P.Tag, x)] -> Stats goodAndBad' cfg goldData taggData = F.foldl' addStats zeroStats [ goodAndBad cfg dag1 dag2 | (dag1, dag2) <- zip goldData taggData ] -- | Compute the accuracy of the model with respect to the labeled dataset. -- To each `P.Tag` an additional information `x` can be assigned, which will be -- taken into account when computing statistics. collect :: (Word w, Ord x, Show x) => AccCfg x -> [Sent w (P.Tag, x)] -- ^ Gold dataset -> [Sent w (P.Tag, x)] -- ^ Tagged dataset (to be compare with the gold) -> Stats collect cfg goldData taggData = let k = numCapabilities parts = partition k (zip goldData taggData) xs = Par.parMap Par.rseq (uncurry (goodAndBad' cfg) . unzip) parts in F.foldl' addStats zeroStats xs -- in fromIntegral good / fromIntegral (good + bad) precision :: Stats -> Double precision Stats{..} = fromIntegral tp / fromIntegral (tp + fp) recall :: Stats -> Double recall Stats{..} = fromIntegral tp / fromIntegral (tp + fn) accuracy :: Stats -> Double accuracy Stats{..} = fromIntegral (tp + tn) / fromIntegral (tp + fp + tn + fn - ce) -- Not that, above, we substract `ce` so as to count inconsistency errors -- as single ones (their are accounted for twice in `fp + fn`). ------------------------------------------------------ -- Verification ------------------------------------------------------ -- | Compute the probability of the DAG, based on the probabilities assigned to -- different edges and their labels. dagProb :: Sent w t -> Double dagProb dag = sum [ fromEdge edgeID | edgeID <- DAG.dagEdges dag , DAG.isInitialEdge edgeID dag ] where fromEdge edgeID = edgeProb edgeID * fromNode (DAG.endsWith edgeID dag) edgeProb edgeID = let Seg{..} = DAG.edgeLabel edgeID dag in sum . map snd . M.toList $ unWMap tags fromNode nodeID = case DAG.outgoingEdges nodeID dag of [] -> 1 xs -> sum (map fromEdge xs) -- -- | Filter out the sentences with ~0 probability. -- verifyDataset :: [Sent w t] -> [Sent w t] -- verifyDataset = -- filter verify -- where -- verify dag = dagProb dag >= eps -- eps = 1e-9 -------------------------- -- Utils -------------------------- -- | Select the chosen tags. -- -- * Tag expansion is performed here (if demanded) -- * Tags are replaced by a dummy in case of `AmbiSeg` comparison choice :: (Ord x) => AccCfg x -> Seg w (P.Tag, x) -> S.Set (P.Tag, x) choice AccCfg{..} = S.fromList . expandMaybe . best where expandMaybe | ignoreTag = map (first $ const dummyTag) | expandTag = concatMap (\(tag, x) -> map (,x) $ P.expand accTagset tag) | otherwise = id dummyTag = P.Tag "AmbiSeg" M.empty -- | The best tags. best :: Seg w t -> [t] best seg | null zs = [] | otherwise = let maxProb = maximum (map snd zs) in if maxProb < eps then [] else map fst . filter ((>= maxProb - eps) . snd) $ zs where zs = M.toList . unWMap . tags $ seg eps = 1.0e-9 partition :: Int -> [a] -> [[a]] partition n = transpose . group n where group _ [] = [] group k xs = take k xs : (group k $ drop k xs) -- | Implication. implies :: Bool -> Bool -> Bool implies p q = if p then q else True