module HLearn.Models.Classifiers.Bayes
( Bayes
)
where
import Debug.Trace
import qualified Data.Map as Map
import GHC.TypeLits
import HLearn.Algebra
import HLearn.Models.Distributions
import HLearn.Models.Classifiers.Common
newtype Bayes label dist = Bayes dist
deriving (Read,Show,Eq,Ord,Monoid,Abelian,Group)
instance (Monoid dist, HomTrainer dist) => HomTrainer (Bayes label dist) where
type Datapoint (Bayes label dist) = Datapoint dist
train1dp dp = Bayes $ train1dp dp
instance Probabilistic (Bayes labelLens dist) where
type Probability (Bayes labelLens dist) = Probability dist
instance
( Margin labelLens dist ~ Categorical label prob
, Ord label, Ord prob, Fractional prob
, label ~ Label (Datapoint dist)
, prob ~ Probability (MarginalizeOut labelLens dist)
, Labeled (Datapoint dist)
, Datapoint (MarginalizeOut labelLens dist) ~ Attributes (Datapoint dist)
, PDF (MarginalizeOut labelLens dist)
, PDF (Margin labelLens dist)
, Marginalize labelLens dist
) => ProbabilityClassifier (Bayes labelLens dist)
where
type ResultDistribution (Bayes labelLens dist) = Margin labelLens dist
probabilityClassify (Bayes dist) dp = Categorical $ Map.fromList $ map (\k -> (k,prob k)) labelL
where
prob k = pdf labelDist k * pdf (attrDist k) dp
labelDist = getMargin (undefined::labelLens) dist
attrDist l = condition (undefined::labelLens) l dist
Categorical labelMap = labelDist
labelL = Map.keys labelMap
instance
( ProbabilityClassifier (Bayes labelLens dist)
, Label (Datapoint (Bayes labelLens dist)) ~ Datapoint (Margin labelLens dist)
, Mean (Margin labelLens dist)
) => Classifier (Bayes labelLens dist)
where
classify model dp = mean $ probabilityClassify model dp