{-# LANGUAGE GeneralizedNewtypeDeriving, TypeFamilies, Rank2Types, FlexibleContexts #-}
module NLP.Probability.ConditionalDistribution (  
  -- * Conditional Distributions
  --                    
  -- $CondDistDesc  
                                                CondObserved(),
                                                CondDistribution,
                                                condObservation,
                                                condObservations,
                                                condObservationCounts,
                                                Context(..), 
                                                estimateGeneralLinear,
                                                Weighting,
                                                wittenBell, 
                                                simpleLinear,
                                                DebugDist,
                                                mkDist
                                                ) where 
import qualified Data.ListTrie.Base.Map as M
import Data.List (inits)
import Data.Monoid
import qualified NLP.Probability.SmoothTrie as ST
import NLP.Probability.Distribution
import NLP.Probability.Observation 
import Data.Binary
import Text.PrettyPrint.HughesPJClass
-- $CondDistDesc
-- Say we want to estimate a conditional distribution based on a very large set of observed data.
-- Naively, we could just collect all the data and estimate a large table, but
-- our table would have little or no counts for a feasible future observations. 
--
-- In practice, we use smoothing to supplement rare contexts with data from similar, more often seen contexts. For instance,
-- using bigram probabilities when the given trigrams observations are too sparse. 
-- Most of these smoothing techniques are special cases of general linear interpolation, which chooses the weight of 
-- each level of smoothing based on the sparsity of the current context. 
--
-- In this module, we give an implementation of this process that separates out count collection
-- from the smoothing model, using  a Trie. The user specifies a Context instance that relates the full conditional context
-- to a sequences of SubContexts that characterize the levels of smoothing and the transitions in the Trie. We also give a small set of smoothing techniques 
-- to combine these levels. 
--
-- This work is based on Chapter 6 of ''Foundations of Statistical Natural Language Processing'' 
-- by Chris Manning and Hinrich Schutze. 
-- 




-- | The set of observations of event conditioned on context. event must be an instance of Event and context of Context 
type CondObserved event context = (ST.SmoothTrie (SubMap context) (Sub context) (Counts event))

 
-- | Events are conditioned on Contexts. When Contexts are sparse, we need a way to decompose into simpler SubContexts. 
--   This class allows us to separate this decomposition from the collection of larger contexts. 
class (M.Map (SubMap a) (Sub a)) => Context a where 
    -- | The type of sub contexts
    type Sub a  
    -- | A map over subcontexts (for efficiency) 
    type SubMap a :: * -> * -> * 

    -- | A function to enumerate subcontexts of a context  
    decompose ::  a -> [Sub a] 


-- | A CondObserved set for a single event and context. 
condObservations :: (Context context, Event event) => 
             event -> context -> Count -> CondObserved event context
condObservations event context count = 
    ST.addColumn decomp observed mempty 
        where observed = observations event count 
              decomp = decompose context 

condObservation event context = condObservations event context 1.0

condObservationCounts :: (Context context, Event event) => 
             context -> Counts event  -> CondObserved event context
condObservationCounts context counts =
    ST.addColumn decomp counts mempty 
        where decomp = decompose context 
    

type CondDistribution event context = context -> Distribution event
type DebugDist event context  =(context -> event -> [(Double,Double)])

type Weighting = forall a. [Maybe (Observed a)] -> [Double]

mkDist :: DebugDist event context -> CondDistribution event context
mkDist dd context event = sum $ map (uncurry (*)) $ dd context event

-- | General Linear Interpolation. Produces a Conditional Distribution from observations.
--   It requires a GeneralLambda function which tells it how to weight each level of smoothing. 
--   The GeneralLambda function can observe the counts of each level of context. 
--
--   Note: We include a final level of backoff where everything is given an epsilon likelihood. To 
--   ignore this, just give it lambda = 0.
estimateGeneralLinear :: (Event event, Context context) => 
                         Weighting -> 
                         CondObserved event context -> 
                         DebugDist event context
estimateGeneralLinear genLambda cstat = conFun 
    where
      conFun context = (\event -> zip lambdas $ map (probE event) stats) 
          where stats = reverse $ 
                        Nothing : (map (\k -> Just $ ST.lookupWithDefault (finish mempty) k cstat')  $ 
                                  tail $ inits $ decompose context)
                probE event (Just dist) = if isNaN p then 0.0 else p
                    where p = mle dist event
                probE event Nothing = 1e-19
                lambdas = genLambda stats                
      cstat' = fmap finish cstat

-- | Weight each level by a fixed predefined amount. 
simpleLinear :: [Double] -> Weighting
simpleLinear lambdas = const lambdas


lambdaWBC :: Int -> Observed b -> Double
lambdaWBC n eobs = total' / (((fromIntegral n) * distinct) + total')
    where total' = total eobs
          distinct = unique eobs

-- | Weight each level by the likelihood that a new event will be seen at that level. 
--   t / ((n * d) + t) where t is the total count, d is the number of distinct observations,
--   and n is a user defined constant.   
wittenBell :: Int -> Weighting 
wittenBell n ls = wittenBell' ls 1.0
    where 
      wittenBell' [Nothing] mult = [mult]
      wittenBell' (Just cur:ls) mult = 
          if total cur > 0 then (l*mult : wittenBell' ls ((1-l)*mult)) 
          else (0.0: wittenBell' ls mult)  
              where l = lambdaWBC n cur