-- | Bayesian classification is one of the standard algorithms in machine learning.  Typically, we make the naive bayes assumption of assuming that none of our attributes are correlated.  The Bayes data type, however, is capable of both naive and non-naive assumptions.

module HLearn.Models.Classifiers.Bayes
    ( Bayes
    )
    where

import Debug.Trace
import qualified Data.Map as Map
import GHC.TypeLits

import HLearn.Algebra
import HLearn.Models.Distributions
import HLearn.Models.Classifiers.Common

-------------------------------------------------------------------------------
-- data types

newtype Bayes label dist = Bayes dist
    deriving (Read,Show,Eq,Ord,Monoid,Abelian,Group)

-------------------------------------------------------------------------------
-- Training

instance (Monoid dist, HomTrainer dist) => HomTrainer (Bayes label dist) where
    type Datapoint (Bayes label dist) = Datapoint dist
    train1dp dp = Bayes $ train1dp dp

-------------------------------------------------------------------------------
-- Classification

instance Probabilistic (Bayes labelLens dist) where
    type Probability (Bayes labelLens dist) = Probability dist

instance
    ( Margin labelLens dist ~ Categorical label prob
    , Ord label, Ord prob, Fractional prob
    , label ~ Label (Datapoint dist)
    , prob ~ Probability (MarginalizeOut labelLens dist)
    , Labeled (Datapoint dist)
    , Datapoint (MarginalizeOut labelLens dist) ~ Attributes (Datapoint dist)
    , PDF (MarginalizeOut labelLens dist)
    , PDF (Margin labelLens dist)
    , Marginalize labelLens dist
    ) => ProbabilityClassifier (Bayes labelLens dist) 
        where
    type ResultDistribution (Bayes labelLens dist) = Margin labelLens dist
    
    probabilityClassify (Bayes dist) dp = Categorical $ Map.fromList $ map (\k -> (k,prob k)) labelL
        where
            prob k = pdf labelDist k * pdf (attrDist k) dp
            
            labelDist = getMargin (undefined::labelLens) dist
            attrDist l = condition (undefined::labelLens) l dist
            
            Categorical labelMap = labelDist
            labelL = Map.keys labelMap

instance 
    ( ProbabilityClassifier (Bayes labelLens dist)
    , Label (Datapoint (Bayes labelLens dist)) ~ Datapoint (Margin labelLens dist)
    , Mean (Margin labelLens dist)
    ) => Classifier (Bayes labelLens dist)
        where
    classify model dp = mean $ probabilityClassify model dp