{-# LANGUAGE FlexibleInstances , DeriveDataTypeable , TemplateHaskell , OverloadedStrings , NoMonomorphismRestriction , FlexibleContexts #-} module Main where import qualified Data.Text.Lazy.IO as Text import qualified Data.Text.Lazy as Text import qualified Data.Text.Lazy.Builder as Text import qualified Data.Text.Lazy.Builder.Int as Text import qualified Data.ByteString as BS import qualified Data.Serialize as Serialize import qualified Data.List as List import qualified Data.Vector.Generic as V import qualified System.Environment as Env import System.Console.CmdArgs.Explicit import qualified Data.Label as L import qualified Data.Label.Maybe as M import Prelude hiding ((.)) import Control.Category ((.)) import qualified NLP.CoNLL as CoNLL import qualified Colada.WordClass as C -- Command line parsing data Program = Help | Learn { _options :: C.Options , _modelPath :: FilePath } | Predict { _topn :: Int , _modelPath :: FilePath } | Label { _modelPath :: FilePath , _noContext :: Bool } deriving (Show) $(L.mkLabels [''Program]) help :: Mode Program help = mode "help" Help "Display help" (flagArg (\ x _ -> Left $ "Unexpected argument " ++ x) "") [] predict :: Mode Program predict = mode "predict" Predict { _topn = maxBound , _modelPath = "model" } "Predict words" (flagArg (\x p -> Right $ maybe p id (M.set modelPath x p)) "FILE") [ flagReq ["topn"] (\x p -> case safeRead x of Right n -> Right $ maybe p id (M.set topn n p) Left err -> Left err ) "NAT" "Number of most probable words to show" ] label :: Mode Program label = mode "label" Label { _modelPath = "model" , _noContext = False } "Label words with classes" (flagArg (\x p -> Right $ maybe p id (M.set modelPath x p)) "FILE") [ flagNone ["no-context"] (\p -> p { _noContext = True }) "Ignore context while labeling" ] learn :: Mode Program learn = let setOption = setOptionWith id setOptionWith f field x p = fmap (maybe p id . flip (M.set (field . options)) p) . fmap f . safeRead $ x in mode "learn" Learn { _options = C.defaultOptions , _modelPath = "model" } "Learn word classes" (flagArg (\x p -> Right $ maybe p id (M.set modelPath x p)) "FILE") [ flagReq ["features"] (\x p -> case x of "unigram" -> Right . maybe p id $ M.set (C.featIds . options) [-1,1] p "bigram" -> Right . maybe p id $ M.set (C.featIds . options) [-12,12] p _ -> Left $ "Unknown feature specification " ++ x) "(unigram|bigram)" "Feature specification" , flagReq ["topic-num"] (setOption C.topicNum) "NAT" "Number of topics K" , flagReq ["alphasum"] (setOption C.alphasum) "FLOAT" "Parameter alpha * K" , flagReq ["beta"] (setOption C.beta) "FLOAT" "Parameter beta" , flagReq ["passes"] (setOption C.passes) "NAT" "Passes per batch" , flagReq ["repeats"] (setOption C.repeats) "NAT" "Repeats per sentence" , flagReq ["batch-size"] (setOption C.batchSize) "NAT" "Sentences per batch" , flagReq ["seed"] (setOption C.seed) "NAT" "Random seed" , flagNone ["progressive"] (\p -> maybe p id . M.set (C.progressive . options) True $ p) "Label progressively" , flagReq ["lambda"] (setOption C.lambda) "FLOAT" "Interpolation parameter for progressive labeling" , flagReq ["init-size"] (setOption C.initSize) "NAT" "Data prefix size for batch initialization" , flagReq ["init-passes"] (setOption C.initPasses) "NAT" "Number of passes for initialization" , flagReq ["exponent"] (setOptionWith Just C.exponent) "FLOAT" "Exponent for learning rate" ] program :: Mode Program program = modes "colada" Help "Word class learning" [learn, predict, label, help] -- Run the program main :: IO () main = do args <- Env.getArgs let opts = processValue program args case opts of Help -> print $ helpText [] HelpFormatDefault program Predict { _topn = n, _modelPath = p } -> do -- FIXME: use Data.Text.Builder instead of converting to Lists let format s = {-# SCC "format" #-} Text.unlines [ Text.concat . List.intersperse "," . map snd . V.toList $ ws | ws <- s ] m <- L.set (C.topn . C.options) n `fmap` parseModel p ss <- CoNLL.parse `fmap` Text.getContents Text.putStr . Text.unlines . map (format . C.predict m) $ ss Label { _modelPath = p , _noContext = noctx } -> do m <- parseModel p ss <- CoNLL.parse `fmap` Text.getContents Text.putStr . Text.unlines . map (formatLabeling . V.map V.maxIndex . C.label noctx m) $ ss Learn { _options = o , _modelPath = p } -> do ss <- CoNLL.parse `fmap` Text.getContents let (m, ls) = C.learn o ss if (L.get C.progressive o) then do Text.putStr . Text.unlines . map formatLabeling $ ls else do Text.putStr . C.summary $ m BS.writeFile p . Serialize.encode $ m formatLabeling :: (V.Vector v Int, V.Vector v Text.Text) => v Int -> Text.Text formatLabeling = Text.unlines . V.toList . V.map (Text.toLazyText . Text.decimal) parseModel :: FilePath -> IO C.WordClass parseModel p = do (either (\err -> error $ "Error reading model " ++ err) id . Serialize.decode) `fmap` BS.readFile p safeRead :: Read b => String -> Either String b safeRead x = case reads x of [(a,"")] -> Right a _ -> Left $ "Couldn't parse " ++ show x