{-# LANGUAGE FlexibleInstances , DeriveDataTypeable , TemplateHaskell , OverloadedStrings , NoMonomorphismRestriction , FlexibleContexts #-} module Main where import qualified Data.Text.Lazy.IO as Text import qualified Data.Text.Lazy as Text import qualified Data.Text.Lazy.Builder as Text import qualified Data.Text.Lazy.Builder.Int as Text import qualified Data.ByteString as BS import qualified Data.Serialize as Serialize import qualified Data.List as List import qualified Data.Vector.Generic as V import qualified Data.Vector.Unboxed as U import qualified System.Environment as Env import qualified Data.Label as L import qualified Data.Label.Maybe as M import qualified NLP.CoNLL as CoNLL import qualified Colada.WordClass as C import qualified Text.Printf as Printf import System.Console.CmdArgs.Explicit import Prelude hiding ((.)) import Control.Category ((.)) -- Command line parsing data Program = Help | Learn { _options :: C.Options , _modelPath :: FilePath } | Predict { _topn :: Int , _modelPath :: FilePath } | Label { _modelPath :: FilePath , _noContext :: Bool } | Summary { _modelPath :: FilePath ,_harden :: Bool } deriving (Show) $(L.mkLabels [''Program]) help :: Mode Program help = mode "help" Help "Display help" (flagArg (\ x _ -> Left $ "Unexpected argument " ++ x) "") [] predict :: Mode Program predict = mode "predict" Predict { _topn = maxBound , _modelPath = "model" } "Predict words" (flagArg (\x p -> Right $ maybe p id (M.set modelPath x p)) "FILE") [ flagReq ["topn"] (\x p -> case safeRead x of Right n -> Right $ maybe p id (M.set topn n p) Left err -> Left err ) "NAT" "Number of most probable words to show" ] summary :: Mode Program summary = mode "summary" Summary { _modelPath = "model" , _harden = False } "Display summary of word classes" (flagArg (\x p -> Right $ maybe p id (M.set modelPath x p)) "FILE") [ flagNone ["harden"] (\p -> p { _harden = True }) "Harden class assignments for summary" ] label :: Mode Program label = mode "label" Label { _modelPath = "model" , _noContext = False } "Label words with classes" (flagArg (\x p -> Right $ maybe p id (M.set modelPath x p)) "FILE") [ flagNone ["no-context"] (\p -> p { _noContext = True }) "Ignore context while labeling" ] learn :: Mode Program learn = let setOption = setOptionWith id setOptionWith f field x p = fmap (maybe p id . flip (M.set (field . options)) p) . fmap f . safeRead $ x in mode "learn" Learn { _options = C.defaultOptions , _modelPath = "model" } "Learn word classes" (flagArg (\x p -> Right $ maybe p id (M.set modelPath x p)) "FILE") [ flagReq ["features"] (\x p -> case x of "unigram" -> Right . maybe p id $ M.set (C.featIds . options) [-1,1] p "bigram" -> Right . maybe p id $ M.set (C.featIds . options) [-12,12] p _ -> Left $ "Unknown feature specification " ++ x) "(unigram|bigram)" "Feature specification" , flagReq ["topic-num"] (setOption C.topicNum) "NAT" "Number of topics K" , flagReq ["alphasum"] (setOption C.alphasum) "FLOAT" "Parameter alpha * K" , flagReq ["beta"] (setOption C.beta) "FLOAT" "Parameter beta" , flagReq ["passes"] (setOption C.passes) "NAT" "Passes per batch" , flagReq ["repeats"] (setOption C.repeats) "NAT" "Repeats per sentence" , flagReq ["batch-size"] (setOption C.batchSize) "NAT" "Sentences per batch" , flagReq ["seed"] (setOption C.seed) "NAT" "Random seed" , flagNone ["progressive"] (\p -> maybe p id . M.set (C.progressive . options) True $ p) "Label progressively" , flagReq ["lambda"] (setOption C.lambda) "FLOAT" "Interpolation parameter for progressive labeling" , flagReq ["init-size"] (setOption C.initSize) "NAT" "Data prefix size for batch initialization" , flagReq ["init-passes"] (setOption C.initPasses) "NAT" "Number of passes for initialization" , flagReq ["exponent"] (setOptionWith Just C.exponent) "FLOAT" "Exponent for learning rate" ] program :: Mode Program program = modes "colada" Help "Word class learning" [learn, predict, label, summary, help] -- Run the program main :: IO () main = do args <- Env.getArgs let opts = processValue program args case opts of Help -> print $ helpText [] HelpFormatDefault program Predict { _topn = n, _modelPath = p } -> do -- FIXME: use Data.Text.Builder instead of converting to Lists let format s = {-# SCC "format" #-} Text.unlines [ Text.concat . List.intersperse "," . map snd . V.toList $ ws | ws <- s ] m <- L.set (C.topn . C.options) n `fmap` parseModel p ss <- CoNLL.parse `fmap` Text.getContents Text.putStr . Text.unlines . map (format . C.predict m) $ ss Label { _modelPath = p , _noContext = noctx } -> do m <- parseModel p ss <- CoNLL.parse `fmap` Text.getContents Text.putStr . Text.unlines . map (formatLabeling . V.map V.maxIndex . C.label noctx m) $ ss Learn { _options = o , _modelPath = p } -> do ss <- CoNLL.parse `fmap` Text.getContents let (m, ls) = C.learn o ss if (L.get C.progressive o) then do Text.putStr . Text.unlines . map formatFullLabeling $ ls else do Text.putStr . C.summary $ m BS.writeFile p . Serialize.encode $ m Summary { _modelPath = p , _harden = h } -> do m <- parseModel p if h then do Text.putStr . C.summarize True $ m else do Text.putStr . C.summarize False $ m formatLabeling :: (V.Vector v Int, V.Vector v Text.Text) => v Int -> Text.Text formatLabeling = Text.unlines . V.toList . V.map (Text.toLazyText . Text.decimal) formatFullLabeling = Text.unlines . map (Text.unwords . map (Text.pack . Printf.printf "%.3f") . U.toList) . V.toList parseModel :: FilePath -> IO C.WordClass parseModel p = do (either (\err -> error $ "Error reading model " ++ err) id . Serialize.decode) `fmap` BS.readFile p safeRead :: Read b => String -> Either String b safeRead x = case reads x of [(a,"")] -> Right a _ -> Left $ "Couldn't parse " ++ show x