{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
module NLP.Concraft.DAG.Morphosyntax.Accuracy
(
Stats(..)
, AccCfg (..)
, collect
, precision
, recall
, accuracy
) where
import Prelude hiding (Word)
import GHC.Conc (numCapabilities)
import Control.Arrow (first)
import qualified Control.Parallel.Strategies as Par
import Data.List (transpose)
import qualified Data.Foldable as F
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import qualified Data.Tagset.Positional as P
import qualified Data.DAG as DAG
import NLP.Concraft.DAG.Morphosyntax
import NLP.Concraft.DAG.Morphosyntax.Ambiguous
(identifyAmbiguousSegments)
import Debug.Trace (trace)
data AccCfg x = AccCfg
{ onlyOov :: Bool
, onlyAmb :: Bool
, onlyMarkedWith :: S.Set x
, accTagset :: P.Tagset
, expandTag :: Bool
, ignoreTag :: Bool
, weakAcc :: Bool
, discardProb0 :: Bool
, verbose :: Bool
}
data Stats = Stats
{ tp :: !Int
, fp :: !Int
, tn :: !Int
, fn :: !Int
, ce :: !Int
} deriving (Show, Eq, Ord)
zeroStats :: Stats
zeroStats = Stats 0 0 0 0 0
addStats :: Stats -> Stats -> Stats
addStats x y = Stats
{ tp = tp x + tp y
, fp = fp x + fp y
, tn = tn x + tn y
, fn = fn x + fn y
, ce = ce x + ce y
}
goodAndBad
:: (Word w, Ord x, Show x)
=> AccCfg x
-> Sent w (P.Tag, x)
-> Sent w (P.Tag, x)
-> Stats
goodAndBad cfg dag1 dag2
| discardProb0 cfg && (dagProb dag1 < eps || dagProb dag2 < eps) = zeroStats
| otherwise =
F.foldl' addStats zeroStats
. DAG.mapE gather
$ dag
where
eps = 1e-9
dag = DAG.zipE' dag1 dag2
ambiDag = identifyAmbiguousSegments dag
traceThem gold tagg =
if verbose cfg
then trace
( let info = (,) <$> orth <*> choice cfg in
"comparing '" ++
show (info <$> gold) ++
"' with '" ++
show (info <$> tagg) ++
"'"
)
else id
gather edgeID (gold, tagg)
| (onlyOov cfg `implies` isOov) &&
(onlyAmb cfg `implies` isAmb) &&
((not . S.null) (onlyMarkedWith cfg) `implies` isMarked) =
traceThem gold tagg $
gather0
(maybe S.empty (choice cfg) gold)
(maybe S.empty (choice cfg) tagg)
| otherwise = zeroStats
where
isOov = oov $ case (gold, tagg) of
(Just seg, _) -> seg
(_, Just seg) -> seg
_ -> error "Accuracy.goodAndBad: impossible happened"
hasMarker =
any (`S.member` onlyMarkedWith cfg) . map (snd . fst) . M.toList
isMarked = hasMarker $ case (gold, tagg) of
(Just seg1, Just seg2) ->
unWMap (tags seg1) `M.union` unWMap (tags seg2)
(Just seg, _) -> unWMap $ tags seg
(_, Just seg) -> unWMap $ tags seg
_ -> error "Accuracy.goodAndBad: impossible2 happened"
isAmb = DAG.edgeLabel edgeID ambiDag
gather0 gold tagg
| S.null gold && S.null tagg =
zeroStats {tn = 1}
| S.null gold =
zeroStats {fp = 1}
| S.null tagg =
zeroStats {fn = 1}
| otherwise =
if consistent gold tagg
then zeroStats {tp = 1}
else zeroStats {fp = 1, fn = 1, ce = 1}
consistent xs ys
| weakAcc cfg = (not . S.null) (S.intersection xs ys)
| otherwise = xs == ys
goodAndBad'
:: (Word w, Ord x, Show x)
=> AccCfg x
-> [Sent w (P.Tag, x)]
-> [Sent w (P.Tag, x)]
-> Stats
goodAndBad' cfg goldData taggData =
F.foldl' addStats zeroStats
[ goodAndBad cfg dag1 dag2
| (dag1, dag2) <- zip goldData taggData ]
collect
:: (Word w, Ord x, Show x)
=> AccCfg x
-> [Sent w (P.Tag, x)]
-> [Sent w (P.Tag, x)]
-> Stats
collect cfg goldData taggData =
let k = numCapabilities
parts = partition k (zip goldData taggData)
xs = Par.parMap Par.rseq (uncurry (goodAndBad' cfg) . unzip) parts
in F.foldl' addStats zeroStats xs
precision :: Stats -> Double
precision Stats{..}
= fromIntegral tp
/ fromIntegral (tp + fp)
recall :: Stats -> Double
recall Stats{..}
= fromIntegral tp
/ fromIntegral (tp + fn)
accuracy :: Stats -> Double
accuracy Stats{..}
= fromIntegral (tp + tn)
/ fromIntegral (tp + fp + tn + fn - ce)
dagProb :: Sent w t -> Double
dagProb dag = sum
[ fromEdge edgeID
| edgeID <- DAG.dagEdges dag
, DAG.isInitialEdge edgeID dag ]
where
fromEdge edgeID
= edgeProb edgeID
* fromNode (DAG.endsWith edgeID dag)
edgeProb edgeID =
let Seg{..} = DAG.edgeLabel edgeID dag
in sum . map snd . M.toList $ unWMap tags
fromNode nodeID =
case DAG.outgoingEdges nodeID dag of
[] -> 1
xs -> sum (map fromEdge xs)
choice :: (Ord x) => AccCfg x -> Seg w (P.Tag, x) -> S.Set (P.Tag, x)
choice AccCfg{..}
= S.fromList . expandMaybe . best
where
expandMaybe
| ignoreTag = map (first $ const dummyTag)
| expandTag = concatMap (\(tag, x) -> map (,x) $ P.expand accTagset tag)
| otherwise = id
dummyTag = P.Tag "AmbiSeg" M.empty
best :: Seg w t -> [t]
best seg
| null zs = []
| otherwise =
let maxProb = maximum (map snd zs)
in if maxProb < eps
then []
else map fst
. filter ((>= maxProb - eps) . snd)
$ zs
where
zs = M.toList . unWMap . tags $ seg
eps = 1.0e-9
partition :: Int -> [a] -> [[a]]
partition n =
transpose . group n
where
group _ [] = []
group k xs = take k xs : (group k $ drop k xs)
implies :: Bool -> Bool -> Bool
implies p q = if p then q else True