{-# LANGUAGE RecordWildCards #-}
module NLP.Concraft.DAG2
(
Concraft (..)
, saveModel
, loadModel
, Anno
, replace
, findOptimalPaths
, disambPath
, guessMarginals
, disambMarginals
, disambProbs
, guess
, guessSent
, tag
, train
, prune
) where
import System.IO (hClose)
import Control.Applicative ((<$>), (<*>))
import Control.Arrow (first)
import Control.Monad (when, guard)
import qualified Data.Foldable as F
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Data.Binary (Binary, put, get, Put, Get)
import qualified Data.Binary as Binary
import Data.Binary.Put (runPut)
import Data.Binary.Get (runGet)
import Data.Aeson
import qualified System.IO.Temp as Temp
import qualified Data.ByteString.Lazy as BL
import qualified Codec.Compression.GZip as GZip
import Data.DAG (DAG, EdgeID)
import qualified Data.DAG as DAG
import qualified Data.Tagset.Positional as P
import NLP.Concraft.Format.Temp
import qualified NLP.Concraft.DAG.Morphosyntax as X
import NLP.Concraft.DAG.Morphosyntax (Sent, WMap)
import qualified NLP.Concraft.DAG.Guess as G
import qualified NLP.Concraft.DAG.Disamb as D
modelVersion :: String
modelVersion = "dag2:0.11"
data Concraft t = Concraft
{ tagset :: P.Tagset
, guessNum :: Int
, guesser :: G.Guesser t P.Tag
, disamb :: D.Disamb t }
putModel :: (Ord t, Binary t) => Concraft t -> Put
putModel Concraft{..} = do
put modelVersion
put tagset
put guessNum
G.putGuesser guesser
D.putDisamb disamb
getModel
:: (Ord t, Binary t)
=> (P.Tagset -> t -> P.Tag)
-> Get (Concraft t)
getModel smp = do
comp <- get
when (comp /= modelVersion) $ error $
"Incompatible model version: " ++ comp ++
", expected: " ++ modelVersion
tagset <- get
Concraft tagset <$> get <*> G.getGuesser (smp tagset) <*> D.getDisamb (smp tagset)
saveModel :: (Ord t, Binary t) => FilePath -> Concraft t -> IO ()
saveModel path = BL.writeFile path . GZip.compress . runPut . putModel
loadModel :: (Ord t, Binary t) => (P.Tagset -> t -> P.Tag) -> FilePath -> IO (Concraft t)
loadModel smp path = do
x <- runGet (getModel smp) . GZip.decompress <$> BL.readFile path
x `seq` return x
type Anno a b = DAG () (M.Map a b)
replace :: (Ord t) => Anno t Double -> Sent w t -> Sent w t
replace anno sent =
fmap join $ DAG.zipE anno sent
where
join (m, seg) = seg {X.tags = X.fromMap m}
extract :: Sent w t -> Anno t Double
extract = fmap $ X.unWMap . X.tags
findOptimalPaths :: Anno t Double -> [[(EdgeID, t)]]
findOptimalPaths dag = do
edgeID <- DAG.dagEdges dag
guard $ DAG.isInitialEdge edgeID dag
doit edgeID
where
doit i = inside i ++ final i
inside i = do
(tag, weight) <- M.toList (DAG.edgeLabel i dag)
guard $ weight >= 1.0 - eps
j <- DAG.nextEdges i dag
xs <- doit j
return $ (i, tag) : xs
final i = do
guard $ DAG.isFinalEdge i dag
(tag, weight) <- M.toList (DAG.edgeLabel i dag)
guard $ weight >= 1.0 - eps
return [(i, tag)]
eps = 1.0e-9
disambPath :: (Ord t) => [(EdgeID, t)] -> Anno t Double -> Anno t Bool
disambPath path =
DAG.mapE doit
where
pathMap = M.fromList path
doit edgeID m = M.fromList $ do
let onPath = M.lookup edgeID pathMap
x <- M.keys m
return (x, Just x == onPath)
guessMarginals :: (X.Word w, Ord t) => G.Guesser t P.Tag -> Sent w t -> Anno t Double
guessMarginals gsr = fmap X.unWMap . G.marginals gsr
disambMarginals :: (X.Word w, Ord t) => D.Disamb t -> Sent w t -> Anno t Double
disambMarginals = disambProbs D.Marginals
disambProbs :: (X.Word w, Ord t) => D.ProbType -> D.Disamb t -> Sent w t -> Anno t Double
disambProbs typ dmb = fmap X.unWMap . D.probs typ dmb
trimOOV :: (X.Word w, Ord t) => Int -> Sent w t -> Sent w t
trimOOV k =
fmap trim
where
trim edge = if X.oov edge
then trimEdge edge
else edge
trimEdge edge = edge {X.tags = X.trim k (X.tags edge)}
guessSent :: (X.Word w, Ord t) => Int -> G.Guesser t P.Tag -> Sent w t -> Sent w t
guessSent k gsr sent = trimOOV k $ replace (guessMarginals gsr sent) sent
guess :: (X.Word w, Ord t) => Int -> G.Guesser t P.Tag -> Sent w t -> Anno t Double
guess k gsr = extract . guessSent k gsr
tag :: (X.Word w, Ord t) => Int -> Concraft t -> Sent w t -> Anno t Double
tag k crf = disambMarginals (disamb crf) . guessSent k (guesser crf)
train
:: (X.Word w, Ord t)
=> P.Tagset
-> Int
-> G.TrainConf t P.Tag
-> D.TrainConf t
-> IO [Sent w t]
-> IO [Sent w t]
-> IO (Concraft t)
train tagset guessNum guessConf disambConf trainR'IO evalR'IO = do
Temp.withTempDirectory "." ".guessed" $ \tmpDir -> do
let temp = withTemp tagset tmpDir
putStrLn "\n===== Train guessing model ====="
guesser <- G.train guessConf trainR'IO evalR'IO
let guess = guessSent guessNum guesser
trainG <- map guess <$> trainR'IO
evalG <- map guess <$> evalR'IO
temp "train" trainG $ \trainG'IO -> do
temp "eval" evalG $ \evalG'IO -> do
putStrLn "\n===== Train disambiguation model ====="
disamb <- D.train disambConf trainG'IO evalG'IO
return $ Concraft tagset guessNum guesser disamb
withTemp
:: P.Tagset
-> FilePath
-> String
-> [Sent w t]
-> (IO [Sent w t] -> IO a)
-> IO a
withTemp _ _ _ [] handler = handler (return [])
withTemp tagset dir tmpl xs handler =
Temp.withTempFile dir tmpl $ \tmpPath tmpHandle -> do
hClose tmpHandle
let txtSent = X.mapSent $ P.showTag tagset
tagSent = X.mapSent $ P.parseTag tagset
handler (return xs)
prune :: Double -> Concraft t -> Concraft t
prune x concraft =
let disamb' = D.prune x (disamb concraft)
in concraft { disamb = disamb' }