module HLearn.Models.Distributions.Multivariate.Internal.CatContainer
where
import Control.DeepSeq
import Control.Monad.Random
import Data.List
import Data.List.Extras
import Debug.Trace
import qualified Data.Map.Strict as Map
import qualified Data.Foldable as F
import HLearn.Algebra
import HLearn.Models.Distributions.Common
import HLearn.Models.Distributions.Multivariate.Internal.Ignore
import HLearn.Models.Distributions.Multivariate.Internal.Unital
import HLearn.Models.Distributions.Multivariate.Internal.Marginalization
import HLearn.Models.Distributions.Univariate.Categorical
data CatContainer label basedist prob = CatContainer
{ pdfmap :: !(Map.Map label basedist)
, probmap :: !(Map.Map label prob)
, catnumdp :: prob
}
deriving (Show,Read,Eq,Ord)
instance (NFData label, NFData prob, NFData basedist) =>
NFData (CatContainer label basedist prob)
where
rnf d = rnf $ pdfmap d
instance (Ord label, Num prob, Monoid basedist) => Abelian (CatContainer label basedist prob)
instance (Ord label, Num prob, Monoid basedist) => Monoid (CatContainer label basedist prob) where
mempty = CatContainer mempty mempty 0
d1 `mappend` d2 = CatContainer
{ pdfmap = Map.unionWith (<>) (pdfmap d1) (pdfmap d2)
, probmap = Map.unionWith (+) (probmap d1) (probmap d2)
, catnumdp = (catnumdp d1)+(catnumdp d2)
}
instance (Ord label, Num prob, Group basedist) => Group (CatContainer label basedist prob) where
inverse d1 = CatContainer
{ pdfmap = Map.map (inverse) (pdfmap d1)
, probmap = Map.map negate (probmap d1)
, catnumdp = catnumdp d1
}
instance (Num prob) => HasRing (CatContainer label basedist prob) where
type Ring (CatContainer label basedist prob) = prob
instance
( Ord label
, Num prob
, Module basedist
, Ring basedist ~ Ring (CatContainer label basedist prob)
) => Module (CatContainer label basedist prob)
where
r .* d = CatContainer
{ pdfmap = Map.map (r.*) (pdfmap d)
, probmap = Map.map (r*) (probmap d)
, catnumdp = r * catnumdp d
}
instance
( Ord label
, Num prob
, HomTrainer basedist
, Datapoint basedist ~ HList ys
) => HomTrainer (CatContainer label basedist prob)
where
type Datapoint (CatContainer label basedist prob) = label `HCons` (Datapoint basedist)
train1dp (dp:::basedp) = CatContainer
{ pdfmap = Map.singleton dp $ train1dp basedp
, probmap = Map.singleton dp 1
, catnumdp = 1
}
instance (Num prob) => NumDP (CatContainer label basedist prob) where
numdp dist = catnumdp dist
instance Probabilistic (CatContainer label basedist prob) where
type Probability (CatContainer label basedist prob) = prob
instance
( Ord prob, Fractional prob, Show prob, Probability basedist ~ prob
, Ord label
, PDF basedist
, Datapoint basedist ~ HList ys
, Show (Datapoint basedist)
, Show label
) => PDF (CatContainer label basedist prob)
where
pdf dist (label:::basedp) = val*weight/(catnumdp dist)
where
weight = case Map.lookup label (probmap dist) of
Nothing -> 0
Just x -> x
val = case Map.lookup label (pdfmap dist) of
Nothing -> trace ("Warning.CatContainer: label "++show label++" not found in training data: "++show (Map.keys $ pdfmap dist)) $ 0
Just x -> pdf x basedp
instance
( NumDP basedist
, Ring basedist ~ prob
, Monoid basedist
, HCons label (Datapoint basedist) ~ HList (label ': ts)
, Ord label
) => Marginalize' (Nat1Box Zero) (CatContainer label basedist prob)
where
type Margin' (Nat1Box Zero) (CatContainer label basedist prob) = (Categorical prob label)
getMargin' _ dist = Categorical $ probmap dist
type MarginalizeOut' (Nat1Box Zero) (CatContainer label basedist prob) = Ignore' label basedist prob
marginalizeOut' _ dist = Ignore' $ reduce $ Map.elems (pdfmap dist)
condition' _ dist dp = Ignore' $
case Map.lookup dp (pdfmap dist) of
Nothing -> error "CatContainer.condition: Nothing"
Just basedist -> basedist
instance
( Marginalize' (Nat1Box n) basedist
, Monoid basedist
, PDF (Margin' (Nat1Box n) basedist)
, prob ~ Probability (Margin' (Nat1Box n) basedist)
, prob ~ Ring basedist
, Module basedist
, Ord label
, Num prob
) => Marginalize' (Nat1Box (Succ n)) (CatContainer label basedist prob)
where
type Margin' (Nat1Box (Succ n)) (CatContainer label basedist prob) = Margin' (Nat1Box n) basedist
getMargin' _ dist = getMargin' (undefined :: Nat1Box n) $ reduce $
zipWith (.*)
(Map.elems $ probmap dist)
(Map.elems $ pdfmap dist)
type MarginalizeOut' (Nat1Box (Succ n)) (CatContainer label basedist prob) =
CatContainer label (MarginalizeOut' (Nat1Box n) basedist) prob
marginalizeOut' _ dist = dist { pdfmap = fmap (marginalizeOut' (undefined :: Nat1Box n)) $ pdfmap dist }
condition' _ dist dp = dist
{ probmap = Map.unionWith (*) (probmap dist) (conditionmap)
, pdfmap = fmap (flip (condition' (undefined :: Nat1Box n)) dp) $ pdfmap dist
}
where
conditionmap = fmap (\dist -> pdf (getMargin' (undefined :: Nat1Box n) dist) dp) $ pdfmap dist
ds= [ "test":::'g':::"foo":::HNil
, "test":::'f':::"fok":::HNil
, "toot":::'f':::"foo":::HNil
]
test = train ds :: CatContainer String (CatContainer Char (CatContainer String (Unital Double) Double) Double) Double