{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE BangPatterns #-} module HLearn.Models.Distributions.Categorical ( Categorical (..) , CategoricalParams(..) , dist2list ) where import Control.Monad.Random import Data.List import Data.List.Extras import Debug.Trace import qualified Data.Map.Strict as Map import qualified Data.Foldable as F import HLearn.Algebra import HLearn.Models.Distributions.Common ------------------------------------------------------------------------------- -- CategoricalParams -- | The Categorical distribution takes no parameters data CategoricalParams = CategoricalParams deriving (Read,Show,Eq) instance NFData CategoricalParams where rnf x = () instance Model CategoricalParams (Categorical label probtype) where getparams model = CategoricalParams instance DefaultModel CategoricalParams (Categorical label probtype) where defparams = CategoricalParams ------------------------------------------------------------------------------- -- Categorical data Categorical sampletype probtype = Categorical { pdfmap :: !(Map.Map sampletype probtype) } deriving (Show,Read,Eq) dist2list :: Categorical sampletype probtype -> [(sampletype,probtype)] dist2list (Categorical pdfmap) = Map.toList pdfmap instance (NFData sampletype, NFData probtype) => NFData (Categorical sampletype probtype) where rnf d = rnf $ pdfmap d ------------------------------------------------------------------------------- -- Training instance (Ord label, Num probtype) => HomTrainer CategoricalParams (label,probtype) (Categorical label probtype) where train1dp' params (dp,w) = Categorical $ Map.singleton dp w instance (Ord label, Num probtype) => HomTrainer CategoricalParams label (Categorical label probtype) where train1dp' params dp = Categorical $ Map.singleton dp 1 ------------------------------------------------------------------------------- -- Distribution instance (Ord label, Ord prob, Floating prob, Random prob) => Distribution (Categorical label prob) label prob where {-# INLINE pdf #-} pdf dist label = 0.0001+(val/tot) where val = case Map.lookup label (pdfmap dist) of Nothing -> 0 Just x -> x tot = F.foldl' (+) 0 $ pdfmap dist {-# INLINE cdf #-} cdf dist label = (Map.foldl' (+) 0 $ Map.filterWithKey (\k a -> k<=label) $ pdfmap dist) / (Map.foldl' (+) 0 $ pdfmap dist) {-# INLINE cdfInverse #-} cdfInverse dist prob = go prob pdfL where pdfL = map (\k -> (k,pdf dist k)) $ Map.keys $ pdfmap dist go prob [] = fst $ last pdfL go prob (x:xs) = if prob < snd x && prob > (snd $ head xs) then fst x else go (prob-snd x) xs -- cdfInverse dist prob = argmax (cdf dist) $ Map.keys $ pdfmap dist {- {-# INLINE mean #-} mean dist = fst $ argmax snd $ Map.toList $ pdfmap dist {-# INLINE drawSample #-} drawSample dist = do x <- getRandomR (0,1) return $ cdfInverse dist (x::prob) -} ------------------------------------------------------------------------------- -- Algebra instance (Ord label, Num probtype{-, NFData probtype-}) => Semigroup (Categorical label probtype) where (<>) !d1 !d2 = {-deepseq res $-} Categorical $ res where res = Map.unionWith (+) (pdfmap d1) (pdfmap d2) instance (Ord label, Num probtype) => RegularSemigroup (Categorical label probtype) where inverse d1 = d1 {pdfmap=Map.map (0-) (pdfmap d1)} instance (Ord label, Num probtype) => Monoid (Categorical label probtype) where mempty = Categorical Map.empty mappend = (<>) instance (Ord label, Num probtype) => Group (Categorical label probtype)