module Colada.WordClass
(
learn
, defaultOptions
, summary
, summarize
, wordTypeClasses
, label
, predict
, WordClass
, ldaModel
, wordTypeTable
, featureTable
, options
, LDA.Finalized
, LDA.docTopics
, LDA.wordTopics
, LDA.topics
, LDA.topicDocs
, LDA.topicWords
, Options
, featIds
, topicNum
, alphasum
, beta
, passes
, repeats
, batchSize
, seed
, topn
, initSize
, initPasses
, exponent
, progressive
, pre
, lambda
)
where
import qualified Data.Text.Lazy.IO as Text
import qualified Data.Text.Lazy as Text
import qualified Data.Text.Lazy.Builder as Text
import qualified Data.Text.Lazy.Builder.Int as Text
import qualified Data.Text.Lazy.Encoding as Text
import qualified Data.Vector as V
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U
import qualified Data.IntMap as IntMap
import qualified Data.Map as Map
import qualified Data.Serialize as Serialize
import qualified Control.Monad as M
import qualified Data.List as List
import qualified Data.List.Split as Split
import qualified Data.Ord as Ord
import qualified Data.Foldable as Fold
import qualified Data.Traversable as Trav
import qualified Control.Monad.ST as ST
import qualified Control.Monad.ST.Lazy as LST
import Data.Function (on)
import Control.Monad.Writer
import Data.Word (Word32)
import Data.Typeable (Typeable)
import Data.Data (Data)
import Prelude hiding ((.), exponent)
import Control.Category ((.))
import Control.Applicative ((<$>))
import qualified System.IO.Unsafe as Unsafe
import qualified Control.Monad.Atom as Atom
import qualified NLP.CoNLL as CoNLL
import qualified Data.List.Zipper as Zipper
import GHC.Generics (Generic)
import qualified Data.Label as L
import Data.Label (get)
import qualified NLP.SwiftLDA as LDA
import qualified Colada.Features as F
import qualified NLP.Symbols as Symbols
import Debug.Trace
data WordClass =
WordClass { _ldaModel :: LDA.Finalized
, _wordTypeTable :: Atom.AtomTable (U.Vector Char)
, _featureTable :: Atom.AtomTable (U.Vector Char)
, _options :: Options
}
deriving (Generic)
data Options = Options { _featIds :: [Int]
, _topicNum :: !Int
, _alphasum :: !Double
, _beta :: !Double
, _passes :: !Int
, _repeats :: !Int
, _batchSize :: !Int
, _seed :: !Word32
, _topn :: !Int
, _initSize :: !Int
, _initPasses :: !Int
, _exponent :: !(Maybe Double)
, _progressive :: !Bool
, _pre :: !Bool
, _lambda :: !Double
}
deriving (Eq, Show, Typeable, Data, Generic)
instance Serialize.Serialize Options
$(L.mkLabels [''WordClass, ''Options])
defaultOptions :: Options
defaultOptions = Options { _featIds = [1,1]
, _topicNum = 10
, _alphasum = 10
, _beta = 0.1
, _passes = 1
, _repeats = 1
, _batchSize = 1
, _seed = 0
, _topn = maxBound
, _initSize = 0
, _initPasses = 100
, _exponent = Nothing
, _progressive= False
, _pre = False
, _lambda = 1.0
}
learn :: Options
-> [CoNLL.Sentence]
-> (WordClass, [V.Vector (U.Vector Double)])
learn opts xs =
let ((bs_init, bs_rest), atomTabD, atomTabW) =
Symbols.runSymbols prepare Symbols.empty Symbols.empty
prepare = do
let (xs_init, xs_rest) = List.splitAt (get initSize opts) xs
ini <- prepareData (get initSize opts)
1
(get featIds opts)
xs_init
rest <- prepareData (get batchSize opts)
(get repeats opts)
(get featIds opts)
xs_rest
return (ini, rest)
sampler :: WriterT [V.Vector (U.Vector Double)] (LST.ST s) LDA.Finalized
sampler = do
m <- st $ LDA.initial (U.singleton (get seed opts))
(get topicNum opts)
(get alphasum opts)
(get beta opts)
(get exponent opts)
let loop t i_last batch i = do
let label_prog = do
Fold.forM_ batch $ \rep -> do
let sent = V.head rep
ls <- st $ V.mapM (interpWordClasses m (get lambda opts)) sent
tell [ls]
M.when (get progressive opts && i == 1 && get pre opts) label_prog
r <- st $ Trav.forM batch $ \rep -> do
Trav.forM rep $ \sent -> do
LDA.pass t m sent
M.when (get progressive opts && i == i_last && not (get pre opts)) label_prog
return $! r
Fold.forM_ bs_init $ \batch -> do
Fold.foldlM (loop 1 $ get initPasses opts) batch [1..get initPasses opts]
Fold.forM_ (zip [1..] bs_rest) $ \(t, batch) -> do
Fold.foldlM (loop t $ get passes opts) batch [1..get passes opts]
st $ LDA.finalize m
(lda, labeled) = LST.runST (runWriterT sampler)
in (WordClass lda atomTabD atomTabW opts, labeled)
type Symb = Symbols.Symbols (U.Vector Char) (U.Vector Char)
type Sent = V.Vector LDA.Doc
type Repeat = V.Vector Sent
type Batch = V.Vector Repeat
prepareData :: Int
-> Int
-> [Int]
-> [CoNLL.Sentence]
-> Symb [Batch]
prepareData bsz rep is ss = do
ss' <- mapM symbolize . map (featurize is) $ ss
return $! (map V.fromList . Split.chunksOf bsz . map (V.replicate rep) $ ss')
featurize :: [Int]
-> CoNLL.Sentence
-> [(Text.Text, [Text.Text])]
featurize is s =
let mk fs =
let d = IntMap.findWithDefault
(error "parseData: focus feature missing") 0 fs
ws = [ Text.concat [f,"^",Text.pack . show $ i ]
| i <- is , Just f <- [IntMap.lookup i fs] ]
in (d, ws)
in map mk . extractFeatures $ s
symbolize :: [(Text.Text, [Text.Text])] -> Symb Sent
symbolize s = V.fromList <$> mapM doc s
where doc (d, ws) = do
da <- Symbols.toAtomA . compress $ d
was <- mapM (Symbols.toAtomB . compress) ws
return (da, U.fromList $ zip was (repeat Nothing))
summary :: WordClass -> Text.Text
summary = summarize False
summarize :: Bool -> WordClass -> Text.Text
summarize harden m =
let format (z,cs) = do
cs' <- mapM (Atom.fromAtom . fst)
. takeWhile ((>0) . snd)
. take 10
. List.sortBy (flip $ Ord.comparing snd) . IntMap.toList
$ cs
return . Text.unwords $ Text.pack (show z)
: map (Text.pack . U.toList) cs'
in fst . flip Atom.runAtom (get wordTypeTable m)
. M.liftM Text.unlines
. mapM format
. IntMap.toList
. IntMap.fromListWith (IntMap.unionWith (+))
. concatMap (\(d,zs) -> [ (z, IntMap.singleton d c) | (z,c) <- zs ])
. (if harden
then
map (\(d,zs) -> let s = sum . map snd $ zs
(z',_) = List.maximumBy (Ord.compare `on` snd) zs
in (d, [ (z, if z == z' then s else 0) | (z,_) <- zs ]))
else id)
. IntMap.toList
. IntMap.map IntMap.toList
. LDA.docTopics
. get ldaModel
$ m
interpWordClasses :: LDA.LDA s
-> Double
-> LDA.Doc
-> ST.ST s (U.Vector Double)
interpWordClasses m lambda doc@(d,_) = do
pzd <- normalize <$> LDA.priorDocTopicWeights_ m d
pzdw <- normalize <$> LDA.docTopicWeights_ m doc
return $! normalize $ U.zipWith (\p q -> lambda * p + (1lambda) * q) pzd pzdw
where normalize x =
let uniform = U.replicate (U.length x) (1 / (fromIntegral (U.length x)))
in case U.sum x of
0 -> uniform
s | s >= 1/0 -> uniform
s -> U.map (/s) x
wordTypeClasses :: WordClass -> Map.Map Text.Text (IntMap.IntMap Double)
wordTypeClasses m =
fst . flip Atom.runAtom (get wordTypeTable m)
. fmap Map.fromList
. mapM (\(k,v) -> do k' <- Atom.fromAtom k ; return (decompress k',v))
. IntMap.toList
. LDA.docTopics
. get ldaModel
$ m
label :: Bool -> WordClass -> CoNLL.Sentence -> V.Vector (U.Vector Double)
label noctx m s = fst3 $ Symbols.runSymbols label'
(L.get wordTypeTable m)
(L.get featureTable m)
where dectx doc@(d, _) = if noctx
then (d, U.singleton (1,Nothing))
else doc
label' = do
let fm = L.get ldaModel m
s' <- prepareSent m s
return $! V.map (LDA.docTopicWeights fm . dectx)
$ s'
predict :: WordClass -> CoNLL.Sentence
-> [V.Vector (Double, Text.Text)]
predict m s = fst3 $ Symbols.runSymbols predict' (L.get wordTypeTable m)
(L.get featureTable m)
where predict' = do
let fm = L.get ldaModel m
s' <- prepareSent m s
let ws = map ( G.convert
. predictDoc (get (topn . options) m) fm
. docToWs )
. V.toList
$ s'
mapM (V.mapM fromAtom) ws
docToWs = U.map fst . snd
fromAtom (n,w) = do w' <- Symbols.fromAtomA w
return (n, decompress w')
prepareSent :: WordClass -> CoNLL.Sentence -> Symb Sent
prepareSent m = symbolize . featurize (L.get (featIds . options) m)
predictDoc :: Int -> LDA.Finalized -> U.Vector LDA.W
-> U.Vector (Double, LDA.D)
predictDoc n m ws =
let k = LDA.topicNum m
a = LDA.alphasum m / fromIntegral k
b = LDA.beta m
v = fromIntegral . LDA.wSize $ m
zt = LDA.topics m
wsums =
IntMap.fromList
[ (z, U.sum . U.map (\w -> (count w wt_z + b) / denom) $ ws)
| z <- IntMap.keys zt
, let wt_z = IntMap.findWithDefault IntMap.empty z . LDA.topicWords
$ m
denom = count z zt + b * v ]
wsum z = IntMap.findWithDefault
(error $ "Colada.WordClass.predictDoc: key not found: "
++ show z)
z wsums
in U.fromList . take n . List.sortBy (flip compare)
$ [ ( sum [ (c + a) * wsum z | (z,c) <- IntMap.toList zt_d ] , d)
| (d, zt_d) <- IntMap.toList . LDA.docTopics $ m
]
extractFeatures :: CoNLL.Sentence -> [IntMap.IntMap Text.Text]
extractFeatures = F.featureSeq combine
. Zipper.fromList
. map (V.! 0)
. V.toList
combine :: Maybe Text.Text -> Maybe Text.Text -> Maybe Text.Text
combine (Just a) (Just b) = Just $ Text.concat [a, "|", b]
combine (Just a) Nothing = Just a
combine Nothing (Just b) = Just b
combine Nothing Nothing = Nothing
compress :: Text.Text -> U.Vector Char
compress = U.fromList . Text.unpack
decompress :: U.Vector Char -> Text.Text
decompress = Text.pack . U.toList
fst3 :: (a, b, c) -> a
fst3 (a,_,_) = a
count :: Int -> IntMap.IntMap Double -> Double
count z t = case IntMap.findWithDefault 0 z t of
n | n < 0 -> error "Colada.WordClass.count: negative count"
n -> n
st :: Monoid w => ST.ST s a -> WriterT w (LST.ST s) a
st = lift . LST.strictToLazyST
instance Serialize.Serialize LDA.Finalized
instance Serialize.Serialize Text.Text where
put = Serialize.put . Text.encodeUtf8
get = Text.decodeUtf8 `fmap` Serialize.get
instance Serialize.Serialize (U.Vector Char) where
put v = Serialize.put (Text.pack . U.toList $ v)
get =
do t <- Serialize.get
return $! U.fromList . Text.unpack $ t
instance Serialize.Serialize (Atom.AtomTable (U.Vector Char))
instance Serialize.Serialize WordClass