module NLP.Probability.ConditionalDistribution (
CondObserved(),
CondDistribution,
condObservation,
condObservations,
condObservationCounts,
Context(..),
estimateGeneralLinear,
Weighting,
wittenBell,
simpleLinear,
DebugDist,
mkDist
) where
import qualified Data.ListTrie.Base.Map as M
import Data.List (inits)
import Data.Monoid
import qualified NLP.Probability.SmoothTrie as ST
import NLP.Probability.Distribution
import NLP.Probability.Observation
import Data.Binary
import Text.PrettyPrint.HughesPJClass
type CondObserved event context = (ST.SmoothTrie (SubMap context) (Sub context) (Counts event))
class (M.Map (SubMap a) (Sub a)) => Context a where
type Sub a
type SubMap a :: * -> * -> *
decompose :: a -> [Sub a]
condObservations :: (Context context, Event event) =>
event -> context -> Count -> CondObserved event context
condObservations event context count =
ST.addColumn decomp observed mempty
where observed = observations event count
decomp = decompose context
condObservation event context = condObservations event context 1.0
condObservationCounts :: (Context context, Event event) =>
context -> Counts event -> CondObserved event context
condObservationCounts context counts =
ST.addColumn decomp counts mempty
where decomp = decompose context
type CondDistribution event context = context -> Distribution event
type DebugDist event context =(context -> event -> [(Double,Double)])
type Weighting = forall a. [Maybe (Observed a)] -> [Double]
mkDist :: DebugDist event context -> CondDistribution event context
mkDist dd context event = sum $ map (uncurry (*)) $ dd context event
estimateGeneralLinear :: (Event event, Context context) =>
Weighting ->
CondObserved event context ->
DebugDist event context
estimateGeneralLinear genLambda cstat = conFun
where
conFun context = (\event -> zip lambdas $ map (probE event) stats)
where stats = reverse $
Nothing : (map (\k -> Just $ ST.lookupWithDefault (finish mempty) k cstat') $
tail $ inits $ decompose context)
probE event (Just dist) = if isNaN p then 0.0 else p
where p = mle dist event
probE event Nothing = 1e-19
lambdas = genLambda stats
cstat' = fmap finish cstat
simpleLinear :: [Double] -> Weighting
simpleLinear lambdas = const lambdas
lambdaWBC :: Int -> Observed b -> Double
lambdaWBC n eobs = total' / (((fromIntegral n) * distinct) + total')
where total' = total eobs
distinct = unique eobs
wittenBell :: Int -> Weighting
wittenBell n ls = wittenBell' ls 1.0
where
wittenBell' [Nothing] mult = [mult]
wittenBell' (Just cur:ls) mult =
if total cur > 0 then (l*mult : wittenBell' ls ((1l)*mult))
else (0.0: wittenBell' ls mult)
where l = lambdaWBC n cur