{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE TupleSections #-} module NLP.Concraft.DAG.Guess ( -- * Types Guesser (..) , putGuesser , getGuesser -- * Marginals , marginals , marginalsSent -- -- * Guessing -- , guess -- , include -- , guessSent -- * Training , TrainConf (..) , R0T (..) , train -- * Utils , schemed ) where import Prelude hiding (words) import Control.Applicative ((<$>), (<*>), pure) import Data.Binary (Binary, put, get, Put, Get) import Data.Text.Binary () import System.Console.CmdArgs import qualified Data.Set as S import qualified Data.Map.Strict as M -- import qualified Data.Vector as V import qualified Data.DAG as DAG import Data.DAG (DAG) import qualified Control.Monad.Ox as Ox -- import qualified Data.CRF.Chain1.Constrained as CRF import qualified Data.CRF.Chain1.Constrained.DAG as CRF -- import qualified Data.CRF.Chain1.Constrained.Dataset.External as CRF.Ext import qualified Numeric.SGD.Momentum as SGD -- import NLP.Concraft.Schema hiding (schematize) -- import qualified NLP.Concraft.Morphosyntax as X import NLP.Concraft.DAG.Schema hiding (schematize) import qualified NLP.Concraft.DAG.Morphosyntax as X -- | A guessing model. data Guesser t s = Guesser { schemaConf :: SchemaConf , crf :: CRF.CRF Ob s , zeroProbLab :: s , unkTagSet :: S.Set t -- ^ The tagset considered for the unknown words (TODO: a solution -- parallel and not 100% consistent with what is implemented in the CRF -- library) , simplify :: t -> s -- ^ A tag simplification function } -- instance (Ord t, Binary t, Ord s, Binary s) => Binary (Guesser t s) where -- put Guesser{..} = do -- put schemaConf -- put crf -- put zeroProbLab -- put simpliMap -- get = Guesser <$> get <*> get <*> get <*> get -- | Store the entire guessing model apart from the simplification function. putGuesser :: (Binary t, Binary s, Ord s) => Guesser t s -> Put putGuesser Guesser{..} = do put schemaConf put crf put zeroProbLab put unkTagSet -- | Get the disambiguation model, provided the simplification function. -- getGuesser :: (M.Map t T.Tag) -> Get (Guesser t) getGuesser :: (Binary t, Binary s, Ord s) => (t -> s) -> Get (Guesser t s) getGuesser smp = Guesser <$> get <*> get <*> get <*> get <*> pure smp -- -------------------------- -- -- Simplify -- -------------------------- -- -- -- -- | Simplify the given label. -- simplify :: (Ord t) => Guesser t s -> t -> s -- simplify Guesser{..} x = -- case M.lookup x simpliMap of -- Nothing -> zeroProbLab -- Just y -> y -------------------------- -- Schematize -------------------------- -- | Schematize the input sentence according to the 'schema' rules. -- TODO: looks like there is no reason at all for `Schema w t a` to -- be parametrized with `t`? schematize :: (X.Word w) => Schema w t a -> X.Sent w t -> CRF.Sent Ob t schematize schema sent = DAG.mapE f sent where f i = const $ CRF.Word (obs i) (lbs i) obs = S.fromList . Ox.execOx . schema sent lbs i | X.oov w = S.empty | otherwise = X.interpsSet w where w = DAG.edgeLabel i sent -------------------------- -- Marginals -------------------------- -- | Determine the marginal probabilities of the individual labels in the sentence. marginals :: (X.Word w, Ord t, Ord s) => CRF.Config s -> Guesser t s -> X.Sent w t -> DAG () (X.WMap t) marginals cfg gsr = fmap X.tags . marginalsSent cfg gsr -- | Replace the probabilities of the sentence labels with the marginal probabilities -- stemming from the model. marginalsSent :: (X.Word w, Ord t, Ord s) => CRF.Config s -> Guesser t s -> X.Sent w t -> X.Sent w t marginalsSent cfg gsr sent = (\new -> inject gsr new sent) . fmap tags . marginalsCRF cfg gsr $ sent where tags = X.mkWMap . M.toList . considerZero . choice -- we mix the chosen and the potential interpretations together choice w = M.unionWith (+) (CRF.unProb . CRF.choice $ w) (M.fromList . map (,0) . interps $ w) interps = S.toList . CRF.lbs . CRF.word -- if empty, we choose the zero probability label. considerZero m | M.null m = M.singleton (zeroProbLab gsr) 0 | otherwise = m -- | Ascertain the marginal probabilities of to individual labels in the sentence. marginalsCRF :: (X.Word w, Ord t, Ord s) => CRF.Config s -> Guesser t s -> X.Sent w t -> CRF.SentL Ob s marginalsCRF cfg gsr dag0 = let schema = fromConf (schemaConf gsr) dag = X.mapSent (simplify gsr) dag0 in CRF.marginals cfg (crf gsr) (schematize schema dag) -- -- | Replace the probabilities of the sentence labels with the new probabilities -- -- stemming from the CRF sentence. -- inject :: DAG () (X.WMap t) -> X.Sent w t -> X.Sent w t -- inject newSent srcSent = -- let doit (new, src) = src {X.tags = new} -- in fmap doit (DAG.zipE newSent srcSent) -- | Replace the probabilities of the sentence labels with the new probabilities -- stemming from the CRF sentence. -- -- TODO: The behavior for OOV words seems unoptimal, since all possible labels -- are taken into account, and not only the default CRF ones. Still, it's not -- necessarily a problem, maybe not even from the speed point of view. -- inject :: (Ord t, Ord s, X.Word w) => Guesser t s -> DAG () (X.WMap s) -> X.Sent w t -> X.Sent w t inject gsr newSent srcSent = let doit (target, src) = let oldTags = if X.oov (X.word src) then X.mkWMap . map (,0) . S.toList . unkTagSet $ gsr else X.tags src newTags = injectWMap gsr target oldTags in src {X.tags = newTags} in fmap doit (DAG.zipE newSent srcSent) -- -- | Replace label probabilities with the new probabilities. -- inject -- :: (Ord t, Ord s) -- => Guesser t s -- -> DAG () (X.WMap s) -- -> DAG () (X.WMap t) -- -> DAG () (X.WMap t) -- inject gsr newDat srcDag = -- let doit (newSpl, src) = injectWMap gsr newSpl src -- in fmap doit (DAG.zipE newDat srcDag) -- | Replace label probabilities with the new probabilities. injectWMap :: (Ord t, Ord s) => Guesser t s -> X.WMap s -> X.WMap t -> X.WMap t injectWMap gsr newSpl src = X.mkWMap [ ( tag , maybe 0 id $ M.lookup (simplify gsr tag) (X.unWMap newSpl) ) | (tag, _) <- M.toList (X.unWMap src) ] -- -------------------------- -- -- ??? -- -------------------------- -- -- -- -- -- -- | Determine the 'k' most probable labels for each word in the sentence. -- -- guess :: (X.Word w, Ord t) -- -- => Int -> Guesser t -> X.Sent w t -> [[t]] -- -- guess k gsr sent = -- -- let schema = fromConf (schemaConf gsr) -- -- in CRF.tagK k (crf gsr) (schematize schema sent) -- -- -- -- -- | Insert guessing results into the sentence. Only interpretations -- -- -- of OOV words will be extended. -- -- include :: (X.Word w, Ord t) => [[t]] -> X.Sent w t -> X.Sent w t -- -- include xss sent = -- -- [ word { X.tags = tags } -- -- | (word, tags) <- zip sent sentTags ] -- -- where -- -- sentTags = -- -- [ if X.oov word -- -- then addInterps (X.tags word) xs -- -- else X.tags word -- -- | (xs, word) <- zip xss sent ] -- -- addInterps wm xs = X.mkWMap -- -- $ M.toList (X.unWMap wm) -- -- ++ zip xs [0, 0 ..] -- -- -- -- -- -- -- | Combine `guess` with `include`. -- -- guessSent :: (X.Word w, Ord t) -- -- => Int -> Guesser t -- -- -> X.Sent w t -> X.Sent w t -- -- guessSent guessNum guesser sent = -- -- include (guess guessNum guesser sent) sent -- -- -------------------------- -- Training -------------------------- -- | Method of constructing the default set of labels (R0). data R0T = AnyInterps -- ^ See `CRF.anyInterps` | AnyChosen -- ^ See `CRF.anyChosen` | OovChosen -- ^ See `CRF.oovChosen` deriving (Show, Eq, Ord, Enum, Typeable, Data) -- | Training configuration. data TrainConf t s = TrainConf { schemaConfT :: SchemaConf -- | SGD parameters. , sgdArgsT :: SGD.SgdArgs -- | Store SGD dataset on disk , onDiskT :: Bool -- | R0 construction method , r0T :: R0T -- | Zero probability label , zeroProbLabel :: t -- | Label simplification function , simplifyLabel :: t -> s -- | Strip the label from irrelevant information. Used to determine othe set of -- possible tags for unknown words. , stripLabel :: t -> t -- | Guess only visible features , onlyVisible :: Bool } -- | Train guesser. train :: (X.Word w, Ord t, Ord s) => TrainConf t s -- ^ Training configuration -> IO [X.Sent w t] -- ^ Training data -> IO [X.Sent w t] -- ^ Evaluation data -> IO (Guesser t s) train TrainConf{..} trainData evalData = do let schema = fromConf schemaConfT mkR0 = case r0T of AnyInterps -> CRF.anyInterps AnyChosen -> CRF.anyChosen OovChosen -> CRF.oovChosen featExtract = if onlyVisible then const CRF.presentFeats else \r0 -> CRF.hiddenFeats r0 . map (fmap fst) tagSet <- S.unions . map (tagSetIn stripLabel) <$> trainData -- let tagMap = M.fromList -- [ (t, simplifyLabel t) -- | t <- S.toList tagSet ] crf <- CRF.train sgdArgsT onDiskT -- mkR0 (const CRF.presentFeats) -- mkR0 (\r0 -> CRF.hiddenFeats r0 . map (fmap fst)) mkR0 featExtract (schemed simplifyLabel schema <$> trainData) (schemed simplifyLabel schema <$> evalData) return $ Guesser schemaConfT crf (simplifyLabel zeroProbLabel) tagSet simplifyLabel -- | Schematized dataset. schemed :: (X.Word w, Ord t, Ord s) => (t -> s) -> Schema w s a -> [X.Sent w t] -> [CRF.SentL Ob s] schemed simpl schema = map onSent where onSent dag0 = let dag = X.mapSent simpl dag0 mkProb = CRF.mkProb . M.toList . X.unWMap . X.tags in fmap (uncurry CRF.mkWordL) $ DAG.zipE (schematize schema dag) (fmap mkProb dag) -- | Retrieve the tagset in the given sentence, provided the stripping function -- (see `stripLabel`). tagSetIn :: (Ord t) => (t -> t) -> X.Sent w t -> S.Set t tagSetIn strip dag = S.fromList [ strip tag | edgeID <- DAG.dagEdges dag , let edge = DAG.edgeLabel edgeID dag , tag <- M.keys . X.unWMap . X.tags $ edge ]