{-# LANGUAGE OverloadedStrings , FlexibleInstances , DeriveGeneric , NoMonomorphismRestriction , DeriveDataTypeable , TemplateHaskell , BangPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | Word Class induction with LDA -- -- This module provides function which implement word class induction -- using the generic algorithm implemented in Colada.LDA. -- -- You can access and set options in the @Options@ record using lenses. -- Example: -- -- > import Data.Label -- > let options = set passes 5 -- > . set beta 0.01 -- > . set topicNum 100 -- > $ defaultOptions -- > in run options sentences module Colada.WordClass ( -- * Running the sampler learn , defaultOptions -- * Extracting information , summary , summarize , wordTypeClasses -- * Class and word prediction , label , predict -- * Data types and associated lenses , WordClass , ldaModel -- | LDA model , wordTypeTable -- | Word type string to atom and vice versa conversion tables , featureTable -- | Feature string to atom and vice versa conversion tables , options -- | Options for Gibbs sampling , LDA.Finalized , LDA.docTopics , LDA.wordTopics , LDA.topics , LDA.topicDocs , LDA.topicWords , Options , featIds -- | Feature ids , topicNum -- | Number of topics K , alphasum -- | Dirichlet parameter alpha*K which controls topic sparseness , beta -- | Dirichlet parameter beta which controls word sparseness , passes -- | Number of sampling passes per batch , repeats -- | Number of repeats per sentences , batchSize -- | Number of sentences per batch , seed -- | Random seed for the sampler , topn -- | Number of most probable words to return , initSize , initPasses , exponent , progressive , pre , lambda ) where -- Standard libraries import qualified Data.Text.Lazy.IO as Text import qualified Data.Text.Lazy as Text import qualified Data.Text.Lazy.Builder as Text import qualified Data.Text.Lazy.Builder.Int as Text import qualified Data.Text.Lazy.Encoding as Text import qualified Data.Vector as V import qualified Data.Vector.Generic as G import qualified Data.Vector.Unboxed as U import qualified Data.IntMap as IntMap import qualified Data.Map as Map import qualified Data.Serialize as Serialize import qualified Control.Monad as M import qualified Data.List as List import qualified Data.List.Split as Split import qualified Data.Ord as Ord import qualified Data.Foldable as Fold import qualified Data.Traversable as Trav import qualified Control.Monad.ST as ST import qualified Control.Monad.ST.Lazy as LST import Data.Function (on) import Control.Monad.Writer import Data.Word (Word32) import Data.Typeable (Typeable) import Data.Data (Data) import Prelude hiding ((.), exponent) import Control.Category ((.)) import Control.Applicative ((<$>)) import qualified System.IO.Unsafe as Unsafe -- Third party modules import qualified Control.Monad.Atom as Atom import qualified NLP.CoNLL as CoNLL import qualified Data.List.Zipper as Zipper import GHC.Generics (Generic) import qualified Data.Label as L import Data.Label (get) import qualified NLP.SwiftLDA as LDA -- Package modules import qualified Colada.Features as F import qualified NLP.Symbols as Symbols import Debug.Trace -- | Container for the Word Class model data WordClass = WordClass { _ldaModel :: LDA.Finalized -- ^ LDA model , _wordTypeTable :: Atom.AtomTable (U.Vector Char) , _featureTable :: Atom.AtomTable (U.Vector Char) , _options :: Options } deriving (Generic) data Options = Options { _featIds :: [Int] , _topicNum :: !Int , _alphasum :: !Double , _beta :: !Double , _passes :: !Int , _repeats :: !Int , _batchSize :: !Int , _seed :: !Word32 , _topn :: !Int , _initSize :: !Int , _initPasses :: !Int , _exponent :: !(Maybe Double) , _progressive :: !Bool , _pre :: !Bool , _lambda :: !Double } deriving (Eq, Show, Typeable, Data, Generic) instance Serialize.Serialize Options $(L.mkLabels [''WordClass, ''Options]) defaultOptions :: Options defaultOptions = Options { _featIds = [-1,1] , _topicNum = 10 , _alphasum = 10 , _beta = 0.1 , _passes = 1 , _repeats = 1 , _batchSize = 1 , _seed = 0 , _topn = maxBound , _initSize = 0 , _initPasses = 100 , _exponent = Nothing , _progressive= False , _pre = False , _lambda = 1.0 } -- | @learn options xs@ runs the LDA Gibbs sampler for word classes -- with @options@ on sentences @xs@, and returns the resulting model -- together with progressive class assignments learn :: Options -> [CoNLL.Sentence] -> (WordClass, [V.Vector (U.Vector Double)]) learn opts xs = let ((bs_init, bs_rest), atomTabD, atomTabW) = Symbols.runSymbols prepare Symbols.empty Symbols.empty prepare = do let (xs_init, xs_rest) = List.splitAt (get initSize opts) xs ini <- prepareData (get initSize opts) 1 (get featIds opts) xs_init rest <- prepareData (get batchSize opts) (get repeats opts) (get featIds opts) xs_rest return (ini, rest) sampler :: WriterT [V.Vector (U.Vector Double)] (LST.ST s) LDA.Finalized sampler = do m <- st $ LDA.initial (U.singleton (get seed opts)) (get topicNum opts) (get alphasum opts) (get beta opts) (get exponent opts) let loop t i_last batch i = do let label_prog = do Fold.forM_ batch $ \rep -> do let sent = V.head rep ls <- st $ V.mapM (interpWordClasses m (get lambda opts)) sent tell [ls] -- Either label before sampling (--pre) M.when (get progressive opts && i == 1 && get pre opts) label_prog r <- st $ Trav.forM batch $ \rep -> do Trav.forM rep $ \sent -> do LDA.pass t m sent -- Or label after sampling M.when (get progressive opts && i == i_last && not (get pre opts)) label_prog return $! r -- Initialize with batch sampler on prefix sbs_init Fold.forM_ bs_init $ \batch -> do Fold.foldlM (loop 1 $ get initPasses opts) batch [1..get initPasses opts] -- Continue sampling Fold.forM_ (zip [1..] bs_rest) $ \(t, batch) -> do Fold.foldlM (loop t $ get passes opts) batch [1..get passes opts] st $ LDA.finalize m (lda, labeled) = LST.runST (runWriterT sampler) in (WordClass lda atomTabD atomTabW opts, labeled) type Symb = Symbols.Symbols (U.Vector Char) (U.Vector Char) type Sent = V.Vector LDA.Doc type Repeat = V.Vector Sent type Batch = V.Vector Repeat -- | Convert a stream of sentences into a stream of batches ready for -- sampling. prepareData :: Int -- ^ batch size -> Int -- ^ no. repeats -> [Int] -- ^ feature indices -> [CoNLL.Sentence] -- ^ stream of sentences -> Symb [Batch] -- ^ stream of batches prepareData bsz rep is ss = do ss' <- mapM symbolize . map (featurize is) $ ss return $! (map V.fromList . Split.chunksOf bsz . map (V.replicate rep) $ ss') -- | Extract features from a sentence featurize :: [Int] -> CoNLL.Sentence -> [(Text.Text, [Text.Text])] featurize is s = let mk fs = let d = IntMap.findWithDefault (error "parseData: focus feature missing") 0 fs ws = [ Text.concat [f,"^",Text.pack . show $ i ] | i <- is , Just f <- [IntMap.lookup i fs] ] in (d, ws) in map mk . extractFeatures $ s -- | Convert text strings into symbols (ints) symbolize :: [(Text.Text, [Text.Text])] -> Symb Sent symbolize s = V.fromList <$> mapM doc s where doc (d, ws) = do da <- Symbols.toAtomA . compress $ d was <- mapM (Symbols.toAtomB . compress) ws return (da, U.fromList $ zip was (repeat Nothing)) -- | @summary m@ returns a textual summary of word classes found in -- model @m@ summary :: WordClass -> Text.Text summary = summarize False summarize :: Bool -> WordClass -> Text.Text summarize harden m = let format (z,cs) = do cs' <- mapM (Atom.fromAtom . fst) . takeWhile ((>0) . snd) . take 10 . List.sortBy (flip $ Ord.comparing snd) . IntMap.toList $ cs return . Text.unwords $ Text.pack (show z) : map (Text.pack . U.toList) cs' in fst . flip Atom.runAtom (get wordTypeTable m) . M.liftM Text.unlines . mapM format . IntMap.toList . IntMap.fromListWith (IntMap.unionWith (+)) . concatMap (\(d,zs) -> [ (z, IntMap.singleton d c) | (z,c) <- zs ]) -- Maybe harden . (if harden then map (\(d,zs) -> let s = sum . map snd $ zs (z',_) = List.maximumBy (Ord.compare `on` snd) zs in (d, [ (z, if z == z' then s else 0) | (z,_) <- zs ])) else id) -- . IntMap.toList . IntMap.map IntMap.toList . LDA.docTopics . get ldaModel $ m -- | @interpWordClasses m lambda doc@ gives the class probabilities for -- word type in context @doc@ according to evolving model @m@. It -- interpolates the prior word type probability with the -- context-conditioned probabilities using alpha: -- P(d,w) = lambda * P(z|d) + (1-lambda) * P(z|d,w) interpWordClasses :: LDA.LDA s -> Double -> LDA.Doc -> ST.ST s (U.Vector Double) interpWordClasses m lambda doc@(d,_) = do pzd <- normalize <$> LDA.priorDocTopicWeights_ m d pzdw <- normalize <$> LDA.docTopicWeights_ m doc return $! normalize $ U.zipWith (\p q -> lambda * p + (1-lambda) * q) pzd pzdw where normalize x = let uniform = U.replicate (U.length x) (1 / (fromIntegral (U.length x))) in case U.sum x of 0 -> uniform s | s >= 1/0 -> uniform s -> U.map (/s) x -- | @wordTypeClasses m@ returns a Map from word types to unnormalized -- distributions over word classes wordTypeClasses :: WordClass -> Map.Map Text.Text (IntMap.IntMap Double) wordTypeClasses m = fst . flip Atom.runAtom (get wordTypeTable m) . fmap Map.fromList . mapM (\(k,v) -> do k' <- Atom.fromAtom k ; return (decompress k',v)) . IntMap.toList . LDA.docTopics . get ldaModel $ m -- | @label m s@ returns for each word in sentences s, -- unnormalized probabilities of word classes. label :: Bool -> WordClass -> CoNLL.Sentence -> V.Vector (U.Vector Double) label noctx m s = fst3 $ Symbols.runSymbols label' (L.get wordTypeTable m) (L.get featureTable m) where dectx doc@(d, _) = if noctx then (d, U.singleton (-1,Nothing)) --FIXME: ugly hack else doc label' = do let fm = L.get ldaModel m s' <- prepareSent m s return $! V.map (LDA.docTopicWeights fm . dectx) $ s' -- | @predict m s@ returns for each word in sentence s, unnormalized -- probabilities of words given predicted word class. predict :: WordClass -> CoNLL.Sentence -> [V.Vector (Double, Text.Text)] predict m s = fst3 $ Symbols.runSymbols predict' (L.get wordTypeTable m) (L.get featureTable m) where predict' = do let fm = L.get ldaModel m s' <- prepareSent m s let ws = map ( G.convert . predictDoc (get (topn . options) m) fm . docToWs ) . V.toList $ s' mapM (V.mapM fromAtom) ws docToWs = U.map fst . snd fromAtom (n,w) = do w' <- Symbols.fromAtomA w return (n, decompress w') prepareSent :: WordClass -> CoNLL.Sentence -> Symb Sent prepareSent m = symbolize . featurize (L.get (featIds . options) m) -- | @predictDoc n m ws@ returns unnormalized probabilities of top @n@ -- most probable document ids given the model @m@ and words @ws@. The -- candidate document ids are taken from the model @m@. The weights -- are computed according to the following formula: -- -- > P(d|{w}) ∝ Σ_z[n(d,z)+a Σ_{w in ws}(n(w,z)+b)/(Σ_{w in V} n(w,z)+b)] predictDoc :: Int -> LDA.Finalized -> U.Vector LDA.W -> U.Vector (Double, LDA.D) predictDoc n m ws = let k = LDA.topicNum m a = LDA.alphasum m / fromIntegral k b = LDA.beta m v = fromIntegral . LDA.wSize $ m zt = LDA.topics m wsums = IntMap.fromList [ (z, U.sum . U.map (\w -> (count w wt_z + b) / denom) $ ws) | z <- IntMap.keys zt , let wt_z = IntMap.findWithDefault IntMap.empty z . LDA.topicWords $ m denom = count z zt + b * v ] wsum z = IntMap.findWithDefault (error $ "Colada.WordClass.predictDoc: key not found: " ++ show z) z wsums in U.fromList . take n . List.sortBy (flip compare) $ [ ( sum [ (c + a) * wsum z | (z,c) <- IntMap.toList zt_d ] , d) | (d, zt_d) <- IntMap.toList . LDA.docTopics $ m ] extractFeatures :: CoNLL.Sentence -> [IntMap.IntMap Text.Text] extractFeatures = F.featureSeq combine . Zipper.fromList . map (V.! 0) . V.toList combine :: Maybe Text.Text -> Maybe Text.Text -> Maybe Text.Text combine (Just a) (Just b) = Just $ Text.concat [a, "|", b] combine (Just a) Nothing = Just a combine Nothing (Just b) = Just b combine Nothing Nothing = Nothing compress :: Text.Text -> U.Vector Char compress = U.fromList . Text.unpack decompress :: U.Vector Char -> Text.Text decompress = Text.pack . U.toList fst3 :: (a, b, c) -> a fst3 (a,_,_) = a count :: Int -> IntMap.IntMap Double -> Double count z t = case IntMap.findWithDefault 0 z t of n | n < 0 -> error "Colada.WordClass.count: negative count" n -> n {-# INLINE count #-} st :: Monoid w => ST.ST s a -> WriterT w (LST.ST s) a st = lift . LST.strictToLazyST -- Instances for serialization instance Serialize.Serialize LDA.Finalized instance Serialize.Serialize Text.Text where put = Serialize.put . Text.encodeUtf8 get = Text.decodeUtf8 `fmap` Serialize.get instance Serialize.Serialize (U.Vector Char) where put v = Serialize.put (Text.pack . U.toList $ v) get = do t <- Serialize.get return $! U.fromList . Text.unpack $ t instance Serialize.Serialize (Atom.AtomTable (U.Vector Char)) instance Serialize.Serialize WordClass