{- | Expectation / Maximization to learn Bayesian network values

-}
module Bayes.EM( 
    -- * Learning function
      learnEM
    ) where 

import Bayes
import Bayes.Factor 
import Bayes.Sampling 
import Bayes.Factor.CPT
import Bayes.FactorElimination
import Data.Maybe(fromJust)


sumG :: (FunctorWithVertex g, Graph g) 
     => Sample g (CPT,CPT) 
     -> Sample g (CPT,CPT) 
     -> Sample g (CPT,CPT)
sumG ga gb = 
  let sumNode vertex (xa,xb) = 
        let (ya,yb) = fromJust . vertexValue gb $ vertex 
        in 
        (cptSum [xa,ya], cptSum [xb,yb])
  in 
  fmapWithVertex sumNode ga 

divideG :: Vertex 
        -> (CPT,CPT) 
        -> CPT
divideG _ (a,b) = cptDivide a b 

-- | Learn network values from samples using the expectation / maximization algorithm.
learnEM :: (FunctorWithVertex g, NamedGraph g, FoldableWithVertex g, DirectedGraph g) 
        => [[DVI]] -- ^ Samples
        -> BayesianNetwork g CPT -- ^ Start network
        -> BayesianNetwork g CPT -- ^ Network with new values learnt from the samples
learnEM samples startG = 
    let jt = createJunctionTree nodeComparisonForTriangulation startG 
        results = map (computeSample startG jt) samples
    in 
    fmapWithVertex divideG (foldl1 sumG results)

computeSample :: (FunctorWithVertex g, Graph g) 
              => BayesianNetwork g CPT 
              -> JunctionTree CPT 
              -> [DVI] 
              -> Sample g (CPT,CPT)
computeSample startG jt s = fmapWithVertex (computeNode startG jt s) startG

computeNode :: Graph g 
            => BayesianNetwork g CPT 
            -> JunctionTree CPT 
            -> [DVI] 
            -> Vertex 
            -> CPT 
            -> (CPT,CPT)
computeNode g jt samples vertex _ = 
    let jt' = changeEvidence samples jt
        f@(main:l) = factorVariables (fromJust . vertexValue g $ vertex)
        a = fromJust $ posterior jt' f 
        b = if null l then factorFromScalar 1.0 else fromJust $ posterior jt' l 
    in 
    (a,b)