{- | Expectation / Maximization to learn Bayesian network values -} module Bayes.EM( -- * Learning function learnEM ) where import Bayes import Bayes.Factor import Bayes.Sampling import Bayes.Factor.CPT import Bayes.FactorElimination import Data.Maybe(fromJust) sumG :: (FunctorWithVertex g, Graph g) => Sample g (CPT,CPT) -> Sample g (CPT,CPT) -> Sample g (CPT,CPT) sumG ga gb = let sumNode vertex (xa,xb) = let (ya,yb) = fromJust . vertexValue gb $ vertex in (cptSum [xa,ya], cptSum [xb,yb]) in fmapWithVertex sumNode ga divideG :: Vertex -> (CPT,CPT) -> CPT divideG _ (a,b) = cptDivide a b -- | Learn network values from samples using the expectation / maximization algorithm. learnEM :: (FunctorWithVertex g, NamedGraph g, FoldableWithVertex g, DirectedGraph g) => [[DVI]] -- ^ Samples -> BayesianNetwork g CPT -- ^ Start network -> BayesianNetwork g CPT -- ^ Network with new values learnt from the samples learnEM samples startG = let jt = createJunctionTree nodeComparisonForTriangulation startG results = map (computeSample startG jt) samples in fmapWithVertex divideG (foldl1 sumG results) computeSample :: (FunctorWithVertex g, Graph g) => BayesianNetwork g CPT -> JunctionTree CPT -> [DVI] -> Sample g (CPT,CPT) computeSample startG jt s = fmapWithVertex (computeNode startG jt s) startG computeNode :: Graph g => BayesianNetwork g CPT -> JunctionTree CPT -> [DVI] -> Vertex -> CPT -> (CPT,CPT) computeNode g jt samples vertex _ = let jt' = changeEvidence samples jt f@(main:l) = factorVariables (fromJust . vertexValue g $ vertex) a = fromJust $ posterior jt' f b = if null l then factorFromScalar 1.0 else fromJust $ posterior jt' l in (a,b)