{-# LANGUAGE  DeriveGeneric , BangPatterns #-}
-- | Latent Dirichlet Allocation
--
-- Simple implementation of a collapsed Gibbs sampler for LDA. This
-- library uses the topic modeling terminology (documents, words,
-- topics), even though it is generic. For example if used for word
-- class induction, replace documents with word types, words with
-- features and topics with word classes.

module NLP.LDA
       ( -- * Running samplers
         runSampler
       , pass
       , runLDA
         -- * Datatypes
       , Sampler 
       , LDA
       , Finalized
       , Doc
       , D
       , W
       , Z
         -- * Access model information
       , docTopics
       , wordTopics
       , topics
       , alphasum
       , beta
       , topicNum
       , vSize
       , model
       , topicDocs
       , topicWords
         -- * Initialization and finalization
       , initial
       , finalize
         -- * Prediction
       , docTopicWeights
         -- * Miscelaneous
       , compress
       , Table2D
       , Table1D
       )
where       
-- Standard libraries  
import qualified Data.IntMap as IntMap  
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector as V
import qualified Data.List as List
import Prelude hiding (sum)

-- Third party module
import GHC.Generics (Generic)
import Data.Random (rvarT)
import Data.RVar
import Data.Random.Distribution.Categorical 
import Control.Monad.State
import Data.Random.Source.PureMT (pureMT)
import Data.Word (Word64)

-- Package modules
import NLP.LDA.Utils (count)
import NLP.LDA.UnboxedMaybeVector () 

-- Exported types
type D = Int
type Z = Int  
type W = Int
type Doc = (D, U.Vector (W, Maybe Z))

type Table2D = IntMap.IntMap Table1D
type Table1D = IntMap.IntMap Double

-- | Abstract type holding the settings and the state of the sampler
data LDA = 
  LDA 
  { docTopics   :: Table2D -- ^ Document-topic counts
  , wordTopics  :: Table2D -- ^ Word-topic counts
  , topics      :: Table1D -- ^ Topic counts
  , alphasum    :: !Double -- ^ alpha * K Dirichlet parameter (topic sparseness)
  , beta        :: !Double -- ^ beta Dirichlet parameter (word sparseness)
  , topicNum    :: !Int    -- ^ Number of topics K
  , vSize       :: !Int    -- ^ Number of unique words
  } deriving (Generic)

-- | Abstract type holding the LDA model, and the inverse count tables
data Finalized = 
  Finalized 
  { model :: LDA         -- ^ LDA model
  , topicDocs :: Table2D   -- ^ Inverse document-topic counts
  , topicWords :: Table2D  -- ^ Inverse word-topic counts
  }
  deriving (Generic)

-- | Custom random variable representing the LDA Gibbs sampler
type Sampler a =  RVarT (State LDA) a
         
-- Exported functions  
         
-- | @initial k a b@ initializes model with @k@ topics, @a/k@ alpha
-- hyperparameter and @b@ beta hyperparameter.
initial :: Int -> Double -> Double -> LDA
initial k a b = 
  LDA { docTopics  = IntMap.empty
        , wordTopics = IntMap.empty
        , topics     = IntMap.empty
        , alphasum   = a
        , beta       = b
        , topicNum   = k 
        , vSize      = 0                       
        }
  
-- | @finalize m@ creates a finalized model from LDA model @m@
finalize :: LDA -> Finalized  
finalize m = 
  Finalized { model = m 
            , topicDocs =  invert . docTopics $ m
            , topicWords = invert . wordTopics $ m }



-- | @pass batch@ runs one pass of Gibbs sampling on documents in @batch@  
pass :: V.Vector Doc ->  Sampler (V.Vector Doc)
pass = V.mapM passOne


-- | @runSampler seed m s@ runs sampler @s@ with @seed@ and initial
-- model @m@. The random number generator used is
-- System.Random.Mersenne.Pure64.
runSampler :: Word64 -> LDA -> Sampler a -> (a, LDA)
runSampler seed m =    
     flip runState m
   . flip evalStateT (pureMT seed)
   . sampleRVarTWith lift 

   
-- | @runLDA seed n m ds@ creates and runs an LDA sampler with @seed@
-- for @n@ passes with initial model @m@ on the batch of documents
-- @ds@. The random number generator used is
-- System.Random.Mersenne.Pure64.
runLDA :: Word64 
            -> Int 
            -> LDA 
            -> V.Vector Doc 
            -> (V.Vector Doc, LDA)
runLDA seed n m ds = runSampler seed m . foldM (const . pass) ds 
                     $ [1..n]

-- | Remove zero counts from the doc/topic table
compress :: IntMap.IntMap (IntMap.IntMap Double) 
            -> IntMap.IntMap (IntMap.IntMap Double) 
compress = IntMap.map dezero

-- Private functions --

-- | Run a pass on a single doc
passOne :: Doc -> Sampler Doc
passOne (d, wz) = do 
  zs <- U.mapM one wz
  return (d, U.zip (U.map fst wz) (U.map Just zs))
  where one (w, mz) = do
          m <- lift get
          let m' = maybe m (update (-1) m d w) mz -- decrement counts
          lift $ put m'
          z <- randomZ d w                        -- sample topic
          lift $ put (update 1 m' d w z)          -- increment counts
          return z

-- | Sample a random topic for doc d and word w
randomZ :: D -> W -> Sampler Z
randomZ d w = do
  m <- lift get
  sampleCategorical . fromWeightedList . U.toList . U.map swap . U.indexed 
    . wordTopicWeights m d 
    $ w
    
-- | @topicWeights m d w@ returns the unnormalized probabilities of
-- topics for word @w@ in document @d@ given LDA model @m@.
wordTopicWeights :: LDA -> D -> W -> U.Vector Double
wordTopicWeights m d w =
  let k = topicNum m
      a = alphasum m / fromIntegral k
      b = beta m
      dt = IntMap.findWithDefault IntMap.empty d . docTopics $ m
      wt = IntMap.findWithDefault IntMap.empty w . wordTopics $ m
      v = fromIntegral . vSize $ m
      weights = [   (count z dt + a)     -- n(z,d) + alpha 
                  * (count z wt + b)     -- n(z,w) + beta
                  * (1/(count z (topics m) + v * b)) 
                      -- n(.,w) + V * beta
                  | z <- [0..k-1] ]
  in U.fromList weights
{-# INLINE wordTopicWeights #-}

-- | @docTopicWeights m doc@ returns unnormalized topic probabilities
-- for document doc given LDA model @m@
docTopicWeights :: LDA -> Doc -> U.Vector Double
docTopicWeights m (d, ws) = 
    U.accumulate (+) (U.replicate (topicNum m) 0)
  . U.concatMap (U.indexed . wordTopicWeights m d)
  . U.map fst 
  $ ws
{-# INLINE docTopicWeights #-}
  
-- | Update counts in the model corresponding to given doc, word and topic
update :: Double -> LDA -> D -> W -> Z -> LDA
update c m d w z = 
  m { docTopics  = upd c (docTopics m) d z
    , wordTopics = upd c (wordTopics m) w z
    , topics = IntMap.insertWith' (+) z c (topics m) 
    , vSize = vSize m + (fromEnum . IntMap.notMember w . wordTopics $ m)
    }
  
-- FIXME: define a more efficient version
-- | Increment table m by c at key (k,k')
upd :: Double -> Table2D -> Int -> Int -> Table2D
upd c m k k' = IntMap.insertWith' (flip (IntMap.unionWith (+)))
                               k 
                               (IntMap.singleton k' c)
                               m
{-# INLINE upd #-}

sampleCategorical :: Categorical Double Z -> Sampler Z
sampleCategorical = sampleRVarT . rvarT 
{-# INLINE sampleCategorical #-}

dezero :: IntMap.IntMap Double -> IntMap.IntMap Double
dezero = IntMap.filter (/=0)
{-# INLINE dezero #-}

-- | Swap the order of keys on Table2D
invert :: Table2D -> Table2D
invert outer = 
  List.foldl' (\z (k,k',v) -> upd v z k k')  IntMap.empty 
  [ (k',k,v)
    | (k, inner) <- IntMap.toList outer
    , (k', v) <- IntMap.toList inner ]
{-# INLINE invert #-}

swap :: (Int, Double) -> (Double, Int)
swap (!a, !b) = (b, a)