module Bayes.EM(
learnEM
) where
import Bayes
import Bayes.Factor
import Bayes.Sampling
import Bayes.Factor.CPT
import Bayes.FactorElimination
import Data.Maybe(fromJust)
sumG :: (FunctorWithVertex g, Graph g)
=> Sample g (CPT,CPT)
-> Sample g (CPT,CPT)
-> Sample g (CPT,CPT)
sumG ga gb =
let sumNode vertex (xa,xb) =
let (ya,yb) = fromJust . vertexValue gb $ vertex
in
(cptSum [xa,ya], cptSum [xb,yb])
in
fmapWithVertex sumNode ga
divideG :: Vertex
-> (CPT,CPT)
-> CPT
divideG _ (a,b) = cptDivide a b
learnEM :: (FunctorWithVertex g, NamedGraph g, FoldableWithVertex g, DirectedGraph g)
=> [[DVI]]
-> BayesianNetwork g CPT
-> BayesianNetwork g CPT
learnEM samples startG =
let jt = createJunctionTree nodeComparisonForTriangulation startG
results = map (computeSample startG jt) samples
in
fmapWithVertex divideG (foldl1 sumG results)
computeSample :: (FunctorWithVertex g, Graph g)
=> BayesianNetwork g CPT
-> JunctionTree CPT
-> [DVI]
-> Sample g (CPT,CPT)
computeSample startG jt s = fmapWithVertex (computeNode startG jt s) startG
computeNode :: Graph g
=> BayesianNetwork g CPT
-> JunctionTree CPT
-> [DVI]
-> Vertex
-> CPT
-> (CPT,CPT)
computeNode g jt samples vertex _ =
let jt' = changeEvidence samples jt
f@(main:l) = factorVariables (fromJust . vertexValue g $ vertex)
a = fromJust $ posterior jt' f
b = if null l then factorFromScalar 1.0 else fromJust $ posterior jt' l
in
(a,b)