{-# LANGUAGE 
   OverloadedStrings  
 , FlexibleInstances
 , DeriveGeneric 
 , NoMonomorphismRestriction 
 , DeriveDataTypeable
 , TemplateHaskell
 , BangPatterns
 #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- | Word Class induction with LDA
--
-- This module provides function which implement word class induction
-- using the generic algorithm implemented in Colada.LDA. 
--
-- You can access and set options in the @Options@ record using lenses.
-- Example:
--
-- >  import Data.Label
-- >  let options =   set passes 5 
-- >                . set beta 0.01 
-- >                . set topicNum 100 
-- >                $ defaultOptions
-- >  in run options sentences

module Colada.WordClass 
       ( 
         -- * Running the sampler
         learn
       , defaultOptions
         -- * Extracting information
       , summary
       , summarize
       , wordTypeClasses
         -- * Class and word prediction
       , label
       , predict
         -- * Data types and associated lenses
       , WordClass       
       , ldaModel
         -- | LDA model
       , wordTypeTable
         -- | Word type string to atom and vice versa conversion tables
       , featureTable
         -- | Feature string to atom and vice versa conversion tables
       , options
         -- | Options for Gibbs sampling
       , LDA.Finalized
       , LDA.docTopics
       , LDA.wordTopics
       , LDA.topics
       , LDA.topicDocs
       , LDA.topicWords
       , Options
       , featIds
         -- | Feature ids
       , topicNum
         -- | Number of topics K
       , alphasum
         -- | Dirichlet parameter alpha*K which controls topic sparseness
       , beta
         -- | Dirichlet parameter beta which controls word sparseness
       , passes
         -- | Number of sampling passes per batch
       , repeats
         -- | Number of repeats per sentences
       , batchSize
         -- | Number of sentences per batch
       , seed
         -- | Random seed for the sampler
       , topn
         -- | Number of most probable words to return
       , initSize
       , initPasses
       , exponent
       , progressive
       , pre
       , lambda
       )
where       
  
-- Standard libraries  
import qualified Data.Text.Lazy.IO          as Text
import qualified Data.Text.Lazy             as Text
import qualified Data.Text.Lazy.Builder     as Text
import qualified Data.Text.Lazy.Builder.Int as Text
import qualified Data.Text.Lazy.Encoding    as Text
import qualified Data.Vector                as V
import qualified Data.Vector.Generic        as G
import qualified Data.Vector.Unboxed        as U
import qualified Data.IntMap                as IntMap
import qualified Data.Map                   as Map
import qualified Data.Serialize             as Serialize
import qualified Control.Monad              as M
import qualified Data.List                  as List
import qualified Data.List.Split            as Split
import qualified Data.Ord                   as Ord
import qualified Data.Foldable              as Fold
import qualified Data.Traversable           as Trav
import qualified Control.Monad.ST           as ST
import qualified Control.Monad.ST.Lazy      as LST
import Data.Function (on)
import Control.Monad.Writer       
import Data.Word           (Word32)
import Data.Typeable       (Typeable)
import Data.Data           (Data)
import Prelude                              hiding ((.), exponent)
import Control.Category    ((.))
import Control.Applicative ((<$>))
import qualified System.IO.Unsafe as Unsafe
-- Third party modules  
import qualified Control.Monad.Atom  as Atom
import qualified NLP.CoNLL  as CoNLL
import qualified Data.List.Zipper as Zipper
import GHC.Generics (Generic)
import qualified Data.Label as L
import Data.Label (get)

import qualified NLP.SwiftLDA as LDA

-- Package modules
import qualified Colada.Features as F
import qualified NLP.Symbols     as Symbols

import Debug.Trace

-- | Container for the Word Class model
data WordClass = 
  WordClass { _ldaModel      :: LDA.Finalized      -- ^ LDA model
            , _wordTypeTable :: Atom.AtomTable (U.Vector Char) 
            , _featureTable  :: Atom.AtomTable (U.Vector Char)
            , _options       :: Options 
            }
  deriving (Generic)

data Options = Options { _featIds    :: [Int]
                       , _topicNum   :: !Int     
                       , _alphasum   :: !Double  
                       , _beta       :: !Double  
                       , _passes     :: !Int     
                       , _repeats    :: !Int     
                       , _batchSize  :: !Int    
                       , _seed       :: !Word32
                       , _topn       :: !Int
                       , _initSize   :: !Int
                       , _initPasses :: !Int
                       , _exponent   :: !(Maybe Double)
                       , _progressive :: !Bool
                       , _pre         :: !Bool
                       , _lambda     :: !Double
                       }
             deriving (Eq, Show, Typeable, Data, Generic)

instance Serialize.Serialize Options

$(L.mkLabels [''WordClass, ''Options])

defaultOptions :: Options
defaultOptions = Options { _featIds    = [-1,1]
                         , _topicNum   = 10                
                         , _alphasum   = 10
                         , _beta       = 0.1
                         , _passes     = 1
                         , _repeats    = 1
                         , _batchSize  = 1 
                         , _seed       = 0 
                         , _topn       = maxBound               
                         , _initSize   = 0
                         , _initPasses = 100
                         , _exponent   = Nothing
                         , _progressive= False
                         , _pre        = False                
                         , _lambda     = 1.0
                         }
                 
-- | @learn options xs@ runs the LDA Gibbs sampler for word classes
-- with @options@ on sentences @xs@, and returns the resulting model
-- together with progressive class assignments
learn :: Options 
         -> [CoNLL.Sentence] 
         -> (WordClass, [V.Vector (U.Vector Double)])
learn opts xs = 
  let ((bs_init, bs_rest), atomTabD, atomTabW) = 
        Symbols.runSymbols prepare Symbols.empty Symbols.empty
      prepare =  do                                        
        let (xs_init, xs_rest) = List.splitAt (get initSize opts) xs
        ini  <- prepareData (get initSize opts)
                                     1
                                     (get featIds opts)
                                     xs_init
        rest <- prepareData (get batchSize opts)
                                   (get repeats opts) 
                                   (get featIds opts) 
                xs_rest
        return (ini, rest)
      sampler :: WriterT [V.Vector (U.Vector Double)] (LST.ST s) LDA.Finalized
      sampler = do         
        m <- st $ LDA.initial (U.singleton (get seed opts)) 
                         (get topicNum opts)
                         (get alphasum opts)
                         (get beta opts)
                         (get exponent opts)
        let loop t i_last batch i = do
              let label_prog = do  
                  Fold.forM_ batch $ \rep -> do    
                    let sent = V.head rep
                    ls <- st $ V.mapM (interpWordClasses m (get lambda opts)) sent
                    tell [ls]
              -- Either label before sampling (--pre)
              M.when (get progressive opts && i == 1 && get pre opts) label_prog 
              r <- st $ Trav.forM batch $ \rep -> do
                     Trav.forM rep $ \sent -> do
                       LDA.pass t m sent
              -- Or label after sampling    
              M.when (get progressive opts && i == i_last && not (get pre opts)) label_prog 
              return $! r
        -- Initialize with batch sampler on prefix sbs_init     
        Fold.forM_ bs_init $ \batch -> do 
          Fold.foldlM (loop 1 $ get initPasses opts) batch  [1..get initPasses opts] 
        -- Continue sampling
        Fold.forM_ (zip [1..] bs_rest) $ \(t, batch) -> do
          Fold.foldlM (loop t $ get passes opts) batch [1..get passes opts]
        st $ LDA.finalize m    
      (lda, labeled) = LST.runST (runWriterT sampler)
  in (WordClass lda atomTabD atomTabW opts, labeled)

type Symb  = Symbols.Symbols (U.Vector Char) (U.Vector Char)
type Sent  = V.Vector LDA.Doc
type Repeat = V.Vector Sent
type Batch = V.Vector Repeat

-- | Convert a stream of sentences into a stream of batches ready for
-- sampling.
prepareData ::   Int                          -- ^ batch size 
               -> Int                         -- ^ no. repeats 
               -> [Int]                       -- ^ feature indices
               -> [CoNLL.Sentence]            -- ^ stream of sentences
               -> Symb [Batch]                -- ^ stream of batches
prepareData bsz rep is ss = do
  ss' <- mapM symbolize . map (featurize is) $ ss
  return $! (map V.fromList . Split.chunksOf bsz . map (V.replicate rep) $ ss')

-- | Extract features from a sentence
featurize :: [Int] 
             -> CoNLL.Sentence 
             -> [(Text.Text, [Text.Text])]
featurize is s = 
  let mk fs = 
        let d = IntMap.findWithDefault 
                        (error "parseData: focus feature missing") 0 fs
            ws =   [ Text.concat [f,"^",Text.pack . show $ i ] 
                   | i <- is , Just f <- [IntMap.lookup i fs] ]
        in (d, ws)
  in map mk . extractFeatures $ s

-- | Convert text strings into symbols (ints)
symbolize ::  [(Text.Text, [Text.Text])] -> Symb Sent
symbolize s = V.fromList <$> mapM doc s 
  where doc (d, ws) = do
          da  <- Symbols.toAtomA . compress $ d
          was <- mapM (Symbols.toAtomB . compress) ws
          return (da, U.fromList $ zip was (repeat Nothing))

-- | @summary m@ returns a textual summary of word classes found in
-- model @m@
summary :: WordClass -> Text.Text
summary = summarize False

summarize :: Bool -> WordClass -> Text.Text
summarize harden m = 
  let format (z,cs) =  do 
        cs' <-  mapM (Atom.fromAtom . fst)
                . takeWhile ((>0) . snd)
                . take 10 
                . List.sortBy (flip $ Ord.comparing snd) . IntMap.toList 
                $ cs
        return . Text.unwords $ Text.pack (show z) 
          : map (Text.pack . U.toList) cs' 
  in fst . flip Atom.runAtom (get wordTypeTable m) 
    . M.liftM Text.unlines 
    . mapM format
    . IntMap.toList
    . IntMap.fromListWith (IntMap.unionWith (+))
    . concatMap (\(d,zs) -> [ (z, IntMap.singleton d c) | (z,c) <- zs ])
    -- Maybe harden    
    . (if harden 
       then 
        map (\(d,zs) -> let s = sum . map snd $ zs 
                            (z',_) = List.maximumBy (Ord.compare `on` snd) zs
                        in (d, [ (z, if z == z' then s else 0) | (z,_) <- zs ]))
      else id)
    --            
    . IntMap.toList
    . IntMap.map  IntMap.toList
    . LDA.docTopics
    . get ldaModel 
    $ m

-- | @interpWordClasses m lambda doc@ gives the class probabilities for
-- word type in context @doc@ according to evolving model @m@. It
-- interpolates the prior word type probability with the
-- context-conditioned probabilities using alpha: 
-- P(d,w) = lambda * P(z|d) + (1-lambda) * P(z|d,w)
interpWordClasses ::    LDA.LDA s
                     -> Double 
                     -> LDA.Doc 
                     -> ST.ST s (U.Vector Double)
interpWordClasses m lambda doc@(d,_) = do  
  pzd  <- normalize <$> LDA.priorDocTopicWeights_ m d
  pzdw <- normalize <$> LDA.docTopicWeights_ m doc
  return $! normalize $ U.zipWith (\p q -> lambda * p + (1-lambda) * q) pzd pzdw
  where normalize x = 
          let uniform = U.replicate (U.length x) (1 / (fromIntegral (U.length x)))
          in case U.sum x of
               0            -> uniform
               s | s >= 1/0 -> uniform 
               s -> U.map (/s) x
        
-- | @wordTypeClasses m@ returns a Map from word types to unnormalized
-- distributions over word classes
wordTypeClasses :: WordClass -> Map.Map Text.Text (IntMap.IntMap Double)
wordTypeClasses m =     
   fst . flip Atom.runAtom (get wordTypeTable m) 
  . fmap Map.fromList 
  . mapM (\(k,v) -> do k' <- Atom.fromAtom k ; return (decompress k',v))
  . IntMap.toList
  . LDA.docTopics
  . get ldaModel
  $ m
    
  
-- | @label m s@ returns for each word in sentences s,
-- unnormalized probabilities of word classes.
label :: Bool -> WordClass -> CoNLL.Sentence -> V.Vector (U.Vector Double)
label noctx m s = fst3 $ Symbols.runSymbols label' 
                                      (L.get wordTypeTable m) 
                                      (L.get featureTable m) 
  where dectx doc@(d, _) = if noctx 
                           then (d, U.singleton (-1,Nothing)) --FIXME: ugly hack
                           else doc
        label' = do
          let fm = L.get ldaModel m
          s' <- prepareSent  m s
          return $! V.map (LDA.docTopicWeights fm . dectx) 
            $ s'

-- | @predict m s@ returns for each word in sentence s, unnormalized
-- probabilities of words given predicted word class.
predict :: WordClass -> CoNLL.Sentence 
           -> [V.Vector (Double, Text.Text)]
predict m s = fst3 $ Symbols.runSymbols predict' (L.get wordTypeTable m) 
                                                 (L.get featureTable m) 
  where predict' = do
          let fm = L.get ldaModel m
          s' <- prepareSent m s
          let ws =   map  (  G.convert 
                           . predictDoc (get (topn . options) m) fm 
                           . docToWs ) 
                    . V.toList 
                    $ s'
          mapM (V.mapM fromAtom) ws
        docToWs = U.map fst . snd
        fromAtom (n,w) = do w' <- Symbols.fromAtomA w
                            return (n, decompress w')

prepareSent :: WordClass -> CoNLL.Sentence -> Symb Sent
prepareSent m = symbolize . featurize (L.get (featIds . options) m) 

-- | @predictDoc n m ws@ returns unnormalized probabilities of top @n@
-- most probable document ids given the model @m@ and words @ws@. The
-- candidate document ids are taken from the model @m@.  The weights
-- are computed according to the following formula:
--  
-- > P(d|{w}) ∝ Σ_z[n(d,z)+a Σ_{w in ws}(n(w,z)+b)/(Σ_{w in V} n(w,z)+b)]
predictDoc ::  Int -> LDA.Finalized -> U.Vector LDA.W 
               -> U.Vector (Double, LDA.D)
predictDoc n m ws =
  let k = LDA.topicNum m
      a = LDA.alphasum m / fromIntegral k
      b = LDA.beta m
      v = fromIntegral . LDA.wSize $ m
      zt = LDA.topics m
      wsums = 
        IntMap.fromList 
          [ (z, U.sum . U.map (\w -> (count w wt_z + b) / denom) $ ws)
          | z <- IntMap.keys zt
          , let wt_z = IntMap.findWithDefault IntMap.empty z . LDA.topicWords 
                       $ m
                denom  = count z zt + b * v ]
      wsum z = IntMap.findWithDefault 
                (error $ "Colada.WordClass.predictDoc: key not found: " 
                 ++ show z)
                z wsums
  in U.fromList . take n . List.sortBy (flip compare) 
     $ [ ( sum [ (c + a) * wsum z | (z,c) <- IntMap.toList zt_d ] , d)
       | (d, zt_d) <- IntMap.toList . LDA.docTopics $ m
       ]

extractFeatures :: CoNLL.Sentence -> [IntMap.IntMap Text.Text]
extractFeatures = F.featureSeq combine  
                  . Zipper.fromList 
                  . map (V.! 0) 
                  . V.toList
  
combine :: Maybe Text.Text -> Maybe Text.Text -> Maybe Text.Text
combine (Just a) (Just b) = Just $ Text.concat [a, "|", b]
combine (Just a) Nothing  = Just a
combine Nothing  (Just b) = Just b
combine Nothing Nothing   = Nothing
  



compress :: Text.Text -> U.Vector Char
compress = U.fromList . Text.unpack

decompress :: U.Vector Char -> Text.Text
decompress = Text.pack . U.toList

fst3 :: (a, b, c) -> a
fst3 (a,_,_) = a      

count :: Int -> IntMap.IntMap Double -> Double
count z t = case IntMap.findWithDefault 0 z t of
        n | n < 0 -> error "Colada.WordClass.count: negative count"
        n -> n
{-# INLINE count #-}   
       
st :: Monoid w => ST.ST s a -> WriterT w (LST.ST s) a
st = lift . LST.strictToLazyST

-- Instances for serialization

instance Serialize.Serialize LDA.Finalized

instance Serialize.Serialize Text.Text where
  put = Serialize.put . Text.encodeUtf8
  get = Text.decodeUtf8 `fmap` Serialize.get
  
instance Serialize.Serialize (U.Vector Char) where
  put v = Serialize.put (Text.pack . U.toList $ v)
  get = 
    do t <- Serialize.get
       return $! U.fromList . Text.unpack $ t
       
instance Serialize.Serialize (Atom.AtomTable (U.Vector Char))

instance Serialize.Serialize WordClass