{-# LANGUAGE DeriveGeneric , BangPatterns #-} -- | Latent Dirichlet Allocation -- -- Simple implementation of a collapsed Gibbs sampler for LDA. This -- library uses the topic modeling terminology (documents, words, -- topics), even though it is generic. For example if used for word -- class induction, replace documents with word types, words with -- features and topics with word classes. module NLP.LDA ( -- * Running samplers runSampler , pass , runLDA -- * Datatypes , Sampler , LDA , Finalized , Doc , D , W , Z -- * Access model information , docTopics , wordTopics , topics , alphasum , beta , topicNum , vSize , model , topicDocs , topicWords -- * Initialization and finalization , initial , finalize -- * Prediction , docTopicWeights -- * Miscelaneous , compress , Table2D , Table1D ) where -- Standard libraries import qualified Data.IntMap as IntMap import qualified Data.Vector.Unboxed as U import qualified Data.Vector as V import qualified Data.List as List import Prelude hiding (sum) -- Third party module import GHC.Generics (Generic) import Data.Random (rvarT) import Data.RVar import Data.Random.Distribution.Categorical import Control.Monad.State import Data.Random.Source.PureMT (pureMT) import Data.Word (Word64) -- Package modules import NLP.LDA.Utils (count) import NLP.LDA.UnboxedMaybeVector () -- Exported types type D = Int type Z = Int type W = Int type Doc = (D, U.Vector (W, Maybe Z)) type Table2D = IntMap.IntMap Table1D type Table1D = IntMap.IntMap Double -- | Abstract type holding the settings and the state of the sampler data LDA = LDA { docTopics :: Table2D -- ^ Document-topic counts , wordTopics :: Table2D -- ^ Word-topic counts , topics :: Table1D -- ^ Topic counts , alphasum :: !Double -- ^ alpha * K Dirichlet parameter (topic sparseness) , beta :: !Double -- ^ beta Dirichlet parameter (word sparseness) , topicNum :: !Int -- ^ Number of topics K , vSize :: !Int -- ^ Number of unique words } deriving (Generic) -- | Abstract type holding the LDA model, and the inverse count tables data Finalized = Finalized { model :: LDA -- ^ LDA model , topicDocs :: Table2D -- ^ Inverse document-topic counts , topicWords :: Table2D -- ^ Inverse word-topic counts } deriving (Generic) -- | Custom random variable representing the LDA Gibbs sampler type Sampler a = RVarT (State LDA) a -- Exported functions -- | @initial k a b@ initializes model with @k@ topics, @a/k@ alpha -- hyperparameter and @b@ beta hyperparameter. initial :: Int -> Double -> Double -> LDA initial k a b = LDA { docTopics = IntMap.empty , wordTopics = IntMap.empty , topics = IntMap.empty , alphasum = a , beta = b , topicNum = k , vSize = 0 } -- | @finalize m@ creates a finalized model from LDA model @m@ finalize :: LDA -> Finalized finalize m = Finalized { model = m , topicDocs = invert . docTopics $ m , topicWords = invert . wordTopics $ m } -- | @pass batch@ runs one pass of Gibbs sampling on documents in @batch@ pass :: V.Vector Doc -> Sampler (V.Vector Doc) pass = V.mapM passOne -- | @runSampler seed m s@ runs sampler @s@ with @seed@ and initial -- model @m@. The random number generator used is -- System.Random.Mersenne.Pure64. runSampler :: Word64 -> LDA -> Sampler a -> (a, LDA) runSampler seed m = flip runState m . flip evalStateT (pureMT seed) . sampleRVarTWith lift -- | @runLDA seed n m ds@ creates and runs an LDA sampler with @seed@ -- for @n@ passes with initial model @m@ on the batch of documents -- @ds@. The random number generator used is -- System.Random.Mersenne.Pure64. runLDA :: Word64 -> Int -> LDA -> V.Vector Doc -> (V.Vector Doc, LDA) runLDA seed n m ds = runSampler seed m . foldM (const . pass) ds $ [1..n] -- | Remove zero counts from the doc/topic table compress :: IntMap.IntMap (IntMap.IntMap Double) -> IntMap.IntMap (IntMap.IntMap Double) compress = IntMap.map dezero -- Private functions -- -- | Run a pass on a single doc passOne :: Doc -> Sampler Doc passOne (d, wz) = do zs <- U.mapM one wz return (d, U.zip (U.map fst wz) (U.map Just zs)) where one (w, mz) = do m <- lift get let m' = maybe m (update (-1) m d w) mz -- decrement counts lift $ put m' z <- randomZ d w -- sample topic lift $ put (update 1 m' d w z) -- increment counts return z -- | Sample a random topic for doc d and word w randomZ :: D -> W -> Sampler Z randomZ d w = do m <- lift get sampleCategorical . fromWeightedList . U.toList . U.map swap . U.indexed . wordTopicWeights m d $ w -- | @topicWeights m d w@ returns the unnormalized probabilities of -- topics for word @w@ in document @d@ given LDA model @m@. wordTopicWeights :: LDA -> D -> W -> U.Vector Double wordTopicWeights m d w = let k = topicNum m a = alphasum m / fromIntegral k b = beta m dt = IntMap.findWithDefault IntMap.empty d . docTopics $ m wt = IntMap.findWithDefault IntMap.empty w . wordTopics $ m v = fromIntegral . vSize $ m weights = [ (count z dt + a) -- n(z,d) + alpha * (count z wt + b) -- n(z,w) + beta * (1/(count z (topics m) + v * b)) -- n(.,w) + V * beta | z <- [0..k-1] ] in U.fromList weights {-# INLINE wordTopicWeights #-} -- | @docTopicWeights m doc@ returns unnormalized topic probabilities -- for document doc given LDA model @m@ docTopicWeights :: LDA -> Doc -> U.Vector Double docTopicWeights m (d, ws) = U.accumulate (+) (U.replicate (topicNum m) 0) . U.concatMap (U.indexed . wordTopicWeights m d) . U.map fst $ ws {-# INLINE docTopicWeights #-} -- | Update counts in the model corresponding to given doc, word and topic update :: Double -> LDA -> D -> W -> Z -> LDA update c m d w z = m { docTopics = upd c (docTopics m) d z , wordTopics = upd c (wordTopics m) w z , topics = IntMap.insertWith' (+) z c (topics m) , vSize = vSize m + (fromEnum . IntMap.notMember w . wordTopics $ m) } -- FIXME: define a more efficient version -- | Increment table m by c at key (k,k') upd :: Double -> Table2D -> Int -> Int -> Table2D upd c m k k' = IntMap.insertWith' (flip (IntMap.unionWith (+))) k (IntMap.singleton k' c) m {-# INLINE upd #-} sampleCategorical :: Categorical Double Z -> Sampler Z sampleCategorical = sampleRVarT . rvarT {-# INLINE sampleCategorical #-} dezero :: IntMap.IntMap Double -> IntMap.IntMap Double dezero = IntMap.filter (/=0) {-# INLINE dezero #-} -- | Swap the order of keys on Table2D invert :: Table2D -> Table2D invert outer = List.foldl' (\z (k,k',v) -> upd v z k k') IntMap.empty [ (k',k,v) | (k, inner) <- IntMap.toList outer , (k', v) <- IntMap.toList inner ] {-# INLINE invert #-} swap :: (Int, Double) -> (Double, Int) swap (!a, !b) = (b, a)